Skip to content

Commit 7d8db8a

Browse files
authored
Merge pull request #86 from Fishrock123/configuration
feat: unstable HttpClient Config
2 parents 6ba80b3 + a268be7 commit 7d8db8a

File tree

8 files changed

+464
-36
lines changed

8 files changed

+464
-36
lines changed

Cargo.toml

+6-2
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ h1_client = ["async-h1", "async-std", "deadpool", "futures"]
2727
native_client = ["curl_client", "wasm_client"]
2828
curl_client = ["isahc", "async-std"]
2929
wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "futures"]
30-
hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util"]
30+
hyper_client = ["hyper", "hyper-tls", "http-types/hyperium_http", "futures-util", "tokio"]
3131

3232
native-tls = ["async-native-tls"]
33-
rustls = ["async-tls"]
33+
rustls = ["async-tls", "rustls_crate"]
34+
35+
unstable-config = []
3436

3537
[dependencies]
3638
async-trait = "0.1.37"
@@ -48,11 +50,13 @@ futures = { version = "0.3.8", optional = true }
4850

4951
# h1_client_rustls
5052
async-tls = { version = "0.10.0", optional = true }
53+
rustls_crate = { version = "0.18", optional = true, package = "rustls" }
5154

5255
# hyper_client
5356
hyper = { version = "0.13.6", features = ["tcp"], optional = true }
5457
hyper-tls = { version = "0.4.3", optional = true }
5558
futures-util = { version = "0.3.5", features = ["io"], optional = true }
59+
tokio = { version = "0.2", features = ["time"], optional = true }
5660

5761
# curl_client
5862
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]

src/config.rs

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//! Configuration for `HttpClient`s.
2+
3+
use std::fmt::Debug;
4+
use std::time::Duration;
5+
6+
/// Configuration for `HttpClient`s.
7+
#[non_exhaustive]
8+
#[derive(Clone)]
9+
pub struct Config {
10+
/// HTTP/1.1 `keep-alive` (connection pooling).
11+
///
12+
/// Default: `true`.
13+
pub http_keep_alive: bool,
14+
/// TCP `NO_DELAY`.
15+
///
16+
/// Default: `false`.
17+
pub tcp_no_delay: bool,
18+
/// Connection timeout duration.
19+
///
20+
/// Default: `Some(Duration::from_secs(60))`.
21+
pub timeout: Option<Duration>,
22+
/// TLS Configuration (Rustls)
23+
#[cfg(all(feature = "h1_client", feature = "rustls"))]
24+
pub tls_config: Option<std::sync::Arc<rustls_crate::ClientConfig>>,
25+
/// TLS Configuration (Native TLS)
26+
#[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))]
27+
pub tls_config: Option<std::sync::Arc<async_native_tls::TlsConnector>>,
28+
}
29+
30+
impl Debug for Config {
31+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32+
let mut dbg_struct = f.debug_struct("Config");
33+
dbg_struct
34+
.field("http_keep_alive", &self.http_keep_alive)
35+
.field("tcp_no_delay", &self.tcp_no_delay)
36+
.field("timeout", &self.timeout);
37+
38+
#[cfg(all(feature = "h1_client", feature = "rustls"))]
39+
{
40+
if self.tls_config.is_some() {
41+
dbg_struct.field("tls_config", &"Some(rustls::ClientConfig)");
42+
} else {
43+
dbg_struct.field("tls_config", &"None");
44+
}
45+
}
46+
#[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))]
47+
{
48+
dbg_struct.field("tls_config", &self.tls_config);
49+
}
50+
51+
dbg_struct.finish()
52+
}
53+
}
54+
55+
impl Config {
56+
/// Construct new empty config.
57+
pub fn new() -> Self {
58+
Self {
59+
http_keep_alive: true,
60+
tcp_no_delay: false,
61+
timeout: Some(Duration::from_secs(60)),
62+
#[cfg(all(feature = "h1_client", any(feature = "rustls", feature = "native-tls")))]
63+
tls_config: None,
64+
}
65+
}
66+
}
67+
68+
impl Default for Config {
69+
fn default() -> Self {
70+
Self::new()
71+
}
72+
}
73+
74+
impl Config {
75+
/// Set HTTP/1.1 `keep-alive` (connection pooling).
76+
pub fn set_http_keep_alive(mut self, keep_alive: bool) -> Self {
77+
self.http_keep_alive = keep_alive;
78+
self
79+
}
80+
81+
/// Set TCP `NO_DELAY`.
82+
pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
83+
self.tcp_no_delay = no_delay;
84+
self
85+
}
86+
87+
/// Set connection timeout duration.
88+
pub fn set_timeout(mut self, timeout: Option<Duration>) -> Self {
89+
self.timeout = timeout;
90+
self
91+
}
92+
93+
/// Set TLS Configuration (Rustls)
94+
#[cfg(all(feature = "h1_client", feature = "rustls"))]
95+
pub fn set_tls_config(
96+
mut self,
97+
tls_config: Option<std::sync::Arc<rustls_crate::ClientConfig>>,
98+
) -> Self {
99+
self.tls_config = tls_config;
100+
self
101+
}
102+
/// Set TLS Configuration (Native TLS)
103+
#[cfg(all(feature = "h1_client", feature = "native-tls", not(feature = "rustls")))]
104+
pub fn set_tls_config(
105+
mut self,
106+
tls_config: Option<std::sync::Arc<async_native_tls::TlsConnector>>,
107+
) -> Self {
108+
self.tls_config = tls_config;
109+
self
110+
}
111+
}

