added redirect handling, fixed some dns resolution issues
This commit is contained in:
parent
5dffb202a5
commit
8b997df4c2
1 changed files with 90 additions and 42 deletions
158
src/lib.rs
158
src/lib.rs
|
@ -12,7 +12,7 @@ use std::{
|
|||
fmt,
|
||||
io::{BufRead, BufReader, Error as IoError, Write},
|
||||
iter,
|
||||
net::{IpAddr, Ipv4Addr, TcpStream, ToSocketAddrs, UdpSocket},
|
||||
net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket},
|
||||
};
|
||||
|
||||
/// An HTTP request.
|
||||
|
@ -26,7 +26,7 @@ use std::{
|
|||
/// // ... start a local server on port 8000 ...
|
||||
/// let request = Request::get("localhost:8000");
|
||||
/// let response = request.send().unwrap();
|
||||
/// assert_eq!(response.status, 200)
|
||||
/// assert_eq!(response.status, 200);
|
||||
/// ```
|
||||
///
|
||||
/// Adding headers:
|
||||
|
@ -67,6 +67,8 @@ pub struct Request<'a> {
|
|||
headers: HashMap<&'a str, &'a str>,
|
||||
/// Request body.
|
||||
body: &'a str,
|
||||
/// How many redirects are followed before an error is emitted.
|
||||
redirects: usize,
|
||||
}
|
||||
|
||||
impl<'a> Request<'a> {
|
||||
|
@ -87,9 +89,25 @@ impl<'a> Request<'a> {
|
|||
method,
|
||||
headers: HashMap::new(),
|
||||
body: "",
|
||||
redirects: 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the URL of the request.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// # use request::*;
|
||||
/// let request = Request::get("http://example.org/a").url("http://example.org/b");
|
||||
/// assert_eq!(format!("{request}"), "GET /b HTTP/1.1\r\nHost: example.org\r\n\r\n");
|
||||
/// ```
|
||||
pub fn url(self, url: &'a str) -> Self {
|
||||
let mut request = self;
|
||||
request.url = url;
|
||||
request
|
||||
}
|
||||
|
||||
/// Add a body to the request.
|
||||
///
|
||||
/// # Examples
|
||||
|
@ -119,6 +137,13 @@ impl<'a> Request<'a> {
|
|||
request
|
||||
}
|
||||
|
||||
/// Set the maximum allowed redirects.
|
||||
pub fn redirects(self, max: usize) -> Self {
|
||||
let mut request = self;
|
||||
request.redirects = max;
|
||||
request
|
||||
}
|
||||
|
||||
/// Construct a new GET request.
|
||||
///
|
||||
/// # Examples
|
||||
|
@ -163,7 +188,8 @@ impl<'a> Request<'a> {
|
|||
|
||||
// create the stream
|
||||
let host = resolve(host(self.url).unwrap())?;
|
||||
let mut stream = TcpStream::connect((host, 80))?;
|
||||
let port = port(self.url).unwrap_or("80").parse::<u16>().unwrap();
|
||||
let mut stream = TcpStream::connect((host, port))?;
|
||||
|
||||
// send the message
|
||||
stream.write_all(message.as_bytes())?;
|
||||
|
@ -175,6 +201,22 @@ impl<'a> Request<'a> {
|
|||
.collect::<Vec<_>>();
|
||||
let received = lines.join("\n");
|
||||
|
||||
// check for redirects
|
||||
if lines[0].contains("301") && self.redirects > 0 {
|
||||
assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached
|
||||
|
||||
// find new location
|
||||
let location = lines
|
||||
.iter()
|
||||
.find_map(|l| l.strip_prefix("Location: "))
|
||||
.unwrap(); // todo: error for missing location in redirect
|
||||
return self
|
||||
.clone()
|
||||
.redirects(self.redirects - 1)
|
||||
.url(location)
|
||||
.send();
|
||||
}
|
||||
|
||||
// process response
|
||||
let response = Response::parse(&received).unwrap();
|
||||
|
||||
|
@ -183,13 +225,10 @@ impl<'a> Request<'a> {
|
|||
}
|
||||
impl<'a> fmt::Display for Request<'a> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let (method, path, host, body) = (
|
||||
self.method,
|
||||
path(self.url).ok_or(fmt::Error)?,
|
||||
host(self.url).ok_or(fmt::Error)?,
|
||||
self.body,
|
||||
);
|
||||
|
||||
let method = self.method;
|
||||
let path = path(self.url).ok_or(fmt::Error)?;
|
||||
let host = host(self.url).ok_or(fmt::Error)?;
|
||||
let body = self.body;
|
||||
let headers = iter::once(format!("Host: {host}"))
|
||||
.chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}")))
|
||||
.collect::<Vec<_>>()
|
||||
|
@ -248,11 +287,6 @@ impl Response {
|
|||
.map(|(a, b)| (a.to_string(), b.to_string()))
|
||||
.collect::<HashMap<String, String>>();
|
||||
|
||||
// check if redirect
|
||||
if status == 301 {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// parse body
|
||||
let body = parts.name("body").map(|m| m.as_str().to_string());
|
||||
|
||||
|
@ -270,7 +304,7 @@ impl Response {
|
|||
}
|
||||
|
||||
static URI_REGEX: Lazy<Regex> = Lazy::new(|| {
|
||||
Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z:\\.\\-]+)(?P<path>/(?:.)*)?").unwrap()
|
||||
Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z\\.\\-]+)(?:\\:(?P<port>\\d+))?(?P<path>/(?:.)*)?").unwrap()
|
||||
});
|
||||
#[allow(dead_code)]
|
||||
fn scheme(url: &str) -> Option<&str> {
|
||||
|
@ -279,6 +313,9 @@ fn scheme(url: &str) -> Option<&str> {
|
|||
fn host(url: &str) -> Option<&str> {
|
||||
URI_REGEX.captures(url)?.name("host").map(|m| m.as_str())
|
||||
}
|
||||
fn port(url: &str) -> Option<&str> {
|
||||
URI_REGEX.captures(url)?.name("port").map(|m| m.as_str())
|
||||
}
|
||||
fn path(url: &str) -> Option<&str> {
|
||||
URI_REGEX
|
||||
.captures(url)?
|
||||
|
@ -289,8 +326,51 @@ fn path(url: &str) -> Option<&str> {
|
|||
|
||||
/// Resolve DNS request using system nameservers.
|
||||
fn resolve(query: &str) -> Result<IpAddr, IoError> {
|
||||
// todo: local overrides
|
||||
if query.starts_with("localhost") {
|
||||
return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST));
|
||||
}
|
||||
|
||||
// todo: dns caching
|
||||
// create dns query header: [id, flags, questions, answers, authority, additional]
|
||||
let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
|
||||
let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); // [qtype, qclass] = [A, IN(ternet)]
|
||||
|
||||
// convert query to standard dns name notation (max 63 characters for each label)
|
||||
let ascii = query.chars().filter(char::is_ascii).collect::<String>();
|
||||
let name = ascii
|
||||
.split('.')
|
||||
.flat_map(|l| {
|
||||
iter::once(u8::try_from(l.len()).unwrap_or(63).min(63)).chain(l.bytes().take(63))
|
||||
})
|
||||
.chain(iter::once(0))
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
// construct the message
|
||||
let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
|
||||
message.extend(&name[..]);
|
||||
message.extend(bytemuck::cast_slice(&question));
|
||||
|
||||
// create the socket
|
||||
let socket = UdpSocket::bind("0.0.0.0:0")?;
|
||||
socket.connect(&DNS_SERVERS[..])?;
|
||||
|
||||
// write dns lookup message
|
||||
socket.send_to(&message, &DNS_SERVERS[..]).unwrap();
|
||||
|
||||
// read dns response
|
||||
let mut buf = vec![0; 1024];
|
||||
let (n, _addr) = socket.recv_from(&mut buf)?;
|
||||
buf.resize(n, 0);
|
||||
|
||||
// parse out the address
|
||||
let ip = &buf.get(message.len()..).unwrap()[12..];
|
||||
let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
|
||||
|
||||
Ok(address)
|
||||
}
|
||||
static DNS_SERVERS: Lazy<Vec<SocketAddr>> = Lazy::new(|| {
|
||||
// find name servers (platform-dependent)
|
||||
let servers = {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::fs;
|
||||
|
@ -304,43 +384,11 @@ fn resolve(query: &str) -> Result<IpAddr, IoError> {
|
|||
}
|
||||
#[cfg(windows)]
|
||||
{
|
||||
("8.8.8.8", 53).to_socket_addrs()?.collect::<Vec<_>>()
|
||||
// todo: get windows name servers
|
||||
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
|
||||
}
|
||||
};
|
||||
|
||||
// request dns resolution from nameservers
|
||||
let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
|
||||
let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be());
|
||||
|
||||
// convert query to standard dns name notation
|
||||
let ascii = query.chars().filter(char::is_ascii).collect::<String>();
|
||||
let name = ascii
|
||||
.split('.')
|
||||
.flat_map(|l| iter::once(u8::try_from(l.len()).unwrap_or(63)).chain(l.bytes().take(63)))
|
||||
.chain(iter::once(0))
|
||||
.collect::<Vec<u8>>();
|
||||
|
||||
// construct the message
|
||||
let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
|
||||
message.extend(&name[..]);
|
||||
message.extend(bytemuck::cast_slice(&question));
|
||||
|
||||
// create the socket
|
||||
let socket = UdpSocket::bind("0.0.0.0:0")?;
|
||||
socket.connect(&servers[..])?;
|
||||
|
||||
// write dns lookup message
|
||||
socket.send_to(&message, &servers[..]).unwrap();
|
||||
|
||||
// read dns response
|
||||
let mut buf = vec![0; 1024];
|
||||
let (n, _addr) = socket.recv_from(&mut buf)?;
|
||||
buf.resize(n, 0);
|
||||
|
||||
// parse out the address
|
||||
let answers = &buf[message.len()..];
|
||||
let ip = &answers[12..];
|
||||
let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
|
||||
|
||||
Ok(address)
|
||||
#[cfg(not(any(unix, windows)))]
|
||||
{
|
||||
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
|
||||
}
|
||||
});
|
||||
|
|
Loading…
Reference in a new issue