Skip to content

Commit

Permalink
GODRIVER-2101 Direct read/write retries to another mongos if possible (
Browse files Browse the repository at this point in the history
…#1358)

* GODRIVER-2101 Expand test to use pigeonhole principle

* GODRIVER-2101 Direct read/write retries to another mongos if possible

* GODRIVER-2101 Revert unecessary changes

* GODRIVER-2101 revert changes to collection and cursor

* GODRIVER-2101 Apply opServerSelector

* GODRIVER-2101 Fix static analysis errors

* GODRIVER-2101 Remove empty line

* GODRIVER-2101 Use map 'ok' value
  • Loading branch information
prestonvasquez authored Sep 12, 2023
1 parent b191b72 commit d92f20d
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 13 deletions.
92 changes: 92 additions & 0 deletions mongo/integration/retryable_reads_prose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/eventtest"
"go.mongodb.org/mongo-driver/internal/require"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
"go.mongodb.org/mongo-driver/mongo/options"
)
Expand Down Expand Up @@ -102,4 +103,95 @@ func TestRetryableReadsProse(t *testing.T) {
"expected a find event, got a(n) %v event", cmdEvt.CommandName)
}
})

mtOpts = mtest.NewOptions().Topologies(mtest.Sharded).MinServerVersion("4.2")
mt.RunOpts("retrying in sharded cluster", mtOpts, func(mt *mtest.T) {
tests := []struct {
name string

// Note that setting this value greater than 2 will result in false
// negatives. The current specification does not account for CSOT, which
// might allow for an "inifinite" number of retries over a period of time.
// Because of this, we only track the "previous server".
hostCount int
failpointErrorCode int32
expectedFailCount int
expectedSuccessCount int
}{
{
name: "retry on different mongos",
hostCount: 2,
failpointErrorCode: 6, // HostUnreachable
expectedFailCount: 2,
expectedSuccessCount: 0,
},
{
name: "retry on same mongos",
hostCount: 1,
failpointErrorCode: 6, // HostUnreachable
expectedFailCount: 1,
expectedSuccessCount: 1,
},
}

for _, tc := range tests {
mt.Run(tc.name, func(mt *mtest.T) {
hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts
require.GreaterOrEqualf(mt, len(hosts), tc.hostCount,
"test cluster must have at least %v mongos hosts", tc.hostCount)

// Configure the failpoint options for each mongos.
failPoint := mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"find"},
ErrorCode: tc.failpointErrorCode,
CloseConnection: false,
},
}

// In order to ensure that each mongos in the hostCount-many mongos
// hosts are tried at least once (i.e. failures are deprioritized), we
// set a failpoint on all mongos hosts. The idea is that if we get
// hostCount-many failures, then by the pigeonhole principal all mongos
// hosts must have been tried.
for i := 0; i < tc.hostCount; i++ {
mt.ResetClient(options.Client().SetHosts([]string{hosts[i]}))
mt.SetFailPoint(failPoint)

// The automatic failpoint clearing may not clear failpoints set on
// specific hosts, so manually clear the failpoint we set on the
// specific mongos when the test is done.
defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]}))
defer mt.ClearFailPoints()
}

failCount := 0
successCount := 0

commandMonitor := &event.CommandMonitor{
Failed: func(context.Context, *event.CommandFailedEvent) {
failCount++
},
Succeeded: func(context.Context, *event.CommandSucceededEvent) {
successCount++
},
}

// Reset the client with exactly hostCount-many mongos hosts.
mt.ResetClient(options.Client().
SetHosts(hosts[:tc.hostCount]).
SetRetryReads(true).
SetMonitor(commandMonitor))

mt.Coll.FindOne(context.Background(), bson.D{})

assert.Equal(mt, tc.expectedFailCount, failCount)
assert.Equal(mt, tc.expectedSuccessCount, successCount)
})
}
})
}
92 changes: 92 additions & 0 deletions mongo/integration/retryable_writes_prose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,4 +284,96 @@ func TestRetryableWritesProse(t *testing.T) {
// Assert that the "ShutdownInProgress" error is returned.
require.True(mt, err.(mongo.WriteException).HasErrorCode(int(shutdownInProgressErrorCode)))
})

mtOpts = mtest.NewOptions().Topologies(mtest.Sharded).MinServerVersion("4.2")
mt.RunOpts("retrying in sharded cluster", mtOpts, func(mt *mtest.T) {
tests := []struct {
name string

// Note that setting this value greater than 2 will result in false
// negatives. The current specification does not account for CSOT, which
// might allow for an "inifinite" number of retries over a period of time.
// Because of this, we only track the "previous server".
hostCount int
failpointErrorCode int32
expectedFailCount int
expectedSuccessCount int
}{
{
name: "retry on different mongos",
hostCount: 2,
failpointErrorCode: 6, // HostUnreachable
expectedFailCount: 2,
expectedSuccessCount: 0,
},
{
name: "retry on same mongos",
hostCount: 1,
failpointErrorCode: 6, // HostUnreachable
expectedFailCount: 1,
expectedSuccessCount: 1,
},
}

for _, tc := range tests {
mt.Run(tc.name, func(mt *mtest.T) {
hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts
require.GreaterOrEqualf(mt, len(hosts), tc.hostCount,
"test cluster must have at least %v mongos hosts", tc.hostCount)

// Configure the failpoint options for each mongos.
failPoint := mtest.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: mtest.FailPointMode{
Times: 1,
},
Data: mtest.FailPointData{
FailCommands: []string{"insert"},
ErrorLabels: &[]string{"RetryableWriteError"},
ErrorCode: tc.failpointErrorCode,
CloseConnection: false,
},
}

// In order to ensure that each mongos in the hostCount-many mongos
// hosts are tried at least once (i.e. failures are deprioritized), we
// set a failpoint on all mongos hosts. The idea is that if we get
// hostCount-many failures, then by the pigeonhole principal all mongos
// hosts must have been tried.
for i := 0; i < tc.hostCount; i++ {
mt.ResetClient(options.Client().SetHosts([]string{hosts[i]}))
mt.SetFailPoint(failPoint)

// The automatic failpoint clearing may not clear failpoints set on
// specific hosts, so manually clear the failpoint we set on the
// specific mongos when the test is done.
defer mt.ResetClient(options.Client().SetHosts([]string{hosts[i]}))
defer mt.ClearFailPoints()
}

failCount := 0
successCount := 0

commandMonitor := &event.CommandMonitor{
Failed: func(context.Context, *event.CommandFailedEvent) {
failCount++
},
Succeeded: func(context.Context, *event.CommandSucceededEvent) {
successCount++
},
}

// Reset the client with exactly hostCount-many mongos hosts.
mt.ResetClient(options.Client().
SetHosts(hosts[:tc.hostCount]).
SetRetryWrites(true).
SetMonitor(commandMonitor))

_, _ = mt.Coll.InsertOne(context.Background(), bson.D{})

assert.Equal(mt, tc.expectedFailCount, failCount)
assert.Equal(mt, tc.expectedSuccessCount, successCount)
})
}
})
}
100 changes: 93 additions & 7 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,73 @@ func (op Operation) shouldEncrypt() bool {
return op.Crypt != nil && !op.Crypt.BypassAutoEncryption()
}

// filterDeprioritizedServers will filter out the server candidates that have
// been deprioritized by the operation due to failure.
//
// The server selector should try to select a server that is not in the
// deprioritization list. However, if this is not possible (e.g. there are no
// other healthy servers in the cluster), the selector may return a
// deprioritized server.
func filterDeprioritizedServers(candidates, deprioritized []description.Server) []description.Server {
if len(deprioritized) == 0 {
return candidates
}

dpaSet := make(map[address.Address]*description.Server)
for i, srv := range deprioritized {
dpaSet[srv.Addr] = &deprioritized[i]
}

allowed := []description.Server{}

// Iterate over the candidates and append them to the allowdIndexes slice if
// they are not in the deprioritizedServers list.
for _, candidate := range candidates {
if srv, ok := dpaSet[candidate.Addr]; !ok || !srv.Equal(candidate) {
allowed = append(allowed, candidate)
}
}

// If nothing is allowed, then all available servers must have been
// deprioritized. In this case, return the candidates list as-is so that the
// selector can find a suitable server
if len(allowed) == 0 {
return candidates
}

return allowed
}

