diff --git a/mongo/integration/retryable_reads_prose_test.go b/mongo/integration/retryable_reads_prose_test.go index 80d7937e8c..80f4d3329a 100644 --- a/mongo/integration/retryable_reads_prose_test.go +++ b/mongo/integration/retryable_reads_prose_test.go @@ -16,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/eventtest" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/integration/mtest" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -102,4 +103,95 @@ func TestRetryableReadsProse(t *testing.T) { "expected a find event, got a(n) %v event", cmdEvt.CommandName) } }) + + mtOpts = mtest.NewOptions().Topologies(mtest.Sharded).MinServerVersion("4.2") + mt.RunOpts("retrying in sharded cluster", mtOpts, func(mt *mtest.T) { + tests := []struct { + name string + + // Note that setting this value greater than 2 will result in false + // negatives. The current specification does not account for CSOT, which + // might allow for an "inifinite" number of retries over a period of time. + // Because of this, we only track the "previous server". + hostCount int + failpointErrorCode int32 + expectedFailCount int + expectedSuccessCount int + }{ + { + name: "retry on different mongos", + hostCount: 2, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 2, + expectedSuccessCount: 0, + }, + { + name: "retry on same mongos", + hostCount: 1, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 1, + expectedSuccessCount: 1, + }, + } + + for _, tc := range tests { + mt.Run(tc.name, func(mt *mtest.T) { + hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts + require.GreaterOrEqualf(mt, len(hosts), tc.hostCount, + "test cluster must have at least %v mongos hosts", tc.hostCount) + + // Configure the failpoint options for each mongos. + failPoint := mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"find"}, + ErrorCode: tc.failpointErrorCode, + CloseConnection: false, + }, + } + + // In order to ensure that each mongos in the hostCount-many mongos + // hosts are tried at least once (i.e. failures are deprioritized), we + // set a failpoint on all mongos hosts. The idea is that if we get + // hostCount-many failures, then by the pigeonhole principal all mongos + // hosts must have been tried. + for i := 0; i < tc.hostCount; i++ { + mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + mt.SetFailPoint(failPoint) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the + // specific mongos when the test is done. + defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + defer mt.ClearFailPoints() + } + + failCount := 0 + successCount := 0 + + commandMonitor := &event.CommandMonitor{ + Failed: func(context.Context, *event.CommandFailedEvent) { + failCount++ + }, + Succeeded: func(context.Context, *event.CommandSucceededEvent) { + successCount++ + }, + } + + // Reset the client with exactly hostCount-many mongos hosts. + mt.ResetClient(options.Client(). + SetHosts(hosts[:tc.hostCount]). + SetRetryReads(true). + SetMonitor(commandMonitor)) + + mt.Coll.FindOne(context.Background(), bson.D{}) + + assert.Equal(mt, tc.expectedFailCount, failCount) + assert.Equal(mt, tc.expectedSuccessCount, successCount) + }) + } + }) } diff --git a/mongo/integration/retryable_writes_prose_test.go b/mongo/integration/retryable_writes_prose_test.go index b378cdcbb5..1c8d353f14 100644 --- a/mongo/integration/retryable_writes_prose_test.go +++ b/mongo/integration/retryable_writes_prose_test.go @@ -284,4 +284,96 @@ func TestRetryableWritesProse(t *testing.T) { // Assert that the "ShutdownInProgress" error is returned. require.True(mt, err.(mongo.WriteException).HasErrorCode(int(shutdownInProgressErrorCode))) }) + + mtOpts = mtest.NewOptions().Topologies(mtest.Sharded).MinServerVersion("4.2") + mt.RunOpts("retrying in sharded cluster", mtOpts, func(mt *mtest.T) { + tests := []struct { + name string + + // Note that setting this value greater than 2 will result in false + // negatives. The current specification does not account for CSOT, which + // might allow for an "inifinite" number of retries over a period of time. + // Because of this, we only track the "previous server". + hostCount int + failpointErrorCode int32 + expectedFailCount int + expectedSuccessCount int + }{ + { + name: "retry on different mongos", + hostCount: 2, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 2, + expectedSuccessCount: 0, + }, + { + name: "retry on same mongos", + hostCount: 1, + failpointErrorCode: 6, // HostUnreachable + expectedFailCount: 1, + expectedSuccessCount: 1, + }, + } + + for _, tc := range tests { + mt.Run(tc.name, func(mt *mtest.T) { + hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts + require.GreaterOrEqualf(mt, len(hosts), tc.hostCount, + "test cluster must have at least %v mongos hosts", tc.hostCount) + + // Configure the failpoint options for each mongos. + failPoint := mtest.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: mtest.FailPointMode{ + Times: 1, + }, + Data: mtest.FailPointData{ + FailCommands: []string{"insert"}, + ErrorLabels: &[]string{"RetryableWriteError"}, + ErrorCode: tc.failpointErrorCode, + CloseConnection: false, + }, + } + + // In order to ensure that each mongos in the hostCount-many mongos + // hosts are tried at least once (i.e. failures are deprioritized), we + // set a failpoint on all mongos hosts. The idea is that if we get + // hostCount-many failures, then by the pigeonhole principal all mongos + // hosts must have been tried. + for i := 0; i < tc.hostCount; i++ { + mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + mt.SetFailPoint(failPoint) + + // The automatic failpoint clearing may not clear failpoints set on + // specific hosts, so manually clear the failpoint we set on the + // specific mongos when the test is done. + defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]})) + defer mt.ClearFailPoints() + } + + failCount := 0 + successCount := 0 + + commandMonitor := &event.CommandMonitor{ + Failed: func(context.Context, *event.CommandFailedEvent) { + failCount++ + }, + Succeeded: func(context.Context, *event.CommandSucceededEvent) { + successCount++ + }, + } + + // Reset the client with exactly hostCount-many mongos hosts. + mt.ResetClient(options.Client(). + SetHosts(hosts[:tc.hostCount]). + SetRetryWrites(true). + SetMonitor(commandMonitor)) + + _, _ = mt.Coll.InsertOne(context.Background(), bson.D{}) + + assert.Equal(mt, tc.expectedFailCount, failCount) + assert.Equal(mt, tc.expectedSuccessCount, successCount) + }) + } + }) } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 229988e133..6b56191a01 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -322,8 +322,73 @@ func (op Operation) shouldEncrypt() bool { return op.Crypt != nil && !op.Crypt.BypassAutoEncryption() } +// filterDeprioritizedServers will filter out the server candidates that have +// been deprioritized by the operation due to failure. +// +// The server selector should try to select a server that is not in the +// deprioritization list. However, if this is not possible (e.g. there are no +// other healthy servers in the cluster), the selector may return a +// deprioritized server. +func filterDeprioritizedServers(candidates, deprioritized []description.Server) []description.Server { + if len(deprioritized) == 0 { + return candidates + } + + dpaSet := make(map[address.Address]*description.Server) + for i, srv := range deprioritized { + dpaSet[srv.Addr] = &deprioritized[i] + } + + allowed := []description.Server{} + + // Iterate over the candidates and append them to the allowdIndexes slice if + // they are not in the deprioritizedServers list. + for _, candidate := range candidates { + if srv, ok := dpaSet[candidate.Addr]; !ok || !srv.Equal(candidate) { + allowed = append(allowed, candidate) + } + } + + // If nothing is allowed, then all available servers must have been + // deprioritized. In this case, return the candidates list as-is so that the + // selector can find a suitable server + if len(allowed) == 0 { + return candidates + } + + return allowed +} + +// opServerSelector is a wrapper for the server selector that is assigned to the +// operation. The purpose of this wrapper is to filter candidates with +// operation-specific logic, such as deprioritizing failing servers. +type opServerSelector struct { + selector description.ServerSelector + deprioritizedServers []description.Server +} + +// SelectServer will filter candidates with operation-specific logic before +// passing them onto the user-defined or default selector. +func (oss *opServerSelector) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + selectedServers, err := oss.selector.SelectServer(topo, candidates) + if err != nil { + return nil, err + } + + filteredServers := filterDeprioritizedServers(selectedServers, oss.deprioritizedServers) + + return filteredServers, nil +} + // selectServer handles performing server selection for an operation. -func (op Operation) selectServer(ctx context.Context, requestID int32) (Server, error) { +func (op Operation) selectServer( + ctx context.Context, + requestID int32, + deprioritized []description.Server, +) (Server, error) { if err := op.Validate(); err != nil { return nil, err } @@ -340,15 +405,24 @@ func (op Operation) selectServer(ctx context.Context, requestID int32) (Server, }) } + oss := &opServerSelector{ + selector: selector, + deprioritizedServers: deprioritized, + } + ctx = logger.WithOperationName(ctx, op.Name) ctx = logger.WithOperationID(ctx, requestID) - return op.Deployment.SelectServer(ctx, selector) + return op.Deployment.SelectServer(ctx, oss) } // getServerAndConnection should be used to retrieve a Server and Connection to execute an operation. -func (op Operation) getServerAndConnection(ctx context.Context, requestID int32) (Server, Connection, error) { - server, err := op.selectServer(ctx, requestID) +func (op Operation) getServerAndConnection( + ctx context.Context, + requestID int32, + deprioritized []description.Server, +) (Server, Connection, error) { + server, err := op.selectServer(ctx, requestID, deprioritized) if err != nil { if op.Client != nil && !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { @@ -481,6 +555,11 @@ func (op Operation) Execute(ctx context.Context) error { first := true currIndex := 0 + // deprioritizedServers are a running list of servers that should be + // deprioritized during server selection. Per the specifications, we should + // only ever deprioritize the "previous server". + var deprioritizedServers []description.Server + // resetForRetry records the error that caused the retry, decrements retries, and resets the // retry loop variables to request a new server and a new connection for the next attempt. resetForRetry := func(err error) { @@ -506,11 +585,18 @@ func (op Operation) Execute(ctx context.Context) error { } } - // If we got a connection, close it immediately to release pool resources for - // subsequent retries. + // If we got a connection, close it immediately to release pool resources + // for subsequent retries. if conn != nil { + // If we are dealing with a sharded cluster, then mark the failed server + // as "deprioritized". + if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.Sharded { + deprioritizedServers = []description.Server{conn.Description()} + } + conn.Close() } + // Set the server and connection to nil to request a new server and connection. srvr = nil conn = nil @@ -535,7 +621,7 @@ func (op Operation) Execute(ctx context.Context) error { // If the server or connection are nil, try to select a new server and get a new connection. if srvr == nil || conn == nil { - srvr, conn, err = op.getServerAndConnection(ctx, requestID) + srvr, conn, err = op.getServerAndConnection(ctx, requestID, deprioritizedServers) if err != nil { // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 8509b5da9b..e6c9d4cf95 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -20,6 +20,7 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/internal/handshake" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/internal/uuid" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" @@ -62,7 +63,7 @@ func TestOperation(t *testing.T) { t.Run("selectServer", func(t *testing.T) { t.Run("returns validation error", func(t *testing.T) { op := &Operation{} - _, err := op.selectServer(context.Background(), 1) + _, err := op.selectServer(context.Background(), 1, nil) if err == nil { t.Error("Expected a validation error from selectServer, but got ") } @@ -76,11 +77,15 @@ func TestOperation(t *testing.T) { Database: "testing", Selector: want, } - _, err := op.selectServer(context.Background(), 1) + _, err := op.selectServer(context.Background(), 1, nil) noerr(t, err) - got := d.params.selector - if !cmp.Equal(got, want) { - t.Errorf("Did not get expected server selector. got %v; want %v", got, want) + + // Assert the the selector is an operation selector wrapper. + oss, ok := d.params.selector.(*opServerSelector) + require.True(t, ok) + + if !cmp.Equal(oss.selector, want) { + t.Errorf("Did not get expected server selector. got %v; want %v", oss.selector, want) } }) t.Run("uses a default server selector", func(t *testing.T) { @@ -90,7 +95,7 @@ func TestOperation(t *testing.T) { Deployment: d, Database: "testing", } - _, err := op.selectServer(context.Background(), 1) + _, err := op.selectServer(context.Background(), 1, nil) noerr(t, err) if d.params.selector == nil { t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed .") @@ -881,3 +886,123 @@ func TestDecodeOpReply(t *testing.T) { assert.Equal(t, []bsoncore.Document(nil), reply.documents) }) } + +func TestFilterDeprioritizedServers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + deprioritized []description.Server + candidates []description.Server + want []description.Server + }{ + { + name: "empty", + candidates: []description.Server{}, + want: []description.Server{}, + }, + { + name: "nil candidates", + candidates: nil, + want: []description.Server{}, + }, + { + name: "nil deprioritized server list", + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + }, + { + name: "deprioritize single server candidate list", + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + deprioritized: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + want: []description.Server{ + // Since all available servers were deprioritized, then the selector + // should return all candidates. + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + }, + { + name: "depriotirize one server in multi server candidate list", + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + deprioritized: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + }, + { + name: "depriotirize multiple servers in multi server candidate list", + deprioritized: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), + }, + }, + candidates: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27017"), + }, + { + Addr: address.Address("mongodb://localhost:27018"), + }, + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + want: []description.Server{ + { + Addr: address.Address("mongodb://localhost:27019"), + }, + }, + }, + } + + for _, tc := range tests { + tc := tc // Capture the range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := filterDeprioritizedServers(tc.candidates, tc.deprioritized) + assert.ElementsMatch(t, got, tc.want) + }) + } +}