added redirect handling, fixed some dns resolution issues

This commit is contained in:
Frederik Palmø 2024-03-01 13:24:01 +01:00
parent 5dffb202a5
commit 8b997df4c2

View file

@ -12,7 +12,7 @@ use std::{
fmt,
io::{BufRead, BufReader, Error as IoError, Write},
iter,
net::{IpAddr, Ipv4Addr, TcpStream, ToSocketAddrs, UdpSocket},
net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket},
};
/// An HTTP request.
@ -26,7 +26,7 @@ use std::{
/// // ... start a local server on port 8000 ...
/// let request = Request::get("localhost:8000");
/// let response = request.send().unwrap();
/// assert_eq!(response.status, 200)
/// assert_eq!(response.status, 200);
/// ```
///
/// Adding headers:
@ -67,6 +67,8 @@ pub struct Request<'a> {
headers: HashMap<&'a str, &'a str>,
/// Request body.
body: &'a str,
/// How many redirects are followed before an error is emitted.
redirects: usize,
}
impl<'a> Request<'a> {
@ -87,9 +89,25 @@ impl<'a> Request<'a> {
method,
headers: HashMap::new(),
body: "",
redirects: 4,
}
}
/// Set the URL of the request.
///
/// # Examples
///
/// ```rust
/// # use request::*;
/// let request = Request::get("http://example.org/a").url("http://example.org/b");
/// assert_eq!(format!("{request}"), "GET /b HTTP/1.1\r\nHost: example.org\r\n\r\n");
/// ```
pub fn url(self, url: &'a str) -> Self {
let mut request = self;
request.url = url;
request
}
/// Add a body to the request.
///
/// # Examples
@ -119,6 +137,13 @@ impl<'a> Request<'a> {
request
}
/// Set the maximum allowed redirects.
pub fn redirects(self, max: usize) -> Self {
let mut request = self;
request.redirects = max;
request
}
/// Construct a new GET request.
///
/// # Examples
@ -163,7 +188,8 @@ impl<'a> Request<'a> {
// create the stream
let host = resolve(host(self.url).unwrap())?;
let mut stream = TcpStream::connect((host, 80))?;
let port = port(self.url).unwrap_or("80").parse::<u16>().unwrap();
let mut stream = TcpStream::connect((host, port))?;
// send the message
stream.write_all(message.as_bytes())?;
@ -175,6 +201,22 @@ impl<'a> Request<'a> {
.collect::<Vec<_>>();
let received = lines.join("\n");
// check for redirects
if lines[0].contains("301") && self.redirects > 0 {
assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached
// find new location
let location = lines
.iter()
.find_map(|l| l.strip_prefix("Location: "))
.unwrap(); // todo: error for missing location in redirect
return self
.clone()
.redirects(self.redirects - 1)
.url(location)
.send();
}
// process response
let response = Response::parse(&received).unwrap();
@ -183,13 +225,10 @@ impl<'a> Request<'a> {
}
impl<'a> fmt::Display for Request<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (method, path, host, body) = (
self.method,
path(self.url).ok_or(fmt::Error)?,
host(self.url).ok_or(fmt::Error)?,
self.body,
);
let method = self.method;
let path = path(self.url).ok_or(fmt::Error)?;
let host = host(self.url).ok_or(fmt::Error)?;
let body = self.body;
let headers = iter::once(format!("Host: {host}"))
.chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}")))
.collect::<Vec<_>>()
@ -248,11 +287,6 @@ impl Response {
.map(|(a, b)| (a.to_string(), b.to_string()))
.collect::<HashMap<String, String>>();
// check if redirect
if status == 301 {
todo!()
}
// parse body
let body = parts.name("body").map(|m| m.as_str().to_string());
@ -270,7 +304,7 @@ impl Response {
}
static URI_REGEX: Lazy<Regex> = Lazy::new(|| {
Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z:\\.\\-]+)(?P<path>/(?:.)*)?").unwrap()
Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z\\.\\-]+)(?:\\:(?P<port>\\d+))?(?P<path>/(?:.)*)?").unwrap()
});
#[allow(dead_code)]
fn scheme(url: &str) -> Option<&str> {
@ -279,6 +313,9 @@ fn scheme(url: &str) -> Option<&str> {
fn host(url: &str) -> Option<&str> {
URI_REGEX.captures(url)?.name("host").map(|m| m.as_str())
}
fn port(url: &str) -> Option<&str> {
URI_REGEX.captures(url)?.name("port").map(|m| m.as_str())
}
fn path(url: &str) -> Option<&str> {
URI_REGEX
.captures(url)?
@ -289,34 +326,23 @@ fn path(url: &str) -> Option<&str> {
/// Resolve DNS request using system nameservers.
fn resolve(query: &str) -> Result<IpAddr, IoError> {
// find name servers (platform-dependent)
let servers = {
#[cfg(unix)]
{
use std::fs;
let resolv = fs::read_to_string("/etc/resolv.conf")?;
let servers = resolv
.lines()
.filter_map(|l| l.split_once("nameserver ").map(|(_, s)| s.to_string()))
.flat_map(|ns| ns.to_socket_addrs().into_iter().flatten())
.collect::<Vec<_>>();
servers
}
#[cfg(windows)]
{
("8.8.8.8", 53).to_socket_addrs()?.collect::<Vec<_>>()
}
};
// todo: local overrides
if query.starts_with("localhost") {
return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST));
}
// request dns resolution from nameservers
// todo: dns caching
// create dns query header: [id, flags, questions, answers, authority, additional]
let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be());
let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); // [qtype, qclass] = [A, IN(ternet)]
// convert query to standard dns name notation
// convert query to standard dns name notation (max 63 characters for each label)
let ascii = query.chars().filter(char::is_ascii).collect::<String>();
let name = ascii
.split('.')
.flat_map(|l| iter::once(u8::try_from(l.len()).unwrap_or(63)).chain(l.bytes().take(63)))
.flat_map(|l| {
iter::once(u8::try_from(l.len()).unwrap_or(63).min(63)).chain(l.bytes().take(63))
})
.chain(iter::once(0))
.collect::<Vec<u8>>();
@ -327,10 +353,10 @@ fn resolve(query: &str) -> Result<IpAddr, IoError> {
// create the socket
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.connect(&servers[..])?;
socket.connect(&DNS_SERVERS[..])?;
// write dns lookup message
socket.send_to(&message, &servers[..]).unwrap();
socket.send_to(&message, &DNS_SERVERS[..]).unwrap();
// read dns response
let mut buf = vec![0; 1024];
@ -338,9 +364,31 @@ fn resolve(query: &str) -> Result<IpAddr, IoError> {
buf.resize(n, 0);
// parse out the address
let answers = &buf[message.len()..];
let ip = &answers[12..];
let ip = &buf.get(message.len()..).unwrap()[12..];
let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
Ok(address)
}
static DNS_SERVERS: Lazy<Vec<SocketAddr>> = Lazy::new(|| {
// find name servers (platform-dependent)
#[cfg(unix)]
{
use std::fs;
let resolv = fs::read_to_string("/etc/resolv.conf")?;
let servers = resolv
.lines()
.filter_map(|l| l.split_once("nameserver ").map(|(_, s)| s.to_string()))
.flat_map(|ns| ns.to_socket_addrs().into_iter().flatten())
.collect::<Vec<_>>();
servers
}
#[cfg(windows)]
{
// todo: get windows name servers
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
}
#[cfg(not(any(unix, windows)))]
{
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
}
});