From 43962b82a22fc391fb3e9653cdd4e25769b8ce04 Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Thu, 17 Aug 2023 17:51:06 -0700 Subject: [PATCH] GODRIVER-2929 Replace all uses of errutil.WrapErrorf with fmt.Errorf (#1354) --- internal/credproviders/imds_provider.go | 11 ++-- internal/errutil/errutil.go | 75 ------------------------- internal/integtest/integtest.go | 3 +- mongo/integration/mtest/setup.go | 2 +- mongo/options/clientoptions_test.go | 41 ++++++-------- x/mongo/driver/auth/creds/gcpcreds.go | 17 ++++-- x/mongo/driver/connstring/connstring.go | 25 ++++----- x/mongo/driver/mongocrypt/mongocrypt.go | 3 +- x/mongo/driver/operation.go | 16 ++++-- 9 files changed, 60 insertions(+), 133 deletions(-) delete mode 100644 internal/errutil/errutil.go diff --git a/internal/credproviders/imds_provider.go b/internal/credproviders/imds_provider.go index b20d8ad78d..96dad1a829 100644 --- a/internal/credproviders/imds_provider.go +++ b/internal/credproviders/imds_provider.go @@ -16,7 +16,6 @@ import ( "time" "go.mongodb.org/mongo-driver/internal/aws/credentials" - "go.mongodb.org/mongo-driver/internal/errutil" ) const ( @@ -47,7 +46,7 @@ func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Va v := credentials.Value{ProviderName: AzureProviderName} req, err := http.NewRequest(http.MethodGet, azureURI, nil) if err != nil { - return v, errutil.WrapErrorf(err, "unable to retrieve Azure credentials") + return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err) } q := make(url.Values) q.Set("api-version", "2018-02-01") @@ -58,15 +57,15 @@ func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Va resp, err := a.httpClient.Do(req.WithContext(ctx)) if err != nil { - return v, errutil.WrapErrorf(err, "unable to retrieve Azure credentials") + return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err) } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - return v, errutil.WrapErrorf(err, "unable to retrieve Azure credentials: error reading response body") + return v, fmt.Errorf("unable to retrieve Azure credentials: error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { - return v, errutil.WrapErrorf(err, "unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body) + return v, fmt.Errorf("unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body) } var tokenResponse struct { AccessToken string `json:"access_token"` @@ -75,7 +74,7 @@ func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Va // Attempt to read body as JSON err = json.Unmarshal(body, &tokenResponse) if err != nil { - return v, errutil.WrapErrorf(err, "unable to retrieve Azure credentials: error reading body JSON. Response body: %s", body) + return v, fmt.Errorf("unable to retrieve Azure credentials: error reading body JSON: %w (response body: %s)", err, body) } if tokenResponse.AccessToken == "" { return v, fmt.Errorf("unable to retrieve Azure credentials: got unexpected empty accessToken from Azure Metadata Server. Response body: %s", body) diff --git a/internal/errutil/errutil.go b/internal/errutil/errutil.go deleted file mode 100644 index 9779f38fa5..0000000000 --- a/internal/errutil/errutil.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package errutil - -import ( - "fmt" -) - -// WrappedError represents an error that contains another error. -type WrappedError interface { - // Message gets the basic message of the error. - Message() string - // Inner gets the inner error if one exists. - Inner() error -} - -// rolledUpErrorMessage gets a flattened error message. -func rolledUpErrorMessage(err error) string { - if wrappedErr, ok := err.(WrappedError); ok { - inner := wrappedErr.Inner() - if inner != nil { - return fmt.Sprintf("%s: %s", wrappedErr.Message(), rolledUpErrorMessage(inner)) - } - - return wrappedErr.Message() - } - - return err.Error() -} - -// UnwrapError attempts to unwrap the error down to its root cause. -func UnwrapError(err error) error { - - switch tErr := err.(type) { - case WrappedError: - return UnwrapError(tErr.Inner()) - } - - return err -} - -// WrapError wraps an error with a message. -func WrapError(inner error, message string) error { - return &wrappedError{message, inner} -} - -// WrapErrorf wraps an error with a message. -func WrapErrorf(inner error, format string, args ...interface{}) error { - return &wrappedError{fmt.Sprintf(format, args...), inner} -} - -type wrappedError struct { - message string - inner error -} - -func (e *wrappedError) Message() string { - return e.message -} - -func (e *wrappedError) Error() string { - return rolledUpErrorMessage(e) -} - -func (e *wrappedError) Inner() error { - return e.inner -} - -func (e *wrappedError) Unwrap() error { - return e.inner -} diff --git a/internal/integtest/integtest.go b/internal/integtest/integtest.go index 8cdfbf6741..d89bcd7539 100644 --- a/internal/integtest/integtest.go +++ b/internal/integtest/integtest.go @@ -8,6 +8,7 @@ package integtest import ( "context" + "errors" "fmt" "math" "os" @@ -202,7 +203,7 @@ func AddServerlessAuthCredentials(uri string) (string, error) { } else if strings.HasPrefix(uri, "mongodb://") { scheme = "mongodb://" } else { - return "", fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"") + return "", errors.New(`scheme must be "mongodb" or "mongodb+srv"`) } uri = scheme + user + ":" + password + "@" + uri[len(scheme):] diff --git a/mongo/integration/mtest/setup.go b/mongo/integration/mtest/setup.go index fd26629069..be2dae93b8 100644 --- a/mongo/integration/mtest/setup.go +++ b/mongo/integration/mtest/setup.go @@ -323,7 +323,7 @@ func addServerlessAuthCredentials(uri string) (string, error) { } else if strings.HasPrefix(uri, "mongodb://") { scheme = "mongodb://" } else { - return "", fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"") + return "", errors.New(`scheme must be "mongodb" or "mongodb+srv"`) } uri = scheme + user + ":" + password + "@" + uri[len(scheme):] diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index 2ae3a673ce..7c148ca0bd 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -28,7 +28,6 @@ import ( "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/internal/errutil" "go.mongodb.org/mongo-driver/internal/httputil" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -41,9 +40,9 @@ var tClientOptions = reflect.TypeOf(&ClientOptions{}) func TestClientOptions(t *testing.T) { t.Run("ApplyURI/doesn't overwrite previous errors", func(t *testing.T) { uri := "not-mongo-db-uri://" - want := errutil.WrapErrorf( - errors.New(`scheme must be "mongodb" or "mongodb+srv"`), "error parsing uri", - ) + want := fmt.Errorf( + "error parsing uri: %w", + errors.New(`scheme must be "mongodb" or "mongodb+srv"`)) co := Client().ApplyURI(uri).ApplyURI("mongodb://localhost/") got := co.Validate() if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { @@ -209,9 +208,9 @@ func TestClientOptions(t *testing.T) { "ParseError", "not-mongo-db-uri://", &ClientOptions{ - err: errutil.WrapErrorf( - errors.New(`scheme must be "mongodb" or "mongodb+srv"`), "error parsing uri", - ), + err: fmt.Errorf( + "error parsing uri: %w", + errors.New(`scheme must be "mongodb" or "mongodb+srv"`)), HTTPClient: httputil.DefaultHTTPClient, }, }, @@ -285,10 +284,9 @@ func TestClientOptions(t *testing.T) { "Unescaped slash in username", "mongodb:///:pwd@localhost", &ClientOptions{ - err: errutil.WrapErrorf( - errors.New("unescaped slash in username"), - "error parsing uri", - ), + err: fmt.Errorf( + "error parsing uri: %w", + errors.New("unescaped slash in username")), HTTPClient: httputil.DefaultHTTPClient, }, }, @@ -472,10 +470,9 @@ func TestClientOptions(t *testing.T) { "TLS only tlsCertificateFile", "mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem", &ClientOptions{ - err: errutil.WrapErrorf( - errors.New("the tlsPrivateKeyFile URI option must be provided if the tlsCertificateFile option is specified"), - "error validating uri", - ), + err: fmt.Errorf( + "error validating uri: %w", + errors.New("the tlsPrivateKeyFile URI option must be provided if the tlsCertificateFile option is specified")), HTTPClient: httputil.DefaultHTTPClient, }, }, @@ -483,10 +480,9 @@ func TestClientOptions(t *testing.T) { "TLS only tlsPrivateKeyFile", "mongodb://localhost/?tlsPrivateKeyFile=testdata/nopass/key.pem", &ClientOptions{ - err: errutil.WrapErrorf( - errors.New("the tlsCertificateFile URI option must be provided if the tlsPrivateKeyFile option is specified"), - "error validating uri", - ), + err: fmt.Errorf( + "error validating uri: %w", + errors.New("the tlsCertificateFile URI option must be provided if the tlsPrivateKeyFile option is specified")), HTTPClient: httputil.DefaultHTTPClient, }, }, @@ -494,11 +490,10 @@ func TestClientOptions(t *testing.T) { "TLS tlsCertificateFile and tlsPrivateKeyFile and tlsCertificateKeyFile", "mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem&tlsPrivateKeyFile=testdata/nopass/key.pem&tlsCertificateKeyFile=testdata/nopass/certificate.pem", &ClientOptions{ - err: errutil.WrapErrorf( + err: fmt.Errorf( + "error validating uri: %w", errors.New("the sslClientCertificateKeyFile/tlsCertificateKeyFile URI option cannot be provided "+ - "along with tlsCertificateFile or tlsPrivateKeyFile"), - "error validating uri", - ), + "along with tlsCertificateFile or tlsPrivateKeyFile")), HTTPClient: httputil.DefaultHTTPClient, }, }, diff --git a/x/mongo/driver/auth/creds/gcpcreds.go b/x/mongo/driver/auth/creds/gcpcreds.go index 821071665c..74f352e36e 100644 --- a/x/mongo/driver/auth/creds/gcpcreds.go +++ b/x/mongo/driver/auth/creds/gcpcreds.go @@ -14,7 +14,6 @@ import ( "net/http" "os" - "go.mongodb.org/mongo-driver/internal/errutil" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -37,20 +36,23 @@ func (p GCPCredentialProvider) GetCredentialsDoc(ctx context.Context) (bsoncore. url := fmt.Sprintf("http://%s/computeMetadata/v1/instance/service-accounts/default/token", metadataHost) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return nil, errutil.WrapErrorf(err, "unable to retrieve GCP credentials") + return nil, fmt.Errorf("unable to retrieve GCP credentials: %w", err) } req.Header.Set("Metadata-Flavor", "Google") resp, err := p.httpClient.Do(req.WithContext(ctx)) if err != nil { - return nil, errutil.WrapErrorf(err, "unable to retrieve GCP credentials") + return nil, fmt.Errorf("unable to retrieve GCP credentials: %w", err) } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, errutil.WrapErrorf(err, "unable to retrieve GCP credentials: error reading response body") + return nil, fmt.Errorf("unable to retrieve GCP credentials: error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { - return nil, errutil.WrapErrorf(err, "unable to retrieve GCP credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body) + return nil, fmt.Errorf( + "unable to retrieve GCP credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", + resp.StatusCode, + body) } var tokenResponse struct { AccessToken string `json:"access_token"` @@ -58,7 +60,10 @@ func (p GCPCredentialProvider) GetCredentialsDoc(ctx context.Context) (bsoncore. // Attempt to read body as JSON err = json.Unmarshal(body, &tokenResponse) if err != nil { - return nil, errutil.WrapErrorf(err, "unable to retrieve GCP credentials: error reading body JSON. Response body: %s", body) + return nil, fmt.Errorf( + "unable to retrieve GCP credentials: error reading body JSON: %w (response body: %s)", + err, + body) } if tokenResponse.AccessToken == "" { return nil, fmt.Errorf("unable to retrieve GCP credentials: got unexpected empty accessToken from GCP Metadata Server. Response body: %s", body) diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 819c63b8ae..983c1dab22 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "go.mongodb.org/mongo-driver/internal/errutil" "go.mongodb.org/mongo-driver/internal/randutil" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/mongo/driver/dns" @@ -58,11 +57,11 @@ func ParseAndValidate(s string) (ConnString, error) { p := parser{dnsResolver: dns.DefaultResolver} err := p.parse(s) if err != nil { - return p.ConnString, errutil.WrapErrorf(err, "error parsing uri") + return p.ConnString, fmt.Errorf("error parsing uri: %w", err) } err = p.ConnString.Validate() if err != nil { - return p.ConnString, errutil.WrapErrorf(err, "error validating uri") + return p.ConnString, fmt.Errorf("error validating uri: %w", err) } return p.ConnString, nil } @@ -74,7 +73,7 @@ func Parse(s string) (ConnString, error) { p := parser{dnsResolver: dns.DefaultResolver} err := p.parse(s) if err != nil { - err = errutil.WrapErrorf(err, "error parsing uri") + err = fmt.Errorf("error parsing uri: %w", err) } return p.ConnString, err } @@ -240,7 +239,7 @@ func (p *parser) parse(original string) error { // remove the scheme uri = uri[len(SchemeMongoDB)+3:] } else { - return fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"") + return errors.New(`scheme must be "mongodb" or "mongodb+srv"`) } if idx := strings.Index(uri, "@"); idx != -1 { @@ -262,7 +261,7 @@ func (p *parser) parse(original string) error { } p.Username, err = url.PathUnescape(username) if err != nil { - return errutil.WrapErrorf(err, "invalid username") + return fmt.Errorf("invalid username: %w", err) } p.UsernameSet = true @@ -275,7 +274,7 @@ func (p *parser) parse(original string) error { } p.Password, err = url.PathUnescape(password) if err != nil { - return errutil.WrapErrorf(err, "invalid password") + return fmt.Errorf("invalid password: %w", err) } } @@ -352,7 +351,7 @@ func (p *parser) parse(original string) error { for _, host := range parsedHosts { err = p.addHost(host) if err != nil { - return errutil.WrapErrorf(err, "invalid host %q", host) + return fmt.Errorf("invalid host %q: %w", host, err) } } if len(p.Hosts) == 0 { @@ -597,7 +596,7 @@ func (p *parser) addHost(host string) error { } host, err := url.QueryUnescape(host) if err != nil { - return errutil.WrapErrorf(err, "invalid host %q", host) + return fmt.Errorf("invalid host %q: %w", host, err) } _, port, err := net.SplitHostPort(host) @@ -612,7 +611,7 @@ func (p *parser) addHost(host string) error { if port != "" { d, err := strconv.Atoi(port) if err != nil { - return errutil.WrapErrorf(err, "port must be an integer") + return fmt.Errorf("port must be an integer: %w", err) } if d <= 0 || d >= 65536 { return fmt.Errorf("port must be in the range [1, 65535]") @@ -630,12 +629,12 @@ func (p *parser) addOption(pair string) error { key, err := url.QueryUnescape(kv[0]) if err != nil { - return errutil.WrapErrorf(err, "invalid option key %q", kv[0]) + return fmt.Errorf("invalid option key %q: %w", kv[0], err) } value, err := url.QueryUnescape(kv[1]) if err != nil { - return errutil.WrapErrorf(err, "invalid option value %q", kv[1]) + return fmt.Errorf("invalid option value %q: %w", kv[1], err) } lowerKey := strings.ToLower(key) @@ -1051,7 +1050,7 @@ func extractDatabaseFromURI(uri string) (extractedDatabase, error) { escapedDatabase, err := url.QueryUnescape(database) if err != nil { - return extractedDatabase{}, errutil.WrapErrorf(err, "invalid database %q", database) + return extractedDatabase{}, fmt.Errorf("invalid database %q: %w", database, err) } uri = uri[len(database):] diff --git a/x/mongo/driver/mongocrypt/mongocrypt.go b/x/mongo/driver/mongocrypt/mongocrypt.go index f8138fc00e..20f6ff0aa9 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt.go +++ b/x/mongo/driver/mongocrypt/mongocrypt.go @@ -23,7 +23,6 @@ import ( "unsafe" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/internal/errutil" "go.mongodb.org/mongo-driver/internal/httputil" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds" @@ -512,7 +511,7 @@ func (m *MongoCrypt) GetKmsProviders(ctx context.Context) (bsoncore.Document, er for k, p := range m.kmsProviders { doc, err := p.GetCredentialsDoc(ctx) if err != nil { - return nil, errutil.WrapErrorf(err, "unable to retrieve %s credentials", k) + return nil, fmt.Errorf("unable to retrieve %s credentials: %w", k, err) } builder.AppendDocument(k, doc) } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 8234482ac0..90573daa53 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -23,7 +23,6 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/internal/errutil" "go.mongodb.org/mongo-driver/internal/handshake" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/mongo/address" @@ -690,8 +689,11 @@ func (op Operation) Execute(ctx context.Context) error { err = ctx.Err() } else if deadline, ok := ctx.Deadline(); ok { if csot.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTTMonitor().P90()).After(deadline) { - err = errutil.WrapErrorf(ErrDeadlineWouldBeExceeded, - "remaining time %v until context deadline is less than 90th percentile RTT\n%v", time.Until(deadline), srvr.RTTMonitor().Stats()) + err = fmt.Errorf( + "remaining time %v until context deadline is less than 90th percentile RTT: %w\n%v", + time.Until(deadline), + ErrDeadlineWouldBeExceeded, + srvr.RTTMonitor().Stats()) } else if time.Now().Add(srvr.RTTMonitor().Min()).After(deadline) { err = context.DeadlineExceeded } @@ -1376,9 +1378,11 @@ func (op Operation) calculateMaxTimeMS(ctx context.Context, rtt90 time.Duration, // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). maxTimeMS := int64((maxTime + (time.Millisecond - 1)) / time.Millisecond) if maxTimeMS <= 0 { - return 0, errutil.WrapErrorf(ErrDeadlineWouldBeExceeded, - "remaining time %v until context deadline is less than or equal to 90th percentile RTT\n%v", - remainingTimeout, rttStats) + return 0, fmt.Errorf( + "remaining time %v until context deadline is less than or equal to 90th percentile RTT: %w\n%v", + remainingTimeout, + ErrDeadlineWouldBeExceeded, + rttStats) } return uint64(maxTimeMS), nil }