src/h1/mod.rs

+95-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
//! http-client implementation for async-h1, with connecton pooling ("Keep-Alive").
22
3+
#[cfg(feature = "unstable-config")]
4+
use std::convert::{Infallible, TryFrom};
5+
36
use std::fmt::Debug;
47
use std::net::SocketAddr;
8+
use std::sync::Arc;
59

610
use async_h1::client;
711
use async_std::net::TcpStream;
@@ -17,6 +21,8 @@ cfg_if::cfg_if! {
1721
}
1822
}
1923

24+
use crate::Config;
25+
2026
use super::{async_trait, Error, HttpClient, Request, Response};
2127

2228
mod tcp;
@@ -40,6 +46,7 @@ pub struct H1Client {
4046
#[cfg(any(feature = "native-tls", feature = "rustls"))]
4147
https_pools: HttpsPool,
4248
max_concurrent_connections: usize,
49+
config: Arc<Config>,
4350
}
4451

4552
impl Debug for H1Client {
@@ -79,6 +86,7 @@ impl Debug for H1Client {
7986
"max_concurrent_connections",
8087
&self.max_concurrent_connections,
8188
)
89+
.field("config", &self.config)
8290
.finish()
8391
}
8492
}
@@ -97,6 +105,7 @@ impl H1Client {
97105
#[cfg(any(feature = "native-tls", feature = "rustls"))]
98106
https_pools: DashMap::new(),
99107
max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS,
108+
config: Arc::new(Config::default()),
100109
}
101110
}
102111

