added redirect handling, fixed some dns resolution issues
This commit is contained in:
		
							parent
							
								
									5dffb202a5
								
							
						
					
					
						commit
						8b997df4c2
					
				
					 1 changed files with 90 additions and 42 deletions
				
			
		
							
								
								
									
										158
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										158
									
								
								src/lib.rs
									
									
									
									
									
								
							| 
						 | 
					@ -12,7 +12,7 @@ use std::{
 | 
				
			||||||
    fmt,
 | 
					    fmt,
 | 
				
			||||||
    io::{BufRead, BufReader, Error as IoError, Write},
 | 
					    io::{BufRead, BufReader, Error as IoError, Write},
 | 
				
			||||||
    iter,
 | 
					    iter,
 | 
				
			||||||
    net::{IpAddr, Ipv4Addr, TcpStream, ToSocketAddrs, UdpSocket},
 | 
					    net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream, UdpSocket},
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// An HTTP request.
 | 
					/// An HTTP request.
 | 
				
			||||||
| 
						 | 
					@ -26,7 +26,7 @@ use std::{
 | 
				
			||||||
/// // ... start a local server on port 8000 ...
 | 
					/// // ... start a local server on port 8000 ...
 | 
				
			||||||
/// let request = Request::get("localhost:8000");
 | 
					/// let request = Request::get("localhost:8000");
 | 
				
			||||||
/// let response = request.send().unwrap();
 | 
					/// let response = request.send().unwrap();
 | 
				
			||||||
/// assert_eq!(response.status, 200)
 | 
					/// assert_eq!(response.status, 200);
 | 
				
			||||||
/// ```
 | 
					/// ```
 | 
				
			||||||
///
 | 
					///
 | 
				
			||||||
/// Adding headers:
 | 
					/// Adding headers:
 | 
				
			||||||
| 
						 | 
					@ -67,6 +67,8 @@ pub struct Request<'a> {
 | 
				
			||||||
    headers: HashMap<&'a str, &'a str>,
 | 
					    headers: HashMap<&'a str, &'a str>,
 | 
				
			||||||
    /// Request body.
 | 
					    /// Request body.
 | 
				
			||||||
    body: &'a str,
 | 
					    body: &'a str,
 | 
				
			||||||
 | 
					    /// How many redirects are followed before an error is emitted.
 | 
				
			||||||
 | 
					    redirects: usize,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl<'a> Request<'a> {
 | 
					impl<'a> Request<'a> {
 | 
				
			||||||
| 
						 | 
					@ -87,9 +89,25 @@ impl<'a> Request<'a> {
 | 
				
			||||||
            method,
 | 
					            method,
 | 
				
			||||||
            headers: HashMap::new(),
 | 
					            headers: HashMap::new(),
 | 
				
			||||||
            body: "",
 | 
					            body: "",
 | 
				
			||||||
 | 
					            redirects: 4,
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /// Set the URL of the request.
 | 
				
			||||||
 | 
					    ///
 | 
				
			||||||
 | 
					    /// # Examples
 | 
				
			||||||
 | 
					    ///
 | 
				
			||||||
 | 
					    /// ```rust
 | 
				
			||||||
 | 
					    /// # use request::*;
 | 
				
			||||||
 | 
					    /// let request = Request::get("http://example.org/a").url("http://example.org/b");
 | 
				
			||||||
 | 
					    /// assert_eq!(format!("{request}"), "GET /b HTTP/1.1\r\nHost: example.org\r\n\r\n");
 | 
				
			||||||
 | 
					    /// ```
 | 
				
			||||||
 | 
					    pub fn url(self, url: &'a str) -> Self {
 | 
				
			||||||
 | 
					        let mut request = self;
 | 
				
			||||||
 | 
					        request.url = url;
 | 
				
			||||||
 | 
					        request
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// Add a body to the request.
 | 
					    /// Add a body to the request.
 | 
				
			||||||
    ///
 | 
					    ///
 | 
				
			||||||
    /// # Examples
 | 
					    /// # Examples
 | 
				
			||||||
| 
						 | 
					@ -119,6 +137,13 @@ impl<'a> Request<'a> {
 | 
				
			||||||
        request
 | 
					        request
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /// Set the maximum allowed redirects.
 | 
				
			||||||
 | 
					    pub fn redirects(self, max: usize) -> Self {
 | 
				
			||||||
 | 
					        let mut request = self;
 | 
				
			||||||
 | 
					        request.redirects = max;
 | 
				
			||||||
 | 
					        request
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// Construct a new GET request.
 | 
					    /// Construct a new GET request.
 | 
				
			||||||
    ///
 | 
					    ///
 | 
				
			||||||
    /// # Examples
 | 
					    /// # Examples
 | 
				
			||||||
| 
						 | 
					@ -163,7 +188,8 @@ impl<'a> Request<'a> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // create the stream
 | 
					        // create the stream
 | 
				
			||||||
        let host = resolve(host(self.url).unwrap())?;
 | 
					        let host = resolve(host(self.url).unwrap())?;
 | 
				
			||||||
        let mut stream = TcpStream::connect((host, 80))?;
 | 
					        let port = port(self.url).unwrap_or("80").parse::<u16>().unwrap();
 | 
				
			||||||
 | 
					        let mut stream = TcpStream::connect((host, port))?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // send the message
 | 
					        // send the message
 | 
				
			||||||
        stream.write_all(message.as_bytes())?;
 | 
					        stream.write_all(message.as_bytes())?;
 | 
				
			||||||
| 
						 | 
					@ -175,6 +201,22 @@ impl<'a> Request<'a> {
 | 
				
			||||||
            .collect::<Vec<_>>();
 | 
					            .collect::<Vec<_>>();
 | 
				
			||||||
        let received = lines.join("\n");
 | 
					        let received = lines.join("\n");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // check for redirects
 | 
				
			||||||
 | 
					        if lines[0].contains("301") && self.redirects > 0 {
 | 
				
			||||||
 | 
					            assert!(self.redirects > 0, "maximum redirect limit reached"); // todo: error for maximum redirect limit reached
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // find new location
 | 
				
			||||||
 | 
					            let location = lines
 | 
				
			||||||
 | 
					                .iter()
 | 
				
			||||||
 | 
					                .find_map(|l| l.strip_prefix("Location: "))
 | 
				
			||||||
 | 
					                .unwrap(); // todo: error for missing location in redirect
 | 
				
			||||||
 | 
					            return self
 | 
				
			||||||
 | 
					                .clone()
 | 
				
			||||||
 | 
					                .redirects(self.redirects - 1)
 | 
				
			||||||
 | 
					                .url(location)
 | 
				
			||||||
 | 
					                .send();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // process response
 | 
					        // process response
 | 
				
			||||||
        let response = Response::parse(&received).unwrap();
 | 
					        let response = Response::parse(&received).unwrap();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -183,13 +225,10 @@ impl<'a> Request<'a> {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
impl<'a> fmt::Display for Request<'a> {
 | 
					impl<'a> fmt::Display for Request<'a> {
 | 
				
			||||||
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 | 
					    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 | 
				
			||||||
        let (method, path, host, body) = (
 | 
					        let method = self.method;
 | 
				
			||||||
            self.method,
 | 
					        let path = path(self.url).ok_or(fmt::Error)?;
 | 
				
			||||||
            path(self.url).ok_or(fmt::Error)?,
 | 
					        let host = host(self.url).ok_or(fmt::Error)?;
 | 
				
			||||||
            host(self.url).ok_or(fmt::Error)?,
 | 
					        let body = self.body;
 | 
				
			||||||
            self.body,
 | 
					 | 
				
			||||||
        );
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        let headers = iter::once(format!("Host: {host}"))
 | 
					        let headers = iter::once(format!("Host: {host}"))
 | 
				
			||||||
            .chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}")))
 | 
					            .chain(self.headers.iter().map(|(k, v)| format!("{k}: {v}")))
 | 
				
			||||||
            .collect::<Vec<_>>()
 | 
					            .collect::<Vec<_>>()
 | 
				
			||||||
| 
						 | 
					@ -248,11 +287,6 @@ impl Response {
 | 
				
			||||||
            .map(|(a, b)| (a.to_string(), b.to_string()))
 | 
					            .map(|(a, b)| (a.to_string(), b.to_string()))
 | 
				
			||||||
            .collect::<HashMap<String, String>>();
 | 
					            .collect::<HashMap<String, String>>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // check if redirect
 | 
					 | 
				
			||||||
        if status == 301 {
 | 
					 | 
				
			||||||
            todo!()
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // parse body
 | 
					        // parse body
 | 
				
			||||||
        let body = parts.name("body").map(|m| m.as_str().to_string());
 | 
					        let body = parts.name("body").map(|m| m.as_str().to_string());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -270,7 +304,7 @@ impl Response {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static URI_REGEX: Lazy<Regex> = Lazy::new(|| {
 | 
					static URI_REGEX: Lazy<Regex> = Lazy::new(|| {
 | 
				
			||||||
    Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z:\\.\\-]+)(?P<path>/(?:.)*)?").unwrap()
 | 
					    Regex::new("(?:(?P<scheme>https?)://)?(?P<host>[0-9a-zA-Z\\.\\-]+)(?:\\:(?P<port>\\d+))?(?P<path>/(?:.)*)?").unwrap()
 | 
				
			||||||
});
 | 
					});
 | 
				
			||||||
#[allow(dead_code)]
 | 
					#[allow(dead_code)]
 | 
				
			||||||
fn scheme(url: &str) -> Option<&str> {
 | 
					fn scheme(url: &str) -> Option<&str> {
 | 
				
			||||||
| 
						 | 
					@ -279,6 +313,9 @@ fn scheme(url: &str) -> Option<&str> {
 | 
				
			||||||
fn host(url: &str) -> Option<&str> {
 | 
					fn host(url: &str) -> Option<&str> {
 | 
				
			||||||
    URI_REGEX.captures(url)?.name("host").map(|m| m.as_str())
 | 
					    URI_REGEX.captures(url)?.name("host").map(|m| m.as_str())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					fn port(url: &str) -> Option<&str> {
 | 
				
			||||||
 | 
					    URI_REGEX.captures(url)?.name("port").map(|m| m.as_str())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
fn path(url: &str) -> Option<&str> {
 | 
					fn path(url: &str) -> Option<&str> {
 | 
				
			||||||
    URI_REGEX
 | 
					    URI_REGEX
 | 
				
			||||||
        .captures(url)?
 | 
					        .captures(url)?
 | 
				
			||||||
| 
						 | 
					@ -289,8 +326,51 @@ fn path(url: &str) -> Option<&str> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// Resolve DNS request using system nameservers.
 | 
					/// Resolve DNS request using system nameservers.
 | 
				
			||||||
fn resolve(query: &str) -> Result<IpAddr, IoError> {
 | 
					fn resolve(query: &str) -> Result<IpAddr, IoError> {
 | 
				
			||||||
 | 
					    // todo: local overrides
 | 
				
			||||||
 | 
					    if query.starts_with("localhost") {
 | 
				
			||||||
 | 
					        return Ok(IpAddr::V4(Ipv4Addr::LOCALHOST));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // todo: dns caching
 | 
				
			||||||
 | 
					    // create dns query header: [id, flags, questions, answers, authority, additional]
 | 
				
			||||||
 | 
					    let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
 | 
				
			||||||
 | 
					    let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be()); // [qtype, qclass] = [A, IN(ternet)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // convert query to standard dns name notation (max 63 characters for each label)
 | 
				
			||||||
 | 
					    let ascii = query.chars().filter(char::is_ascii).collect::<String>();
 | 
				
			||||||
 | 
					    let name = ascii
 | 
				
			||||||
 | 
					        .split('.')
 | 
				
			||||||
 | 
					        .flat_map(|l| {
 | 
				
			||||||
 | 
					            iter::once(u8::try_from(l.len()).unwrap_or(63).min(63)).chain(l.bytes().take(63))
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					        .chain(iter::once(0))
 | 
				
			||||||
 | 
					        .collect::<Vec<u8>>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // construct the message
 | 
				
			||||||
 | 
					    let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
 | 
				
			||||||
 | 
					    message.extend(&name[..]);
 | 
				
			||||||
 | 
					    message.extend(bytemuck::cast_slice(&question));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // create the socket
 | 
				
			||||||
 | 
					    let socket = UdpSocket::bind("0.0.0.0:0")?;
 | 
				
			||||||
 | 
					    socket.connect(&DNS_SERVERS[..])?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // write dns lookup message
 | 
				
			||||||
 | 
					    socket.send_to(&message, &DNS_SERVERS[..]).unwrap();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // read dns response
 | 
				
			||||||
 | 
					    let mut buf = vec![0; 1024];
 | 
				
			||||||
 | 
					    let (n, _addr) = socket.recv_from(&mut buf)?;
 | 
				
			||||||
 | 
					    buf.resize(n, 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // parse out the address
 | 
				
			||||||
 | 
					    let ip = &buf.get(message.len()..).unwrap()[12..];
 | 
				
			||||||
 | 
					    let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Ok(address)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					static DNS_SERVERS: Lazy<Vec<SocketAddr>> = Lazy::new(|| {
 | 
				
			||||||
    // find name servers (platform-dependent)
 | 
					    // find name servers (platform-dependent)
 | 
				
			||||||
    let servers = {
 | 
					 | 
				
			||||||
    #[cfg(unix)]
 | 
					    #[cfg(unix)]
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        use std::fs;
 | 
					        use std::fs;
 | 
				
			||||||
| 
						 | 
					@ -304,43 +384,11 @@ fn resolve(query: &str) -> Result<IpAddr, IoError> {
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    #[cfg(windows)]
 | 
					    #[cfg(windows)]
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
            ("8.8.8.8", 53).to_socket_addrs()?.collect::<Vec<_>>()
 | 
					        // todo: get windows name servers
 | 
				
			||||||
 | 
					        vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    };
 | 
					    #[cfg(not(any(unix, windows)))]
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
    // request dns resolution from nameservers
 | 
					        vec![SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53)]
 | 
				
			||||||
    let header: [u16; 6] = [0xabcd, 0x0100, 0x0001, 0x0000, 0x0000, 0x0000].map(|b: u16| b.to_be());
 | 
					 | 
				
			||||||
    let question: [u16; 2] = [0x0001, 0x0001].map(|b: u16| b.to_be());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // convert query to standard dns name notation
 | 
					 | 
				
			||||||
    let ascii = query.chars().filter(char::is_ascii).collect::<String>();
 | 
					 | 
				
			||||||
    let name = ascii
 | 
					 | 
				
			||||||
        .split('.')
 | 
					 | 
				
			||||||
        .flat_map(|l| iter::once(u8::try_from(l.len()).unwrap_or(63)).chain(l.bytes().take(63)))
 | 
					 | 
				
			||||||
        .chain(iter::once(0))
 | 
					 | 
				
			||||||
        .collect::<Vec<u8>>();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // construct the message
 | 
					 | 
				
			||||||
    let mut message = bytemuck::cast::<[u16; 6], [u8; 12]>(header).to_vec();
 | 
					 | 
				
			||||||
    message.extend(&name[..]);
 | 
					 | 
				
			||||||
    message.extend(bytemuck::cast_slice(&question));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // create the socket
 | 
					 | 
				
			||||||
    let socket = UdpSocket::bind("0.0.0.0:0")?;
 | 
					 | 
				
			||||||
    socket.connect(&servers[..])?;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // write dns lookup message
 | 
					 | 
				
			||||||
    socket.send_to(&message, &servers[..]).unwrap();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // read dns response
 | 
					 | 
				
			||||||
    let mut buf = vec![0; 1024];
 | 
					 | 
				
			||||||
    let (n, _addr) = socket.recv_from(&mut buf)?;
 | 
					 | 
				
			||||||
    buf.resize(n, 0);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // parse out the address
 | 
					 | 
				
			||||||
    let answers = &buf[message.len()..];
 | 
					 | 
				
			||||||
    let ip = &answers[12..];
 | 
					 | 
				
			||||||
    let address = IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Ok(address)
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					});
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue