From 68bf155d63446251f63603eebd74a12707bcdf51 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 7 Sep 2023 10:41:47 -0600 Subject: [PATCH] GODRIVER-2972 Fix wiremessage RequestID race in operation.Execute (#1375) --- x/mongo/driver/operation.go | 23 ++++++++++++++--------- x/mongo/driver/operation_test.go | 9 +++++---- x/mongo/driver/wiremessage/wiremessage.go | 3 --- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 90573daa53..8e52773503 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -322,7 +322,7 @@ func (op Operation) shouldEncrypt() bool { } // selectServer handles performing server selection for an operation. -func (op Operation) selectServer(ctx context.Context) (Server, error) { +func (op Operation) selectServer(ctx context.Context, requestID int32) (Server, error) { if err := op.Validate(); err != nil { return nil, err } @@ -340,14 +340,14 @@ func (op Operation) selectServer(ctx context.Context) (Server, error) { } ctx = logger.WithOperationName(ctx, op.Name) - ctx = logger.WithOperationID(ctx, wiremessage.CurrentRequestID()) + ctx = logger.WithOperationID(ctx, requestID) return op.Deployment.SelectServer(ctx, selector) } // getServerAndConnection should be used to retrieve a Server and Connection to execute an operation. -func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) { - server, err := op.selectServer(ctx) +func (op Operation) getServerAndConnection(ctx context.Context, requestID int32) (Server, Connection, error) { + server, err := op.selectServer(ctx, requestID) if err != nil { if op.Client != nil && !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { @@ -530,11 +530,11 @@ func (op Operation) Execute(ctx context.Context) error { } }() for { - wiremessage.NextRequestID() + requestID := wiremessage.NextRequestID() // If the server or connection are nil, try to select a new server and get a new connection. if srvr == nil || conn == nil { - srvr, conn, err = op.getServerAndConnection(ctx) + srvr, conn, err = op.getServerAndConnection(ctx, requestID) if err != nil { // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server @@ -629,7 +629,7 @@ func (op Operation) Execute(ctx context.Context) error { } var startedInfo startedInformation - *wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn) + *wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) if err != nil { return err @@ -1103,8 +1103,13 @@ func (op Operation) addBatchArray(dst []byte) []byte { return dst } -func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer, +func (op Operation) createMsgWireMessage( + ctx context.Context, + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, conn Connection, + requestID int32, ) ([]byte, startedInformation, error) { var info startedInformation var flags wiremessage.MsgFlag @@ -1120,7 +1125,7 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, flags |= wiremessage.ExhaustAllowed } - info.requestID = wiremessage.CurrentRequestID() + info.requestID = requestID wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg) dst = wiremessage.AppendMsgFlags(dst, flags) // Body diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index d4c5a1b6a0..8509b5da9b 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -62,7 +62,7 @@ func TestOperation(t *testing.T) { t.Run("selectServer", func(t *testing.T) { t.Run("returns validation error", func(t *testing.T) { op := &Operation{} - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1) if err == nil { t.Error("Expected a validation error from selectServer, but got ") } @@ -76,7 +76,7 @@ func TestOperation(t *testing.T) { Database: "testing", Selector: want, } - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1) noerr(t, err) got := d.params.selector if !cmp.Equal(got, want) { @@ -90,7 +90,7 @@ func TestOperation(t *testing.T) { Deployment: d, Database: "testing", } - _, err := op.selectServer(context.Background()) + _, err := op.selectServer(context.Background(), 1) noerr(t, err) if d.params.selector == nil { t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed .") @@ -652,7 +652,8 @@ func TestOperation(t *testing.T) { } func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte { - idx, wm := wiremessage.AppendHeaderStart(nil, 0, wiremessage.CurrentRequestID()+1, wiremessage.OpMsg) + const psuedoRequestID = 1 + idx, wm := wiremessage.AppendHeaderStart(nil, 0, psuedoRequestID, wiremessage.OpMsg) var flags wiremessage.MsgFlag if moreToCome { flags = wiremessage.MoreToCome diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index c4d2567bf0..abf09c15bd 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -19,9 +19,6 @@ type WireMessage []byte var globalRequestID int32 -// CurrentRequestID returns the current request ID. -func CurrentRequestID() int32 { return atomic.LoadInt32(&globalRequestID) } - // NextRequestID returns the next request ID. func NextRequestID() int32 { return atomic.AddInt32(&globalRequestID, 1) }