@@ -107,6 +116,7 @@ impl H1Client {
107116
#[cfg(any(feature = "native-tls", feature = "rustls"))]
108117
https_pools: DashMap::new(),
109118
max_concurrent_connections: max,
119+
config: Arc::new(Config::default()),
110120
}
111121
}
112122
}
@@ -147,12 +157,43 @@ impl HttpClient for H1Client {
147157
for (idx, addr) in addrs.into_iter().enumerate() {
148158
let has_another_addr = idx != max_addrs_idx;
149159

160+
#[cfg(feature = "unstable-config")]
161+
if !self.config.http_keep_alive {
162+
match scheme {
163+
"http" => {
164+
let stream = async_std::net::TcpStream::connect(addr).await?;
165+
req.set_peer_addr(stream.peer_addr().ok());
166+
req.set_local_addr(stream.local_addr().ok());
167+
let tcp_conn = client::connect(stream, req);
168+
return if let Some(timeout) = self.config.timeout {
169+
async_std::future::timeout(timeout, tcp_conn).await?
170+
} else {
171+
tcp_conn.await
172+
};
173+
}
174+
#[cfg(any(feature = "native-tls", feature = "rustls"))]
175+
"https" => {
176+
let raw_stream = async_std::net::TcpStream::connect(addr).await?;
177+
req.set_peer_addr(raw_stream.peer_addr().ok());
178+
req.set_local_addr(raw_stream.local_addr().ok());
179+
let tls_stream = tls::add_tls(&host, raw_stream, &self.config).await?;
180+
let tsl_conn = client::connect(tls_stream, req);
181+
return if let Some(timeout) = self.config.timeout {
182+
async_std::future::timeout(timeout, tsl_conn).await?
183+
} else {
184+
tsl_conn.await
185+
};
186+
}
187+
_ => unreachable!(),
188+
}
189+
}
190+
150191
match scheme {
151192
"http" => {
152193
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
153194
pool_ref
154195
} else {
155-
let manager = TcpConnection::new(addr);
196+
let manager = TcpConnection::new(addr, self.config.clone());
156197
let pool = Pool::<TcpStream, std::io::Error>::new(
157198
manager,
158199
self.max_concurrent_connections,
@@ -168,19 +209,28 @@ impl HttpClient for H1Client {
168209
let stream = match pool.get().await {
169210
Ok(s) => s,
170211
Err(_) if has_another_addr => continue,
171-
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
212+
Err(e) => return Err(Error::from_str(400, e.to_string())),
172213
};
173214

174215
req.set_peer_addr(stream.peer_addr().ok());
175216
req.set_local_addr(stream.local_addr().ok());
176-
return client::connect(TcpConnWrapper::new(stream), req).await;
217+
218+
let tcp_conn = client::connect(TcpConnWrapper::new(stream), req);
219+
#[cfg(feature = "unstable-config")]
220+
return if let Some(timeout) = self.config.timeout {
221+
async_std::future::timeout(timeout, tcp_conn).await?
222+
} else {
223+
tcp_conn.await
224+
};
225+
#[cfg(not(feature = "unstable-config"))]
226+
return tcp_conn.await;
177227
}
178228
#[cfg(any(feature = "native-tls", feature = "rustls"))]
179229
"https" => {
180230
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
181231
pool_ref
182232
} else {
183-
let manager = TlsConnection::new(host.clone(), addr);
233+
let manager = TlsConnection::new(host.clone(), addr, self.config.clone());
184234
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
185235
manager,
186236
self.max_concurrent_connections,
@@ -196,13 +246,21 @@ impl HttpClient for H1Client {
196246
let stream = match pool.get().await {
197247
Ok(s) => s,
198248
Err(_) if has_another_addr => continue,
199-
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
249+
Err(e) => return Err(Error::from_str(400, e.to_string())),
200250
};
201251

202252
req.set_peer_addr(stream.get_ref().peer_addr().ok());
203253
req.set_local_addr(stream.get_ref().local_addr().ok());
204254

205-
return client::connect(TlsConnWrapper::new(stream), req).await;
255+
let tls_conn = client::connect(TlsConnWrapper::new(stream), req);
256+
#[cfg(feature = "unstable-config")]
257+
return if let Some(timeout) = self.config.timeout {
258+
async_std::future::timeout(timeout, tls_conn).await?
259+
} else {
260+
tls_conn.await
261+
};
262+
#[cfg(not(feature = "unstable-config"))]
263+
return tls_conn.await;
206264
}
207265
_ => unreachable!(),
208266
}
@@ -213,6 +271,37 @@ impl HttpClient for H1Client {
213271
"missing valid address",
214272
))
215273
}
274+
275+
#[cfg(feature = "unstable-config")]
276+
/// Override the existing configuration with new configuration.
277+
///
278+
/// Config options may not impact existing connections.
279+
fn set_config(&mut self, config: Config) -> http_types::Result<()> {
280+
self.config = Arc::new(config);
281+
282+
Ok(())
283+
}
284+
285+
#[cfg(feature = "unstable-config")]
286+
/// Get the current configuration.
287+
fn config(&self) -> &Config {
288+
&*self.config
289+
}
290+
}
291+
292+
#[cfg(feature = "unstable-config")]
293+
impl TryFrom<Config> for H1Client {
294+
type Error = Infallible;
295+
296+
fn try_from(config: Config) -> Result<Self, Self::Error> {
297+
Ok(Self {
298+
http_pools: DashMap::new(),
299+
#[cfg(any(feature = "native-tls", feature = "rustls"))]
300+
https_pools: DashMap::new(),
301+
max_concurrent_connections: DEFAULT_MAX_CONCURRENT_CONNECTIONS,
302+
config: Arc::new(config),
303+
})
304+
}
216305
}
217306

218307
#[cfg(test)]

src/h1/tcp.rs

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
1-
use std::fmt::Debug;
21
use std::net::SocketAddr;
32
use std::pin::Pin;
3+
use std::sync::Arc;
44

55
use async_std::net::TcpStream;
66
use async_trait::async_trait;
77
use deadpool::managed::{Manager, Object, RecycleResult};
88
use futures::io::{AsyncRead, AsyncWrite};
99
use futures::task::{Context, Poll};
1010

11-
#[derive(Clone, Debug)]
11+
use crate::Config;
12+
13+
#[derive(Clone)]
14+
#[cfg_attr(not(feature = "rustls"), derive(std::fmt::Debug))]
1215
pub(crate) struct TcpConnection {
1316
addr: SocketAddr,
17+
config: Arc<Config>,
1418
}
19+
1520
impl TcpConnection {
16-
pub(crate) fn new(addr: SocketAddr) -> Self {
17-
Self { addr }
21+
pub(crate) fn new(addr: SocketAddr, config: Arc<Config>) -> Self {
22+
Self { addr, config }
1823
}
1924
}
2025

@@ -58,12 +63,21 @@ impl AsyncWrite for TcpConnWrapper {
5863
#[async_trait]
5964
impl Manager<TcpStream, std::io::Error> for TcpConnection {
6065
async fn create(&self) -> Result<TcpStream, std::io::Error> {
61-
TcpStream::connect(self.addr).await
66+
let tcp_stream = TcpStream::connect(self.addr).await?;
67+
68+
#[cfg(feature = "unstable-config")]
69+
tcp_stream.set_nodelay(self.config.tcp_no_delay)?;
70+
71+
Ok(tcp_stream)
6272
}
6373

6474
async fn recycle(&self, conn: &mut TcpStream) -> RecycleResult<std::io::Error> {
6575
let mut buf = [0; 4];
6676
let mut cx = Context::from_waker(futures::task::noop_waker_ref());
77+
78+
#[cfg(feature = "unstable-config")]
79+
conn.set_nodelay(self.config.tcp_no_delay)?;
80+
6781
match Pin::new(conn).poll_read(&mut cx, &mut buf) {
6882
Poll::Ready(Err(error)) => Err(error),
6983
Poll::Ready(Ok(bytes)) if bytes == 0 => Err(std::io::Error::new(

0 commit comments

Comments
 (0)