From 9e52588e0ac8d885b38c78908c10bf6309e409ad Mon Sep 17 00:00:00 2001 From: vodofrede Date: Sat, 2 Mar 2024 13:56:24 +0100 Subject: [PATCH] fixed slow response reading by using a fixed buffer --- src/lib.rs | 68 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5138a85..df87d6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,13 @@ -#![warn(clippy::all, clippy::pedantic)] +#![warn(clippy::all, clippy::pedantic, missing_docs)] #![deny(unsafe_code)] #![doc = include_str!("../README.md")] -#[cfg(test)] -mod tests; - use once_cell::sync::Lazy; use regex::Regex; use std::{ collections::HashMap, fmt, - io::{BufRead, BufReader, Error as IoError, Write}, + io::{self, Read, Write}, iter, net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket}, }; @@ -93,6 +90,13 @@ impl<'a> Request<'a> { } } + /// Set the HTTP method of the request. + pub fn method(self, method: Method) -> Self { + let mut request = self; + request.method = method; + request + } + /// Set the URL of the request. /// /// # Examples @@ -181,10 +185,9 @@ impl<'a> Request<'a> { /// let response = request.send().expect("request failed"); /// assert_eq!(response.status, 200); /// ``` - pub fn send(&self) -> Result { + pub fn send(&self) -> Result { // format the message let message = format!("{self}"); - dbg!(&message); // create the stream let host = resolve(host(self.url).unwrap())?; @@ -195,26 +198,25 @@ impl<'a> Request<'a> { stream.write_all(message.as_bytes())?; // receive the response - let lines = BufReader::new(stream) - .lines() - .map_while(Result::ok) - .collect::>(); - let received = lines.join("\n"); + // todo: allow larger responses by resizing response buffer + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf)?; + buf.resize(n, 0); + let received = String::from_utf8(buf).unwrap(); // check for redirects - if lines[0].contains("301") && self.redirects > 0 { - assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached - - // find new location - let location = lines - .iter() + let status: u16 = received[9..12].parse().unwrap(); + if (300..400).contains(&status) { + // todo: error for maximum redirect limit reached + assert!(self.redirects > 0, "maximum redirect limit reached"); + let location = received + .lines() .find_map(|l| l.strip_prefix("Location: ")) .unwrap(); // todo: error for missing location in redirect - return self - .clone() - .redirects(self.redirects - 1) - .url(location) - .send(); + let request = self.clone().redirects(self.redirects - 1).url(location); + return (status == 303) + .then(|| request.send()) + .unwrap_or_else(|| request.method(Method::GET).send()); } // process response @@ -241,6 +243,7 @@ impl<'a> fmt::Display for Request<'a> { /// HTTP methods. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[allow(missing_docs)] pub enum Method { GET, HEAD, @@ -256,10 +259,19 @@ pub enum Method { /// An HTTP response. #[derive(Debug, Clone)] pub struct Response { + /// HTTP version. + /// + /// Should be one of HTTP/1.0, HTTP/1.1, HTTP/2.0, or HTTP/3.0. pub version: String, + /// Status code. + /// + /// 100-199: info, 200-299: success, 300-399: redir, 400-499: client error, 500-599: server error. pub status: u16, + /// Message associated to the status code. pub reason: String, + /// Map of headers. pub headers: HashMap, + /// Message body. pub body: Option, } impl Response { @@ -325,7 +337,7 @@ fn path(url: &str) -> Option<&str> { } /// Resolve DNS request using system nameservers. -fn resolve(query: &str) -> Result { +fn resolve(query: &str) -> Result { // todo: local overrides if query.starts_with("localhost") { return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST)); @@ -359,8 +371,9 @@ fn resolve(query: &str) -> Result { socket.send_to(&message, &DNS_SERVERS[..]).unwrap(); // read dns response - let mut buf = vec![0; 1024]; - let (n, _addr) = socket.recv_from(&mut buf)?; + let mut buf = vec![0u8; 256]; + socket.peek_from(&mut buf)?; + let n = socket.recv(&mut buf)?; buf.resize(n, 0); // parse out the address @@ -392,3 +405,6 @@ static DNS_SERVERS: Lazy> = Lazy::new(|| { vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)] } }); + +#[cfg(test)] +mod tests;