Skip to content

Commit

Permalink
refactor(NODE-5914): topology.selectServer to async-await (#4020)
Browse files Browse the repository at this point in the history
  • Loading branch information
W-A-James authored Mar 15, 2024
1 parent d86d2ae commit aec8416
Show file tree
Hide file tree
Showing 15 changed files with 441 additions and 541 deletions.
20 changes: 11 additions & 9 deletions src/change_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -934,14 +934,16 @@ export class ChangeStream<
this.cursor.close().catch(() => null);

const topology = getTopology(this.parent);
topology.selectServer(
this.cursor.readPreference,
{ operationName: 'reconnect topology in change stream' },
serverSelectionError => {
if (serverSelectionError) return this._closeEmitterModeWithError(changeStreamError);
this.cursor = this._createChangeStreamCursor(this.cursor.resumeOptions);
}
);
topology
.selectServer(this.cursor.readPreference, {
operationName: 'reconnect topology in change stream'
})
.then(
() => {
this.cursor = this._createChangeStreamCursor(this.cursor.resumeOptions);
},
() => this._closeEmitterModeWithError(changeStreamError)
);
} else {
this._closeEmitterModeWithError(changeStreamError);
}
Expand All @@ -966,7 +968,7 @@ export class ChangeStream<
await this.cursor.close().catch(() => null);
const topology = getTopology(this.parent);
try {
await topology.selectServerAsync(this.cursor.readPreference, {
await topology.selectServer(this.cursor.readPreference, {
operationName: 'reconnect topology in change stream'
});
this.cursor = this._createChangeStreamCursor(this.cursor.resumeOptions);
Expand Down
3 changes: 1 addition & 2 deletions src/mongo_client.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { promises as fs } from 'fs';
import type { TcpNetConnectOpts } from 'net';
import type { ConnectionOptions as TLSConnectionOptions, TLSSocketOptions } from 'tls';
import { promisify } from 'util';

import { type BSONSerializeOptions, type Document, resolveBSONOptions } from './bson';
import { ChangeStream, type ChangeStreamDocument, type ChangeStreamOptions } from './change_stream';
Expand Down Expand Up @@ -550,7 +549,7 @@ export class MongoClient extends TypedEventEmitter<MongoClientEvents> {

const topologyConnect = async () => {
try {
await promisify(callback => this.topology?.connect(options, callback))();
await this.topology?.connect(options);
} catch (error) {
this.topology?.close();
throw error;
Expand Down
4 changes: 2 additions & 2 deletions src/operations/execute_operation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ export async function executeOperation<
selector = readPreference;
}

const server = await topology.selectServerAsync(selector, {
const server = await topology.selectServer(selector, {
session,
operationName: operation.commandName
});
Expand Down Expand Up @@ -244,7 +244,7 @@ async function retryOperation<
}

// select a new server, and attempt to retry the operation
const server = await topology.selectServerAsync(selector, {
const server = await topology.selectServer(selector, {
session,
operationName: operation.commandName,
previousServer
Expand Down
158 changes: 73 additions & 85 deletions src/sdam/topology.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import { promisify } from 'util';

import type { BSONSerializeOptions, Document } from '../bson';
import type { MongoCredentials } from '../cmap/auth/mongo_credentials';
import type { ConnectionEvents } from '../cmap/connection';
Expand Down Expand Up @@ -44,6 +42,7 @@ import {
makeStateMachine,
now,
ns,
promiseWithResolvers,
shuffle,
TimeoutController
} from '../utils';
Expand Down Expand Up @@ -105,7 +104,8 @@ export interface ServerSelectionRequest {
mongoLogger: MongoLogger | undefined;
transaction?: Transaction;
startTime: number;
callback: ServerSelectionCallback;
resolve: (server: Server) => void;
reject: (error: MongoError) => void;
[kCancelled]?: boolean;
timeoutController: TimeoutController;
operationName: string;
Expand Down Expand Up @@ -215,6 +215,9 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {

client!: MongoClient;

/** @internal */
private connectionLock?: Promise<Topology>;

/** @event */
static readonly SERVER_OPENING = SERVER_OPENING;
/** @event */
Expand All @@ -238,11 +241,6 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {
/** @event */
static readonly TIMEOUT = TIMEOUT;

selectServerAsync: (
selector: string | ReadPreference | ServerSelector,
options: SelectServerOptions
) => Promise<Server>;

/**
* @param seedlist - a list of HostAddress instances to connect to
*/
Expand All @@ -254,14 +252,6 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {
super();

this.client = client;
this.selectServerAsync = promisify(
(
selector: string | ReadPreference | ServerSelector,
options: SelectServerOptions,
callback: (e: Error, r: Server) => void
) => this.selectServer(selector, options, callback as any)
);

// Options should only be undefined in tests, MongoClient will always have defined options
options = options ?? {
hosts: [HostAddress.fromString('localhost:27017')],
Expand Down Expand Up @@ -351,6 +341,7 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {

this.on(Topology.TOPOLOGY_DESCRIPTION_CHANGED, this.s.detectShardedTopology);
}
this.connectionLock = undefined;
}

private detectShardedTopology(event: TopologyDescriptionChangedEvent) {
Expand Down Expand Up @@ -411,17 +402,22 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {
}

/** Initiate server connect */
connect(callback: Callback): void;
connect(options: ConnectOptions, callback: Callback): void;
connect(options?: ConnectOptions | Callback, callback?: Callback): void {
if (typeof options === 'function') (callback = options), (options = {});
async connect(options?: ConnectOptions): Promise<Topology> {
this.connectionLock ??= this._connect(options);
try {
await this.connectionLock;
return this;
} finally {
this.connectionLock = undefined;
}

return this;
}

private async _connect(options?: ConnectOptions): Promise<Topology> {
options = options ?? {};
if (this.s.state === STATE_CONNECTED) {
if (typeof callback === 'function') {
callback();
}

return;
return this;
}

stateTransition(this, STATE_CONNECTING);
Expand Down Expand Up @@ -459,40 +455,33 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {
}
}

const exitWithError = (error: Error) =>
callback ? callback(error) : this.emit(Topology.ERROR, error);

const readPreference = options.readPreference ?? ReadPreference.primary;
const selectServerOptions = { operationName: 'ping', ...options };
this.selectServer(
readPreferenceServerSelector(readPreference),
selectServerOptions,
(err, server) => {
if (err) {
this.close();
return exitWithError(err);
}

const skipPingOnConnect = this.s.options[Symbol.for('@@mdb.skipPingOnConnect')] === true;
if (!skipPingOnConnect && server && this.s.credentials) {
server.command(ns('admin.$cmd'), { ping: 1 }, {}).then(() => {
stateTransition(this, STATE_CONNECTED);
this.emit(Topology.OPEN, this);
this.emit(Topology.CONNECT, this);

callback?.(undefined, this);
}, exitWithError);

return;
}
try {
const server = await this.selectServer(
readPreferenceServerSelector(readPreference),
selectServerOptions
);

const skipPingOnConnect = this.s.options[Symbol.for('@@mdb.skipPingOnConnect')] === true;
if (!skipPingOnConnect && server && this.s.credentials) {
await server.command(ns('admin.$cmd'), { ping: 1 }, {});
stateTransition(this, STATE_CONNECTED);
this.emit(Topology.OPEN, this);
this.emit(Topology.CONNECT, this);

callback?.(undefined, this);
return this;
}
);

stateTransition(this, STATE_CONNECTED);
this.emit(Topology.OPEN, this);
this.emit(Topology.CONNECT, this);

return this;
} catch (error) {
this.close();
throw error;
}
}

/** Close this topology */
Expand Down Expand Up @@ -533,11 +522,10 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {
* @param callback - The callback used to indicate success or failure
* @returns An instance of a `Server` meeting the criteria of the predicate provided
*/
selectServer(
async selectServer(
selector: string | ReadPreference | ServerSelector,
options: SelectServerOptions,
callback: Callback<Server>
): void {
options: SelectServerOptions
): Promise<Server> {
let serverSelector;
if (typeof selector !== 'function') {
if (typeof selector === 'string') {
Expand Down Expand Up @@ -588,16 +576,17 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {
)
);
}
callback(undefined, transaction.server);
return;
return transaction.server;
}

const { promise: serverPromise, resolve, reject } = promiseWithResolvers<Server>();
const waitQueueMember: ServerSelectionRequest = {
serverSelector,
topologyDescription: this.description,
mongoLogger: this.client.mongoLogger,
transaction,
callback,
resolve,
reject,
timeoutController: new TimeoutController(options.serverSelectionTimeoutMS),
startTime: now(),
operationName: options.operationName,
Expand Down Expand Up @@ -628,13 +617,14 @@ export class Topology extends TypedEventEmitter<TopologyEvents> {
)
);
}
waitQueueMember.callback(timeoutError);
waitQueueMember.reject(timeoutError);
});

this[kWaitQueue].push(waitQueueMember);
processWaitQueue(this);
}

return serverPromise;
}
/**
* Update the internal TopologyDescription with a ServerDescription
*
Expand Down Expand Up @@ -883,7 +873,7 @@ function updateServers(topology: Topology, incomingServerDescription?: ServerDes
}
}

function drainWaitQueue(queue: List<ServerSelectionRequest>, err?: MongoDriverError) {
function drainWaitQueue(queue: List<ServerSelectionRequest>, drainError: MongoDriverError) {
while (queue.length) {
const waitQueueMember = queue.shift();
if (!waitQueueMember) {
Expand All @@ -893,25 +883,23 @@ function drainWaitQueue(queue: List<ServerSelectionRequest>, err?: MongoDriverEr
waitQueueMember.timeoutController.clear();

if (!waitQueueMember[kCancelled]) {
if (err) {
if (
waitQueueMember.mongoLogger?.willLog(
MongoLoggableComponent.SERVER_SELECTION,
SeverityLevel.DEBUG
if (
waitQueueMember.mongoLogger?.willLog(
MongoLoggableComponent.SERVER_SELECTION,
SeverityLevel.DEBUG
)
) {
waitQueueMember.mongoLogger?.debug(
MongoLoggableComponent.SERVER_SELECTION,
new ServerSelectionFailedEvent(
waitQueueMember.serverSelector,
waitQueueMember.topologyDescription,
drainError,
waitQueueMember.operationName
)
) {
waitQueueMember.mongoLogger?.debug(
MongoLoggableComponent.SERVER_SELECTION,
new ServerSelectionFailedEvent(
waitQueueMember.serverSelector,
waitQueueMember.topologyDescription,
err,
waitQueueMember.operationName
)
);
}
);
}
waitQueueMember.callback(err);
waitQueueMember.reject(drainError);
}
}
}
Expand Down Expand Up @@ -946,7 +934,7 @@ function processWaitQueue(topology: Topology) {
previousServer ? [previousServer] : []
)
: serverDescriptions;
} catch (e) {
} catch (selectorError) {
waitQueueMember.timeoutController.clear();
if (
topology.client.mongoLogger?.willLog(
Expand All @@ -959,12 +947,12 @@ function processWaitQueue(topology: Topology) {
new ServerSelectionFailedEvent(
waitQueueMember.serverSelector,
topology.description,
e,
selectorError,
waitQueueMember.operationName
)
);
}
waitQueueMember.callback(e);
waitQueueMember.reject(selectorError);
continue;
}

Expand Down Expand Up @@ -1007,7 +995,7 @@ function processWaitQueue(topology: Topology) {
}

if (!selectedServer) {
const error = new MongoServerSelectionError(
const serverSelectionError = new MongoServerSelectionError(
'server selection returned a server description but the server was not found in the topology',
topology.description
);
Expand All @@ -1022,12 +1010,12 @@ function processWaitQueue(topology: Topology) {
new ServerSelectionFailedEvent(
waitQueueMember.serverSelector,
topology.description,
error,
serverSelectionError,
waitQueueMember.operationName
)
);
}
waitQueueMember.callback(error);
waitQueueMember.reject(serverSelectionError);
return;
}
const transaction = waitQueueMember.transaction;
Expand All @@ -1053,7 +1041,7 @@ function processWaitQueue(topology: Topology) {
)
);
}
waitQueueMember.callback(undefined, selectedServer);
waitQueueMember.resolve(selectedServer);
}

if (topology[kWaitQueue].length > 0) {
Expand Down
Loading

0 comments on commit aec8416

Please sign in to comment.