diff --git a/src/client-side-encryption/auto_encrypter.ts b/src/client-side-encryption/auto_encrypter.ts index 3066b8d1f0b..2ca163390bf 100644 --- a/src/client-side-encryption/auto_encrypter.ts +++ b/src/client-side-encryption/auto_encrypter.ts @@ -15,7 +15,11 @@ import * as cryptoCallbacks from './crypto_callbacks'; import { MongoCryptInvalidArgumentError } from './errors'; import { MongocryptdManager } from './mongocryptd_manager'; import { type KMSProviders, refreshKMSCredentials } from './providers'; -import { type CSFLEKMSTlsOptions, StateMachine } from './state_machine'; +import { + type ClientEncryptionSocketOptions, + type CSFLEKMSTlsOptions, + StateMachine +} from './state_machine'; /** @public */ export interface AutoEncryptionOptions { @@ -101,6 +105,8 @@ export interface AutoEncryptionOptions { proxyOptions?: ProxyOptions; /** The TLS options to use connecting to the KMS provider */ tlsOptions?: CSFLEKMSTlsOptions; + /** Options for KMS socket requests. */ + socketOptions?: ClientEncryptionSocketOptions; } /** @@ -150,6 +156,7 @@ export class AutoEncrypter { _kmsProviders: KMSProviders; _bypassMongocryptdAndCryptShared: boolean; _contextCounter: number; + _socketOptions: ClientEncryptionSocketOptions; _mongocryptdManager?: MongocryptdManager; _mongocryptdClient?: MongoClient; @@ -234,6 +241,7 @@ export class AutoEncrypter { this._proxyOptions = options.proxyOptions || {}; this._tlsOptions = options.tlsOptions || {}; this._kmsProviders = options.kmsProviders || {}; + this._socketOptions = options.socketOptions || {}; const mongoCryptOptions: MongoCryptOptions = { cryptoCallbacks @@ -379,7 +387,8 @@ export class AutoEncrypter { promoteValues: false, promoteLongs: false, proxyOptions: this._proxyOptions, - tlsOptions: this._tlsOptions + tlsOptions: this._tlsOptions, + socketOptions: this._socketOptions }); return deserialize(await stateMachine.execute(this, context), { @@ -399,7 +408,8 @@ export class AutoEncrypter { const stateMachine = new StateMachine({ ...options, proxyOptions: this._proxyOptions, - tlsOptions: this._tlsOptions + tlsOptions: this._tlsOptions, + socketOptions: this._socketOptions }); return await stateMachine.execute(this, context); diff --git a/src/client-side-encryption/client_encryption.ts b/src/client-side-encryption/client_encryption.ts index b78b3549962..1ce33279eaf 100644 --- a/src/client-side-encryption/client_encryption.ts +++ b/src/client-side-encryption/client_encryption.ts @@ -28,7 +28,11 @@ import { type KMSProviders, refreshKMSCredentials } from './providers/index'; -import { type CSFLEKMSTlsOptions, StateMachine } from './state_machine'; +import { + type ClientEncryptionSocketOptions, + type CSFLEKMSTlsOptions, + StateMachine +} from './state_machine'; /** * @public @@ -62,6 +66,8 @@ export class ClientEncryption { _tlsOptions: CSFLEKMSTlsOptions; /** @internal */ _kmsProviders: KMSProviders; + /** @internal */ + _socketOptions: ClientEncryptionSocketOptions; /** @internal */ _mongoCrypt: MongoCrypt; @@ -108,6 +114,15 @@ export class ClientEncryption { this._proxyOptions = options.proxyOptions ?? {}; this._tlsOptions = options.tlsOptions ?? {}; this._kmsProviders = options.kmsProviders || {}; + this._socketOptions = {}; + + if ('autoSelectFamily' in client.s.options) { + this._socketOptions.autoSelectFamily = client.s.options.autoSelectFamily; + } + if ('autoSelectFamilyAttemptTimeout' in client.s.options) { + this._socketOptions.autoSelectFamilyAttemptTimeout = + client.s.options.autoSelectFamilyAttemptTimeout; + } if (options.keyVaultNamespace == null) { throw new MongoCryptInvalidArgumentError('Missing required option `keyVaultNamespace`'); @@ -199,7 +214,8 @@ export class ClientEncryption { const stateMachine = new StateMachine({ proxyOptions: this._proxyOptions, - tlsOptions: this._tlsOptions + tlsOptions: this._tlsOptions, + socketOptions: this._socketOptions }); const dataKey = deserialize(await stateMachine.execute(this, context)) as DataKey; @@ -256,7 +272,8 @@ export class ClientEncryption { const context = this._mongoCrypt.makeRewrapManyDataKeyContext(filterBson, keyEncryptionKeyBson); const stateMachine = new StateMachine({ proxyOptions: this._proxyOptions, - tlsOptions: this._tlsOptions + tlsOptions: this._tlsOptions, + socketOptions: this._socketOptions }); const { v: dataKeys } = deserialize(await stateMachine.execute(this, context)); @@ -637,7 +654,8 @@ export class ClientEncryption { const stateMachine = new StateMachine({ proxyOptions: this._proxyOptions, - tlsOptions: this._tlsOptions + tlsOptions: this._tlsOptions, + socketOptions: this._socketOptions }); const { v } = deserialize(await stateMachine.execute(this, context)); @@ -715,7 +733,8 @@ export class ClientEncryption { const valueBuffer = serialize({ v: value }); const stateMachine = new StateMachine({ proxyOptions: this._proxyOptions, - tlsOptions: this._tlsOptions + tlsOptions: this._tlsOptions, + socketOptions: this._socketOptions }); const context = this._mongoCrypt.makeExplicitEncryptionContext(valueBuffer, contextOptions); diff --git a/src/client-side-encryption/state_machine.ts b/src/client-side-encryption/state_machine.ts index f0ae19546aa..fd21fd4f3b9 100644 --- a/src/client-side-encryption/state_machine.ts +++ b/src/client-side-encryption/state_machine.ts @@ -114,6 +114,16 @@ export type CSFLEKMSTlsOptions = { [key: string]: ClientEncryptionTlsOptions | undefined; }; +/** + * @public + * + * Socket options to use for KMS requests. + */ +export type ClientEncryptionSocketOptions = Pick< + MongoClientOptions, + 'autoSelectFamily' | 'autoSelectFamilyAttemptTimeout' +>; + /** * This is kind of a hack. For `rewrapManyDataKey`, we have tests that * guarantee that when there are no matching keys, `rewrapManyDataKey` returns @@ -153,6 +163,9 @@ export type StateMachineOptions = { /** TLS options for KMS requests, if set. */ tlsOptions: CSFLEKMSTlsOptions; + + /** Socket specific options we support. */ + socketOptions: ClientEncryptionSocketOptions; } & Pick; /** @@ -289,7 +302,12 @@ export class StateMachine { async kmsRequest(request: MongoCryptKMSRequest): Promise { const parsedUrl = request.endpoint.split(':'); const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT; - const options: tls.ConnectionOptions & { host: string; port: number } = { + const options: tls.ConnectionOptions & { + host: string; + port: number; + autoSelectFamily?: boolean; + autoSelectFamilyAttemptTimeout?: number; + } = { host: parsedUrl[0], servername: parsedUrl[0], port @@ -297,6 +315,14 @@ export class StateMachine { const message = request.message; const buffer = new BufferPool(); + const socketOptions = this.options.socketOptions || {}; + if ('autoSelectFamily' in socketOptions) { + options.autoSelectFamily = socketOptions.autoSelectFamily; + } + if ('autoSelectFamilyAttemptTimeout' in socketOptions) { + options.autoSelectFamilyAttemptTimeout = socketOptions.autoSelectFamilyAttemptTimeout; + } + const netSocket: net.Socket = new net.Socket(); let socket: tls.TLSSocket; diff --git a/src/encrypter.ts b/src/encrypter.ts index fbcf7c195d9..7ebda9e61fe 100644 --- a/src/encrypter.ts +++ b/src/encrypter.ts @@ -1,6 +1,7 @@ import { callbackify } from 'util'; import { AutoEncrypter, type AutoEncryptionOptions } from './client-side-encryption/auto_encrypter'; +import { type ClientEncryptionSocketOptions } from './client-side-encryption/state_machine'; import { MONGO_CLIENT_EVENTS } from './constants'; import { getMongoDBClientEncryption } from './deps'; import { MongoInvalidArgumentError, MongoMissingDependencyError } from './error'; @@ -56,6 +57,15 @@ export class Encrypter { }; } + const socketOptions: ClientEncryptionSocketOptions = {}; + if ('autoSelectFamily' in options) { + socketOptions.autoSelectFamily = options.autoSelectFamily; + } + if ('autoSelectFamilyAttemptTimeout' in options) { + socketOptions.autoSelectFamilyAttemptTimeout = options.autoSelectFamilyAttemptTimeout; + } + options.autoEncryption.socketOptions = socketOptions; + this.autoEncrypter = new AutoEncrypter(client, options.autoEncryption); } diff --git a/src/index.ts b/src/index.ts index 0ba8f82c01b..efd0b9d0550 100644 --- a/src/index.ts +++ b/src/index.ts @@ -248,6 +248,7 @@ export type { LocalKMSProviderConfiguration } from './client-side-encryption/providers/index'; export type { + ClientEncryptionSocketOptions, ClientEncryptionTlsOptions, CSFLEKMSTlsOptions, StateMachineExecutable diff --git a/test/unit/client-side-encryption/client_encryption.test.ts b/test/unit/client-side-encryption/client_encryption.test.ts index c83383d4e42..b9b56ff9e58 100644 --- a/test/unit/client-side-encryption/client_encryption.test.ts +++ b/test/unit/client-side-encryption/client_encryption.test.ts @@ -19,6 +19,11 @@ import { Binary, BSON, deserialize } from '../../mongodb'; const { EJSON } = BSON; class MockClient { + s: any; + + constructor(options?: any) { + this.s = { options: options || {} }; + } db(dbName) { return { async createCollection(name, options) { diff --git a/test/unit/client-side-encryption/state_machine.test.ts b/test/unit/client-side-encryption/state_machine.test.ts index 518e63a26db..baec0cbece7 100644 --- a/test/unit/client-side-encryption/state_machine.test.ts +++ b/test/unit/client-side-encryption/state_machine.test.ts @@ -148,6 +148,30 @@ describe('StateMachine', function () { }); }); + context('when socket options are provided', function () { + const stateMachine = new StateMachine({ + socketOptions: { autoSelectFamily: true, autoSelectFamilyAttemptTimeout: 300 } + } as any); + const request = new MockRequest(Buffer.from('foobar'), -1); + let connectOptions; + + it('passes them through to the socket', async function () { + sandbox.stub(tls, 'connect').callsFake((options, callback) => { + connectOptions = options; + this.fakeSocket = new MockSocket(callback); + return this.fakeSocket; + }); + const kmsRequestPromise = stateMachine.kmsRequest(request); + + await setTimeoutAsync(0); + this.fakeSocket.emit('data', Buffer.alloc(0)); + + await kmsRequestPromise; + expect(connectOptions.autoSelectFamily).to.equal(true); + expect(connectOptions.autoSelectFamilyAttemptTimeout).to.equal(300); + }); + }); + context('when tls options are provided', function () { context('when the options are insecure', function () { [