Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2579 Retry connection check-out in a loop. #1368

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions internal/errutil/join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

//go:build go1.20
// +build go1.20

package errutil

import "errors"

// Join calls [errors.Join].
func Join(errs ...error) error {
return errors.Join(errs...)
}
88 changes: 88 additions & 0 deletions internal/errutil/join_go1.19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

//go:build !go1.20
// +build !go1.20

package errutil

import "errors"

// Join returns an error that wraps the given errors. Any nil error values are
// discarded. Join returns nil if every value in errs is nil. The error formats
// as the concatenation of the strings obtained by calling the Error method of
// each element of errs, with a newline between each string.
//
// A non-nil error returned by Join implements the "Unwrap() error" method.
func Join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}

// joinError is a Go 1.13-1.19 compatible joinable error type. Its error
// message is identical to [errors.Join], but it implements "Unwrap() error"
// instead of "Unwrap() []error".
//
// It is heavily based on the joinError from
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
type joinError struct {
errs []error
}

func (e *joinError) Error() string {
var b []byte
for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}
return string(b)
}

// Unwrap returns another joinError with the same errors as the current
// joinError except the first error in the slice. Continuing to call Unwrap
// on each returned error will increment through every error in the slice. The
// resulting behavior when using [errors.Is] and [errors.As] is similar to an
// error created using [errors.Join] in Go 1.20+.
func (e *joinError) Unwrap() error {
if len(e.errs) == 1 {
return e.errs[0]
}
return &joinError{errs: e.errs[1:]}
}

// Is calls [errors.Is] with the first error in the slice.
func (e *joinError) Is(target error) bool {
if len(e.errs) == 0 {
return false
}
return errors.Is(e.errs[0], target)
}

// As calls [errors.As] with the first error in the slice.
func (e *joinError) As(target interface{}) bool {
if len(e.errs) == 0 {
return false
}
return errors.As(e.errs[0], target)
}
163 changes: 163 additions & 0 deletions internal/errutil/join_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package errutil_test

import (
"context"
"errors"
"fmt"
"testing"

"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/errutil"
)

func TestJoinReturnsNil(t *testing.T) {
t.Parallel()

if err := errutil.Join(); err != nil {
t.Errorf("errutil.Join() = %v, want nil", err)
}
if err := errutil.Join(nil); err != nil {
t.Errorf("errutil.Join(nil) = %v, want nil", err)
}
if err := errutil.Join(nil, nil); err != nil {
t.Errorf("errutil.Join(nil, nil) = %v, want nil", err)
}
}

func TestJoin_Error(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
err2 := errors.New("err2")

tests := []struct {
errs []error
want string
}{
{
errs: []error{err1},
want: "err1",
},
{
errs: []error{err1, err2},
want: "err1\nerr2",
},
{
errs: []error{err1, nil, err2},
want: "err1\nerr2",
},
}

for _, test := range tests {
test := test // Capture range variable.

t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) {
t.Parallel()

got := errutil.Join(test.errs...).Error()
assert.Equal(t, test.want, got, "expected and actual error strings are different")
})
}
}

func TestJoin_ErrorsIs(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
err2 := errors.New("err2")

tests := []struct {
errs []error
target error
want bool
}{
{
errs: []error{err1},
target: err1,
want: true,
},
{
errs: []error{err1},
target: err2,
want: false,
},
{
errs: []error{err1, err2},
target: err2,
want: true,
},
{
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: context.DeadlineExceeded,
want: true,
},
}

for _, test := range tests {
test := test // Capture range variable.

t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) {
err := errutil.Join(test.errs...)
got := errors.Is(err, test.target)
assert.Equal(t, test.want, got, "expected and actual errors.Is result are different")
})
}
}

type errType1 struct{}

func (errType1) Error() string { return "" }

type errType2 struct{}

func (errType2) Error() string { return "" }

