Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Aug 27, 2024
1 parent 2bdacda commit 686ef6d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 78 deletions.
122 changes: 62 additions & 60 deletions mongo/bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,44 +333,39 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera

switch converted := model.(type) {
case *ReplaceOneModel:
doc, err = createUpdateDoc(
converted.Filter,
converted.Replacement,
converted.Hint,
nil,
converted.Collation,
converted.Upsert,
false,
false,
bw.collection.bsonOpts,
bw.collection.registry)
doc, err = updateDoc{
filter: converted.Filter,
update: converted.Replacement,
hint: converted.Hint,
sort: converted.Sort,
collation: converted.Collation,
upsert: converted.Upsert,
}.marshal(bw.collection.bsonOpts, bw.collection.registry)
hasHint = hasHint || (converted.Hint != nil)
case *UpdateOneModel:
doc, err = createUpdateDoc(
converted.Filter,
converted.Update,
converted.Hint,
converted.ArrayFilters,
converted.Collation,
converted.Upsert,
false,
true,
bw.collection.bsonOpts,
bw.collection.registry)
doc, err = updateDoc{
filter: converted.Filter,
update: converted.Update,
hint: converted.Hint,
sort: converted.Sort,
arrayFilters: converted.ArrayFilters,
collation: converted.Collation,
upsert: converted.Upsert,
checkDollarKey: true,
}.marshal(bw.collection.bsonOpts, bw.collection.registry)
hasHint = hasHint || (converted.Hint != nil)
hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
case *UpdateManyModel:
doc, err = createUpdateDoc(
converted.Filter,
converted.Update,
converted.Hint,
converted.ArrayFilters,
converted.Collation,
converted.Upsert,
true,
true,
bw.collection.bsonOpts,
bw.collection.registry)
doc, err = updateDoc{
filter: converted.Filter,
update: converted.Update,
hint: converted.Hint,
arrayFilters: converted.ArrayFilters,
collation: converted.Collation,
upsert: converted.Upsert,
multi: true,
checkDollarKey: true,
}.marshal(bw.collection.bsonOpts, bw.collection.registry)
hasHint = hasHint || (converted.Hint != nil)
hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
}
Expand Down Expand Up @@ -420,62 +415,69 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera
return op.Result(), err
}

func createUpdateDoc(
filter interface{},
update interface{},
hint interface{},
arrayFilters *options.ArrayFilters,
collation *options.Collation,
upsert *bool,
multi bool,
checkDollarKey bool,
bsonOpts *options.BSONOptions,
registry *bsoncodec.Registry,
) (bsoncore.Document, error) {
f, err := marshal(filter, bsonOpts, registry)
type updateDoc struct {
filter interface{}
update interface{}
hint interface{}
sort interface{}
arrayFilters *options.ArrayFilters
collation *options.Collation
upsert *bool
multi bool
checkDollarKey bool
}

func (doc updateDoc) marshal(bsonOpts *options.BSONOptions, registry *bsoncodec.Registry) (bsoncore.Document, error) {
f, err := marshal(doc.filter, bsonOpts, registry)
if err != nil {
return nil, err
}

uidx, updateDoc := bsoncore.AppendDocumentStart(nil)
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f)

u, err := marshalUpdateValue(update, bsonOpts, registry, checkDollarKey)
u, err := marshalUpdateValue(doc.update, bsonOpts, registry, doc.checkDollarKey)
if err != nil {
return nil, err
}

updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)

if multi {
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi)
if doc.multi {
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", doc.multi)
} else if doc.sort != nil {
s, err := marshal(doc.sort, bsonOpts, registry)
if err != nil {
return nil, err
}
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "sort", s)
}

if arrayFilters != nil {
if doc.arrayFilters != nil {
reg := registry
if arrayFilters.Registry != nil {
reg = arrayFilters.Registry
if doc.arrayFilters.Registry != nil {
reg = doc.arrayFilters.Registry
}
arr, err := marshalValue(arrayFilters.Filters, bsonOpts, reg)
arr, err := marshalValue(doc.arrayFilters.Filters, bsonOpts, reg)
if err != nil {
return nil, err
}
updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr.Data)
}

if collation != nil {
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument()))
if doc.collation != nil {
updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(doc.collation.ToDocument()))
}

if upsert != nil {
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert)
if doc.upsert != nil {
updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *doc.upsert)
}

if hint != nil {
if isUnorderedMap(hint) {
if doc.hint != nil {
if isUnorderedMap(doc.hint) {
return nil, ErrMapForOrderedArgument{"hint"}
}
hintVal, err := marshalValue(hint, bsonOpts, registry)
hintVal, err := marshalValue(doc.hint, bsonOpts, registry)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions mongo/bulk_write_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ type ReplaceOneModel struct {
Filter interface{}
Replacement interface{}
Hint interface{}
Sort interface{}
}

// NewReplaceOneModel creates a new ReplaceOneModel.
Expand Down Expand Up @@ -183,6 +184,7 @@ type UpdateOneModel struct {
Update interface{}
ArrayFilters *options.ArrayFilters
Hint interface{}
Sort interface{}
}

// NewUpdateOneModel creates a new UpdateOneModel.
Expand Down
29 changes: 11 additions & 18 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,17 +548,17 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc

// collation, arrayFilters, upsert, and hint are included on the individual update documents rather than as part of the
// command
updateDoc, err := createUpdateDoc(
filter,
update,
uo.Hint,
uo.ArrayFilters,
uo.Collation,
uo.Upsert,
multi,
checkDollarKey,
coll.bsonOpts,
coll.registry)
updateDoc, err := updateDoc{
filter: filter,
update: update,
hint: uo.Hint,
sort: uo.Sort,
arrayFilters: uo.ArrayFilters,
collation: uo.Collation,
upsert: uo.Upsert,
multi: multi,
checkDollarKey: checkDollarKey,
}.marshal(coll.bsonOpts, coll.registry)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -598,13 +598,6 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc
}
op = op.Let(let)
}
if !multi && uo.Sort != nil {
sort, err := marshal(uo.Sort, coll.bsonOpts, coll.registry)
if err != nil {
return nil, err
}
op = op.Sort(sort)
}
if uo.BypassDocumentValidation != nil && *uo.BypassDocumentValidation {
op = op.BypassDocumentValidation(*uo.BypassDocumentValidation)
}
Expand Down

0 comments on commit 686ef6d

Please sign in to comment.