added redirect handling, fixed some dns resolution issues

This commit is contained in:
Frederik Palmø 2024-03-01 13:24:01 +01:00
parent 5dffb202a5
commit 8b997df4c2

View file

@ -12,7 +12,7 @@ use std::{
fmt, fmt,
io::{BufRead, BufReader, Error as IoError, Write}, io::{BufRead, BufReader, Error as IoError, Write},
iter, iter,
net::{IpAddr, Ipv4Addr, TcpStream, ToSocketAddrs, UdpSocket}, net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket},
}; };
/// An HTTP request. /// An HTTP request.
@ -26,7 +26,7 @@ use std::{
/// // ... start a local server on port 8000 ... /// // ... start a local server on port 8000 ...
/// let request = Request::get("localhost:8000"); /// let request = Request::get("localhost:8000");
/// let response = request.send().unwrap(); /// let response = request.send().unwrap();
/// assert_eq!(response.status, 200) /// assert_eq!(response.status, 200);
/// ``` /// ```
/// ///
/// Adding headers: /// Adding headers:
@ -67,6 +67,8 @@ pub struct Request<'a> {
headers: HashMap<&'a str, &'a str>, headers: HashMap<&'a str, &'a str>,
/// Request body. /// Request body.
body: &'a str, body: &'a str,
/// How many redirects are followed before an error is emitted.
redirects: usize,
} }
impl<'a> Request<'a> { impl<'a> Request<'a> {
@ -87,9 +89,25 @@ impl<'a> Request<'a> {
method, method,
headers: HashMap::new(), headers: HashMap::new(),
body: "", body: "",
redirects: 4,
} }
} }
/// Set the URL of the request.
///
/// # Examples
///
/// ```rust
/// # use request::*;
/// let request = Request::get("http://example.org/a").url("http://example.org/b");
/// assert_eq!(format!("{request}"), "GET /b HTTP/1.1\r\nHost: example.org\r\n\r\n");
/// ```
pub fn url(self, url: &'a str) -> Self {
let mut request = self;
request.url = url;
request
}
/// Add a body to the request. /// Add a body to the request.
/// ///
/// # Examples /// # Examples
@ -119,6 +137,13 @@ impl<'a> Request<'a> {
request request
} }
/// Set the maximum allowed redirects.
pub fn redirects(self, max: usize) -> Self {
let mut request = self;
request.redirects = max;
request
}
/// Construct a new GET request. /// Construct a new GET request.
/// ///
/// # Examples /// # Examples
@ -163,7 +188,8 @@ impl<'a> Request<'a> {
// create the stream // create the stream
let host = resolve(host(self.url).unwrap())?; let host = resolve(host(self.url).unwrap())?;
let mut stream = TcpStream::connect((host, 80))?; let port = port(self.url).unwrap_or("80").parse::<u16>().unwrap();
let mut stream = TcpStream::connect((host, port))?;
// send the message // send the message
stream.write_all(message.as_bytes())?; stream.write_all(message.as_bytes())?;
@ -175,6 +201,22 @@ impl<'a> Request<'a> {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let received = lines.join("\n"); let received = lines.join("\n");
// check for redirects
if lines[0].contains("301") && self.redirects > 0 {
assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached
// find new location
let location = lines
.iter()
.find_map(|l| l.strip_prefix("Location: "))
.unwrap(); // todo: error for missing location in redirect
return self
.clone()
.redirects(self.redirects - 1)
.url(location)
.send();
}
// process response // process response
let response = Response::parse(&received).unwrap(); let response = Response::parse(&received).unwrap();
@ -183,13 +225,10 @@ impl<'a> Request<'a> {
} }
impl<'a> fmt::Display for Request<'a> { impl<'a> fmt::Display for Request<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (method, path, host, body) = ( let method = self.method;
self.method, let path = path(self.url).ok_or(fmt::Error)?;
path(self.url).ok_or(fmt::Error)?, let host = host(self.url).ok_or(fmt::Error)?;
host(self.url).ok_or(fmt::Error)?, let body = self.body;
self.body,
);
let headers = iter::once(format!("Host: {host}")) let headers = iter::once(format!("Host: {host}"))
.chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}"))) .chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}")))
.collect::<Vec<_>>() .collect::<Vec<_>>()
@ -248,11 +287,6 @@ impl Response {
.map(|(a, b)| (a.to_string(), b.to_string())) .map(|(a, b)| (a.to_string(), b.to_string()))
.collect::<HashMap<String, String>>(); .collect::<HashMap<String, String>>();
// check if redirect
if status == 301 {
todo!()
}
// parse body // parse body
let body = parts.name("body").map(|m| m.as_str().to_string()); let body = parts.name("body").map(|m| m.as_str().to_string());
@ -270,7 +304,7 @@ impl Response {
} }
static URI_REGEX: Lazy<Regex> = Lazy::new(|| { static URI_REGEX: Lazy<Regex> = Lazy::new(|| {
Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z:\\.\\-]+)(?P<path>/(?:.)*)?").unwrap() Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z\\.\\-]+)(?:\\:(?P<port>\\d+))?(?P<path>/(?:.)*)?").unwrap()
}); });
#[allow(dead_code)] #[allow(dead_code)]
fn scheme(url: &str) -> Option<&str> { fn scheme(url: &str) -> Option<&str> {
@ -279,6 +313,9 @@ fn scheme(url: &str) -> Option<&str> {
fn host(url: &str) -> Option<&str> { fn host(url: &str) -> Option<&str> {
URI_REGEX.captures(url)?.name("host").map(|m| m.as_str()) URI_REGEX.captures(url)?.name("host").map(|m| m.as_str())
} }
fn port(url: &str) -> Option<&str> {
URI_REGEX.captures(url)?.name("port").map(|m| m.as_str())
}
fn path(url: &str) -> Option<&str> { fn path(url: &str) -> Option<&str> {
URI_REGEX URI_REGEX
.captures(url)? .captures(url)?
@ -289,8 +326,51 @@ fn path(url: &str) -> Option<&str> {
/// Resolve DNS request using system nameservers. /// Resolve DNS request using system nameservers.
fn resolve(query: &str) -> Result<IpAddr, IoError> { fn resolve(query: &str) -> Result<IpAddr, IoError> {
// todo: local overrides
if query.starts_with("localhost") {
return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST));
}
// todo: dns caching
// create dns query header: [id, flags, questions, answers, authority, additional]
let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); // [qtype, qclass] = [A, IN(ternet)]
// convert query to standard dns name notation (max 63 characters for each label)
let ascii = query.chars().filter(char::is_ascii).collect::<String>();
let name = ascii
.split('.')
.flat_map(|l| {
iter::once(u8::try_from(l.len()).unwrap_or(63).min(63)).chain(l.bytes().take(63))
})
.chain(iter::once(0))
.collect::<Vec<u8>>();
// construct the message
let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
message.extend(&name[..]);
message.extend(bytemuck::cast_slice(&question));
// create the socket
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.connect(&DNS_SERVERS[..])?;
// write dns lookup message
socket.send_to(&message, &DNS_SERVERS[..]).unwrap();
// read dns response
let mut buf = vec![0; 1024];
let (n, _addr) = socket.recv_from(&mut buf)?;
buf.resize(n, 0);
// parse out the address
let ip = &buf.get(message.len()..).unwrap()[12..];
let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
Ok(address)
}
static DNS_SERVERS: Lazy<Vec<SocketAddr>> = Lazy::new(|| {
// find name servers (platform-dependent) // find name servers (platform-dependent)
let servers = {
#[cfg(unix)] #[cfg(unix)]
{ {
use std::fs; use std::fs;
@ -304,43 +384,11 @@ fn resolve(query: &str) -> Result<IpAddr, IoError> {
} }
#[cfg(windows)] #[cfg(windows)]
{ {
("8.8.8.8", 53).to_socket_addrs()?.collect::<Vec<_>>() // todo: get windows name servers
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
} }
}; #[cfg(not(any(unix, windows)))]
{
// request dns resolution from nameservers vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be()); }
let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); });
// convert query to standard dns name notation
let ascii = query.chars().filter(char::is_ascii).collect::<String>();
let name = ascii
.split('.')
.flat_map(|l| iter::once(u8::try_from(l.len()).unwrap_or(63)).chain(l.bytes().take(63)))
.chain(iter::once(0))
.collect::<Vec<u8>>();
// construct the message
let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
message.extend(&name[..]);
message.extend(bytemuck::cast_slice(&question));
// create the socket
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.connect(&servers[..])?;
// write dns lookup message
socket.send_to(&message, &servers[..]).unwrap();
// read dns response
let mut buf = vec![0; 1024];
let (n, _addr) = socket.recv_from(&mut buf)?;
buf.resize(n, 0);
// parse out the address
let answers = &buf[message.len()..];
let ip = &answers[12..];
let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
Ok(address)
}