added redirect handling, fixed some dns resolution issues
This commit is contained in:
		
							parent
							
								
									5dffb202a5
								
							
						
					
					
						commit
						8b997df4c2
					
				
					 1 changed files with 90 additions and 42 deletions
				
			
		
							
								
								
									
										158
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										158
									
								
								src/lib.rs
									
									
									
									
									
								
							| 
						 | 
				
			
			@ -12,7 +12,7 @@ use std::{
 | 
			
		|||
    fmt,
 | 
			
		||||
    io::{BufRead, BufReader, Error as IoError, Write},
 | 
			
		||||
    iter,
 | 
			
		||||
    net::{IpAddr, Ipv4Addr, TcpStream, ToSocketAddrs, UdpSocket},
 | 
			
		||||
    net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// An HTTP request.
 | 
			
		||||
| 
						 | 
				
			
			@ -26,7 +26,7 @@ use std::{
 | 
			
		|||
/// // ... start a local server on port 8000 ...
 | 
			
		||||
/// let request = Request::get("localhost:8000");
 | 
			
		||||
/// let response = request.send().unwrap();
 | 
			
		||||
/// assert_eq!(response.status, 200)
 | 
			
		||||
/// assert_eq!(response.status, 200);
 | 
			
		||||
/// ```
 | 
			
		||||
///
 | 
			
		||||
/// Adding headers:
 | 
			
		||||
| 
						 | 
				
			
			@ -67,6 +67,8 @@ pub struct Request<'a> {
 | 
			
		|||
    headers: HashMap<&'a str, &'a str>,
 | 
			
		||||
    /// Request body.
 | 
			
		||||
    body: &'a str,
 | 
			
		||||
    /// How many redirects are followed before an error is emitted.
 | 
			
		||||
    redirects: usize,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> Request<'a> {
 | 
			
		||||
| 
						 | 
				
			
			@ -87,9 +89,25 @@ impl<'a> Request<'a> {
 | 
			
		|||
            method,
 | 
			
		||||
            headers: HashMap::new(),
 | 
			
		||||
            body: "",
 | 
			
		||||
            redirects: 4,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Set the URL of the request.
 | 
			
		||||
    ///
 | 
			
		||||
    /// # Examples
 | 
			
		||||
    ///
 | 
			
		||||
    /// ```rust
 | 
			
		||||
    /// # use request::*;
 | 
			
		||||
    /// let request = Request::get("http://example.org/a").url("http://example.org/b");
 | 
			
		||||
    /// assert_eq!(format!("{request}"), "GET /b HTTP/1.1\r\nHost: example.org\r\n\r\n");
 | 
			
		||||
    /// ```
 | 
			
		||||
    pub fn url(self, url: &'a str) -> Self {
 | 
			
		||||
        let mut request = self;
 | 
			
		||||
        request.url = url;
 | 
			
		||||
        request
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Add a body to the request.
 | 
			
		||||
    ///
 | 
			
		||||
    /// # Examples
 | 
			
		||||
| 
						 | 
				
			
			@ -119,6 +137,13 @@ impl<'a> Request<'a> {
 | 
			
		|||
        request
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Set the maximum allowed redirects.
 | 
			
		||||
    pub fn redirects(self, max: usize) -> Self {
 | 
			
		||||
        let mut request = self;
 | 
			
		||||
        request.redirects = max;
 | 
			
		||||
        request
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Construct a new GET request.
 | 
			
		||||
    ///
 | 
			
		||||
    /// # Examples
 | 
			
		||||
| 
						 | 
				
			
			@ -163,7 +188,8 @@ impl<'a> Request<'a> {
 | 
			
		|||
 | 
			
		||||
        // create the stream
 | 
			
		||||
        let host = resolve(host(self.url).unwrap())?;
 | 
			
		||||
        let mut stream = TcpStream::connect((host, 80))?;
 | 
			
		||||
        let port = port(self.url).unwrap_or("80").parse::<u16>().unwrap();
 | 
			
		||||
        let mut stream = TcpStream::connect((host, port))?;
 | 
			
		||||
 | 
			
		||||
        // send the message
 | 
			
		||||
        stream.write_all(message.as_bytes())?;
 | 
			
		||||
| 
						 | 
				
			
			@ -175,6 +201,22 @@ impl<'a> Request<'a> {
 | 
			
		|||
            .collect::<Vec<_>>();
 | 
			
		||||
        let received = lines.join("\n");
 | 
			
		||||
 | 
			
		||||
        // check for redirects
 | 
			
		||||
        if lines[0].contains("301") && self.redirects > 0 {
 | 
			
		||||
            assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached
 | 
			
		||||
 | 
			
		||||
            // find new location
 | 
			
		||||
            let location = lines
 | 
			
		||||
                .iter()
 | 
			
		||||
                .find_map(|l| l.strip_prefix("Location: "))
 | 
			
		||||
                .unwrap(); // todo: error for missing location in redirect
 | 
			
		||||
            return self
 | 
			
		||||
                .clone()
 | 
			
		||||
                .redirects(self.redirects - 1)
 | 
			
		||||
                .url(location)
 | 
			
		||||
                .send();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // process response
 | 
			
		||||
        let response = Response::parse(&received).unwrap();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -183,13 +225,10 @@ impl<'a> Request<'a> {
 | 
			
		|||
}
 | 
			
		||||
impl<'a> fmt::Display for Request<'a> {
 | 
			
		||||
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 | 
			
		||||
        let (method, path, host, body) = (
 | 
			
		||||
            self.method,
 | 
			
		||||
            path(self.url).ok_or(fmt::Error)?,
 | 
			
		||||
            host(self.url).ok_or(fmt::Error)?,
 | 
			
		||||
            self.body,
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        let method = self.method;
 | 
			
		||||
        let path = path(self.url).ok_or(fmt::Error)?;
 | 
			
		||||
        let host = host(self.url).ok_or(fmt::Error)?;
 | 
			
		||||
        let body = self.body;
 | 
			
		||||
        let headers = iter::once(format!("Host: {host}"))
 | 
			
		||||
            .chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}")))
 | 
			
		||||
            .collect::<Vec<_>>()
 | 
			
		||||
| 
						 | 
				
			
			@ -248,11 +287,6 @@ impl Response {
 | 
			
		|||
            .map(|(a, b)| (a.to_string(), b.to_string()))
 | 
			
		||||
            .collect::<HashMap<String, String>>();
 | 
			
		||||
 | 
			
		||||
        // check if redirect
 | 
			
		||||
        if status == 301 {
 | 
			
		||||
            todo!()
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // parse body
 | 
			
		||||
        let body = parts.name("body").map(|m| m.as_str().to_string());
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -270,7 +304,7 @@ impl Response {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
static URI_REGEX: Lazy<Regex> = Lazy::new(|| {
 | 
			
		||||
    Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z:\\.\\-]+)(?P<path>/(?:.)*)?").unwrap()
 | 
			
		||||
    Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z\\.\\-]+)(?:\\:(?P<port>\\d+))?(?P<path>/(?:.)*)?").unwrap()
 | 
			
		||||
});
 | 
			
		||||
#[allow(dead_code)]
 | 
			
		||||
fn scheme(url: &str) -> Option<&str> {
 | 
			
		||||
| 
						 | 
				
			
			@ -279,6 +313,9 @@ fn scheme(url: &str) -> Option<&str> {
 | 
			
		|||
fn host(url: &str) -> Option<&str> {
 | 
			
		||||
    URI_REGEX.captures(url)?.name("host").map(|m| m.as_str())
 | 
			
		||||
}
 | 
			
		||||
fn port(url: &str) -> Option<&str> {
 | 
			
		||||
    URI_REGEX.captures(url)?.name("port").map(|m| m.as_str())
 | 
			
		||||
}
 | 
			
		||||
fn path(url: &str) -> Option<&str> {
 | 
			
		||||
    URI_REGEX
 | 
			
		||||
        .captures(url)?
 | 
			
		||||
| 
						 | 
				
			
			@ -289,8 +326,51 @@ fn path(url: &str) -> Option<&str> {
 | 
			
		|||
 | 
			
		||||
/// Resolve DNS request using system nameservers.
 | 
			
		||||
fn resolve(query: &str) -> Result<IpAddr, IoError> {
 | 
			
		||||
    // todo: local overrides
 | 
			
		||||
    if query.starts_with("localhost") {
 | 
			
		||||
        return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // todo: dns caching
 | 
			
		||||
    // create dns query header: [id, flags, questions, answers, authority, additional]
 | 
			
		||||
    let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
 | 
			
		||||
    let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); // [qtype, qclass] = [A, IN(ternet)]
 | 
			
		||||
 | 
			
		||||
    // convert query to standard dns name notation (max 63 characters for each label)
 | 
			
		||||
    let ascii = query.chars().filter(char::is_ascii).collect::<String>();
 | 
			
		||||
    let name = ascii
 | 
			
		||||
        .split('.')
 | 
			
		||||
        .flat_map(|l| {
 | 
			
		||||
            iter::once(u8::try_from(l.len()).unwrap_or(63).min(63)).chain(l.bytes().take(63))
 | 
			
		||||
        })
 | 
			
		||||
        .chain(iter::once(0))
 | 
			
		||||
        .collect::<Vec<u8>>();
 | 
			
		||||
 | 
			
		||||
    // construct the message
 | 
			
		||||
    let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
 | 
			
		||||
    message.extend(&name[..]);
 | 
			
		||||
    message.extend(bytemuck::cast_slice(&question));
 | 
			
		||||
 | 
			
		||||
    // create the socket
 | 
			
		||||
    let socket = UdpSocket::bind("0.0.0.0:0")?;
 | 
			
		||||
    socket.connect(&DNS_SERVERS[..])?;
 | 
			
		||||
 | 
			
		||||
    // write dns lookup message
 | 
			
		||||
    socket.send_to(&message, &DNS_SERVERS[..]).unwrap();
 | 
			
		||||
 | 
			
		||||
    // read dns response
 | 
			
		||||
    let mut buf = vec![0; 1024];
 | 
			
		||||
    let (n, _addr) = socket.recv_from(&mut buf)?;
 | 
			
		||||
    buf.resize(n, 0);
 | 
			
		||||
 | 
			
		||||
    // parse out the address
 | 
			
		||||
    let ip = &buf.get(message.len()..).unwrap()[12..];
 | 
			
		||||
    let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
 | 
			
		||||
 | 
			
		||||
    Ok(address)
 | 
			
		||||
}
 | 
			
		||||
static DNS_SERVERS: Lazy<Vec<SocketAddr>> = Lazy::new(|| {
 | 
			
		||||
    // find name servers (platform-dependent)
 | 
			
		||||
    let servers = {
 | 
			
		||||
    #[cfg(unix)]
 | 
			
		||||
    {
 | 
			
		||||
        use std::fs;
 | 
			
		||||
| 
						 | 
				
			
			@ -304,43 +384,11 @@ fn resolve(query: &str) -> Result<IpAddr, IoError> {
 | 
			
		|||
    }
 | 
			
		||||
    #[cfg(windows)]
 | 
			
		||||
    {
 | 
			
		||||
            ("8.8.8.8", 53).to_socket_addrs()?.collect::<Vec<_>>()
 | 
			
		||||
        // todo: get windows name servers
 | 
			
		||||
        vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
 | 
			
		||||
    }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // request dns resolution from nameservers
 | 
			
		||||
    let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
 | 
			
		||||
    let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be());
 | 
			
		||||
 | 
			
		||||
    // convert query to standard dns name notation
 | 
			
		||||
    let ascii = query.chars().filter(char::is_ascii).collect::<String>();
 | 
			
		||||
    let name = ascii
 | 
			
		||||
        .split('.')
 | 
			
		||||
        .flat_map(|l| iter::once(u8::try_from(l.len()).unwrap_or(63)).chain(l.bytes().take(63)))
 | 
			
		||||
        .chain(iter::once(0))
 | 
			
		||||
        .collect::<Vec<u8>>();
 | 
			
		||||
 | 
			
		||||
    // construct the message
 | 
			
		||||
    let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
 | 
			
		||||
    message.extend(&name[..]);
 | 
			
		||||
    message.extend(bytemuck::cast_slice(&question));
 | 
			
		||||
 | 
			
		||||
    // create the socket
 | 
			
		||||
    let socket = UdpSocket::bind("0.0.0.0:0")?;
 | 
			
		||||
    socket.connect(&servers[..])?;
 | 
			
		||||
 | 
			
		||||
    // write dns lookup message
 | 
			
		||||
    socket.send_to(&message, &servers[..]).unwrap();
 | 
			
		||||
 | 
			
		||||
    // read dns response
 | 
			
		||||
    let mut buf = vec![0; 1024];
 | 
			
		||||
    let (n, _addr) = socket.recv_from(&mut buf)?;
 | 
			
		||||
    buf.resize(n, 0);
 | 
			
		||||
 | 
			
		||||
    // parse out the address
 | 
			
		||||
    let answers = &buf[message.len()..];
 | 
			
		||||
    let ip = &answers[12..];
 | 
			
		||||
    let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
 | 
			
		||||
 | 
			
		||||
    Ok(address)
 | 
			
		||||
    #[cfg(not(any(unix, windows)))]
 | 
			
		||||
    {
 | 
			
		||||
        vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
 | 
			
		||||
    }
 | 
			
		||||
});
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue