diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 81738eb513..0603f385ac 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -437,37 +437,18 @@ functions: make -s evg-test-enterprise-auth run-atlas-test: + - command: ec2.assume_role + params: + role_arn: "${aws_test_secrets_role}" - command: shell.exec type: test params: shell: "bash" working_dir: src/go.mongodb.org/mongo-driver + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] script: | - # DO NOT ECHO WITH XTRACE - if [ "Windows_NT" = "$OS" ]; then - export GOPATH=$(cygpath -w $(dirname $(dirname $(dirname `pwd`)))) - export GOCACHE=$(cygpath -w "$(pwd)/.cache") - else - export GOPATH=$(dirname $(dirname $(dirname `pwd`))) - export GOCACHE="$(pwd)/.cache" - fi; - export GOPATH="$GOPATH" - export GOROOT="${GO_DIST}" - export GOCACHE="$GOCACHE" - export PATH="${GCC_PATH}:${GO_DIST}/bin:$PATH" - export ATLAS_FREE="${atlas_free_tier_uri}" - export ATLAS_REPLSET="${atlas_replica_set_uri}" - export ATLAS_SHARD="${atlas_sharded_uri}" - export ATLAS_TLS11="${atlas_tls_v11_uri}" - export ATLAS_TLS12="${atlas_tls_v12_uri}" - export ATLAS_FREE_SRV="${atlas_free_tier_uri_srv}" - export ATLAS_REPLSET_SRV="${atlas_replica_set_uri_srv}" - export ATLAS_SHARD_SRV="${atlas_sharded_uri_srv}" - export ATLAS_TLS11_SRV="${atlas_tls_v11_uri_srv}" - export ATLAS_TLS12_SRV="${atlas_tls_v12_uri_srv}" - export ATLAS_SERVERLESS="${atlas_serverless_uri}" - export ATLAS_SERVERLESS_SRV="${atlas_serverless_uri_srv}" - make -s evg-test-atlas + ${PREPARE_SHELL} + bash etc/run-atlas-test.sh run-ocsp-test: - command: shell.exec diff --git a/.gitignore b/.gitignore index e9609ce1e9..16b52325e4 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,8 @@ internal/test/compilecheck/compilecheck.so # Ignore api report files api-report.md -api-report.txt \ No newline at end of file +api-report.txt + +# Ignore secrets files +secrets-expansion.yml +secrets-export.sh diff --git a/Makefile b/Makefile index 5531b61a29..66f5b137e5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,3 @@ -ATLAS_URIS = "$(ATLAS_FREE)" "$(ATLAS_REPLSET)" "$(ATLAS_SHARD)" "$(ATLAS_TLS11)" "$(ATLAS_TLS12)" "$(ATLAS_FREE_SRV)" "$(ATLAS_REPLSET_SRV)" "$(ATLAS_SHARD_SRV)" "$(ATLAS_TLS11_SRV)" "$(ATLAS_TLS12_SRV)" "$(ATLAS_SERVERLESS)" "$(ATLAS_SERVERLESS_SRV)" TEST_TIMEOUT = 1800 ### Utility targets. ### @@ -128,10 +127,6 @@ build-aws-ecs-test: evg-test: go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s -p 1 ./... >> test.suite -.PHONY: evg-test-atlas -evg-test-atlas: - go run ./cmd/testatlas/main.go $(ATLAS_URIS) - .PHONY: evg-test-atlas-data-lake evg-test-atlas-data-lake: ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestUnifiedSpecs/atlas-data-lake-testing >> spec_test.suite diff --git a/cmd/testatlas/main.go b/cmd/testatlas/main.go index d98631d468..fa50d7cde7 100644 --- a/cmd/testatlas/main.go +++ b/cmd/testatlas/main.go @@ -23,7 +23,11 @@ func main() { uris := flag.Args() ctx := context.Background() + fmt.Printf("Running atlas tests for %d uris\n", len(uris)) + for idx, uri := range uris { + fmt.Printf("Running test %d\n", idx) + // Set a low server selection timeout so we fail fast if there are errors. clientOpts := options.Client(). ApplyURI(uri). @@ -41,6 +45,8 @@ func main() { panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err)) } } + + fmt.Println("Finished!") } func runTest(ctx context.Context, clientOpts *options.ClientOptions) error { diff --git a/etc/get_aws_secrets.sh b/etc/get_aws_secrets.sh new file mode 100644 index 0000000000..894016553b --- /dev/null +++ b/etc/get_aws_secrets.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# get-aws-secrets +# Gets AWS secrets from the vault +set -eu + +if [ -z "$DRIVERS_TOOLS" ]; then + echo "Please define DRIVERS_TOOLS variable" + exit 1 +fi + +bash $DRIVERS_TOOLS/.evergreen/auth_aws/setup_secrets.sh $@ +. ./secrets-export.sh diff --git a/etc/run-atlas-test.sh b/etc/run-atlas-test.sh new file mode 100644 index 0000000000..aa89b2dd4b --- /dev/null +++ b/etc/run-atlas-test.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# run-atlas-test +# Run atlas connectivity tests. +set -eu +set +x + +# Get the atlas secrets. +. etc/get_aws_secrets.sh drivers/atlas_connect + +echo "Running cmd/testatlas/main.go" +go run ./cmd/testatlas/main.go "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" diff --git a/mongo/errors.go b/mongo/errors.go index 5f2b1b819b..72c3bcc243 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -102,50 +102,56 @@ func replaceErrors(err error) error { return err } -// IsDuplicateKeyError returns true if err is a duplicate key error +// IsDuplicateKeyError returns true if err is a duplicate key error. func IsDuplicateKeyError(err error) bool { - // handles SERVER-7164 and SERVER-11493 - for ; err != nil; err = unwrap(err) { - if e, ok := err.(ServerError); ok { - return e.HasErrorCode(11000) || e.HasErrorCode(11001) || e.HasErrorCode(12582) || - e.HasErrorCodeWithMessage(16460, " E11000 ") - } + if se := ServerError(nil); errors.As(err, &se) { + return se.HasErrorCode(11000) || // Duplicate key error. + se.HasErrorCode(11001) || // Duplicate key error on update. + // Duplicate key error in a capped collection. See SERVER-7164. + se.HasErrorCode(12582) || + // Mongos insert error caused by a duplicate key error. See + // SERVER-11493. + se.HasErrorCodeWithMessage(16460, " E11000 ") } return false } -// IsTimeout returns true if err is from a timeout +// timeoutErrs is a list of error values that indicate a timeout happened. +var timeoutErrs = [...]error{ + context.DeadlineExceeded, + driver.ErrDeadlineWouldBeExceeded, + topology.ErrServerSelectionTimeout, +} + +// IsTimeout returns true if err was caused by a timeout. For error chains, +// IsTimeout returns true if any error in the chain was caused by a timeout. func IsTimeout(err error) bool { - for ; err != nil; err = unwrap(err) { - // check unwrappable errors together - if err == context.DeadlineExceeded { - return true - } - if err == driver.ErrDeadlineWouldBeExceeded { - return true - } - if err == topology.ErrServerSelectionTimeout { - return true - } - if _, ok := err.(topology.WaitQueueTimeoutError); ok { - return true - } - if ce, ok := err.(CommandError); ok && ce.IsMaxTimeMSExpiredError() { + // Check if the error chain contains any of the timeout error values. + for _, target := range timeoutErrs { + if errors.Is(err, target) { return true } - if we, ok := err.(WriteException); ok && we.WriteConcernError != nil && - we.WriteConcernError.IsMaxTimeMSExpiredError() { + } + + // Check if the error chain contains any error types that can indicate + // timeout. + if errors.As(err, &topology.WaitQueueTimeoutError{}) { + return true + } + if ce := (CommandError{}); errors.As(err, &ce) && ce.IsMaxTimeMSExpiredError() { + return true + } + if we := (WriteException{}); errors.As(err, &we) && we.WriteConcernError != nil && we.WriteConcernError.IsMaxTimeMSExpiredError() { + return true + } + if ne := net.Error(nil); errors.As(err, &ne) { + return ne.Timeout() + } + // Check timeout error labels. + if le := LabeledError(nil); errors.As(err, &le) { + if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { return true } - if ne, ok := err.(net.Error); ok { - return ne.Timeout() - } - //timeout error labels - if le, ok := err.(LabeledError); ok { - if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { - return true - } - } } return false diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 914ca863b7..038ed25d72 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -126,8 +126,6 @@ func TestClient(t *testing.T) { "expected security field to be type %v, got %v", bson.TypeMaxKey, security.Type) _, found := security.Document().LookupErr("SSLServerSubjectName") assert.Nil(mt, found, "SSLServerSubjectName not found in result") - _, found = security.Document().LookupErr("SSLServerHasCertificateAuthority") - assert.Nil(mt, found, "SSLServerHasCertificateAuthority not found in result") }) mt.RunOpts("x509", mtest.NewOptions().Auth(true).SSL(true), func(mt *mtest.T) { testCases := []struct { @@ -711,7 +709,7 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(ctx, nil) cancel() assert.NotNil(mt, err, "expected Ping to return an error") - assert.True(mt, mongo.IsTimeout(err), "expected a timeout error: got %v", err) + assert.True(mt, mongo.IsTimeout(err), "expected a timeout error, got: %v", err) } // Assert that the Ping timeouts result in no connections being closed. @@ -733,8 +731,8 @@ func TestClient(t *testing.T) { pair := msgPairs[0] assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s", handshake.LegacyHello, pair.CommandName) - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) + assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode, + "expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String()) // Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire // version is now known to be >= 6. diff --git a/mongo/integration/mtest/sent_message.go b/mongo/integration/mtest/sent_message.go index d36075bf81..6b96e061bc 100644 --- a/mongo/integration/mtest/sent_message.go +++ b/mongo/integration/mtest/sent_message.go @@ -37,6 +37,8 @@ type sentMsgParseFn func([]byte) (*SentMessage, error) func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { switch opcode { + case wiremessage.OpQuery: + return parseOpQuery, true case wiremessage.OpMsg: return parseSentOpMsg, true case wiremessage.OpCompressed: @@ -46,6 +48,69 @@ func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { } } +func parseOpQuery(wm []byte) (*SentMessage, error) { + var ok bool + + if _, wm, ok = wiremessage.ReadQueryFlags(wm); !ok { + return nil, errors.New("failed to read query flags") + } + if _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm); !ok { + return nil, errors.New("failed to read full collection name") + } + if _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm); !ok { + return nil, errors.New("failed to read number to skip") + } + if _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm); !ok { + return nil, errors.New("failed to read number to return") + } + + query, wm, ok := wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("failed to read query") + } + + // If there is no read preference document, the command document is query. + // Otherwise, query is in the format {$query: , $readPreference: }. + commandDoc := query + var rpDoc bsoncore.Document + + dollarQueryVal, err := query.LookupErr("$query") + if err == nil { + commandDoc = dollarQueryVal.Document() + + rpVal, err := query.LookupErr("$readPreference") + if err != nil { + return nil, fmt.Errorf("query %s contains $query but not $readPreference fields", query) + } + rpDoc = rpVal.Document() + } + + // For OP_QUERY, inserts, updates, and deletes are sent as a BSON array of documents inside the main command + // document. Pull these sequences out into an ArrayStyle DocumentSequence. + var docSequence *bsoncore.DocumentSequence + cmdElems, _ := commandDoc.Elements() + for _, elem := range cmdElems { + switch elem.Key() { + case "documents", "updates", "deletes": + docSequence = &bsoncore.DocumentSequence{ + Style: bsoncore.ArrayStyle, + Data: elem.Value().Array(), + } + } + if docSequence != nil { + // There can only be one of these arrays in a well-formed command, so we exit the loop once one is found. + break + } + } + + sm := &SentMessage{ + Command: commandDoc, + ReadPreference: rpDoc, + DocumentSequence: docSequence, + } + return sm, nil +} + func parseSentMessage(wm []byte) (*SentMessage, error) { // Re-assign the wire message to "remaining" so "wm" continues to point to the entire message after parsing. _, requestID, _, opcode, remaining, ok := wiremessage.ReadHeader(wm) diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index f2234e227c..a159891adc 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -93,7 +93,7 @@ func TestSpeculativeSCRAM(t *testing.T) { // Assert that the driver sent hello with the speculative authentication message. assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d", len(tc.payloads), (conn.Written)) - helloCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, helloCmd, handshake.LegacyHello) @@ -177,7 +177,7 @@ func TestSpeculativeSCRAM(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index 13fdf2b185..cf46de6ffd 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -58,7 +58,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) @@ -103,7 +103,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index 7f355f61a4..d79b024b74 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -26,48 +26,72 @@ type CompressionOpts struct { UncompressedSize int32 } -var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder +// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil +// destination writer. It panics on any errors and should only be used at +// package initialization time. +func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { + enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl)) + if err != nil { + panic(err) + } + return enc +} + +var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{ + 0: nil, // zstd.speedNotSet + zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest), + zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault), + zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression), + zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression), +} func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { - if v, ok := zstdEncoders.Load(level); ok { - return v.(*zstd.Encoder), nil - } - encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) - if err != nil { - return nil, err + if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression { + return zstdEncoders[level], nil } - zstdEncoders.Store(level, encoder) - return encoder, nil + // The level is outside the expected range, return an error. + return nil, fmt.Errorf("invalid zstd compression level: %d", level) } -var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder +// zlibEncodersOffset is the offset into the zlibEncoders array for a given +// compression level. +const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2 + +var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool func getZlibEncoder(level int) (*zlibEncoder, error) { - if v, ok := zlibEncoders.Load(level); ok { - return v.(*zlibEncoder), nil - } - writer, err := zlib.NewWriterLevel(nil, level) - if err != nil { - return nil, err + if zlib.HuffmanOnly <= level && level <= zlib.BestCompression { + if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil { + return enc, nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + enc := &zlibEncoder{writer: writer, level: level} + return enc, nil } - encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} - zlibEncoders.Store(level, encoder) + // The level is outside the expected range, return an error. + return nil, fmt.Errorf("invalid zlib compression level: %d", level) +} - return encoder, nil +func putZlibEncoder(enc *zlibEncoder) { + if enc != nil { + zlibEncoders[enc.level+zlibEncodersOffset].Put(enc) + } } type zlibEncoder struct { - mu sync.Mutex writer *zlib.Writer - buf *bytes.Buffer + buf bytes.Buffer + level int } func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { - e.mu.Lock() - defer e.mu.Unlock() + defer putZlibEncoder(e) e.buf.Reset() - e.writer.Reset(e.buf) + e.writer.Reset(&e.buf) _, err := e.writer.Write(src) if err != nil { @@ -105,8 +129,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { } } +var zstdReaderPool = sync.Pool{ + New: func() interface{} { + r, _ := zstd.NewReader(nil) + return r + }, +} + // DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed -func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { +func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { switch opts.Compressor { case wiremessage.CompressorNoOp: return in, nil @@ -117,34 +148,29 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er } else if int32(l) != opts.UncompressedSize { return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) } - uncompressed = make([]byte, opts.UncompressedSize) - return snappy.Decode(uncompressed, in) + out := make([]byte, opts.UncompressedSize) + return snappy.Decode(out, in) case wiremessage.CompressorZLib: r, err := zlib.NewReader(bytes.NewReader(in)) if err != nil { return nil, err } - defer func() { - err = r.Close() - }() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { + out := make([]byte, opts.UncompressedSize) + if _, err := io.ReadFull(r, out); err != nil { return nil, err } - return uncompressed, nil - case wiremessage.CompressorZstd: - r, err := zstd.NewReader(bytes.NewBuffer(in)) - if err != nil { - return nil, err - } - defer r.Close() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { + if err := r.Close(); err != nil { return nil, err } - return uncompressed, nil + return out, nil + case wiremessage.CompressorZstd: + buf := make([]byte, 0, opts.UncompressedSize) + // Using a pool here is about ~20% faster + // than using a single global zstd.Reader + r := zstdReaderPool.Get().(*zstd.Decoder) + out, err := r.DecodeAll(in, buf) + zstdReaderPool.Put(r) + return out, err default: return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) } diff --git a/x/mongo/driver/compression_test.go b/x/mongo/driver/compression_test.go index b477cb32c1..75a7ff072b 100644 --- a/x/mongo/driver/compression_test.go +++ b/x/mongo/driver/compression_test.go @@ -7,9 +7,14 @@ package driver import ( + "bytes" + "compress/zlib" "os" "testing" + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" + "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -41,6 +46,43 @@ func TestCompression(t *testing.T) { } } +func TestCompressionLevels(t *testing.T) { + in := []byte("abc") + wr := new(bytes.Buffer) + + t.Run("ZLib", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorZLib, + } + for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ { + opts.ZlibLevel = lvl + _, err1 := CompressPayload(in, opts) + _, err2 := zlib.NewWriterLevel(wr, lvl) + if err2 != nil { + assert.Error(t, err1, "expected an error for ZLib level %d", lvl) + } else { + assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl) + } + } + }) + + t.Run("Zstd", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorZstd, + } + for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ { + opts.ZstdLevel = int(lvl) + _, err1 := CompressPayload(in, opts) + _, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel))) + if err2 != nil { + assert.Error(t, err1, "expected an error for Zstd level %d", lvl) + } else { + assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl) + } + } + }) +} + func TestDecompressFailures(t *testing.T) { t.Parallel() @@ -62,18 +104,57 @@ func TestDecompressFailures(t *testing.T) { }) } -func BenchmarkCompressPayload(b *testing.B) { - payload := func() []byte { - buf, err := os.ReadFile("compression.go") +var ( + compressionPayload []byte + compressedSnappyPayload []byte + compressedZLibPayload []byte + compressedZstdPayload []byte +) + +func initCompressionPayload(b *testing.B) { + if compressionPayload != nil { + return + } + data, err := os.ReadFile("testdata/compression.go") + if err != nil { + b.Fatal(err) + } + for i := 1; i < 10; i++ { + data = append(data, data...) + } + compressionPayload = data + + compressedSnappyPayload = snappy.Encode(compressedSnappyPayload[:0], data) + + { + var buf bytes.Buffer + enc, err := zstd.NewWriter(&buf, zstd.WithEncoderLevel(zstd.SpeedDefault)) if err != nil { - b.Log(err) - b.FailNow() + b.Fatal(err) } - for i := 1; i < 10; i++ { - buf = append(buf, buf...) + compressedZstdPayload = enc.EncodeAll(data, nil) + } + + { + var buf bytes.Buffer + enc := zlib.NewWriter(&buf) + if _, err := enc.Write(data); err != nil { + b.Fatal(err) } - return buf - }() + if err := enc.Close(); err != nil { + b.Fatal(err) + } + if err := enc.Close(); err != nil { + b.Fatal(err) + } + compressedZLibPayload = append(compressedZLibPayload[:0], buf.Bytes()...) + } + + b.ResetTimer() +} + +func BenchmarkCompressPayload(b *testing.B) { + initCompressionPayload(b) compressors := []wiremessage.CompressorID{ wiremessage.CompressorSnappy, @@ -88,6 +169,9 @@ func BenchmarkCompressPayload(b *testing.B) { ZlibLevel: wiremessage.DefaultZlibLevel, ZstdLevel: wiremessage.DefaultZstdLevel, } + payload := compressionPayload + b.SetBytes(int64(len(payload))) + b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _, err := CompressPayload(payload, opts) @@ -99,3 +183,38 @@ func BenchmarkCompressPayload(b *testing.B) { }) } } + +func BenchmarkDecompressPayload(b *testing.B) { + initCompressionPayload(b) + + benchmarks := []struct { + compressor wiremessage.CompressorID + payload []byte + }{ + {wiremessage.CompressorSnappy, compressedSnappyPayload}, + {wiremessage.CompressorZLib, compressedZLibPayload}, + {wiremessage.CompressorZstd, compressedZstdPayload}, + } + + for _, bench := range benchmarks { + b.Run(bench.compressor.String(), func(b *testing.B) { + opts := CompressionOpts{ + Compressor: bench.compressor, + ZlibLevel: wiremessage.DefaultZlibLevel, + ZstdLevel: wiremessage.DefaultZstdLevel, + UncompressedSize: int32(len(compressionPayload)), + } + payload := bench.payload + b.SetBytes(int64(len(compressionPayload))) + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := DecompressPayload(payload, opts) + if err != nil { + b.Fatal(err) + } + } + }) + }) + } +} diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index d2ae8df248..27be4c264d 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -99,6 +99,38 @@ func MakeReply(doc bsoncore.Document) []byte { return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) } +// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message. +func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) { + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + return nil, errors.New("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + return nil, errors.New("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + return nil, errors.New("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + return nil, errors.New("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + return nil, errors.New("could not read numberToReturn") + } + + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("could not read query") + } + return query, nil +} + // GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message. func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) { var ok bool diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go index 55f2fb37eb..177aa1234b 100644 --- a/x/mongo/driver/errors.go +++ b/x/mongo/driver/errors.go @@ -264,10 +264,15 @@ func (e Error) UnsupportedStorageEngine() bool { // Error implements the error interface. func (e Error) Error() string { + var msg string if e.Name != "" { - return fmt.Sprintf("(%v) %v", e.Name, e.Message) + msg = fmt.Sprintf("(%v)", e.Name) } - return e.Message + msg += " " + e.Message + if e.Wrapped != nil { + msg += ": " + e.Wrapped.Error() + } + return msg } // Unwrap returns the underlying error. diff --git a/x/mongo/driver/legacy.go b/x/mongo/driver/legacy.go index 9f3b8a39ac..c40f1f8091 100644 --- a/x/mongo/driver/legacy.go +++ b/x/mongo/driver/legacy.go @@ -19,4 +19,5 @@ const ( LegacyKillCursors LegacyListCollections LegacyListIndexes + LegacyHandshake ) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 4c759c1505..6b56191a01 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -19,6 +19,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/csot" @@ -383,7 +384,11 @@ func (oss *opServerSelector) SelectServer( } // selectServer handles performing server selection for an operation. -func (op Operation) selectServer(ctx context.Context, deprioritized []description.Server) (Server, error) { +func (op Operation) selectServer( + ctx context.Context, + requestID int32, + deprioritized []description.Server, +) (Server, error) { if err := op.Validate(); err != nil { return nil, err } @@ -406,7 +411,7 @@ func (op Operation) selectServer(ctx context.Context, deprioritized []descriptio } ctx = logger.WithOperationName(ctx, op.Name) - ctx = logger.WithOperationID(ctx, wiremessage.CurrentRequestID()) + ctx = logger.WithOperationID(ctx, requestID) return op.Deployment.SelectServer(ctx, oss) } @@ -414,9 +419,10 @@ func (op Operation) selectServer(ctx context.Context, deprioritized []descriptio // getServerAndConnection should be used to retrieve a Server and Connection to execute an operation. func (op Operation) getServerAndConnection( ctx context.Context, + requestID int32, deprioritized []description.Server, ) (Server, Connection, error) { - server, err := op.selectServer(ctx, deprioritized) + server, err := op.selectServer(ctx, requestID, deprioritized) if err != nil { if op.Client != nil && !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { @@ -611,11 +617,11 @@ func (op Operation) Execute(ctx context.Context) error { } }() for { - wiremessage.NextRequestID() + requestID := wiremessage.NextRequestID() // If the server or connection are nil, try to select a new server and get a new connection. if srvr == nil || conn == nil { - srvr, conn, err = op.getServerAndConnection(ctx, deprioritizedServers) + srvr, conn, err = op.getServerAndConnection(ctx, requestID, deprioritizedServers) if err != nil { // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server @@ -710,7 +716,7 @@ func (op Operation) Execute(ctx context.Context) error { } var startedInfo startedInformation - *wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn) + *wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) if err != nil { return err @@ -1184,8 +1190,92 @@ func (op Operation) addBatchArray(dst []byte) []byte { return dst } -func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer, +func (op Operation) createLegacyHandshakeWireMessage( + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, +) ([]byte, startedInformation, error) { + var info startedInformation + flags := op.secondaryOK(desc) + var wmindex int32 + info.requestID = wiremessage.NextRequestID() + wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery) + dst = wiremessage.AppendQueryFlags(dst, flags) + + dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'} + + // FullCollectionName + dst = append(dst, op.Database...) + dst = append(dst, dollarCmd[:]...) + dst = append(dst, 0x00) + dst = wiremessage.AppendQueryNumberToSkip(dst, 0) + dst = wiremessage.AppendQueryNumberToReturn(dst, -1) + + wrapper := int32(-1) + rp, err := op.createReadPref(desc, true) + if err != nil { + return dst, info, err + } + if len(rp) > 0 { + wrapper, dst = bsoncore.AppendDocumentStart(dst) + dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query") + } + idx, dst := bsoncore.AppendDocumentStart(dst) + dst, err = op.CommandFn(dst, desc) + if err != nil { + return dst, info, err + } + + if op.Batches != nil && len(op.Batches.Current) > 0 { + dst = op.addBatchArray(dst) + } + + dst, err = op.addReadConcern(dst, desc) + if err != nil { + return dst, info, err + } + + dst, err = op.addWriteConcern(dst, desc) + if err != nil { + return dst, info, err + } + + dst, err = op.addSession(dst, desc) + if err != nil { + return dst, info, err + } + + dst = op.addClusterTime(dst, desc) + dst = op.addServerAPI(dst) + // If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly + // specifies the default behavior of no timeout server-side. + if maxTimeMS > 0 { + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS)) + } + + dst, _ = bsoncore.AppendDocumentEnd(dst, idx) + // Command monitoring only reports the document inside $query + info.cmd = dst[idx:] + + if len(rp) > 0 { + var err error + dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) + dst, err = bsoncore.AppendDocumentEnd(dst, wrapper) + if err != nil { + return dst, info, err + } + } + + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil +} + +func (op Operation) createMsgWireMessage( + ctx context.Context, + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, conn Connection, + requestID int32, ) ([]byte, startedInformation, error) { var info startedInformation var flags wiremessage.MsgFlag @@ -1201,7 +1291,7 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, flags |= wiremessage.ExhaustAllowed } - info.requestID = wiremessage.CurrentRequestID() + info.requestID = requestID wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg) dst = wiremessage.AppendMsgFlags(dst, flags) // Body @@ -1267,6 +1357,29 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil } +// isLegacyHandshake returns True if the operation is the first message of +// the initial handshake and should use a legacy hello. +func isLegacyHandshake(op Operation, desc description.SelectedServer) bool { + isInitialHandshake := desc.WireVersion == nil || desc.WireVersion.Max == 0 + + return op.Legacy == LegacyHandshake && isInitialHandshake +} + +func (op Operation) createWireMessage( + ctx context.Context, + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, + conn Connection, + requestID int32, +) ([]byte, startedInformation, error) { + if isLegacyHandshake(op, desc) { + return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc) + } + + return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID) +} + // addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document // has already been added and does not add the final 0 byte. func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) { diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 3cfa2d450a..16d5809130 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -537,8 +537,16 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti return h.createOperation().ExecuteExhaust(ctx, conn) } +// isLegacyHandshake returns True if server API version is not requested and +// loadBalanced is False. If this is the case, then the drivers MUST use legacy +// hello for the first message of the initial handshake with the OP_QUERY +// protocol +func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool { + return srvAPI == nil && deployment.Kind() != description.LoadBalanced +} + func (h *Hello) createOperation() driver.Operation { - return driver.Operation{ + op := driver.Operation{ Clock: h.clock, CommandFn: h.command, Database: "admin", @@ -549,23 +557,36 @@ func (h *Hello) createOperation() driver.Operation { }, ServerAPI: h.serverAPI, } + + if isLegacyHandshake(h.serverAPI, h.d) { + op.Legacy = driver.LegacyHandshake + } + + return op } // GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant // information about the server. This function implements the driver.Handshaker interface. func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, c driver.Connection) (driver.HandshakeInformation, error) { - err := driver.Operation{ + deployment := driver.SingleConnectionDeployment{C: c} + + op := driver.Operation{ Clock: h.clock, CommandFn: h.handshakeCommand, - Deployment: driver.SingleConnectionDeployment{C: c}, + Deployment: deployment, Database: "admin", ProcessResponseFn: func(info driver.ResponseInfo) error { h.res = info.ServerResponse return nil }, ServerAPI: h.serverAPI, - }.Execute(ctx) - if err != nil { + } + + if isLegacyHandshake(h.serverAPI, deployment) { + op.Legacy = driver.LegacyHandshake + } + + if err := op.Execute(ctx); err != nil { return driver.HandshakeInformation{}, err } @@ -578,6 +599,9 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, if serverConnectionID, ok := h.res.Lookup("connectionId").AsInt64OK(); ok { info.ServerConnectionID = &serverConnectionID } + + var err error + // Cast to bson.Raw to lookup saslSupportedMechs to avoid converting from bsoncore.Value to bson.RawValue for the // StringSliceFromRawValue call. if saslSupportedMechs, lookupErr := bson.Raw(h.res).LookupErr("saslSupportedMechs"); lookupErr == nil { diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index c308f97241..e6c9d4cf95 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -63,7 +63,7 @@ func TestOperation(t *testing.T) { t.Run("selectServer", func(t *testing.T) { t.Run("returns validation error", func(t *testing.T) { op := &Operation{} - _, err := op.selectServer(context.Background(), nil) + _, err := op.selectServer(context.Background(), 1, nil) if err == nil { t.Error("Expected a validation error from selectServer, but got ") } @@ -77,7 +77,7 @@ func TestOperation(t *testing.T) { Database: "testing", Selector: want, } - _, err := op.selectServer(context.Background(), nil) + _, err := op.selectServer(context.Background(), 1, nil) noerr(t, err) // Assert the the selector is an operation selector wrapper. @@ -95,7 +95,7 @@ func TestOperation(t *testing.T) { Deployment: d, Database: "testing", } - _, err := op.selectServer(context.Background(), nil) + _, err := op.selectServer(context.Background(), 1, nil) noerr(t, err) if d.params.selector == nil { t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed .") @@ -657,7 +657,8 @@ func TestOperation(t *testing.T) { } func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte { - idx, wm := wiremessage.AppendHeaderStart(nil, 0, wiremessage.CurrentRequestID()+1, wiremessage.OpMsg) + const psuedoRequestID = 1 + idx, wm := wiremessage.AppendHeaderStart(nil, 0, psuedoRequestID, wiremessage.OpMsg) var flags wiremessage.MsgFlag if moreToCome { flags = wiremessage.MoreToCome diff --git a/x/mongo/driver/testdata/compression.go b/x/mongo/driver/testdata/compression.go new file mode 100644 index 0000000000..7f355f61a4 --- /dev/null +++ b/x/mongo/driver/testdata/compression.go @@ -0,0 +1,151 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driver + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" + + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" +) + +// CompressionOpts holds settings for how to compress a payload +type CompressionOpts struct { + Compressor wiremessage.CompressorID + ZlibLevel int + ZstdLevel int + UncompressedSize int32 +} + +var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder + +func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { + if v, ok := zstdEncoders.Load(level); ok { + return v.(*zstd.Encoder), nil + } + encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) + if err != nil { + return nil, err + } + zstdEncoders.Store(level, encoder) + return encoder, nil +} + +var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder + +func getZlibEncoder(level int) (*zlibEncoder, error) { + if v, ok := zlibEncoders.Load(level); ok { + return v.(*zlibEncoder), nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} + zlibEncoders.Store(level, encoder) + + return encoder, nil +} + +type zlibEncoder struct { + mu sync.Mutex + writer *zlib.Writer + buf *bytes.Buffer +} + +func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { + e.mu.Lock() + defer e.mu.Unlock() + + e.buf.Reset() + e.writer.Reset(e.buf) + + _, err := e.writer.Write(src) + if err != nil { + return nil, err + } + err = e.writer.Close() + if err != nil { + return nil, err + } + dst = append(dst[:0], e.buf.Bytes()...) + return dst, nil +} + +// CompressPayload takes a byte slice and compresses it according to the options passed +func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + return snappy.Encode(nil, in), nil + case wiremessage.CompressorZLib: + encoder, err := getZlibEncoder(opts.ZlibLevel) + if err != nil { + return nil, err + } + return encoder.Encode(nil, in) + case wiremessage.CompressorZstd: + encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel)) + if err != nil { + return nil, err + } + return encoder.EncodeAll(in, nil), nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +} + +// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed +func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + l, err := snappy.DecodedLen(in) + if err != nil { + return nil, fmt.Errorf("decoding compressed length %w", err) + } else if int32(l) != opts.UncompressedSize { + return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) + } + uncompressed = make([]byte, opts.UncompressedSize) + return snappy.Decode(uncompressed, in) + case wiremessage.CompressorZLib: + r, err := zlib.NewReader(bytes.NewReader(in)) + if err != nil { + return nil, err + } + defer func() { + err = r.Close() + }() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + case wiremessage.CompressorZstd: + r, err := zstd.NewReader(bytes.NewBuffer(in)) + if err != nil { + return nil, err + } + defer r.Close() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +} diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index a2abd1fb1f..ba92b6dd94 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -31,9 +31,11 @@ import ( "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) type channelNetConnDialer struct{} @@ -1207,12 +1209,41 @@ func TestServer_ProcessError(t *testing.T) { func includesClientMetadata(t *testing.T, wm []byte) bool { t.Helper() - doc, err := drivertest.GetCommandFromMsgWireMessage(wm) - assert.NoError(t, err) + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + t.Fatal("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + t.Fatal("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + t.Fatal("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + t.Fatal("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + t.Fatal("could not read numberToReturn") + } + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + t.Fatal("could not read query") + } - _, err = doc.LookupErr("client") + if _, err := query.LookupErr("client"); err == nil { + return true + } + if _, err := query.LookupErr("$query", "client"); err == nil { + return true + } - return err == nil + return false } // processErrorTestConn is a driver.Connection implementation used by tests diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index c4d2567bf0..abf09c15bd 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -19,9 +19,6 @@ type WireMessage []byte var globalRequestID int32 -// CurrentRequestID returns the current request ID. -func CurrentRequestID() int32 { return atomic.LoadInt32(&globalRequestID) } - // NextRequestID returns the next request ID. func NextRequestID() int32 { return atomic.AddInt32(&globalRequestID, 1) }