diff --git a/mongo/integration/mtest/wiremessage_helpers.go b/mongo/integration/mtest/wiremessage_helpers.go index c6d8a677f6..5fd2bc9fa9 100644 --- a/mongo/integration/mtest/wiremessage_helpers.go +++ b/mongo/integration/mtest/wiremessage_helpers.go @@ -49,16 +49,11 @@ func parseOpCompressed(wm []byte) (wiremessage.OpCode, []byte, error) { return originalOpcode, nil, errors.New("failed to read uncompressed size") } - compressorID, wm, ok := wiremessage.ReadCompressedCompressorID(wm) + compressorID, compressedMsg, ok := wiremessage.ReadCompressedCompressorID(wm) if !ok { return originalOpcode, nil, errors.New("failed to read compressor ID") } - compressedMsg, _, ok := wiremessage.ReadCompressedCompressedMessage(wm, int32(len(wm))) - if !ok { - return originalOpcode, nil, errors.New("failed to read compressed message") - } - opts := driver.CompressionOpts{ Compressor: compressorID, UncompressedSize: uncompressedSize, diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index cea3543d14..9dc243fdf0 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1199,18 +1199,12 @@ func (Operation) decompressWireMessage(wm []byte) (wiremessage.OpCode, []byte, e if !ok { return 0, nil, errors.New("malformed OP_COMPRESSED: missing compressor ID") } - compressedSize := len(wm) - 9 // original opcode (4) + uncompressed size (4) + compressor ID (1) - // return the original wiremessage - msg, _, ok := wiremessage.ReadCompressedCompressedMessage(rem, int32(compressedSize)) - if !ok { - return 0, nil, errors.New("malformed OP_COMPRESSED: insufficient bytes for compressed wiremessage") - } opts := CompressionOpts{ Compressor: compressorID, UncompressedSize: uncompressedSize, } - uncompressed, err := DecompressPayload(msg, opts) + uncompressed, err := DecompressPayload(rem, opts) if err != nil { return 0, nil, err } diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index 3e81249f89..987ae16c08 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -381,16 +381,9 @@ func ReadMsgSectionSingleDocument(src []byte) (doc bsoncore.Document, rem []byte // ReadMsgSectionDocumentSequence reads an identifier and document sequence from src and returns the document sequence // data parsed into a slice of BSON documents. func ReadMsgSectionDocumentSequence(src []byte) (identifier string, docs []bsoncore.Document, rem []byte, ok bool) { - length, rem, ok := readi32(src) - if !ok || int(length) > len(src) || length-4 < 0 { - return "", nil, rem, false - } - - rem, ret := rem[:length-4], rem[length-4:] // reslice so we can just iterate a loop later - - identifier, rem, ok = readcstring(rem) + identifier, rem, ret, ok := ReadMsgSectionRawDocumentSequence(src) if !ok { - return "", nil, rem, false + return "", nil, src, false } docs = make([]bsoncore.Document, 0) @@ -403,7 +396,7 @@ func ReadMsgSectionDocumentSequence(src []byte) (identifier string, docs []bsonc docs = append(docs, doc) } if len(rem) > 0 { - return "", nil, append(rem, ret...), false + return "", nil, src, false } return identifier, docs, ret, true @@ -414,7 +407,7 @@ func ReadMsgSectionDocumentSequence(src []byte) (identifier string, docs []bsonc func ReadMsgSectionRawDocumentSequence(src []byte) (identifier string, data []byte, rem []byte, ok bool) { length, rem, ok := readi32(src) if !ok || int(length) > len(src) || length-4 < 0 { - return "", nil, rem, false + return "", nil, src, false } // After these assignments, rem will be the data containing the identifier string + the document sequence bytes and @@ -423,7 +416,7 @@ func ReadMsgSectionRawDocumentSequence(src []byte) (identifier string, data []by identifier, rem, ok = readcstring(rem) if !ok { - return "", nil, rem, false + return "", nil, src, false } return identifier, rem, rest, true @@ -546,14 +539,6 @@ func ReadCompressedCompressorID(src []byte) (id CompressorID, rem []byte, ok boo return CompressorID(src[0]), src[1:], true } -// ReadCompressedCompressedMessage reads the compressed wiremessage to dst. -func ReadCompressedCompressedMessage(src []byte, length int32) (msg []byte, rem []byte, ok bool) { - if len(src) < int(length) || length < 0 { - return nil, src, false - } - return src[:length], src[length:], true -} - // ReadKillCursorsZero reads the zero field from src. func ReadKillCursorsZero(src []byte) (zero int32, rem []byte, ok bool) { return readi32(src) diff --git a/x/mongo/driver/wiremessage/wiremessage_test.go b/x/mongo/driver/wiremessage/wiremessage_test.go index 26cb2637a6..fc8e2e75c7 100644 --- a/x/mongo/driver/wiremessage/wiremessage_test.go +++ b/x/mongo/driver/wiremessage/wiremessage_test.go @@ -175,6 +175,22 @@ func TestReadMsgSectionDocumentSequence(t *testing.T) { wantRem: []byte{0, 1}, wantOK: false, }, + { + desc: "incorrect size", + src: []byte{3, 0, 0}, + wantIdentifier: "", + wantDocs: nil, + wantRem: []byte{3, 0, 0}, + wantOK: false, + }, + { + desc: "insufficient size", + src: []byte{4, 0, 0}, + wantIdentifier: "", + wantDocs: nil, + wantRem: []byte{4, 0, 0}, + wantOK: false, + }, { desc: "nil", src: nil,