Skip to content

Commit

Permalink
feat(NODE-5940): cache the AWS credentials provider in the MONGODB-AW…
Browse files Browse the repository at this point in the history
…S auth logic (#4000)
  • Loading branch information
alenakhineika authored Feb 28, 2024
1 parent 4893330 commit 60bfc48
Show file tree
Hide file tree
Showing 16 changed files with 229 additions and 106 deletions.
4 changes: 4 additions & 0 deletions src/cmap/auth/auth_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ export class AuthContext {
}
}

/**
* Provider used during authentication.
* @internal
*/
export abstract class AuthProvider {
/**
* Prepare the handshake document before the initial handshake.
Expand Down
123 changes: 62 additions & 61 deletions src/cmap/auth/mongodb_aws.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import * as crypto from 'crypto';
import * as process from 'process';
import { promisify } from 'util';

import type { Binary, BSONSerializeOptions } from '../../bson';
import * as BSON from '../../bson';
import { aws4, getAwsCredentialProvider } from '../../deps';
import { aws4, type AWSCredentials, getAwsCredentialProvider } from '../../deps';
import {
MongoAWSError,
MongoCompatibilityError,
MongoMissingCredentialsError,
MongoRuntimeError
} from '../../error';
import { ByteUtils, maxWireVersion, ns, request } from '../../utils';
import { ByteUtils, maxWireVersion, ns, randomBytes, request } from '../../utils';
import { type AuthContext, AuthProvider } from './auth_provider';
import { MongoCredentials } from './mongo_credentials';
import { AuthMechanism } from './providers';
Expand Down Expand Up @@ -57,12 +55,40 @@ interface AWSSaslContinuePayload {
}

export class MongoDBAWS extends AuthProvider {
static credentialProvider: ReturnType<typeof getAwsCredentialProvider> | null = null;
randomBytesAsync: (size: number) => Promise<Buffer>;
static credentialProvider: ReturnType<typeof getAwsCredentialProvider>;
provider?: () => Promise<AWSCredentials>;

constructor() {
super();
this.randomBytesAsync = promisify(crypto.randomBytes);
MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

if ('fromNodeProviderChain' in MongoDBAWS.credentialProvider) {
this.provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();
}
}

override async auth(authContext: AuthContext): Promise<void> {
Expand All @@ -83,7 +109,7 @@ export class MongoDBAWS extends AuthProvider {
}

if (!authContext.credentials.username) {
authContext.credentials = await makeTempCredentials(authContext.credentials);
authContext.credentials = await makeTempCredentials(authContext.credentials, this.provider);
}

const { credentials } = authContext;
Expand All @@ -101,7 +127,7 @@ export class MongoDBAWS extends AuthProvider {
: undefined;

const db = credentials.source;
const nonce = await this.randomBytesAsync(32);
const nonce = await randomBytes(32);

const saslStart = {
saslStart: 1,
Expand Down Expand Up @@ -181,7 +207,10 @@ interface AWSTempCredentials {
Expiration?: Date;
}

async function makeTempCredentials(credentials: MongoCredentials): Promise<MongoCredentials> {
async function makeTempCredentials(
credentials: MongoCredentials,
provider?: () => Promise<AWSCredentials>
): Promise<MongoCredentials> {
function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) {
if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) {
throw new MongoMissingCredentialsError('Could not obtain temporary MONGODB-AWS credentials');
Expand All @@ -198,11 +227,31 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});
}

MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

// Check if the AWS credential provider from the SDK is present. If not,
// use the old method.
if ('kModuleError' in MongoDBAWS.credentialProvider) {
if (provider && !('kModuleError' in MongoDBAWS.credentialProvider)) {
/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
*
* - Environment variables exposed via process.env
* - SSO credentials from token cache
* - Web identity token credentials
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
AccessKeyId: creds.accessKeyId,
SecretAccessKey: creds.secretAccessKey,
Token: creds.sessionToken,
Expiration: creds.expiration
});
} catch (error) {
throw new MongoAWSError(error.message);
}
} else {
// If the environment variable AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
// is set then drivers MUST assume that it was set by an AWS ECS agent
if (process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI) {
Expand Down Expand Up @@ -232,54 +281,6 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});

return makeMongoCredentialsFromAWSTemp(creds);
} else {
let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

const provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();

/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
*
* - Environment variables exposed via process.env
* - SSO credentials from token cache
* - Web identity token credentials
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
AccessKeyId: creds.accessKeyId,
SecretAccessKey: creds.secretAccessKey,
Token: creds.sessionToken,
Expiration: creds.expiration
});
} catch (error) {
throw new MongoAWSError(error.message);
}
}
}

Expand Down
7 changes: 2 additions & 5 deletions src/cmap/auth/scram.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import * as crypto from 'crypto';
import { promisify } from 'util';

import { Binary, type Document } from '../../bson';
import { saslprep } from '../../deps';
Expand All @@ -8,7 +7,7 @@ import {
MongoMissingCredentialsError,
MongoRuntimeError
} from '../../error';
import { emitWarning, ns } from '../../utils';
import { emitWarning, ns, randomBytes } from '../../utils';
import type { HandshakeDocument } from '../connect';
import { type AuthContext, AuthProvider } from './auth_provider';
import type { MongoCredentials } from './mongo_credentials';
Expand All @@ -18,11 +17,9 @@ type CryptoMethod = 'sha1' | 'sha256';

class ScramSHA extends AuthProvider {
cryptoMethod: CryptoMethod;
randomBytesAsync: (size: number) => Promise<Buffer>;
constructor(cryptoMethod: CryptoMethod) {
super();
this.cryptoMethod = cryptoMethod || 'sha1';
this.randomBytesAsync = promisify(crypto.randomBytes);
}

override async prepare(
Expand All @@ -41,7 +38,7 @@ class ScramSHA extends AuthProvider {
emitWarning('Warning: no saslprep library specified. Passwords will not be sanitized');
}

const nonce = await this.randomBytesAsync(24);
const nonce = await randomBytes(24);
// store the nonce for later use
authContext.nonce = nonce;

Expand Down
33 changes: 11 additions & 22 deletions src/cmap/connect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,8 @@ import {
needsRetryableWriteLabel
} from '../error';
import { type Callback, HostAddress, ns } from '../utils';
import { AuthContext, type AuthProvider } from './auth/auth_provider';
import { GSSAPI } from './auth/gssapi';
import { MongoCR } from './auth/mongocr';
import { MongoDBAWS } from './auth/mongodb_aws';
import { Plain } from './auth/plain';
import { AuthContext } from './auth/auth_provider';
import { AuthMechanism } from './auth/providers';
import { ScramSHA1, ScramSHA256 } from './auth/scram';
import { X509 } from './auth/x509';
import {
type CommandOptions,
Connection,
Expand All @@ -39,17 +33,6 @@ import {
MIN_SUPPORTED_WIRE_VERSION
} from './wire_protocol/constants';

/** @internal */
export const AUTH_PROVIDERS = new Map<AuthMechanism | string, AuthProvider>([
[AuthMechanism.MONGODB_AWS, new MongoDBAWS()],
[AuthMechanism.MONGODB_CR, new MongoCR()],
[AuthMechanism.MONGODB_GSSAPI, new GSSAPI()],
[AuthMechanism.MONGODB_PLAIN, new Plain()],
[AuthMechanism.MONGODB_SCRAM_SHA1, new ScramSHA1()],
[AuthMechanism.MONGODB_SCRAM_SHA256, new ScramSHA256()],
[AuthMechanism.MONGODB_X509, new X509()]
]);

/** @public */
export type Stream = Socket | TLSSocket;

Expand Down Expand Up @@ -110,7 +93,7 @@ async function performInitialHandshake(
if (credentials) {
if (
!(credentials.mechanism === AuthMechanism.MONGODB_DEFAULT) &&
!AUTH_PROVIDERS.get(credentials.mechanism)
!options.authProviders.getOrCreateProvider(credentials.mechanism)
) {
throw new MongoInvalidArgumentError(`AuthMechanism '${credentials.mechanism}' not supported`);
}
Expand Down Expand Up @@ -165,7 +148,7 @@ async function performInitialHandshake(
authContext.response = response;

const resolvedCredentials = credentials.resolveAuthMechanism(response);
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
const provider = options.authProviders.getOrCreateProvider(resolvedCredentials.mechanism);
if (!provider) {
throw new MongoInvalidArgumentError(
`No AuthProvider for ${resolvedCredentials.mechanism} defined.`
Expand All @@ -186,6 +169,10 @@ async function performInitialHandshake(
}
}

/**
* HandshakeDocument used during authentication.
* @internal
*/
export interface HandshakeDocument extends Document {
/**
* @deprecated Use hello instead
Expand Down Expand Up @@ -227,7 +214,9 @@ export async function prepareHandshakeDocument(
if (credentials.mechanism === AuthMechanism.MONGODB_DEFAULT && credentials.username) {
handshakeDoc.saslSupportedMechs = `${credentials.source}.${credentials.username}`;

const provider = AUTH_PROVIDERS.get(AuthMechanism.MONGODB_SCRAM_SHA256);
const provider = authContext.options.authProviders.getOrCreateProvider(
AuthMechanism.MONGODB_SCRAM_SHA256
);
if (!provider) {
// This auth mechanism is always present.
throw new MongoInvalidArgumentError(
Expand All @@ -236,7 +225,7 @@ export async function prepareHandshakeDocument(
}
return provider.prepare(handshakeDoc, authContext);
}
const provider = AUTH_PROVIDERS.get(credentials.mechanism);
const provider = authContext.options.authProviders.getOrCreateProvider(credentials.mechanism);
if (!provider) {
throw new MongoInvalidArgumentError(`No AuthProvider for ${credentials.mechanism} defined.`);
}
Expand Down
3 changes: 3 additions & 0 deletions src/cmap/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
MongoWriteConcernError
} from '../error';
import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client';
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
import { type CancellationToken, TypedEventEmitter } from '../mongo_types';
import type { ReadPreferenceLike } from '../read_preference';
import { applySession, type ClientSession, updateSessionFromResponse } from '../sessions';
Expand Down Expand Up @@ -120,6 +121,8 @@ export interface ConnectionOptions
/** @internal */
connectionType?: typeof Connection;
credentials?: MongoCredentials;
/** @internal */
authProviders: MongoClientAuthProviders;
connectTimeoutMS?: number;
tls: boolean;
/** @deprecated - Will not be able to turn off in the future. */
Expand Down
9 changes: 6 additions & 3 deletions src/cmap/connection_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import {
import { CancellationToken, TypedEventEmitter } from '../mongo_types';
import type { Server } from '../sdam/server';
import { type Callback, eachAsync, List, makeCounter } from '../utils';
import { AUTH_PROVIDERS, connect } from './connect';
import { connect } from './connect';
import { Connection, type ConnectionEvents, type ConnectionOptions } from './connection';
import {
ConnectionCheckedInEvent,
Expand Down Expand Up @@ -620,7 +620,9 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
);
}
const resolvedCredentials = credentials.resolveAuthMechanism(connection.hello || undefined);
const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism);
const provider = this[kServer].topology.client.s.authProviders.getOrCreateProvider(
resolvedCredentials.mechanism
);
if (!provider) {
return callback(
new MongoMissingCredentialsError(
Expand Down Expand Up @@ -697,7 +699,8 @@ export class ConnectionPool extends TypedEventEmitter<ConnectionPoolEvents> {
...this.options,
id: this[kConnectionCounter].next().value,
generation: this[kGeneration],
cancellationToken: this[kCancellationToken]
cancellationToken: this[kCancellationToken],
authProviders: this[kServer].topology.client.s.authProviders
};

this[kPending]++;
Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ export type {
ResumeToken,
UpdateDescription
} from './change_stream';
export type { AuthContext } from './cmap/auth/auth_provider';
export type { AuthContext, AuthProvider } from './cmap/auth/auth_provider';
export type {
AuthMechanismProperties,
MongoCredentials,
Expand All @@ -217,6 +217,7 @@ export type {
Response,
WriteProtocolMessageType
} from './cmap/commands';
export type { HandshakeDocument } from './cmap/connect';
export type { LEGAL_TCP_SOCKET_OPTIONS, LEGAL_TLS_SOCKET_OPTIONS, Stream } from './cmap/connect';
export type {
CommandOptions,
Expand Down Expand Up @@ -304,6 +305,7 @@ export type {
SupportedTLSSocketOptions,
WithSessionCallback
} from './mongo_client';
export { MongoClientAuthProviders } from './mongo_client_auth_providers';
export type {
Log,
LogConvertible,
Expand Down
Loading

0 comments on commit 60bfc48

Please sign in to comment.