@@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor';
15
15
import { getSocks , type SocksLib } from '../deps' ;
16
16
import { MongoOperationTimeoutError } from '../error' ;
17
17
import { type MongoClient , type MongoClientOptions } from '../mongo_client' ;
18
+ import { type Abortable } from '../mongo_types' ;
18
19
import { Timeout , type TimeoutContext , TimeoutError } from '../timeout' ;
19
- import { BufferPool , MongoDBCollectionNamespace , promiseWithResolvers } from '../utils' ;
20
+ import {
21
+ addAbortListener ,
22
+ BufferPool ,
23
+ kDispose ,
24
+ MongoDBCollectionNamespace ,
25
+ promiseWithResolvers
26
+ } from '../utils' ;
20
27
import { autoSelectSocketOptions , type DataKey } from './client_encryption' ;
21
28
import { MongoCryptError } from './errors' ;
22
29
import { type MongocryptdManager } from './mongocryptd_manager' ;
@@ -189,7 +196,7 @@ export class StateMachine {
189
196
async execute (
190
197
executor : StateMachineExecutable ,
191
198
context : MongoCryptContext ,
192
- timeoutContext ?: TimeoutContext
199
+ options : { timeoutContext ?: TimeoutContext } & Abortable
193
200
) : Promise < Uint8Array > {
194
201
const keyVaultNamespace = executor . _keyVaultNamespace ;
195
202
const keyVaultClient = executor . _keyVaultClient ;
@@ -199,6 +206,7 @@ export class StateMachine {
199
206
let result : Uint8Array | null = null ;
200
207
201
208
while ( context . state !== MONGOCRYPT_CTX_DONE && context . state !== MONGOCRYPT_CTX_ERROR ) {
209
+ options . signal ?. throwIfAborted ( ) ;
202
210
debug ( `[context#${ context . id } ] ${ stateToString . get ( context . state ) || context . state } ` ) ;
203
211
204
212
switch ( context . state ) {
@@ -214,7 +222,7 @@ export class StateMachine {
214
222
metaDataClient ,
215
223
context . ns ,
216
224
filter ,
217
- timeoutContext
225
+ options
218
226
) ;
219
227
if ( collInfo ) {
220
228
context . addMongoOperationResponse ( collInfo ) ;
@@ -235,9 +243,9 @@ export class StateMachine {
235
243
// When we are using the shared library, we don't have a mongocryptd manager.
236
244
const markedCommand : Uint8Array = mongocryptdManager
237
245
? await mongocryptdManager . withRespawn (
238
- this . markCommand . bind ( this , mongocryptdClient , context . ns , command , timeoutContext )
246
+ this . markCommand . bind ( this , mongocryptdClient , context . ns , command , options )
239
247
)
240
- : await this . markCommand ( mongocryptdClient , context . ns , command , timeoutContext ) ;
248
+ : await this . markCommand ( mongocryptdClient , context . ns , command , options ) ;
241
249
242
250
context . addMongoOperationResponse ( markedCommand ) ;
243
251
context . finishMongoOperation ( ) ;
@@ -246,12 +254,7 @@ export class StateMachine {
246
254
247
255
case MONGOCRYPT_CTX_NEED_MONGO_KEYS : {
248
256
const filter = context . nextMongoOperation ( ) ;
249
- const keys = await this . fetchKeys (
250
- keyVaultClient ,
251
- keyVaultNamespace ,
252
- filter ,
253
- timeoutContext
254
- ) ;
257
+ const keys = await this . fetchKeys ( keyVaultClient , keyVaultNamespace , filter , options ) ;
255
258
256
259
if ( keys . length === 0 ) {
257
260
// See docs on EMPTY_V
@@ -273,7 +276,7 @@ export class StateMachine {
273
276
}
274
277
275
278
case MONGOCRYPT_CTX_NEED_KMS : {
276
- await Promise . all ( this . requests ( context , timeoutContext ) ) ;
279
+ await Promise . all ( this . requests ( context , options ) ) ;
277
280
context . finishKMSRequests ( ) ;
278
281
break ;
279
282
}
@@ -315,11 +318,13 @@ export class StateMachine {
315
318
* @param kmsContext - A C++ KMS context returned from the bindings
316
319
* @returns A promise that resolves when the KMS reply has be fully parsed
317
320
*/
318
- async kmsRequest ( request : MongoCryptKMSRequest , timeoutContext ?: TimeoutContext ) : Promise < void > {
321
+ async kmsRequest (
322
+ request : MongoCryptKMSRequest ,
323
+ options ?: { timeoutContext ?: TimeoutContext } & Abortable
324
+ ) : Promise < void > {
319
325
const parsedUrl = request . endpoint . split ( ':' ) ;
320
326
const port = parsedUrl [ 1 ] != null ? Number . parseInt ( parsedUrl [ 1 ] , 10 ) : HTTPS_PORT ;
321
- const socketOptions = autoSelectSocketOptions ( this . options . socketOptions || { } ) ;
322
- const options : tls . ConnectionOptions & {
327
+ const socketOptions : tls . ConnectionOptions & {
323
328
host : string ;
324
329
port : number ;
325
330
autoSelectFamily ?: boolean ;
@@ -328,7 +333,7 @@ export class StateMachine {
328
333
host : parsedUrl [ 0 ] ,
329
334
servername : parsedUrl [ 0 ] ,
330
335
port,
331
- ...socketOptions
336
+ ...autoSelectSocketOptions ( this . options . socketOptions || { } )
332
337
} ;
333
338
const message = request . message ;
334
339
const buffer = new BufferPool ( ) ;
@@ -363,7 +368,7 @@ export class StateMachine {
363
368
throw error ;
364
369
}
365
370
try {
366
- await this . setTlsOptions ( providerTlsOptions , options ) ;
371
+ await this . setTlsOptions ( providerTlsOptions , socketOptions ) ;
367
372
} catch ( err ) {
368
373
throw onerror ( err ) ;
369
374
}
@@ -380,23 +385,25 @@ export class StateMachine {
380
385
. once ( 'close' , ( ) => rejectOnNetSocketError ( onclose ( ) ) )
381
386
. once ( 'connect' , ( ) => resolveOnNetSocketConnect ( ) ) ;
382
387
388
+ let abortListener ;
389
+
383
390
try {
384
391
if ( this . options . proxyOptions && this . options . proxyOptions . proxyHost ) {
385
392
const netSocketOptions = {
393
+ ...socketOptions ,
386
394
host : this . options . proxyOptions . proxyHost ,
387
- port : this . options . proxyOptions . proxyPort || 1080 ,
388
- ...socketOptions
395
+ port : this . options . proxyOptions . proxyPort || 1080
389
396
} ;
390
397
netSocket . connect ( netSocketOptions ) ;
391
398
await willConnect ;
392
399
393
400
try {
394
401
socks ??= loadSocks ( ) ;
395
- options . socket = (
402
+ socketOptions . socket = (
396
403
await socks . SocksClient . createConnection ( {
397
404
existing_socket : netSocket ,
398
405
command : 'connect' ,
399
- destination : { host : options . host , port : options . port } ,
406
+ destination : { host : socketOptions . host , port : socketOptions . port } ,
400
407
proxy : {
401
408
// host and port are ignored because we pass existing_socket
402
409
host : 'iLoveJavaScript' ,
@@ -412,7 +419,7 @@ export class StateMachine {
412
419
}
413
420
}
414
421
415
- socket = tls . connect ( options , ( ) => {
422
+ socket = tls . connect ( socketOptions , ( ) => {
416
423
socket . write ( message ) ;
417
424
} ) ;
418
425
@@ -422,6 +429,11 @@ export class StateMachine {
422
429
resolve
423
430
} = promiseWithResolvers < void > ( ) ;
424
431
432
+ abortListener = addAbortListener ( options ?. signal , function ( ) {
433
+ destroySockets ( ) ;
434
+ rejectOnTlsSocketError ( this . reason ) ;
435
+ } ) ;
436
+
425
437
socket
426
438
. once ( 'error' , err => rejectOnTlsSocketError ( onerror ( err ) ) )
427
439
. once ( 'close' , ( ) => rejectOnTlsSocketError ( onclose ( ) ) )
@@ -436,8 +448,11 @@ export class StateMachine {
436
448
resolve ( ) ;
437
449
}
438
450
} ) ;
439
- await ( timeoutContext ?. csotEnabled ( )
440
- ? Promise . all ( [ willResolveKmsRequest , Timeout . expires ( timeoutContext ?. remainingTimeMS ) ] )
451
+ await ( options ?. timeoutContext ?. csotEnabled ( )
452
+ ? Promise . all ( [
453
+ willResolveKmsRequest ,
454
+ Timeout . expires ( options . timeoutContext ?. remainingTimeMS )
455
+ ] )
441
456
: willResolveKmsRequest ) ;
442
457
} catch ( error ) {
443
458
if ( error instanceof TimeoutError )
@@ -446,16 +461,17 @@ export class StateMachine {
446
461
} finally {
447
462
// There's no need for any more activity on this socket at this point.
448
463
destroySockets ( ) ;
464
+ abortListener ?. [ kDispose ] ( ) ;
449
465
}
450
466
}
451
467
452
- * requests ( context : MongoCryptContext , timeoutContext ?: TimeoutContext ) {
468
+ * requests ( context : MongoCryptContext , options ?: { timeoutContext ?: TimeoutContext } & Abortable ) {
453
469
for (
454
470
let request = context . nextKMSRequest ( ) ;
455
471
request != null ;
456
472
request = context . nextKMSRequest ( )
457
473
) {
458
- yield this . kmsRequest ( request , timeoutContext ) ;
474
+ yield this . kmsRequest ( request , options ) ;
459
475
}
460
476
}
461
477
@@ -516,14 +532,16 @@ export class StateMachine {
516
532
client : MongoClient ,
517
533
ns : string ,
518
534
filter : Document ,
519
- timeoutContext ?: TimeoutContext
535
+ options ?: { timeoutContext ?: TimeoutContext } & Abortable
520
536
) : Promise < Uint8Array | null > {
521
537
const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
522
538
523
539
const cursor = client . db ( db ) . listCollections ( filter , {
524
540
promoteLongs : false ,
525
541
promoteValues : false ,
526
- timeoutContext : timeoutContext && new CursorTimeoutContext ( timeoutContext , Symbol ( ) )
542
+ timeoutContext :
543
+ options ?. timeoutContext && new CursorTimeoutContext ( options ?. timeoutContext , Symbol ( ) ) ,
544
+ signal : options ?. signal
527
545
} ) ;
528
546
529
547
// There is always exactly zero or one matching documents, so this should always exhaust the cursor
@@ -547,17 +565,30 @@ export class StateMachine {
547
565
client : MongoClient ,
548
566
ns : string ,
549
567
command : Uint8Array ,
550
- timeoutContext ?: TimeoutContext
568
+ options ?: { timeoutContext ?: TimeoutContext } & Abortable
551
569
) : Promise < Uint8Array > {
552
570
const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
553
571
const bsonOptions = { promoteLongs : false , promoteValues : false } ;
554
572
const rawCommand = deserialize ( command , bsonOptions ) ;
555
573
574
+ const commandOptions : {
575
+ timeoutMS ?: number ;
576
+ signal ?: AbortSignal ;
577
+ } = {
578
+ timeoutMS : undefined ,
579
+ signal : undefined
580
+ } ;
581
+
582
+ if ( options ?. timeoutContext ?. csotEnabled ( ) ) {
583
+ commandOptions . timeoutMS = options . timeoutContext . remainingTimeMS ;
584
+ }
585
+ if ( options ?. signal ) {
586
+ commandOptions . signal = options . signal ;
587
+ }
588
+
556
589
const response = await client . db ( db ) . command ( rawCommand , {
557
590
...bsonOptions ,
558
- ...( timeoutContext ?. csotEnabled ( )
559
- ? { timeoutMS : timeoutContext ?. remainingTimeMS }
560
- : undefined )
591
+ ...commandOptions
561
592
} ) ;
562
593
563
594
return serialize ( response , this . bsonOptions ) ;
@@ -575,17 +606,30 @@ export class StateMachine {
575
606
client : MongoClient ,
576
607
keyVaultNamespace : string ,
577
608
filter : Uint8Array ,
578
- timeoutContext ?: TimeoutContext
609
+ options ?: { timeoutContext ?: TimeoutContext } & Abortable
579
610
) : Promise < Array < DataKey > > {
580
611
const { db : dbName , collection : collectionName } =
581
612
MongoDBCollectionNamespace . fromString ( keyVaultNamespace ) ;
582
613
614
+ const commandOptions : {
615
+ timeoutContext ?: CursorTimeoutContext ;
616
+ signal ?: AbortSignal ;
617
+ } = {
618
+ timeoutContext : undefined ,
619
+ signal : undefined
620
+ } ;
621
+
622
+ if ( options ?. timeoutContext != null ) {
623
+ commandOptions . timeoutContext = new CursorTimeoutContext ( options . timeoutContext , Symbol ( ) ) ;
624
+ }
625
+ if ( options ?. signal != null ) {
626
+ commandOptions . signal = options . signal ;
627
+ }
628
+
583
629
return client
584
630
. db ( dbName )
585
631
. collection < DataKey > ( collectionName , { readConcern : { level : 'majority' } } )
586
- . find ( deserialize ( filter ) , {
587
- timeoutContext : timeoutContext && new CursorTimeoutContext ( timeoutContext , Symbol ( ) )
588
- } )
632
+ . find ( deserialize ( filter ) , commandOptions )
589
633
. toArray ( ) ;
590
634
}
591
635
}
0 commit comments