Skip to content

Commit 74673f8

Browse files
Merge branch 'main' into NODE-6626/server-side
2 parents 689706b + 654069f commit 74673f8

33 files changed

+2038
-184
lines changed

src/client-side-encryption/auto_encrypter.ts

+12-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { kDecorateResult } from '../constants';
1111
import { getMongoDBClientEncryption } from '../deps';
1212
import { MongoRuntimeError } from '../error';
1313
import { MongoClient, type MongoClientOptions } from '../mongo_client';
14+
import { type Abortable } from '../mongo_types';
1415
import { MongoDBCollectionNamespace } from '../utils';
1516
import { autoSelectSocketOptions } from './client_encryption';
1617
import * as cryptoCallbacks from './crypto_callbacks';
@@ -372,8 +373,10 @@ export class AutoEncrypter {
372373
async encrypt(
373374
ns: string,
374375
cmd: Document,
375-
options: CommandOptions = {}
376+
options: CommandOptions & Abortable = {}
376377
): Promise<Document | Uint8Array> {
378+
options.signal?.throwIfAborted();
379+
377380
if (this._bypassEncryption) {
378381
// If `bypassAutoEncryption` has been specified, don't encrypt
379382
return cmd;
@@ -398,7 +401,7 @@ export class AutoEncrypter {
398401
socketOptions: autoSelectSocketOptions(this._client.s.options)
399402
});
400403

401-
return deserialize(await stateMachine.execute(this, context, options.timeoutContext), {
404+
return deserialize(await stateMachine.execute(this, context, options), {
402405
promoteValues: false,
403406
promoteLongs: false
404407
});
@@ -407,7 +410,12 @@ export class AutoEncrypter {
407410
/**
408411
* Decrypt a command response
409412
*/
410-
async decrypt(response: Uint8Array, options: CommandOptions = {}): Promise<Uint8Array> {
413+
async decrypt(
414+
response: Uint8Array,
415+
options: CommandOptions & Abortable = {}
416+
): Promise<Uint8Array> {
417+
options.signal?.throwIfAborted();
418+
411419
const context = this._mongocrypt.makeDecryptionContext(response);
412420

413421
context.id = this._contextCounter++;
@@ -419,11 +427,7 @@ export class AutoEncrypter {
419427
socketOptions: autoSelectSocketOptions(this._client.s.options)
420428
});
421429

422-
return await stateMachine.execute(
423-
this,
424-
context,
425-
options.timeoutContext?.csotEnabled() ? options.timeoutContext : undefined
426-
);
430+
return await stateMachine.execute(this, context, options);
427431
}
428432

429433
/**

src/client-side-encryption/client_encryption.ts

+6-4
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ export class ClientEncryption {
225225
TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }));
226226

227227
const dataKey = deserialize(
228-
await stateMachine.execute(this, context, timeoutContext)
228+
await stateMachine.execute(this, context, { timeoutContext })
229229
) as DataKey;
230230

231231
const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
@@ -293,7 +293,9 @@ export class ClientEncryption {
293293
resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })
294294
);
295295

296-
const { v: dataKeys } = deserialize(await stateMachine.execute(this, context, timeoutContext));
296+
const { v: dataKeys } = deserialize(
297+
await stateMachine.execute(this, context, { timeoutContext })
298+
);
297299
if (dataKeys.length === 0) {
298300
return {};
299301
}
@@ -696,7 +698,7 @@ export class ClientEncryption {
696698
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
697699
: undefined;
698700

699-
const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
701+
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));
700702

701703
return v;
702704
}
@@ -780,7 +782,7 @@ export class ClientEncryption {
780782
this._timeoutMS != null
781783
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
782784
: undefined;
783-
const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
785+
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));
784786
return v;
785787
}
786788
}

src/client-side-encryption/state_machine.ts

+80-36
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor';
1515
import { getSocks, type SocksLib } from '../deps';
1616
import { MongoOperationTimeoutError } from '../error';
1717
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
18+
import { type Abortable } from '../mongo_types';
1819
import { Timeout, type TimeoutContext, TimeoutError } from '../timeout';
19-
import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils';
20+
import {
21+
addAbortListener,
22+
BufferPool,
23+
kDispose,
24+
MongoDBCollectionNamespace,
25+
promiseWithResolvers
26+
} from '../utils';
2027
import { autoSelectSocketOptions, type DataKey } from './client_encryption';
2128
import { MongoCryptError } from './errors';
2229
import { type MongocryptdManager } from './mongocryptd_manager';
@@ -189,7 +196,7 @@ export class StateMachine {
189196
async execute(
190197
executor: StateMachineExecutable,
191198
context: MongoCryptContext,
192-
timeoutContext?: TimeoutContext
199+
options: { timeoutContext?: TimeoutContext } & Abortable
193200
): Promise<Uint8Array> {
194201
const keyVaultNamespace = executor._keyVaultNamespace;
195202
const keyVaultClient = executor._keyVaultClient;
@@ -199,6 +206,7 @@ export class StateMachine {
199206
let result: Uint8Array | null = null;
200207

201208
while (context.state !== MONGOCRYPT_CTX_DONE && context.state !== MONGOCRYPT_CTX_ERROR) {
209+
options.signal?.throwIfAborted();
202210
debug(`[context#${context.id}] ${stateToString.get(context.state) || context.state}`);
203211

204212
switch (context.state) {
@@ -214,7 +222,7 @@ export class StateMachine {
214222
metaDataClient,
215223
context.ns,
216224
filter,
217-
timeoutContext
225+
options
218226
);
219227
if (collInfo) {
220228
context.addMongoOperationResponse(collInfo);
@@ -235,9 +243,9 @@ export class StateMachine {
235243
// When we are using the shared library, we don't have a mongocryptd manager.
236244
const markedCommand: Uint8Array = mongocryptdManager
237245
? await mongocryptdManager.withRespawn(
238-
this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext)
246+
this.markCommand.bind(this, mongocryptdClient, context.ns, command, options)
239247
)
240-
: await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext);
248+
: await this.markCommand(mongocryptdClient, context.ns, command, options);
241249

242250
context.addMongoOperationResponse(markedCommand);
243251
context.finishMongoOperation();
@@ -246,12 +254,7 @@ export class StateMachine {
246254

247255
case MONGOCRYPT_CTX_NEED_MONGO_KEYS: {
248256
const filter = context.nextMongoOperation();
249-
const keys = await this.fetchKeys(
250-
keyVaultClient,
251-
keyVaultNamespace,
252-
filter,
253-
timeoutContext
254-
);
257+
const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter, options);
255258

256259
if (keys.length === 0) {
257260
// See docs on EMPTY_V
@@ -273,7 +276,7 @@ export class StateMachine {
273276
}
274277

275278
case MONGOCRYPT_CTX_NEED_KMS: {
276-
await Promise.all(this.requests(context, timeoutContext));
279+
await Promise.all(this.requests(context, options));
277280
context.finishKMSRequests();
278281
break;
279282
}
@@ -315,11 +318,13 @@ export class StateMachine {
315318
* @param kmsContext - A C++ KMS context returned from the bindings
316319
* @returns A promise that resolves when the KMS reply has be fully parsed
317320
*/
318-
async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise<void> {
321+
async kmsRequest(
322+
request: MongoCryptKMSRequest,
323+
options?: { timeoutContext?: TimeoutContext } & Abortable
324+
): Promise<void> {
319325
const parsedUrl = request.endpoint.split(':');
320326
const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT;
321-
const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {});
322-
const options: tls.ConnectionOptions & {
327+
const socketOptions: tls.ConnectionOptions & {
323328
host: string;
324329
port: number;
325330
autoSelectFamily?: boolean;
@@ -328,7 +333,7 @@ export class StateMachine {
328333
host: parsedUrl[0],
329334
servername: parsedUrl[0],
330335
port,
331-
...socketOptions
336+
...autoSelectSocketOptions(this.options.socketOptions || {})
332337
};
333338
const message = request.message;
334339
const buffer = new BufferPool();
@@ -363,7 +368,7 @@ export class StateMachine {
363368
throw error;
364369
}
365370
try {
366-
await this.setTlsOptions(providerTlsOptions, options);
371+
await this.setTlsOptions(providerTlsOptions, socketOptions);
367372
} catch (err) {
368373
throw onerror(err);
369374
}
@@ -380,23 +385,25 @@ export class StateMachine {
380385
.once('close', () => rejectOnNetSocketError(onclose()))
381386
.once('connect', () => resolveOnNetSocketConnect());
382387

388+
let abortListener;
389+
383390
try {
384391
if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) {
385392
const netSocketOptions = {
393+
...socketOptions,
386394
host: this.options.proxyOptions.proxyHost,
387-
port: this.options.proxyOptions.proxyPort || 1080,
388-
...socketOptions
395+
port: this.options.proxyOptions.proxyPort || 1080
389396
};
390397
netSocket.connect(netSocketOptions);
391398
await willConnect;
392399

393400
try {
394401
socks ??= loadSocks();
395-
options.socket = (
402+
socketOptions.socket = (
396403
await socks.SocksClient.createConnection({
397404
existing_socket: netSocket,
398405
command: 'connect',
399-
destination: { host: options.host, port: options.port },
406+
destination: { host: socketOptions.host, port: socketOptions.port },
400407
proxy: {
401408
// host and port are ignored because we pass existing_socket
402409
host: 'iLoveJavaScript',
@@ -412,7 +419,7 @@ export class StateMachine {
412419
}
413420
}
414421

415-
socket = tls.connect(options, () => {
422+
socket = tls.connect(socketOptions, () => {
416423
socket.write(message);
417424
});
418425

@@ -422,6 +429,11 @@ export class StateMachine {
422429
resolve
423430
} = promiseWithResolvers<void>();
424431

432+
abortListener = addAbortListener(options?.signal, function () {
433+
destroySockets();
434+
rejectOnTlsSocketError(this.reason);
435+
});
436+
425437
socket
426438
.once('error', err => rejectOnTlsSocketError(onerror(err)))
427439
.once('close', () => rejectOnTlsSocketError(onclose()))
@@ -436,8 +448,11 @@ export class StateMachine {
436448
resolve();
437449
}
438450
});
439-
await (timeoutContext?.csotEnabled()
440-
? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)])
451+
await (options?.timeoutContext?.csotEnabled()
452+
? Promise.all([
453+
willResolveKmsRequest,
454+
Timeout.expires(options.timeoutContext?.remainingTimeMS)
455+
])
441456
: willResolveKmsRequest);
442457
} catch (error) {
443458
if (error instanceof TimeoutError)
@@ -446,16 +461,17 @@ export class StateMachine {
446461
} finally {
447462
// There's no need for any more activity on this socket at this point.
448463
destroySockets();
464+
abortListener?.[kDispose]();
449465
}
450466
}
451467

452-
*requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) {
468+
*requests(context: MongoCryptContext, options?: { timeoutContext?: TimeoutContext } & Abortable) {
453469
for (
454470
let request = context.nextKMSRequest();
455471
request != null;
456472
request = context.nextKMSRequest()
457473
) {
458-
yield this.kmsRequest(request, timeoutContext);
474+
yield this.kmsRequest(request, options);
459475
}
460476
}
461477

@@ -516,14 +532,16 @@ export class StateMachine {
516532
client: MongoClient,
517533
ns: string,
518534
filter: Document,
519-
timeoutContext?: TimeoutContext
535+
options?: { timeoutContext?: TimeoutContext } & Abortable
520536
): Promise<Uint8Array | null> {
521537
const { db } = MongoDBCollectionNamespace.fromString(ns);
522538

523539
const cursor = client.db(db).listCollections(filter, {
524540
promoteLongs: false,
525541
promoteValues: false,
526-
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
542+
timeoutContext:
543+
options?.timeoutContext && new CursorTimeoutContext(options?.timeoutContext, Symbol()),
544+
signal: options?.signal
527545
});
528546

529547
// There is always exactly zero or one matching documents, so this should always exhaust the cursor
@@ -547,17 +565,30 @@ export class StateMachine {
547565
client: MongoClient,
548566
ns: string,
549567
command: Uint8Array,
550-
timeoutContext?: TimeoutContext
568+
options?: { timeoutContext?: TimeoutContext } & Abortable
551569
): Promise<Uint8Array> {
552570
const { db } = MongoDBCollectionNamespace.fromString(ns);
553571
const bsonOptions = { promoteLongs: false, promoteValues: false };
554572
const rawCommand = deserialize(command, bsonOptions);
555573

574+
const commandOptions: {
575+
timeoutMS?: number;
576+
signal?: AbortSignal;
577+
} = {
578+
timeoutMS: undefined,
579+
signal: undefined
580+
};
581+
582+
if (options?.timeoutContext?.csotEnabled()) {
583+
commandOptions.timeoutMS = options.timeoutContext.remainingTimeMS;
584+
}
585+
if (options?.signal) {
586+
commandOptions.signal = options.signal;
587+
}
588+
556589
const response = await client.db(db).command(rawCommand, {
557590
...bsonOptions,
558-
...(timeoutContext?.csotEnabled()
559-
? { timeoutMS: timeoutContext?.remainingTimeMS }
560-
: undefined)
591+
...commandOptions
561592
});
562593

563594
return serialize(response, this.bsonOptions);
@@ -575,17 +606,30 @@ export class StateMachine {
575606
client: MongoClient,
576607
keyVaultNamespace: string,
577608
filter: Uint8Array,
578-
timeoutContext?: TimeoutContext
609+
options?: { timeoutContext?: TimeoutContext } & Abortable
579610
): Promise<Array<DataKey>> {
580611
const { db: dbName, collection: collectionName } =
581612
MongoDBCollectionNamespace.fromString(keyVaultNamespace);
582613

614+
const commandOptions: {
615+
timeoutContext?: CursorTimeoutContext;
616+
signal?: AbortSignal;
617+
} = {
618+
timeoutContext: undefined,
619+
signal: undefined
620+
};
621+
622+
if (options?.timeoutContext != null) {
623+
commandOptions.timeoutContext = new CursorTimeoutContext(options.timeoutContext, Symbol());
624+
}
625+
if (options?.signal != null) {
626+
commandOptions.signal = options.signal;
627+
}
628+
583629
return client
584630
.db(dbName)
585631
.collection<DataKey>(collectionName, { readConcern: { level: 'majority' } })
586-
.find(deserialize(filter), {
587-
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
588-
})
632+
.find(deserialize(filter), commandOptions)
589633
.toArray();
590634
}
591635
}

0 commit comments

Comments
 (0)