fixed slow response reading by using a fixed buffer

This commit is contained in:
Frederik Palmø 2024-03-02 13:56:24 +01:00
parent 8b997df4c2
commit 9e52588e0a

View file

@ -1,16 +1,13 @@
#![warn(clippy::all, clippy::pedantic)] #![warn(clippy::all, clippy::pedantic, missing_docs)]
#![deny(unsafe_code)] #![deny(unsafe_code)]
#![doc = include_str!("../README.md")] #![doc = include_str!("../README.md")]
#[cfg(test)]
mod tests;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use regex::Regex; use regex::Regex;
use std::{ use std::{
collections::HashMap, collections::HashMap,
fmt, fmt,
io::{BufRead, BufReader, Error as IoError, Write}, io::{self, Read, Write},
iter, iter,
net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket}, net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket},
}; };
@ -93,6 +90,13 @@ impl<'a> Request<'a> {
} }
} }
/// Set the HTTP method of the request.
pub fn method(self, method: Method) -> Self {
let mut request = self;
request.method = method;
request
}
/// Set the URL of the request. /// Set the URL of the request.
/// ///
/// # Examples /// # Examples
@ -181,10 +185,9 @@ impl<'a> Request<'a> {
/// let response = request.send().expect("request failed"); /// let response = request.send().expect("request failed");
/// assert_eq!(response.status, 200); /// assert_eq!(response.status, 200);
/// ``` /// ```
pub fn send(&self) -> Result<Response, IoError> { pub fn send(&self) -> Result<Response, io::Error> {
// format the message // format the message
let message = format!("{self}"); let message = format!("{self}");
dbg!(&message);
// create the stream // create the stream
let host = resolve(host(self.url).unwrap())?; let host = resolve(host(self.url).unwrap())?;
@ -195,26 +198,25 @@ impl<'a> Request<'a> {
stream.write_all(message.as_bytes())?; stream.write_all(message.as_bytes())?;
// receive the response // receive the response
let lines = BufReader::new(stream) // todo: allow larger responses by resizing response buffer
.lines() let mut buf = vec![0u8; 4096];
.map_while(Result::ok) let n = stream.read(&mut buf)?;
.collect::<Vec<_>>(); buf.resize(n, 0);
let received = lines.join("\n"); let received = String::from_utf8(buf).unwrap();
// check for redirects // check for redirects
if lines[0].contains("301") && self.redirects > 0 { let status: u16 = received[9..12].parse().unwrap();
assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached if (300..400).contains(&status) {
// todo: error for maximum redirect limit reached
// find new location assert!(self.redirects > 0, "maximum redirect limit reached");
let location = lines let location = received
.iter() .lines()
.find_map(|l| l.strip_prefix("Location: ")) .find_map(|l| l.strip_prefix("Location: "))
.unwrap(); // todo: error for missing location in redirect .unwrap(); // todo: error for missing location in redirect
return self let request = self.clone().redirects(self.redirects - 1).url(location);
.clone() return (status == 303)
.redirects(self.redirects - 1) .then(|| request.send())
.url(location) .unwrap_or_else(|| request.method(Method::GET).send());
.send();
} }
// process response // process response
@ -241,6 +243,7 @@ impl<'a> fmt::Display for Request<'a> {
/// HTTP methods. /// HTTP methods.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[allow(missing_docs)]
pub enum Method { pub enum Method {
GET, GET,
HEAD, HEAD,
@ -256,10 +259,19 @@ pub enum Method {
/// An HTTP response. /// An HTTP response.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Response { pub struct Response {
/// HTTP version.
///
/// Should be one of HTTP/1.0, HTTP/1.1, HTTP/2.0, or HTTP/3.0.
pub version: String, pub version: String,
/// Status code.
///
/// 100-199: info, 200-299: success, 300-399: redir, 400-499: client error, 500-599: server error.
pub status: u16, pub status: u16,
/// Message associated to the status code.
pub reason: String, pub reason: String,
/// Map of headers.
pub headers: HashMap<String, String>, pub headers: HashMap<String, String>,
/// Message body.
pub body: Option<String>, pub body: Option<String>,
} }
impl Response { impl Response {
@ -325,7 +337,7 @@ fn path(url: &str) -> Option<&str> {
} }
/// Resolve DNS request using system nameservers. /// Resolve DNS request using system nameservers.
fn resolve(query: &str) -> Result<IpAddr, IoError> { fn resolve(query: &str) -> Result<IpAddr, io::Error> {
// todo: local overrides // todo: local overrides
if query.starts_with("localhost") { if query.starts_with("localhost") {
return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST)); return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST));
@ -359,8 +371,9 @@ fn resolve(query: &str) -> Result<IpAddr, IoError> {
socket.send_to(&message, &DNS_SERVERS[..]).unwrap(); socket.send_to(&message, &DNS_SERVERS[..]).unwrap();
// read dns response // read dns response
let mut buf = vec![0; 1024]; let mut buf = vec![0u8; 256];
let (n, _addr) = socket.recv_from(&mut buf)?; socket.peek_from(&mut buf)?;
let n = socket.recv(&mut buf)?;
buf.resize(n, 0); buf.resize(n, 0);
// parse out the address // parse out the address
@ -392,3 +405,6 @@ static DNS_SERVERS: Lazy<Vec<SocketAddr>> = Lazy::new(|| {
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)] vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
} }
}); });
#[cfg(test)]
mod tests;