Skip to content

Commit

Permalink
GODRIVER-3285 [v2]Allow update to supply sort option.
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Sep 9, 2024
1 parent 7910023 commit fab1a09
Show file tree
Hide file tree
Showing 19 changed files with 1,749 additions and 116 deletions.
4 changes: 2 additions & 2 deletions internal/integration/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ func TestCollection(t *testing.T) {
filter := bson.D{{"x", 0}}
update := bson.D{{"$inc", bson.D{{"x", 1}}}}

res, err := mt.Coll.UpdateOne(context.Background(), filter, update, options.Update().SetUpsert(true))
res, err := mt.Coll.UpdateOne(context.Background(), filter, update, options.UpdateOne().SetUpsert(true))
assert.Nil(mt, err, "UpdateOne error: %v", err)
assert.Equal(mt, int64(0), res.MatchedCount, "expected matched count 0, got %v", res.MatchedCount)
assert.Equal(mt, int64(0), res.ModifiedCount, "expected matched count 0, got %v", res.ModifiedCount)
Expand Down Expand Up @@ -570,7 +570,7 @@ func TestCollection(t *testing.T) {
update := bson.D{{"$inc", bson.D{{"x", 1}}}}

id := "blah"
res, err := mt.Coll.UpdateByID(context.Background(), id, update, options.Update().SetUpsert(true))
res, err := mt.Coll.UpdateByID(context.Background(), id, update, options.UpdateOne().SetUpsert(true))
assert.Nil(mt, err, "UpdateByID error: %v", err)
assert.Equal(mt, int64(0), res.MatchedCount, "expected matched count 0, got %v", res.MatchedCount)
assert.Equal(mt, int64(0), res.ModifiedCount, "expected modified count 0, got %v", res.ModifiedCount)
Expand Down
24 changes: 21 additions & 3 deletions internal/integration/crud_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ func createHint(mt *mtest.T, val bson.RawValue) interface{} {
return hint
}

// create a sort document from a bson.RawValue
func createSort(mt *mtest.T, val bson.RawValue) interface{} {
mt.Helper()

var sort interface{}
switch val.Type {
case bson.TypeEmbeddedDocument:
sort = val.Document()
default:
mt.Fatalf("unrecognized sort value type: %s\n", val.Type)
}
return sort
}

// returns true if err is a mongo.CommandError containing a code that is expected from a killAllSessions command.
func isExpectedKillAllSessionsError(err error) bool {
cmdErr, ok := err.(mongo.CommandError)
Expand Down Expand Up @@ -874,7 +888,7 @@ func executeUpdateOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.U

filter := emptyDoc
var update interface{} = emptyDoc
opts := options.Update()
opts := options.UpdateOne()

elems, _ := args.Elements()
for _, elem := range elems {
Expand All @@ -896,13 +910,15 @@ func executeUpdateOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.U
opts = opts.SetCollation(createCollation(mt, val.Document()))
case "hint":
opts = opts.SetHint(createHint(mt, val))
case "sort":
opts = opts.SetSort(createSort(mt, val))
case "session":
default:
mt.Fatalf("unrecognized updateOne option: %v", key)
}
}

updateArgs, err := mongoutil.NewOptions[options.UpdateOptions](opts)
updateArgs, err := mongoutil.NewOptions[options.UpdateOneOptions](opts)
require.NoError(mt, err, "failed to construct options from builder")

if updateArgs.Upsert == nil {
Expand Down Expand Up @@ -954,7 +970,7 @@ func executeUpdateMany(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.
}
}

updateArgs, err := mongoutil.NewOptions[options.UpdateOptions](opts)
updateArgs, err := mongoutil.NewOptions[options.UpdateManyOptions](opts)
require.NoError(mt, err, "failed to construct options from builder")

if updateArgs.Upsert == nil {
Expand Down Expand Up @@ -996,6 +1012,8 @@ func executeReplaceOne(mt *mtest.T, sess *mongo.Session, args bson.Raw) (*mongo.
opts = opts.SetCollation(createCollation(mt, val.Document()))
case "hint":
opts = opts.SetHint(createHint(mt, val))
case "sort":
opts = opts.SetSort(createSort(mt, val))
case "session":
default:
mt.Fatalf("unrecognized replaceOne option: %v", key)
Expand Down
12 changes: 12 additions & 0 deletions internal/integration/unified/bulkwrite_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) {
if err != nil {
return nil, fmt.Errorf("error creating update: %w", err)
}
case "sort":
sort, err := createSort(val)
if err != nil {
return nil, fmt.Errorf("error creating sort: %w", err)
}
uom.SetSort(sort)
case "upsert":
uom.SetUpsert(val.Boolean())
default:
Expand Down Expand Up @@ -249,6 +255,12 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) {
return nil, fmt.Errorf("error creating hint: %w", err)
}
rom.SetHint(hint)
case "sort":
sort, err := createSort(val)
if err != nil {
return nil, fmt.Errorf("error creating sort: %w", err)
}
rom.SetSort(sort)
case "replacement":
replacement = val.Document()
case "upsert":
Expand Down
10 changes: 8 additions & 2 deletions internal/integration/unified/collection_operation_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,12 @@ func executeReplaceOne(ctx context.Context, operation *operation) (*operationRes
return nil, fmt.Errorf("error creating hint: %w", err)
}
opts.SetHint(hint)
case "sort":
sort, err := createSort(val)
if err != nil {
return nil, fmt.Errorf("error creating sort: %w", err)
}
opts.SetSort(sort)
case "replacement":
replacement = val.Document()
case "upsert":
Expand All @@ -1316,7 +1322,7 @@ func executeUpdateOne(ctx context.Context, operation *operation) (*operationResu
return nil, err
}

updateArgs, err := createUpdateArguments(operation.Arguments)
updateArgs, err := createUpdateArguments[options.UpdateOneOptions](operation.Arguments)
if err != nil {
return nil, err
}
Expand All @@ -1335,7 +1341,7 @@ func executeUpdateMany(ctx context.Context, operation *operation) (*operationRes
return nil, err
}

updateArgs, err := createUpdateArguments(operation.Arguments)
updateArgs, err := createUpdateArguments[options.UpdateManyOptions](operation.Arguments)
if err != nil {
return nil, err
}
Expand Down
58 changes: 44 additions & 14 deletions internal/integration/unified/crud_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package unified

import (
"fmt"
"reflect"
"strings"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/bsonutil"
Expand All @@ -20,65 +22,81 @@ func newMissingArgumentError(arg string) error {
return fmt.Errorf("operation arguments document is missing required field %q", arg)
}

type updateArguments struct {
type updateArguments[Options options.UpdateManyOptions | options.UpdateOneOptions] struct {
filter bson.Raw
update interface{}
opts *options.UpdateOptionsBuilder
opts options.Lister[Options]
}

func createUpdateArguments(args bson.Raw) (*updateArguments, error) {
ua := &updateArguments{
opts: options.Update(),
func createUpdateArguments[Options options.UpdateManyOptions | options.UpdateOneOptions](args bson.Raw) (*updateArguments[Options], error) {
ua := &updateArguments[Options]{}
var builder reflect.Value
switch any((*Options)(nil)).(type) {
case *options.UpdateManyOptions:
builder = reflect.ValueOf(options.Update())
case *options.UpdateOneOptions:
builder = reflect.ValueOf(options.UpdateOne())
}
var err error

elems, _ := args.Elements()
for _, elem := range elems {
key := elem.Key()
val := elem.Value()

var arg reflect.Value
switch key {
case "arrayFilters":
ua.opts.SetArrayFilters(
arg = reflect.ValueOf(
bsonutil.RawToInterfaces(bsonutil.RawArrayToDocuments(val.Array())...),
)
case "bypassDocumentValidation":
ua.opts.SetBypassDocumentValidation(val.Boolean())
arg = reflect.ValueOf(val.Boolean())
case "collation":
collation, err := createCollation(val.Document())
if err != nil {
return nil, fmt.Errorf("error creating collation: %w", err)
}
ua.opts.SetCollation(collation)
arg = reflect.ValueOf(collation)
case "comment":
ua.opts.SetComment(val)
arg = reflect.ValueOf(val)
case "filter":
ua.filter = val.Document()
case "hint":
hint, err := createHint(val)
if err != nil {
return nil, fmt.Errorf("error creating hint: %w", err)
}
ua.opts.SetHint(hint)
case "let":
ua.opts.SetLet(val.Document())
arg = reflect.ValueOf(hint)
case "let", "sort":
arg = reflect.ValueOf(val.Document())
case "update":
var err error
ua.update, err = createUpdateValue(val)
if err != nil {
return nil, fmt.Errorf("error processing update value: %w", err)
}
case "upsert":
ua.opts.SetUpsert(val.Boolean())
arg = reflect.ValueOf(val.Boolean())
default:
return nil, fmt.Errorf("unrecognized update option %q", key)
}
if arg.IsValid() {
fn := builder.MethodByName(
fmt.Sprintf("Set%s%s", strings.ToUpper(string(key[0])), key[1:]),
)
if !fn.IsValid() {
return nil, fmt.Errorf("unrecognized update option %q", key)
}
fn.Call([]reflect.Value{arg})
}
}
if ua.filter == nil {
return nil, newMissingArgumentError("filter")
}
if ua.update == nil {
return nil, newMissingArgumentError("update")
}
ua.opts = builder.Interface().(options.Lister[Options])

return ua, nil
}
Expand Down Expand Up @@ -158,3 +176,15 @@ func createHint(val bson.RawValue) (interface{}, error) {
}
return hint, nil
}

func createSort(val bson.RawValue) (interface{}, error) {
var sort interface{}

switch val.Type {
case bson.TypeEmbeddedDocument:
sort = val.Document()
default:
return nil, fmt.Errorf("unrecognized sort value type %s", val.Type)
}
return sort, nil
}
Loading

0 comments on commit fab1a09

Please sign in to comment.