func TestJoin_ErrorsAs(t *testing.T) {
t.Parallel()

err1 := errType1{}
err2 := errType2{}

tests := []struct {
errs []error
target interface{}
want bool
}{
{
errs: []error{err1},
target: &errType1{},
want: true,
},
{
errs: []error{err1},
target: &errType2{},
want: false,
},
{
errs: []error{err1, err2},
target: &errType2{},
want: true,
},
{
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: &errType2{},
want: true,
},
}

for _, test := range tests {
test := test // Capture range variable.

t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) {
err := errutil.Join(test.errs...)
got := errors.As(err, test.target)
assert.Equal(t, test.want, got, "expected and actual errors.Is result are different")
})
}
}
58 changes: 31 additions & 27 deletions mongo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,38 +114,42 @@ func IsDuplicateKeyError(err error) bool {
return false
}

// IsTimeout returns true if err is from a timeout
// timeoutErrs is a list of error values that indicate a timeout happened.
var timeoutErrs = [...]error{
context.DeadlineExceeded,
driver.ErrDeadlineWouldBeExceeded,
topology.ErrServerSelectionTimeout,
}

// IsTimeout returns true if err was caused by a timeout. For error chains,
// IsTimeout returns true if any error in the chain was caused by a timeout.
func IsTimeout(err error) bool {
for ; err != nil; err = unwrap(err) {
// check unwrappable errors together
if err == context.DeadlineExceeded {
return true
}
if err == driver.ErrDeadlineWouldBeExceeded {
return true
}
if err == topology.ErrServerSelectionTimeout {
return true
}
if _, ok := err.(topology.WaitQueueTimeoutError); ok {
// Check if the error chain contains any of the timeout error values.
for _, target := range timeoutErrs {
if errors.Is(err, target) {
return true
}
if ce, ok := err.(CommandError); ok && ce.IsMaxTimeMSExpiredError() {
return true
}
if we, ok := err.(WriteException); ok && we.WriteConcernError != nil &&
we.WriteConcernError.IsMaxTimeMSExpiredError() {
}

// Check if the error chain contains any error types that can indicate
// timeout.
if errors.As(err, &topology.WaitQueueTimeoutError{}) {
return true
}
if ce := (CommandError{}); errors.As(err, &ce) && ce.IsMaxTimeMSExpiredError() {
return true
}
if we := (WriteException{}); errors.As(err, &we) && we.WriteConcernError != nil && we.WriteConcernError.IsMaxTimeMSExpiredError() {
return true
}
if ne := net.Error(nil); errors.As(err, &ne) {
return ne.Timeout()
}
// Check timeout error labels.
if le := LabeledError(nil); errors.As(err, &le) {
if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") {
return true
}
if ne, ok := err.(net.Error); ok {
return ne.Timeout()
}
//timeout error labels
if le, ok := err.(LabeledError); ok {
if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") {
return true
}
}
}

return false
Expand Down
2 changes: 1 addition & 1 deletion mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ func TestClient(t *testing.T) {
err := mt.Client.Ping(ctx, nil)
cancel()
assert.NotNil(mt, err, "expected Ping to return an error")
assert.True(mt, mongo.IsTimeout(err), "expected a timeout error: got %v", err)
assert.True(mt, mongo.IsTimeout(err), "expected a timeout error, got %v", err)
}

// Assert that the Ping timeouts result in no connections being closed.
Expand Down
10 changes: 10 additions & 0 deletions mongo/integration/mtest/opmsg_deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ type mockDeployment struct {
}

var _ driver.Deployment = &mockDeployment{}
var _ driver.ConnDeployment = &mockDeployment{}
var _ driver.Server = &mockDeployment{}
var _ driver.Connector = &mockDeployment{}
var _ driver.Disconnector = &mockDeployment{}
Expand All @@ -141,6 +142,15 @@ func (md *mockDeployment) Kind() description.TopologyKind {
return description.Single
}

// TODO: How should this behave?
func (md *mockDeployment) SelectServerAndConnection(
ctx context.Context,
_ description.ServerSelector,
) (driver.Server, driver.Connection, error) {
conn, err := md.Connection(ctx)
return md, conn, err
}

// Connection implements the driver.Server interface.
func (md *mockDeployment) Connection(context.Context) (driver.Connection, error) {
return md.conn, nil
Expand Down
Loading
Loading