// opServerSelector is a wrapper for the server selector that is assigned to the
// operation. The purpose of this wrapper is to filter candidates with
// operation-specific logic, such as deprioritizing failing servers.
type opServerSelector struct {
selector description.ServerSelector
deprioritizedServers []description.Server
}

// SelectServer will filter candidates with operation-specific logic before
// passing them onto the user-defined or default selector.
func (oss *opServerSelector) SelectServer(
topo description.Topology,
candidates []description.Server,
) ([]description.Server, error) {
selectedServers, err := oss.selector.SelectServer(topo, candidates)
if err != nil {
return nil, err
}

filteredServers := filterDeprioritizedServers(selectedServers, oss.deprioritizedServers)

return filteredServers, nil
}

// selectServer handles performing server selection for an operation.
func (op Operation) selectServer(ctx context.Context, requestID int32) (Server, error) {
func (op Operation) selectServer(
ctx context.Context,
requestID int32,
deprioritized []description.Server,
) (Server, error) {
if err := op.Validate(); err != nil {
return nil, err
}
Expand All @@ -340,15 +405,24 @@ func (op Operation) selectServer(ctx context.Context, requestID int32) (Server,
})
}

oss := &opServerSelector{
selector: selector,
deprioritizedServers: deprioritized,
}

ctx = logger.WithOperationName(ctx, op.Name)
ctx = logger.WithOperationID(ctx, requestID)

return op.Deployment.SelectServer(ctx, selector)
return op.Deployment.SelectServer(ctx, oss)
}

// getServerAndConnection should be used to retrieve a Server and Connection to execute an operation.
func (op Operation) getServerAndConnection(ctx context.Context, requestID int32) (Server, Connection, error) {
server, err := op.selectServer(ctx, requestID)
func (op Operation) getServerAndConnection(
ctx context.Context,
requestID int32,
deprioritized []description.Server,
) (Server, Connection, error) {
server, err := op.selectServer(ctx, requestID, deprioritized)
if err != nil {
if op.Client != nil &&
!(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() {
Expand Down Expand Up @@ -481,6 +555,11 @@ func (op Operation) Execute(ctx context.Context) error {
first := true
currIndex := 0

// deprioritizedServers are a running list of servers that should be
// deprioritized during server selection. Per the specifications, we should
// only ever deprioritize the "previous server".
var deprioritizedServers []description.Server

// resetForRetry records the error that caused the retry, decrements retries, and resets the
// retry loop variables to request a new server and a new connection for the next attempt.
resetForRetry := func(err error) {
Expand All @@ -506,11 +585,18 @@ func (op Operation) Execute(ctx context.Context) error {
}
}

// If we got a connection, close it immediately to release pool resources for
// subsequent retries.
// If we got a connection, close it immediately to release pool resources
// for subsequent retries.
if conn != nil {
// If we are dealing with a sharded cluster, then mark the failed server
// as "deprioritized".
if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.Sharded {
deprioritizedServers = []description.Server{conn.Description()}
}

conn.Close()
}

// Set the server and connection to nil to request a new server and connection.
srvr = nil
conn = nil
Expand All @@ -535,7 +621,7 @@ func (op Operation) Execute(ctx context.Context) error {

// If the server or connection are nil, try to select a new server and get a new connection.
if srvr == nil || conn == nil {
srvr, conn, err = op.getServerAndConnection(ctx, requestID)
srvr, conn, err = op.getServerAndConnection(ctx, requestID, deprioritizedServers)
if err != nil {
// If the returned error is retryable and there are retries remaining (negative
// retries means retry indefinitely), then retry the operation. Set the server
Expand Down
Loading

0 comments on commit d92f20d

Please sign in to comment.