diff --git a/.evergreen/config.yml b/.evergreen/config.yml index db6b51f680..6e71fd660f 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -141,8 +141,8 @@ functions: export UPLOAD_BUCKET="$UPLOAD_BUCKET" export PROJECT="$PROJECT" export TMPDIR="$MONGO_ORCHESTRATION_HOME/db" - export PKG_CONFIG_PATH=$(pwd)/install/libmongocrypt/lib/pkgconfig:$(pwd)/install/mongo-c-driver/lib/pkgconfig - export LD_LIBRARY_PATH=$(pwd)/install/libmongocrypt/lib + export PKG_CONFIG_PATH=$(pwd)/install/libmongocrypt/lib64/pkgconfig:$(pwd)/install/mongo-c-driver/lib/pkgconfig + export LD_LIBRARY_PATH=$(pwd)/install/libmongocrypt/lib64 export PATH="$PATH" EOT # See what we variables we've set. @@ -245,6 +245,7 @@ functions: params: shell: "bash" script: | + set -x ${PREPARE_SHELL} MONGODB_VERSION=${VERSION} \ @@ -437,37 +438,18 @@ functions: make -s evg-test-enterprise-auth run-atlas-test: + - command: ec2.assume_role + params: + role_arn: "${aws_test_secrets_role}" - command: shell.exec type: test params: shell: "bash" working_dir: src/go.mongodb.org/mongo-driver + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] script: | - # DO NOT ECHO WITH XTRACE - if [ "Windows_NT" = "$OS" ]; then - export GOPATH=$(cygpath -w $(dirname $(dirname $(dirname `pwd`)))) - export GOCACHE=$(cygpath -w "$(pwd)/.cache") - else - export GOPATH=$(dirname $(dirname $(dirname `pwd`))) - export GOCACHE="$(pwd)/.cache" - fi; - export GOPATH="$GOPATH" - export GOROOT="${GO_DIST}" - export GOCACHE="$GOCACHE" - export PATH="${GCC_PATH}:${GO_DIST}/bin:$PATH" - export ATLAS_FREE="${atlas_free_tier_uri}" - export ATLAS_REPLSET="${atlas_replica_set_uri}" - export ATLAS_SHARD="${atlas_sharded_uri}" - export ATLAS_TLS11="${atlas_tls_v11_uri}" - export ATLAS_TLS12="${atlas_tls_v12_uri}" - export ATLAS_FREE_SRV="${atlas_free_tier_uri_srv}" - export ATLAS_REPLSET_SRV="${atlas_replica_set_uri_srv}" - export ATLAS_SHARD_SRV="${atlas_sharded_uri_srv}" - export ATLAS_TLS11_SRV="${atlas_tls_v11_uri_srv}" - export ATLAS_TLS12_SRV="${atlas_tls_v12_uri_srv}" - export ATLAS_SERVERLESS="${atlas_serverless_uri}" - export ATLAS_SERVERLESS_SRV="${atlas_serverless_uri_srv}" - make -s evg-test-atlas + ${PREPARE_SHELL} + bash etc/run-atlas-test.sh run-ocsp-test: - command: shell.exec @@ -2101,7 +2083,7 @@ tasks: export GCPKMS_PROJECT=${GCPKMS_PROJECT} export GCPKMS_ZONE=${GCPKMS_ZONE} export GCPKMS_INSTANCENAME=${GCPKMS_INSTANCENAME} - tar czf testgcpkms.tgz ./testkms ./install/libmongocrypt/lib/libmongocrypt.* + tar czf testgcpkms.tgz ./testkms ./install/libmongocrypt/lib64/libmongocrypt.* GCPKMS_SRC=testgcpkms.tgz GCPKMS_DST=$GCPKMS_INSTANCENAME: $DRIVERS_TOOLS/.evergreen/csfle/gcpkms/copy-file.sh echo "Copying files ... end" @@ -2120,7 +2102,7 @@ tasks: export GCPKMS_PROJECT=${GCPKMS_PROJECT} export GCPKMS_ZONE=${GCPKMS_ZONE} export GCPKMS_INSTANCENAME=${GCPKMS_INSTANCENAME} - GCPKMS_CMD="LD_LIBRARY_PATH=./install/libmongocrypt/lib MONGODB_URI='mongodb://localhost:27017' PROVIDER='gcp' ./testkms" $DRIVERS_TOOLS/.evergreen/csfle/gcpkms/run-command.sh + GCPKMS_CMD="LD_LIBRARY_PATH=./install/libmongocrypt/lib64 MONGODB_URI='mongodb://localhost:27017' PROVIDER='gcp' ./testkms" $DRIVERS_TOOLS/.evergreen/csfle/gcpkms/run-command.sh - name: "testgcpkms-fail-task" # testgcpkms-fail-task runs in a non-GCE environment. @@ -2138,7 +2120,7 @@ tasks: PKG_CONFIG_PATH=$PKG_CONFIG_PATH \ make build-kms-test echo "Building build-kms-test ... end" - LD_LIBRARY_PATH=./install/libmongocrypt/lib \ + LD_LIBRARY_PATH=./install/libmongocrypt/lib64 \ MONGODB_URI='mongodb://localhost:27017/' \ EXPECT_ERROR='unable to retrieve GCP credentials' \ PROVIDER='gcp' \ @@ -2162,7 +2144,7 @@ tasks: export AWS_ACCESS_KEY_ID="${cse_aws_access_key_id}" export AWS_SECRET_ACCESS_KEY="${cse_aws_secret_access_key}" - LD_LIBRARY_PATH=./install/libmongocrypt/lib \ + LD_LIBRARY_PATH=./install/libmongocrypt/lib64 \ MONGODB_URI='${atlas_free_tier_uri}' \ PROVIDER='aws' \ ./testkms @@ -2184,9 +2166,9 @@ tasks: make build-kms-test echo "Building build-kms-test ... end" - LD_LIBRARY_PATH=./install/libmongocrypt/lib \ + LD_LIBRARY_PATH=./install/libmongocrypt/lib64 \ MONGODB_URI='${atlas_free_tier_uri}' \ - EXPECT_ERROR='unable to retrieve aws credentials' \ + EXPECT_ERROR='status=400' \ PROVIDER='aws' \ ./testkms @@ -2210,7 +2192,7 @@ tasks: export AZUREKMS_VMNAME=${AZUREKMS_VMNAME} echo '${testazurekms_privatekey}' > /tmp/testazurekms.prikey export AZUREKMS_PRIVATEKEYPATH=/tmp/testazurekms.prikey - tar czf testazurekms.tgz ./testkms ./install/libmongocrypt/lib/libmongocrypt.* + tar czf testazurekms.tgz ./testkms ./install/libmongocrypt/lib64/libmongocrypt.* AZUREKMS_SRC=testazurekms.tgz AZUREKMS_DST=/tmp $DRIVERS_TOOLS/.evergreen/csfle/azurekms/copy-file.sh echo "Copying files ... end" echo "Untarring file ... begin" @@ -2228,7 +2210,7 @@ tasks: export AZUREKMS_VMNAME=${AZUREKMS_VMNAME} echo '${testazurekms_privatekey}' > /tmp/testazurekms.prikey export AZUREKMS_PRIVATEKEYPATH=/tmp/testazurekms.prikey - AZUREKMS_CMD="LD_LIBRARY_PATH=./install/libmongocrypt/lib MONGODB_URI='mongodb://localhost:27017' PROVIDER='azure' ./testkms" $DRIVERS_TOOLS/.evergreen/csfle/azurekms/run-command.sh + AZUREKMS_CMD="LD_LIBRARY_PATH=./install/libmongocrypt/lib64 MONGODB_URI='mongodb://localhost:27017' PROVIDER='azure' AZUREKMS_KEY_NAME='${AZUREKMS_KEY_NAME}' AZUREKMS_KEY_VAULT_ENDPOINT='${AZUREKMS_KEY_VAULT_ENDPOINT}' ./testkms" $DRIVERS_TOOLS/.evergreen/csfle/azurekms/run-command.sh - name: "testazurekms-fail-task" # testazurekms-fail-task runs without environment variables. @@ -2247,10 +2229,10 @@ tasks: make build-kms-test echo "Building build-kms-test ... end" - LD_LIBRARY_PATH=./install/libmongocrypt/lib \ + LD_LIBRARY_PATH=./install/libmongocrypt/lib64 \ MONGODB_URI='mongodb://localhost:27017' \ EXPECT_ERROR='unable to retrieve azure credentials' \ - PROVIDER='azure' \ + PROVIDER='azure' AZUREKMS_KEY_NAME='${AZUREKMS_KEY_NAME}' AZUREKMS_KEY_VAULT_ENDPOINT='${AZUREKMS_KEY_VAULT_ENDPOINT}' \ ./testkms - name: "test-fuzz" @@ -2330,9 +2312,9 @@ axes: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" VENV_BIN_DIR: "Scripts" - - id: "ubuntu1604-64-go-1-20" - display_name: "Ubuntu 16.04" - run_on: ubuntu1604-build + - id: "rhel87-64-go-1-20" + display_name: "RHEL 8.7" + run_on: rhel8.7-large variables: GO_DIST: "/opt/golang/go1.20" - id: "macos11-go-1-20" @@ -2354,9 +2336,9 @@ axes: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" VENV_BIN_DIR: "Scripts" - - id: "ubuntu1804-64-go-1-20" - display_name: "Ubuntu 18.04" - run_on: ubuntu1804-build + - id: "rhel87-64-go-1-20" + display_name: "RHEL 8.7" + run_on: rhel8.7-large variables: GO_DIST: "/opt/golang/go1.20" - id: "macos11-go-1-20" @@ -2366,13 +2348,12 @@ axes: variables: GO_DIST: "/opt/golang/go1.20" - # OCSP linux tasks need to run against this OS since stapling is disabled on Ubuntu 18.04 (SERVER-51364) - - id: ocsp-rhel-70 + - id: ocsp-rhel-87 display_name: OS values: - - id: "rhel70-go-1-20" - display_name: "RHEL 7.0" - run_on: rhel70-build + - id: "rhel87-go-1-20" + display_name: "RHEL 8.7" + run_on: rhel8.7-large variables: GO_DIST: "/opt/golang/go1.20" @@ -2387,6 +2368,8 @@ axes: GCC_PATH: "/cygdrive/c/ProgramData/chocolatey/lib/mingw/tools/install/mingw64/bin" GO_DIST: "C:\\golang\\go1.20" SKIP_ECS_AUTH_TEST: true + # TODO(BUILD-17329): Update this to Ubuntu 22 after we add a new ECS task + # definition. - id: "ubuntu1804-64-go-1-20" display_name: "Ubuntu 18.04" run_on: ubuntu1804-test @@ -2405,18 +2388,18 @@ axes: - id: os-faas-80 display_name: OS values: - - id: "rhel80-large-go-1-20" - display_name: "RHEL 8.0" - run_on: rhel80-large + - id: "rhel87-large-go-1-20" + display_name: "RHEL 8.7" + run_on: rhel8.7-large variables: GO_DIST: "/opt/golang/go1.20" - id: os-serverless display_name: OS values: - - id: "ubuntu2204-go-1-20" - display_name: "Ubuntu 22.04" - run_on: ubuntu2204-small + - id: "rhel87-go-1-20" + display_name: "RHEL 8.7" + run_on: rhel8.7-small variables: GO_DIST: "/opt/golang/go1.20" @@ -2639,7 +2622,7 @@ buildvariants: - name: static-analysis display_name: "Static Analysis" run_on: - - ubuntu1804-build + - rhel8.7-large expansions: GO_DIST: "/opt/golang/go1.20" tasks: @@ -2648,7 +2631,7 @@ buildvariants: - name: perf display_name: "Performance" run_on: - - ubuntu1804-build + - rhel8.7-large expansions: GO_DIST: "/opt/golang/go1.20" tasks: @@ -2657,16 +2640,16 @@ buildvariants: - name: build-check display_name: "Compile Only Checks" run_on: - - ubuntu1804-test + - rhel8.7-large expansions: GO_DIST: "/opt/golang/go1.20" tasks: - name: ".compile-check" - + - name: atlas-test display_name: "Atlas test" run_on: - - ubuntu1804-build + - rhel8.7-large expansions: GO_DIST: "/opt/golang/go1.20" tasks: @@ -2675,7 +2658,7 @@ buildvariants: - name: atlas-data-lake-test display_name: "Atlas Data Lake Test" run_on: - - ubuntu1804-build + - rhel8.7-large expansions: GO_DIST: "/opt/golang/go1.20" tasks: @@ -2712,8 +2695,8 @@ buildvariants: - name: "aws-auth-test" - matrix_name: "ocsp-test" - matrix_spec: { version: ["4.4", "5.0", "6.0", "7.0", "latest"], ocsp-rhel-70: ["rhel70-go-1-20"] } - display_name: "OCSP ${version} ${ocsp-rhel-70}" + matrix_spec: { version: ["4.4", "5.0", "6.0", "7.0", "latest"], ocsp-rhel-87: ["rhel87-go-1-20"] } + display_name: "OCSP ${version} ${ocsp-rhel-87}" batchtime: 20160 # Use a batchtime of 14 days as suggested by the OCSP test README tasks: - name: ".ocsp" @@ -2735,7 +2718,7 @@ buildvariants: - name: ".ocsp-rsa !.ocsp-staple" - matrix_name: "race-test" - matrix_spec: { version: ["latest"], os-ssl-40: ["ubuntu1804-64-go-1-20"] } + matrix_spec: { version: ["latest"], os-ssl-40: ["rhel87-64-go-1-20"] } display_name: "Race Detector Test" tasks: - name: ".race" @@ -2747,14 +2730,13 @@ buildvariants: - name: ".versioned-api" - matrix_name: "kms-tls-test" - matrix_spec: { version: ["latest"], os-ssl-40: ["ubuntu1804-64-go-1-20"] } + matrix_spec: { version: ["latest"], os-ssl-40: ["rhel87-64-go-1-20"] } display_name: "KMS TLS ${os-ssl-40}" tasks: - name: ".kms-tls" - matrix_name: "load-balancer-test" - # The LB software is only available on Ubuntu 18.04, so we don't test on all OSes. - matrix_spec: { version: ["5.0", "6.0", "7.0", "latest", "rapid"], os-ssl-40: ["ubuntu1804-64-go-1-20"] } + matrix_spec: { version: ["5.0", "6.0", "7.0", "latest", "rapid"], os-ssl-40: ["rhel87-64-go-1-20"] } display_name: "Load Balancer Support ${version} ${os-ssl-40}" tasks: - name: ".load-balancer" @@ -2766,20 +2748,20 @@ buildvariants: - "serverless_task_group" - matrix_name: "kms-kmip-test" - matrix_spec: { version: ["latest"], os-ssl-40: ["ubuntu1804-64-go-1-20"] } + matrix_spec: { version: ["latest"], os-ssl-40: ["rhel87-64-go-1-20"] } display_name: "KMS KMIP ${os-ssl-40}" tasks: - name: ".kms-kmip" - matrix_name: "fuzz-test" - matrix_spec: { version: ["5.0"], os-ssl-40: ["ubuntu1804-64-go-1-20"] } + matrix_spec: { version: ["5.0"], os-ssl-40: ["rhel87-64-go-1-20"] } display_name: "Fuzz ${version} ${os-ssl-40}" tasks: - name: "test-fuzz" batchtime: 1440 # Run at most once per 24 hours. - matrix_name: "faas-test" - matrix_spec: { version: ["latest"], os-faas-80: ["rhel80-large-go-1-20"] } + matrix_spec: { version: ["latest"], os-faas-80: ["rhel87-large-go-1-20"] } display_name: "FaaS ${version} ${os-faas-80}" tasks: - test-aws-lambda-task-group @@ -2787,7 +2769,7 @@ buildvariants: - name: testgcpkms-variant display_name: "GCP KMS" run_on: - - debian11-small + - rhel8.7-small expansions: GO_DIST: "/opt/golang/go1.20" tasks: @@ -2798,7 +2780,7 @@ buildvariants: - name: testawskms-variant display_name: "AWS KMS" run_on: - - debian11-small + - rhel8.7-small expansions: GO_DIST: "/opt/golang/go1.20" tasks: @@ -2808,7 +2790,7 @@ buildvariants: - name: testazurekms-variant display_name: "AZURE KMS" run_on: - - debian11-small + - rhel8.7-small expansions: GO_DIST: "/opt/golang/go1.20" tasks: diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 392531613a..a223f309f8 100644 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -15,8 +15,8 @@ fi export GOROOT="${GOROOT}" export PATH="${GOROOT}/bin:${GCC_PATH}:$GOPATH/bin:$PATH" export PROJECT="${project}" -export PKG_CONFIG_PATH=$(pwd)/install/libmongocrypt/lib/pkgconfig:$(pwd)/install/mongo-c-driver/lib/pkgconfig -export LD_LIBRARY_PATH=$(pwd)/install/libmongocrypt/lib +export PKG_CONFIG_PATH=$(pwd)/install/libmongocrypt/lib64/pkgconfig:$(pwd)/install/mongo-c-driver/lib/pkgconfig +export LD_LIBRARY_PATH=$(pwd)/install/libmongocrypt/lib64 export GOFLAGS=-mod=vendor SSL=${SSL:-nossl} diff --git a/.gitignore b/.gitignore index e9609ce1e9..16b52325e4 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,8 @@ internal/test/compilecheck/compilecheck.so # Ignore api report files api-report.md -api-report.txt \ No newline at end of file +api-report.txt + +# Ignore secrets files +secrets-expansion.yml +secrets-export.sh diff --git a/Makefile b/Makefile index 5531b61a29..66f5b137e5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,3 @@ -ATLAS_URIS = "$(ATLAS_FREE)" "$(ATLAS_REPLSET)" "$(ATLAS_SHARD)" "$(ATLAS_TLS11)" "$(ATLAS_TLS12)" "$(ATLAS_FREE_SRV)" "$(ATLAS_REPLSET_SRV)" "$(ATLAS_SHARD_SRV)" "$(ATLAS_TLS11_SRV)" "$(ATLAS_TLS12_SRV)" "$(ATLAS_SERVERLESS)" "$(ATLAS_SERVERLESS_SRV)" TEST_TIMEOUT = 1800 ### Utility targets. ### @@ -128,10 +127,6 @@ build-aws-ecs-test: evg-test: go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s -p 1 ./... >> test.suite -.PHONY: evg-test-atlas -evg-test-atlas: - go run ./cmd/testatlas/main.go $(ATLAS_URIS) - .PHONY: evg-test-atlas-data-lake evg-test-atlas-data-lake: ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestUnifiedSpecs/atlas-data-lake-testing >> spec_test.suite diff --git a/bson/bsoncodec/slice_codec.go b/bson/bsoncodec/slice_codec.go index 20c3e7549c..a43daf005f 100644 --- a/bson/bsoncodec/slice_codec.go +++ b/bson/bsoncodec/slice_codec.go @@ -62,7 +62,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re } // If we have a []primitive.E we want to treat it as a document instead of as an array. - if val.Type().ConvertibleTo(tD) { + if val.Type() == tD || val.Type().ConvertibleTo(tD) { d := val.Convert(tD).Interface().(primitive.D) dw, err := vw.WriteDocument() diff --git a/bson/bsoncodec/struct_codec.go b/bson/bsoncodec/struct_codec.go index 29ea76d19c..4cde0a4d6b 100644 --- a/bson/bsoncodec/struct_codec.go +++ b/bson/bsoncodec/struct_codec.go @@ -190,15 +190,14 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val encoder := desc.encoder var zero bool - rvInterface := rv.Interface() if cz, ok := encoder.(CodecZeroer); ok { - zero = cz.IsTypeZero(rvInterface) + zero = cz.IsTypeZero(rv.Interface()) } else if rv.Kind() == reflect.Interface { // isZero will not treat an interface rv as an interface, so we need to check for the // zero interface separately. zero = rv.IsNil() } else { - zero = isZero(rvInterface, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct) + zero = isZero(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct) } if desc.omitEmpty && zero { continue @@ -392,56 +391,32 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val return nil } -func isZero(i interface{}, omitZeroStruct bool) bool { - v := reflect.ValueOf(i) - - // check the value validity - if !v.IsValid() { - return true +func isZero(v reflect.Value, omitZeroStruct bool) bool { + kind := v.Kind() + if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) { + return v.Interface().(Zeroer).IsZero() } - - if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { - return z.IsZero() - } - - switch v.Kind() { - case reflect.Array, reflect.Map, reflect.Slice, reflect.String: - return v.Len() == 0 - case reflect.Bool: - return !v.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return v.Uint() == 0 - case reflect.Float32, reflect.Float64: - return v.Float() == 0 - case reflect.Interface, reflect.Ptr: - return v.IsNil() - case reflect.Struct: + if kind == reflect.Struct { if !omitZeroStruct { return false } - - // TODO(GODRIVER-2820): Update the logic to be able to handle private struct fields. - // TODO Use condition "reflect.Zero(v.Type()).Equal(v)" instead. - vt := v.Type() if vt == tTime { return v.Interface().(time.Time).IsZero() } - for i := 0; i < v.NumField(); i++ { - if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous { + numField := vt.NumField() + for i := 0; i < numField; i++ { + ff := vt.Field(i) + if ff.PkgPath != "" && !ff.Anonymous { continue // Private field } - fld := v.Field(i) - if !isZero(fld.Interface(), omitZeroStruct) { + if !isZero(v.Field(i), omitZeroStruct) { return false } } return true } - - return false + return !v.IsValid() || v.IsZero() } type structDescription struct { @@ -708,21 +683,21 @@ func getInlineField(val reflect.Value, index []int) (reflect.Value, error) { // DeepZero returns recursive zero object func deepZero(st reflect.Type) (result reflect.Value) { - result = reflect.Indirect(reflect.New(st)) - - if result.Kind() == reflect.Struct { - for i := 0; i < result.NumField(); i++ { - if f := result.Field(i); f.Kind() == reflect.Ptr { - if f.CanInterface() { - if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct { - result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem()))) - } + if st.Kind() == reflect.Struct { + numField := st.NumField() + for i := 0; i < numField; i++ { + if result == emptyValue { + result = reflect.Indirect(reflect.New(st)) + } + f := result.Field(i) + if f.CanInterface() { + if f.Type().Kind() == reflect.Struct { + result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem()))) } } } } - - return + return result } // recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside diff --git a/bson/bsoncodec/struct_codec_test.go b/bson/bsoncodec/struct_codec_test.go index 008fc11528..573b374b14 100644 --- a/bson/bsoncodec/struct_codec_test.go +++ b/bson/bsoncodec/struct_codec_test.go @@ -7,6 +7,7 @@ package bsoncodec import ( + "reflect" "testing" "time" @@ -147,7 +148,7 @@ func TestIsZero(t *testing.T) { t.Run(tc.description, func(t *testing.T) { t.Parallel() - got := isZero(tc.value, tc.omitZeroStruct) + got := isZero(reflect.ValueOf(tc.value), tc.omitZeroStruct) assert.Equal(t, tc.want, got, "expected and actual isZero return are different") }) } diff --git a/bson/bsoncodec/types.go b/bson/bsoncodec/types.go index 07f4b70e6d..6ade17b7d3 100644 --- a/bson/bsoncodec/types.go +++ b/bson/bsoncodec/types.go @@ -34,6 +34,7 @@ var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem() var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem() var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem() var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem() +var tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem() var tBinary = reflect.TypeOf(primitive.Binary{}) var tUndefined = reflect.TypeOf(primitive.Undefined{}) diff --git a/bson/bsonrw/copier.go b/bson/bsonrw/copier.go index 33d59bd258..c146d02e58 100644 --- a/bson/bsonrw/copier.go +++ b/bson/bsonrw/copier.go @@ -193,7 +193,7 @@ func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) vw.reset(dst) @@ -213,7 +213,7 @@ func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) vw.reset(dst) @@ -258,7 +258,7 @@ func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, [] } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) start := len(dst) diff --git a/bson/bsonrw/value_reader.go b/bson/bsonrw/value_reader.go index 9bf24fae0b..a242bb57cf 100644 --- a/bson/bsonrw/value_reader.go +++ b/bson/bsonrw/value_reader.go @@ -739,8 +739,7 @@ func (vr *valueReader) ReadValue() (ValueReader, error) { return nil, ErrEOA } - _, err = vr.readCString() - if err != nil { + if err := vr.skipCString(); err != nil { return nil, err } @@ -794,6 +793,15 @@ func (vr *valueReader) readByte() (byte, error) { return vr.d[vr.offset-1], nil } +func (vr *valueReader) skipCString() error { + idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) + if idx < 0 { + return io.EOF + } + vr.offset += int64(idx) + 1 + return nil +} + func (vr *valueReader) readCString() (string, error) { idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) if idx < 0 { diff --git a/bson/bsonrw/value_writer.go b/bson/bsonrw/value_writer.go index a6dd8d34f5..311518a80d 100644 --- a/bson/bsonrw/value_writer.go +++ b/bson/bsonrw/value_writer.go @@ -28,6 +28,13 @@ var vwPool = sync.Pool{ }, } +func putValueWriter(vw *valueWriter) { + if vw != nil { + vw.w = nil // don't leak the writer + vwPool.Put(vw) + } +} + // BSONValueWriterPool is a pool for BSON ValueWriters. // // Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0. @@ -149,32 +156,21 @@ type valueWriter struct { } func (vw *valueWriter) advanceFrame() { - if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack - length := len(vw.stack) - if length+1 >= cap(vw.stack) { - // double it - buf := make([]vwState, 2*cap(vw.stack)+1) - copy(buf, vw.stack) - vw.stack = buf - } - vw.stack = vw.stack[:length+1] - } vw.frame++ + if vw.frame >= int64(len(vw.stack)) { + vw.stack = append(vw.stack, vwState{}) + } } func (vw *valueWriter) push(m mode) { vw.advanceFrame() // Clean the stack - vw.stack[vw.frame].mode = m - vw.stack[vw.frame].key = "" - vw.stack[vw.frame].arrkey = 0 - vw.stack[vw.frame].start = 0 + vw.stack[vw.frame] = vwState{mode: m} - vw.stack[vw.frame].mode = m switch m { case mDocument, mArray, mCodeWithScope: - vw.reserveLength() + vw.reserveLength() // WARN: this is not needed } } @@ -213,6 +209,7 @@ func newValueWriter(w io.Writer) *valueWriter { return vw } +// TODO: only used in tests func newValueWriterFromSlice(buf []byte) *valueWriter { vw := new(valueWriter) stack := make([]vwState, 1, 5) @@ -249,17 +246,16 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod } func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error { - switch vw.stack[vw.frame].mode { + frame := &vw.stack[vw.frame] + switch frame.mode { case mElement: - key := vw.stack[vw.frame].key + key := frame.key if !isValidCString(key) { return errors.New("BSON element key cannot contain null bytes") } - - vw.buf = bsoncore.AppendHeader(vw.buf, t, key) + vw.appendHeader(t, key) case mValue: - // TODO: Do this with a cache of the first 1000 or so array keys. - vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey)) + vw.appendIntHeader(t, frame.arrkey) default: modes := []mode{mElement, mValue} if addmodes != nil { @@ -601,9 +597,11 @@ func (vw *valueWriter) writeLength() error { if length > maxSize { return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))} } - length = length - int(vw.stack[vw.frame].start) - start := vw.stack[vw.frame].start + frame := &vw.stack[vw.frame] + length = length - int(frame.start) + start := frame.start + _ = vw.buf[start+3] // BCE vw.buf[start+0] = byte(length) vw.buf[start+1] = byte(length >> 8) vw.buf[start+2] = byte(length >> 16) @@ -612,5 +610,31 @@ func (vw *valueWriter) writeLength() error { } func isValidCString(cs string) bool { - return !strings.ContainsRune(cs, '\x00') + // Disallow the zero byte in a cstring because the zero byte is used as the + // terminating character. + // + // It's safe to check bytes instead of runes because all multibyte UTF-8 + // code points start with (binary) 11xxxxxx or 10xxxxxx, so 00000000 (i.e. + // 0) will never be part of a multibyte UTF-8 code point. This logic is the + // same as the "r < utf8.RuneSelf" case in strings.IndexRune but can be + // inlined. + // + // https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/strings/strings.go;l=127 + return strings.IndexByte(cs, 0) == -1 +} + +// appendHeader is the same as bsoncore.AppendHeader but does not check if the +// key is a valid C string since the caller has already checked for that. +// +// The caller of this function must check if key is a valid C string. +func (vw *valueWriter) appendHeader(t bsontype.Type, key string) { + vw.buf = bsoncore.AppendType(vw.buf, t) + vw.buf = append(vw.buf, key...) + vw.buf = append(vw.buf, 0x00) +} + +func (vw *valueWriter) appendIntHeader(t bsontype.Type, key int) { + vw.buf = bsoncore.AppendType(vw.buf, t) + vw.buf = strconv.AppendInt(vw.buf, int64(key), 10) + vw.buf = append(vw.buf, 0x00) } diff --git a/bson/marshal.go b/bson/marshal.go index f2c48d049e..17ce6697e0 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -9,6 +9,7 @@ package bson import ( "bytes" "encoding/json" + "sync" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsonrw" @@ -141,6 +142,13 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{ return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) } +// Pool of buffers for marshalling BSON. +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + // MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the // bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be // transformed into a document, MarshalValueAppendWithContext should be used instead. @@ -162,8 +170,26 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{ // // See [Encoder] for more examples. func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) { - sw := new(bsonrw.SliceWriter) - *sw = dst + sw := bufPool.Get().(*bytes.Buffer) + defer func() { + // Proper usage of a sync.Pool requires each entry to have approximately + // the same memory cost. To obtain this property when the stored type + // contains a variably-sized buffer, we add a hard limit on the maximum + // buffer to place back in the pool. We limit the size to 16MiB because + // that's the maximum wire message size supported by any current MongoDB + // server. + // + // Comment based on + // https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147 + // + // Recycle byte slices that are smaller than 16MiB and at least half + // occupied. + if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() { + bufPool.Put(sw) + } + }() + + sw.Reset() vw := bvwPool.Get(sw) defer bvwPool.Put(vw) @@ -184,7 +210,7 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf return nil, err } - return *sw, nil + return append(dst, sw.Bytes()...), nil } // MarshalValue returns the BSON encoding of val. diff --git a/cmd/testatlas/main.go b/cmd/testatlas/main.go index d98631d468..fa50d7cde7 100644 --- a/cmd/testatlas/main.go +++ b/cmd/testatlas/main.go @@ -23,7 +23,11 @@ func main() { uris := flag.Args() ctx := context.Background() + fmt.Printf("Running atlas tests for %d uris\n", len(uris)) + for idx, uri := range uris { + fmt.Printf("Running test %d\n", idx) + // Set a low server selection timeout so we fail fast if there are errors. clientOpts := options.Client(). ApplyURI(uri). @@ -41,6 +45,8 @@ func main() { panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err)) } } + + fmt.Println("Finished!") } func runTest(ctx context.Context, clientOpts *options.ClientOptions) error { diff --git a/cmd/testkms/main.go b/cmd/testkms/main.go index 5d7532c4b2..af86eca523 100644 --- a/cmd/testkms/main.go +++ b/cmd/testkms/main.go @@ -24,8 +24,8 @@ var datakeyopts = map[string]primitive.M{ "key": "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0", }, "azure": bson.M{ - "keyVaultEndpoint": "https://keyvault-drivers-2411.vault.azure.net/keys/", - "keyName": "KEY-NAME", + "keyVaultEndpoint": "", + "keyName": "", }, "gcp": bson.M{ "projectId": "devprod-drivers", @@ -53,6 +53,20 @@ func main() { default: ok = true } + if provider == "azure" { + azureKmsKeyName := os.Getenv("AZUREKMS_KEY_NAME") + azureKmsKeyVaultEndpoint := os.Getenv("AZUREKMS_KEY_VAULT_ENDPOINT") + if azureKmsKeyName == "" { + fmt.Println("ERROR: Please set required AZUREKMS_KEY_NAME environment variable.") + ok = false + } + if azureKmsKeyVaultEndpoint == "" { + fmt.Println("ERROR: Please set required AZUREKMS_KEY_VAULT_ENDPOINT environment variable.") + ok = false + } + datakeyopts["azure"]["keyName"] = azureKmsKeyName + datakeyopts["azure"]["keyVaultEndpoint"] = azureKmsKeyVaultEndpoint + } if !ok { providers := make([]string, 0, len(datakeyopts)) for p := range datakeyopts { @@ -63,6 +77,8 @@ func main() { fmt.Println("- MONGODB_URI as a MongoDB URI. Example: 'mongodb://localhost:27017'") fmt.Println("- EXPECT_ERROR as an optional expected error substring.") fmt.Println("- PROVIDER as a KMS provider, which supports:", strings.Join(providers, ", ")) + fmt.Println("- AZUREKMS_KEY_NAME as the Azure key name. Required if PROVIDER=azure.") + fmt.Println("- AZUREKMS_KEY_VAULT_ENDPOINT as the Azure key name. Required if PROVIDER=azure.") os.Exit(1) } diff --git a/etc/api_report.sh b/etc/api_report.sh index 1cfbac5025..ffdac7f975 100755 --- a/etc/api_report.sh +++ b/etc/api_report.sh @@ -10,6 +10,7 @@ if [ -z $cmd ]; then fi branch=${GITHUB_BASE_REF:-master} +git fetch origin $branch:$branch sha=$(git merge-base $branch HEAD) gorelease -base=$sha > api-report.txt || true diff --git a/etc/get_aws_secrets.sh b/etc/get_aws_secrets.sh new file mode 100644 index 0000000000..894016553b --- /dev/null +++ b/etc/get_aws_secrets.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# get-aws-secrets +# Gets AWS secrets from the vault +set -eu + +if [ -z "$DRIVERS_TOOLS" ]; then + echo "Please define DRIVERS_TOOLS variable" + exit 1 +fi + +bash $DRIVERS_TOOLS/.evergreen/auth_aws/setup_secrets.sh $@ +. ./secrets-export.sh diff --git a/etc/run-atlas-test.sh b/etc/run-atlas-test.sh new file mode 100644 index 0000000000..aa89b2dd4b --- /dev/null +++ b/etc/run-atlas-test.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# run-atlas-test +# Run atlas connectivity tests. +set -eu +set +x + +# Get the atlas secrets. +. etc/get_aws_secrets.sh drivers/atlas_connect + +echo "Running cmd/testatlas/main.go" +go run ./cmd/testatlas/main.go "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" diff --git a/mongo/errors.go b/mongo/errors.go index 5f2b1b819b..72c3bcc243 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -102,50 +102,56 @@ func replaceErrors(err error) error { return err } -// IsDuplicateKeyError returns true if err is a duplicate key error +// IsDuplicateKeyError returns true if err is a duplicate key error. func IsDuplicateKeyError(err error) bool { - // handles SERVER-7164 and SERVER-11493 - for ; err != nil; err = unwrap(err) { - if e, ok := err.(ServerError); ok { - return e.HasErrorCode(11000) || e.HasErrorCode(11001) || e.HasErrorCode(12582) || - e.HasErrorCodeWithMessage(16460, " E11000 ") - } + if se := ServerError(nil); errors.As(err, &se) { + return se.HasErrorCode(11000) || // Duplicate key error. + se.HasErrorCode(11001) || // Duplicate key error on update. + // Duplicate key error in a capped collection. See SERVER-7164. + se.HasErrorCode(12582) || + // Mongos insert error caused by a duplicate key error. See + // SERVER-11493. + se.HasErrorCodeWithMessage(16460, " E11000 ") } return false } -// IsTimeout returns true if err is from a timeout +// timeoutErrs is a list of error values that indicate a timeout happened. +var timeoutErrs = [...]error{ + context.DeadlineExceeded, + driver.ErrDeadlineWouldBeExceeded, + topology.ErrServerSelectionTimeout, +} + +// IsTimeout returns true if err was caused by a timeout. For error chains, +// IsTimeout returns true if any error in the chain was caused by a timeout. func IsTimeout(err error) bool { - for ; err != nil; err = unwrap(err) { - // check unwrappable errors together - if err == context.DeadlineExceeded { - return true - } - if err == driver.ErrDeadlineWouldBeExceeded { - return true - } - if err == topology.ErrServerSelectionTimeout { - return true - } - if _, ok := err.(topology.WaitQueueTimeoutError); ok { - return true - } - if ce, ok := err.(CommandError); ok && ce.IsMaxTimeMSExpiredError() { + // Check if the error chain contains any of the timeout error values. + for _, target := range timeoutErrs { + if errors.Is(err, target) { return true } - if we, ok := err.(WriteException); ok && we.WriteConcernError != nil && - we.WriteConcernError.IsMaxTimeMSExpiredError() { + } + + // Check if the error chain contains any error types that can indicate + // timeout. + if errors.As(err, &topology.WaitQueueTimeoutError{}) { + return true + } + if ce := (CommandError{}); errors.As(err, &ce) && ce.IsMaxTimeMSExpiredError() { + return true + } + if we := (WriteException{}); errors.As(err, &we) && we.WriteConcernError != nil && we.WriteConcernError.IsMaxTimeMSExpiredError() { + return true + } + if ne := net.Error(nil); errors.As(err, &ne) { + return ne.Timeout() + } + // Check timeout error labels. + if le := LabeledError(nil); errors.As(err, &le) { + if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { return true } - if ne, ok := err.(net.Error); ok { - return ne.Timeout() - } - //timeout error labels - if le, ok := err.(LabeledError); ok { - if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { - return true - } - } } return false diff --git a/mongo/integration/change_stream_test.go b/mongo/integration/change_stream_test.go index 868706ad1e..b3d0469c36 100644 --- a/mongo/integration/change_stream_test.go +++ b/mongo/integration/change_stream_test.go @@ -770,7 +770,7 @@ func TestChangeStream_ReplicaSet(t *testing.T) { require.NoError(mt, err, "failed to update idValue") }() - nextCtx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + nextCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) t.Cleanup(cancel) type splitEvent struct { diff --git a/mongo/integration/client_side_encryption_prose_test.go b/mongo/integration/client_side_encryption_prose_test.go index f650aa6b47..45a2ef01c8 100644 --- a/mongo/integration/client_side_encryption_prose_test.go +++ b/mongo/integration/client_side_encryption_prose_test.go @@ -19,7 +19,6 @@ import ( "net/http" "os" "path/filepath" - "runtime" "strings" "testing" "time" @@ -55,6 +54,16 @@ const ( maxBsonObjSize = 16777216 // max bytes in BSON object ) +func containsSubstring(possibleSubstrings []string, str string) bool { + for _, possibleSubstring := range possibleSubstrings { + if strings.Contains(str, possibleSubstring) { + return true + } + } + + return false +} + func TestClientSideEncryptionProse(t *testing.T) { t.Parallel() @@ -378,9 +387,6 @@ func TestClientSideEncryptionProse(t *testing.T) { } }) mt.Run("4. bson size limits", func(mt *mtest.T) { - // TODO(GODRIVER-2872): Fix and unskip this test case. - mt.Skip("Test fails frequently, skipping. See GODRIVER-2872") - kmsProviders := map[string]map[string]interface{}{ "local": { "key": localMasterKey, @@ -868,26 +874,119 @@ func TestClientSideEncryptionProse(t *testing.T) { "endpoint": "doesnotexist.local:5698", } + const ( + errConnectionRefused = "connection refused" + errInvalidKMSResponse = "Invalid KMS response" + errMongocryptError = "mongocrypt error" + errNoSuchHost = "no such host" + errServerMisbehaving = "server misbehaving" + errWindowsTLSConnectionRefused = "No connection could be made because the target machine actively refused it" + ) + testCases := []struct { name string provider string masterKey interface{} - errorSubstring string + errorSubstring []string testInvalidClientEncryption bool - invalidClientEncryptionErrorSubstring string + invalidClientEncryptionErrorSubstring []string }{ - {"Case 1: aws success without endpoint", "aws", awsSuccessWithoutEndpoint, "", false, ""}, - {"Case 2: aws success with endpoint", "aws", awsSuccessWithEndpoint, "", false, ""}, - {"Case 3: aws success with https endpoint", "aws", awsSuccessWithHTTPSEndpoint, "", false, ""}, - {"Case 4: aws failure with connection error", "aws", awsFailureConnectionError, "connection refused", false, ""}, - {"Case 5: aws failure with wrong endpoint", "aws", awsFailureInvalidEndpoint, "mongocrypt error", false, ""}, - {"Case 6: aws failure with parse error", "aws", awsFailureParseError, "no such host", false, ""}, - {"Case 7: azure success", "azure", azure, "", true, "no such host"}, - {"Case 8: gcp success", "gcp", gcpSuccess, "", true, "no such host"}, - {"Case 9: gcp failure", "gcp", gcpFailure, "Invalid KMS response", false, ""}, - {"Case 10: kmip success without endpoint", "kmip", kmipSuccessWithoutEndpoint, "", true, "no such host"}, - {"Case 11: kmip success with endpoint", "kmip", kmipSuccessWithEndpoint, "", false, ""}, - {"Case 12: kmip failure with invalid endpoint", "kmip", kmipFailureInvalidEndpoint, "no such host", false, ""}, + { + name: "Case 1: aws success without endpoint", + provider: "aws", + masterKey: awsSuccessWithoutEndpoint, + errorSubstring: []string{}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, + { + name: "Case 2: aws success with endpoint", + provider: "aws", + masterKey: awsSuccessWithEndpoint, + errorSubstring: []string{}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, + { + name: "Case 3: aws success with https endpoint", + provider: "aws", + masterKey: awsSuccessWithHTTPSEndpoint, + errorSubstring: []string{}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, + { + name: "Case 4: aws failure with connection error", + provider: "aws", + masterKey: awsFailureConnectionError, + errorSubstring: []string{errConnectionRefused, errWindowsTLSConnectionRefused}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, + { + name: "Case 5: aws failure with wrong endpoint", + provider: "aws", + masterKey: awsFailureInvalidEndpoint, + errorSubstring: []string{errMongocryptError}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, + { + name: "Case 6: aws failure with parse error", + provider: "aws", + masterKey: awsFailureParseError, + errorSubstring: []string{errNoSuchHost, errServerMisbehaving}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, + { + name: "Case 7: azure success", + provider: "azure", + masterKey: azure, + errorSubstring: []string{}, + testInvalidClientEncryption: true, + invalidClientEncryptionErrorSubstring: []string{errNoSuchHost, errServerMisbehaving}, + }, + { + name: "Case 8: gcp success", + provider: "gcp", + masterKey: gcpSuccess, + errorSubstring: []string{}, + testInvalidClientEncryption: true, + invalidClientEncryptionErrorSubstring: []string{errNoSuchHost, errServerMisbehaving}, + }, + { + name: "Case 9: gcp failure", + provider: "gcp", + masterKey: gcpFailure, + errorSubstring: []string{errInvalidKMSResponse}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, + { + name: "Case 10: kmip success without endpoint", + provider: "kmip", + masterKey: kmipSuccessWithoutEndpoint, + errorSubstring: []string{}, + testInvalidClientEncryption: true, + invalidClientEncryptionErrorSubstring: []string{errNoSuchHost, errServerMisbehaving}, + }, + { + name: "Case 11: kmip success with endpoint", + provider: "kmip", + masterKey: kmipSuccessWithEndpoint, + errorSubstring: []string{}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, + { + name: "Case 12: kmip failure with invalid endpoint", + provider: "kmip", + masterKey: kmipFailureInvalidEndpoint, + errorSubstring: []string{errNoSuchHost, errServerMisbehaving}, + testInvalidClientEncryption: false, + invalidClientEncryptionErrorSubstring: []string{}, + }, } for _, tc := range testCases { mt.Run(tc.name, func(mt *mtest.T) { @@ -899,16 +998,12 @@ func TestClientSideEncryptionProse(t *testing.T) { dkOpts := options.DataKey().SetMasterKey(tc.masterKey) createdKey, err := cpt.clientEnc.CreateDataKey(context.Background(), tc.provider, dkOpts) - if tc.errorSubstring != "" { + if len(tc.errorSubstring) > 0 { assert.NotNil(mt, err, "expected error, got nil") - errSubstr := tc.errorSubstring - if runtime.GOOS == "windows" && errSubstr == "connection refused" { - // tls.Dial returns an error that does not contain the substring "connection refused" - // on Windows machines - errSubstr = "No connection could be made because the target machine actively refused it" - } - assert.True(mt, strings.Contains(err.Error(), errSubstr), - "expected error '%s' to contain '%s'", err.Error(), errSubstr) + + assert.True(t, containsSubstring(tc.errorSubstring, err.Error()), + "expected tc.errorSubstring=%v to contain %v, but it didn't", tc.errorSubstring, err.Error()) + return } assert.Nil(mt, err, "CreateDataKey error: %v", err) @@ -935,8 +1030,10 @@ func TestClientSideEncryptionProse(t *testing.T) { invalidKeyOpts := options.DataKey().SetMasterKey(tc.masterKey) _, err = invalidClientEncryption.CreateDataKey(context.Background(), tc.provider, invalidKeyOpts) assert.NotNil(mt, err, "expected CreateDataKey error, got nil") - assert.True(mt, strings.Contains(err.Error(), tc.invalidClientEncryptionErrorSubstring), - "expected error %v to contain substring '%v'", err, tc.invalidClientEncryptionErrorSubstring) + + assert.True(t, containsSubstring(tc.invalidClientEncryptionErrorSubstring, err.Error()), + "expected tc.invalidClientEncryptionErrorSubstring=%v to contain %v, but it didn't", + tc.invalidClientEncryptionErrorSubstring, err.Error()) }) } }) diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 914ca863b7..038ed25d72 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -126,8 +126,6 @@ func TestClient(t *testing.T) { "expected security field to be type %v, got %v", bson.TypeMaxKey, security.Type) _, found := security.Document().LookupErr("SSLServerSubjectName") assert.Nil(mt, found, "SSLServerSubjectName not found in result") - _, found = security.Document().LookupErr("SSLServerHasCertificateAuthority") - assert.Nil(mt, found, "SSLServerHasCertificateAuthority not found in result") }) mt.RunOpts("x509", mtest.NewOptions().Auth(true).SSL(true), func(mt *mtest.T) { testCases := []struct { @@ -711,7 +709,7 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(ctx, nil) cancel() assert.NotNil(mt, err, "expected Ping to return an error") - assert.True(mt, mongo.IsTimeout(err), "expected a timeout error: got %v", err) + assert.True(mt, mongo.IsTimeout(err), "expected a timeout error, got: %v", err) } // Assert that the Ping timeouts result in no connections being closed. @@ -733,8 +731,8 @@ func TestClient(t *testing.T) { pair := msgPairs[0] assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s", handshake.LegacyHello, pair.CommandName) - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) + assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode, + "expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String()) // Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire // version is now known to be >= 6. diff --git a/mongo/integration/mtest/sent_message.go b/mongo/integration/mtest/sent_message.go index d36075bf81..6b96e061bc 100644 --- a/mongo/integration/mtest/sent_message.go +++ b/mongo/integration/mtest/sent_message.go @@ -37,6 +37,8 @@ type sentMsgParseFn func([]byte) (*SentMessage, error) func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { switch opcode { + case wiremessage.OpQuery: + return parseOpQuery, true case wiremessage.OpMsg: return parseSentOpMsg, true case wiremessage.OpCompressed: @@ -46,6 +48,69 @@ func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { } } +func parseOpQuery(wm []byte) (*SentMessage, error) { + var ok bool + + if _, wm, ok = wiremessage.ReadQueryFlags(wm); !ok { + return nil, errors.New("failed to read query flags") + } + if _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm); !ok { + return nil, errors.New("failed to read full collection name") + } + if _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm); !ok { + return nil, errors.New("failed to read number to skip") + } + if _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm); !ok { + return nil, errors.New("failed to read number to return") + } + + query, wm, ok := wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("failed to read query") + } + + // If there is no read preference document, the command document is query. + // Otherwise, query is in the format {$query: , $readPreference: }. + commandDoc := query + var rpDoc bsoncore.Document + + dollarQueryVal, err := query.LookupErr("$query") + if err == nil { + commandDoc = dollarQueryVal.Document() + + rpVal, err := query.LookupErr("$readPreference") + if err != nil { + return nil, fmt.Errorf("query %s contains $query but not $readPreference fields", query) + } + rpDoc = rpVal.Document() + } + + // For OP_QUERY, inserts, updates, and deletes are sent as a BSON array of documents inside the main command + // document. Pull these sequences out into an ArrayStyle DocumentSequence. + var docSequence *bsoncore.DocumentSequence + cmdElems, _ := commandDoc.Elements() + for _, elem := range cmdElems { + switch elem.Key() { + case "documents", "updates", "deletes": + docSequence = &bsoncore.DocumentSequence{ + Style: bsoncore.ArrayStyle, + Data: elem.Value().Array(), + } + } + if docSequence != nil { + // There can only be one of these arrays in a well-formed command, so we exit the loop once one is found. + break + } + } + + sm := &SentMessage{ + Command: commandDoc, + ReadPreference: rpDoc, + DocumentSequence: docSequence, + } + return sm, nil +} + func parseSentMessage(wm []byte) (*SentMessage, error) { // Re-assign the wire message to "remaining" so "wm" continues to point to the entire message after parsing. _, requestID, _, opcode, remaining, ok := wiremessage.ReadHeader(wm) diff --git a/mongo/integration/retryable_reads_prose_test.go b/mongo/integration/retryable_reads_prose_test.go index 80d7937e8c..80f4d3329a 100644 --- a/mongo/integration/retryable_reads_prose_test.go +++ b/mongo/integration/retryable_reads_prose_test.go @@ -16,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/eventtest" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/integration/mtest" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -102,4 +103,95 @@ func TestRetryableReadsProse(t *testing.T) { "expected a find event, got a(n) %v event", cmdEvt.CommandName) } }) + + mtOpts = mtest.NewOptions().Topologies(mtest.Sharded).MinServerVersion("4.2") + mt.RunOpts("retrying in sharded cluster", mtOpts, func(mt *mtest.T) { + tests := []struct { + name string + + // Note that setting this value greater than 2 will result in false + // negatives. The current specification does not account for CSOT, which + // might allow for an "inifinite" number of retries over a period of time. + // Because of this, we only track the "previous server". + hostCount int + failpointErrorCode int32 + expectedFailCount int + expectedSuccessCount int + }{ + { + name: "retry on different mongos", + hostCount: 2, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 2, + expectedSuccessCount: 0, + }, + { + name: "retry on same mongos", + hostCount: 1, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 1, + expectedSuccessCount: 1, + }, + } + + for _, tc := range tests { + mt.Run(tc.name, func(mt *mtest.T) { + hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts + require.GreaterOrEqualf(mt, len(hosts), tc.hostCount, + "test cluster must have at least %v mongos hosts", tc.hostCount) + + // Configure the failpoint options for each mongos. + failPoint := mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"find"}, + ErrorCode: tc.failpointErrorCode, + CloseConnection: false, + }, + } + + // In order to ensure that each mongos in the hostCount-many mongos + // hosts are tried at least once (i.e. failures are deprioritized), we + // set a failpoint on all mongos hosts. The idea is that if we get + // hostCount-many failures, then by the pigeonhole principal all mongos + // hosts must have been tried. + for i := 0; i < tc.hostCount; i++ { + mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + mt.SetFailPoint(failPoint) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the + // specific mongos when the test is done. + defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + defer mt.ClearFailPoints() + } + + failCount := 0 + successCount := 0 + + commandMonitor := &event.CommandMonitor{ + Failed: func(context.Context, *event.CommandFailedEvent) { + failCount++ + }, + Succeeded: func(context.Context, *event.CommandSucceededEvent) { + successCount++ + }, + } + + // Reset the client with exactly hostCount-many mongos hosts. + mt.ResetClient(options.Client(). + SetHosts(hosts[:tc.hostCount]). + SetRetryReads(true). + SetMonitor(commandMonitor)) + + mt.Coll.FindOne(context.Background(), bson.D{}) + + assert.Equal(mt, tc.expectedFailCount, failCount) + assert.Equal(mt, tc.expectedSuccessCount, successCount) + }) + } + }) } diff --git a/mongo/integration/retryable_writes_prose_test.go b/mongo/integration/retryable_writes_prose_test.go index b378cdcbb5..1c8d353f14 100644 --- a/mongo/integration/retryable_writes_prose_test.go +++ b/mongo/integration/retryable_writes_prose_test.go @@ -284,4 +284,96 @@ func TestRetryableWritesProse(t *testing.T) { // Assert that the "ShutdownInProgress" error is returned. require.True(mt, err.(mongo.WriteException).HasErrorCode(int(shutdownInProgressErrorCode))) }) + + mtOpts = mtest.NewOptions().Topologies(mtest.Sharded).MinServerVersion("4.2") + mt.RunOpts("retrying in sharded cluster", mtOpts, func(mt *mtest.T) { + tests := []struct { + name string + + // Note that setting this value greater than 2 will result in false + // negatives. The current specification does not account for CSOT, which + // might allow for an "inifinite" number of retries over a period of time. + // Because of this, we only track the "previous server". + hostCount int + failpointErrorCode int32 + expectedFailCount int + expectedSuccessCount int + }{ + { + name: "retry on different mongos", + hostCount: 2, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 2, + expectedSuccessCount: 0, + }, + { + name: "retry on same mongos", + hostCount: 1, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 1, + expectedSuccessCount: 1, + }, + } + + for _, tc := range tests { + mt.Run(tc.name, func(mt *mtest.T) { + hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts + require.GreaterOrEqualf(mt, len(hosts), tc.hostCount, + "test cluster must have at least %v mongos hosts", tc.hostCount) + + // Configure the failpoint options for each mongos. + failPoint := mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"insert"}, + ErrorLabels: &[]string{"RetryableWriteError"}, + ErrorCode: tc.failpointErrorCode, + CloseConnection: false, + }, + } + + // In order to ensure that each mongos in the hostCount-many mongos + // hosts are tried at least once (i.e. failures are deprioritized), we + // set a failpoint on all mongos hosts. The idea is that if we get + // hostCount-many failures, then by the pigeonhole principal all mongos + // hosts must have been tried. + for i := 0; i < tc.hostCount; i++ { + mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + mt.SetFailPoint(failPoint) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the + // specific mongos when the test is done. + defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + defer mt.ClearFailPoints() + } + + failCount := 0 + successCount := 0 + + commandMonitor := &event.CommandMonitor{ + Failed: func(context.Context, *event.CommandFailedEvent) { + failCount++ + }, + Succeeded: func(context.Context, *event.CommandSucceededEvent) { + successCount++ + }, + } + + // Reset the client with exactly hostCount-many mongos hosts. + mt.ResetClient(options.Client(). + SetHosts(hosts[:tc.hostCount]). + SetRetryWrites(true). + SetMonitor(commandMonitor)) + + _, _ = mt.Coll.InsertOne(context.Background(), bson.D{}) + + assert.Equal(mt, tc.expectedFailCount, failCount) + assert.Equal(mt, tc.expectedSuccessCount, successCount) + }) + } + }) } diff --git a/mongo/integration/unified/client_entity.go b/mongo/integration/unified/client_entity.go index e63c891039..ff7d9d5fc3 100644 --- a/mongo/integration/unified/client_entity.go +++ b/mongo/integration/unified/client_entity.go @@ -66,6 +66,7 @@ type clientEntity struct { eventsCountLock sync.RWMutex serverDescriptionChangedEventsCountLock sync.RWMutex + eventProcessMu sync.RWMutex entityMap *EntityMap @@ -471,6 +472,9 @@ func (c *clientEntity) processPoolEvent(evt *event.PoolEvent) { } func (c *clientEntity) processServerDescriptionChangedEvent(evt *event.ServerDescriptionChangedEvent) { + c.eventProcessMu.Lock() + defer c.eventProcessMu.Unlock() + if !c.getRecordEvents() { return } @@ -487,6 +491,9 @@ func (c *clientEntity) processServerDescriptionChangedEvent(evt *event.ServerDes } func (c *clientEntity) processServerHeartbeatFailedEvent(evt *event.ServerHeartbeatFailedEvent) { + c.eventProcessMu.Lock() + defer c.eventProcessMu.Unlock() + if !c.getRecordEvents() { return } @@ -499,6 +506,9 @@ func (c *clientEntity) processServerHeartbeatFailedEvent(evt *event.ServerHeartb } func (c *clientEntity) processServerHeartbeatStartedEvent(evt *event.ServerHeartbeatStartedEvent) { + c.eventProcessMu.Lock() + defer c.eventProcessMu.Unlock() + if !c.getRecordEvents() { return } @@ -511,6 +521,9 @@ func (c *clientEntity) processServerHeartbeatStartedEvent(evt *event.ServerHeart } func (c *clientEntity) processServerHeartbeatSucceededEvent(evt *event.ServerHeartbeatSucceededEvent) { + c.eventProcessMu.Lock() + defer c.eventProcessMu.Unlock() + if !c.getRecordEvents() { return } @@ -523,6 +536,9 @@ func (c *clientEntity) processServerHeartbeatSucceededEvent(evt *event.ServerHea } func (c *clientEntity) processTopologyDescriptionChangedEvent(evt *event.TopologyDescriptionChangedEvent) { + c.eventProcessMu.Lock() + defer c.eventProcessMu.Unlock() + if !c.getRecordEvents() { return } diff --git a/mongo/integration/unified/logger.go b/mongo/integration/unified/logger.go index 6dcadacf4a..1d9a612092 100644 --- a/mongo/integration/unified/logger.go +++ b/mongo/integration/unified/logger.go @@ -7,6 +7,8 @@ package unified import ( + "sync" + "go.mongodb.org/mongo-driver/internal/logger" ) @@ -20,9 +22,19 @@ type orderedLogMessage struct { // Logger is the Sink used to captured log messages for logger verification in // the unified spec tests. type Logger struct { + // bufSize is the number of logs expected to be sent to the logger for a + // unified spec test. + bufSize int + + // lastOrder increments each time the "Info" method is called, and is used to + // determine when to close the logQueue. lastOrder int - logQueue chan orderedLogMessage - bufSize int + + // orderMu guards the order value, which increments each time the "Info" + // method is called. This is necessary since "Info" could be called from + // multiple go routines, e.g. SDAM logs. + orderMu sync.RWMutex + logQueue chan orderedLogMessage } func newLogger(olm *observeLogMessages, bufSize int) *Logger { @@ -44,14 +56,17 @@ func (log *Logger) Info(level int, msg string, args ...interface{}) { return } - defer func() { log.lastOrder++ }() - // If the order is greater than the buffer size, we must return. This // would indicate that the logQueue channel has been closed. if log.lastOrder > log.bufSize { return } + log.orderMu.Lock() + defer log.orderMu.Unlock() + + defer func() { log.lastOrder++ }() + // Add the Diff back to the level, as there is no need to create a // logging offset. level = level + logger.DiffToInfo @@ -68,7 +83,7 @@ func (log *Logger) Info(level int, msg string, args ...interface{}) { logMessage: logMessage, } - // If the order has reached the buffer size, then close the channe. + // If the order has reached the buffer size, then close the channel. if log.lastOrder == log.bufSize { close(log.logQueue) } diff --git a/mongo/integration/unified/unified_spec_runner.go b/mongo/integration/unified/unified_spec_runner.go index 1b1cbeb533..7b92d07204 100644 --- a/mongo/integration/unified/unified_spec_runner.go +++ b/mongo/integration/unified/unified_spec_runner.go @@ -224,7 +224,7 @@ func (tc *TestCase) Run(ls LoggerSkipper) error { } // Count the number of expected log messages over all clients. - expectedLogCount := 0 + var expectedLogCount int for _, clientLog := range tc.ExpectLogMessages { expectedLogCount += len(clientLog.LogMessages) } diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index f2234e227c..a159891adc 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -93,7 +93,7 @@ func TestSpeculativeSCRAM(t *testing.T) { // Assert that the driver sent hello with the speculative authentication message. assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d", len(tc.payloads), (conn.Written)) - helloCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, helloCmd, handshake.LegacyHello) @@ -177,7 +177,7 @@ func TestSpeculativeSCRAM(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index 13fdf2b185..cf46de6ffd 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -58,7 +58,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) @@ -103,7 +103,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index 7f355f61a4..d79b024b74 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -26,48 +26,72 @@ type CompressionOpts struct { UncompressedSize int32 } -var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder +// mustZstdNewWriter creates a zstd.Encoder with the given level and a nil +// destination writer. It panics on any errors and should only be used at +// package initialization time. +func mustZstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { + enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl)) + if err != nil { + panic(err) + } + return enc +} + +var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{ + 0: nil, // zstd.speedNotSet + zstd.SpeedFastest: mustZstdNewWriter(zstd.SpeedFastest), + zstd.SpeedDefault: mustZstdNewWriter(zstd.SpeedDefault), + zstd.SpeedBetterCompression: mustZstdNewWriter(zstd.SpeedBetterCompression), + zstd.SpeedBestCompression: mustZstdNewWriter(zstd.SpeedBestCompression), +} func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { - if v, ok := zstdEncoders.Load(level); ok { - return v.(*zstd.Encoder), nil - } - encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) - if err != nil { - return nil, err + if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression { + return zstdEncoders[level], nil } - zstdEncoders.Store(level, encoder) - return encoder, nil + // The level is outside the expected range, return an error. + return nil, fmt.Errorf("invalid zstd compression level: %d", level) } -var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder +// zlibEncodersOffset is the offset into the zlibEncoders array for a given +// compression level. +const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2 + +var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool func getZlibEncoder(level int) (*zlibEncoder, error) { - if v, ok := zlibEncoders.Load(level); ok { - return v.(*zlibEncoder), nil - } - writer, err := zlib.NewWriterLevel(nil, level) - if err != nil { - return nil, err + if zlib.HuffmanOnly <= level && level <= zlib.BestCompression { + if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil { + return enc, nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + enc := &zlibEncoder{writer: writer, level: level} + return enc, nil } - encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} - zlibEncoders.Store(level, encoder) + // The level is outside the expected range, return an error. + return nil, fmt.Errorf("invalid zlib compression level: %d", level) +} - return encoder, nil +func putZlibEncoder(enc *zlibEncoder) { + if enc != nil { + zlibEncoders[enc.level+zlibEncodersOffset].Put(enc) + } } type zlibEncoder struct { - mu sync.Mutex writer *zlib.Writer - buf *bytes.Buffer + buf bytes.Buffer + level int } func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { - e.mu.Lock() - defer e.mu.Unlock() + defer putZlibEncoder(e) e.buf.Reset() - e.writer.Reset(e.buf) + e.writer.Reset(&e.buf) _, err := e.writer.Write(src) if err != nil { @@ -105,8 +129,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { } } +var zstdReaderPool = sync.Pool{ + New: func() interface{} { + r, _ := zstd.NewReader(nil) + return r + }, +} + // DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed -func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { +func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { switch opts.Compressor { case wiremessage.CompressorNoOp: return in, nil @@ -117,34 +148,29 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er } else if int32(l) != opts.UncompressedSize { return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) } - uncompressed = make([]byte, opts.UncompressedSize) - return snappy.Decode(uncompressed, in) + out := make([]byte, opts.UncompressedSize) + return snappy.Decode(out, in) case wiremessage.CompressorZLib: r, err := zlib.NewReader(bytes.NewReader(in)) if err != nil { return nil, err } - defer func() { - err = r.Close() - }() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { + out := make([]byte, opts.UncompressedSize) + if _, err := io.ReadFull(r, out); err != nil { return nil, err } - return uncompressed, nil - case wiremessage.CompressorZstd: - r, err := zstd.NewReader(bytes.NewBuffer(in)) - if err != nil { - return nil, err - } - defer r.Close() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { + if err := r.Close(); err != nil { return nil, err } - return uncompressed, nil + return out, nil + case wiremessage.CompressorZstd: + buf := make([]byte, 0, opts.UncompressedSize) + // Using a pool here is about ~20% faster + // than using a single global zstd.Reader + r := zstdReaderPool.Get().(*zstd.Decoder) + out, err := r.DecodeAll(in, buf) + zstdReaderPool.Put(r) + return out, err default: return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) } diff --git a/x/mongo/driver/compression_test.go b/x/mongo/driver/compression_test.go index b477cb32c1..75a7ff072b 100644 --- a/x/mongo/driver/compression_test.go +++ b/x/mongo/driver/compression_test.go @@ -7,9 +7,14 @@ package driver import ( + "bytes" + "compress/zlib" "os" "testing" + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" + "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -41,6 +46,43 @@ func TestCompression(t *testing.T) { } } +func TestCompressionLevels(t *testing.T) { + in := []byte("abc") + wr := new(bytes.Buffer) + + t.Run("ZLib", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorZLib, + } + for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ { + opts.ZlibLevel = lvl + _, err1 := CompressPayload(in, opts) + _, err2 := zlib.NewWriterLevel(wr, lvl) + if err2 != nil { + assert.Error(t, err1, "expected an error for ZLib level %d", lvl) + } else { + assert.NoError(t, err1, "unexpected error for ZLib level %d", lvl) + } + } + }) + + t.Run("Zstd", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorZstd, + } + for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ { + opts.ZstdLevel = int(lvl) + _, err1 := CompressPayload(in, opts) + _, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel))) + if err2 != nil { + assert.Error(t, err1, "expected an error for Zstd level %d", lvl) + } else { + assert.NoError(t, err1, "unexpected error for Zstd level %d", lvl) + } + } + }) +} + func TestDecompressFailures(t *testing.T) { t.Parallel() @@ -62,18 +104,57 @@ func TestDecompressFailures(t *testing.T) { }) } -func BenchmarkCompressPayload(b *testing.B) { - payload := func() []byte { - buf, err := os.ReadFile("compression.go") +var ( + compressionPayload []byte + compressedSnappyPayload []byte + compressedZLibPayload []byte + compressedZstdPayload []byte +) + +func initCompressionPayload(b *testing.B) { + if compressionPayload != nil { + return + } + data, err := os.ReadFile("testdata/compression.go") + if err != nil { + b.Fatal(err) + } + for i := 1; i < 10; i++ { + data = append(data, data...) + } + compressionPayload = data + + compressedSnappyPayload = snappy.Encode(compressedSnappyPayload[:0], data) + + { + var buf bytes.Buffer + enc, err := zstd.NewWriter(&buf, zstd.WithEncoderLevel(zstd.SpeedDefault)) if err != nil { - b.Log(err) - b.FailNow() + b.Fatal(err) } - for i := 1; i < 10; i++ { - buf = append(buf, buf...) + compressedZstdPayload = enc.EncodeAll(data, nil) + } + + { + var buf bytes.Buffer + enc := zlib.NewWriter(&buf) + if _, err := enc.Write(data); err != nil { + b.Fatal(err) } - return buf - }() + if err := enc.Close(); err != nil { + b.Fatal(err) + } + if err := enc.Close(); err != nil { + b.Fatal(err) + } + compressedZLibPayload = append(compressedZLibPayload[:0], buf.Bytes()...) + } + + b.ResetTimer() +} + +func BenchmarkCompressPayload(b *testing.B) { + initCompressionPayload(b) compressors := []wiremessage.CompressorID{ wiremessage.CompressorSnappy, @@ -88,6 +169,9 @@ func BenchmarkCompressPayload(b *testing.B) { ZlibLevel: wiremessage.DefaultZlibLevel, ZstdLevel: wiremessage.DefaultZstdLevel, } + payload := compressionPayload + b.SetBytes(int64(len(payload))) + b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _, err := CompressPayload(payload, opts) @@ -99,3 +183,38 @@ func BenchmarkCompressPayload(b *testing.B) { }) } } + +func BenchmarkDecompressPayload(b *testing.B) { + initCompressionPayload(b) + + benchmarks := []struct { + compressor wiremessage.CompressorID + payload []byte + }{ + {wiremessage.CompressorSnappy, compressedSnappyPayload}, + {wiremessage.CompressorZLib, compressedZLibPayload}, + {wiremessage.CompressorZstd, compressedZstdPayload}, + } + + for _, bench := range benchmarks { + b.Run(bench.compressor.String(), func(b *testing.B) { + opts := CompressionOpts{ + Compressor: bench.compressor, + ZlibLevel: wiremessage.DefaultZlibLevel, + ZstdLevel: wiremessage.DefaultZstdLevel, + UncompressedSize: int32(len(compressionPayload)), + } + payload := bench.payload + b.SetBytes(int64(len(compressionPayload))) + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := DecompressPayload(payload, opts) + if err != nil { + b.Fatal(err) + } + } + }) + }) + } +} diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index d2ae8df248..27be4c264d 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -99,6 +99,38 @@ func MakeReply(doc bsoncore.Document) []byte { return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) } +// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message. +func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) { + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + return nil, errors.New("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + return nil, errors.New("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + return nil, errors.New("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + return nil, errors.New("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + return nil, errors.New("could not read numberToReturn") + } + + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("could not read query") + } + return query, nil +} + // GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message. func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) { var ok bool diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go index 55f2fb37eb..177aa1234b 100644 --- a/x/mongo/driver/errors.go +++ b/x/mongo/driver/errors.go @@ -264,10 +264,15 @@ func (e Error) UnsupportedStorageEngine() bool { // Error implements the error interface. func (e Error) Error() string { + var msg string if e.Name != "" { - return fmt.Sprintf("(%v) %v", e.Name, e.Message) + msg = fmt.Sprintf("(%v)", e.Name) } - return e.Message + msg += " " + e.Message + if e.Wrapped != nil { + msg += ": " + e.Wrapped.Error() + } + return msg } // Unwrap returns the underlying error. diff --git a/x/mongo/driver/legacy.go b/x/mongo/driver/legacy.go index 9f3b8a39ac..c40f1f8091 100644 --- a/x/mongo/driver/legacy.go +++ b/x/mongo/driver/legacy.go @@ -19,4 +19,5 @@ const ( LegacyKillCursors LegacyListCollections LegacyListIndexes + LegacyHandshake ) diff --git a/x/mongo/driver/mongocrypt/binary.go b/x/mongo/driver/mongocrypt/binary.go index 9e887375a9..4e4b51d74b 100644 --- a/x/mongo/driver/mongocrypt/binary.go +++ b/x/mongo/driver/mongocrypt/binary.go @@ -9,7 +9,10 @@ package mongocrypt -// #include +/* +#include +#include +*/ import "C" import ( "unsafe" @@ -17,6 +20,7 @@ import ( // binary is a wrapper type around a mongocrypt_binary_t* type binary struct { + p *C.uint8_t wrapped *C.mongocrypt_binary_t } @@ -33,11 +37,11 @@ func newBinaryFromBytes(data []byte) *binary { return newBinary() } - // We don't need C.CBytes here because data cannot go out of scope. Any mongocrypt function that takes a - // mongocrypt_binary_t will make a copy of the data so the data can be garbage collected after calling. - addr := (*C.uint8_t)(unsafe.Pointer(&data[0])) // uint8_t* - dataLen := C.uint32_t(len(data)) // uint32_t + // TODO: Consider using runtime.Pinner to replace the C.CBytes after using go1.21.0. + addr := (*C.uint8_t)(C.CBytes(data)) // uint8_t* + dataLen := C.uint32_t(len(data)) // uint32_t return &binary{ + p: addr, wrapped: C.mongocrypt_binary_new_from_data(addr, dataLen), } } @@ -52,5 +56,8 @@ func (b *binary) toBytes() []byte { // close cleans up any resources associated with the given binary instance. func (b *binary) close() { + if b.p != nil { + C.free(unsafe.Pointer(b.p)) + } C.mongocrypt_binary_destroy(b.wrapped) } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 90573daa53..6b56191a01 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -19,6 +19,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/csot" @@ -321,8 +322,73 @@ func (op Operation) shouldEncrypt() bool { return op.Crypt != nil && !op.Crypt.BypassAutoEncryption() } +// filterDeprioritizedServers will filter out the server candidates that have +// been deprioritized by the operation due to failure. +// +// The server selector should try to select a server that is not in the +// deprioritization list. However, if this is not possible (e.g. there are no +// other healthy servers in the cluster), the selector may return a +// deprioritized server. +func filterDeprioritizedServers(candidates, deprioritized []description.Server) []description.Server { + if len(deprioritized) == 0 { + return candidates + } + + dpaSet := make(map[address.Address]*description.Server) + for i, srv := range deprioritized { + dpaSet[srv.Addr] = &deprioritized[i] + } + + allowed := []description.Server{} + + // Iterate over the candidates and append them to the allowdIndexes slice if + // they are not in the deprioritizedServers list. + for _, candidate := range candidates { + if srv, ok := dpaSet[candidate.Addr]; !ok || !srv.Equal(candidate) { + allowed = append(allowed, candidate) + } + } + + // If nothing is allowed, then all available servers must have been + // deprioritized. In this case, return the candidates list as-is so that the + // selector can find a suitable server + if len(allowed) == 0 { + return candidates + } + + return allowed +} + +// opServerSelector is a wrapper for the server selector that is assigned to the +// operation. The purpose of this wrapper is to filter candidates with +// operation-specific logic, such as deprioritizing failing servers. +type opServerSelector struct { + selector description.ServerSelector + deprioritizedServers []description.Server +} + +// SelectServer will filter candidates with operation-specific logic before +// passing them onto the user-defined or default selector. +func (oss *opServerSelector) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + selectedServers, err := oss.selector.SelectServer(topo, candidates) + if err != nil { + return nil, err + } + + filteredServers := filterDeprioritizedServers(selectedServers, oss.deprioritizedServers) + + return filteredServers, nil +} + // selectServer handles performing server selection for an operation. -func (op Operation) selectServer(ctx context.Context) (Server, error) { +func (op Operation) selectServer( + ctx context.Context, + requestID int32, + deprioritized []description.Server, +) (Server, error) { if err := op.Validate(); err != nil { return nil, err } @@ -339,15 +405,24 @@ func (op Operation) selectServer(ctx context.Context) (Server, error) { }) } + oss := &opServerSelector{ + selector: selector, + deprioritizedServers: deprioritized, + } + ctx = logger.WithOperationName(ctx, op.Name) - ctx = logger.WithOperationID(ctx, wiremessage.CurrentRequestID()) + ctx = logger.WithOperationID(ctx, requestID) - return op.Deployment.SelectServer(ctx, selector) + return op.Deployment.SelectServer(ctx, oss) } // getServerAndConnection should be used to retrieve a Server and Connection to execute an operation. -func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) { - server, err := op.selectServer(ctx) +func (op Operation) getServerAndConnection( + ctx context.Context, + requestID int32, + deprioritized []description.Server, +) (Server, Connection, error) { + server, err := op.selectServer(ctx, requestID, deprioritized) if err != nil { if op.Client != nil && !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { @@ -480,6 +555,11 @@ func (op Operation) Execute(ctx context.Context) error { first := true currIndex := 0 + // deprioritizedServers are a running list of servers that should be + // deprioritized during server selection. Per the specifications, we should + // only ever deprioritize the "previous server". + var deprioritizedServers []description.Server + // resetForRetry records the error that caused the retry, decrements retries, and resets the // retry loop variables to request a new server and a new connection for the next attempt. resetForRetry := func(err error) { @@ -505,11 +585,18 @@ func (op Operation) Execute(ctx context.Context) error { } } - // If we got a connection, close it immediately to release pool resources for - // subsequent retries. + // If we got a connection, close it immediately to release pool resources + // for subsequent retries. if conn != nil { + // If we are dealing with a sharded cluster, then mark the failed server + // as "deprioritized". + if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.Sharded { + deprioritizedServers = []description.Server{conn.Description()} + } + conn.Close() } + // Set the server and connection to nil to request a new server and connection. srvr = nil conn = nil @@ -530,11 +617,11 @@ func (op Operation) Execute(ctx context.Context) error { } }() for { - wiremessage.NextRequestID() + requestID := wiremessage.NextRequestID() // If the server or connection are nil, try to select a new server and get a new connection. if srvr == nil || conn == nil { - srvr, conn, err = op.getServerAndConnection(ctx) + srvr, conn, err = op.getServerAndConnection(ctx, requestID, deprioritizedServers) if err != nil { // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server @@ -629,7 +716,7 @@ func (op Operation) Execute(ctx context.Context) error { } var startedInfo startedInformation - *wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn) + *wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) if err != nil { return err @@ -1103,8 +1190,92 @@ func (op Operation) addBatchArray(dst []byte) []byte { return dst } -func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer, +func (op Operation) createLegacyHandshakeWireMessage( + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, +) ([]byte, startedInformation, error) { + var info startedInformation + flags := op.secondaryOK(desc) + var wmindex int32 + info.requestID = wiremessage.NextRequestID() + wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery) + dst = wiremessage.AppendQueryFlags(dst, flags) + + dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'} + + // FullCollectionName + dst = append(dst, op.Database...) + dst = append(dst, dollarCmd[:]...) + dst = append(dst, 0x00) + dst = wiremessage.AppendQueryNumberToSkip(dst, 0) + dst = wiremessage.AppendQueryNumberToReturn(dst, -1) + + wrapper := int32(-1) + rp, err := op.createReadPref(desc, true) + if err != nil { + return dst, info, err + } + if len(rp) > 0 { + wrapper, dst = bsoncore.AppendDocumentStart(dst) + dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query") + } + idx, dst := bsoncore.AppendDocumentStart(dst) + dst, err = op.CommandFn(dst, desc) + if err != nil { + return dst, info, err + } + + if op.Batches != nil && len(op.Batches.Current) > 0 { + dst = op.addBatchArray(dst) + } + + dst, err = op.addReadConcern(dst, desc) + if err != nil { + return dst, info, err + } + + dst, err = op.addWriteConcern(dst, desc) + if err != nil { + return dst, info, err + } + + dst, err = op.addSession(dst, desc) + if err != nil { + return dst, info, err + } + + dst = op.addClusterTime(dst, desc) + dst = op.addServerAPI(dst) + // If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly + // specifies the default behavior of no timeout server-side. + if maxTimeMS > 0 { + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS)) + } + + dst, _ = bsoncore.AppendDocumentEnd(dst, idx) + // Command monitoring only reports the document inside $query + info.cmd = dst[idx:] + + if len(rp) > 0 { + var err error + dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) + dst, err = bsoncore.AppendDocumentEnd(dst, wrapper) + if err != nil { + return dst, info, err + } + } + + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil +} + +func (op Operation) createMsgWireMessage( + ctx context.Context, + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, conn Connection, + requestID int32, ) ([]byte, startedInformation, error) { var info startedInformation var flags wiremessage.MsgFlag @@ -1120,7 +1291,7 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, flags |= wiremessage.ExhaustAllowed } - info.requestID = wiremessage.CurrentRequestID() + info.requestID = requestID wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg) dst = wiremessage.AppendMsgFlags(dst, flags) // Body @@ -1186,6 +1357,29 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil } +// isLegacyHandshake returns True if the operation is the first message of +// the initial handshake and should use a legacy hello. +func isLegacyHandshake(op Operation, desc description.SelectedServer) bool { + isInitialHandshake := desc.WireVersion == nil || desc.WireVersion.Max == 0 + + return op.Legacy == LegacyHandshake && isInitialHandshake +} + +func (op Operation) createWireMessage( + ctx context.Context, + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, + conn Connection, + requestID int32, +) ([]byte, startedInformation, error) { + if isLegacyHandshake(op, desc) { + return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc) + } + + return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID) +} + // addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document // has already been added and does not add the final 0 byte. func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) { diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 3cfa2d450a..16d5809130 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -537,8 +537,16 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti return h.createOperation().ExecuteExhaust(ctx, conn) } +// isLegacyHandshake returns True if server API version is not requested and +// loadBalanced is False. If this is the case, then the drivers MUST use legacy +// hello for the first message of the initial handshake with the OP_QUERY +// protocol +func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool { + return srvAPI == nil && deployment.Kind() != description.LoadBalanced +} + func (h *Hello) createOperation() driver.Operation { - return driver.Operation{ + op := driver.Operation{ Clock: h.clock, CommandFn: h.command, Database: "admin", @@ -549,23 +557,36 @@ func (h *Hello) createOperation() driver.Operation { }, ServerAPI: h.serverAPI, } + + if isLegacyHandshake(h.serverAPI, h.d) { + op.Legacy = driver.LegacyHandshake + } + + return op } // GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant // information about the server. This function implements the driver.Handshaker interface. func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, c driver.Connection) (driver.HandshakeInformation, error) { - err := driver.Operation{ + deployment := driver.SingleConnectionDeployment{C: c} + + op := driver.Operation{ Clock: h.clock, CommandFn: h.handshakeCommand, - Deployment: driver.SingleConnectionDeployment{C: c}, + Deployment: deployment, Database: "admin", ProcessResponseFn: func(info driver.ResponseInfo) error { h.res = info.ServerResponse return nil }, ServerAPI: h.serverAPI, - }.Execute(ctx) - if err != nil { + } + + if isLegacyHandshake(h.serverAPI, deployment) { + op.Legacy = driver.LegacyHandshake + } + + if err := op.Execute(ctx); err != nil { return driver.HandshakeInformation{}, err } @@ -578,6 +599,9 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, if serverConnectionID, ok := h.res.Lookup("connectionId").AsInt64OK(); ok { info.ServerConnectionID = &serverConnectionID } + + var err error + // Cast to bson.Raw to lookup saslSupportedMechs to avoid converting from bsoncore.Value to bson.RawValue for the // StringSliceFromRawValue call. if saslSupportedMechs, lookupErr := bson.Raw(h.res).LookupErr("saslSupportedMechs"); lookupErr == nil { diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index d4c5a1b6a0..e6c9d4cf95 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -20,6 +20,7 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/internal/handshake" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/internal/uuid" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" @@ -62,7 +63,7 @@ func TestOperation(t *testing.T) { t.Run("selectServer", func(t *testing.T) { t.Run("returns validation error", func(t *testing.T) { op := &Operation{} - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1, nil) if err == nil { t.Error("Expected a validation error from selectServer, but got ") } @@ -76,11 +77,15 @@ func TestOperation(t *testing.T) { Database: "testing", Selector: want, } - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1, nil) noerr(t, err) - got := d.params.selector - if !cmp.Equal(got, want) { - t.Errorf("Did not get expected server selector. got %v; want %v", got, want) + + // Assert the the selector is an operation selector wrapper. + oss, ok := d.params.selector.(*opServerSelector) + require.True(t, ok) + + if !cmp.Equal(oss.selector, want) { + t.Errorf("Did not get expected server selector. got %v; want %v", oss.selector, want) } }) t.Run("uses a default server selector", func(t *testing.T) { @@ -90,7 +95,7 @@ func TestOperation(t *testing.T) { Deployment: d, Database: "testing", } - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1, nil) noerr(t, err) if d.params.selector == nil { t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed .") @@ -652,7 +657,8 @@ func TestOperation(t *testing.T) { } func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte { - idx, wm := wiremessage.AppendHeaderStart(nil, 0, wiremessage.CurrentRequestID()+1, wiremessage.OpMsg) + const psuedoRequestID = 1 + idx, wm := wiremessage.AppendHeaderStart(nil, 0, psuedoRequestID, wiremessage.OpMsg) var flags wiremessage.MsgFlag if moreToCome { flags = wiremessage.MoreToCome @@ -880,3 +886,123 @@ func TestDecodeOpReply(t *testing.T) { assert.Equal(t, []bsoncore.Document(nil), reply.documents) }) } + +func TestFilterDeprioritizedServers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + deprioritized []description.Server + candidates []description.Server + want []description.Server + }{ + { + name: "empty", + candidates: []description.Server{}, + want: []description.Server{}, + }, + { + name: "nil candidates", + candidates: nil, + want: []description.Server{}, + }, + { + name: "nil deprioritized server list", + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + }, + { + name: "deprioritize single server candidate list", + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + deprioritized: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + want: []description.Server{ + // Since all available servers were deprioritized, then the selector + // should return all candidates. + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + }, + { + name: "depriotirize one server in multi server candidate list", + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + deprioritized: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + }, + { + name: "depriotirize multiple servers in multi server candidate list", + deprioritized: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), + }, + }, + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + }, + } + + for _, tc := range tests { + tc := tc // Capture the range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := filterDeprioritizedServers(tc.candidates, tc.deprioritized) + assert.ElementsMatch(t, got, tc.want) + }) + } +} diff --git a/x/mongo/driver/testdata/compression.go b/x/mongo/driver/testdata/compression.go new file mode 100644 index 0000000000..7f355f61a4 --- /dev/null +++ b/x/mongo/driver/testdata/compression.go @@ -0,0 +1,151 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driver + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" + + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" +) + +// CompressionOpts holds settings for how to compress a payload +type CompressionOpts struct { + Compressor wiremessage.CompressorID + ZlibLevel int + ZstdLevel int + UncompressedSize int32 +} + +var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder + +func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { + if v, ok := zstdEncoders.Load(level); ok { + return v.(*zstd.Encoder), nil + } + encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) + if err != nil { + return nil, err + } + zstdEncoders.Store(level, encoder) + return encoder, nil +} + +var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder + +func getZlibEncoder(level int) (*zlibEncoder, error) { + if v, ok := zlibEncoders.Load(level); ok { + return v.(*zlibEncoder), nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} + zlibEncoders.Store(level, encoder) + + return encoder, nil +} + +type zlibEncoder struct { + mu sync.Mutex + writer *zlib.Writer + buf *bytes.Buffer +} + +func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { + e.mu.Lock() + defer e.mu.Unlock() + + e.buf.Reset() + e.writer.Reset(e.buf) + + _, err := e.writer.Write(src) + if err != nil { + return nil, err + } + err = e.writer.Close() + if err != nil { + return nil, err + } + dst = append(dst[:0], e.buf.Bytes()...) + return dst, nil +} + +// CompressPayload takes a byte slice and compresses it according to the options passed +func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + return snappy.Encode(nil, in), nil + case wiremessage.CompressorZLib: + encoder, err := getZlibEncoder(opts.ZlibLevel) + if err != nil { + return nil, err + } + return encoder.Encode(nil, in) + case wiremessage.CompressorZstd: + encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel)) + if err != nil { + return nil, err + } + return encoder.EncodeAll(in, nil), nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +} + +// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed +func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + l, err := snappy.DecodedLen(in) + if err != nil { + return nil, fmt.Errorf("decoding compressed length %w", err) + } else if int32(l) != opts.UncompressedSize { + return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) + } + uncompressed = make([]byte, opts.UncompressedSize) + return snappy.Decode(uncompressed, in) + case wiremessage.CompressorZLib: + r, err := zlib.NewReader(bytes.NewReader(in)) + if err != nil { + return nil, err + } + defer func() { + err = r.Close() + }() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + case wiremessage.CompressorZstd: + r, err := zstd.NewReader(bytes.NewBuffer(in)) + if err != nil { + return nil, err + } + defer r.Close() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +} diff --git a/x/mongo/driver/topology/errors.go b/x/mongo/driver/topology/errors.go index 4f7b485405..7ce41864e6 100644 --- a/x/mongo/driver/topology/errors.go +++ b/x/mongo/driver/topology/errors.go @@ -9,6 +9,7 @@ package topology import ( "context" "fmt" + "time" "go.mongodb.org/mongo-driver/mongo/description" ) @@ -69,11 +70,17 @@ func (e ServerSelectionError) Unwrap() error { // WaitQueueTimeoutError represents a timeout when requesting a connection from the pool type WaitQueueTimeoutError struct { - Wrapped error - PinnedCursorConnections uint64 - PinnedTransactionConnections uint64 - maxPoolSize uint64 - totalConnectionCount int + Wrapped error + pinnedConnections *pinnedConnections + maxPoolSize uint64 + totalConnections int + availableConnections int + waitDuration time.Duration +} + +type pinnedConnections struct { + cursorConnections uint64 + transactionConnections uint64 } // Error implements the error interface. @@ -95,14 +102,19 @@ func (w WaitQueueTimeoutError) Error() string { ) } - return fmt.Sprintf( - "%s; maxPoolSize: %d, connections in use by cursors: %d"+ - ", connections in use by transactions: %d, connections in use by other operations: %d", - errorMsg, - w.maxPoolSize, - w.PinnedCursorConnections, - w.PinnedTransactionConnections, - uint64(w.totalConnectionCount)-w.PinnedCursorConnections-w.PinnedTransactionConnections) + msg := fmt.Sprintf("%s; total connections: %d, maxPoolSize: %d, ", errorMsg, w.totalConnections, w.maxPoolSize) + if pinnedConnections := w.pinnedConnections; pinnedConnections != nil { + openConnectionCount := uint64(w.totalConnections) - + pinnedConnections.cursorConnections - + pinnedConnections.transactionConnections + msg += fmt.Sprintf("connections in use by cursors: %d, connections in use by transactions: %d, connections in use by other operations: %d, ", + pinnedConnections.cursorConnections, + pinnedConnections.transactionConnections, + openConnectionCount, + ) + } + msg += fmt.Sprintf("idle connections: %d, wait duration: %s", w.availableConnections, w.waitDuration.String()) + return msg } // Unwrap returns the underlying error. diff --git a/x/mongo/driver/topology/polling_srv_records_test.go b/x/mongo/driver/topology/polling_srv_records_test.go index 0ca5c7cbce..7484109d4e 100644 --- a/x/mongo/driver/topology/polling_srv_records_test.go +++ b/x/mongo/driver/topology/polling_srv_records_test.go @@ -105,6 +105,7 @@ func (ss serverSorter) Less(i, j int) bool { } func compareHosts(t *testing.T, received []description.Server, expected []string) { + t.Helper() if len(received) != len(expected) { t.Fatalf("Number of hosts in topology does not match expected value. Got %v; want %v.", len(received), len(expected)) } diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 5d2369352e..6e150344db 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -74,6 +74,7 @@ type poolConfig struct { MaxConnecting uint64 MaxIdleTime time.Duration MaintainInterval time.Duration + LoadBalanced bool PoolMonitor *event.PoolMonitor Logger *logger.Logger handshakeErrFn func(error, uint64, *primitive.ObjectID) @@ -93,6 +94,7 @@ type pool struct { minSize uint64 maxSize uint64 maxConnecting uint64 + loadBalanced bool monitor *event.PoolMonitor logger *logger.Logger @@ -206,6 +208,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool { minSize: config.MinPoolSize, maxSize: config.MaxPoolSize, maxConnecting: maxConnecting, + loadBalanced: config.LoadBalanced, monitor: config.PoolMonitor, logger: config.Logger, handshakeErrFn: config.handshakeErrFn, @@ -574,6 +577,7 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { p.stateMu.RUnlock() // Wait for either the wantConn to be ready or for the Context to time out. + start := time.Now() select { case <-w.ready: if w.err != nil { @@ -615,6 +619,8 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { } return w.conn, nil case <-ctx.Done(): + duration := time.Since(start) + if mustLogPoolMessage(p) { keysAndValues := logger.KeyValues{ logger.KeyReason, logger.ReasonConnCheckoutFailedTimout, @@ -632,13 +638,20 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { }) } - return nil, WaitQueueTimeoutError{ - Wrapped: ctx.Err(), - PinnedCursorConnections: atomic.LoadUint64(&p.pinnedCursorConnections), - PinnedTransactionConnections: atomic.LoadUint64(&p.pinnedTransactionConnections), - maxPoolSize: p.maxSize, - totalConnectionCount: p.totalConnectionCount(), + err := WaitQueueTimeoutError{ + Wrapped: ctx.Err(), + maxPoolSize: p.maxSize, + totalConnections: p.totalConnectionCount(), + availableConnections: p.availableConnectionCount(), + waitDuration: duration, + } + if p.loadBalanced { + err.pinnedConnections = &pinnedConnections{ + cursorConnections: atomic.LoadUint64(&p.pinnedCursorConnections), + transactionConnections: atomic.LoadUint64(&p.pinnedTransactionConnections), + } } + return nil, err } } diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 600797df40..88b93b15e6 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -177,6 +177,7 @@ func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...Serv MaxConnecting: cfg.maxConnecting, MaxIdleTime: cfg.poolMaxIdleTime, MaintainInterval: cfg.poolMaintainInterval, + LoadBalanced: cfg.loadBalanced, PoolMonitor: cfg.poolMonitor, Logger: cfg.logger, handshakeErrFn: s.ProcessHandshakeError, diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index a2abd1fb1f..ba92b6dd94 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -31,9 +31,11 @@ import ( "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) type channelNetConnDialer struct{} @@ -1207,12 +1209,41 @@ func TestServer_ProcessError(t *testing.T) { func includesClientMetadata(t *testing.T, wm []byte) bool { t.Helper() - doc, err := drivertest.GetCommandFromMsgWireMessage(wm) - assert.NoError(t, err) + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + t.Fatal("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + t.Fatal("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + t.Fatal("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + t.Fatal("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + t.Fatal("could not read numberToReturn") + } + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + t.Fatal("could not read query") + } - _, err = doc.LookupErr("client") + if _, err := query.LookupErr("client"); err == nil { + return true + } + if _, err := query.LookupErr("$query", "client"); err == nil { + return true + } - return err == nil + return false } // processErrorTestConn is a driver.Connection implementation used by tests diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index b0683021ee..b79efed4ed 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -161,13 +161,13 @@ func New(cfg *Config) (*Topology, error) { return t, nil } -func mustLogTopologyMessage(topo *Topology) bool { +func mustLogTopologyMessage(topo *Topology, level logger.Level) bool { return topo.cfg.logger != nil && topo.cfg.logger.LevelComponentEnabled( - logger.LevelDebug, logger.ComponentTopology) + level, logger.ComponentTopology) } -func logTopologyMessage(topo *Topology, msg string, keysAndValues ...interface{}) { - topo.cfg.logger.Print(logger.LevelDebug, +func logTopologyMessage(topo *Topology, level logger.Level, msg string, keysAndValues ...interface{}) { + topo.cfg.logger.Print(level, logger.ComponentTopology, msg, logger.SerializeTopology(logger.Topology{ @@ -176,6 +176,36 @@ func logTopologyMessage(topo *Topology, msg string, keysAndValues ...interface{} }, keysAndValues...)...) } +func logTopologyThirdPartyUsage(topo *Topology, parsedHosts []string) { + thirdPartyMessages := [2]string{ + `You appear to be connected to a CosmosDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/cosmosdb`, + `You appear to be connected to a DocumentDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/documentdb`, + } + + thirdPartySuffixes := map[string]int{ + ".cosmos.azure.com": 0, + ".docdb.amazonaws.com": 1, + ".docdb-elastic.amazonaws.com": 1, + } + + hostSet := make([]bool, len(thirdPartyMessages)) + for _, host := range parsedHosts { + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + for suffix, env := range thirdPartySuffixes { + if !strings.HasSuffix(host, suffix) { + continue + } + if hostSet[env] { + break + } + hostSet[env] = true + logTopologyMessage(topo, logger.LevelInfo, thirdPartyMessages[env]) + } + } +} + func mustLogServerSelection(topo *Topology, level logger.Level) bool { return topo.cfg.logger != nil && topo.cfg.logger.LevelComponentEnabled( level, logger.ComponentServerSelection) @@ -183,8 +213,8 @@ func mustLogServerSelection(topo *Topology, level logger.Level) bool { func logServerSelection( ctx context.Context, - level logger.Level, topo *Topology, + level logger.Level, msg string, srvSelector description.ServerSelector, keysAndValues ...interface{}, @@ -224,7 +254,7 @@ func logServerSelectionSucceeded( portInt64, _ := strconv.ParseInt(port, 10, 32) - logServerSelection(ctx, logger.LevelDebug, topo, logger.ServerSelectionSucceeded, srvSelector, + logServerSelection(ctx, topo, logger.LevelDebug, logger.ServerSelectionSucceeded, srvSelector, logger.KeyServerHost, host, logger.KeyServerPort, portInt64) } @@ -235,7 +265,7 @@ func logServerSelectionFailed( srvSelector description.ServerSelector, err error, ) { - logServerSelection(ctx, logger.LevelDebug, topo, logger.ServerSelectionFailed, srvSelector, + logServerSelection(ctx, topo, logger.LevelDebug, logger.ServerSelectionFailed, srvSelector, logger.KeyFailure, err.Error()) } @@ -321,13 +351,17 @@ func (t *Topology) Connect() error { } t.serversLock.Unlock() + uri, err := url.Parse(t.cfg.URI) + if err != nil { + return err + } + parsedHosts := strings.Split(uri.Host, ",") + if mustLogTopologyMessage(t, logger.LevelInfo) { + logTopologyThirdPartyUsage(t, parsedHosts) + } if t.pollingRequired { - uri, err := url.Parse(t.cfg.URI) - if err != nil { - return err - } // sanity check before passing the hostname to resolver - if parsedHosts := strings.Split(uri.Host, ","); len(parsedHosts) != 1 { + if len(parsedHosts) != 1 { return fmt.Errorf("URI with SRV must include one and only one hostname") } _, _, err = net.SplitHostPort(uri.Host) @@ -492,7 +526,7 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect if !doneOnce { if mustLogServerSelection(t, logger.LevelDebug) { - logServerSelection(ctx, logger.LevelDebug, t, logger.ServerSelectionStarted, ss) + logServerSelection(ctx, t, logger.LevelDebug, logger.ServerSelectionStarted, ss) } // for the first pass, select a server from the current description. @@ -531,7 +565,7 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect elapsed := time.Since(startTime) remainingTimeMS := t.cfg.ServerSelectionTimeout - elapsed - logServerSelection(ctx, logger.LevelInfo, t, logger.ServerSelectionWaiting, ss, + logServerSelection(ctx, t, logger.LevelInfo, logger.ServerSelectionWaiting, ss, logger.KeyRemainingTimeMS, remainingTimeMS.Milliseconds()) } @@ -970,7 +1004,7 @@ func (t *Topology) publishServerClosedEvent(addr address.Address) { t.cfg.ServerMonitor.ServerClosed(serverClosed) } - if mustLogTopologyMessage(t) { + if mustLogTopologyMessage(t, logger.LevelDebug) { serverHost, serverPort, err := net.SplitHostPort(addr.String()) if err != nil { serverHost = addr.String() @@ -979,7 +1013,7 @@ func (t *Topology) publishServerClosedEvent(addr address.Address) { portInt64, _ := strconv.ParseInt(serverPort, 10, 32) - logTopologyMessage(t, logger.TopologyServerClosed, + logTopologyMessage(t, logger.LevelDebug, logger.TopologyServerClosed, logger.KeyServerHost, serverHost, logger.KeyServerPort, portInt64) } @@ -997,8 +1031,8 @@ func (t *Topology) publishTopologyDescriptionChangedEvent(prev description.Topol t.cfg.ServerMonitor.TopologyDescriptionChanged(topologyDescriptionChanged) } - if mustLogTopologyMessage(t) { - logTopologyMessage(t, logger.TopologyDescriptionChanged, + if mustLogTopologyMessage(t, logger.LevelDebug) { + logTopologyMessage(t, logger.LevelDebug, logger.TopologyDescriptionChanged, logger.KeyPreviousDescription, prev.String(), logger.KeyNewDescription, current.String()) } @@ -1014,8 +1048,8 @@ func (t *Topology) publishTopologyOpeningEvent() { t.cfg.ServerMonitor.TopologyOpening(topologyOpening) } - if mustLogTopologyMessage(t) { - logTopologyMessage(t, logger.TopologyOpening) + if mustLogTopologyMessage(t, logger.LevelDebug) { + logTopologyMessage(t, logger.LevelDebug, logger.TopologyOpening) } } @@ -1029,7 +1063,7 @@ func (t *Topology) publishTopologyClosedEvent() { t.cfg.ServerMonitor.TopologyClosed(topologyClosed) } - if mustLogTopologyMessage(t) { - logTopologyMessage(t, logger.TopologyClosed) + if mustLogTopologyMessage(t, logger.LevelDebug) { + logTopologyMessage(t, logger.LevelDebug, logger.TopologyClosed) } } diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index 773a8b6475..6cf540a95e 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -693,6 +693,240 @@ func TestTopologyConstruction(t *testing.T) { }) } +type mockLogSink struct { + msgs []string +} + +func (s *mockLogSink) Info(_ int, msg string, _ ...interface{}) { + s.msgs = append(s.msgs, msg) +} +func (*mockLogSink) Error(error, string, ...interface{}) { + // Do nothing. +} + +// Note: SRV connection strings are intentionally untested, since initial +// lookup responses cannot be easily mocked. +func TestTopologyConstructionLogging(t *testing.T) { + const ( + cosmosDBMsg = `You appear to be connected to a CosmosDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/cosmosdb` + documentDBMsg = `You appear to be connected to a DocumentDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/documentdb` + ) + + newLoggerOptions := func(sink options.LogSink) *options.LoggerOptions { + return options. + Logger(). + SetSink(sink). + SetComponentLevel(options.LogComponentTopology, options.LogLevelInfo) + } + + t.Run("CosmosDB URIs", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + uri string + msgs []string + }{ + { + name: "normal", + uri: "mongodb://a.mongo.cosmos.azure.com:19555/", + msgs: []string{cosmosDBMsg}, + }, + { + name: "multiple hosts", + uri: "mongodb://a.mongo.cosmos.azure.com:1955,b.mongo.cosmos.azure.com:19555/", + msgs: []string{cosmosDBMsg}, + }, + { + name: "case-insensitive matching", + uri: "mongodb://a.MONGO.COSMOS.AZURE.COM:19555/", + msgs: []string{}, + }, + { + name: "Mixing genuine and nongenuine hosts (unlikely in practice)", + uri: "mongodb://a.example.com:27017,b.mongo.cosmos.azure.com:19555/", + msgs: []string{cosmosDBMsg}, + }, + } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + sink := &mockLogSink{} + cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil) + require.Nil(t, err, "error constructing topology config: %v", err) + + topo, err := New(cfg) + require.Nil(t, err, "topology.New error: %v", err) + + err = topo.Connect() + assert.Nil(t, err, "Connect error: %v", err) + + assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs) + }) + } + }) + t.Run("DocumentDB URIs", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + uri string + msgs []string + }{ + { + name: "normal", + uri: "mongodb://a.docdb.amazonaws.com:27017/", + msgs: []string{documentDBMsg}, + }, + { + name: "normal", + uri: "mongodb://a.docdb-elastic.amazonaws.com:27017/", + msgs: []string{documentDBMsg}, + }, + { + name: "multiple hosts", + uri: "mongodb://a.docdb.amazonaws.com:27017,a.docdb-elastic.amazonaws.com:27017/", + msgs: []string{documentDBMsg}, + }, + { + name: "case-insensitive matching", + uri: "mongodb://a.DOCDB.AMAZONAWS.COM:27017/", + msgs: []string{}, + }, + { + name: "case-insensitive matching", + uri: "mongodb://a.DOCDB-ELASTIC.AMAZONAWS.COM:27017/", + msgs: []string{}, + }, + { + name: "Mixing genuine and nongenuine hosts (unlikely in practice)", + uri: "mongodb://a.example.com:27017,b.docdb.amazonaws.com:27017/", + msgs: []string{documentDBMsg}, + }, + { + name: "Mixing genuine and nongenuine hosts (unlikely in practice)", + uri: "mongodb://a.example.com:27017,b.docdb-elastic.amazonaws.com:27017/", + msgs: []string{documentDBMsg}, + }, + } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + sink := &mockLogSink{} + cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil) + require.Nil(t, err, "error constructing topology config: %v", err) + + topo, err := New(cfg) + require.Nil(t, err, "topology.New error: %v", err) + + err = topo.Connect() + assert.Nil(t, err, "Connect error: %v", err) + + assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs) + }) + } + }) + t.Run("Mixing CosmosDB and DocumentDB URIs", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + uri string + msgs []string + }{ + { + name: "Mixing hosts", + uri: "mongodb://a.mongo.cosmos.azure.com:19555,a.docdb.amazonaws.com:27017/", + msgs: []string{cosmosDBMsg, documentDBMsg}, + }, + } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + sink := &mockLogSink{} + cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil) + require.Nil(t, err, "error constructing topology config: %v", err) + + topo, err := New(cfg) + require.Nil(t, err, "topology.New error: %v", err) + + err = topo.Connect() + assert.Nil(t, err, "Connect error: %v", err) + + assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs) + }) + } + }) + t.Run("genuine URIs", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + uri string + msgs []string + }{ + { + name: "normal", + uri: "mongodb://a.example.com:27017/", + msgs: []string{}, + }, + { + name: "srv", + uri: "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname", + msgs: []string{}, + }, + { + name: "multiple hosts", + uri: "mongodb://a.example.com:27017,b.example.com:27017/", + msgs: []string{}, + }, + { + name: "unexpected suffix", + uri: "mongodb://a.mongo.cosmos.azure.com.tld:19555/", + msgs: []string{}, + }, + { + name: "unexpected suffix", + uri: "mongodb://a.docdb.amazonaws.com.tld:27017/", + msgs: []string{}, + }, + { + name: "unexpected suffix", + uri: "mongodb://a.docdb-elastic.amazonaws.com.tld:27017/", + msgs: []string{}, + }, + } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + sink := &mockLogSink{} + cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil) + require.Nil(t, err, "error constructing topology config: %v", err) + + topo, err := New(cfg) + require.Nil(t, err, "topology.New error: %v", err) + + err = topo.Connect() + assert.Nil(t, err, "Connect error: %v", err) + + assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs) + }) + } + }) +} + type inWindowServer struct { Address string `json:"address"` Type string `json:"type"` diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index c4d2567bf0..abf09c15bd 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -19,9 +19,6 @@ type WireMessage []byte var globalRequestID int32 -// CurrentRequestID returns the current request ID. -func CurrentRequestID() int32 { return atomic.LoadInt32(&globalRequestID) } - // NextRequestID returns the next request ID. func NextRequestID() int32 { return atomic.AddInt32(&globalRequestID, 1) }