Skip to content

Commit 4139b56

Browse files
RUST-2006 Add option to configure DEK cache lifetime (#1284)
1 parent 2dadbed commit 4139b56

File tree

16 files changed

+1889
-1016
lines changed

16 files changed

+1889
-1016
lines changed

src/client/csfle.rs

+10
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ impl ClientState {
122122
if opts.bypass_query_analysis == Some(true) {
123123
builder = builder.bypass_query_analysis();
124124
}
125+
if let Some(key_cache_expiration) = opts.key_cache_expiration {
126+
let expiration_ms: u64 = key_cache_expiration.as_millis().try_into().map_err(|_| {
127+
Error::invalid_argument(format!(
128+
"key_cache_expiration must not exceed {} milliseconds, got {:?}",
129+
u64::MAX,
130+
key_cache_expiration
131+
))
132+
})?;
133+
builder = builder.key_cache_expiration(expiration_ms)?;
134+
}
125135
let crypt = builder.build()?;
126136
if opts.extra_option(&EO_CRYPT_SHARED_REQUIRED)? == Some(true)
127137
&& crypt.shared_lib_version().is_none()

src/client/csfle/client_builder.rs

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::time::Duration;
2+
13
use crate::{bson::Document, error::Result, options::ClientOptions, Client};
24

35
use super::options::AutoEncryptionOptions;
@@ -101,6 +103,13 @@ impl EncryptedClientBuilder {
101103
self
102104
}
103105

106+
/// Set the duration of time after which the data encryption key cache should expire. Defaults
107+
/// to 60 seconds if unset.
108+
pub fn key_cache_expiration(mut self, expiration: impl Into<Option<Duration>>) -> Self {
109+
self.enc_opts.key_cache_expiration = expiration.into();
110+
self
111+
}
112+
104113
/// Constructs a new `Client` using automatic encryption. May perform DNS lookups and/or spawn
105114
/// mongocryptd as part of `Client` initialization.
106115
pub async fn build(self) -> Result<Client> {

src/client/csfle/client_encryption.rs

+102-26
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
mod create_data_key;
44
mod encrypt;
55

6+
use std::time::Duration;
7+
68
use mongocrypt::{ctx::KmsProvider, Crypt};
79
use serde::{Deserialize, Serialize};
810
use typed_builder::TypedBuilder;
@@ -61,32 +63,44 @@ impl ClientEncryption {
6163
key_vault_namespace: Namespace,
6264
kms_providers: impl IntoIterator<Item = (KmsProvider, bson::Document, Option<TlsOptions>)>,
6365
) -> Result<Self> {
64-
let kms_providers = KmsProviders::new(kms_providers)?;
65-
let crypt = Crypt::builder()
66-
.kms_providers(&kms_providers.credentials_doc()?)?
67-
.use_need_kms_credentials_state()
68-
.retry_kms(true)?
69-
.use_range_v2()?
70-
.build()?;
71-
let exec = CryptExecutor::new_explicit(
72-
key_vault_client.weak(),
73-
key_vault_namespace.clone(),
74-
kms_providers,
75-
)?;
76-
let key_vault = key_vault_client
77-
.database(&key_vault_namespace.db)
78-
.collection_with_options(
79-
&key_vault_namespace.coll,
80-
CollectionOptions::builder()
81-
.write_concern(WriteConcern::majority())
82-
.read_concern(ReadConcern::majority())
83-
.build(),
84-
);
85-
Ok(ClientEncryption {
86-
crypt,
87-
exec,
88-
key_vault,
89-
})
66+
Self::builder(key_vault_client, key_vault_namespace, kms_providers).build()
67+
}
68+
69+
/// Initialize a builder to construct a [`ClientEncryption`]. Methods on
70+
/// [`ClientEncryptionBuilder`] can be chained to set options.
71+
///
72+
/// ```no_run
73+
/// # use bson::doc;
74+
/// # use mongocrypt::ctx::KmsProvider;
75+
/// # use mongodb::client_encryption::ClientEncryption;
76+
/// # use mongodb::error::Result;
77+
/// # fn func() -> Result<()> {
78+
/// # let kv_client = todo!();
79+
/// # let kv_namespace = todo!();
80+
/// # let local_key = doc! { };
81+
/// let enc = ClientEncryption::builder(
82+
/// kv_client,
83+
/// kv_namespace,
84+
/// [
85+
/// (KmsProvider::Local, doc! { "key": local_key }, None),
86+
/// (KmsProvider::Kmip, doc! { "endpoint": "localhost:5698" }, None),
87+
/// ]
88+
/// )
89+
/// .build()?;
90+
/// # Ok(())
91+
/// # }
92+
/// ```
93+
pub fn builder(
94+
key_vault_client: Client,
95+
key_vault_namespace: Namespace,
96+
kms_providers: impl IntoIterator<Item = (KmsProvider, bson::Document, Option<TlsOptions>)>,
97+
) -> ClientEncryptionBuilder {
98+
ClientEncryptionBuilder {
99+
key_vault_client,
100+
key_vault_namespace,
101+
kms_providers: kms_providers.into_iter().collect(),
102+
key_cache_expiration: None,
103+
}
90104
}
91105

92106
// pub async fn rewrap_many_data_key(&self, _filter: Document, _opts: impl
@@ -189,6 +203,68 @@ impl ClientEncryption {
189203
}
190204
}
191205

206+
/// Builder for constructing a [`ClientEncryption`]. Construct by calling
207+
/// [`ClientEncryption::builder`].
208+
pub struct ClientEncryptionBuilder {
209+
key_vault_client: Client,
210+
key_vault_namespace: Namespace,
211+
kms_providers: Vec<(KmsProvider, bson::Document, Option<TlsOptions>)>,
212+
key_cache_expiration: Option<Duration>,
213+
}
214+
215+
impl ClientEncryptionBuilder {
216+
/// Set the duration of time after which the data encryption key cache should expire. Defaults
217+
/// to 60 seconds if unset.
218+
pub fn key_cache_expiration(mut self, expiration: impl Into<Option<Duration>>) -> Self {
219+
self.key_cache_expiration = expiration.into();
220+
self
221+
}
222+
223+
/// Build the [`ClientEncryption`].
224+
pub fn build(self) -> Result<ClientEncryption> {
225+
let kms_providers = KmsProviders::new(self.kms_providers)?;
226+
227+
let mut crypt_builder = Crypt::builder()
228+
.kms_providers(&kms_providers.credentials_doc()?)?
229+
.use_need_kms_credentials_state()
230+
.use_range_v2()?
231+
.retry_kms(true)?;
232+
if let Some(key_cache_expiration) = self.key_cache_expiration {
233+
let expiration_ms: u64 = key_cache_expiration.as_millis().try_into().map_err(|_| {
234+
Error::invalid_argument(format!(
235+
"key_cache_expiration must not exceed {} milliseconds, got {:?}",
236+
u64::MAX,
237+
key_cache_expiration
238+
))
239+
})?;
240+
crypt_builder = crypt_builder.key_cache_expiration(expiration_ms)?;
241+
}
242+
let crypt = crypt_builder.build()?;
243+
244+
let exec = CryptExecutor::new_explicit(
245+
self.key_vault_client.weak(),
246+
self.key_vault_namespace.clone(),
247+
kms_providers,
248+
)?;
249+
let key_vault = self
250+
.key_vault_client
251+
.database(&self.key_vault_namespace.db)
252+
.collection_with_options(
253+
&self.key_vault_namespace.coll,
254+
CollectionOptions::builder()
255+
.write_concern(WriteConcern::majority())
256+
.read_concern(ReadConcern::majority())
257+
.build(),
258+
);
259+
260+
Ok(ClientEncryption {
261+
crypt,
262+
exec,
263+
key_vault,
264+
})
265+
}
266+
}
267+
192268
/// A KMS-specific key used to encrypt data keys.
193269
#[derive(Debug, Clone, Serialize, Deserialize)]
194270
#[serde(untagged)]

src/client/csfle/options.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::collections::HashMap;
1+
use std::{collections::HashMap, time::Duration};
22

33
use bson::Array;
44
use mongocrypt::ctx::KmsProvider;
@@ -8,6 +8,7 @@ use crate::{
88
bson::{Bson, Document},
99
client::options::TlsOptions,
1010
error::{Error, Result},
11+
serde_util,
1112
Namespace,
1213
};
1314

@@ -59,6 +60,14 @@ pub(crate) struct AutoEncryptionOptions {
5960
#[cfg(test)]
6061
#[serde(skip)]
6162
pub(crate) disable_crypt_shared: Option<bool>,
63+
/// The duration after which the data encryption key cache expires. Defaults to 60 seconds if
64+
/// unset.
65+
#[serde(
66+
default,
67+
rename = "keyExpirationMS",
68+
deserialize_with = "serde_util::deserialize_duration_option_from_u64_millis"
69+
)]
70+
pub(crate) key_cache_expiration: Option<Duration>,
6271
}
6372

6473
fn default_key_vault_namespace() -> Namespace {
@@ -81,6 +90,7 @@ impl AutoEncryptionOptions {
8190
bypass_query_analysis: None,
8291
#[cfg(test)]
8392
disable_crypt_shared: None,
93+
key_cache_expiration: None,
8494
}
8595
}
8696
}

src/cmap/conn.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,8 @@ impl PinnedConnectionHandle {
327327
}
328328
}
329329

330-
/// Retrieve the pinned connection. Will fail if the connection has been unpinned or is still in
331-
/// use.
330+
/// Retrieve the pinned connection. Will fail if the connection has been unpinned or is still
331+
/// in use.
332332
pub(crate) async fn take_connection(&self) -> Result<PooledConnection> {
333333
use tokio::sync::mpsc::error::TryRecvError;
334334
let mut receiver = self.receiver.lock().await;

0 commit comments

Comments
 (0)