diff --git a/Cargo.toml b/Cargo.toml index bb7bc87..e976c7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "oxhttp" -version = "0.2.7" +version = "0.3.0-dev" authors = ["Tpt "] license = "MIT OR Apache-2.0" readme = "README.md" @@ -15,14 +15,15 @@ rust-version = "1.74" [dependencies] flate2 = { version = "1", optional = true } +http = "1.1" httparse = "1.8" +url = { version = "2.4", optional = true } native-tls = { version = "0.2.11", optional = true } rustls = { version = "0.23.16", optional = true, default-features = false, features = ["std", "tls12"] } rustls-native-certs = { version = "0.8", optional = true } rustls-pki-types = { version = "1.10", optional = true } rustls-platform-verifier = { version = "0.5", optional = true } webpki-roots = { version = "0.26", optional = true } -url = "2.4" [dev-dependencies] codspeed-criterion-compat = "2" @@ -35,7 +36,7 @@ rustls-ring-webpki = ["rustls/ring", "rustls-pki-types", "webpki-roots"] rustls-aws-lc-platform-verifier = ["rustls/aws_lc_rs", "rustls-pki-types", "rustls-platform-verifier"] rustls-aws-lc-native = ["rustls/aws_lc_rs", "rustls-native-certs", "rustls-pki-types"] rustls-aws-lc-webpki = ["rustls/aws_lc_rs", "rustls-pki-types", "webpki-roots"] -client = [] +client = ["dep:url"] server = [] [[bench]] diff --git a/README.md b/README.md index cba21dd..c12920d 100644 --- a/README.md +++ b/README.md @@ -39,13 +39,14 @@ Example: ```rust use oxhttp::Client; -use oxhttp::model::{Request, Method, Status, HeaderName}; +use oxhttp::model::{Body, Request, Method, StatusCode, HeaderName}; +use oxhttp::model::header::CONTENT_TYPE; use std::io::Read; let client = Client::new(); -let response = client.request(Request::builder(Method::GET, "http://example.com".parse().unwrap()).build()).unwrap(); -assert_eq!(response.status(), Status::OK); -assert_eq!(response.header(&HeaderName::CONTENT_TYPE).unwrap().as_ref(), b"text/html"); +let response = client.request(Request::builder().uri("http://example.com").body(Body::empty()).unwrap()).unwrap(); +assert_eq!(response.status(), StatusCode::OK); +assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "text/html"); let body = response.into_body().to_string().unwrap(); ``` @@ -60,15 +61,15 @@ Example: ```rust no_run use std::net::{Ipv4Addr, Ipv6Addr}; use oxhttp::Server; -use oxhttp::model::{Response, Status}; +use oxhttp::model::{Body, Response, StatusCode}; use std::time::Duration; // Builds a new server that returns a 404 everywhere except for "/" where it returns the body 'home' let mut server = Server::new( | request| { -if request.url().path() == "/" { -Response::builder(Status::OK).with_body("home") +if request.uri().path() == "/" { +Response::builder().body(Body::from("home")).unwrap() } else { -Response::builder(Status::NOT_FOUND).build() +Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap() } }); // We bind the server to localhost on both IPv4 and v6 diff --git a/benches/lib.rs b/benches/lib.rs index 1be9624..7753c97 100644 --- a/benches/lib.rs +++ b/benches/lib.rs @@ -1,24 +1,23 @@ use codspeed_criterion_compat::{criterion_group, criterion_main, Criterion}; -use oxhttp::model::{Body, Method, Request, Response, Status}; +use oxhttp::model::{Body, Request, Response, Uri}; use oxhttp::{Client, Server}; use std::io; use std::io::Read; use std::net::{Ipv4Addr, SocketAddrV4}; -use url::Url; fn client_server_no_body(c: &mut Criterion) { - Server::new(|_| Response::builder(Status::OK).build()) + Server::new(|_| Response::builder().body(Body::empty()).unwrap()) .bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3456)) .spawn() .unwrap(); let client = Client::new(); - let url = Url::parse("http://localhost:3456").unwrap(); + let uri = Uri::try_from("http://localhost:3456").unwrap(); c.bench_function("client_server_no_body", |b| { b.iter(|| { client - .request(Request::builder(Method::GET, url.clone()).build()) + .request(Request::builder().uri(uri.clone()).body(()).unwrap()) .unwrap(); }) }); @@ -28,20 +27,25 @@ fn client_server_fixed_body(c: &mut Criterion) { Server::new(|request| { let mut body = Vec::new(); request.body_mut().read_to_end(&mut body).unwrap(); - Response::builder(Status::OK).with_body(body) + Response::builder().body(Body::from(body)).unwrap() }) .bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3457)) .spawn() .unwrap(); let client = Client::new(); - let url = Url::parse("http://localhost:3456").unwrap(); + let uri = Uri::try_from("http://localhost:3456").unwrap(); let body = vec![16u8; 1024]; c.bench_function("client_server_fixed_body", |b| { b.iter(|| { client - .request(Request::builder(Method::GET, url.clone()).with_body(body.clone())) + .request( + Request::builder() + .uri(uri.clone()) + .body(body.clone()) + .unwrap(), + ) .unwrap(); }) }); @@ -51,21 +55,23 @@ fn client_server_chunked_body(c: &mut Criterion) { Server::new(|request| { let mut body = Vec::new(); request.body_mut().read_to_end(&mut body).unwrap(); - Response::builder(Status::OK).build() + Response::builder().body(Body::empty()).unwrap() }) .bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 3458)) .spawn() .unwrap(); let client = Client::new(); - let url = Url::parse("http://localhost:3456").unwrap(); + let uri = Uri::try_from("http://localhost:3456").unwrap(); c.bench_function("client_server_chunked_body", |b| { b.iter(|| { client .request( - Request::builder(Method::GET, url.clone()) - .with_body(Body::from_read(ChunkedReader::default())), + Request::builder() + .uri(uri.clone()) + .body(Body::from_read(ChunkedReader::default())) + .unwrap(), ) .unwrap(); }) diff --git a/src/client.rs b/src/client.rs index d333a0c..78566b2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,9 +1,11 @@ #![allow(unreachable_code, clippy::needless_return)] use crate::io::{decode_response, encode_request, BUFFER_CAPACITY}; -use crate::model::{ - HeaderName, HeaderValue, InvalidHeader, Method, Request, Response, Status, Url, +use crate::model::header::{ + InvalidHeaderValue, ACCEPT_ENCODING, CONNECTION, LOCATION, RANGE, USER_AGENT, }; +use crate::model::uri::Scheme; +use crate::model::{Body, HeaderValue, Method, Request, Response, StatusCode, Uri}; use crate::utils::{invalid_data_error, invalid_input_error}; #[cfg(feature = "native-tls")] use native_tls::TlsConnector; @@ -30,12 +32,13 @@ use rustls_pki_types::ServerName; ))] use rustls_platform_verifier::ConfigVerifierExt; use std::io::{BufReader, BufWriter, Error, ErrorKind, Result}; -use std::net::{SocketAddr, TcpStream}; +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; #[cfg(all(feature = "rustls", not(feature = "native-tls")))] use std::sync::Arc; #[cfg(any(feature = "native-tls", feature = "rustls"))] use std::sync::OnceLock; use std::time::Duration; +use url::Url; #[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))] use webpki_roots::TLS_SERVER_ROOTS; @@ -65,14 +68,19 @@ use webpki_roots::TLS_SERVER_ROOTS; /// Missing: HSTS support, authentication and keep alive. /// /// ``` +/// use http::header::CONTENT_TYPE; +/// use oxhttp::model::{Body, HeaderName, Method, Request, StatusCode}; /// use oxhttp::Client; -/// use oxhttp::model::{Request, Method, Status, HeaderName}; /// use std::io::Read; /// /// let client = Client::new(); -/// let response = client.request(Request::builder(Method::GET, "http://example.com".parse()?).build())?; -/// assert_eq!(response.status(), Status::OK); -/// assert_eq!(response.header(&HeaderName::CONTENT_TYPE).unwrap().as_ref(), b"text/html"); +/// let response = client.request( +/// Request::builder() +/// .uri("http://example.com") +/// .body(Body::empty())?, +/// )?; +/// assert_eq!(response.status(), StatusCode::OK); +/// assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "text/html"); /// let body = response.into_body().to_string()?; /// # Result::<_,Box>::Ok(()) /// ``` @@ -101,7 +109,7 @@ impl Client { pub fn with_user_agent( mut self, user_agent: impl Into, - ) -> std::result::Result { + ) -> std::result::Result { self.user_agent = Some(HeaderValue::try_from(user_agent.into())?); Ok(self) } @@ -114,178 +122,171 @@ impl Client { self } - pub fn request(&self, mut request: Request) -> Result { + pub fn request(&self, request: Request>) -> Result> { + let mut request = request.map(Into::into); // Loops the number of allowed redirections + 1 for _ in 0..(self.redirection_limit + 1) { let previous_method = request.method().clone(); let response = self.single_request(&mut request)?; - let Some(location) = response.header(&HeaderName::LOCATION) else { + let Some(location) = response.headers().get(LOCATION) else { return Ok(response); }; - let new_method = match response.status() { - Status::MOVED_PERMANENTLY | Status::FOUND | Status::SEE_OTHER => { + let mut request_builder = Request::builder(); + request_builder = request_builder.method(match response.status() { + StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => { if previous_method == Method::HEAD { Method::HEAD } else { Method::GET } } - Status::TEMPORARY_REDIRECT | Status::PERMANENT_REDIRECT + StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT if previous_method.is_safe() => { previous_method } _ => return Ok(response), - }; + }); let location = location.to_str().map_err(invalid_data_error)?; - let new_url = request.url().join(location).map_err(|e| { - invalid_data_error(format!( - "Invalid URL in Location header raising error {e}: {location}" - )) - })?; - let mut request_builder = Request::builder(new_method, new_url); + request_builder = request_builder.uri(join_urls(request.uri(), location)?); for (header_name, header_value) in request.headers() { - request_builder - .headers_mut() - .set(header_name.clone(), header_value.clone()); + request_builder = request_builder.header(header_name, header_value); } - request = request_builder.build(); + request = request_builder.body(Body::empty()).map_err(|e| { + invalid_input_error(format!( + "Failure when trying to build the redirected request: {e}" + )) + })?; } Err(Error::new( ErrorKind::Other, format!( "The server requested too many redirects ({}). The latest redirection target is {}", self.redirection_limit + 1, - request.url() + request.uri() ), )) } - fn single_request(&self, request: &mut Request) -> Result { + fn single_request(&self, request: &mut Request) -> Result> { // Additional headers { let headers = request.headers_mut(); - headers.set( - HeaderName::CONNECTION, - HeaderValue::new_unchecked("close".as_bytes()), - ); + headers.insert(CONNECTION, HeaderValue::from_static("close")); if let Some(user_agent) = &self.user_agent { - if !headers.contains(&HeaderName::USER_AGENT) { - headers.set(HeaderName::USER_AGENT, user_agent.clone()) - } + headers + .entry(USER_AGENT) + .or_insert_with(|| user_agent.clone()); } - if cfg!(feature = "flate2") - && !headers.contains(&HeaderName::ACCEPT_ENCODING) - && !headers.contains(&HeaderName::RANGE) - { - headers.set( - HeaderName::ACCEPT_ENCODING, - HeaderValue::new_unchecked("gzip,deflate".as_bytes()), - ); + if cfg!(feature = "flate2") && !headers.contains_key(RANGE) { + headers + .entry(ACCEPT_ENCODING) + .or_insert_with(|| HeaderValue::from_static("gzip,deflate")); } } #[cfg(any(feature = "native-tls", feature = "rustls"))] let host = request - .url() - .host_str() + .uri() + .host() .ok_or_else(|| invalid_input_error("No host provided"))?; - match request.url().scheme() { - "http" => { - let addresses = get_and_validate_socket_addresses(request.url(), 80)?; - let stream = self.connect(&addresses)?; - let stream = - encode_request(request, BufWriter::with_capacity(BUFFER_CAPACITY, stream))? - .into_inner() - .map_err(|e| e.into_error())?; - decode_response(BufReader::with_capacity(BUFFER_CAPACITY, stream)) - } - "https" => { - #[cfg(feature = "native-tls")] - { - static TLS_CONNECTOR: OnceLock = OnceLock::new(); - - let addresses = get_and_validate_socket_addresses(request.url(), 443)?; - let stream = self.connect(&addresses)?; - let stream = TLS_CONNECTOR - .get_or_init(|| match TlsConnector::new() { - Ok(connector) => connector, - Err(e) => panic!("Error while loading TLS configuration: {}", e), // TODO: use get_or_try_init - }) - .connect(host, stream) - .map_err(|e| Error::new(ErrorKind::Other, e))?; - let stream = - encode_request(request, BufWriter::with_capacity(BUFFER_CAPACITY, stream))? - .into_inner() - .map_err(|e| e.into_error())?; - return decode_response(BufReader::with_capacity(BUFFER_CAPACITY, stream)); - } - #[cfg(all(feature = "rustls", not(feature = "native-tls")))] - { - #[cfg(not(any( - feature = "rustls-platform-verifier", - feature = "rustls-native-certs", - feature = "webpki-roots" - )))] - compile_error!( + let scheme = request.uri().scheme().ok_or_else(|| { + invalid_input_error(format!("A URI scheme must be set, found {}", request.uri())) + })?; + + if *scheme == Scheme::HTTP { + let addresses = get_and_validate_socket_addresses(request.uri(), 80)?; + let stream = self.connect(&addresses)?; + let stream = + encode_request(request, BufWriter::with_capacity(BUFFER_CAPACITY, stream))? + .into_inner() + .map_err(|e| e.into_error())?; + return decode_response(BufReader::with_capacity(BUFFER_CAPACITY, stream)); + } + + #[cfg(feature = "native-tls")] + if *scheme == Scheme::HTTPS { + static TLS_CONNECTOR: OnceLock = OnceLock::new(); + + let addresses = get_and_validate_socket_addresses(request.uri(), 443)?; + let stream = self.connect(&addresses)?; + let stream = TLS_CONNECTOR + .get_or_init(|| match TlsConnector::new() { + Ok(connector) => connector, + Err(e) => panic!("Error while loading TLS configuration: {}", e), // TODO: use get_or_try_init + }) + .connect(host, stream) + .map_err(|e| Error::new(ErrorKind::Other, e))?; + let stream = + encode_request(request, BufWriter::with_capacity(BUFFER_CAPACITY, stream))? + .into_inner() + .map_err(|e| e.into_error())?; + return decode_response(BufReader::with_capacity(BUFFER_CAPACITY, stream)); + } + #[cfg(all(feature = "rustls", not(feature = "native-tls")))] + if *scheme == Scheme::HTTPS { + #[cfg(not(any( + feature = "rustls-platform-verifier", + feature = "rustls-native-certs", + feature = "webpki-roots" + )))] + compile_error!( "rustls-platform-verifier or rustls-native-certs or webpki-roots must be installed to use OxHTTP with Rustls" ); - static RUSTLS_CONFIG: OnceLock> = OnceLock::new(); + static RUSTLS_CONFIG: OnceLock> = OnceLock::new(); - let rustls_config = RUSTLS_CONFIG.get_or_init(|| { - #[cfg(feature = "rustls-platform-verifier")] - { - Arc::new(ClientConfig::with_platform_verifier()) - } - #[cfg(not(feature = "rustls-platform-verifier"))] - { - #[cfg(feature = "rustls-native-certs")] - let root_store = { - let mut root_store = RootCertStore::empty(); - for cert in load_native_certs().certs { - root_store.add(cert).unwrap(); - } - root_store - }; - - #[cfg(all( - feature = "webpki-roots", - not(feature = "rustls-native-certs") - ))] - let root_store = RootCertStore { - roots: TLS_SERVER_ROOTS.to_vec(), - }; - - Arc::new( - ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(), - ) + let rustls_config = RUSTLS_CONFIG.get_or_init(|| { + #[cfg(feature = "rustls-platform-verifier")] + { + Arc::new(ClientConfig::with_platform_verifier()) + } + #[cfg(not(feature = "rustls-platform-verifier"))] + { + #[cfg(feature = "rustls-native-certs")] + let root_store = { + let mut root_store = RootCertStore::empty(); + for cert in load_native_certs().certs { + root_store.add(cert).unwrap(); } - }); - let addresses = get_and_validate_socket_addresses(request.url(), 443)?; - let dns_name = ServerName::try_from(host) - .map_err(invalid_input_error)? - .to_owned(); - let connection = ClientConnection::new(Arc::clone(rustls_config), dns_name) - .map_err(|e| Error::new(ErrorKind::Other, e))?; - let stream = StreamOwned::new(connection, self.connect(&addresses)?); - let stream = - encode_request(request, BufWriter::with_capacity(BUFFER_CAPACITY, stream))? - .into_inner() - .map_err(|e| e.into_error())?; - return decode_response(BufReader::with_capacity(BUFFER_CAPACITY, stream)); + root_store + }; + + #[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))] + let root_store = RootCertStore { + roots: TLS_SERVER_ROOTS.to_vec(), + }; + + Arc::new( + ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(), + ) } - #[cfg(not(any(feature = "native-tls", feature = "rustls")))] - return Err(invalid_input_error("HTTPS is not supported by the client. You should enable the `native-tls` or `rustls` feature of the `oxhttp` crate")); - } - _ => Err(invalid_input_error(format!( - "Not supported URL scheme: {}", - request.url().scheme() - ))), + }); + let addresses = get_and_validate_socket_addresses(request.uri(), 443)?; + let dns_name = ServerName::try_from(host) + .map_err(invalid_input_error)? + .to_owned(); + let connection = ClientConnection::new(Arc::clone(rustls_config), dns_name) + .map_err(|e| Error::new(ErrorKind::Other, e))?; + let stream = StreamOwned::new(connection, self.connect(&addresses)?); + let stream = + encode_request(request, BufWriter::with_capacity(BUFFER_CAPACITY, stream))? + .into_inner() + .map_err(|e| e.into_error())?; + return decode_response(BufReader::with_capacity(BUFFER_CAPACITY, stream)); + } + + #[cfg(not(any(feature = "native-tls", feature = "rustls")))] + if *scheme == Scheme::HTTPS { + return Err(invalid_input_error("HTTPS is not supported by the client. You should enable the `native-tls` or `rustls` feature of the `oxhttp` crate")); } + + Err(invalid_input_error(format!( + "Not supported URL scheme: {scheme}" + ))) } fn connect(&self, addresses: &[SocketAddr]) -> Result { @@ -325,8 +326,12 @@ const BAD_PORTS: [u16; 80] = [ 6697, 10080, ]; -fn get_and_validate_socket_addresses(url: &Url, default_port: u16) -> Result> { - let addresses = url.socket_addrs(|| Some(default_port))?; +fn get_and_validate_socket_addresses(uri: &Uri, default_port: u16) -> Result> { + let host = uri + .host() + .ok_or_else(|| invalid_input_error(format!("No host in request URL {uri}")))?; + let port = uri.port_u16().unwrap_or(default_port); + let addresses = (host, port).to_socket_addrs()?.collect::>(); for address in &addresses { if BAD_PORTS.binary_search(&address.port()).is_ok() { return Err(invalid_input_error(format!( @@ -338,22 +343,48 @@ fn get_and_validate_socket_addresses(url: &Url, default_port: u16) -> Result Result { + Uri::try_from( + Url::parse(&base.to_string()) + .map_err(|e| { + Error::new( + ErrorKind::InvalidInput, + format!("Invalid base URL '{base}': {e}"), + ) + })? + .join(relative) + .map_err(|e| { + Error::new( + ErrorKind::InvalidData, + format!("Invalid location header URL '{relative}': {e}"), + ) + })? + .to_string(), + ) + .map_err(|e| { + Error::new( + ErrorKind::InvalidData, + format!("Invalid location header URL '{relative}': {e}"), + ) + }) +} + #[cfg(test)] mod tests { use super::*; - use crate::model::{Method, Status}; + use crate::model::header::CONTENT_TYPE; #[test] fn test_http_get_ok() -> Result<()> { let client = Client::new(); let response = client.request( - Request::builder(Method::GET, "http://example.com".parse().unwrap()).build(), + Request::builder() + .uri("http://example.com") + .body(()) + .unwrap(), )?; - assert_eq!(response.status(), Status::OK); - assert_eq!( - response.header(&HeaderName::CONTENT_TYPE).unwrap().as_ref(), - b"text/html" - ); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "text/html"); let body = response.into_body().to_string()?; assert!(body.contains(" Result<()> { let client = Client::new(); let response = client.request( - Request::builder(Method::GET, "http://example.com:80".parse().unwrap()).build(), + Request::builder() + .uri("http://example.com:80") + .body(()) + .unwrap(), )?; - assert_eq!(response.status(), Status::OK); - assert_eq!( - response.header(&HeaderName::CONTENT_TYPE).unwrap().as_ref(), - b"text/html" - ); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "text/html"); Ok(()) } @@ -395,7 +426,10 @@ mod tests { let client = Client::new(); assert!(client .request( - Request::builder(Method::GET, "http://example.com:22".parse().unwrap()).build(), + Request::builder() + .uri("http://example.com:22") + .body(()) + .unwrap(), ) .is_err()); } @@ -405,13 +439,13 @@ mod tests { fn test_https_get_ok() -> Result<()> { let client = Client::new(); let response = client.request( - Request::builder(Method::GET, "https://example.com".parse().unwrap()).build(), + Request::builder() + .uri("https://example.com") + .body(()) + .unwrap(), )?; - assert_eq!(response.status(), Status::OK); - assert_eq!( - response.header(&HeaderName::CONTENT_TYPE).unwrap().as_ref(), - b"text/html" - ); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "text/html"); Ok(()) } @@ -420,7 +454,12 @@ mod tests { fn test_https_get_err() { let client = Client::new(); assert!(client - .request(Request::builder(Method::GET, "https://example.com".parse().unwrap()).build()) + .request( + Request::builder() + .uri("https://example.com") + .body(()) + .unwrap() + ) .is_err()); } @@ -428,15 +467,14 @@ mod tests { fn test_http_get_not_found() -> Result<()> { let client = Client::new(); let response = client.request( - Request::builder( - Method::GET, - "http://example.com/not_existing".parse().unwrap(), - ) - .build(), + Request::builder() + .uri("http://example.com/not_existing") + .body(()) + .unwrap(), )?; assert!(matches!( response.status(), - Status::NOT_FOUND | Status::INTERNAL_SERVER_ERROR + StatusCode::NOT_FOUND | StatusCode::INTERNAL_SERVER_ERROR )); Ok(()) } @@ -446,11 +484,10 @@ mod tests { let client = Client::new(); assert!(client .request( - Request::builder( - Method::GET, - "file://example.com/not_existing".parse().unwrap(), - ) - .build(), + Request::builder() + .uri("file://example.com/not_existing") + .body(()) + .unwrap(), ) .is_err()); } @@ -460,9 +497,12 @@ mod tests { fn test_redirection() -> Result<()> { let client = Client::new().with_redirection_limit(5); let response = client.request( - Request::builder(Method::GET, "http://wikipedia.org".parse().unwrap()).build(), + Request::builder() + .uri("http://wikipedia.org") + .body(()) + .unwrap(), )?; - assert_eq!(response.status(), Status::OK); + assert_eq!(response.status(), StatusCode::OK); Ok(()) } } diff --git a/src/io/decoder.rs b/src/io/decoder.rs index 85b73b5..40f5637 100644 --- a/src/io/decoder.rs +++ b/src/io/decoder.rs @@ -1,11 +1,15 @@ +use crate::model::header::{CONTENT_ENCODING, CONTENT_LENGTH, HOST, TRANSFER_ENCODING}; +use crate::model::request::Builder as RequestBuilder; +use crate::model::uri::{Authority, Parts as UriParts, PathAndQuery, Scheme}; use crate::model::{ - Body, ChunkedTransferPayload, HeaderName, HeaderValue, Headers, Method, Request, - RequestBuilder, Response, Status, Url, + Body, ChunkedTransferPayload, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, + StatusCode, Uri, Version, }; use crate::utils::invalid_data_error; +use httparse::Header; use std::cmp::min; use std::io::{BufRead, Error, ErrorKind, Read, Result}; -use std::str::{self, FromStr}; +use std::str::FromStr; const DEFAULT_SIZE: usize = 1024; const MAX_HEADER_SIZE: u64 = 8 * 1024; @@ -28,86 +32,83 @@ pub fn decode_request_headers( )); } - let method = Method::from_str( - parsed_request - .method - .ok_or_else(|| invalid_data_error("No method in the HTTP request"))?, - ) - .map_err(invalid_data_error)?; + // We build the request + let mut request = Request::builder(); + decode_headers(parsed_request.headers, request.headers_mut().unwrap())?; + if let Some(version) = parsed_request.version { + request = request.version(match version { + 0 => Version::HTTP_10, + 1 => Version::HTTP_11, + _ => { + return Err(invalid_data_error(format!( + "Unsupported HTTP version {version}" + ))) + } + }); + } + // Method + request = request.method( + Method::from_str( + parsed_request + .method + .ok_or_else(|| invalid_data_error("No method in the HTTP request"))?, + ) + .map_err(invalid_data_error)?, + ); + + // URI let path = parsed_request .path .ok_or_else(|| invalid_data_error("No path in the HTTP request"))?; - let url = if let Some(host) = parsed_request.headers.iter().find_map(|header| { - if header.name.eq_ignore_ascii_case("host") { - Some(header.value) - } else { - None - } - }) { - let host = str::from_utf8(host) - .map_err(|e| invalid_data_error(format!("Invalid host header value: {e}")))?; - let base_url = Url::parse(&if is_connection_secure { - format!("https://{host}") - } else { - format!("http://{host}") - }) - .map_err(|e| invalid_data_error(format!("Invalid host header value '{host}': {e}")))?; - if path == "*" { - base_url - } else { - base_url - .join(path) - .map_err(|e| invalid_data_error(format!("Invalid request path '{path}': {e}")))? - } + let mut uri_parts = if path == "*" { + let mut uri_parts = UriParts::default(); + uri_parts.path_and_query = Some(PathAndQuery::from_static("")); + uri_parts } else { - Url::parse(path).map_err(|e| { - invalid_data_error(format!( - "No host header in HTTP request and not absolute path '{path}': {e}" - )) - })? + Uri::try_from(if path == "*" { "" } else { path }) + .map_err(invalid_data_error)? + .into_parts() }; - - // We validate that the URL is valid - if !url.has_authority() { - return Err(invalid_data_error("No host header in HTTP request")); - } if is_connection_secure { - if url.scheme() != "https" { - return Err(invalid_data_error("The HTTPS URL scheme should be 'https")); + if *uri_parts.scheme.get_or_insert(Scheme::HTTPS) != Scheme::HTTPS { + return Err(invalid_data_error("The HTTPS URL scheme must be 'https")); } - } else if url.scheme() != "http" { - return Err(invalid_data_error("The HTTP URL scheme should be 'http")); - } - - let mut request = Request::builder(method, url); - for header in parsed_request.headers { - request.headers_mut().append( - HeaderName::new_unchecked(header.name.to_ascii_lowercase()), - HeaderValue::new_unchecked(header.value.to_vec()), - ); - } - if parsed_request.version == Some(0) { - // Hack to fallback to default HTTP 1.0 behavior of closing connections - if !request.headers().contains(&HeaderName::CONNECTION) { - request.headers_mut().append( - HeaderName::CONNECTION, - HeaderValue::new_unchecked("close".as_bytes()), + } else if *uri_parts.scheme.get_or_insert(Scheme::HTTP) != Scheme::HTTP { + return Err(invalid_data_error("The HTTP URL scheme must be 'http")); + } + if uri_parts.authority.is_none() { + uri_parts.authority = Some( + Authority::try_from( + request + .headers_ref() + .unwrap() + .get(HOST) + .ok_or_else(|| invalid_data_error("No host header in HTTP request"))? + .as_bytes(), ) - } + .map_err(|e| invalid_data_error(format!("Invalid host header value: {e}")))?, + ); } + request = request.uri(Uri::from_parts(uri_parts).unwrap()); Ok(request) } pub fn decode_request_body( request: RequestBuilder, reader: impl BufRead + 'static, -) -> Result { - let body = decode_body(request.headers(), reader)?; - Ok(request.with_body(body)) +) -> Result> { + let body = if let Some(headers) = request.headers_ref() { + decode_body(headers, reader)? + } else { + Body::empty() + }; + request + .body(body) + .map_err(|e| invalid_data_error(format!("Unexpected error when parsing the request: {e}"))) } -pub fn decode_response(mut reader: impl BufRead + 'static) -> Result { +pub fn decode_response(mut reader: impl BufRead + 'static) -> Result> { // Let's read the headers let buffer = read_header_bytes(&mut reader)?; let mut headers = [httparse::EMPTY_HEADER; DEFAULT_SIZE]; @@ -122,7 +123,7 @@ pub fn decode_response(mut reader: impl BufRead + 'static) -> Result { )); } - let status = Status::try_from( + let status = StatusCode::from_u16( parsed_response .code .ok_or_else(|| invalid_data_error("No status code in the HTTP response"))?, @@ -130,16 +131,15 @@ pub fn decode_response(mut reader: impl BufRead + 'static) -> Result { .map_err(invalid_data_error)?; // Let's build the response - let mut response = Response::builder(status); - for header in parsed_response.headers { - response.headers_mut().append( - HeaderName::new_unchecked(header.name.to_ascii_lowercase()), - HeaderValue::new_unchecked(header.value.to_vec()), - ); - } + let mut response = Response::builder().status(status); + decode_headers(parsed_response.headers, response.headers_mut().unwrap())?; - let body = decode_body(response.headers(), reader)?; - Ok(response.with_body(body)) + let body = if let Some(headers) = response.headers_ref() { + decode_body(headers, reader)? + } else { + Body::empty() + }; + Ok(response.body(body).unwrap()) } fn read_header_bytes(reader: impl BufRead) -> Result> { @@ -166,15 +166,15 @@ fn read_header_bytes(reader: impl BufRead) -> Result> { return Err(invalid_data_error("The headers size should fit in 8kb")); } if buffer.ends_with(b"\n\n") { - break; //end of buffer + break; // end of buffer } } Ok(buffer) } -fn decode_body(headers: &Headers, reader: impl BufRead + 'static) -> Result { - let content_length = headers.get(&HeaderName::CONTENT_LENGTH); - let transfer_encoding = headers.get(&HeaderName::TRANSFER_ENCODING); +fn decode_body(headers: &HeaderMap, reader: impl BufRead + 'static) -> Result { + let content_length = headers.get(CONTENT_LENGTH); + let transfer_encoding = headers.get(TRANSFER_ENCODING); if transfer_encoding.is_some() && content_length.is_some() { return Err(invalid_data_error( "Transfer-Encoding and Content-Length should not be set at the same time", @@ -205,14 +205,27 @@ fn decode_body(headers: &Headers, reader: impl BufRead + 'static) -> Result Result { - let Some(content_encoding) = headers.get(&HeaderName::CONTENT_ENCODING) else { +fn decode_headers(from: &[Header<'_>], to: &mut HeaderMap) -> Result<()> { + for header in from { + to.try_append( + HeaderName::try_from(header.name) + .map_err(|e| invalid_data_error(format!("Invalid header name: {e}")))?, + HeaderValue::try_from(header.value) + .map_err(|e| invalid_data_error(format!("Invalid header value: {e}")))?, + ) + .map_err(|e| invalid_data_error(format!("Too many headers: {e}")))?; + } + Ok(()) +} + +fn decode_content_encoding(body: Body, headers: &HeaderMap) -> Result { + let Some(content_encoding) = headers.get(CONTENT_ENCODING) else { return Ok(body); }; match content_encoding.as_ref() { @@ -231,7 +244,7 @@ struct ChunkedDecoder { is_start: bool, chunk_position: usize, chunk_size: usize, - trailers: Option, + trailers: Option, } impl Read for ChunkedDecoder { @@ -300,7 +313,7 @@ impl Read for ChunkedDecoder { self.buffer.push(b'\n') } if self.buffer.ends_with(b"\n\n") { - break; //end of buffer + break; // end of buffer } } let mut trailers = [httparse::EMPTY_HEADER; DEFAULT_SIZE]; @@ -317,13 +330,8 @@ impl Read for ChunkedDecoder { "Invalid data at the end of the trailer section", )); } - let mut trailers = Headers::new(); - for trailer in parsed_trailers { - trailers.append( - HeaderName::new_unchecked(trailer.name.to_ascii_lowercase()), - HeaderValue::new_unchecked(trailer.value.to_vec()), - ); - } + let mut trailers = HeaderMap::new(); + decode_headers(parsed_trailers, &mut trailers)?; self.trailers = Some(trailers); return Ok(0); } @@ -332,7 +340,7 @@ impl Read for ChunkedDecoder { } impl ChunkedTransferPayload for ChunkedDecoder { - fn trailers(&self) -> Option<&Headers> { + fn trailers(&self) -> Option<&HeaderMap> { self.trailers.as_ref() } } @@ -340,15 +348,21 @@ impl ChunkedTransferPayload for ChunkedDecoder { #[cfg(test)] mod tests { use super::*; - use std::ops::Deref; + use crate::model::header::CONTENT_TYPE; + use crate::model::HeaderName; #[test] fn decode_request_target_origin_form() -> Result<()> { let request = decode_request_headers( &mut b"GET /where?q=now HTTP/1.1\nHost: www.example.org\n\n".as_slice(), false, - )?; - assert_eq!(request.url().as_str(), "http://www.example.org/where?q=now"); + )? + .body(()) + .unwrap(); + assert_eq!( + request.uri().to_string(), + "http://www.example.org/where?q=now" + ); Ok(()) } @@ -359,9 +373,9 @@ mod tests { b"GET http://www.example.org/pub/WWW/TheProject.html HTTP/1.1\nHost: example.com\n\n".as_slice() , false, - )?; + )?.body(()).unwrap(); assert_eq!( - request.url().as_str(), + request.uri().to_string(), "http://www.example.org/pub/WWW/TheProject.html" ); Ok(()) @@ -372,9 +386,11 @@ mod tests { let request = decode_request_headers( &mut b"GET http://www.example.org/pub/WWW/TheProject.html HTTP/1.1\n\n".as_slice(), false, - )?; + )? + .body(()) + .unwrap(); assert_eq!( - request.url().as_str(), + request.uri().to_string(), "http://www.example.org/pub/WWW/TheProject.html" ); Ok(()) @@ -417,8 +433,10 @@ mod tests { let request = decode_request_headers( &mut b"OPTIONS * HTTP/1.1\nHost: www.example.org:8001\n\n".as_slice(), false, - )?; - assert_eq!(request.url().as_str(), "http://www.example.org:8001/"); //TODO: should be http://www.example.org:8001 + )? + .body(()) + .unwrap(); + assert_eq!(request.uri().to_string(), "http://www.example.org:8001/"); // TODO: should be http://www.example.org:8001 Ok(()) } @@ -428,21 +446,24 @@ mod tests { &mut b"GET / HTTP/1.1\nHost: www.example.org:8001\nFoo: v1\nbar: vbar\nfoo: v2\n\n" .as_slice(), true, - )?; - assert_eq!(request.url().as_str(), "https://www.example.org:8001/"); + )? + .body(()) + .unwrap(); + assert_eq!(request.uri().to_string(), "https://www.example.org:8001/"); assert_eq!( request - .header(&HeaderName::from_str("foo").unwrap()) - .unwrap() - .as_ref(), - b"v1, v2".as_ref() + .headers() + .get_all(HeaderName::from_str("foo").unwrap()) + .into_iter() + .collect::>(), + vec!["v1", "v2"] ); assert_eq!( request - .header(&HeaderName::from_str("Bar").unwrap()) - .unwrap() - .as_ref(), - b"vbar".as_ref() + .headers() + .get(HeaderName::from_str("Bar").unwrap()) + .unwrap(), + "vbar" ); Ok(()) } @@ -525,11 +546,8 @@ mod tests { let mut read = b"POST http://example.com/foo HTTP/1.0\r\ncontent-length: 12\r\n\r\nfoobar".as_slice(); let request = decode_request_body(decode_request_headers(&mut read, false)?, read)?; - assert_eq!(request.url().as_str(), "http://example.com/foo"); - assert_eq!( - request.header(&HeaderName::CONNECTION).unwrap().deref(), - b"close" - ); + assert_eq!(request.version(), Version::HTTP_10); + assert_eq!(request.uri().to_string(), "http://example.com/foo"); Ok(()) } @@ -543,7 +561,7 @@ mod tests { #[test] fn decode_response_without_payload() -> Result<()> { let response = decode_response(b"HTTP/1.1 404 Not Found\r\n\r\n".as_slice())?; - assert_eq!(response.status(), Status::NOT_FOUND); + assert_eq!(response.status(), StatusCode::NOT_FOUND); assert_eq!(response.body().len(), Some(0)); Ok(()) } @@ -554,15 +572,8 @@ mod tests { b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ncontent-length:12\r\n\r\ntestbodybody" .as_slice(), )?; - assert_eq!(response.status(), Status::OK); - assert_eq!( - response - .header(&HeaderName::CONTENT_TYPE) - .unwrap() - .to_str() - .unwrap(), - "text/plain" - ); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "text/plain"); assert_eq!(response.into_body().to_string()?, "testbodybody"); Ok(()) } @@ -572,15 +583,8 @@ mod tests { let response = decode_response( b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ntransfer-encoding:chunked\r\n\r\n4\r\nWiki\r\n5\r\npedia\r\nE\r\n in\r\n\r\nchunks.\r\n0\r\n\r\n".as_slice() )?; - assert_eq!(response.status(), Status::OK); - assert_eq!( - response - .header(&HeaderName::CONTENT_TYPE) - .unwrap() - .to_str() - .unwrap(), - "text/plain" - ); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "text/plain"); assert_eq!( response.into_body().to_string()?, "Wikipedia in\r\n\r\nchunks." @@ -593,15 +597,8 @@ mod tests { let response = decode_response( b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ntransfer-encoding:chunked\r\n\r\n4\r\nWiki\r\n5\r\npedia\r\nE\r\n in\r\n\r\nchunks.\r\n0\r\ntest: foo\r\n\r\n".as_slice() )?; - assert_eq!(response.status(), Status::OK); - assert_eq!( - response - .header(&HeaderName::CONTENT_TYPE) - .unwrap() - .to_str() - .unwrap(), - "text/plain" - ); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "text/plain"); let mut buf = String::new(); let mut body = response.into_body(); body.read_to_string(&mut buf)?; @@ -609,10 +606,9 @@ mod tests { assert_eq!( body.trailers() .unwrap() - .get(&HeaderName::from_str("test").unwrap()) - .unwrap() - .as_ref(), - b"foo" + .get(HeaderName::from_static("test")) + .unwrap(), + "foo" ); Ok(()) } @@ -636,10 +632,7 @@ mod tests { #[test] fn decode_unknown_response() -> Result<()> { let response = decode_response(b"HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\ncontent-encoding: foo\r\ncontent-length: 5\r\n\r\nfoooo".as_slice())?; - assert_eq!( - response.headers().get(&HeaderName::CONTENT_ENCODING), - Some(&HeaderValue::new_unchecked("foo".as_bytes())) - ); + assert_eq!(response.headers().get(CONTENT_ENCODING).unwrap(), "foo"); assert_eq!(response.into_body().to_string()?, "foooo"); Ok(()) } @@ -728,7 +721,7 @@ mod tests { let response = decode_response( b"HTTP/1.1 200 OK\r\ntransfer-encoding:chunked\r\n\r\n4\r\nWiki\r\n5\r\npedia\r\nE\r\n in\r\n\r\nchunks.\r\n0\r\n\r\n".as_slice() )?; - assert_eq!(response.status(), Status::OK); + assert_eq!(response.status(), StatusCode::OK); let mut body = response.into_body(); body.read_to_end(&mut Vec::new())?; assert_eq!(body.read(&mut [0; 1])?, 0); diff --git a/src/io/encoder.rs b/src/io/encoder.rs index b54aa60..2cb3dc6 100644 --- a/src/io/encoder.rs +++ b/src/io/encoder.rs @@ -1,38 +1,47 @@ -use crate::model::{Body, HeaderName, Headers, Method, Request, Response, Status}; +use crate::model::header::{ + ACCEPT_CHARSET, ACCEPT_ENCODING, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_REQUEST_HEADERS, + CONNECTION, CONTENT_LENGTH, DATE, EXPECT, HOST, ORIGIN, TE, TRAILER, TRANSFER_ENCODING, + UPGRADE, VIA, +}; +use crate::model::{Body, HeaderMap, HeaderName, Method, Request, Response, StatusCode}; use crate::utils::invalid_input_error; use std::io::{copy, Read, Result, Write}; -pub fn encode_request(request: &mut Request, mut writer: W) -> Result { - if !request.url().username().is_empty() || request.url().password().is_some() { +pub fn encode_request(request: &mut Request, mut writer: W) -> Result { + if request + .uri() + .authority() + .is_some_and(|a| a.as_str().contains('@')) + { return Err(invalid_input_error( "Username and password are not allowed in HTTP URLs", )); } let host = request - .url() - .host_str() + .uri() + .host() .ok_or_else(|| invalid_input_error("No host provided"))?; - if let Some(query) = request.url().query() { + if let Some(query) = request.uri().query() { write!( &mut writer, "{} {}?{} HTTP/1.1\r\n", - request.method(), - request.url().path(), + request.method().as_str(), + request.uri().path(), query )?; } else { write!( &mut writer, "{} {} HTTP/1.1\r\n", - request.method(), - request.url().path(), + request.method().as_str(), + request.uri().path(), )?; } // host - if let Some(port) = request.url().port() { - write!(writer, "host: {host}:{port}\r\n")?; + if let Some(port) = request.uri().port() { + write!(writer, "host: {host}:{}\r\n", port.as_str())?; } else { write!(writer, "host: {host}\r\n")?; } @@ -47,19 +56,25 @@ pub fn encode_request(request: &mut Request, mut writer: W) -> Result< Ok(writer) } -pub fn encode_response(response: &mut Response, mut writer: W) -> Result { - write!(&mut writer, "HTTP/1.1 {}\r\n", response.status())?; +pub fn encode_response(response: &mut Response, mut writer: W) -> Result { + let status = response.status(); + write!( + &mut writer, + "HTTP/1.1 {} {}\r\n", + status.as_u16(), + status.canonical_reason().unwrap_or_default() + )?; encode_headers(response.headers(), &mut writer)?; let must_include_body = does_response_must_include_body(response.status()); encode_body(response.body_mut(), &mut writer, must_include_body)?; Ok(writer) } -fn encode_headers(headers: &Headers, writer: &mut impl Write) -> Result<()> { +fn encode_headers(headers: &HeaderMap, writer: &mut impl Write) -> Result<()> { for (name, value) in headers { if !is_forbidden_name(name) { write!(writer, "{name}: ")?; - writer.write_all(value)?; + writer.write_all(value.as_bytes())?; write!(writer, "\r\n")?; } } @@ -107,49 +122,57 @@ fn encode_body(body: &mut Body, writer: &mut impl Write, must_include_body: bool /// /// We removed some of them not managed by this library (`Access-Control-Request-Headers`, `Access-Control-Request-Method`, `DNT`, `Cookie`, `Cookie2`, `Referer`, `Proxy-`, `Sec-`, `Via`...) fn is_forbidden_name(header: &HeaderName) -> bool { - header.as_ref() == "accept-charset" - || *header == HeaderName::ACCEPT_ENCODING - || header.as_ref() == "access-control-request-headers" - || header.as_ref() == "access-control-request-method" - || *header == HeaderName::CONNECTION - || *header == HeaderName::CONTENT_LENGTH - || *header == HeaderName::DATE - || *header == HeaderName::EXPECT - || *header == HeaderName::HOST - || header.as_ref() == "keep-alive" - || header.as_ref() == "origin" - || *header == HeaderName::TE - || *header == HeaderName::TRAILER - || *header == HeaderName::TRANSFER_ENCODING - || *header == HeaderName::UPGRADE - || *header == HeaderName::VIA + header == ACCEPT_CHARSET + || header == ACCEPT_ENCODING + || header == ACCESS_CONTROL_REQUEST_HEADERS + || header == ACCESS_CONTROL_ALLOW_METHODS + || header == CONNECTION + || header == CONTENT_LENGTH + || header == DATE + || header == EXPECT + || header == HOST + || header.as_str() == "keep-alive" + || header == ORIGIN + || header == TE + || header == TRAILER + || header == TRANSFER_ENCODING + || header == UPGRADE + || header == VIA } fn does_request_must_include_body(method: &Method) -> bool { *method == Method::POST || *method == Method::PUT } -fn does_response_must_include_body(status: Status) -> bool { - !(status.is_informational() || status == Status::NO_CONTENT || status == Status::NOT_MODIFIED) +fn does_response_must_include_body(status: StatusCode) -> bool { + !(status.is_informational() + || status == StatusCode::NO_CONTENT + || status == StatusCode::NOT_MODIFIED) } #[cfg(test)] mod tests { use super::*; - use crate::model::{ChunkedTransferPayload, Headers, Method, Status}; + use crate::model::header::{ACCEPT, CONTENT_LANGUAGE}; + use crate::model::{ChunkedTransferPayload, HeaderMap, HeaderValue}; use std::str; #[test] fn user_password_not_allowed_in_request() { let mut buffer = Vec::new(); assert!(encode_request( - &mut Request::builder(Method::GET, "http://foo@example.com/".parse().unwrap()).build(), + &mut Request::builder() + .uri("http://foo@example.com/") + .body(Body::empty()) + .unwrap(), &mut buffer ) .is_err()); assert!(encode_request( - &mut Request::builder(Method::GET, "http://foo:bar@example.com/".parse().unwrap()) - .build(), + &mut Request::builder() + .uri("http://foo:bar@example.com/") + .body(Body::empty()) + .unwrap(), &mut buffer ) .is_err()); @@ -157,15 +180,11 @@ mod tests { #[test] fn encode_get_request() -> Result<()> { - let mut request = Request::builder( - Method::GET, - "http://example.com:81/foo/bar?query#fragment" - .parse() - .unwrap(), - ) - .with_header(HeaderName::ACCEPT, "application/json") - .unwrap() - .build(); + let mut request = Request::builder() + .uri("http://example.com:81/foo/bar?query#fragment") + .header(ACCEPT, "application/json") + .body(Body::empty()) + .unwrap(); let buffer = encode_request(&mut request, Vec::new())?; assert_eq!( str::from_utf8(&buffer).unwrap(), @@ -176,13 +195,12 @@ mod tests { #[test] fn encode_post_request() -> Result<()> { - let mut request = Request::builder( - Method::POST, - "http://example.com/foo/bar?query#fragment".parse().unwrap(), - ) - .with_header(HeaderName::ACCEPT, "application/json") - .unwrap() - .with_body(b"testbodybody".as_ref()); + let mut request = Request::builder() + .method(Method::POST) + .uri("http://example.com/foo/bar?query#fragment") + .header(ACCEPT, "application/json") + .body(Body::from("testbodybody")) + .unwrap(); let buffer = encode_request(&mut request, Vec::new())?; assert_eq!( str::from_utf8(&buffer).unwrap(), @@ -193,11 +211,11 @@ mod tests { #[test] fn encode_post_request_without_body() -> Result<()> { - let mut request = Request::builder( - Method::POST, - "http://example.com/foo/bar?query#fragment".parse().unwrap(), - ) - .build(); + let mut request = Request::builder() + .method(Method::POST) + .uri("http://example.com/foo/bar?query#fragment") + .body(Body::empty()) + .unwrap(); let buffer = encode_request(&mut request, Vec::new())?; assert_eq!( str::from_utf8(&buffer).unwrap(), @@ -208,17 +226,17 @@ mod tests { #[test] fn encode_post_request_with_chunked() -> Result<()> { - let mut trailers = Headers::new(); - trailers.append(HeaderName::CONTENT_LANGUAGE, "foo".parse().unwrap()); + let mut trailers = HeaderMap::new(); + trailers.append(CONTENT_LANGUAGE, HeaderValue::from_static("foo")); - let mut request = Request::builder( - Method::POST, - "http://example.com/foo/bar?query#fragment".parse().unwrap(), - ) - .with_body(Body::from_chunked_transfer_payload(SimpleTrailers { - read: b"testbodybody".as_slice(), - trailers, - })); + let mut request = Request::builder() + .method(Method::POST) + .uri("http://example.com/foo/bar?query#fragment") + .body(Body::from_chunked_transfer_payload(SimpleTrailers { + read: b"testbodybody".as_slice(), + trailers, + })) + .unwrap(); let buffer = encode_request(&mut request, Vec::new())?; assert_eq!( str::from_utf8(&buffer).unwrap(), @@ -229,10 +247,10 @@ mod tests { #[test] fn encode_response_ok() -> Result<()> { - let mut response = Response::builder(Status::OK) - .with_header(HeaderName::ACCEPT, "application/json") - .unwrap() - .with_body("test test2"); + let mut response = Response::builder() + .header(ACCEPT, "application/json") + .body(Body::from("test test2")) + .unwrap(); let buffer = encode_response(&mut response, Vec::new())?; assert_eq!( str::from_utf8(&buffer).unwrap(), @@ -243,7 +261,10 @@ mod tests { #[test] fn encode_response_not_found() -> Result<()> { - let mut response = Response::builder(Status::NOT_FOUND).build(); + let mut response = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap(); let buffer = encode_response(&mut response, Vec::new())?; assert_eq!( str::from_utf8(&buffer).unwrap(), @@ -254,7 +275,7 @@ mod tests { #[test] fn encode_response_custom_code() -> Result<()> { - let mut response = Response::builder(Status::try_from(499).unwrap()).build(); + let mut response = Response::builder().status(499).body(Body::empty()).unwrap(); let buffer = encode_response(&mut response, Vec::new())?; assert_eq!( str::from_utf8(&buffer).unwrap(), @@ -265,7 +286,7 @@ mod tests { struct SimpleTrailers { read: &'static [u8], - trailers: Headers, + trailers: HeaderMap, } impl Read for SimpleTrailers { @@ -275,7 +296,7 @@ mod tests { } impl ChunkedTransferPayload for SimpleTrailers { - fn trailers(&self) -> Option<&Headers> { + fn trailers(&self) -> Option<&HeaderMap> { Some(&self.trailers) } } diff --git a/src/model/body.rs b/src/model/body.rs index 1fd122b..0d1f789 100644 --- a/src/model/body.rs +++ b/src/model/body.rs @@ -1,6 +1,7 @@ -use crate::model::Headers; +use crate::model::HeaderMap; #[cfg(feature = "flate2")] use flate2::read::{DeflateDecoder, GzDecoder}; +use std::borrow::Cow; use std::fmt; use std::io::{Cursor, Error, ErrorKind, Read, Result}; @@ -60,6 +61,12 @@ impl Body { )))) } + /// The empty body + #[inline] + pub fn empty() -> Self { + Self(BodyAlt::SimpleBorrowed(b"")) + } + /// The number of bytes in the body (if known). #[allow(clippy::len_without_is_empty)] #[inline] @@ -77,7 +84,7 @@ impl Body { /// Returns the chunked transfer encoding trailers if they exists and are already received. /// You should fully consume the body before attempting to fetch them. #[inline] - pub fn trailers(&self) -> Option<&Headers> { + pub fn trailers(&self) -> Option<&HeaderMap> { match &self.0 { BodyAlt::SimpleOwned(_) | BodyAlt::SimpleBorrowed(_) | BodyAlt::Sized { .. } => None, BodyAlt::Chunked(c) => c.trailers(), @@ -185,7 +192,7 @@ impl Read for Body { impl Default for Body { #[inline] fn default() -> Self { - b"".as_ref().into() + Self::empty() } } @@ -217,6 +224,33 @@ impl From<&'static str> for Body { } } +impl From> for Body { + #[inline] + fn from(data: Cow<'static, [u8]>) -> Self { + match data { + Cow::Borrowed(data) => data.into(), + Cow::Owned(data) => data.into(), + } + } +} + +impl From> for Body { + #[inline] + fn from(data: Cow<'static, str>) -> Self { + match data { + Cow::Borrowed(data) => data.into(), + Cow::Owned(data) => data.into(), + } + } +} + +impl From<()> for Body { + #[inline] + fn from(_: ()) -> Self { + Self::empty() + } +} + impl fmt::Debug for Body { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -230,7 +264,7 @@ impl fmt::Debug for Body { /// It allows to provide [trailers](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#trailer.fields) to serialize. pub trait ChunkedTransferPayload: Read { /// The [trailers](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#trailer.fields) to serialize. - fn trailers(&self) -> Option<&Headers>; + fn trailers(&self) -> Option<&HeaderMap>; } struct SimpleChunkedTransferEncoding(R); @@ -244,7 +278,7 @@ impl Read for SimpleChunkedTransferEncoding { impl ChunkedTransferPayload for SimpleChunkedTransferEncoding { #[inline] - fn trailers(&self) -> Option<&Headers> { + fn trailers(&self) -> Option<&HeaderMap> { None } } diff --git a/src/model/header.rs b/src/model/header.rs deleted file mode 100644 index 083c6b5..0000000 --- a/src/model/header.rs +++ /dev/null @@ -1,610 +0,0 @@ -use std::borrow::{Borrow, Cow}; -use std::collections::btree_map::Entry; -use std::collections::BTreeMap; -use std::convert::Infallible; -use std::error::Error; -use std::fmt; -use std::fmt::Debug; -use std::ops::Deref; -use std::str; -use std::str::{FromStr, Utf8Error}; - -/// A list of headers aka [fields](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#fields). -/// -/// ``` -/// use oxhttp::model::{Headers, HeaderName, HeaderValue}; -/// use std::str::FromStr; -/// -/// let mut headers = Headers::new(); -/// headers.append(HeaderName::ACCEPT_LANGUAGE, "en".parse()?); -/// headers.append(HeaderName::ACCEPT_LANGUAGE, "fr".parse()?); -/// assert_eq!(headers.get(&HeaderName::ACCEPT_LANGUAGE).unwrap().as_ref(), b"en, fr"); -/// # Result::<_,Box>::Ok(()) -/// ``` -#[derive(PartialEq, Eq, Debug, Clone, Hash, Default)] -pub struct Headers(BTreeMap); - -impl Headers { - #[inline] - pub fn new() -> Self { - Self::default() - } - - /// Adds a header to the list. - /// - /// It does not override the existing value(s) for the same header. - #[inline] - pub fn append(&mut self, name: HeaderName, value: HeaderValue) { - match self.0.entry(name) { - Entry::Occupied(e) => { - let val = &mut e.into_mut().0.to_mut(); - val.extend_from_slice(b", "); - val.extend_from_slice(&value.0); - } - Entry::Vacant(e) => { - e.insert(value); - } - } - } - - /// Removes an header from the list. - #[inline] - pub fn remove(&mut self, name: &HeaderName) { - self.0.remove(name); - } - - /// Get an header value(s) from the list. - #[inline] - pub fn get(&self, name: &HeaderName) -> Option<&HeaderValue> { - self.0.get(name) - } - - #[inline] - pub fn contains(&self, name: &HeaderName) -> bool { - self.0.contains_key(name) - } - - /// Sets a header it the list. - /// - /// It overrides the existing value(s) for the same header. - #[inline] - pub fn set(&mut self, name: HeaderName, value: HeaderValue) { - self.0.insert(name, value); - } - - #[inline] - pub fn iter(&self) -> Iter<'_> { - Iter(self.0.iter()) - } - - /// Number of distinct headers - #[inline] - pub fn len(&self) -> usize { - self.0.len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } -} - -impl IntoIterator for Headers { - type Item = (HeaderName, HeaderValue); - type IntoIter = IntoIter; - - #[inline] - fn into_iter(self) -> IntoIter { - IntoIter(self.0.into_iter()) - } -} - -impl<'a> IntoIterator for &'a Headers { - type Item = (&'a HeaderName, &'a HeaderValue); - type IntoIter = Iter<'a>; - - #[inline] - fn into_iter(self) -> Iter<'a> { - self.iter() - } -} - -/// A [header/field name](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#fields.names). -/// -/// It is also normalized to lower case to ease equality checks. -/// -/// ``` -/// use oxhttp::model::HeaderName; -/// use std::str::FromStr; -/// -/// assert_eq!(HeaderName::from_str("content-Type")?, HeaderName::CONTENT_TYPE); -/// # Result::<_,Box>::Ok(()) -/// ``` -#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Hash)] -pub struct HeaderName(Cow<'static, str>); - -impl HeaderName { - #[inline] - pub(crate) fn new_unchecked(name: impl Into>) -> Self { - Self(name.into()) - } - - /// [`Accept`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.accept) - pub const ACCEPT: Self = Self(Cow::Borrowed("accept")); - /// [`Accept-Encoding`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.accept-encoding) - pub const ACCEPT_ENCODING: Self = Self(Cow::Borrowed("accept-encoding")); - /// [`Accept-Language`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.accept-language) - pub const ACCEPT_LANGUAGE: Self = Self(Cow::Borrowed("accept-language")); - /// [`Allow`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.allow) - pub const ACCEPT_RANGES: Self = Self(Cow::Borrowed("accept-ranges")); - /// [`Allow`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.allow) - pub const ALLOW: Self = Self(Cow::Borrowed("allow")); - /// [`Authentication-Info`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.authentication-info) - pub const AUTHENTICATION_INFO: Self = Self(Cow::Borrowed("authentication-info")); - /// [`Authorization`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.authorization) - pub const AUTHORIZATION: Self = Self(Cow::Borrowed("authorization")); - /// [`Connection`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.connection) - pub const CONNECTION: Self = Self(Cow::Borrowed("connection")); - /// [`Content-Encoding`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.content-encoding) - pub const CONTENT_ENCODING: Self = Self(Cow::Borrowed("content-encoding")); - /// [`Content-Language`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.content-language) - pub const CONTENT_LANGUAGE: Self = Self(Cow::Borrowed("content-language")); - /// [`Content-Length`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.content-length) - pub const CONTENT_LENGTH: Self = Self(Cow::Borrowed("content-length")); - /// [`Content-Location`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.content-location) - pub const CONTENT_LOCATION: Self = Self(Cow::Borrowed("content-location")); - /// [`Content-Range`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.content-range) - pub const CONTENT_RANGE: Self = Self(Cow::Borrowed("content-range")); - /// [`Content-Type`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.content-type) - pub const CONTENT_TYPE: Self = Self(Cow::Borrowed("content-type")); - /// [`Date`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.date) - pub const DATE: Self = Self(Cow::Borrowed("date")); - /// [`ETag`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.etag) - pub const ETAG: Self = Self(Cow::Borrowed("etag")); - /// [`Expect`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.expect) - pub const EXPECT: Self = Self(Cow::Borrowed("expect")); - /// [`From`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.from) - pub const FROM: Self = Self(Cow::Borrowed("from")); - /// [`Host`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.host) - pub const HOST: Self = Self(Cow::Borrowed("host")); - /// [`If-Match`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.if-match) - pub const IF_MATCH: Self = Self(Cow::Borrowed("if-match")); - /// [`If-Modified-Since`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.if-modified-since) - pub const IF_MODIFIED_SINCE: Self = Self(Cow::Borrowed("if-modified-since")); - /// [`If-None-Match`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.if-none-match) - pub const IF_NONE_MATCH: Self = Self(Cow::Borrowed("if-none-match")); - /// [`If-Range`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.if-range) - pub const IF_RANGE: Self = Self(Cow::Borrowed("if-range")); - /// [`If-Unmodified-Since`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.if-unmodified-since) - pub const IF_UNMODIFIED_SINCE: Self = Self(Cow::Borrowed("if-unmodified-since")); - /// [`Last-Modified`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.last-modified) - pub const LAST_MODIFIED: Self = Self(Cow::Borrowed("last-modified")); - /// [`Location`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.location) - pub const LOCATION: Self = Self(Cow::Borrowed("location")); - /// [`Max-Forwards`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.max-forwards) - pub const MAX_FORWARDS: Self = Self(Cow::Borrowed("max-forwards")); - /// [`Proxy-Authenticate`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.proxy-authenticate) - pub const PROXY_AUTHENTICATE: Self = Self(Cow::Borrowed("proxy-authenticate")); - /// [`Proxy-Authentication-Info`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.proxy-authentication-info) - pub const PROXY_AUTHENTICATION_INFO: Self = Self(Cow::Borrowed("proxy-authentication-info")); - /// [`Proxy-Authorization`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.proxy-authorization) - pub const PROXY_AUTHORIZATION: Self = Self(Cow::Borrowed("proxy-authorization")); - /// [`Range`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.range) - pub const RANGE: Self = Self(Cow::Borrowed("range")); - /// [`Referer`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.referer) - pub const REFERER: Self = Self(Cow::Borrowed("referer")); - /// [`Retry-After`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.retry-after) - pub const RETRY_AFTER: Self = Self(Cow::Borrowed("retry-after")); - /// [`Server`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.server) - pub const SERVER: Self = Self(Cow::Borrowed("server")); - /// [`TE`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.te) - pub const TE: Self = Self(Cow::Borrowed("te")); - /// [`Trailer`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.trailer) - pub const TRAILER: Self = Self(Cow::Borrowed("trailer")); - /// [`Transfer-Encoding`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.transfer-encoding) - pub const TRANSFER_ENCODING: Self = Self(Cow::Borrowed("transfer-encoding")); - /// [`Upgrade`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.upgrade) - pub const UPGRADE: Self = Self(Cow::Borrowed("upgrade")); - /// [`User-Agent`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.user-agent) - pub const USER_AGENT: Self = Self(Cow::Borrowed("user-agent")); - /// [`Vary`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.vary) - pub const VARY: Self = Self(Cow::Borrowed("vary")); - /// [`Via`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.via) - pub const VIA: Self = Self(Cow::Borrowed("via")); - /// [`WWW-Authenticate`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.www-authenticate) - pub const WWW_AUTHENTICATE: Self = Self(Cow::Borrowed("www-authenticate")); -} - -impl Deref for HeaderName { - type Target = str; - - #[inline] - fn deref(&self) -> &str { - &self.0 - } -} - -impl AsRef for HeaderName { - #[inline] - fn as_ref(&self) -> &str { - &self.0 - } -} - -impl Borrow for HeaderName { - #[inline] - fn borrow(&self) -> &str { - &self.0 - } -} - -impl FromStr for HeaderName { - type Err = InvalidHeader; - - #[inline] - fn from_str(name: &str) -> Result { - Self::try_from(name.to_owned()) - } -} - -impl TryFrom<&'static str> for HeaderName { - type Error = InvalidHeader; - - #[inline] - fn try_from(value: &'static str) -> Result { - Self::try_from(Cow::Borrowed(value)) - } -} - -impl TryFrom for HeaderName { - type Error = InvalidHeader; - - #[inline] - fn try_from(value: String) -> Result { - Self::try_from(Cow::Owned(value)) - } -} - -impl TryFrom> for HeaderName { - type Error = InvalidHeader; - - #[inline] - fn try_from(mut name: Cow<'static, str>) -> Result { - if name.contains(|c: char| c.is_ascii_uppercase()) { - name.to_mut().make_ascii_lowercase(); // We normalize to lowercase - } - if name.is_empty() { - Err(InvalidHeader(InvalidHeaderAlt::EmptyName)) - } else { - for c in name.chars() { - if !matches!(c, '!' | '#' | '$' | '%' | '&' | '\'' | '*' - | '+' | '-' | '.' | '^' | '_' | '`' | '|' | '~' - | '0'..='9' | 'a'..='z') - { - return Err(InvalidHeader(InvalidHeaderAlt::InvalidNameChar { - name, - invalid_char: c, - })); - } - } - Ok(Self(name)) - } - } -} - -impl fmt::Display for HeaderName { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -pub trait IntoHeaderName { - fn try_into(self) -> Result; -} - -impl IntoHeaderName for HeaderName { - #[inline] - fn try_into(self) -> Result { - Ok(self) - } -} - -impl> IntoHeaderName for T { - #[inline] - fn try_into(self) -> Result { - self.try_into() - } -} - -/// A [header/field value](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#fields.values). -/// -/// ``` -/// use oxhttp::model::HeaderValue; -/// use std::str::FromStr; -/// -/// assert_eq!(HeaderValue::from_str("foo")?.as_ref(), b"foo"); -/// # Result::<_,Box>::Ok(()) -/// ``` -#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Hash, Default)] -pub struct HeaderValue(Cow<'static, [u8]>); - -impl HeaderValue { - #[inline] - pub(crate) fn new_unchecked(value: impl Into>) -> Self { - Self(value.into()) - } - - #[inline] - pub fn to_str(&self) -> Result<&str, Utf8Error> { - str::from_utf8(self) - } -} -impl Deref for HeaderValue { - type Target = [u8]; - - #[inline] - fn deref(&self) -> &[u8] { - &self.0 - } -} - -impl AsRef<[u8]> for HeaderValue { - #[inline] - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl Borrow<[u8]> for HeaderValue { - #[inline] - fn borrow(&self) -> &[u8] { - &self.0 - } -} - -impl FromStr for HeaderValue { - type Err = InvalidHeader; - - #[inline] - fn from_str(value: &str) -> Result { - Self::try_from(value.to_string().into_bytes()) - } -} - -impl TryFrom<&'static str> for HeaderValue { - type Error = InvalidHeader; - - #[inline] - fn try_from(value: &str) -> Result { - Self::try_from(value.to_owned()) - } -} - -impl TryFrom for HeaderValue { - type Error = InvalidHeader; - - #[inline] - fn try_from(value: String) -> Result { - Self::try_from(value.into_bytes()) - } -} - -impl TryFrom> for HeaderValue { - type Error = InvalidHeader; - - #[inline] - fn try_from(value: Cow<'static, str>) -> Result { - Self::try_from(match value { - Cow::Owned(value) => Cow::Owned(value.into_bytes()), - Cow::Borrowed(value) => Cow::Borrowed(value.as_bytes()), - }) - } -} - -impl TryFrom<&'static [u8]> for HeaderValue { - type Error = InvalidHeader; - - #[inline] - fn try_from(value: &'static [u8]) -> Result { - Cow::Borrowed(value).try_into() - } -} - -impl TryFrom> for HeaderValue { - type Error = InvalidHeader; - - #[inline] - fn try_from(value: Vec) -> Result { - Cow::<'static, [u8]>::Owned(value).try_into() - } -} - -impl TryFrom> for HeaderValue { - type Error = InvalidHeader; - - #[inline] - fn try_from(value: Cow<'static, [u8]>) -> Result { - // no tab or space at the beginning - if let Some(c) = value.first().cloned() { - if matches!(c, b'\t' | b' ') { - return Err(InvalidHeader(InvalidHeaderAlt::InvalidValueByte { - value, - invalid_byte: c, - })); - } - } - // no tab or space at the end - if let Some(c) = value.last().cloned() { - if matches!(c, b'\t' | b' ') { - return Err(InvalidHeader(InvalidHeaderAlt::InvalidValueByte { - value, - invalid_byte: c, - })); - } - } - // no line jump - for c in value.iter() { - if matches!(*c, b'\r' | b'\n') { - return Err(InvalidHeader(InvalidHeaderAlt::InvalidValueByte { - value: value.clone(), - invalid_byte: *c, - })); - } - } - Ok(HeaderValue(value)) - } -} - -impl fmt::Display for HeaderValue { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", String::from_utf8_lossy(&self.0)) - } -} - -#[derive(Debug)] -pub struct Iter<'a>(std::collections::btree_map::Iter<'a, HeaderName, HeaderValue>); - -impl<'a> Iterator for Iter<'a> { - type Item = (&'a HeaderName, &'a HeaderValue); - - #[inline] - fn next(&mut self) -> Option<(&'a HeaderName, &'a HeaderValue)> { - self.0.next() - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.0.size_hint() - } - - #[inline] - fn last(self) -> Option<(&'a HeaderName, &'a HeaderValue)> { - self.0.last() - } -} - -impl<'a> DoubleEndedIterator for Iter<'a> { - #[inline] - fn next_back(&mut self) -> Option<(&'a HeaderName, &'a HeaderValue)> { - self.0.next_back() - } -} - -impl ExactSizeIterator for Iter<'_> { - #[inline] - fn len(&self) -> usize { - self.0.len() - } -} - -#[derive(Debug)] -pub struct IntoIter(std::collections::btree_map::IntoIter); - -impl Iterator for IntoIter { - type Item = (HeaderName, HeaderValue); - - #[inline] - fn next(&mut self) -> Option<(HeaderName, HeaderValue)> { - self.0.next() - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - self.0.size_hint() - } - - #[inline] - fn last(self) -> Option<(HeaderName, HeaderValue)> { - self.0.last() - } -} - -impl DoubleEndedIterator for IntoIter { - #[inline] - fn next_back(&mut self) -> Option<(HeaderName, HeaderValue)> { - self.0.next_back() - } -} - -impl ExactSizeIterator for IntoIter { - #[inline] - fn len(&self) -> usize { - self.0.len() - } -} - -/// Error returned by [`HeaderName::try_from`]. -#[derive(Debug, Clone)] -pub struct InvalidHeader(InvalidHeaderAlt); - -#[derive(Debug, Clone)] -enum InvalidHeaderAlt { - EmptyName, - InvalidNameChar { - name: Cow<'static, str>, - invalid_char: char, - }, - InvalidValueByte { - value: Cow<'static, [u8]>, - invalid_byte: u8, - }, -} - -impl fmt::Display for InvalidHeader { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.0 { - InvalidHeaderAlt::EmptyName => f.write_str("header names should not be empty"), - InvalidHeaderAlt::InvalidNameChar { name, invalid_char } => write!( - f, - "The character '{invalid_char}' is not valid inside of header name '{name}'" - ), - InvalidHeaderAlt::InvalidValueByte { - value, - invalid_byte, - } => write!( - f, - "The byte '{}' is not valid inside of header value '{}'", - invalid_byte, - String::from_utf8_lossy(value) - ), - } - } -} - -impl Error for InvalidHeader {} - -impl From for InvalidHeader { - #[inline] - fn from(e: Infallible) -> Self { - match e {} - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn validate_header_name() { - assert!(HeaderName::from_str("").is_err()); - assert!(HeaderName::from_str("ffo bar").is_err()); - assert!(HeaderName::from_str("ffo\tbar").is_err()); - assert!(HeaderName::from_str("ffo\rbar").is_err()); - assert!(HeaderName::from_str("ffo\nbar").is_err()); - assert!(HeaderName::from_str("ffoébar").is_err()); - assert!(HeaderName::from_str("foo-bar").is_ok()); - } - - #[test] - fn validate_header_value() { - assert!(HeaderValue::from_str("").is_ok()); - assert!(HeaderValue::from_str(" ffobar").is_err()); - assert!(HeaderValue::from_str("ffobar ").is_err()); - assert!(HeaderValue::from_str("ffo\rbar").is_err()); - assert!(HeaderValue::from_str("ffo\nbar").is_err()); - assert!(HeaderValue::from_str("ffoébar").is_ok()); - } -} diff --git a/src/model/method.rs b/src/model/method.rs deleted file mode 100644 index 76cb7e4..0000000 --- a/src/model/method.rs +++ /dev/null @@ -1,165 +0,0 @@ -use std::borrow::{Borrow, Cow}; -use std::error::Error; -use std::fmt; -use std::ops::Deref; -use std::str::FromStr; - -/// An [HTTP method](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#methods) like `GET` or `POST`. -/// -/// ``` -/// use oxhttp::model::Method; -/// use std::str::FromStr; -/// -/// assert_eq!(Method::from_str("get")?, Method::GET); -/// # Result::<_,Box>::Ok(()) -/// ``` -#[derive(PartialEq, Eq, Debug, Clone, Hash)] -pub struct Method(Cow<'static, str>); - -impl Method { - /// Is the method [safe](https://httpwg.org/specs/rfc7231.html#safe.methods) - pub(crate) fn is_safe(&self) -> bool { - matches!(self.as_ref(), "GET" | "HEAD" | "OPTIONS" | "TRACE") - } - - /// [CONNECT](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#CONNECT). - pub const CONNECT: Method = Self(Cow::Borrowed("CONNECT")); - /// [DELETE](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#DELETE). - pub const DELETE: Method = Self(Cow::Borrowed("DELETE")); - /// [GET](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#GET). - pub const GET: Method = Self(Cow::Borrowed("GET")); - /// [HEAD](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#HEAD). - pub const HEAD: Method = Self(Cow::Borrowed("HEAD")); - /// [OPTIONS](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#OPTIONS). - pub const OPTIONS: Method = Self(Cow::Borrowed("OPTIONS")); - /// [POST](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#POST). - pub const POST: Method = Self(Cow::Borrowed("POST")); - /// [PUT](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#PUT). - pub const PUT: Method = Self(Cow::Borrowed("PUT")); - /// [TRACE](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#TRACE). - pub const TRACE: Method = Self(Cow::Borrowed("TRACE")); -} - -impl Deref for Method { - type Target = str; - - #[inline] - fn deref(&self) -> &str { - &self.0 - } -} - -impl AsRef for Method { - #[inline] - fn as_ref(&self) -> &str { - &self.0 - } -} - -impl Borrow for Method { - #[inline] - fn borrow(&self) -> &str { - &self.0 - } -} - -impl FromStr for Method { - type Err = InvalidMethod; - - #[inline] - fn from_str(name: &str) -> Result { - for method in STATIC_METHODS { - if method.eq_ignore_ascii_case(name) { - return Ok(method); - } - } - name.to_owned().try_into() - } -} - -impl TryFrom for Method { - type Error = InvalidMethod; - - #[inline] - fn try_from(name: String) -> Result { - for method in STATIC_METHODS { - if method.eq_ignore_ascii_case(&name) { - return Ok(method); - } - } - if name.is_empty() { - Err(InvalidMethod(InvalidMethodAlt::Empty)) - } else { - for c in name.chars() { - if !matches!(c, '!' | '#' | '$' | '%' | '&' | '\'' | '*' - | '+' | '-' | '.' | '^' | '_' | '`' | '|' | '~' - | '0'..='9' | 'a'..='z' | 'A'..='Z') - { - return Err(InvalidMethod(InvalidMethodAlt::InvalidChar { - name: name.to_owned(), - invalid_char: c, - })); - } - } - Ok(Self(name.into())) - } - } -} - -impl fmt::Display for Method { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.as_ref()) - } -} - -const STATIC_METHODS: [Method; 8] = [ - Method::CONNECT, - Method::DELETE, - Method::GET, - Method::HEAD, - Method::OPTIONS, - Method::POST, - Method::PUT, - Method::TRACE, -]; - -/// Error returned by [`Method::try_from`]. -#[derive(Debug, Clone)] -pub struct InvalidMethod(InvalidMethodAlt); - -#[derive(Debug, Clone)] -enum InvalidMethodAlt { - Empty, - InvalidChar { name: String, invalid_char: char }, -} - -impl fmt::Display for InvalidMethod { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.0 { - InvalidMethodAlt::Empty => f.write_str("HTTP methods should not be empty"), - InvalidMethodAlt::InvalidChar { name, invalid_char } => write!( - f, - "The character '{invalid_char}' is not valid inside of HTTP method '{name}'" - ), - } - } -} - -impl Error for InvalidMethod {} -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn validate_header_name() { - assert!(Method::from_str("").is_err()); - assert!(Method::from_str("ffo bar").is_err()); - assert!(Method::from_str("ffo\tbar").is_err()); - assert!(Method::from_str("ffo\rbar").is_err()); - assert!(Method::from_str("ffo\nbar").is_err()); - assert!(Method::from_str("ffoébar").is_err()); - assert!(Method::from_str("foo-bar").is_ok()); - } -} diff --git a/src/model/mod.rs b/src/model/mod.rs index 9baf70a..870dfcd 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -1,17 +1,9 @@ //! The HTTP model encoded in Rust type system. //! +//! This reexport the [`http`](https://docs.rs/http) crate except for [`Body`]. +//! //! The main entry points are [`Request`] and [`Response`]. mod body; -mod header; -mod method; -mod request; -mod response; -mod status; pub use body::{Body, ChunkedTransferPayload}; -pub use header::{HeaderName, HeaderValue, Headers, InvalidHeader}; -pub use method::{InvalidMethod, Method}; -pub use request::{Request, RequestBuilder}; -pub use response::{Response, ResponseBuilder}; -pub use status::{InvalidStatus, Status}; -pub use url::Url; +pub use http::*; diff --git a/src/model/request.rs b/src/model/request.rs deleted file mode 100644 index 650dd40..0000000 --- a/src/model/request.rs +++ /dev/null @@ -1,147 +0,0 @@ -use crate::model::header::IntoHeaderName; -use crate::model::{Body, HeaderName, HeaderValue, Headers, InvalidHeader, Method, Url}; - -/// A HTTP request. -/// -/// ``` -/// use oxhttp::model::{Request, Method, HeaderName, Body}; -/// -/// let request = Request::builder(Method::POST, "http://example.com:80/foo".parse()?) -/// .with_header(HeaderName::CONTENT_TYPE, "application/json")? -/// .with_body("{\"foo\": \"bar\"}"); -/// -/// assert_eq!(*request.method(), Method::POST); -/// assert_eq!(request.url().as_str(), "http://example.com/foo"); -/// assert_eq!(request.header(&HeaderName::CONTENT_TYPE).unwrap().as_ref(), b"application/json"); -/// assert_eq!(&request.into_body().to_vec()?, b"{\"foo\": \"bar\"}"); -/// # Result::<_,Box>::Ok(()) -/// ``` -#[derive(Debug)] -pub struct Request { - method: Method, - url: Url, - headers: Headers, - body: Body, -} - -impl Request { - #[inline] - pub fn builder(method: Method, url: Url) -> RequestBuilder { - RequestBuilder { - method, - url, - headers: Headers::new(), - } - } - - #[inline] - pub fn method(&self) -> &Method { - &self.method - } - - #[inline] - pub fn url(&self) -> &Url { - &self.url - } - - #[inline] - pub fn headers(&self) -> &Headers { - &self.headers - } - - #[inline] - pub fn headers_mut(&mut self) -> &mut Headers { - &mut self.headers - } - - #[inline] - pub fn header(&self, name: &HeaderName) -> Option<&HeaderValue> { - self.headers.get(name) - } - - #[inline] - pub fn append_header>( - &mut self, - name: impl IntoHeaderName, - value: impl TryInto, - ) -> Result<(), InvalidHeader> { - self.headers_mut() - .append(name.try_into()?, value.try_into().map_err(Into::into)?); - Ok(()) - } - - #[inline] - pub fn body(&self) -> &Body { - &self.body - } - - #[inline] - pub fn body_mut(&mut self) -> &mut Body { - &mut self.body - } - - #[inline] - pub fn into_body(self) -> Body { - self.body - } -} - -/// Builder for [`Request`] -pub struct RequestBuilder { - method: Method, - url: Url, - headers: Headers, -} - -impl RequestBuilder { - #[inline] - pub fn method(&self) -> &Method { - &self.method - } - - #[inline] - pub fn url(&self) -> &Url { - &self.url - } - - #[inline] - pub fn headers(&self) -> &Headers { - &self.headers - } - - #[inline] - pub fn headers_mut(&mut self) -> &mut Headers { - &mut self.headers - } - - #[inline] - pub fn header(&self, name: &HeaderName) -> Option<&HeaderValue> { - self.headers.get(name) - } - - #[inline] - pub fn with_header>( - mut self, - name: impl IntoHeaderName, - value: impl TryInto, - ) -> Result { - self.headers_mut() - .append(name.try_into()?, value.try_into().map_err(Into::into)?); - Ok(self) - } - - #[inline] - pub fn with_body(self, body: impl Into) -> Request { - Request { - method: self.method, - url: self.url, - headers: self.headers, - body: body.into(), - } - } - - #[inline] - pub fn build(self) -> Request { - self.with_body(Body::default()) - } -} diff --git a/src/model/response.rs b/src/model/response.rs deleted file mode 100644 index fc95808..0000000 --- a/src/model/response.rs +++ /dev/null @@ -1,133 +0,0 @@ -use crate::model::header::IntoHeaderName; -use crate::model::{Body, HeaderName, HeaderValue, Headers, InvalidHeader, Status}; - -/// A HTTP response. -/// -/// ``` -/// use oxhttp::model::{HeaderName, Body, Response, Status}; -/// -/// let response = Response::builder(Status::OK) -/// .with_header(HeaderName::CONTENT_TYPE, "application/json")? -/// .with_header("X-Custom", "foo")? -/// .with_body("{\"foo\": \"bar\"}"); -/// -/// assert_eq!(response.status(), Status::OK); -/// assert_eq!(response.header(&HeaderName::CONTENT_TYPE).unwrap().as_ref(), b"application/json"); -/// assert_eq!(&response.into_body().to_vec()?, b"{\"foo\": \"bar\"}"); -/// # Result::<_,Box>::Ok(()) -/// ``` -#[derive(Debug)] -pub struct Response { - status: Status, - headers: Headers, - body: Body, -} - -impl Response { - #[inline] - pub fn builder(status: Status) -> ResponseBuilder { - ResponseBuilder { - status, - headers: Headers::new(), - } - } - - #[inline] - pub fn status(&self) -> Status { - self.status - } - - #[inline] - pub fn headers(&self) -> &Headers { - &self.headers - } - - #[inline] - pub fn headers_mut(&mut self) -> &mut Headers { - &mut self.headers - } - - #[inline] - pub fn header(&self, name: &HeaderName) -> Option<&HeaderValue> { - self.headers.get(name) - } - - #[inline] - pub fn append_header>( - &mut self, - name: impl IntoHeaderName, - value: impl TryInto, - ) -> Result<(), InvalidHeader> { - self.headers_mut() - .append(name.try_into()?, value.try_into().map_err(Into::into)?); - Ok(()) - } - - #[inline] - pub fn body(&self) -> &Body { - &self.body - } - - #[inline] - pub fn body_mut(&mut self) -> &mut Body { - &mut self.body - } - - #[inline] - pub fn into_body(self) -> Body { - self.body - } -} - -/// Builder for [`Response`] -pub struct ResponseBuilder { - status: Status, - headers: Headers, -} - -impl ResponseBuilder { - #[inline] - pub fn status(&self) -> Status { - self.status - } - - #[inline] - pub fn headers(&self) -> &Headers { - &self.headers - } - - #[inline] - pub fn headers_mut(&mut self) -> &mut Headers { - &mut self.headers - } - - #[inline] - pub fn header(&self, name: &HeaderName) -> Option<&HeaderValue> { - self.headers.get(name) - } - - #[inline] - pub fn with_header>( - mut self, - name: impl IntoHeaderName, - value: impl TryInto, - ) -> Result { - self.headers_mut() - .append(name.try_into()?, value.try_into().map_err(Into::into)?); - Ok(self) - } - - #[inline] - pub fn with_body(self, body: impl Into) -> Response { - Response { - status: self.status, - headers: self.headers, - body: body.into(), - } - } - - #[inline] - pub fn build(self) -> Response { - self.with_body(Body::default()) - } -} diff --git a/src/model/status.rs b/src/model/status.rs deleted file mode 100644 index 5cce358..0000000 --- a/src/model/status.rs +++ /dev/null @@ -1,271 +0,0 @@ -use std::borrow::Borrow; -use std::error::Error; -use std::fmt; -use std::ops::Deref; - -/// An HTTP [status](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.codes). -/// -/// ``` -/// use oxhttp::model::Status; -/// -/// assert_eq!(Status::OK, Status::try_from(200)?); -/// # Result::<_,Box>::Ok(()) -/// ``` -#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)] -pub struct Status(u16); - -impl Status { - /// Is the status [informational](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.1xx). - #[inline] - pub fn is_informational(&self) -> bool { - (100..=199).contains(&self.0) - } - - /// Is the status [successful](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.2xx). - #[inline] - pub fn is_successful(&self) -> bool { - (200..=299).contains(&self.0) - } - - /// Is the status [related to redirections](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.3xx). - #[inline] - pub fn is_redirection(&self) -> bool { - (300..=399).contains(&self.0) - } - - /// Is the status a [client error](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.4xx). - #[inline] - pub fn is_client_error(&self) -> bool { - (400..=499).contains(&self.0) - } - - /// Is the status [server error](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.5xx). - #[inline] - pub fn is_server_error(&self) -> bool { - (500..=599).contains(&self.0) - } - - /// [100 Continue](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.100) - pub const CONTINUE: Self = Self(100); - /// [101 Switching Protocols](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.101) - pub const SWITCHING_PROTOCOLS: Self = Self(101); - /// [200 OK](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.200) - pub const OK: Self = Self(200); - /// [201 Created](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.201) - pub const CREATED: Self = Self(201); - /// [202 Accepted](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.202) - pub const ACCEPTED: Self = Self(202); - /// [203 Non-Authoritative Information](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.203) - pub const NON_AUTHORITATIVE_INFORMATION: Self = Self(203); - /// [204 No Content](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.204) - pub const NO_CONTENT: Self = Self(204); - /// [205 Reset Content](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.205) - pub const RESET_CONTENT: Self = Self(205); - /// [206 Partial Content](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.206) - pub const PARTIAL_CONTENT: Self = Self(206); - /// [300 Multiple Choices](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.300) - pub const MULTIPLE_CHOICES: Self = Self(300); - /// [301 Moved Permanently](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.301) - pub const MOVED_PERMANENTLY: Self = Self(301); - /// [302 Found](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.302) - pub const FOUND: Self = Self(302); - /// [303 See Other](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.303) - pub const SEE_OTHER: Self = Self(303); - /// [304 Not Modified](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.304) - pub const NOT_MODIFIED: Self = Self(304); - /// [305 Use Proxy](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.305) - pub const USE_PROXY: Self = Self(305); - /// [307 Temporary Redirect](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.307) - pub const TEMPORARY_REDIRECT: Self = Self(307); - /// [308 Permanent Redirect](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.308) - pub const PERMANENT_REDIRECT: Self = Self(308); - /// [400 Bad Request](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.400) - pub const BAD_REQUEST: Self = Self(400); - /// [401 Unauthorized](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.401) - pub const UNAUTHORIZED: Self = Self(401); - /// [402 Payment Required](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.402) - pub const PAYMENT_REQUIRED: Self = Self(402); - /// [403 Forbidden](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.403) - pub const FORBIDDEN: Self = Self(403); - /// [404 Not Found](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.404) - pub const NOT_FOUND: Self = Self(404); - /// [405 Method Not Allowed](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.405) - pub const METHOD_NOT_ALLOWED: Self = Self(405); - /// [406 Not Acceptable](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.406) - pub const NOT_ACCEPTABLE: Self = Self(406); - /// [407 Proxy Authentication Required](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.407) - pub const PROXY_AUTHENTICATION_REQUIRED: Self = Self(407); - /// [408 Request Timeout](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.408) - pub const REQUEST_TIMEOUT: Self = Self(408); - /// [409 Conflict](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.409) - pub const CONFLICT: Self = Self(409); - /// [410 Gone](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.410) - pub const GONE: Self = Self(410); - /// [411 Length Required](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.411) - pub const LENGTH_REQUIRED: Self = Self(411); - /// [412 Precondition Failed](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.412) - pub const PRECONDITION_FAILED: Self = Self(412); - /// [413 Content Too Large](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.413) - pub const CONTENT_TOO_LARGE: Self = Self(413); - /// [414 URI Too Long](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.414) - pub const URI_TOO_LONG: Self = Self(414); - /// [415 Unsupported Media Type](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.415) - pub const UNSUPPORTED_MEDIA_TYPE: Self = Self(415); - /// [416 Range Not Satisfiable](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.416) - pub const RANGE_NOT_SATISFIABLE: Self = Self(416); - /// [417 Expectation Failed](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.417) - pub const EXPECTATION_FAILED: Self = Self(417); - /// [421 Misdirected Request](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.421) - pub const MISDIRECTED_REQUEST: Self = Self(421); - /// [422 Unprocessable Content](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.422) - pub const UNPROCESSABLE_CONTENT: Self = Self(422); - /// [426 Upgrade Required](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.426) - pub const UPGRADE_REQUIRED: Self = Self(426); - /// [500 Internal Server Error](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.500) - pub const INTERNAL_SERVER_ERROR: Self = Self(500); - /// [501 Not Implemented](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.501) - pub const NOT_IMPLEMENTED: Self = Self(501); - /// [502 Bad Gateway](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.502) - pub const BAD_GATEWAY: Self = Self(502); - /// [503 Service Unavailable](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.503) - pub const SERVICE_UNAVAILABLE: Self = Self(503); - /// [504 Gateway Timeout](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.504) - pub const GATEWAY_TIMEOUT: Self = Self(504); - /// [505 HTTP Version Not Supported](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#status.505) - pub const HTTP_VERSION_NOT_SUPPORTED: Self = Self(505); - - pub(crate) fn reason_phrase(&self) -> Option<&'static str> { - match self.0 { - 100 => Some("Continue"), - 101 => Some("Switching Protocols"), - 102 => Some("Processing"), - 103 => Some("Early Hints"), - 200 => Some("OK"), - 201 => Some("Created"), - 202 => Some("Accepted"), - 203 => Some("Non-Authoritative Information"), - 204 => Some("No Content"), - 205 => Some("Reset Content"), - 206 => Some("Partial Content"), - 207 => Some("Multi-Status"), - 208 => Some("Already Reported"), - 226 => Some("IM Used"), - 300 => Some("Multiple Choices"), - 301 => Some("Moved Permanently"), - 302 => Some("Found"), - 303 => Some("See Other"), - 304 => Some("Not Modified"), - 305 => Some("Use Proxy"), - 307 => Some("Temporary Redirect"), - 308 => Some("Permanent Redirect"), - 400 => Some("Bad Request"), - 401 => Some("Unauthorized"), - 402 => Some("Payment Required"), - 403 => Some("Forbidden"), - 404 => Some("Not Found"), - 405 => Some("Method Not Allowed"), - 406 => Some("Not Acceptable"), - 407 => Some("Proxy Authentication Required"), - 408 => Some("Request Timeout"), - 409 => Some("Conflict"), - 410 => Some("Gone"), - 411 => Some("Length Required"), - 412 => Some("Precondition Failed"), - 413 => Some("Content Too Large"), - 414 => Some("URI Too Long"), - 415 => Some("Unsupported Media Type"), - 416 => Some("Range Not Satisfiable"), - 417 => Some("Expectation Failed"), - 421 => Some("Misdirected Request"), - 422 => Some("Unprocessable Content"), - 423 => Some("Locked"), - 424 => Some("Failed Dependency"), - 425 => Some("Too Early"), - 426 => Some("Upgrade Required"), - 428 => Some("Precondition Required"), - 429 => Some("Too Many Requests"), - 431 => Some("Request Header Fields Too Large"), - 451 => Some("Unavailable For Legal Reasons"), - 500 => Some("Internal Server Error"), - 501 => Some("Not Implemented"), - 502 => Some("Bad Gateway"), - 503 => Some("Service Unavailable"), - 504 => Some("Gateway Timeout"), - 505 => Some("HTTP Version Not Supported"), - 506 => Some("Variant Also Negotiates"), - 507 => Some("Insufficient Storage"), - 508 => Some("Loop Detected"), - 510 => Some("Not Extended"), - 511 => Some("Network Authentication Required"), - _ => None, - } - } -} - -impl Deref for Status { - type Target = u16; - - #[inline] - fn deref(&self) -> &u16 { - &self.0 - } -} - -impl AsRef for Status { - #[inline] - fn as_ref(&self) -> &u16 { - &self.0 - } -} - -impl Borrow for Status { - #[inline] - fn borrow(&self) -> &u16 { - &self.0 - } -} - -impl TryFrom for Status { - type Error = InvalidStatus; - - #[inline] - fn try_from(code: u16) -> Result { - if (0..=999).contains(&code) { - Ok(Self(code)) - } else { - Err(InvalidStatus(code)) - } - } -} - -impl From for u16 { - #[inline] - fn from(status: Status) -> Self { - status.0 - } -} - -impl fmt::Display for Status { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} {}", self.0, self.reason_phrase().unwrap_or("")) - } -} - -/// Error returned by [`Status::try_from`]. -#[allow(missing_copy_implementations)] -#[derive(Debug, Clone)] -pub struct InvalidStatus(u16); - -impl fmt::Display for InvalidStatus { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "The HTTP status code should be between 0 and 999, '{}' found", - self.0 - ) - } -} - -impl Error for InvalidStatus {} diff --git a/src/server.rs b/src/server.rs index 07f8692..ed78e92 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,13 +1,12 @@ -use crate::io::{decode_request_body, decode_request_headers}; -use crate::io::{encode_response, BUFFER_CAPACITY}; -use crate::model::{ - HeaderName, HeaderValue, InvalidHeader, Request, RequestBuilder, Response, Status, -}; +use crate::io::{decode_request_body, decode_request_headers, encode_response, BUFFER_CAPACITY}; +use crate::model::header::{InvalidHeaderValue, CONNECTION, CONTENT_TYPE, EXPECT, SERVER}; +use crate::model::request::Builder as RequestBuilder; +use crate::model::{Body, HeaderValue, Request, Response, StatusCode, Version}; use std::fmt; use std::io::{copy, sink, BufReader, BufWriter, Error, ErrorKind, Result, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::sync::{Arc, Condvar, Mutex}; -use std::thread::{Builder, JoinHandle}; +use std::thread::{Builder as ThreadBuilder, JoinHandle}; use std::time::Duration; /// An HTTP server. @@ -18,15 +17,15 @@ use std::time::Duration; /// ```no_run /// use std::net::{Ipv4Addr, Ipv6Addr}; /// use oxhttp::Server; -/// use oxhttp::model::{Response, Status}; +/// use oxhttp::model::{Body, Response, StatusCode}; /// use std::time::Duration; /// /// // Builds a new server that returns a 404 everywhere except for "/" where it returns the body 'home' /// let mut server = Server::new(|request| { -/// if request.url().path() == "/" { -/// Response::builder(Status::OK).with_body("home") +/// if request.uri().path() == "/" { +/// Response::builder().body(Body::from("home")).unwrap() /// } else { -/// Response::builder(Status::NOT_FOUND).build() +/// Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap() /// } /// }); /// // We bind the server to localhost on both IPv4 and v6 @@ -41,7 +40,8 @@ use std::time::Duration; /// ``` #[allow(missing_copy_implementations)] pub struct Server { - on_request: Arc Response + Send + Sync + 'static>, + #[allow(clippy::type_complexity)] + on_request: Arc) -> Response + Send + Sync + 'static>, socket_addrs: Vec, timeout: Option, server: Option, @@ -51,7 +51,9 @@ pub struct Server { impl Server { /// Builds the server using the given `on_request` method that builds a `Response` from a given `Request`. #[inline] - pub fn new(on_request: impl Fn(&mut Request) -> Response + Send + Sync + 'static) -> Self { + pub fn new( + on_request: impl Fn(&mut Request) -> Response + Send + Sync + 'static, + ) -> Self { Self { on_request: Arc::new(on_request), socket_addrs: Vec::new(), @@ -75,7 +77,7 @@ impl Server { pub fn with_server_name( mut self, server: impl Into, - ) -> std::result::Result { + ) -> std::result::Result { self.server = Some(HeaderValue::try_from(server.into())?); Ok(self) } @@ -109,7 +111,7 @@ impl Server { let thread_limit = thread_limit.clone(); let on_request = Arc::clone(&self.on_request); let server = self.server.clone(); - Builder::new().name(thread_name).spawn(move || { + ThreadBuilder::new().name(thread_name).spawn(move || { for stream in listener.incoming() { match stream { Ok(stream) => { @@ -127,7 +129,7 @@ impl Server { let thread_guard = thread_limit.as_ref().map(|s| s.lock()); let on_request = Arc::clone(&on_request); let server = server.clone(); - if let Err(error) = Builder::new().name(thread_name).spawn( + if let Err(error) = ThreadBuilder::new().name(thread_name).spawn( move || { if let Err(error) = accept_request(stream, &*on_request, timeout, &server) @@ -182,7 +184,7 @@ impl ListeningServer { fn accept_request( mut stream: TcpStream, - on_request: &dyn Fn(&mut Request) -> Response, + on_request: &dyn Fn(&mut Request) -> Response, timeout: Option, server: &Option, ) -> Result<()> { @@ -195,14 +197,14 @@ fn accept_request( { Ok(request) => { // Handles Expect header - if let Some(expect) = request.header(&HeaderName::EXPECT).cloned() { - if expect.eq_ignore_ascii_case(b"100-continue") { + if let Some(expect) = request.headers_ref().unwrap().get(EXPECT).cloned() { + if expect.as_bytes().eq_ignore_ascii_case(b"100-continue") { stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n")?; read_body_and_build_response(request, reader, on_request) } else { ( build_text_response( - Status::EXPECTATION_FAILED, + StatusCode::EXPECTATION_FAILED, format!( "Expect header value '{}' is not supported.", String::from_utf8_lossy(expect.as_ref()) @@ -227,11 +229,10 @@ fn accept_request( // Additional headers if let Some(server) = server { - if !response.headers().contains(&HeaderName::SERVER) { - response - .headers_mut() - .set(HeaderName::SERVER, server.clone()) - } + response + .headers_mut() + .entry(SERVER) + .or_insert_with(|| server.clone()); } stream = encode_response( @@ -253,22 +254,30 @@ enum ConnectionState { fn read_body_and_build_response( request: RequestBuilder, reader: BufReader, - on_request: &dyn Fn(&mut Request) -> Response, -) -> (Response, ConnectionState) { + on_request: &dyn Fn(&mut Request) -> Response, +) -> (Response, ConnectionState) { match decode_request_body(request, reader) { Ok(mut request) => { let response = on_request(&mut request); // We make sure to finish reading the body if let Err(error) = copy(request.body_mut(), &mut sink()) { - (build_error(error), ConnectionState::Close) //TODO: ignore? + (build_error(error), ConnectionState::Close) // TODO: ignore? } else { let connection_state = request - .header(&HeaderName::CONNECTION) + .headers() + .get(CONNECTION) .and_then(|v| { - v.eq_ignore_ascii_case(b"close") + v.as_bytes() + .eq_ignore_ascii_case(b"close") .then_some(ConnectionState::Close) }) - .unwrap_or(ConnectionState::KeepAlive); + .unwrap_or_else(|| { + if request.version() <= Version::HTTP_10 { + ConnectionState::Close + } else { + ConnectionState::KeepAlive + } + }); (response, connection_state) } } @@ -276,22 +285,23 @@ fn read_body_and_build_response( } } -fn build_error(error: Error) -> Response { +fn build_error(error: Error) -> Response { build_text_response( match error.kind() { - ErrorKind::TimedOut => Status::REQUEST_TIMEOUT, - ErrorKind::InvalidData => Status::BAD_REQUEST, - _ => Status::INTERNAL_SERVER_ERROR, + ErrorKind::TimedOut => StatusCode::REQUEST_TIMEOUT, + ErrorKind::InvalidData => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, }, error.to_string(), ) } -fn build_text_response(status: Status, text: String) -> Response { - Response::builder(status) - .with_header(HeaderName::CONTENT_TYPE, "text/plain; charset=utf-8") +fn build_text_response(status: StatusCode, text: String) -> Response { + Response::builder() + .status(status) + .header(CONTENT_TYPE, "text/plain; charset=utf-8") + .body(Body::from(text)) .unwrap() - .with_body(text) } /// Dumb semaphore allowing to overflow capacity @@ -344,7 +354,6 @@ impl Drop for SemaphoreGuard { #[cfg(test)] mod tests { use super::*; - use crate::model::Status; use std::io::Read; use std::net::{Ipv4Addr, Ipv6Addr}; use std::thread::sleep; @@ -385,10 +394,13 @@ mod tests { responses: impl IntoIterator, ) -> Result<()> { Server::new(|request| { - if request.url().path() == "/" { - Response::builder(Status::OK).with_body("home") + if request.uri().path() == "/" { + Response::builder().body(Body::from("home")).unwrap() } else { - Response::builder(Status::NOT_FOUND).build() + Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty()) + .unwrap() } }) .bind((Ipv4Addr::LOCALHOST, server_port)) @@ -413,7 +425,7 @@ mod tests { let server_port = 9996; let request = b"GET / HTTP/1.1\nhost: localhost:9999\n\n"; let response = b"HTTP/1.1 200 OK\r\nserver: OxHTTP/1.0\r\ncontent-length: 4\r\n\r\nhome"; - Server::new(|_| Response::builder(Status::OK).with_body("home")) + Server::new(|_| Response::builder().body(Body::from("home")).unwrap()) .bind((Ipv4Addr::LOCALHOST, server_port)) .bind((Ipv6Addr::LOCALHOST, server_port)) .with_server_name("OxHTTP/1.0")