Skip to content

Commit

Permalink
GODRIVER-2972 Fix wiremessage RequestID race in operation.Execute (#1375
Browse files Browse the repository at this point in the history
)
  • Loading branch information
prestonvasquez committed Sep 7, 2023
1 parent 84a4385 commit 68bf155
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
23 changes: 14 additions & 9 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func (op Operation) shouldEncrypt() bool {
}

// selectServer handles performing server selection for an operation.
func (op Operation) selectServer(ctx context.Context) (Server, error) {
func (op Operation) selectServer(ctx context.Context, requestID int32) (Server, error) {
if err := op.Validate(); err != nil {
return nil, err
}
Expand All @@ -340,14 +340,14 @@ func (op Operation) selectServer(ctx context.Context) (Server, error) {
}

ctx = logger.WithOperationName(ctx, op.Name)
ctx = logger.WithOperationID(ctx, wiremessage.CurrentRequestID())
ctx = logger.WithOperationID(ctx, requestID)

return op.Deployment.SelectServer(ctx, selector)
}

// getServerAndConnection should be used to retrieve a Server and Connection to execute an operation.
func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) {
server, err := op.selectServer(ctx)
func (op Operation) getServerAndConnection(ctx context.Context, requestID int32) (Server, Connection, error) {
server, err := op.selectServer(ctx, requestID)
if err != nil {
if op.Client != nil &&
!(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() {
Expand Down Expand Up @@ -530,11 +530,11 @@ func (op Operation) Execute(ctx context.Context) error {
}
}()
for {
wiremessage.NextRequestID()
requestID := wiremessage.NextRequestID()

// If the server or connection are nil, try to select a new server and get a new connection.
if srvr == nil || conn == nil {
srvr, conn, err = op.getServerAndConnection(ctx)
srvr, conn, err = op.getServerAndConnection(ctx, requestID)
if err != nil {
// If the returned error is retryable and there are retries remaining (negative
// retries means retry indefinitely), then retry the operation. Set the server
Expand Down Expand Up @@ -629,7 +629,7 @@ func (op Operation) Execute(ctx context.Context) error {
}

var startedInfo startedInformation
*wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn)
*wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)

if err != nil {
return err
Expand Down Expand Up @@ -1103,8 +1103,13 @@ func (op Operation) addBatchArray(dst []byte) []byte {
return dst
}

func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer,
func (op Operation) createMsgWireMessage(
ctx context.Context,
maxTimeMS uint64,
dst []byte,
desc description.SelectedServer,
conn Connection,
requestID int32,
) ([]byte, startedInformation, error) {
var info startedInformation
var flags wiremessage.MsgFlag
Expand All @@ -1120,7 +1125,7 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64,
flags |= wiremessage.ExhaustAllowed
}

info.requestID = wiremessage.CurrentRequestID()
info.requestID = requestID
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg)
dst = wiremessage.AppendMsgFlags(dst, flags)
// Body
Expand Down
9 changes: 5 additions & 4 deletions x/mongo/driver/operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestOperation(t *testing.T) {
t.Run("selectServer", func(t *testing.T) {
t.Run("returns validation error", func(t *testing.T) {
op := &Operation{}
_, err := op.selectServer(context.Background())
_, err := op.selectServer(context.Background(), 1)
if err == nil {
t.Error("Expected a validation error from selectServer, but got <nil>")
}
Expand All @@ -76,7 +76,7 @@ func TestOperation(t *testing.T) {
Database: "testing",
Selector: want,
}
_, err := op.selectServer(context.Background())
_, err := op.selectServer(context.Background(), 1)
noerr(t, err)
got := d.params.selector
if !cmp.Equal(got, want) {
Expand All @@ -90,7 +90,7 @@ func TestOperation(t *testing.T) {
Deployment: d,
Database: "testing",
}
_, err := op.selectServer(context.Background())
_, err := op.selectServer(context.Background(), 1)
noerr(t, err)
if d.params.selector == nil {
t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed <nil>.")
Expand Down Expand Up @@ -652,7 +652,8 @@ func TestOperation(t *testing.T) {
}

func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte {
idx, wm := wiremessage.AppendHeaderStart(nil, 0, wiremessage.CurrentRequestID()+1, wiremessage.OpMsg)
const psuedoRequestID = 1
idx, wm := wiremessage.AppendHeaderStart(nil, 0, psuedoRequestID, wiremessage.OpMsg)
var flags wiremessage.MsgFlag
if moreToCome {
flags = wiremessage.MoreToCome
Expand Down
3 changes: 0 additions & 3 deletions x/mongo/driver/wiremessage/wiremessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ type WireMessage []byte

var globalRequestID int32

// CurrentRequestID returns the current request ID.
func CurrentRequestID() int32 { return atomic.LoadInt32(&globalRequestID) }

// NextRequestID returns the next request ID.
func NextRequestID() int32 { return atomic.AddInt32(&globalRequestID, 1) }

Expand Down

0 comments on commit 68bf155

Please sign in to comment.