From 8b997df4c22b2cf586439dcc53393462fe2be995 Mon Sep 17 00:00:00 2001 From: vodofrede Date: Fri, 1 Mar 2024 13:24:01 +0100 Subject: [PATCH] added redirect handling, fixed some dns resolution issues --- src/lib.rs | 132 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 90 insertions(+), 42 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d0258f5..5138a85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ use std::{ fmt, io::{BufRead, BufReader, Error as IoError, Write}, iter, - net::{IpAddr, Ipv4Addr, TcpStream, ToSocketAddrs, UdpSocket}, + net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket}, }; /// An HTTP request. @@ -26,7 +26,7 @@ use std::{ /// // ... start a local server on port 8000 ... /// let request = Request::get("localhost:8000"); /// let response = request.send().unwrap(); -/// assert_eq!(response.status, 200) +/// assert_eq!(response.status, 200); /// ``` /// /// Adding headers: @@ -67,6 +67,8 @@ pub struct Request<'a> { headers: HashMap<&'a str, &'a str>, /// Request body. body: &'a str, + /// How many redirects are followed before an error is emitted. + redirects: usize, } impl<'a> Request<'a> { @@ -87,9 +89,25 @@ impl<'a> Request<'a> { method, headers: HashMap::new(), body: "", + redirects: 4, } } + /// Set the URL of the request. + /// + /// # Examples + /// + /// ```rust + /// # use request::*; + /// let request = Request::get("http://example.org/a").url("http://example.org/b"); + /// assert_eq!(format!("{request}"), "GET /b HTTP/1.1\r\nHost: example.org\r\n\r\n"); + /// ``` + pub fn url(self, url: &'a str) -> Self { + let mut request = self; + request.url = url; + request + } + /// Add a body to the request. /// /// # Examples @@ -119,6 +137,13 @@ impl<'a> Request<'a> { request } + /// Set the maximum allowed redirects. + pub fn redirects(self, max: usize) -> Self { + let mut request = self; + request.redirects = max; + request + } + /// Construct a new GET request. /// /// # Examples @@ -163,7 +188,8 @@ impl<'a> Request<'a> { // create the stream let host = resolve(host(self.url).unwrap())?; - let mut stream = TcpStream::connect((host, 80))?; + let port = port(self.url).unwrap_or("80").parse::().unwrap(); + let mut stream = TcpStream::connect((host, port))?; // send the message stream.write_all(message.as_bytes())?; @@ -175,6 +201,22 @@ impl<'a> Request<'a> { .collect::>(); let received = lines.join("\n"); + // check for redirects + if lines[0].contains("301") && self.redirects > 0 { + assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached + + // find new location + let location = lines + .iter() + .find_map(|l| l.strip_prefix("Location: ")) + .unwrap(); // todo: error for missing location in redirect + return self + .clone() + .redirects(self.redirects - 1) + .url(location) + .send(); + } + // process response let response = Response::parse(&received).unwrap(); @@ -183,13 +225,10 @@ impl<'a> Request<'a> { } impl<'a> fmt::Display for Request<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let (method, path, host, body) = ( - self.method, - path(self.url).ok_or(fmt::Error)?, - host(self.url).ok_or(fmt::Error)?, - self.body, - ); - + let method = self.method; + let path = path(self.url).ok_or(fmt::Error)?; + let host = host(self.url).ok_or(fmt::Error)?; + let body = self.body; let headers = iter::once(format!("Host: {host}")) .chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}"))) .collect::>() @@ -248,11 +287,6 @@ impl Response { .map(|(a, b)| (a.to_string(), b.to_string())) .collect::>(); - // check if redirect - if status == 301 { - todo!() - } - // parse body let body = parts.name("body").map(|m| m.as_str().to_string()); @@ -270,7 +304,7 @@ impl Response { } static URI_REGEX: Lazy = Lazy::new(|| { - Regex::new("(?:(?Phttps?)://)?(?P[0-9a-zA-Z:\\.\\-]+)(?P/(?:.)*)?").unwrap() + Regex::new("(?:(?Phttps?)://)?(?P[0-9a-zA-Z\\.\\-]+)(?:\\:(?P\\d+))?(?P/(?:.)*)?").unwrap() }); #[allow(dead_code)] fn scheme(url: &str) -> Option<&str> { @@ -279,6 +313,9 @@ fn scheme(url: &str) -> Option<&str> { fn host(url: &str) -> Option<&str> { URI_REGEX.captures(url)?.name("host").map(|m| m.as_str()) } +fn port(url: &str) -> Option<&str> { + URI_REGEX.captures(url)?.name("port").map(|m| m.as_str()) +} fn path(url: &str) -> Option<&str> { URI_REGEX .captures(url)? @@ -289,34 +326,23 @@ fn path(url: &str) -> Option<&str> { /// Resolve DNS request using system nameservers. fn resolve(query: &str) -> Result { - // find name servers (platform-dependent) - let servers = { - #[cfg(unix)] - { - use std::fs; - let resolv = fs::read_to_string("/etc/resolv.conf")?; - let servers = resolv - .lines() - .filter_map(|l| l.split_once("nameserver ").map(|(_, s)| s.to_string())) - .flat_map(|ns| ns.to_socket_addrs().into_iter().flatten()) - .collect::>(); - servers - } - #[cfg(windows)] - { - ("8.8.8.8", 53).to_socket_addrs()?.collect::>() - } - }; + // todo: local overrides + if query.starts_with("localhost") { + return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST)); + } - // request dns resolution from nameservers + // todo: dns caching + // create dns query header: [id, flags, questions, answers, authority, additional] let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be()); - let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); + let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); // [qtype, qclass] = [A, IN(ternet)] - // convert query to standard dns name notation + // convert query to standard dns name notation (max 63 characters for each label) let ascii = query.chars().filter(char::is_ascii).collect::(); let name = ascii .split('.') - .flat_map(|l| iter::once(u8::try_from(l.len()).unwrap_or(63)).chain(l.bytes().take(63))) + .flat_map(|l| { + iter::once(u8::try_from(l.len()).unwrap_or(63).min(63)).chain(l.bytes().take(63)) + }) .chain(iter::once(0)) .collect::>(); @@ -327,10 +353,10 @@ fn resolve(query: &str) -> Result { // create the socket let socket = UdpSocket::bind("0.0.0.0:0")?; - socket.connect(&servers[..])?; + socket.connect(&DNS_SERVERS[..])?; // write dns lookup message - socket.send_to(&message, &servers[..]).unwrap(); + socket.send_to(&message, &DNS_SERVERS[..]).unwrap(); // read dns response let mut buf = vec![0; 1024]; @@ -338,9 +364,31 @@ fn resolve(query: &str) -> Result { buf.resize(n, 0); // parse out the address - let answers = &buf[message.len()..]; - let ip = &answers[12..]; + let ip = &buf.get(message.len()..).unwrap()[12..]; let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3])); Ok(address) } +static DNS_SERVERS: Lazy> = Lazy::new(|| { + // find name servers (platform-dependent) + #[cfg(unix)] + { + use std::fs; + let resolv = fs::read_to_string("/etc/resolv.conf")?; + let servers = resolv + .lines() + .filter_map(|l| l.split_once("nameserver ").map(|(_, s)| s.to_string())) + .flat_map(|ns| ns.to_socket_addrs().into_iter().flatten()) + .collect::>(); + servers + } + #[cfg(windows)] + { + // todo: get windows name servers + vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)] + } + #[cfg(not(any(unix, windows)))] + { + vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)] + } +});