diff --git a/src/cmap/wire_protocol/on_demand/document.ts b/src/cmap/wire_protocol/on_demand/document.ts index 944916f10b..67f5b3a091 100644 --- a/src/cmap/wire_protocol/on_demand/document.ts +++ b/src/cmap/wire_protocol/on_demand/document.ts @@ -1,15 +1,16 @@ import { Binary, - BSON, type BSONElement, BSONError, type BSONSerializeOptions, BSONType, + deserialize, getBigInt64LE, getFloat64LE, getInt32LE, ObjectId, parseToElementsToArray, + pluckBSONSerializeOptions, Timestamp, toUTF8 } from '../../../bson'; @@ -330,11 +331,23 @@ export class OnDemandDocument { * @param options - BSON deserialization options */ public toObject(options?: BSONSerializeOptions): Record { - return BSON.deserialize(this.bson, { - ...options, + const exactBSONOptions = { + ...pluckBSONSerializeOptions(options ?? {}), + validation: this.parseBsonSerializationOptions(options), index: this.offset, allowObjectSmallerThanBufferSize: true - }); + }; + return deserialize(this.bson, exactBSONOptions); + } + + private parseBsonSerializationOptions(options?: { enableUtf8Validation?: boolean }): { + utf8: { writeErrors: false } | false; + } { + const enableUtf8Validation = options?.enableUtf8Validation; + if (enableUtf8Validation === false) { + return { utf8: false }; + } + return { utf8: { writeErrors: false } }; } /** Returns this document's bytes only */ diff --git a/src/cmap/wire_protocol/responses.ts b/src/cmap/wire_protocol/responses.ts index 0ef048e8da..9837634bfb 100644 --- a/src/cmap/wire_protocol/responses.ts +++ b/src/cmap/wire_protocol/responses.ts @@ -5,7 +5,6 @@ import { type Document, Long, parseToElementsToArray, - pluckBSONSerializeOptions, type Timestamp } from '../../bson'; import { MongoUnexpectedServerResponseError } from '../../error'; @@ -166,24 +165,6 @@ export class MongoDBResponse extends OnDemandDocument { } return this.clusterTime ?? null; } - - public override toObject(options?: BSONSerializeOptions): Record { - const exactBSONOptions = { - ...pluckBSONSerializeOptions(options ?? {}), - validation: this.parseBsonSerializationOptions(options) - }; - return super.toObject(exactBSONOptions); - } - - private parseBsonSerializationOptions(options?: { enableUtf8Validation?: boolean }): { - utf8: { writeErrors: false } | false; - } { - const enableUtf8Validation = options?.enableUtf8Validation; - if (enableUtf8Validation === false) { - return { utf8: false }; - } - return { utf8: { writeErrors: false } }; - } } /** @internal */ diff --git a/test/integration/node-specific/bson-options/utf8_validation.test.ts b/test/integration/node-specific/bson-options/utf8_validation.test.ts index 5c3f94e7fb..d6345a884d 100644 --- a/test/integration/node-specific/bson-options/utf8_validation.test.ts +++ b/test/integration/node-specific/bson-options/utf8_validation.test.ts @@ -1,11 +1,16 @@ import { expect } from 'chai'; +import * as net from 'net'; import * as sinon from 'sinon'; +import { inspect } from 'util'; import { BSON, + BSONError, + type Collection, + deserialize, type MongoClient, - MongoDBResponse, MongoServerError, + OnDemandDocument, OpMsgResponse } from '../../../mongodb'; @@ -23,12 +28,12 @@ describe('class MongoDBResponse', () => { let bsonSpy: sinon.SinonSpy; beforeEach(() => { - bsonSpy = sinon.spy(MongoDBResponse.prototype, 'parseBsonSerializationOptions'); + // @ts-expect-error private function + bsonSpy = sinon.spy(OnDemandDocument.prototype, 'parseBsonSerializationOptions'); }); afterEach(() => { bsonSpy?.restore(); - // @ts-expect-error: Allow this to be garbage collected bsonSpy = null; }); @@ -153,3 +158,180 @@ describe('class MongoDBResponse', () => { } ); }); + +describe('utf8 validation with cursors', function () { + let client: MongoClient; + let collection: Collection; + + /** + * Inserts a document with malformed utf8 bytes. This method spies on socket.write, and then waits + * for an OP_MSG payload corresponding to `collection.insertOne({ field: 'é' })`, and then modifies the + * bytes of the character 'é', to produce invalid utf8. + */ + async function insertDocumentWithInvalidUTF8() { + const stub = sinon.stub(net.Socket.prototype, 'write').callsFake(function (...args) { + const providedBuffer = args[0].toString('hex'); + const targetBytes = Buffer.from(document.field, 'utf-8').toString('hex'); + + if (providedBuffer.includes(targetBytes)) { + if (providedBuffer.split(targetBytes).length !== 2) { + sinon.restore(); + const message = `too many target bytes sequences: received ${providedBuffer.split(targetBytes).length}\n. command: ${inspect(deserialize(args[0]), { depth: Infinity })}`; + throw new Error(message); + } + const buffer = Buffer.from(providedBuffer.replace(targetBytes, 'c301'.repeat(8)), 'hex'); + const result = stub.wrappedMethod.apply(this, [buffer]); + sinon.restore(); + return result; + } + const result = stub.wrappedMethod.apply(this, args); + return result; + }); + + const document = { + field: 'é'.repeat(8) + }; + + await collection.insertOne(document); + + sinon.restore(); + } + + beforeEach(async function () { + client = this.configuration.newClient(); + await client.connect(); + const db = client.db('test'); + collection = db.collection('invalidutf'); + + await collection.deleteMany({}); + await insertDocumentWithInvalidUTF8(); + }); + + afterEach(async function () { + sinon.restore(); + await client.close(); + }); + + context('when utf-8 validation is explicitly disabled', function () { + it('documents can be read using a for-await loop without errors', async function () { + for await (const _doc of collection.find({}, { enableUtf8Validation: false })); + }); + it('documents can be read using next() without errors', async function () { + const cursor = collection.find({}, { enableUtf8Validation: false }); + + while (await cursor.hasNext()) { + await cursor.next(); + } + }); + + it('documents can be read using toArray() without errors', async function () { + const cursor = collection.find({}, { enableUtf8Validation: false }); + await cursor.toArray(); + }); + + it('documents can be read using .stream() without errors', async function () { + const cursor = collection.find({}, { enableUtf8Validation: false }); + await cursor.stream().toArray(); + }); + + it('documents can be read with tryNext() without error', async function () { + const cursor = collection.find({}, { enableUtf8Validation: false }); + + while (await cursor.hasNext()) { + await cursor.tryNext(); + } + }); + }); + + async function expectReject(fn: () => Promise) { + try { + await fn(); + expect.fail('expected the provided callback function to reject, but it did not.'); + } catch (error) { + expect(error).to.match(/Invalid UTF-8 string in BSON document/); + expect(error).to.be.instanceOf(BSONError); + } + } + + context('when utf-8 validation is explicitly enabled', function () { + it('a for-await loop throw a BSON error', async function () { + await expectReject(async () => { + for await (const _doc of collection.find({}, { enableUtf8Validation: true })); + }); + }); + it('next() throws a BSON error', async function () { + await expectReject(async () => { + const cursor = collection.find({}, { enableUtf8Validation: true }); + + while (await cursor.hasNext()) { + await cursor.next(); + } + }); + }); + + it('toArray() throws a BSON error', async function () { + await expectReject(async () => { + const cursor = collection.find({}, { enableUtf8Validation: true }); + await cursor.toArray(); + }); + }); + + it('.stream() throws a BSONError', async function () { + await expectReject(async () => { + const cursor = collection.find({}, { enableUtf8Validation: true }); + await cursor.stream().toArray(); + }); + }); + + it('tryNext() throws a BSONError', async function () { + await expectReject(async () => { + const cursor = collection.find({}, { enableUtf8Validation: true }); + + while (await cursor.hasNext()) { + await cursor.tryNext(); + } + }); + }); + }); + + context('utf-8 validation defaults to enabled', function () { + it('a for-await loop throw a BSON error', async function () { + await expectReject(async () => { + for await (const _doc of collection.find({})); + }); + }); + it('next() throws a BSON error', async function () { + await expectReject(async () => { + const cursor = collection.find({}); + + while (await cursor.hasNext()) { + await cursor.next(); + } + }); + }); + + it('toArray() throws a BSON error', async function () { + await expectReject(async () => { + const cursor = collection.find({}); + await cursor.toArray(); + }); + }); + + it('.stream() throws a BSONError', async function () { + await expectReject(async () => { + const cursor = collection.find({}); + await cursor.stream().toArray(); + }); + }); + + it('tryNext() throws a BSONError', async function () { + await expectReject(async () => { + const cursor = collection.find({}, { enableUtf8Validation: true }); + + while (await cursor.hasNext()) { + await cursor.tryNext(); + } + }); + }); + }); +}); diff --git a/test/unit/cmap/wire_protocol/responses.test.ts b/test/unit/cmap/wire_protocol/responses.test.ts index 9498765cf4..7fccbfc7fc 100644 --- a/test/unit/cmap/wire_protocol/responses.test.ts +++ b/test/unit/cmap/wire_protocol/responses.test.ts @@ -1,54 +1,87 @@ import { expect } from 'chai'; import * as sinon from 'sinon'; +// to spy on the bson module, we must import it from the driver +// eslint-disable-next-line @typescript-eslint/no-restricted-imports +import * as mdb from '../../../../src/bson'; import { - BSON, CursorResponse, Int32, MongoDBResponse, MongoUnexpectedServerResponseError, - OnDemandDocument + OnDemandDocument, + serialize } from '../../../mongodb'; describe('class MongoDBResponse', () => { it('is a subclass of OnDemandDocument', () => { - expect(new MongoDBResponse(BSON.serialize({ ok: 1 }))).to.be.instanceOf(OnDemandDocument); + expect(new MongoDBResponse(serialize({ ok: 1 }))).to.be.instanceOf(OnDemandDocument); }); context('utf8 validation', () => { - afterEach(() => sinon.restore()); + let deseriailzeSpy: sinon.SinonStub>; + beforeEach(function () { + const deserialize = mdb.deserialize; + deseriailzeSpy = sinon.stub>().callsFake(deserialize); + sinon.stub(mdb, 'deserialize').get(() => { + return deseriailzeSpy; + }); + }); + afterEach(function () { + sinon.restore(); + }); context('when enableUtf8Validation is not specified', () => { const options = { enableUtf8Validation: undefined }; it('calls BSON deserialize with writeErrors validation turned off', () => { - const res = new MongoDBResponse(BSON.serialize({})); - const toObject = sinon.spy(Object.getPrototypeOf(Object.getPrototypeOf(res)), 'toObject'); + const res = new MongoDBResponse(serialize({})); res.toObject(options); - expect(toObject).to.have.been.calledWith( - sinon.match({ validation: { utf8: { writeErrors: false } } }) - ); + + expect(deseriailzeSpy).to.have.been.called; + + const [ + { + args: [_buffer, { validation }] + } + ] = deseriailzeSpy.getCalls(); + + expect(validation).to.deep.equal({ utf8: { writeErrors: false } }); }); }); context('when enableUtf8Validation is true', () => { const options = { enableUtf8Validation: true }; it('calls BSON deserialize with writeErrors validation turned off', () => { - const res = new MongoDBResponse(BSON.serialize({})); - const toObject = sinon.spy(Object.getPrototypeOf(Object.getPrototypeOf(res)), 'toObject'); + const res = new MongoDBResponse(serialize({})); res.toObject(options); - expect(toObject).to.have.been.calledWith( - sinon.match({ validation: { utf8: { writeErrors: false } } }) - ); + + expect(deseriailzeSpy).to.have.been.called; + + const [ + { + args: [_buffer, { validation }] + } + ] = deseriailzeSpy.getCalls(); + + expect(validation).to.deep.equal({ utf8: { writeErrors: false } }); }); }); context('when enableUtf8Validation is false', () => { const options = { enableUtf8Validation: false }; it('calls BSON deserialize with all validation disabled', () => { - const res = new MongoDBResponse(BSON.serialize({})); - const toObject = sinon.spy(Object.getPrototypeOf(Object.getPrototypeOf(res)), 'toObject'); + const res = new MongoDBResponse(serialize({})); res.toObject(options); - expect(toObject).to.have.been.calledWith(sinon.match({ validation: { utf8: false } })); + + expect(deseriailzeSpy).to.have.been.called; + + const [ + { + args: [_buffer, { validation }] + } + ] = deseriailzeSpy.getCalls(); + + expect(validation).to.deep.equal({ utf8: false }); }); }); }); @@ -57,7 +90,7 @@ describe('class MongoDBResponse', () => { describe('class CursorResponse', () => { describe('get cursor()', () => { it('throws if input does not contain cursor embedded document', () => { - expect(() => new CursorResponse(BSON.serialize({ ok: 1 })).cursor).to.throw( + expect(() => new CursorResponse(serialize({ ok: 1 })).cursor).to.throw( MongoUnexpectedServerResponseError, /"cursor" is missing/ ); @@ -66,7 +99,7 @@ describe('class CursorResponse', () => { describe('get id()', () => { it('throws if input does not contain cursor.id int64', () => { - expect(() => new CursorResponse(BSON.serialize({ ok: 1, cursor: {} })).id).to.throw( + expect(() => new CursorResponse(serialize({ ok: 1, cursor: {} })).id).to.throw( MongoUnexpectedServerResponseError, /"id" is missing/ ); @@ -77,22 +110,22 @@ describe('class CursorResponse', () => { it('throws if input does not contain firstBatch nor nextBatch', () => { expect( // @ts-expect-error: testing private getter - () => new CursorResponse(BSON.serialize({ ok: 1, cursor: { id: 0n, batch: [] } })).batch + () => new CursorResponse(serialize({ ok: 1, cursor: { id: 0n, batch: [] } })).batch ).to.throw(MongoUnexpectedServerResponseError, /did not contain a batch/); }); }); describe('get ns()', () => { it('sets namespace to null if input does not contain cursor.ns', () => { - expect(new CursorResponse(BSON.serialize({ ok: 1, cursor: { id: 0n, firstBatch: [] } })).ns) - .to.be.null; + expect(new CursorResponse(serialize({ ok: 1, cursor: { id: 0n, firstBatch: [] } })).ns).to.be + .null; }); }); describe('get batchSize()', () => { it('reports the returned batch size', () => { const response = new CursorResponse( - BSON.serialize({ ok: 1, cursor: { id: 0n, nextBatch: [{}, {}, {}] } }) + serialize({ ok: 1, cursor: { id: 0n, nextBatch: [{}, {}, {}] } }) ); expect(response.batchSize).to.equal(3); expect(response.shift()).to.deep.equal({}); @@ -103,7 +136,7 @@ describe('class CursorResponse', () => { describe('get length()', () => { it('reports number of documents remaining in the batch', () => { const response = new CursorResponse( - BSON.serialize({ ok: 1, cursor: { id: 0n, nextBatch: [{}, {}, {}] } }) + serialize({ ok: 1, cursor: { id: 0n, nextBatch: [{}, {}, {}] } }) ); expect(response).to.have.lengthOf(3); expect(response.shift()).to.deep.equal({}); @@ -116,7 +149,7 @@ describe('class CursorResponse', () => { beforeEach(async function () { response = new CursorResponse( - BSON.serialize({ + serialize({ ok: 1, cursor: { id: 0n, nextBatch: [{ _id: 1 }, { _id: 2 }, { _id: 3 }] } }) @@ -143,7 +176,7 @@ describe('class CursorResponse', () => { beforeEach(async function () { response = new CursorResponse( - BSON.serialize({ + serialize({ ok: 1, cursor: { id: 0n, nextBatch: [{ _id: 1 }, { _id: 2 }, { _id: 3 }] } })