added redirect handling, fixed some dns resolution issues
This commit is contained in:
parent
5dffb202a5
commit
8b997df4c2
1 changed files with 90 additions and 42 deletions
160
src/lib.rs
160
src/lib.rs
|
@ -12,7 +12,7 @@ use std::{
|
||||||
fmt,
|
fmt,
|
||||||
io::{BufRead, BufReader, Error as IoError, Write},
|
io::{BufRead, BufReader, Error as IoError, Write},
|
||||||
iter,
|
iter,
|
||||||
net::{IpAddr, Ipv4Addr, TcpStream, ToSocketAddrs, UdpSocket},
|
net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An HTTP request.
|
/// An HTTP request.
|
||||||
|
@ -26,7 +26,7 @@ use std::{
|
||||||
/// // ... start a local server on port 8000 ...
|
/// // ... start a local server on port 8000 ...
|
||||||
/// let request = Request::get("localhost:8000");
|
/// let request = Request::get("localhost:8000");
|
||||||
/// let response = request.send().unwrap();
|
/// let response = request.send().unwrap();
|
||||||
/// assert_eq!(response.status, 200)
|
/// assert_eq!(response.status, 200);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// Adding headers:
|
/// Adding headers:
|
||||||
|
@ -67,6 +67,8 @@ pub struct Request<'a> {
|
||||||
headers: HashMap<&'a str, &'a str>,
|
headers: HashMap<&'a str, &'a str>,
|
||||||
/// Request body.
|
/// Request body.
|
||||||
body: &'a str,
|
body: &'a str,
|
||||||
|
/// How many redirects are followed before an error is emitted.
|
||||||
|
redirects: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Request<'a> {
|
impl<'a> Request<'a> {
|
||||||
|
@ -87,9 +89,25 @@ impl<'a> Request<'a> {
|
||||||
method,
|
method,
|
||||||
headers: HashMap::new(),
|
headers: HashMap::new(),
|
||||||
body: "",
|
body: "",
|
||||||
|
redirects: 4,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the URL of the request.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use request::*;
|
||||||
|
/// let request = Request::get("http://example.org/a").url("http://example.org/b");
|
||||||
|
/// assert_eq!(format!("{request}"), "GET /b HTTP/1.1\r\nHost: example.org\r\n\r\n");
|
||||||
|
/// ```
|
||||||
|
pub fn url(self, url: &'a str) -> Self {
|
||||||
|
let mut request = self;
|
||||||
|
request.url = url;
|
||||||
|
request
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a body to the request.
|
/// Add a body to the request.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
|
@ -119,6 +137,13 @@ impl<'a> Request<'a> {
|
||||||
request
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the maximum allowed redirects.
|
||||||
|
pub fn redirects(self, max: usize) -> Self {
|
||||||
|
let mut request = self;
|
||||||
|
request.redirects = max;
|
||||||
|
request
|
||||||
|
}
|
||||||
|
|
||||||
/// Construct a new GET request.
|
/// Construct a new GET request.
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
|
@ -163,7 +188,8 @@ impl<'a> Request<'a> {
|
||||||
|
|
||||||
// create the stream
|
// create the stream
|
||||||
let host = resolve(host(self.url).unwrap())?;
|
let host = resolve(host(self.url).unwrap())?;
|
||||||
let mut stream = TcpStream::connect((host, 80))?;
|
let port = port(self.url).unwrap_or("80").parse::<u16>().unwrap();
|
||||||
|
let mut stream = TcpStream::connect((host, port))?;
|
||||||
|
|
||||||
// send the message
|
// send the message
|
||||||
stream.write_all(message.as_bytes())?;
|
stream.write_all(message.as_bytes())?;
|
||||||
|
@ -175,6 +201,22 @@ impl<'a> Request<'a> {
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let received = lines.join("\n");
|
let received = lines.join("\n");
|
||||||
|
|
||||||
|
// check for redirects
|
||||||
|
if lines[0].contains("301") && self.redirects > 0 {
|
||||||
|
assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached
|
||||||
|
|
||||||
|
// find new location
|
||||||
|
let location = lines
|
||||||
|
.iter()
|
||||||
|
.find_map(|l| l.strip_prefix("Location: "))
|
||||||
|
.unwrap(); // todo: error for missing location in redirect
|
||||||
|
return self
|
||||||
|
.clone()
|
||||||
|
.redirects(self.redirects - 1)
|
||||||
|
.url(location)
|
||||||
|
.send();
|
||||||
|
}
|
||||||
|
|
||||||
// process response
|
// process response
|
||||||
let response = Response::parse(&received).unwrap();
|
let response = Response::parse(&received).unwrap();
|
||||||
|
|
||||||
|
@ -183,13 +225,10 @@ impl<'a> Request<'a> {
|
||||||
}
|
}
|
||||||
impl<'a> fmt::Display for Request<'a> {
|
impl<'a> fmt::Display for Request<'a> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
let (method, path, host, body) = (
|
let method = self.method;
|
||||||
self.method,
|
let path = path(self.url).ok_or(fmt::Error)?;
|
||||||
path(self.url).ok_or(fmt::Error)?,
|
let host = host(self.url).ok_or(fmt::Error)?;
|
||||||
host(self.url).ok_or(fmt::Error)?,
|
let body = self.body;
|
||||||
self.body,
|
|
||||||
);
|
|
||||||
|
|
||||||
let headers = iter::once(format!("Host: {host}"))
|
let headers = iter::once(format!("Host: {host}"))
|
||||||
.chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}")))
|
.chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}")))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
|
@ -248,11 +287,6 @@ impl Response {
|
||||||
.map(|(a, b)| (a.to_string(), b.to_string()))
|
.map(|(a, b)| (a.to_string(), b.to_string()))
|
||||||
.collect::<HashMap<String, String>>();
|
.collect::<HashMap<String, String>>();
|
||||||
|
|
||||||
// check if redirect
|
|
||||||
if status == 301 {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse body
|
// parse body
|
||||||
let body = parts.name("body").map(|m| m.as_str().to_string());
|
let body = parts.name("body").map(|m| m.as_str().to_string());
|
||||||
|
|
||||||
|
@ -270,7 +304,7 @@ impl Response {
|
||||||
}
|
}
|
||||||
|
|
||||||
static URI_REGEX: Lazy<Regex> = Lazy::new(|| {
|
static URI_REGEX: Lazy<Regex> = Lazy::new(|| {
|
||||||
Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z:\\.\\-]+)(?P<path>/(?:.)*)?").unwrap()
|
Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z\\.\\-]+)(?:\\:(?P<port>\\d+))?(?P<path>/(?:.)*)?").unwrap()
|
||||||
});
|
});
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn scheme(url: &str) -> Option<&str> {
|
fn scheme(url: &str) -> Option<&str> {
|
||||||
|
@ -279,6 +313,9 @@ fn scheme(url: &str) -> Option<&str> {
|
||||||
fn host(url: &str) -> Option<&str> {
|
fn host(url: &str) -> Option<&str> {
|
||||||
URI_REGEX.captures(url)?.name("host").map(|m| m.as_str())
|
URI_REGEX.captures(url)?.name("host").map(|m| m.as_str())
|
||||||
}
|
}
|
||||||
|
fn port(url: &str) -> Option<&str> {
|
||||||
|
URI_REGEX.captures(url)?.name("port").map(|m| m.as_str())
|
||||||
|
}
|
||||||
fn path(url: &str) -> Option<&str> {
|
fn path(url: &str) -> Option<&str> {
|
||||||
URI_REGEX
|
URI_REGEX
|
||||||
.captures(url)?
|
.captures(url)?
|
||||||
|
@ -289,8 +326,51 @@ fn path(url: &str) -> Option<&str> {
|
||||||
|
|
||||||
/// Resolve DNS request using system nameservers.
|
/// Resolve DNS request using system nameservers.
|
||||||
fn resolve(query: &str) -> Result<IpAddr, IoError> {
|
fn resolve(query: &str) -> Result<IpAddr, IoError> {
|
||||||
|
// todo: local overrides
|
||||||
|
if query.starts_with("localhost") {
|
||||||
|
return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST));
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo: dns caching
|
||||||
|
// create dns query header: [id, flags, questions, answers, authority, additional]
|
||||||
|
let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
|
||||||
|
let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); // [qtype, qclass] = [A, IN(ternet)]
|
||||||
|
|
||||||
|
// convert query to standard dns name notation (max 63 characters for each label)
|
||||||
|
let ascii = query.chars().filter(char::is_ascii).collect::<String>();
|
||||||
|
let name = ascii
|
||||||
|
.split('.')
|
||||||
|
.flat_map(|l| {
|
||||||
|
iter::once(u8::try_from(l.len()).unwrap_or(63).min(63)).chain(l.bytes().take(63))
|
||||||
|
})
|
||||||
|
.chain(iter::once(0))
|
||||||
|
.collect::<Vec<u8>>();
|
||||||
|
|
||||||
|
// construct the message
|
||||||
|
let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
|
||||||
|
message.extend(&name[..]);
|
||||||
|
message.extend(bytemuck::cast_slice(&question));
|
||||||
|
|
||||||
|
// create the socket
|
||||||
|
let socket = UdpSocket::bind("0.0.0.0:0")?;
|
||||||
|
socket.connect(&DNS_SERVERS[..])?;
|
||||||
|
|
||||||
|
// write dns lookup message
|
||||||
|
socket.send_to(&message, &DNS_SERVERS[..]).unwrap();
|
||||||
|
|
||||||
|
// read dns response
|
||||||
|
let mut buf = vec![0; 1024];
|
||||||
|
let (n, _addr) = socket.recv_from(&mut buf)?;
|
||||||
|
buf.resize(n, 0);
|
||||||
|
|
||||||
|
// parse out the address
|
||||||
|
let ip = &buf.get(message.len()..).unwrap()[12..];
|
||||||
|
let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
|
||||||
|
|
||||||
|
Ok(address)
|
||||||
|
}
|
||||||
|
static DNS_SERVERS: Lazy<Vec<SocketAddr>> = Lazy::new(|| {
|
||||||
// find name servers (platform-dependent)
|
// find name servers (platform-dependent)
|
||||||
let servers = {
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
{
|
{
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
@ -304,43 +384,11 @@ fn resolve(query: &str) -> Result<IpAddr, IoError> {
|
||||||
}
|
}
|
||||||
#[cfg(windows)]
|
#[cfg(windows)]
|
||||||
{
|
{
|
||||||
("8.8.8.8", 53).to_socket_addrs()?.collect::<Vec<_>>()
|
// todo: get windows name servers
|
||||||
|
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
|
||||||
}
|
}
|
||||||
};
|
#[cfg(not(any(unix, windows)))]
|
||||||
|
{
|
||||||
// request dns resolution from nameservers
|
vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
|
||||||
let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
|
}
|
||||||
let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be());
|
});
|
||||||
|
|
||||||
// convert query to standard dns name notation
|
|
||||||
let ascii = query.chars().filter(char::is_ascii).collect::<String>();
|
|
||||||
let name = ascii
|
|
||||||
.split('.')
|
|
||||||
.flat_map(|l| iter::once(u8::try_from(l.len()).unwrap_or(63)).chain(l.bytes().take(63)))
|
|
||||||
.chain(iter::once(0))
|
|
||||||
.collect::<Vec<u8>>();
|
|
||||||
|
|
||||||
// construct the message
|
|
||||||
let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
|
|
||||||
message.extend(&name[..]);
|
|
||||||
message.extend(bytemuck::cast_slice(&question));
|
|
||||||
|
|
||||||
// create the socket
|
|
||||||
let socket = UdpSocket::bind("0.0.0.0:0")?;
|
|
||||||
socket.connect(&servers[..])?;
|
|
||||||
|
|
||||||
// write dns lookup message
|
|
||||||
socket.send_to(&message, &servers[..]).unwrap();
|
|
||||||
|
|
||||||
// read dns response
|
|
||||||
let mut buf = vec![0; 1024];
|
|
||||||
let (n, _addr) = socket.recv_from(&mut buf)?;
|
|
||||||
buf.resize(n, 0);
|
|
||||||
|
|
||||||
// parse out the address
|
|
||||||
let answers = &buf[message.len()..];
|
|
||||||
let ip = &answers[12..];
|
|
||||||
let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
|
|
||||||
|
|
||||||
Ok(address)
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue