From c481fdff87244758b75c8552a4c124bd06ab49d1 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 2 May 2024 12:08:54 -0600 Subject: [PATCH] GODRIVER-2443 Make Distinct return a decodable struct (#1603) --- internal/integration/collection_test.go | 25 ++-- internal/integration/crud_helpers_test.go | 59 +++++--- internal/integration/sessions_test.go | 2 + internal/integration/unified/admin_helpers.go | 2 +- .../unified/collection_operation_execution.go | 12 +- internal/integration/unified_spec_test.go | 4 +- mongo/client_examples_test.go | 2 +- mongo/collection.go | 36 ++--- mongo/collection_test.go | 4 +- mongo/crud_examples_test.go | 7 +- mongo/results.go | 48 +++++++ mongo/single_result.go | 9 +- mongo/single_result_test.go | 128 +++++++++--------- 13 files changed, 211 insertions(+), 127 deletions(-) diff --git a/internal/integration/collection_test.go b/internal/integration/collection_test.go index 43dee0f2cc..c98441b9d2 100644 --- a/internal/integration/collection_test.go +++ b/internal/integration/collection_test.go @@ -874,24 +874,29 @@ func TestCollection(t *testing.T) { } }) mt.RunOpts("distinct", noClientOpts, func(mt *mtest.T) { - all := []interface{}{int32(1), int32(2), int32(3), int32(4), int32(5)} - last3 := []interface{}{int32(3), int32(4), int32(5)} + all := []int32{1, 2, 3, 4, 5} + testCases := []struct { - name string - filter bson.D - opts *options.DistinctOptions - expected []interface{} + name string + filter bson.D + opts *options.DistinctOptions + want []int32 }{ {"no options", bson.D{}, nil, all}, - {"filter", bson.D{{"x", bson.D{{"$gt", 2}}}}, nil, last3}, + {"filter", bson.D{{"x", bson.D{{"$gt", 2}}}}, nil, all[2:]}, {"options", bson.D{}, options.Distinct().SetMaxTime(5000000000), all}, } for _, tc := range testCases { mt.Run(tc.name, func(mt *mtest.T) { initCollection(mt, mt.Coll) - res, err := mt.Coll.Distinct(context.Background(), "x", tc.filter, tc.opts) - assert.Nil(mt, err, "Distinct error: %v", err) - assert.Equal(mt, tc.expected, res, "expected result %v, got %v", tc.expected, res) + res := mt.Coll.Distinct(context.Background(), "x", tc.filter, tc.opts) + assert.Nil(mt, res.Err(), "Distinct error: %v", res.Err()) + + var got []int32 + err := res.Decode(&got) + assert.NoError(t, err) + + assert.EqualValues(mt, tc.want, got, "expected result %v, got %v", tc.want, got) }) } }) diff --git a/internal/integration/crud_helpers_test.go b/internal/integration/crud_helpers_test.go index e6337cb210..355f934add 100644 --- a/internal/integration/crud_helpers_test.go +++ b/internal/integration/crud_helpers_test.go @@ -602,7 +602,7 @@ func executeListIndexes(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo return mt.Coll.Indexes().List(context.Background()) } -func executeDistinct(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]interface{}, error) { +func executeDistinct(mt *mtest.T, sess *mongo.Session, args bson.Raw) (bson.RawArray, error) { mt.Helper() var fieldName string @@ -627,16 +627,22 @@ func executeDistinct(mt *mtest.T, sess *mongo.Session, args bson.Raw) ([]interfa } } + var res *mongo.DistinctResult if sess != nil { - var res []interface{} - err := mongo.WithSession(context.Background(), sess, func(sc context.Context) error { - var derr error - res, derr = mt.Coll.Distinct(sc, fieldName, filter, opts) - return derr + err := mongo.WithSession(context.Background(), sess, func(ctx context.Context) error { + res = mt.Coll.Distinct(ctx, fieldName, filter, opts) + + return res.Err() }) - return res, err + + if err != nil { + return nil, err + } + } else { + res = mt.Coll.Distinct(context.Background(), fieldName, filter, opts) } - return mt.Coll.Distinct(context.Background(), fieldName, filter, opts) + + return res.Raw() } func executeFindOneAndDelete(mt *mtest.T, sess *mongo.Session, args bson.Raw) *mongo.SingleResult { @@ -1529,25 +1535,34 @@ func verifyDeleteResult(mt *mtest.T, res *mongo.DeleteResult, result interface{} "deleted count mismatch; expected %v, got %v", expected.DeletedCount, res.DeletedCount) } -func verifyDistinctResult(mt *mtest.T, actualResult []interface{}, expectedResult interface{}) { +func verifyDistinctResult( + mt *mtest.T, + got bson.RawArray, + want interface{}, +) { mt.Helper() - if expectedResult == nil { + if got == nil { return } - for i, expected := range expectedResult.(bson.A) { - actual := actualResult[i] - iExpected := getIntFromInterface(expected) - iActual := getIntFromInterface(actual) + assert.NotNil(mt, want, "expected want to be non-nil") - if iExpected != nil { - assert.NotNil(mt, iActual, "expected nil but got %v", iActual) - assert.Equal(mt, *iExpected, *iActual, "expected value %v but got %v", *iExpected, *iActual) - continue + arr, ok := want.(bson.A) + assert.True(mt, ok, "expected want to be a BSON array") + + for i, iwant := range arr { + gotRawValue := got.Index(uint(i)) + + iwantType, iwantBytes, err := bson.MarshalValue(iwant) + assert.NoError(mt, err) + + wantRawValue := bson.RawValue{ + Type: iwantType, + Value: iwantBytes, } - assert.Equal(mt, expected, actual, "expected value %v but got %v", expected, actual) + assert.EqualValues(mt, wantRawValue, gotRawValue, "expected value %v but got %v", wantRawValue, gotRawValue) } } @@ -1636,7 +1651,11 @@ func verifyCursorResult(mt *mtest.T, cur *mongo.Cursor, result interface{}) { } } -func verifySingleResult(mt *mtest.T, actualResult *mongo.SingleResult, expectedResult interface{}) { +func verifySingleResult( + mt *mtest.T, + actualResult *mongo.SingleResult, + expectedResult interface{}, +) { mt.Helper() if expectedResult == nil { diff --git a/internal/integration/sessions_test.go b/internal/integration/sessions_test.go index d368e4a76a..0150a21fa2 100644 --- a/internal/integration/sessions_test.go +++ b/internal/integration/sessions_test.go @@ -620,6 +620,8 @@ func extractReturnError(returnValues []reflect.Value) error { return converted case *mongo.SingleResult: return converted.Err() + case *mongo.DistinctResult: + return converted.Err() default: return nil } diff --git a/internal/integration/unified/admin_helpers.go b/internal/integration/unified/admin_helpers.go index 5e5379256d..a5184432ab 100644 --- a/internal/integration/unified/admin_helpers.go +++ b/internal/integration/unified/admin_helpers.go @@ -67,7 +67,7 @@ func performDistinctWorkaround(ctx context.Context) error { commandFn := func(ctx context.Context, client *mongo.Client) error { for _, coll := range entities(ctx).collections() { newColl := client.Database(coll.Database().Name()).Collection(coll.Name()) - _, err := newColl.Distinct(ctx, "x", bson.D{}) + err := newColl.Distinct(ctx, "x", bson.D{}).Err() if err != nil { ns := fmt.Sprintf("%s.%s", coll.Database().Name(), coll.Name()) return fmt.Errorf("error running distinct for collection %q: %w", ns, err) diff --git a/internal/integration/unified/collection_operation_execution.go b/internal/integration/unified/collection_operation_execution.go index 5be40a391c..796ab01344 100644 --- a/internal/integration/unified/collection_operation_execution.go +++ b/internal/integration/unified/collection_operation_execution.go @@ -531,15 +531,17 @@ func executeDistinct(ctx context.Context, operation *operation) (*operationResul return nil, newMissingArgumentError("filter") } - res, err := coll.Distinct(ctx, fieldName, filter, opts) - if err != nil { + res := coll.Distinct(ctx, fieldName, filter, opts) + if err := res.Err(); err != nil { return newErrorResult(err), nil } - _, rawRes, err := bson.MarshalValue(res) + + arr, err := res.Raw() if err != nil { - return nil, fmt.Errorf("error converting Distinct result to raw BSON: %w", err) + return newErrorResult(err), nil } - return newValueResult(bson.TypeArray, rawRes, nil), nil + + return newValueResult(bson.TypeArray, arr, nil), nil } func executeDropIndex(ctx context.Context, operation *operation) (*operationResult, error) { diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index 256071ef56..cba3244db3 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -259,8 +259,8 @@ func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) { if mtest.ClusterTopologyKind() == mtest.Sharded && test.Description == "distinct" { err := runCommandOnAllServers(func(mongosClient *mongo.Client) error { coll := mongosClient.Database(mt.DB.Name()).Collection(mt.Coll.Name()) - _, err := coll.Distinct(context.Background(), "x", bson.D{}) - return err + + return coll.Distinct(context.Background(), "x", bson.D{}).Err() }) assert.Nil(mt, err, "error running distinct against all mongoses: %v", err) } diff --git a/mongo/client_examples_test.go b/mongo/client_examples_test.go index e6654ba84d..fae3083580 100644 --- a/mongo/client_examples_test.go +++ b/mongo/client_examples_test.go @@ -398,7 +398,7 @@ func ExampleConnect_stableAPI() { coll := serverAPIStrictClient.Database("db").Collection("coll") // Fails with error: (APIStrictError) Provided apiStrict:true, but the // command distinct is not in API Version 1 - _, err = coll.Distinct(context.TODO(), "distinct", bson.D{}) + err = coll.Distinct(context.TODO(), "distinct", bson.D{}).Err() log.Println(err) // ServerAPIOptions can be declared with a DeprecationErrors option. diff --git a/mongo/collection.go b/mongo/collection.go index 1767df8aa6..f5661c3be4 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1298,8 +1298,12 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, // The opts parameter can be used to specify options for the operation (see the options.DistinctOptions documentation). // // For more information about the command, see https://www.mongodb.com/docs/manual/reference/command/distinct/. -func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter interface{}, - opts ...*options.DistinctOptions) ([]interface{}, error) { +func (coll *Collection) Distinct( + ctx context.Context, + fieldName string, + filter interface{}, + opts ...*options.DistinctOptions, +) *DistinctResult { if ctx == nil { ctx = context.Background() @@ -1307,7 +1311,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { - return nil, err + return &DistinctResult{err: err} } sess := sessionFromContext(ctx) @@ -1319,7 +1323,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i err = coll.client.validSession(sess) if err != nil { - return nil, err + return &DistinctResult{err: err} } rc := coll.readConcern @@ -1357,7 +1361,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i if option.Comment != nil { comment, err := marshalValue(option.Comment, coll.bsonOpts, coll.registry) if err != nil { - return nil, err + return &DistinctResult{err: err} } op.Comment(comment) } @@ -1369,30 +1373,20 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i err = op.Execute(ctx) if err != nil { - return nil, replaceErrors(err) + return &DistinctResult{err: replaceErrors(err)} } arr, ok := op.Result().Values.ArrayOK() if !ok { - return nil, fmt.Errorf("response field 'values' is type array, but received BSON type %s", op.Result().Values.Type) - } + err := fmt.Errorf("response field 'values' is type array, but received BSON type %s", op.Result().Values.Type) - values, err := arr.Values() - if err != nil { - return nil, err + return &DistinctResult{err: err} } - retArray := make([]interface{}, len(values)) - - for i, val := range values { - raw := bson.RawValue{Type: bson.Type(val.Type), Value: val.Data} - err = raw.Unmarshal(&retArray[i]) - if err != nil { - return nil, err - } + return &DistinctResult{ + reg: coll.registry, + arr: bson.RawArray(arr), } - - return retArray, replaceErrors(err) } // mergeFindOptions combines the given FindOptions instances into a single FindOptions in a last-property-wins fashion. diff --git a/mongo/collection_test.go b/mongo/collection_test.go index 355cae09cd..ffff38753d 100644 --- a/mongo/collection_test.go +++ b/mongo/collection_test.go @@ -110,7 +110,7 @@ func TestCollection(t *testing.T) { _, err = coll.CountDocuments(bgCtx, doc) assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) - _, err = coll.Distinct(bgCtx, "x", doc) + err = coll.Distinct(bgCtx, "x", doc).Err() assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err) _, err = coll.Find(bgCtx, doc) @@ -176,7 +176,7 @@ func TestCollection(t *testing.T) { _, err = coll.CountDocuments(bgCtx, nil) assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) - _, err = coll.Distinct(bgCtx, "x", nil) + err = coll.Distinct(bgCtx, "x", nil).Err() assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) _, err = coll.Find(bgCtx, nil) diff --git a/mongo/crud_examples_test.go b/mongo/crud_examples_test.go index e17be4bce4..9ef1a63acd 100644 --- a/mongo/crud_examples_test.go +++ b/mongo/crud_examples_test.go @@ -323,7 +323,12 @@ func ExampleCollection_Distinct() { // run on the server. filter := bson.D{{"age", bson.D{{"$gt", 25}}}} opts := options.Distinct().SetMaxTime(2 * time.Second) - values, err := coll.Distinct(context.TODO(), "name", filter, opts) + res := coll.Distinct(context.TODO(), "name", filter, opts) + if err := res.Err(); err != nil { + log.Fatal(err) + } + + values, err := res.Raw() if err != nil { log.Fatal(err) } diff --git a/mongo/results.go b/mongo/results.go index 0dfe510440..818d283fb7 100644 --- a/mongo/results.go +++ b/mongo/results.go @@ -173,3 +173,51 @@ type CollectionSpecification struct { // option is used and for MongoDB versions < 3.4. IDIndex *IndexSpecification } + +// DistinctResult represents an array of BSON data returned from an operation. +// If the operation resulted in an error, all DistinctResult methods will return +// that error. If the operation did not return any data, all DistinctResult +// methods will return ErrNoDocuments. +type DistinctResult struct { + err error + arr bson.RawArray + reg *bson.Registry +} + +// Decode will unmarshal the array represented by this DistinctResult into v. If +// there was an error from the operation that created this DistinctReuslt, that +// error will be returned. If the operation returned no array, Decode will +// return ErrNoDocuments. +// +// If the operation was successful and returned an array, Decode will return any +// errors from the unmarshalling process without any modification. If v is nil +// or is a typed nil, an error will be returned. +func (dr *DistinctResult) Decode(v any) error { + val := bson.RawValue{ + Value: dr.arr, + Type: bson.TypeArray, + } + + return val.UnmarshalWithRegistry(dr.reg, v) +} + +// Err provides a way to check for query errors without calling Decode. Err +// returns the error, if any, that was encountered while running the operation. +// If the operation was successful but did not return any documents, Err returns +// ErrNoDocuments. If this error is not nil, this error will also be returned +// from Decode. +func (dr *DistinctResult) Err() error { + return dr.err +} + +// Raw returns the document represented by this DistinctResult as a bson.Raw. If +// there was an error from the operation that created this DistinctResult, both +// the result and that error will be returned. If the operation returned no +// documents, this will return (nil, ErrNoDocuments). +func (dr *DistinctResult) Raw() (bson.RawArray, error) { + if dr.err != nil { + return nil, dr.err + } + + return dr.arr, nil +} diff --git a/mongo/single_result.go b/mongo/single_result.go index e0639e4069..6a0a695685 100644 --- a/mongo/single_result.go +++ b/mongo/single_result.go @@ -35,7 +35,11 @@ type SingleResult struct { // from the one provided occurs during creation of the SingleResult, that error will be stored on the returned SingleResult. // // The document parameter must be a non-nil document. -func NewSingleResultFromDocument(document interface{}, err error, registry *bson.Registry) *SingleResult { +func NewSingleResultFromDocument( + document interface{}, + err error, + registry *bson.Registry, +) *SingleResult { if document == nil { return &SingleResult{err: ErrNilDocument} } @@ -90,6 +94,7 @@ func (sr *SingleResult) Raw() (bson.Raw, error) { if sr.err = sr.setRdrContents(); sr.err != nil { return nil, sr.err } + return sr.rdr, nil } @@ -110,7 +115,9 @@ func (sr *SingleResult) setRdrContents() error { return ErrNoDocuments } + sr.rdr = sr.cur.Current + return nil } diff --git a/mongo/single_result_test.go b/mongo/single_result_test.go index 1338fe90c6..a9f409eeb0 100644 --- a/mongo/single_result_test.go +++ b/mongo/single_result_test.go @@ -14,71 +14,8 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/internal/require" - "go.mongodb.org/mongo-driver/mongo/options" ) -func TestSingleResult(t *testing.T) { - t.Run("Decode", func(t *testing.T) { - t.Run("decode twice", func(t *testing.T) { - // Test that Decode and Raw can be called more than once - c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.DefaultRegistry) - assert.Nil(t, err, "newCursor error: %v", err) - - sr := &SingleResult{cur: c, reg: bson.DefaultRegistry} - var firstDecode, secondDecode bson.Raw - err = sr.Decode(&firstDecode) - assert.Nil(t, err, "Decode error: %v", err) - err = sr.Decode(&secondDecode) - assert.Nil(t, err, "Decode error: %v", err) - - rawBytes, err := sr.Raw() - assert.Nil(t, err, "Raw error: %v", err) - - assert.Equal(t, firstDecode, secondDecode, "expected contents %v, got %v", firstDecode, secondDecode) - assert.Equal(t, firstDecode, rawBytes, "expected contents %v, got %v", firstDecode, rawBytes) - }) - t.Run("decode with error", func(t *testing.T) { - r := []byte("foo") - sr := &SingleResult{rdr: r, err: errors.New("Raw error")} - res, err := sr.Raw() - resBytes := []byte(res) - assert.Equal(t, r, resBytes, "expected contents %v, got %v", r, resBytes) - assert.Equal(t, sr.err, err, "expected error %v, got %v", sr.err, err) - }) - t.Run("with BSONOptions", func(t *testing.T) { - c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.DefaultRegistry) - require.NoError(t, err, "newCursor error") - - sr := &SingleResult{ - cur: c, - bsonOpts: &options.BSONOptions{ - UseJSONStructTags: true, - }, - reg: bson.DefaultRegistry, - } - - type myDocument struct { - A *int32 `json:"foo"` - } - - var got myDocument - err = sr.Decode(&got) - require.NoError(t, err, "Decode error") - - i := int32(0) - want := myDocument{A: &i} - - assert.Equal(t, want, got, "expected and actual Decode results are different") - }) - }) - - t.Run("Err", func(t *testing.T) { - sr := &SingleResult{} - assert.Equal(t, ErrNoDocuments, sr.Err(), "expected error %v, got %v", ErrNoDocuments, sr.Err()) - }) -} - func TestNewSingleResultFromDocument(t *testing.T) { // Mock a document returned by FindOne in SingleResult. t.Run("mock FindOne", func(t *testing.T) { @@ -132,4 +69,69 @@ func TestNewSingleResultFromDocument(t *testing.T) { assert.Equal(t, mockErr, res.cur.err, "expected underlying cursor %v, got %v", mockErr, res.cur.err) }) + + // Mock an error in SingleResult. + t.Run("mock FindOne with error", func(t *testing.T) { + mockErr := fmt.Errorf("mock error") + res := NewSingleResultFromDocument(bson.D{}, mockErr, nil) + + // Assert that the raw bytes returns the mocked error. + _, err := res.Raw() + assert.NotNil(t, err, "expected Raw error, got nil") + assert.Equal(t, mockErr, err, "expected error %v, got %v", mockErr, err) + + // Check for error on SingleResult. + assert.NotNil(t, res.Err(), "expected SingleResult error, got nil") + assert.Equal(t, mockErr, res.Err(), "expected SingleResult error %v, got %v", + mockErr, res.Err()) + + // Assert that error is propagated to underlying cursor. + assert.NotNil(t, res.cur.err, "expected underlying cursor, got nil") + assert.Equal(t, mockErr, res.cur.err, "expected underlying cursor %v, got %v", + mockErr, res.cur.err) + }) +} + +func TestSingleResult_Decode(t *testing.T) { + t.Run("decode twice", func(t *testing.T) { + t.Run("bson.Raw", func(t *testing.T) { + // Test that Decode and Raw can be called more than once + c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.DefaultRegistry) + assert.Nil(t, err, "newCursor error: %v", err) + + sr := &SingleResult{cur: c, reg: bson.DefaultRegistry} + var firstDecode, secondDecode bson.Raw + err = sr.Decode(&firstDecode) + assert.Nil(t, err, "Decode error: %v", err) + err = sr.Decode(&secondDecode) + assert.Nil(t, err, "Decode error: %v", err) + + rawBytes, err := sr.Raw() + assert.Nil(t, err, "Raw error: %v", err) + + assert.Equal(t, firstDecode, secondDecode, "expected contents %v, got %v", firstDecode, secondDecode) + assert.Equal(t, firstDecode, rawBytes, "expected contents %v, got %v", firstDecode, rawBytes) + }) + }) + + t.Run("decode with error", func(t *testing.T) { + t.Run("bson.Raw", func(t *testing.T) { + r := []byte("foo") + sr := &SingleResult{ + rdr: r, + err: errors.New("Raw error"), + } + res, err := sr.Raw() + resBytes := []byte(res) + assert.Equal(t, r, resBytes, "expected contents %v, got %v", r, resBytes) + assert.Equal(t, sr.err, err, "expected error %v, got %v", sr.err, err) + }) + }) +} + +func TestSingleResult_Err(t *testing.T) { + t.Run("bson.Raw", func(t *testing.T) { + sr := &SingleResult{} + assert.Equal(t, ErrNoDocuments, sr.Err(), "expected error %v, got %v", ErrNoDocuments, sr.Err()) + }) }