diff --git a/internal/errutil/join.go b/internal/errutil/join.go new file mode 100644 index 0000000000..69b9ad2231 --- /dev/null +++ b/internal/errutil/join.go @@ -0,0 +1,17 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +//go:build go1.20 +// +build go1.20 + +package errutil + +import "errors" + +// Join calls [errors.Join]. +func Join(errs ...error) error { + return errors.Join(errs...) +} diff --git a/internal/errutil/join_go1.19.go b/internal/errutil/join_go1.19.go new file mode 100644 index 0000000000..995f353f2f --- /dev/null +++ b/internal/errutil/join_go1.19.go @@ -0,0 +1,88 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +//go:build !go1.20 +// +build !go1.20 + +package errutil + +import "errors" + +// Join returns an error that wraps the given errors. Any nil error values are +// discarded. Join returns nil if every value in errs is nil. The error formats +// as the concatenation of the strings obtained by calling the Error method of +// each element of errs, with a newline between each string. +// +// A non-nil error returned by Join implements the "Unwrap() error" method. +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +// joinError is a Go 1.13-1.19 compatible joinable error type. Its error +// message is identical to [errors.Join], but it implements "Unwrap() error" +// instead of "Unwrap() []error". +// +// It is heavily based on the joinError from +// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + var b []byte + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + return string(b) +} + +// Unwrap returns another joinError with the same errors as the current +// joinError except the first error in the slice. Continuing to call Unwrap +// on each returned error will increment through every error in the slice. The +// resulting behavior when using [errors.Is] and [errors.As] is similar to an +// error created using [errors.Join] in Go 1.20+. +func (e *joinError) Unwrap() error { + if len(e.errs) == 1 { + return e.errs[0] + } + return &joinError{errs: e.errs[1:]} +} + +// Is calls [errors.Is] with the first error in the slice. +func (e *joinError) Is(target error) bool { + if len(e.errs) == 0 { + return false + } + return errors.Is(e.errs[0], target) +} + +// As calls [errors.As] with the first error in the slice. +func (e *joinError) As(target interface{}) bool { + if len(e.errs) == 0 { + return false + } + return errors.As(e.errs[0], target) +} diff --git a/internal/errutil/join_test.go b/internal/errutil/join_test.go new file mode 100644 index 0000000000..06a539ac3b --- /dev/null +++ b/internal/errutil/join_test.go @@ -0,0 +1,163 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package errutil_test + +import ( + "context" + "errors" + "fmt" + "testing" + + "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/errutil" +) + +func TestJoinReturnsNil(t *testing.T) { + t.Parallel() + + if err := errutil.Join(); err != nil { + t.Errorf("errutil.Join() = %v, want nil", err) + } + if err := errutil.Join(nil); err != nil { + t.Errorf("errutil.Join(nil) = %v, want nil", err) + } + if err := errutil.Join(nil, nil); err != nil { + t.Errorf("errutil.Join(nil, nil) = %v, want nil", err) + } +} + +func TestJoin_Error(t *testing.T) { + t.Parallel() + + err1 := errors.New("err1") + err2 := errors.New("err2") + + tests := []struct { + errs []error + want string + }{ + { + errs: []error{err1}, + want: "err1", + }, + { + errs: []error{err1, err2}, + want: "err1\nerr2", + }, + { + errs: []error{err1, nil, err2}, + want: "err1\nerr2", + }, + } + + for _, test := range tests { + test := test // Capture range variable. + + t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) { + t.Parallel() + + got := errutil.Join(test.errs...).Error() + assert.Equal(t, test.want, got, "expected and actual error strings are different") + }) + } +} + +func TestJoin_ErrorsIs(t *testing.T) { + t.Parallel() + + err1 := errors.New("err1") + err2 := errors.New("err2") + + tests := []struct { + errs []error + target error + want bool + }{ + { + errs: []error{err1}, + target: err1, + want: true, + }, + { + errs: []error{err1}, + target: err2, + want: false, + }, + { + errs: []error{err1, err2}, + target: err2, + want: true, + }, + { + errs: []error{err1, nil, context.DeadlineExceeded, err2}, + target: context.DeadlineExceeded, + want: true, + }, + } + + for _, test := range tests { + test := test // Capture range variable. + + t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) { + err := errutil.Join(test.errs...) + got := errors.Is(err, test.target) + assert.Equal(t, test.want, got, "expected and actual errors.Is result are different") + }) + } +} + +type errType1 struct{} + +func (errType1) Error() string { return "" } + +type errType2 struct{} + +func (errType2) Error() string { return "" } + +func TestJoin_ErrorsAs(t *testing.T) { + t.Parallel() + + err1 := errType1{} + err2 := errType2{} + + tests := []struct { + errs []error + target interface{} + want bool + }{ + { + errs: []error{err1}, + target: &errType1{}, + want: true, + }, + { + errs: []error{err1}, + target: &errType2{}, + want: false, + }, + { + errs: []error{err1, err2}, + target: &errType2{}, + want: true, + }, + { + errs: []error{err1, nil, context.DeadlineExceeded, err2}, + target: &errType2{}, + want: true, + }, + } + + for _, test := range tests { + test := test // Capture range variable. + + t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) { + err := errutil.Join(test.errs...) + got := errors.As(err, test.target) + assert.Equal(t, test.want, got, "expected and actual errors.Is result are different") + }) + } +} diff --git a/mongo/errors.go b/mongo/errors.go index 5f2b1b819b..7843f2260b 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -114,38 +114,42 @@ func IsDuplicateKeyError(err error) bool { return false } -// IsTimeout returns true if err is from a timeout +// timeoutErrs is a list of error values that indicate a timeout happened. +var timeoutErrs = [...]error{ + context.DeadlineExceeded, + driver.ErrDeadlineWouldBeExceeded, + topology.ErrServerSelectionTimeout, +} + +// IsTimeout returns true if err was caused by a timeout. For error chains, +// IsTimeout returns true if any error in the chain was caused by a timeout. func IsTimeout(err error) bool { - for ; err != nil; err = unwrap(err) { - // check unwrappable errors together - if err == context.DeadlineExceeded { - return true - } - if err == driver.ErrDeadlineWouldBeExceeded { - return true - } - if err == topology.ErrServerSelectionTimeout { - return true - } - if _, ok := err.(topology.WaitQueueTimeoutError); ok { + // Check if the error chain contains any of the timeout error values. + for _, target := range timeoutErrs { + if errors.Is(err, target) { return true } - if ce, ok := err.(CommandError); ok && ce.IsMaxTimeMSExpiredError() { - return true - } - if we, ok := err.(WriteException); ok && we.WriteConcernError != nil && - we.WriteConcernError.IsMaxTimeMSExpiredError() { + } + + // Check if the error chain contains any error types that can indicate + // timeout. + if errors.As(err, &topology.WaitQueueTimeoutError{}) { + return true + } + if ce := (CommandError{}); errors.As(err, &ce) && ce.IsMaxTimeMSExpiredError() { + return true + } + if we := (WriteException{}); errors.As(err, &we) && we.WriteConcernError != nil && we.WriteConcernError.IsMaxTimeMSExpiredError() { + return true + } + if ne := net.Error(nil); errors.As(err, &ne) { + return ne.Timeout() + } + // Check timeout error labels. + if le := LabeledError(nil); errors.As(err, &le) { + if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { return true } - if ne, ok := err.(net.Error); ok { - return ne.Timeout() - } - //timeout error labels - if le, ok := err.(LabeledError); ok { - if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") { - return true - } - } } return false diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index 914ca863b7..911d785a70 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -711,7 +711,7 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(ctx, nil) cancel() assert.NotNil(mt, err, "expected Ping to return an error") - assert.True(mt, mongo.IsTimeout(err), "expected a timeout error: got %v", err) + assert.True(mt, mongo.IsTimeout(err), "expected a timeout error, got %v", err) } // Assert that the Ping timeouts result in no connections being closed. diff --git a/mongo/integration/mtest/opmsg_deployment.go b/mongo/integration/mtest/opmsg_deployment.go index ae4e359380..bdc852293a 100644 --- a/mongo/integration/mtest/opmsg_deployment.go +++ b/mongo/integration/mtest/opmsg_deployment.go @@ -124,6 +124,7 @@ type mockDeployment struct { } var _ driver.Deployment = &mockDeployment{} +var _ driver.ConnDeployment = &mockDeployment{} var _ driver.Server = &mockDeployment{} var _ driver.Connector = &mockDeployment{} var _ driver.Disconnector = &mockDeployment{} @@ -141,6 +142,15 @@ func (md *mockDeployment) Kind() description.TopologyKind { return description.Single } +// TODO: How should this behave? +func (md *mockDeployment) SelectServerAndConnection( + ctx context.Context, + _ description.ServerSelector, +) (driver.Server, driver.Connection, error) { + conn, err := md.Connection(ctx) + return md, conn, err +} + // Connection implements the driver.Server interface. func (md *mockDeployment) Connection(context.Context) (driver.Connection, error) { return md.conn, nil diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 5fd3ddcb42..0b7a62f68b 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -23,6 +23,43 @@ type Deployment interface { Kind() description.TopologyKind } +// TODO: Can we integrate this into a type that we pass into +// getServerAndConnection instead of type-asserting it there? +// TODO: Name? +type ConnDeployment interface { + SelectServerAndConnection(context.Context, description.ServerSelector) (Server, Connection, error) +} + +var _ ConnDeployment = &connDeployment{} + +// TODO: Name? +type connDeployment struct { + deployment Deployment +} + +func (cd *connDeployment) SelectServerAndConnection( + ctx context.Context, + selector description.ServerSelector, +) (Server, Connection, error) { + server, err := cd.deployment.SelectServer(ctx, selector) + if err != nil { + return nil, nil, err + } + conn, err := server.Connection(ctx) + if err != nil { + return nil, nil, err + } + return server, conn, nil +} + +// TODO: Name? +func makeConnDeployment(d Deployment) ConnDeployment { + if cd, ok := d.(ConnDeployment); ok { + return cd + } + return &connDeployment{deployment: d} +} + // Connector represents a type that can connect to a server. type Connector interface { Connect() error @@ -258,12 +295,9 @@ const ( // is not specified. For example, if an insert is batch split into 4 commands then each of // those commands is eligible for one retry. RetryOncePerCommand - // RetryContext will enable retrying until the context.Context's deadline is exceeded or it is - // cancelled. - RetryContext ) // Enabled returns if this RetryMode enables retrying. func (rm RetryMode) Enabled() bool { - return rm == RetryOnce || rm == RetryOncePerCommand || rm == RetryContext + return rm == RetryOnce || rm == RetryOncePerCommand } diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go index 55f2fb37eb..177aa1234b 100644 --- a/x/mongo/driver/errors.go +++ b/x/mongo/driver/errors.go @@ -264,10 +264,15 @@ func (e Error) UnsupportedStorageEngine() bool { // Error implements the error interface. func (e Error) Error() string { + var msg string if e.Name != "" { - return fmt.Sprintf("(%v) %v", e.Name, e.Message) + msg = fmt.Sprintf("(%v)", e.Name) } - return e.Message + msg += " " + e.Message + if e.Wrapped != nil { + msg += ": " + e.Wrapped.Error() + } + return msg } // Unwrap returns the underlying error. diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 90573daa53..101cf388d8 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -59,11 +59,6 @@ const ( readSnapshotMinWireVersion int32 = 13 ) -// RetryablePoolError is a connection pool error that can be retried while executing an operation. -type RetryablePoolError interface { - Retryable() bool -} - // labeledError is an error that can have error labels added to it. type labeledError interface { error @@ -259,11 +254,11 @@ type Operation struct { // cluster clocks to be only updated as far as the last command that's been run. Clock *session.ClusterClock - // RetryMode specifies how to retry. There are three modes that enable retry: RetryOnce, - // RetryOncePerCommand, and RetryContext. For more information about what these modes do, please - // refer to their definitions. Both RetryMode and Type must be set for retryability to be enabled. - // If Timeout is set on the Client, the operation will automatically retry as many times as - // possible unless RetryNone is used. + // RetryMode specifies how to retry. There are two modes that enable retry: RetryOnce and + // RetryOncePerCommand. For more information about what these modes do, please refer to their + // definitions. Both RetryMode and Type must be set for retryability to be enabled. If Timeout + // is set on the Client, the operation will automatically retry as many times as possible unless + // RetryNone is used. RetryMode *RetryMode // Type specifies the kind of operation this is. There is only one mode that enables retry: Write. @@ -321,37 +316,31 @@ func (op Operation) shouldEncrypt() bool { return op.Crypt != nil && !op.Crypt.BypassAutoEncryption() } -// selectServer handles performing server selection for an operation. -func (op Operation) selectServer(ctx context.Context) (Server, error) { - if err := op.Validate(); err != nil { - return nil, err - } +// TODO: Rewrite comment. +// getServerAndConnection should be used to retrieve a Server and Connection to +// execute an operation. +func (op Operation) getServerAndConnection( + ctx context.Context, + selector description.ServerSelector, +) (Server, Connection, error) { + var server Server + var conn Connection + var err error - selector := op.Selector - if selector == nil { - rp := op.ReadPreference - if rp == nil { - rp = readpref.Primary() - } - selector = description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(rp), - description.LatencySelector(defaultLocalThreshold), - }) + if op.Client != nil && op.Client.PinnedConnection != nil { + // If the provided client session has a pinned connection, it should be + // used for the operation because this indicates that we're in a + // transaction and the target server is behind a load balancer. + conn = op.Client.PinnedConnection + server, err = op.Deployment.SelectServer(ctx, selector) + } else { + cd := makeConnDeployment(op.Deployment) + server, conn, err = cd.SelectServerAndConnection(ctx, selector) } - - ctx = logger.WithOperationName(ctx, op.Name) - ctx = logger.WithOperationID(ctx, wiremessage.CurrentRequestID()) - - return op.Deployment.SelectServer(ctx, selector) -} - -// getServerAndConnection should be used to retrieve a Server and Connection to execute an operation. -func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) { - server, err := op.selectServer(ctx) if err != nil { - if op.Client != nil && - !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { - err = Error{ + if op.Client != nil && !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() { + return nil, nil, Error{ + // TODO: Don't cause a repeating error message. Message: err.Error(), Labels: []string{TransientTransactionError}, Wrapped: err, @@ -360,18 +349,6 @@ func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connect return nil, nil, err } - // If the provided client session has a pinned connection, it should be used for the operation because this - // indicates that we're in a transaction and the target server is behind a load balancer. - if op.Client != nil && op.Client.PinnedConnection != nil { - return server, op.Client.PinnedConnection, nil - } - - // Otherwise, default to checking out a connection from the server's pool. - conn, err := server.Connection(ctx) - if err != nil { - return nil, nil, err - } - // If we're in load balanced mode and this is the first operation in a transaction, pin the session to a connection. if conn.Description().LoadBalanced() && op.Client != nil && op.Client.TransactionStarting() { pinnedConn, ok := conn.(PinnedConnection) @@ -450,15 +427,11 @@ func (op Operation) Execute(ctx context.Context) error { switch *op.RetryMode { case RetryOnce, RetryOncePerCommand: retries = 1 - case RetryContext: - retries = -1 } case Read: switch *op.RetryMode { case RetryOnce, RetryOncePerCommand: retries = 1 - case RetryContext: - retries = -1 } } } @@ -529,17 +502,36 @@ func (op Operation) Execute(ctx context.Context) error { memoryPool.Put(wm) } }() + + // TODO: Comment? + selector := op.Selector + if selector == nil { + rp := op.ReadPreference + if rp == nil { + rp = readpref.Primary() + } + selector = description.CompositeSelector([]description.ServerSelector{ + description.ReadPrefSelector(rp), + description.LatencySelector(defaultLocalThreshold), + }) + } + + ctx = logger.WithOperationName(ctx, op.Name) + ctx = logger.WithOperationID(ctx, wiremessage.CurrentRequestID()) + for { wiremessage.NextRequestID() // If the server or connection are nil, try to select a new server and get a new connection. if srvr == nil || conn == nil { - srvr, conn, err = op.getServerAndConnection(ctx) + srvr, conn, err = op.getServerAndConnection(ctx, selector) if err != nil { - // If the returned error is retryable and there are retries remaining (negative - // retries means retry indefinitely), then retry the operation. Set the server - // and connection to nil to request a new server and connection. - if rerr, ok := err.(RetryablePoolError); ok && rerr.Retryable() && retries != 0 { + if retrySupported && ctx.Err() == nil && retries > 0 { + if op.Client != nil && op.Client.Committing { + // Apply majority write concern for retries + op.Client.UpdateCommitTransactionWriteConcern() + op.WriteConcern = op.Client.CurrentWc + } resetForRetry(err) continue } diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index d4c5a1b6a0..a604b8b688 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -56,47 +56,43 @@ func compareErrors(err1, err2 error) bool { return true } +// TODO: +// func TestOperationExecute(t *testing.T) { +// t.Run("uses specified server selector", func(t *testing.T) { +// want := new(mockServerSelector) +// d := new(mockDeployment) +// op := &Operation{ +// CommandFn: func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil }, +// Deployment: d, +// Database: "testing", +// Selector: want, +// } +// _, err := op.getServerAndConnection(context.Background()) +// noerr(t, err) +// got := d.params.selector +// if !cmp.Equal(got, want) { +// t.Errorf("Did not get expected server selector. got %v; want %v", got, want) +// } +// }) +// t.Run("uses a default server selector", func(t *testing.T) { +// d := new(mockDeployment) +// op := &Operation{ +// CommandFn: func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil }, +// Deployment: d, +// Database: "testing", +// } +// err := op.Execute(context.Background()) +// require.NoError(t, err, "Execute error") + +// if d.params.selector == nil { +// t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed .") +// } +// }) +// } + func TestOperation(t *testing.T) { int64ToPtr := func(i64 int64) *int64 { return &i64 } - t.Run("selectServer", func(t *testing.T) { - t.Run("returns validation error", func(t *testing.T) { - op := &Operation{} - _, err := op.selectServer(context.Background()) - if err == nil { - t.Error("Expected a validation error from selectServer, but got ") - } - }) - t.Run("uses specified server selector", func(t *testing.T) { - want := new(mockServerSelector) - d := new(mockDeployment) - op := &Operation{ - CommandFn: func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil }, - Deployment: d, - Database: "testing", - Selector: want, - } - _, err := op.selectServer(context.Background()) - noerr(t, err) - got := d.params.selector - if !cmp.Equal(got, want) { - t.Errorf("Did not get expected server selector. got %v; want %v", got, want) - } - }) - t.Run("uses a default server selector", func(t *testing.T) { - d := new(mockDeployment) - op := &Operation{ - CommandFn: func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil }, - Deployment: d, - Database: "testing", - } - _, err := op.selectServer(context.Background()) - noerr(t, err) - if d.params.selector == nil { - t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed .") - } - }) - }) t.Run("Validate", func(t *testing.T) { cmdFn := func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil } d := new(mockDeployment) @@ -746,14 +742,6 @@ func (m *mockConnection) ReadWireMessage(_ context.Context) ([]byte, error) { return m.rReadWM, m.rReadErr } -type retryableError struct { - error -} - -func (retryableError) Retryable() bool { return true } - -var _ RetryablePoolError = retryableError{} - // mockRetryServer is used to test retry of connection checkout. Returns a retryable error from // Connection(). type mockRetryServer struct { @@ -770,7 +758,7 @@ func (ms *mockRetryServer) Connection(ctx context.Context) (Connection, error) { } time.Sleep(1 * time.Millisecond) - return nil, retryableError{error: errors.New("test error")} + return nil, errors.New("test error") } func (ms *mockRetryServer) RTTMonitor() RTTMonitor { @@ -778,36 +766,7 @@ func (ms *mockRetryServer) RTTMonitor() RTTMonitor { } func TestRetry(t *testing.T) { - t.Run("retries multiple times with RetryContext", func(t *testing.T) { - d := new(mockDeployment) - ms := new(mockRetryServer) - d.returns.server = ms - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - retry := RetryContext - err := Operation{ - CommandFn: func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil }, - Deployment: d, - Database: "testing", - RetryMode: &retry, - Type: Read, - }.Execute(ctx) - assert.NotNil(t, err, "expected an error from Execute()") - - // Expect Connection() to be called at least 3 times. The first call is the initial attempt - // to run the operation and the second is the retry. The third indicates that we retried - // more than once, which is the behavior we want to assert. - assert.True(t, - ms.numCallsToConnection >= 3, - "expected Connection() to be called at least 3 times") - - deadline, _ := ctx.Deadline() - assert.True(t, - time.Now().After(deadline), - "expected operation to complete only after the context deadline is exceeded") - }) + // TODO: ? } func TestConvertI64PtrToI32Ptr(t *testing.T) { diff --git a/x/mongo/driver/topology/errors_test.go b/x/mongo/driver/topology/errors_test.go new file mode 100644 index 0000000000..e0c7eedf5f --- /dev/null +++ b/x/mongo/driver/topology/errors_test.go @@ -0,0 +1,16 @@ +package topology + +import ( + "context" + "errors" + "testing" + + "go.mongodb.org/mongo-driver/internal/assert" +) + +func TestWaitQueueTimeoutError(t *testing.T) { + t.Run("", func(t *testing.T) { + err := WaitQueueTimeoutError{Wrapped: context.DeadlineExceeded} + assert.True(t, errors.Is(err, context.DeadlineExceeded), "") + }) +} diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 5d2369352e..dafea24c80 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -18,7 +18,6 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/x/mongo/driver" ) // Connection pool state constants. @@ -46,8 +45,8 @@ type PoolError string func (pe PoolError) Error() string { return string(pe) } -// poolClearedError is an error returned when the connection pool is cleared or currently paused. It -// is a retryable error. +// poolClearedError is an error returned when the connection pool is cleared or +// currently paused. type poolClearedError struct { err error address address.Address @@ -60,12 +59,6 @@ func (pce poolClearedError) Error() string { pce.err) } -// Retryable returns true. All poolClearedErrors are retryable. -func (poolClearedError) Retryable() bool { return true } - -// Assert that poolClearedError is a driver.RetryablePoolError. -var _ driver.RetryablePoolError = poolClearedError{} - // poolConfig contains all aspects of the pool that can be configured type poolConfig struct { Address address.Address diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index b0683021ee..44965dcfa7 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -24,6 +24,7 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/internal/errutil" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/internal/randutil" "go.mongodb.org/mongo-driver/mongo/address" @@ -621,6 +622,64 @@ func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelect } } +func (t *Topology) SelectServerAndConnection( + ctx context.Context, + ss description.ServerSelector, +) (driver.Server, driver.Connection, error) { + if atomic.LoadInt64(&t.state) != topologyConnected { + if mustLogServerSelection(t, logger.LevelDebug) { + logServerSelectionFailed(ctx, t, ss, ErrTopologyClosed) + } + + return nil, nil, ErrTopologyClosed + } + + // Retry selecting a server and checking out a connection until + // TODO:? + + var server driver.Server + var conn driver.Connection + var errs []error + start := time.Now() + for { + if ctx.Err() != nil { + // If the context is expired, append the context error and break out + // of the loop. Note that this may result in a duplicated context + // error in the errors list if a previous retry error wraps the + // context error, but it's better to have too much information here + // than not enough. + errs = append(errs, ctx.Err()) + return nil, nil, errutil.Join(errs...) + } + // TODO: Should we add connect timeout? + if timeout := t.cfg.ServerSelectionTimeout; timeout > 0 && time.Since(start) > timeout { + // TODO: What error here? + errs = append(errs, ServerSelectionError{ + Wrapped: ErrServerSelectionTimeout, + Desc: t.Description(), + }) + return nil, nil, errutil.Join(errs...) + } + + var err error + server, err = t.SelectServer(ctx, ss) + if err != nil { + // TODO: Are there any SelectServer errors that should be retried? + errs = append(errs, fmt.Errorf("error selecting a server: %w", err)) + return nil, nil, errutil.Join(errs...) + } + + conn, err = server.Connection(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("error checking out a connection: %w", err)) + continue + } + + // We successfully got a server and connection. Return. + return server, conn, nil + } +} + // pick2 returns 2 random server descriptions from the input slice of server descriptions, // guaranteeing that the same element from the slice is not picked twice. The order of server // descriptions in the input slice may be modified. If fewer than 2 server descriptions are