Skip to content

Commit

Permalink
GODRIVER-2579 Retry connection check-out in a loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdale committed Aug 30, 2023
1 parent 43962b8 commit 89db4e6
Show file tree
Hide file tree
Showing 16 changed files with 536 additions and 185 deletions.
16 changes: 16 additions & 0 deletions internal/errutil/join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

//go:build go1.20

package errutil

import "errors"

// Join calls [errors.Join].
func Join(errs ...error) error {
return errors.Join(errs...)
}
87 changes: 87 additions & 0 deletions internal/errutil/join_go1.19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copied from https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go;l=15

// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build !go1.20

package errutil

import "errors"

// Join returns an error that wraps the given errors. Any nil error values are
// discarded. Join returns nil if every value in errs is nil. The error formats
// as the concatenation of the strings obtained by calling the Error method of
// each element of errs, with a newline between each string.
//
// A non-nil error returned by Join implements the "Unwrap() error" method.
func Join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &go119JoinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}

// go119JoinError is a Go 1.13-1.19 compatible joinable error type. Its error
// message is identical to [errors.Join], but it implements "Unwrap() error"
// instead of "Unwrap() []error".
//
// It is heavily based on the joinError from
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
type go119JoinError struct {
errs []error
}

func (e *go119JoinError) Error() string {
var b []byte
for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}
return string(b)
}

// Unwrap returns another go119JoinError with the same errors as the current
// go119JoinError except the first error in the slice. Continuing to call Unwrap
// on each returned error will increment through every error in the slice. The
// resulting behavior when using [errors.Is] and [errors.As] is similar to an
// error created using [errors.Join] in Go 1.20+.
func (e *go119JoinError) Unwrap() error {
if len(e.errs) == 1 {
return e.errs[0]
}
return &go119JoinError{errs: e.errs[1:]}
}

// Is calls [errors.Is] with the first error in the slice.
func (e *go119JoinError) Is(target error) bool {
if len(e.errs) == 0 {
return false
}
return errors.Is(e.errs[0], target)
}

// As calls [errors.As] with the first error in the slice.
func (e *go119JoinError) As(target interface{}) bool {
if len(e.errs) == 0 {
return false
}
return errors.As(e.errs[0], target)
}
165 changes: 165 additions & 0 deletions internal/errutil/join_go1.19_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

//go:build !go1.20

package errutil_test

import (
"context"
"errors"
"fmt"
"testing"

"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/errutil"
)

func TestJoinReturnsNil(t *testing.T) {
t.Parallel()

if err := errutil.Join(); err != nil {
t.Errorf("errutil.Join() = %v, want nil", err)
}
if err := errutil.Join(nil); err != nil {
t.Errorf("errutil.Join(nil) = %v, want nil", err)
}
if err := errutil.Join(nil, nil); err != nil {
t.Errorf("errutil.Join(nil, nil) = %v, want nil", err)
}
}

func TestJoin_Error(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
err2 := errors.New("err2")

tests := []struct {
errs []error
want string
}{
{
errs: []error{err1},
want: "err1",
},
{
errs: []error{err1, err2},
want: "err1\nerr2",
},
{
errs: []error{err1, nil, err2},
want: "err1\nerr2",
},
}

for _, test := range tests {
test := test // Capture range variable.

t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) {
t.Parallel()

got := errutil.Join(test.errs...).Error()
assert.Equal(t, test.want, got, "expected and actual error strings are different")
})
}
}

func TestJoin_ErrorsIs(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
err2 := errors.New("err2")

tests := []struct {
errs []error
target error
want bool
}{
{
errs: []error{err1},
target: err1,
want: true,
},
{
errs: []error{err1},
target: err2,
want: false,
},
{
errs: []error{err1, err2},
target: err2,
want: true,
},
{
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: context.DeadlineExceeded,
want: true,
},
}

for _, test := range tests {
test := test // Capture range variable.

t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) {
err := errutil.Join(test.errs...)
got := errors.Is(err, test.target)
assert.Equal(t, test.want, got, "expected and actual errors.Is result are different")
})
}
}

type errType1 struct{}

func (errType1) Error() string { return "" }

type errType2 struct{}

func (errType2) Error() string { return "" }

