Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Oct 16, 2024
1 parent 45d22f8 commit b8309ff
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 184 deletions.
92 changes: 55 additions & 37 deletions mongo/client_bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

// bulkWrite performs a bulkwrite operation
type clientBulkWrite struct {
models []interface{}
models []clientWriteModel
errorsOnly bool
ordered *bool
bypassDocumentValidation *bool
Expand All @@ -45,12 +45,17 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
if len(bw.models) == 0 {
return errors.New("empty write models")
}
retryMode := driver.RetryNone
if bw.client.retryWrites {
retryMode = driver.RetryOncePerCommand
}
batches := &modelBatches{
session: bw.session,
client: bw.client,
ordered: bw.ordered,
models: bw.models,
result: &bw.result,
session: bw.session,
client: bw.client,
ordered: bw.ordered,
models: bw.models,
result: &bw.result,
retryMode: retryMode,
}
err := driver.Operation{
CommandFn: bw.newCommand(),
Expand Down Expand Up @@ -142,7 +147,7 @@ type modelBatches struct {
client *Client

ordered *bool
models []interface{}
models []clientWriteModel

offset int

Expand Down Expand Up @@ -222,17 +227,14 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
mb.newIDMap = make(map[int]interface{})

nsMap := make(map[string]int)
getNsIndex := func(namespace string) (int, bsoncore.Document) {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "ns", namespace)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)

if v, ok := nsMap[namespace]; ok {
return v, doc
getNsIndex := func(namespace string) (int, bool) {
v, ok := nsMap[namespace]
if ok {
return v, ok
}
nsIdx := len(nsMap)
nsMap[namespace] = nsIdx
return nsIdx, doc
return nsIdx, ok
}

canRetry := true
Expand All @@ -249,12 +251,13 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
break
}

var nsIdx int
var ns, doc bsoncore.Document
ns := mb.models[i].namespace
nsIdx, exists := getNsIndex(ns)

var doc bsoncore.Document
var err error
switch model := mb.models[i].(type) {
switch model := mb.models[i].model.(type) {
case *ClientInsertOneModel:
nsIdx, ns = getNsIndex(model.Namespace)
mb.cursorHandlers[i] = mb.appendInsertResult
var id interface{}
id, doc, err = (&clientInsertDoc{
Expand All @@ -266,7 +269,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
}
mb.newIDMap[i] = id
case *ClientUpdateOneModel:
nsIdx, ns = getNsIndex(model.Namespace)
mb.cursorHandlers[i] = mb.appendUpdateResult
doc, err = (&clientUpdateDoc{
namespace: nsIdx,
Expand All @@ -281,7 +283,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
}).marshal(mb.client.bsonOpts, mb.client.registry)
case *ClientUpdateManyModel:
canRetry = false
nsIdx, ns = getNsIndex(model.Namespace)
mb.cursorHandlers[i] = mb.appendUpdateResult
doc, err = (&clientUpdateDoc{
namespace: nsIdx,
Expand All @@ -295,7 +296,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
checkDollarKey: true,
}).marshal(mb.client.bsonOpts, mb.client.registry)
case *ClientReplaceOneModel:
nsIdx, ns = getNsIndex(model.Namespace)
mb.cursorHandlers[i] = mb.appendUpdateResult
doc, err = (&clientUpdateDoc{
namespace: nsIdx,
Expand All @@ -309,7 +309,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
checkDollarKey: false,
}).marshal(mb.client.bsonOpts, mb.client.registry)
case *ClientDeleteOneModel:
nsIdx, ns = getNsIndex(model.Namespace)
mb.cursorHandlers[i] = mb.appendDeleteResult
doc, err = (&clientDeleteDoc{
namespace: nsIdx,
Expand All @@ -320,7 +319,6 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
}).marshal(mb.client.bsonOpts, mb.client.registry)
case *ClientDeleteManyModel:
canRetry = false
nsIdx, ns = getNsIndex(model.Namespace)
mb.cursorHandlers[i] = mb.appendDeleteResult
doc, err = (&clientDeleteDoc{
namespace: nsIdx,
Expand All @@ -343,7 +341,12 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, maxD
}

dst = fn.appendDocument(dst, strconv.Itoa(n), doc)
nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), ns)
if !exists {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "ns", ns)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
nsDst = fn.appendDocument(nsDst, strconv.Itoa(n), doc)
}
n++
}
if n == 0 {
Expand Down Expand Up @@ -430,7 +433,7 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
if int(cur.Idx) >= len(mb.cursorHandlers) {
continue
}
ok = ok && mb.cursorHandlers[int(cur.Idx)](&cur, cursor.Current)
ok = mb.cursorHandlers[int(cur.Idx)](&cur, cursor.Current) && ok
}
err = cursor.Err()
if err != nil {
Expand All @@ -456,32 +459,51 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
}

func (mb *modelBatches) appendDeleteResult(cur *cursorInfo, raw bson.Raw) bool {
if err := cur.extractError(); err != nil {
err.Raw = raw
if mb.writeErrors == nil {
mb.writeErrors = make(map[int]WriteError)
}
mb.writeErrors[int(cur.Idx)] = *err
return false
}

if mb.result.DeleteResults == nil {
mb.result.DeleteResults = make(map[int]ClientDeleteResult)
}
mb.result.DeleteResults[int(cur.Idx)] = ClientDeleteResult{int64(cur.N)}

return true
}

func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool {
if err := cur.extractError(); err != nil {
err.Raw = raw
if mb.writeErrors == nil {
mb.writeErrors = make(map[int]WriteError)
}
mb.writeErrors[int(cur.Idx)] = *err
return false
}
return true
}

func (mb *modelBatches) appendInsertResult(cur *cursorInfo, raw bson.Raw) bool {
if mb.result.InsertResults == nil {
mb.result.InsertResults = make(map[int]ClientInsertResult)
}
mb.result.InsertResults[int(cur.Idx)] = ClientInsertResult{mb.newIDMap[int(cur.Idx)]}

return true
}

func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool {
if err := cur.extractError(); err != nil {
err.Raw = raw
if mb.writeErrors == nil {
mb.writeErrors = make(map[int]WriteError)
}
mb.writeErrors[int(cur.Idx)] = *err
return false
}
return true
}

func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool {
if mb.result.UpdateResults == nil {
mb.result.UpdateResults = make(map[int]ClientUpdateResult)
}
Expand All @@ -495,11 +517,7 @@ func (mb *modelBatches) appendUpdateResult(cur *cursorInfo, raw bson.Raw) bool {
result.UpsertedID = cur.Upserted.ID
}
mb.result.UpdateResults[int(cur.Idx)] = result
if err := cur.extractError(); err != nil {
err.Raw = raw
mb.writeErrors[int(cur.Idx)] = *err
return false
}

return true
}

Expand Down
Loading

0 comments on commit b8309ff

Please sign in to comment.