func TestJoin_ErrorsAs(t *testing.T) {
t.Parallel()

err1 := errType1{}
err2 := errType2{}

tests := []struct {
errs []error
target interface{}
want bool
}{
{
errs: []error{err1},
target: &errType1{},
want: true,
},
{
errs: []error{err1},
target: &errType2{},
want: false,
},
{
errs: []error{err1, err2},
target: &errType2{},
want: true,
},
{
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: &errType2{},
want: true,
},
}

for _, test := range tests {
test := test // Capture range variable.

t.Run(fmt.Sprintf("Join(%v)", test.errs), func(t *testing.T) {
err := errutil.Join(test.errs...)
got := errors.As(err, test.target)
assert.Equal(t, test.want, got, "expected and actual errors.Is result are different")
})
}
}
57 changes: 30 additions & 27 deletions mongo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,38 +114,41 @@ func IsDuplicateKeyError(err error) bool {
return false
}

// IsTimeout returns true if err is from a timeout
var timeoutErrs = [...]error{
context.DeadlineExceeded,
driver.ErrDeadlineWouldBeExceeded,
topology.ErrServerSelectionTimeout,
}

// IsTimeout returns true if err was caused by a timeout. For error chains,
// IsTimeout returns true if any error in the chain was caused by a timeout.
func IsTimeout(err error) bool {
for ; err != nil; err = unwrap(err) {
// check unwrappable errors together
if err == context.DeadlineExceeded {
return true
}
if err == driver.ErrDeadlineWouldBeExceeded {
return true
}
if err == topology.ErrServerSelectionTimeout {
return true
}
if _, ok := err.(topology.WaitQueueTimeoutError); ok {
// Check if the error chain contains any of the timeout error values.
for _, target := range timeoutErrs {
if errors.Is(err, target) {
return true
}
if ce, ok := err.(CommandError); ok && ce.IsMaxTimeMSExpiredError() {
return true
}
if we, ok := err.(WriteException); ok && we.WriteConcernError != nil &&
we.WriteConcernError.IsMaxTimeMSExpiredError() {
}

// Check if the error chain contains any error types that can indicate
// timeout.
if errors.As(err, &topology.WaitQueueTimeoutError{}) {
return true
}
if ce := (CommandError{}); errors.As(err, &ce) && ce.IsMaxTimeMSExpiredError() {
return true
}
if we := (WriteException{}); errors.As(err, &we) && we.WriteConcernError != nil && we.WriteConcernError.IsMaxTimeMSExpiredError() {
return true
}
if ne := net.Error(nil); errors.As(err, &ne) {
return ne.Timeout()
}
// Check timeout error labels.
if le := LabeledError(nil); errors.As(err, &le) {
if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") {
return true
}
if ne, ok := err.(net.Error); ok {
return ne.Timeout()
}
//timeout error labels
if le, ok := err.(LabeledError); ok {
if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") {
return true
}
}
}

return false
Expand Down
2 changes: 1 addition & 1 deletion mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ func TestClient(t *testing.T) {
err := mt.Client.Ping(ctx, nil)
cancel()
assert.NotNil(mt, err, "expected Ping to return an error")
assert.True(mt, mongo.IsTimeout(err), "expected a timeout error: got %v", err)
assert.True(mt, mongo.IsTimeout(err), "expected a timeout error, got %v", err)
}

// Assert that the Ping timeouts result in no connections being closed.
Expand Down
10 changes: 10 additions & 0 deletions mongo/integration/mtest/opmsg_deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ type mockDeployment struct {
}

var _ driver.Deployment = &mockDeployment{}
var _ driver.ConnDeployment = &mockDeployment{}
var _ driver.Server = &mockDeployment{}
var _ driver.Connector = &mockDeployment{}
var _ driver.Disconnector = &mockDeployment{}
Expand All @@ -141,6 +142,15 @@ func (md *mockDeployment) Kind() description.TopologyKind {
return description.Single
}

// TODO: How should this behave?
func (md *mockDeployment) SelectServerAndConnection(
ctx context.Context,
_ description.ServerSelector,
) (driver.Server, driver.Connection, error) {
conn, err := md.Connection(ctx)
return md, conn, err
}

// Connection implements the driver.Server interface.
func (md *mockDeployment) Connection(context.Context) (driver.Connection, error) {
return md.conn, nil
Expand Down
5 changes: 5 additions & 0 deletions x/bsonx/bsoncore/bsoncore.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ func AppendHeader(dst []byte, t bsontype.Type, key string) []byte {
panic(invalidKeyPanicMsg)
}

// Disallow the zero byte in a cstring because the zero byte is used as the
// terminating character. It's safe to check bytes instead of runes because
// all multibyte UTF-8 code points start with "11xxxxxx" or "10xxxxxx", so
// "00000000" will never be part of a multibyte UTF-8 code point.

dst = AppendType(dst, t)
dst = append(dst, key...)
return append(dst, 0x00)
Expand Down
Loading

0 comments on commit 89db4e6

Please sign in to comment.