diff --git a/cmd/spire-server/cli/bundle/common.go b/cmd/spire-server/cli/bundle/common.go index a8b24780e8..26a3f52cb8 100644 --- a/cmd/spire-server/cli/bundle/common.go +++ b/cmd/spire-server/cli/bundle/common.go @@ -17,7 +17,6 @@ import ( "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/cmd/spire-server/util" "github.com/spiffe/spire/pkg/common/jwtutil" - "github.com/zeebo/errs" ) const ( @@ -78,7 +77,7 @@ func printBundle(out io.Writer, bundle *types.Bundle) error { docBytes, err := b.Marshal() if err != nil { - return errs.Wrap(err) + return err } var o bytes.Buffer @@ -87,7 +86,7 @@ func printBundle(out io.Writer, bundle *types.Bundle) error { } if _, err := fmt.Fprintln(out, o.String()); err != nil { - return errs.Wrap(err) + return err } return nil diff --git a/go.mod b/go.mod index a1a258efe2..6318a150fb 100644 --- a/go.mod +++ b/go.mod @@ -77,7 +77,6 @@ require ( github.com/stretchr/testify v1.10.0 github.com/uber-go/tally/v4 v4.1.16 github.com/valyala/fastjson v1.6.4 - github.com/zeebo/errs v1.4.0 golang.org/x/crypto v0.32.0 golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 golang.org/x/net v0.34.0 @@ -280,6 +279,7 @@ require ( github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/yashtewari/glob-intersection v0.2.0 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect + github.com/zeebo/errs v1.4.0 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.31.0 // indirect diff --git a/pkg/agent/attestor/node/node.go b/pkg/agent/attestor/node/node.go index c7d0cdca3e..da7f024f98 100644 --- a/pkg/agent/attestor/node/node.go +++ b/pkg/agent/attestor/node/node.go @@ -28,7 +28,6 @@ import ( "github.com/spiffe/spire/pkg/common/tlspolicy" "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/pkg/common/x509util" - "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) @@ -101,7 +100,7 @@ func (a *attestor) Attest(ctx context.Context) (res *AttestationResult, err erro // This is a bizarre case where we have an SVID but were unable to // load a bundle from the cache which suggests some tampering with the // cache on disk. - return nil, errs.New("SVID loaded but no bundle in cache") + return nil, errors.New("SVID loaded but no bundle in cache") default: log.WithField(telemetry.SPIFFEID, svid[0].URIs[0].String()).Info("SVID loaded") } @@ -265,7 +264,7 @@ func (a *attestor) serverConn(ctx context.Context, bundle *spiffebundle.Bundle) if !a.c.InsecureBootstrap { // We shouldn't get here since loadBundle() should fail if the bundle // is empty, but just in case... - return nil, errs.New("no bundle and not doing insecure bootstrap") + return nil, errors.New("no bundle and not doing insecure bootstrap") } // Insecure bootstrapping. Do not verify the server chain but rather do a @@ -279,7 +278,7 @@ func (a *attestor) serverConn(ctx context.Context, bundle *spiffebundle.Bundle) if len(rawCerts) == 0 { // This is not really possible without a catastrophic bug // creeping into the TLS stack. - return errs.New("server chain is unexpectedly empty") + return errors.New("server chain is unexpectedly empty") } expectedServerID, err := idutil.ServerID(a.c.TrustDomain) @@ -292,7 +291,7 @@ func (a *attestor) serverConn(ctx context.Context, bundle *spiffebundle.Bundle) return err } if len(serverCert.URIs) != 1 || serverCert.URIs[0].String() != expectedServerID.String() { - return errs.New("expected server SPIFFE ID %q; got %q", expectedServerID, serverCert.URIs) + return fmt.Errorf("expected server SPIFFE ID %q; got %q", expectedServerID, serverCert.URIs) } return nil }, diff --git a/pkg/agent/endpoints/sdsv3/handler.go b/pkg/agent/endpoints/sdsv3/handler.go index 664e9c9f85..64188a4ae1 100644 --- a/pkg/agent/endpoints/sdsv3/handler.go +++ b/pkg/agent/endpoints/sdsv3/handler.go @@ -22,7 +22,6 @@ import ( "github.com/spiffe/spire/pkg/common/pemutil" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb" @@ -99,7 +98,7 @@ func (h *Handler) StreamSecrets(stream secret_v3.SecretDiscoveryService_StreamSe }() var versionCounter int64 - var versionInfo = strconv.FormatInt(versionCounter, 10) + versionInfo := strconv.FormatInt(versionCounter, 10) var lastNonce string var lastNode *core_v3.Node var upd *cache.WorkloadUpdate @@ -150,7 +149,7 @@ func (h *Handler) StreamSecrets(stream secret_v3.SecretDiscoveryService_StreamSe // We need to send updates if the requested resource list has changed // either explicitly, or implicitly because this is the first request. - var sendUpdates = lastReq == nil || subListChanged(lastReq.ResourceNames, newReq.ResourceNames) + sendUpdates := lastReq == nil || subListChanged(lastReq.ResourceNames, newReq.ResourceNames) // save request so that all future workload updates lead to SDS updates for the last request lastReq = newReq @@ -206,7 +205,7 @@ func subListChanged(oldSubs []string, newSubs []string) (b bool) { if len(oldSubs) != len(newSubs) { return true } - var subMap = make(map[string]bool) + subMap := make(map[string]bool) for _, sub := range oldSubs { subMap[sub] = true } @@ -582,7 +581,7 @@ func nextNonce() (string, error) { b := make([]byte, 4) _, err := rand.Read(b) if err != nil { - return "", errs.Wrap(err) + return "", err } return hex.EncodeToString(b), nil } diff --git a/pkg/agent/endpoints/workload/handler.go b/pkg/agent/endpoints/workload/handler.go index 9f191a1471..68cf81087a 100644 --- a/pkg/agent/endpoints/workload/handler.go +++ b/pkg/agent/endpoints/workload/handler.go @@ -22,7 +22,6 @@ import ( "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/x509util" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" @@ -512,12 +511,12 @@ func keyStoreFromBundles(bundles []*spiffebundle.Bundle) (jwtsvid.KeyStore, erro func structFromValues(values map[string]any) (*structpb.Struct, error) { valuesJSON, err := json.Marshal(values) if err != nil { - return nil, errs.Wrap(err) + return nil, err } s := new(structpb.Struct) if err := protojson.Unmarshal(valuesJSON, s); err != nil { - return nil, errs.Wrap(err) + return nil, err } return s, nil diff --git a/pkg/agent/plugin/nodeattestor/k8spsat/psat.go b/pkg/agent/plugin/nodeattestor/k8spsat/psat.go index 20e33c4c84..47f95ba21b 100644 --- a/pkg/agent/plugin/nodeattestor/k8spsat/psat.go +++ b/pkg/agent/plugin/nodeattestor/k8spsat/psat.go @@ -3,6 +3,7 @@ package k8spsat import ( "context" "encoding/json" + "fmt" "os" "sync" @@ -12,7 +13,6 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/k8s" "github.com/spiffe/spire/pkg/common/pluginconf" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -145,10 +145,10 @@ func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { func loadTokenFromFile(path string) (string, error) { data, err := os.ReadFile(path) if err != nil { - return "", errs.Wrap(err) + return "", err } if len(data) == 0 { - return "", errs.New("%q is empty", path) + return "", fmt.Errorf("%q is empty", path) } return string(data), nil } diff --git a/pkg/agent/plugin/nodeattestor/k8ssat/sat.go b/pkg/agent/plugin/nodeattestor/k8ssat/sat.go index bce6fd91e6..d93d39a1d9 100644 --- a/pkg/agent/plugin/nodeattestor/k8ssat/sat.go +++ b/pkg/agent/plugin/nodeattestor/k8ssat/sat.go @@ -14,7 +14,6 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/k8s" "github.com/spiffe/spire/pkg/common/pluginconf" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -148,10 +147,10 @@ func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { func loadTokenFromFile(path string) (string, error) { data, err := os.ReadFile(path) if err != nil { - return "", errs.Wrap(err) + return "", err } if len(data) == 0 { - return "", errs.New("%q is empty", path) + return "", fmt.Errorf("%q is empty", path) } return string(data), nil } diff --git a/pkg/common/bundleutil/unmarshal.go b/pkg/common/bundleutil/unmarshal.go index c49fbadcb2..ff86b79a17 100644 --- a/pkg/common/bundleutil/unmarshal.go +++ b/pkg/common/bundleutil/unmarshal.go @@ -8,7 +8,6 @@ import ( "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/zeebo/errs" ) func Decode(trustDomain spiffeid.TrustDomain, r io.Reader) (*spiffebundle.Bundle, error) { @@ -22,7 +21,7 @@ func Decode(trustDomain spiffeid.TrustDomain, r io.Reader) (*spiffebundle.Bundle func Unmarshal(trustDomain spiffeid.TrustDomain, data []byte) (*spiffebundle.Bundle, error) { doc := new(bundleDoc) if err := json.Unmarshal(data, doc); err != nil { - return nil, errs.Wrap(err) + return nil, err } return unmarshal(trustDomain, doc) } @@ -35,20 +34,20 @@ func unmarshal(trustDomain spiffeid.TrustDomain, doc *bundleDoc) (*spiffebundle. switch key.Use { case x509SVIDUse: if len(key.Certificates) != 1 { - return nil, errs.New("expected a single certificate in x509-svid entry %d; got %d", i, len(key.Certificates)) + return nil, fmt.Errorf("expected a single certificate in x509-svid entry %d; got %d", i, len(key.Certificates)) } bundle.AddX509Authority(key.Certificates[0]) case jwtSVIDUse: if key.KeyID == "" { - return nil, errs.New("missing key ID in jwt-svid entry %d", i) + return nil, fmt.Errorf("missing key ID in jwt-svid entry %d", i) } if err := bundle.AddJWTAuthority(key.KeyID, key.Key); err != nil { - return nil, errs.New("failed to add jwt-svid entry %d: %v", i, err) + return nil, fmt.Errorf("failed to add jwt-svid entry %d: %w", i, err) } case "": - return nil, errs.New("missing use for key entry %d", i) + return nil, fmt.Errorf("missing use for key entry %d", i) default: - return nil, errs.New("unrecognized use %q for key entry %d", key.Use, i) + return nil, fmt.Errorf("unrecognized use %q for key entry %d", key.Use, i) } } diff --git a/pkg/common/catalog/builtin.go b/pkg/common/catalog/builtin.go index ae246e8164..919681fdba 100644 --- a/pkg/common/catalog/builtin.go +++ b/pkg/common/catalog/builtin.go @@ -11,7 +11,6 @@ import ( "github.com/spiffe/spire-plugin-sdk/pluginsdk" "github.com/spiffe/spire-plugin-sdk/private" "github.com/spiffe/spire/pkg/common/log" - "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -147,7 +146,7 @@ func startPipeServer(server *grpc.Server, log logrus.FieldLogger) (_ *pipeConn, // Dial the server conn, err := grpc.Dial("IGNORED", grpc.WithBlock(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(pipeNet.DialContext)) //nolint: staticcheck // It is going to be resolved on #5152 if err != nil { - return nil, errs.Wrap(err) + return nil, err } closers = append(closers, conn) diff --git a/pkg/common/catalog/closers.go b/pkg/common/catalog/closers.go index 4e418ca905..d72a186fae 100644 --- a/pkg/common/catalog/closers.go +++ b/pkg/common/catalog/closers.go @@ -1,10 +1,10 @@ package catalog import ( + "errors" "io" "time" - "github.com/zeebo/errs" "google.golang.org/grpc" ) @@ -12,11 +12,12 @@ type closerGroup []io.Closer func (cs closerGroup) Close() error { // Close in reverse order. - var errs errs.Group + var errs error for i := len(cs) - 1; i >= 0; i-- { - errs.Add(cs[i].Close()) + errs = errors.Join(errs, cs[i].Close()) } - return errs.Err() + + return errs } type closerFunc func() diff --git a/pkg/common/catalog/external.go b/pkg/common/catalog/external.go index 1a65b19f53..177de77b59 100644 --- a/pkg/common/catalog/external.go +++ b/pkg/common/catalog/external.go @@ -14,7 +14,6 @@ import ( "github.com/spiffe/spire-plugin-sdk/pluginsdk" "github.com/spiffe/spire-plugin-sdk/private" "github.com/spiffe/spire/pkg/common/log" - "github.com/zeebo/errs" "google.golang.org/grpc" ) @@ -154,7 +153,7 @@ func (p *hcClientPlugin) GRPCClient(ctx context.Context, b *goplugin.GRPCBroker, // does not work yet anyway, so it is a moot point. listener, err := b.Accept(private.HostServiceProviderID) if err != nil { - return nil, errs.Wrap(err) + return nil, err } server := newHostServer(p.config.Log, p.config.Name, p.config.HostServices) diff --git a/pkg/common/cryptoutil/keys.go b/pkg/common/cryptoutil/keys.go index db73567185..fa4a1e938a 100644 --- a/pkg/common/cryptoutil/keys.go +++ b/pkg/common/cryptoutil/keys.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/go-jose/go-jose/v4" - "github.com/zeebo/errs" ) func RSAPublicKeyEqual(a, b *rsa.PublicKey) bool { @@ -58,7 +57,7 @@ func JoseAlgFromPublicKey(publicKey any) (jose.SignatureAlgorithm, error) { case *rsa.PublicKey: // Prevent the use of keys smaller than 2048 bits if publicKey.Size() < 256 { - return "", errs.New("unsupported RSA key size: %d", publicKey.Size()) + return "", fmt.Errorf("unsupported RSA key size: %d", publicKey.Size()) } alg = jose.RS256 case *ecdsa.PublicKey: @@ -69,10 +68,10 @@ func JoseAlgFromPublicKey(publicKey any) (jose.SignatureAlgorithm, error) { case 384: alg = jose.ES384 default: - return "", errs.New("unable to determine signature algorithm for EC public key size %d", params.BitSize) + return "", fmt.Errorf("unable to determine signature algorithm for EC public key size %d", params.BitSize) } default: - return "", errs.New("unable to determine signature algorithm for public key type %T", publicKey) + return "", fmt.Errorf("unable to determine signature algorithm for public key type %T", publicKey) } return alg, nil } diff --git a/pkg/common/jwtsvid/common.go b/pkg/common/jwtsvid/common.go index 6d529bedbf..b1e84e30a3 100644 --- a/pkg/common/jwtsvid/common.go +++ b/pkg/common/jwtsvid/common.go @@ -5,18 +5,17 @@ import ( "time" "github.com/go-jose/go-jose/v4/jwt" - "github.com/zeebo/errs" ) func GetTokenExpiry(token string) (time.Time, time.Time, error) { tok, err := jwt.ParseSigned(token, AllowedSignatureAlgorithms) if err != nil { - return time.Time{}, time.Time{}, errs.Wrap(err) + return time.Time{}, time.Time{}, err } claims := jwt.Claims{} if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { - return time.Time{}, time.Time{}, errs.Wrap(err) + return time.Time{}, time.Time{}, err } if claims.IssuedAt == nil { return time.Time{}, time.Time{}, errors.New("JWT missing iat claim") diff --git a/pkg/common/jwtsvid/validate.go b/pkg/common/jwtsvid/validate.go index dce51831d5..33e46fa349 100644 --- a/pkg/common/jwtsvid/validate.go +++ b/pkg/common/jwtsvid/validate.go @@ -9,7 +9,6 @@ import ( "github.com/go-jose/go-jose/v4/jwt" "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/zeebo/errs" ) type KeyStore interface { @@ -41,17 +40,17 @@ func (t *keyStore) FindPublicKey(_ context.Context, td spiffeid.TrustDomain, key func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audience []string) (spiffeid.ID, map[string]any, error) { tok, err := jwt.ParseSigned(token, AllowedSignatureAlgorithms) if err != nil { - return spiffeid.ID{}, nil, errs.New("unable to parse JWT token: %v", err) + return spiffeid.ID{}, nil, fmt.Errorf("unable to parse JWT token: %w", err) } if len(tok.Headers) != 1 { - return spiffeid.ID{}, nil, errs.New("expected a single token header; got %d", len(tok.Headers)) + return spiffeid.ID{}, nil, fmt.Errorf("expected a single token header; got %d", len(tok.Headers)) } // Obtain the key ID from the header keyID := tok.Headers[0].KeyID if keyID == "" { - return spiffeid.ID{}, nil, errs.New("token header missing key id") + return spiffeid.ID{}, nil, fmt.Errorf("token header missing key id") } // Parse out the unverified claims. We need to look up the key by the trust @@ -59,14 +58,14 @@ func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audienc // when creating the generic map of claims that we return to the caller. var claims jwt.Claims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { - return spiffeid.ID{}, nil, errs.Wrap(err) + return spiffeid.ID{}, nil, err } if claims.Subject == "" { - return spiffeid.ID{}, nil, errs.New("token missing subject claim") + return spiffeid.ID{}, nil, errors.New("token missing subject claim") } spiffeID, err := spiffeid.FromString(claims.Subject) if err != nil { - return spiffeid.ID{}, nil, errs.New("token has in invalid subject claim: %v", err) + return spiffeid.ID{}, nil, fmt.Errorf("token has in invalid subject claim: %w", err) } // Construct the trust domain id from the SPIFFE ID and look up key by ID @@ -78,7 +77,7 @@ func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audienc // Now obtain the generic claims map verified using the obtained key claimsMap := make(map[string]any) if err := tok.Claims(key, &claimsMap); err != nil { - return spiffeid.ID{}, nil, errs.Wrap(err) + return spiffeid.ID{}, nil, err } // Now that the signature over the claims has been verified, validate the @@ -90,11 +89,9 @@ func ValidateToken(ctx context.Context, token string, keyStore KeyStore, audienc // Convert expected validation errors for pretty errors switch { case errors.Is(err, jwt.ErrExpired): - err = errs.New("token has expired") + err = errors.New("token has expired") case errors.Is(err, jwt.ErrInvalidAudience): - err = errs.New("expected audience in %q (audience=%q)", audience, claims.Audience) - default: - err = errs.Wrap(err) + err = fmt.Errorf("expected audience in %q (audience=%q)", audience, claims.Audience) } return spiffeid.ID{}, nil, err } diff --git a/pkg/common/jwtutil/keyset.go b/pkg/common/jwtutil/keyset.go index a188fe7b29..a233dc2cf6 100644 --- a/pkg/common/jwtutil/keyset.go +++ b/pkg/common/jwtutil/keyset.go @@ -3,6 +3,8 @@ package jwtutil import ( "context" "encoding/json" + "errors" + "fmt" "io" "net/http" "net/url" @@ -12,7 +14,6 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/sirupsen/logrus" - "github.com/zeebo/errs" ) const ( @@ -34,7 +35,7 @@ type OIDCIssuer string func (c OIDCIssuer) GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) { u, err := url.Parse(string(c)) if err != nil { - return nil, errs.Wrap(err) + return nil, err } u.Path = path.Join(u.Path, wellKnownOpenIDConfiguration) @@ -86,7 +87,7 @@ func (c *CachingKeySetProvider) GetKeySet(ctx context.Context) (*jose.JSONWebKey } else { logrus.WithError(err).Warn("Unable to refresh key set") if c.jwks == nil { - return nil, errs.Wrap(err) + return nil, err } } @@ -96,27 +97,27 @@ func (c *CachingKeySetProvider) GetKeySet(ctx context.Context) (*jose.JSONWebKey func DiscoverKeySetURI(ctx context.Context, configURL string) (string, error) { req, err := http.NewRequest("GET", configURL, nil) if err != nil { - return "", errs.Wrap(err) + return "", err } req = req.WithContext(ctx) resp, err := http.DefaultClient.Do(req) if err != nil { - return "", errs.Wrap(err) + return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) + return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) } config := &struct { JWKSURI string `json:"jwks_uri"` }{} if err := json.NewDecoder(resp.Body).Decode(config); err != nil { - return "", errs.New("failed to decode configuration: %v", err) + return "", fmt.Errorf("failed to decode configuration: %w", err) } if config.JWKSURI == "" { - return "", errs.New("configuration missing JWKS URI") + return "", errors.New("configuration missing JWKS URI") } return config.JWKSURI, nil @@ -125,22 +126,22 @@ func DiscoverKeySetURI(ctx context.Context, configURL string) (string, error) { func FetchKeySet(ctx context.Context, jwksURI string) (*jose.JSONWebKeySet, error) { req, err := http.NewRequest("GET", jwksURI, nil) if err != nil { - return nil, errs.Wrap(err) + return nil, err } req = req.WithContext(ctx) resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, errs.Wrap(err) + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) + return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) } jwks := new(jose.JSONWebKeySet) if err := json.NewDecoder(resp.Body).Decode(jwks); err != nil { - return nil, errs.New("failed to decode key set: %v", err) + return nil, fmt.Errorf("failed to decode key set: %w", err) } return jwks, nil diff --git a/pkg/common/plugin/aws/iid.go b/pkg/common/plugin/aws/iid.go index 8b8fcea741..6da18e5c82 100644 --- a/pkg/common/plugin/aws/iid.go +++ b/pkg/common/plugin/aws/iid.go @@ -1,19 +1,12 @@ package aws -import ( - "github.com/zeebo/errs" -) +import "fmt" const ( // PluginName for AWS IID PluginName = "aws_iid" ) -var ( - IidErrorClass = errs.Class("aws-iid") - iidError = IidErrorClass -) - // IIDAttestationData AWS IID attestation data type IIDAttestationData struct { Document string `json:"document"` @@ -23,5 +16,5 @@ type IIDAttestationData struct { // AttestationStepError error with attestation func AttestationStepError(step string, cause error) error { - return iidError.New("attempted attestation but an error occurred %s: %w", step, cause) + return fmt.Errorf("aws-iid: attempted attestation but an error occurred %s: %w", step, cause) } diff --git a/pkg/common/plugin/azure/msi.go b/pkg/common/plugin/azure/msi.go index 99356cbbc3..129c4dbdde 100644 --- a/pkg/common/plugin/azure/msi.go +++ b/pkg/common/plugin/azure/msi.go @@ -2,6 +2,8 @@ package azure import ( "encoding/json" + "errors" + "fmt" "io" "net/http" @@ -9,7 +11,6 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/agentpathtemplate" "github.com/spiffe/spire/pkg/common/idutil" - "github.com/zeebo/errs" ) const ( @@ -56,7 +57,7 @@ func (fn HTTPClientFunc) Do(req *http.Request) (*http.Response, error) { func FetchMSIToken(cl HTTPClient, resource string) (string, error) { req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01", nil) if err != nil { - return "", errs.Wrap(err) + return "", err } req.Header.Add("Metadata", "true") @@ -66,11 +67,11 @@ func FetchMSIToken(cl HTTPClient, resource string) (string, error) { resp, err := cl.Do(req) if err != nil { - return "", errs.Wrap(err) + return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) + return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) } r := struct { @@ -78,11 +79,11 @@ func FetchMSIToken(cl HTTPClient, resource string) (string, error) { }{} if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { - return "", errs.New("unable to decode response: %v", err) + return "", fmt.Errorf("unable to decode response: %w", err) } if r.AccessToken == "" { - return "", errs.New("response missing access token") + return "", fmt.Errorf("response missing access token") } return r.AccessToken, nil @@ -91,31 +92,31 @@ func FetchMSIToken(cl HTTPClient, resource string) (string, error) { func FetchInstanceMetadata(cl HTTPClient) (*InstanceMetadata, error) { req, err := http.NewRequest("GET", "http://169.254.169.254/metadata/instance?api-version=2017-08-01&format=json", nil) if err != nil { - return nil, errs.Wrap(err) + return nil, err } req.Header.Add("Metadata", "true") resp, err := cl.Do(req) if err != nil { - return nil, errs.Wrap(err) + return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, errs.New("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) + return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, tryRead(resp.Body)) } metadata := new(InstanceMetadata) if err := json.NewDecoder(resp.Body).Decode(metadata); err != nil { - return nil, errs.New("unable to decode response: %v", err) + return nil, fmt.Errorf("unable to decode response: %w", err) } switch { case metadata.Compute.Name == "": - return nil, errs.New("response missing instance name") + return nil, errors.New("response missing instance name") case metadata.Compute.SubscriptionID == "": - return nil, errs.New("response missing instance subscription id") + return nil, errors.New("response missing instance subscription id") case metadata.Compute.ResourceGroupName == "": - return nil, errs.New("response missing instance resource group name") + return nil, errors.New("response missing instance resource group name") } return metadata, nil diff --git a/pkg/common/profiling/dumpers.go b/pkg/common/profiling/dumpers.go index e7fa6e5442..5e47b414fa 100644 --- a/pkg/common/profiling/dumpers.go +++ b/pkg/common/profiling/dumpers.go @@ -6,8 +6,6 @@ import ( "runtime/pprof" "runtime/trace" "strings" - - "github.com/zeebo/errs" ) const ( @@ -99,7 +97,7 @@ func (d *traceDumper) Dump(timestamp string, name string) error { d.data.Close() filename := getFilename(timestamp, d.c.Tag, name) if err := os.Rename(getTempFilename(d.c.Tag, traceProfTmpFilename), filename); err != nil { - return errs.Wrap(err) + return err } return d.Prepare() } @@ -133,7 +131,7 @@ func (d *cpuDumper) Dump(timestamp string, name string) error { d.data.Close() filename := getFilename(timestamp, d.c.Tag, name) if err := os.Rename(getTempFilename(d.c.Tag, cpuProfTmpFilename), filename); err != nil { - return errs.Wrap(err) + return err } return d.Prepare() } diff --git a/pkg/common/util/csr.go b/pkg/common/util/csr.go index 089ae61393..bdd98f7d92 100644 --- a/pkg/common/util/csr.go +++ b/pkg/common/util/csr.go @@ -7,7 +7,6 @@ import ( "net/url" "github.com/spiffe/go-spiffe/v2/spiffeid" - "github.com/zeebo/errs" ) func MakeCSR(privateKey any, spiffeID spiffeid.ID) ([]byte, error) { @@ -33,7 +32,7 @@ func MakeCSRWithoutURISAN(privateKey any) ([]byte, error) { func makeCSR(privateKey any, template *x509.CertificateRequest) ([]byte, error) { csr, err := x509.CreateCertificateRequest(rand.Reader, template, privateKey) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return csr, nil } diff --git a/pkg/server/bundle/client/client.go b/pkg/server/bundle/client/client.go index 2462a0917b..009dc3721a 100644 --- a/pkg/server/bundle/client/client.go +++ b/pkg/server/bundle/client/client.go @@ -14,7 +14,6 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig" "github.com/spiffe/spire/pkg/common/bundleutil" "github.com/spiffe/spire/pkg/common/tlspolicy" - "github.com/zeebo/errs" ) type SPIFFEAuthConfig struct { @@ -92,15 +91,15 @@ func (c *client) FetchBundle(context.Context) (*spiffebundle.Bundle, error) { var hostnameError x509.HostnameError if errors.As(err, &hostnameError) && c.c.SPIFFEAuth == nil && len(hostnameError.Certificate.URIs) > 0 { if id, idErr := spiffeid.FromString(hostnameError.Certificate.URIs[0].String()); idErr == nil { - return nil, errs.New("failed to authenticate bundle endpoint using web authentication but the server certificate contains SPIFFE ID %q: maybe use https_spiffe instead of https_web: %v", id, err) + return nil, fmt.Errorf("failed to authenticate bundle endpoint using web authentication but the server certificate contains SPIFFE ID %q: maybe use https_spiffe instead of https_web: %w", id, err) } } - return nil, errs.New("failed to fetch bundle: %v", err) + return nil, fmt.Errorf("failed to fetch bundle: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, errs.New("unexpected status %d fetching bundle: %s", resp.StatusCode, tryRead(resp.Body)) + return nil, fmt.Errorf("unexpected status %d fetching bundle: %s", resp.StatusCode, tryRead(resp.Body)) } b, err := bundleutil.Decode(c.c.TrustDomain, resp.Body) diff --git a/pkg/server/bundle/client/manager_test.go b/pkg/server/bundle/client/manager_test.go index b2a4855bfc..e883e1a520 100644 --- a/pkg/server/bundle/client/manager_test.go +++ b/pkg/server/bundle/client/manager_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "errors" + "fmt" "sync" "testing" "time" @@ -17,7 +18,6 @@ import ( "github.com/spiffe/spire/test/fakes/fakedatastore" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zeebo/errs" ) func TestManagerPeriodicBundleRefresh(t *testing.T) { @@ -278,7 +278,7 @@ func newManagerTest(t *testing.T, source TrustDomainConfigSource, localBundles, go func() { defer func() { if r := recover(); r != nil { - errCh <- errs.New("%+v", r) + errCh <- fmt.Errorf("%+v", r) } }() errCh <- test.manager.Run(ctx) diff --git a/pkg/server/bundle/client/updater.go b/pkg/server/bundle/client/updater.go index 3e906d4d62..b268570f0b 100644 --- a/pkg/server/bundle/client/updater.go +++ b/pkg/server/bundle/client/updater.go @@ -10,7 +10,6 @@ import ( "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/common/bundleutil" "github.com/spiffe/spire/pkg/server/datastore" - "github.com/zeebo/errs" ) type BundleUpdaterConfig struct { @@ -141,7 +140,7 @@ func fetchBundleIfExists(ctx context.Context, ds datastore.DataStore, trustDomai // Load the current bundle and extract the root CA certificates bundle, err := ds.FetchBundle(ctx, trustDomain.IDString()) if err != nil { - return nil, errs.Wrap(err) + return nil, err } if bundle == nil { return nil, nil diff --git a/pkg/server/ca/manager/journal.go b/pkg/server/ca/manager/journal.go index cc280e90cc..be95fad938 100644 --- a/pkg/server/ca/manager/journal.go +++ b/pkg/server/ca/manager/journal.go @@ -14,7 +14,6 @@ import ( "github.com/spiffe/spire/pkg/server/catalog" "github.com/spiffe/spire/pkg/server/datastore" "github.com/spiffe/spire/proto/private/server/journal" - "github.com/zeebo/errs" "google.golang.org/protobuf/proto" ) @@ -125,7 +124,7 @@ func (j *Journal) AppendJWTKey(ctx context.Context, slotID string, issuedAt time pkixBytes, err := x509.MarshalPKIXPublicKey(jwtKey.Signer.Public()) if err != nil { - return errs.Wrap(err) + return err } backup := j.entries.JwtKeys @@ -273,7 +272,7 @@ func (j *Journal) findCAJournal(ctx context.Context) (*datastore.CAJournal, erro func (j *Journal) save(ctx context.Context) error { entriesBytes, err := proto.Marshal(j.entries) if err != nil { - return errs.Wrap(err) + return err } caJournalID, err := j.saveInDatastore(ctx, entriesBytes) @@ -315,7 +314,7 @@ func loadJournalFromDS(ctx context.Context, config *journalConfig) (*Journal, er j.caJournalID = caJournal.ID if err := proto.Unmarshal(caJournal.Data, j.entries); err != nil { - return nil, errs.New("unable to unmarshal entries from CA journal record: %v", err) + return nil, fmt.Errorf("unable to unmarshal entries from CA journal record: %w", err) } return j, nil } diff --git a/pkg/server/ca/manager/manager.go b/pkg/server/ca/manager/manager.go index 0240697fee..bb2975c1a9 100644 --- a/pkg/server/ca/manager/manager.go +++ b/pkg/server/ca/manager/manager.go @@ -28,7 +28,6 @@ import ( "github.com/spiffe/spire/pkg/server/plugin/notifier" "github.com/spiffe/spire/proto/private/server/journal" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -456,7 +455,6 @@ func (m *Manager) PruneBundle(ctx context.Context) (err error) { expiresBefore := m.c.Clock.Now().Add(-safetyThresholdBundle) changed, err := ds.PruneBundle(ctx, m.c.TrustDomain.IDString(), expiresBefore) - if err != nil { return fmt.Errorf("unable to prune bundle: %w", err) } @@ -478,7 +476,6 @@ func (m *Manager) PruneCAJournals(ctx context.Context) (err error) { expiresBefore := m.c.Clock.Now().Add(-safetyThresholdCAJournals) err = ds.PruneCAJournals(ctx, expiresBefore.Unix()) - if err != nil { return fmt.Errorf("unable to prune CA journals: %w", err) } @@ -735,17 +732,18 @@ func (m *Manager) notify(ctx context.Context, event string, advise bool, pre fun }(n) } - var allErrs errs.Group + var allErrs error for range notifiers { // don't select on the ctx here as we can rely on the plugins to // respond to context cancellation and return an error. if err := <-errsCh; err != nil { - allErrs.Add(err) + allErrs = errors.Join(allErrs, err) } } - if err := allErrs.Err(); err != nil { - return errs.New("one or more notifiers returned an error: %v", err) + if allErrs != nil { + return fmt.Errorf("one or more notifiers returned an error: %w", allErrs) } + return nil } @@ -755,7 +753,7 @@ func (m *Manager) fetchRequiredBundle(ctx context.Context) (*common.Bundle, erro return nil, err } if bundle == nil { - return nil, errs.New("trust domain bundle is missing") + return nil, errors.New("trust domain bundle is missing") } return bundle, nil } @@ -764,7 +762,7 @@ func (m *Manager) fetchOptionalBundle(ctx context.Context) (*common.Bundle, erro ds := m.c.Catalog.GetDataStore() bundle, err := ds.FetchBundle(ctx, m.c.TrustDomain.IDString()) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return bundle, nil } @@ -1052,7 +1050,7 @@ func keyIDFromBytes(choices []byte) string { func publicKeyFromJWTKey(jwtKey *ca.JWTKey) (*common.PublicKey, error) { pkixBytes, err := x509.MarshalPKIXPublicKey(jwtKey.Signer.Public()) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return &common.PublicKey{ diff --git a/pkg/server/ca/manager/slot.go b/pkg/server/ca/manager/slot.go index cbfff8c768..fa0be6a33c 100644 --- a/pkg/server/ca/manager/slot.go +++ b/pkg/server/ca/manager/slot.go @@ -19,7 +19,6 @@ import ( "github.com/spiffe/spire/pkg/server/catalog" "github.com/spiffe/spire/proto/private/server/journal" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -274,7 +273,6 @@ func (s *SlotLoader) getJWTKeysSlots(ctx context.Context, entries []*journal.JWT // Instead, we'll rotate into a new one. func (s *SlotLoader) filterInvalidEntries(ctx context.Context, entries *journal.Entries) ([]*journal.JWTKeyEntry, []*journal.X509CAEntry, error) { bundle, err := s.fetchOptionalBundle(ctx) - if err != nil { return nil, nil, err } @@ -314,7 +312,7 @@ func (s *SlotLoader) fetchOptionalBundle(ctx context.Context) (*common.Bundle, e ds := s.Catalog.GetDataStore() bundle, err := ds.FetchBundle(ctx, s.TrustDomain.IDString()) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return bundle, nil } @@ -351,14 +349,14 @@ func (s *SlotLoader) loadX509CASlotFromEntry(ctx context.Context, entry *journal cert, err := x509.ParseCertificate(entry.Certificate) if err != nil { - return nil, "", errs.New("unable to parse CA certificate: %v", err) + return nil, "", fmt.Errorf("unable to parse CA certificate: %w", err) } var upstreamChain []*x509.Certificate for _, certDER := range entry.UpstreamChain { cert, err := x509.ParseCertificate(certDER) if err != nil { - return nil, "", errs.New("unable to parse upstream chain certificate: %v", err) + return nil, "", fmt.Errorf("unable to parse upstream chain certificate: %w", err) } upstreamChain = append(upstreamChain, cert) } @@ -421,7 +419,7 @@ func (s *SlotLoader) loadJWTKeySlotFromEntry(ctx context.Context, entry *journal publicKey, err := x509.ParsePKIXPublicKey(entry.PublicKey) if err != nil { - return nil, "", errs.Wrap(err) + return nil, "", err } signer, err := s.makeSigner(ctx, jwtKeyKmKeyID(entry.SlotId)) @@ -460,7 +458,7 @@ func (s *SlotLoader) makeSigner(ctx context.Context, keyID string) (crypto.Signe case codes.NotFound: return nil, nil default: - return nil, errs.Wrap(err) + return nil, err } } diff --git a/pkg/server/ca/rotator/rotator.go b/pkg/server/ca/rotator/rotator.go index 923a020ca7..000f10494e 100644 --- a/pkg/server/ca/rotator/rotator.go +++ b/pkg/server/ca/rotator/rotator.go @@ -11,7 +11,6 @@ import ( "github.com/spiffe/spire/pkg/common/health" "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/pkg/server/ca/manager" - "github.com/zeebo/errs" ) const ( @@ -138,7 +137,7 @@ func (r *Rotator) rotate(ctx context.Context) error { r.c.Log.WithError(jwtKeyErr).Error("Unable to rotate JWT key") } - return errs.Combine(x509CAErr, jwtKeyErr) + return errors.Join(x509CAErr, jwtKeyErr) } func (r *Rotator) rotateJWTKey(ctx context.Context) error { diff --git a/pkg/server/datastore/sqlstore/errors.go b/pkg/server/datastore/sqlstore/errors.go new file mode 100644 index 0000000000..1aaf152470 --- /dev/null +++ b/pkg/server/datastore/sqlstore/errors.go @@ -0,0 +1,92 @@ +package sqlstore + +import ( + "fmt" +) + +const ( + datastoreSQLErrorPrefix = "datastore-sql" + datastoreValidationErrorPrefix = "datastore-validation" +) + +type sqlError struct { + err error + msg string +} + +func (s *sqlError) Error() string { + if s == nil { + return "" + } + + if s.err != nil { + return fmt.Sprintf("%s: %s", datastoreSQLErrorPrefix, s.err) + } + + return fmt.Sprintf("%s: %s", datastoreSQLErrorPrefix, s.msg) +} + +func (s *sqlError) Unwrap() error { + if s == nil { + return nil + } + + return s.err +} + +type validationError struct { + err error + msg string +} + +func (v *validationError) Error() string { + if v == nil { + return "" + } + + if v.err != nil { + return fmt.Sprintf("%s: %s", datastoreValidationErrorPrefix, v.err) + } + + return fmt.Sprintf("%s: %s", datastoreValidationErrorPrefix, v.msg) +} + +func (v *validationError) Unwrap() error { + if v == nil { + return nil + } + + return v.err +} + +func newSQLError(fmtMsg string, args ...any) error { + return &sqlError{ + msg: fmt.Sprintf(fmtMsg, args...), + } +} + +func newWrappedSQLError(err error) error { + if err == nil { + return nil + } + + return &sqlError{ + err: err, + } +} + +func newValidationError(fmtMsg string, args ...any) error { + return &validationError{ + msg: fmt.Sprintf(fmtMsg, args...), + } +} + +func newWrappedValidationError(err error) error { + if err == nil { + return nil + } + + return &validationError{ + err: err, + } +} diff --git a/pkg/server/datastore/sqlstore/errors_test.go b/pkg/server/datastore/sqlstore/errors_test.go new file mode 100644 index 0000000000..5d2079aa81 --- /dev/null +++ b/pkg/server/datastore/sqlstore/errors_test.go @@ -0,0 +1,58 @@ +package sqlstore + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSQLError(t *testing.T) { + err := newSQLError("an error with two dynamic fields: %s, %d", "hello", 1) + assert.EqualError(t, err, "datastore-sql: an error with two dynamic fields: hello, 1") + + var sErr *sqlError + assert.ErrorAs(t, err, &sErr) +} + +func TestWrappedSQLError(t *testing.T) { + t.Run("nil error", func(t *testing.T) { + err := newWrappedSQLError(nil) + assert.NoError(t, err) + }) + + t.Run("non-nil error", func(t *testing.T) { + wrappedErr := errors.New("foo") + err := newWrappedSQLError(wrappedErr) + + assert.EqualError(t, err, "datastore-sql: foo") + + var sErr *sqlError + assert.ErrorAs(t, err, &sErr) + }) +} + +func TestValidationError(t *testing.T) { + err := newValidationError("an error with two dynamic fields: %s, %d", "hello", 1) + assert.EqualError(t, err, "datastore-validation: an error with two dynamic fields: hello, 1") + + var vErr *validationError + assert.ErrorAs(t, err, &vErr) +} + +func TestWrappedValidationError(t *testing.T) { + t.Run("nil error", func(t *testing.T) { + err := newWrappedValidationError(nil) + assert.NoError(t, err) + }) + + t.Run("non-nil error", func(t *testing.T) { + wrappedErr := errors.New("bar") + err := newWrappedValidationError(wrappedErr) + + assert.EqualError(t, err, "datastore-validation: bar") + + var vErr *validationError + assert.ErrorAs(t, err, &vErr) + }) +} diff --git a/pkg/server/datastore/sqlstore/migration.go b/pkg/server/datastore/sqlstore/migration.go index c9febb270a..0d8eece2c7 100644 --- a/pkg/server/datastore/sqlstore/migration.go +++ b/pkg/server/datastore/sqlstore/migration.go @@ -271,12 +271,12 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie // version before continuing, and fail if we're not. if codeVersion.Major > 1 { log.Error("Migration code needs updating for current release version") - return sqlError.New("current migration code not compatible with current release version") + return newSQLError("current migration code not compatible with current release version") } isNew := !db.HasTable(&Migration{}) if err := db.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if isNew { @@ -285,12 +285,12 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie // ensure migrations table exists so we can check versioning in all cases if err := db.AutoMigrate(&Migration{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } migration := new(Migration) if err := db.Assign(Migration{}).FirstOrCreate(migration).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } schemaVersion := migration.Version @@ -300,7 +300,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie dbCodeVersion, err := getDBCodeVersion(*migration) if err != nil { log.WithError(err).Error("Error getting DB code version") - return sqlError.New("error getting DB code version: %v", err) + return newSQLError("error getting DB code version: %v", err) } log = log.WithField(telemetry.VersionInfo, dbCodeVersion.String()) @@ -316,7 +316,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie } if err := db.Model(&Migration{}).Updates(newMigration).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } } return nil @@ -325,7 +325,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie if disableMigration { if err = isDisabledMigrationAllowed(codeVersion, dbCodeVersion); err != nil { log.WithError(err).Error("Auto-migrate must be enabled") - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil } @@ -336,7 +336,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie if schemaVersion > latestSchemaVersion { if !isCompatibleCodeVersion(codeVersion, dbCodeVersion) { log.Error("Incompatible DB schema is too new for code version, upgrade SPIRE Server") - return sqlError.New("incompatible DB schema and code version") + return newSQLError("incompatible DB schema and code version") } log.Warn("DB schema is ahead of code version, upgrading SPIRE Server is recommended") return nil @@ -350,7 +350,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie for schemaVersion < latestSchemaVersion { tx := db.Begin() if err := tx.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } schemaVersion, err = migrateVersion(tx, schemaVersion, log) if err != nil { @@ -358,7 +358,7 @@ func migrateDB(db *gorm.DB, dbType string, disableMigration bool, log logrus.Fie return err } if err := tx.Commit().Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } } @@ -401,7 +401,7 @@ func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { log.Info("Initializing new database") tx := db.Begin() if err := tx.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } tables := []any{ @@ -421,7 +421,7 @@ func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { if err := tableOptionsForDialect(tx, dbType).AutoMigrate(tables...).Error; err != nil { tx.Rollback() - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := tx.Assign(Migration{ @@ -429,7 +429,7 @@ func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { CodeVersion: codeVersion.String(), }).FirstOrCreate(&Migration{}).Error; err != nil { tx.Rollback() - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := addFederatedRegistrationEntriesRegisteredEntryIDIndex(tx); err != nil { @@ -437,7 +437,7 @@ func initDB(db *gorm.DB, dbType string, log logrus.FieldLogger) (err error) { } if err := tx.Commit().Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -461,11 +461,11 @@ func migrateVersion(tx *gorm.DB, currVersion int, log logrus.FieldLogger) (versi Version: nextVersion, CodeVersion: version.Version(), }).Error; err != nil { - return 0, sqlError.Wrap(err) + return 0, newWrappedSQLError(err) } if currVersion < lastMinorReleaseSchemaVersion { - return 0, sqlError.New("migrating from schema version %d requires a previous SPIRE release; please follow the upgrade strategy at doc/upgrading.md", currVersion) + return 0, newSQLError("migrating from schema version %d requires a previous SPIRE release; please follow the upgrade strategy at doc/upgrading.md", currVersion) } // Place all migrations handled by the current minor release here. This @@ -489,7 +489,7 @@ func migrateVersion(tx *gorm.DB, currVersion int, log logrus.FieldLogger) (versi // switch currVersion { //nolint: gocritic // No upgrade required yet, keeping switch for future additions default: - err = sqlError.New("no migration support for unknown schema version %d", currVersion) + err = newSQLError("no migration support for unknown schema version %d", currVersion) } if err != nil { return 0, err @@ -506,7 +506,7 @@ func addFederatedRegistrationEntriesRegisteredEntryIDIndex(tx *gorm.DB) error { // to introduce the index since there is no explicit struct to add tags to // so we have to manually create it. if err := tx.Table("federated_registration_entries").AddIndex("idx_federated_registration_entries_registered_entry_id", "registered_entry_id").Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil } diff --git a/pkg/server/datastore/sqlstore/mysql.go b/pkg/server/datastore/sqlstore/mysql.go index 8e626330f1..a7ee2faeff 100644 --- a/pkg/server/datastore/sqlstore/mysql.go +++ b/pkg/server/datastore/sqlstore/mysql.go @@ -169,11 +169,11 @@ func hasTLSConfig(cfg *configuration) bool { func validateMySQLConfig(cfg *configuration, isReadOnly bool) error { opts, err := mysql.ParseDSN(getConnectionString(cfg, isReadOnly)) if err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if !opts.ParseTime { - return sqlError.Wrap(errors.New("invalid mysql config: missing parseTime=true param in connection_string")) + return newSQLError("invalid mysql config: missing parseTime=true param in connection_string") } return nil diff --git a/pkg/server/datastore/sqlstore/sqlite.go b/pkg/server/datastore/sqlstore/sqlite.go index a3e4ff56e2..c911f2920e 100644 --- a/pkg/server/datastore/sqlstore/sqlite.go +++ b/pkg/server/datastore/sqlstore/sqlite.go @@ -55,7 +55,7 @@ func openSQLite3(connString string) (*gorm.DB, error) { } db, err := gorm.Open("sqlite3", embellished) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return db, nil } @@ -74,7 +74,7 @@ func embellishSQLite3ConnString(connectionString string) (string, error) { u, err := url.Parse(connectionString) if err != nil { - return "", sqlError.Wrap(err) + return "", newWrappedSQLError(err) } switch { @@ -88,7 +88,7 @@ func embellishSQLite3ConnString(connectionString string) (string, error) { u.Opaque, u.Path = u.Path, "" case u.Scheme != "file": // only no scheme (i.e. file path) or file scheme is supported - return "", sqlError.New("unsupported scheme %q", u.Scheme) + return "", newSQLError("unsupported scheme %q", u.Scheme) } q := u.Query() diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index 1f1034ad16..45ed51340f 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -30,26 +30,21 @@ import ( "github.com/spiffe/spire/pkg/server/datastore" "github.com/spiffe/spire/proto/private/server/journal" "github.com/spiffe/spire/proto/spire/common" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) -var ( - sqlError = errs.Class("datastore-sql") - validationError = errs.Class("datastore-validation") - validEntryIDChars = &unicode.RangeTable{ - R16: []unicode.Range16{ - {0x002d, 0x002e, 1}, // - | . - {0x0030, 0x0039, 1}, // [0-9] - {0x0041, 0x005a, 1}, // [A-Z] - {0x005f, 0x005f, 1}, // _ - {0x0061, 0x007a, 1}, // [a-z] - }, - LatinOffset: 5, - } -) +var validEntryIDChars = &unicode.RangeTable{ + R16: []unicode.Range16{ + {0x002d, 0x002e, 1}, // - | . + {0x0030, 0x0039, 1}, // [0-9] + {0x0041, 0x005a, 1}, // [A-Z] + {0x005f, 0x005f, 1}, // _ + {0x0061, 0x007a, 1}, // [a-z] + }, + LatinOffset: 5, +} const ( PluginName = "sql" @@ -104,7 +99,7 @@ type awsConfig struct { func (a *awsConfig) validate() error { if a.Region == "" { - return sqlError.New("region must be specified") + return newSQLError("region must be specified") } return nil } @@ -288,7 +283,7 @@ func (ds *Plugin) RevokeJWTKey(ctx context.Context, trustDoaminID string, author // CreateAttestedNode stores the given attested node func (ds *Plugin) CreateAttestedNode(ctx context.Context, node *common.AttestedNode) (attestedNode *common.AttestedNode, err error) { if node == nil { - return nil, sqlError.New("invalid request: missing attested node") + return nil, newSQLError("invalid request: missing attested node") } if err = ds.withWriteTx(ctx, func(tx *gorm.DB) (err error) { @@ -801,7 +796,7 @@ func (ds *Plugin) PruneCAJournals(ctx context.Context, allAuthoritiesExpireBefor func (ds *Plugin) pruneCAJournals(tx *gorm.DB, allAuthoritiesExpireBefore int64) error { var caJournals []CAJournal if err := tx.Find(&caJournals).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } checkAuthorities: @@ -884,7 +879,7 @@ func (ds *Plugin) openConnection(config *configuration, isReadOnly bool) error { raw := db.DB() if raw == nil { - return sqlError.New("unable to get raw database object") + return newSQLError("unable to get raw database object") } if sqlDb != nil { @@ -919,15 +914,15 @@ func (ds *Plugin) openConnection(config *configuration, isReadOnly bool) error { } func (ds *Plugin) Close() error { - var errs errs.Group + var errs error if ds.db != nil { - errs.Add(ds.db.Close()) + errs = errors.Join(errs, ds.db.Close()) } if ds.roDb != nil { - errs.Add(ds.roDb.Close()) + errs = errors.Join(errs, ds.roDb.Close()) } - return errs.Err() + return errs } // withReadModifyWriteTx wraps the operation in a transaction appropriate for @@ -987,7 +982,7 @@ func (ds *Plugin) withTx(ctx context.Context, op func(tx *gorm.DB) error, readOn tx := db.BeginTx(ctx, nil) if err := tx.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := op(tx); err != nil { @@ -999,9 +994,9 @@ func (ds *Plugin) withTx(ctx context.Context, op func(tx *gorm.DB) error, readOn // rolling back makes sure that functions that are invoked with // withReadTx, and then do writes, will not pass unit tests, since the // writes won't be committed. - return sqlError.Wrap(tx.Rollback().Error) + return newWrappedSQLError(tx.Rollback().Error) } - return sqlError.Wrap(tx.Commit().Error) + return newWrappedSQLError(tx.Commit().Error) } // gormToGRPCStatus takes an error, and converts it to a GRPC error. If the @@ -1020,7 +1015,8 @@ func (ds *Plugin) gormToGRPCStatus(err error) error { } code := codes.Unknown - if validationError.Has(err) { + var vErr *validationError + if errors.As(err, &vErr) { code = codes.InvalidArgument } @@ -1050,12 +1046,12 @@ func (ds *Plugin) openDB(cfg *configuration, isReadOnly bool) (*gorm.DB, string, logger: ds.log, } default: - return nil, "", false, nil, sqlError.New("unsupported database_type: %v", cfg.databaseTypeConfig.databaseType) + return nil, "", false, nil, newSQLError("unsupported database_type: %v", cfg.databaseTypeConfig.databaseType) } db, version, supportsCTE, err := dialect.connect(cfg, isReadOnly) if err != nil { - return nil, "", false, nil, sqlError.Wrap(err) + return nil, "", false, nil, newWrappedSQLError(err) } db.SetLogger(gormLogger{ @@ -1107,7 +1103,7 @@ func createBundle(tx *gorm.DB, bundle *common.Bundle) (*common.Bundle, error) { } if err := tx.Create(model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return bundle, nil @@ -1121,16 +1117,16 @@ func updateBundle(tx *gorm.DB, newBundle *common.Bundle, mask *common.BundleMask model := &Bundle{} if err := tx.Find(model, "trust_domain = ?", newModel.TrustDomain).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } model.Data, newBundle, err = applyBundleMask(model, newBundle, mask) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Save(model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return newBundle, nil @@ -1186,7 +1182,7 @@ func setBundle(tx *gorm.DB, b *common.Bundle) (*common.Bundle, error) { } return bundle, nil } else if result.Error != nil { - return nil, sqlError.Wrap(result.Error) + return nil, newWrappedSQLError(result.Error) } bundle, err := updateBundle(tx, b, nil) @@ -1212,7 +1208,7 @@ func appendBundle(tx *gorm.DB, b *common.Bundle) (*common.Bundle, error) { } return bundle, nil } else if result.Error != nil { - return nil, sqlError.Wrap(result.Error) + return nil, newWrappedSQLError(result.Error) } // parse the bundle data and add missing elements @@ -1230,7 +1226,7 @@ func appendBundle(tx *gorm.DB, b *common.Bundle) (*common.Bundle, error) { } model.Data = newModel.Data if err := tx.Save(model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -1240,14 +1236,14 @@ func appendBundle(tx *gorm.DB, b *common.Bundle) (*common.Bundle, error) { func deleteBundle(tx *gorm.DB, trustDomainID string, mode datastore.DeleteMode) error { model := new(Bundle) if err := tx.Find(model, "trust_domain = ?", trustDomainID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } // Get a count of associated registration entries entriesAssociation := tx.Model(model).Association("FederatedEntries") entriesCount := entriesAssociation.Count() if err := entriesAssociation.Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if entriesCount > 0 { @@ -1261,11 +1257,11 @@ func deleteBundle(tx *gorm.DB, trustDomainID string, mode datastore.DeleteMode) federated_registration_entries WHERE bundle_id = ?)`), model.ID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } case datastore.Dissociate: if err := entriesAssociation.Clear().Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } default: return status.Newf(codes.FailedPrecondition, "datastore-sql: cannot delete bundle; federated with %d registration entries", entriesCount).Err() @@ -1273,7 +1269,7 @@ func deleteBundle(tx *gorm.DB, trustDomainID string, mode datastore.DeleteMode) } if err := tx.Delete(model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -1287,7 +1283,7 @@ func fetchBundle(tx *gorm.DB, trustDomainID string) (*common.Bundle, error) { case errors.Is(err, gorm.ErrRecordNotFound): return nil, nil case err != nil: - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } bundle, err := modelToBundle(model) @@ -1304,7 +1300,7 @@ func countBundles(tx *gorm.DB) (int32, error) { var count int if err := tx.Count(&count).Error; err != nil { - return 0, sqlError.Wrap(err) + return 0, newWrappedSQLError(err) } return int32(count), nil @@ -1327,7 +1323,7 @@ func listBundles(tx *gorm.DB, req *datastore.ListBundlesRequest) (*datastore.Lis var bundles []Bundle if err := tx.Find(&bundles).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if p != nil { @@ -1546,7 +1542,7 @@ func revokeJWTKey(tx *gorm.DB, trustDomainID string, authorityID string) (*commo func getBundle(tx *gorm.DB, trustDomainID string) (*common.Bundle, error) { model := &Bundle{} if err := tx.Find(model, "trust_domain = ?", trustDomainID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } bundle, err := modelToBundle(model) @@ -1569,7 +1565,7 @@ func createAttestedNode(tx *gorm.DB, node *common.AttestedNode) (*common.Atteste } if err := tx.Create(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToAttestedNode(model), nil @@ -1582,7 +1578,7 @@ func fetchAttestedNode(tx *gorm.DB, spiffeID string) (*common.AttestedNode, erro case errors.Is(err, gorm.ErrRecordNotFound): return nil, nil case err != nil: - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToAttestedNode(model), nil } @@ -1590,7 +1586,7 @@ func fetchAttestedNode(tx *gorm.DB, spiffeID string) (*common.AttestedNode, erro func countAttestedNodes(tx *gorm.DB) (int32, error) { var count int if err := tx.Model(&AttestedNode{}).Count(&count).Error; err != nil { - return 0, sqlError.Wrap(err) + return 0, newWrappedSQLError(err) } return int32(count), nil @@ -1705,7 +1701,7 @@ func createAttestedNodeEvent(tx *gorm.DB, event *datastore.AttestedNodeEvent) er }, SpiffeID: event.SpiffeID, }).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -1717,15 +1713,15 @@ func listAttestedNodeEvents(tx *gorm.DB, req *datastore.ListAttestedNodeEventsRe if req.GreaterThanEventID != 0 || req.LessThanEventID != 0 { query, id, err := buildListEventsQueryString(req.GreaterThanEventID, req.LessThanEventID) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Find(&events, query.String(), id).Order("id asc").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } else { if err := tx.Find(&events).Order("id asc").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -1742,7 +1738,7 @@ func listAttestedNodeEvents(tx *gorm.DB, req *datastore.ListAttestedNodeEventsRe func pruneAttestedNodeEvents(tx *gorm.DB, olderThan time.Duration) error { if err := tx.Where("created_at < ?", time.Now().Add(-olderThan)).Delete(&AttestedNodeEvent{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -1751,7 +1747,7 @@ func pruneAttestedNodeEvents(tx *gorm.DB, olderThan time.Duration) error { func fetchAttestedNodeEvent(db *sqlDB, eventID uint) (*datastore.AttestedNodeEvent, error) { event := AttestedNodeEvent{} if err := db.Find(&event, "id = ?", eventID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return &datastore.AttestedNodeEvent{ @@ -1766,7 +1762,7 @@ func deleteAttestedNodeEvent(tx *gorm.DB, eventID uint) error { ID: eventID, }, }).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -1805,12 +1801,12 @@ func filterNodesBySelectorSet(nodes []*common.AttestedNode, selectors []*common. func listAttestedNodesOnce(ctx context.Context, db *sqlDB, req *datastore.ListAttestedNodesRequest) (*datastore.ListAttestedNodesResponse, error) { query, args, err := buildListAttestedNodesQuery(db.databaseType, db.supportsCTE, req) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() @@ -1842,7 +1838,7 @@ func listAttestedNodesOnce(ctx context.Context, db *sqlDB, req *datastore.ListAt pushNode(node) if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } resp := &datastore.ListAttestedNodesResponse{ @@ -1878,7 +1874,7 @@ func buildListAttestedNodesQuery(dbType string, supportsCTE bool, req *datastore } return buildListAttestedNodesQueryMySQL(req) default: - return "", nil, sqlError.New("unsupported db type: %q", dbType) + return "", nil, newSQLError("unsupported db type: %q", dbType) } } @@ -2022,7 +2018,7 @@ SELECT } builder.WriteString(query) if len(req.BySelectorMatch.Selectors) > 1 { - builder.WriteString(fmt.Sprintf(") c_%d\n", i)) + fmt.Fprintf(builder, ") c_%d\n", i) } // First subquery does not need USING(ID) if i > 0 { @@ -2041,7 +2037,7 @@ SELECT } } default: - return "", nil, errs.New("unhandled match behavior %q", req.BySelectorMatch.Match) + return "", nil, fmt.Errorf("unhandled match behavior %q", req.BySelectorMatch.Match) } // Add all selectors as arguments @@ -2206,11 +2202,11 @@ FROM attested_node_entries N builder.WriteString("\t\t\tINNER JOIN\n") builder.WriteString("\t\t\t(") builder.WriteString(query) - builder.WriteString(fmt.Sprintf(") c_%d\n", i+1)) + fmt.Fprintf(builder, ") c_%d\n", i+1) builder.WriteString("\t\t\tUSING(spiffe_id)\n") } default: - return "", nil, errs.New("unhandled match behavior %q", req.BySelectorMatch.Match) + return "", nil, fmt.Errorf("unhandled match behavior %q", req.BySelectorMatch.Match) } for _, selector := range req.BySelectorMatch.Selectors { @@ -2244,7 +2240,7 @@ FROM attested_node_entries N func updateAttestedNode(tx *gorm.DB, n *common.AttestedNode, mask *common.AttestedNodeMask) (*common.AttestedNode, error) { var model AttestedNode if err := tx.Find(&model, "spiffe_id = ?", n.SpiffeId).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if mask == nil { @@ -2268,7 +2264,7 @@ func updateAttestedNode(tx *gorm.DB, n *common.AttestedNode, mask *common.Attest updates["can_reattest"] = n.CanReattest } if err := tx.Model(&model).Updates(updates).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToAttestedNode(model), nil @@ -2282,15 +2278,15 @@ func deleteAttestedNodeAndSelectors(tx *gorm.DB, spiffeID string) (*common.Attes // batch delete all associated node selectors if err := tx.Where("spiffe_id = ?", spiffeID).Delete(&nodeSelectorModel).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Find(&nodeModel, "spiffe_id = ?", spiffeID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Delete(&nodeModel).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToAttestedNode(nodeModel), nil @@ -2310,11 +2306,11 @@ func setNodeSelectors(tx *gorm.DB, spiffeID string, selectors []*common.Selector // gap locks on the index. var ids []int64 if err := tx.Model(&NodeSelector{}).Where("spiffe_id = ?", spiffeID).Pluck("id", &ids).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if len(ids) > 0 { if err := tx.Where("id IN (?)", ids).Delete(&NodeSelector{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } } @@ -2325,7 +2321,7 @@ func setNodeSelectors(tx *gorm.DB, spiffeID string, selectors []*common.Selector Value: selector.Value, } if err := tx.Create(model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } } @@ -2336,7 +2332,7 @@ func getNodeSelectors(ctx context.Context, db *sqlDB, spiffeID string) ([]*commo query := maybeRebind(db.databaseType, "SELECT type, value FROM node_resolver_map_entries WHERE spiffe_id=? ORDER BY id") rows, err := db.QueryContext(ctx, query, spiffeID) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() @@ -2344,13 +2340,13 @@ func getNodeSelectors(ctx context.Context, db *sqlDB, spiffeID string) ([]*commo for rows.Next() { selector := new(common.Selector) if err := rows.Scan(&selector.Type, &selector.Value); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } selectors = append(selectors, selector) } if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return selectors, nil @@ -2361,7 +2357,7 @@ func listNodeSelectors(ctx context.Context, db *sqlDB, req *datastore.ListNodeSe query := maybeRebind(db.databaseType, rawQuery) rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() @@ -2403,7 +2399,7 @@ func listNodeSelectors(ctx context.Context, db *sqlDB, req *datastore.ListNodeSe push("", nil) if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return resp, nil @@ -2447,7 +2443,7 @@ func createRegistrationEntry(tx *gorm.DB, entry *common.RegistrationEntry) (*com } if err := tx.Create(&newRegisteredEntry).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } federatesWith, err := makeFederatesWith(tx, entry.FederatesWith) @@ -2467,7 +2463,7 @@ func createRegistrationEntry(tx *gorm.DB, entry *common.RegistrationEntry) (*com } if err := tx.Create(&newSelector).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -2478,7 +2474,7 @@ func createRegistrationEntry(tx *gorm.DB, entry *common.RegistrationEntry) (*com } if err := tx.Create(&newDNS).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -2493,12 +2489,12 @@ func createRegistrationEntry(tx *gorm.DB, entry *common.RegistrationEntry) (*com func fetchRegistrationEntry(ctx context.Context, db *sqlDB, entryID string) (*common.RegistrationEntry, error) { query, args, err := buildFetchRegistrationEntryQuery(db.databaseType, db.supportsCTE, entryID) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() @@ -2518,7 +2514,7 @@ func fetchRegistrationEntry(ctx context.Context, db *sqlDB, entryID string) (*co } if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return entry, nil @@ -2540,7 +2536,7 @@ func buildFetchRegistrationEntryQuery(dbType string, supportsCTE bool, entryID s } return buildFetchRegistrationEntryQueryMySQL(entryID) default: - return "", nil, sqlError.New("unsupported db type: %q", dbType) + return "", nil, newSQLError("unsupported db type: %q", dbType) } } @@ -2857,12 +2853,12 @@ type queryContext interface { func listRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseType string, supportsCTE bool, req *datastore.ListRegistrationEntriesRequest) (*datastore.ListRegistrationEntriesResponse, error) { query, args, err := buildListRegistrationEntriesQuery(databaseType, supportsCTE, req) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } defer rows.Close() entries := make([]*common.RegistrationEntry, 0, calculateResultPreallocation(req.Pagination)) @@ -2898,7 +2894,7 @@ func listRegistrationEntriesOnce(ctx context.Context, db queryContext, databaseT pushEntry(entry) if err := rows.Err(); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } resp := &datastore.ListRegistrationEntriesResponse{ @@ -2933,7 +2929,7 @@ func buildListRegistrationEntriesQuery(dbType string, supportsCTE bool, req *dat } return buildListRegistrationEntriesQueryMySQL(req) default: - return "", nil, sqlError.New("unsupported db type: %q", dbType) + return "", nil, newSQLError("unsupported db type: %q", dbType) } } @@ -3525,7 +3521,7 @@ func appendListRegistrationEntriesFilterQuery(filterExp string, builder *strings }) } default: - return false, nil, errs.New("unhandled selectors match behavior %q", req.BySelectors.Match) + return false, nil, fmt.Errorf("unhandled selectors match behavior %q", req.BySelectors.Match) } for _, selector := range req.BySelectors.Selectors { args = append(args, selector.Type, selector.Value) @@ -3598,7 +3594,7 @@ func appendListRegistrationEntriesFilterQuery(filterExp string, builder *strings args = append(args, len(trustDomains)) default: - return false, nil, errs.New("unhandled federates with match behavior %q", req.ByFederatesWith.Match) + return false, nil, fmt.Errorf("unhandled federates with match behavior %q", req.ByFederatesWith.Match) } root.children = append(root.children, filterNode) } @@ -3689,7 +3685,7 @@ type nodeRow struct { } func scanNodeRow(rs *sql.Rows, r *nodeRow) error { - return sqlError.Wrap(rs.Scan( + return newWrappedSQLError(rs.Scan( &r.EId, &r.SpiffeID, &r.DataType, @@ -3730,7 +3726,7 @@ func fillNodeFromRow(node *common.AttestedNode, r *nodeRow) error { if r.SelectorType.Valid { if !r.SelectorValue.Valid { - return sqlError.New("expected non-nil selector.value value for attested node %s", node.SpiffeId) + return newSQLError("expected non-nil selector.value value for attested node %s", node.SpiffeId) } node.Selectors = append(node.Selectors, &common.Selector{ Type: r.SelectorType.String, @@ -3752,7 +3748,7 @@ type nodeSelectorRow struct { } func scanNodeSelectorRow(rs *sql.Rows, r *nodeSelectorRow) error { - return sqlError.Wrap(rs.Scan( + return newWrappedSQLError(rs.Scan( &r.SpiffeID, &r.Type, &r.Value, @@ -3792,7 +3788,7 @@ type entryRow struct { } func scanEntryRow(rs *sql.Rows, r *entryRow) error { - return sqlError.Wrap(rs.Scan( + return newWrappedSQLError(rs.Scan( &r.EId, &r.EntryID, &r.SpiffeID, @@ -3842,7 +3838,7 @@ func fillEntryFromRow(entry *common.RegistrationEntry, r *entryRow) error { } if r.SelectorType.Valid { if !r.SelectorValue.Valid { - return sqlError.New("expected non-nil selector.value value for entry id %s", entry.EntryId) + return newSQLError("expected non-nil selector.value value for entry id %s", entry.EntryId) } entry.Selectors = append(entry.Selectors, &common.Selector{ Type: r.SelectorType.String, @@ -3896,7 +3892,7 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com // Get the existing entry entry := RegisteredEntry{} if err := tx.Find(&entry, "entry_id = ?", e.EntryId).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if mask == nil || mask.StoreSvid { entry.StoreSvid = e.StoreSvid @@ -3904,7 +3900,7 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com if mask == nil || mask.Selectors { // Delete existing selectors - we will write new ones if err := tx.Exec("DELETE FROM selectors WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } selectors := []Selector{} @@ -3921,13 +3917,13 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com // Verify that final selectors contains the same 'type' when entry is used for store SVIDs if entry.StoreSvid && !equalSelectorTypes(entry.Selectors) { - return nil, validationError.New("invalid registration entry: selector types must be the same when store SVID is enabled") + return nil, newValidationError("invalid registration entry: selector types must be the same when store SVID is enabled") } if mask == nil || mask.DnsNames { // Delete existing DNSs - we will write new ones if err := tx.Exec("DELETE FROM dns_names WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } dnsList := []DNSName{} @@ -3970,7 +3966,7 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com entry.RevisionNumber++ if err := tx.Save(&entry).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if mask == nil || mask.FederatesWith { @@ -3996,7 +3992,7 @@ func updateRegistrationEntry(tx *gorm.DB, e *common.RegistrationEntry, mask *com func deleteRegistrationEntry(tx *gorm.DB, entryID string) (*common.RegistrationEntry, error) { entry := RegisteredEntry{} if err := tx.Find(&entry, "entry_id = ?", entryID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } registrationEntry, err := modelToEntry(tx, entry) @@ -4018,17 +4014,17 @@ func deleteRegistrationEntrySupport(tx *gorm.DB, entry RegisteredEntry) error { } if err := tx.Delete(&entry).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } // Delete existing selectors if err := tx.Exec("DELETE FROM selectors WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } // Delete existing dns_names if err := tx.Exec("DELETE FROM dns_names WHERE registered_entry_id = ?", entry.ID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4066,7 +4062,7 @@ func createRegistrationEntryEvent(tx *gorm.DB, event *datastore.RegistrationEntr }, EntryID: event.EntryID, }).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4075,7 +4071,7 @@ func createRegistrationEntryEvent(tx *gorm.DB, event *datastore.RegistrationEntr func fetchRegistrationEntryEvent(db *sqlDB, eventID uint) (*datastore.RegistrationEntryEvent, error) { event := RegisteredEntryEvent{} if err := db.Find(&event, "id = ?", eventID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return &datastore.RegistrationEntryEvent{ @@ -4090,7 +4086,7 @@ func deleteRegistrationEntryEvent(tx *gorm.DB, eventID uint) error { ID: eventID, }, }).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4102,15 +4098,15 @@ func listRegistrationEntryEvents(tx *gorm.DB, req *datastore.ListRegistrationEnt if req.GreaterThanEventID != 0 || req.LessThanEventID != 0 { query, id, err := buildListEventsQueryString(req.GreaterThanEventID, req.LessThanEventID) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if err := tx.Find(&events, query.String(), id).Order("id asc").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } else { if err := tx.Find(&events).Order("id asc").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } } @@ -4127,7 +4123,7 @@ func listRegistrationEntryEvents(tx *gorm.DB, req *datastore.ListRegistrationEnt func pruneRegistrationEntryEvents(tx *gorm.DB, olderThan time.Duration) error { if err := tx.Where("created_at < ?", time.Now().Add(-olderThan)).Delete(&RegisteredEntryEvent{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4160,7 +4156,7 @@ func createJoinToken(tx *gorm.DB, token *datastore.JoinToken) error { } if err := tx.Create(&t).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4172,7 +4168,7 @@ func fetchJoinToken(tx *gorm.DB, token string) (*datastore.JoinToken, error) { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } else if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToJoinToken(model), nil @@ -4181,11 +4177,11 @@ func fetchJoinToken(tx *gorm.DB, token string) (*datastore.JoinToken, error) { func deleteJoinToken(tx *gorm.DB, token string) error { var model JoinToken if err := tx.Find(&model, "token = ?", token).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := tx.Delete(&model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4193,7 +4189,7 @@ func deleteJoinToken(tx *gorm.DB, token string) error { func pruneJoinTokens(tx *gorm.DB, expiresBefore time.Time) error { if err := tx.Where("expiry < ?", expiresBefore.Unix()).Delete(&JoinToken{}).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil @@ -4219,7 +4215,7 @@ func createFederationRelationship(tx *gorm.DB, fr *datastore.FederationRelations } if err := tx.Create(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return fr, nil @@ -4228,10 +4224,10 @@ func createFederationRelationship(tx *gorm.DB, fr *datastore.FederationRelations func deleteFederationRelationship(tx *gorm.DB, trustDomain spiffeid.TrustDomain) error { model := new(FederatedTrustDomain) if err := tx.Find(model, "trust_domain = ?", trustDomain.Name()).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := tx.Delete(model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil } @@ -4243,7 +4239,7 @@ func fetchFederationRelationship(tx *gorm.DB, trustDomain spiffeid.TrustDomain) case errors.Is(err, gorm.ErrRecordNotFound): return nil, nil case err != nil: - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToFederationRelationship(tx, &model) @@ -4266,7 +4262,7 @@ func listFederationRelationships(tx *gorm.DB, req *datastore.ListFederationRelat var federationRelationships []FederatedTrustDomain if err := tx.Find(&federationRelationships).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } if p != nil { @@ -4323,7 +4319,7 @@ func updateFederationRelationship(tx *gorm.DB, fr *datastore.FederationRelations } if err := tx.Save(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToFederationRelationship(tx, &model) @@ -4365,7 +4361,7 @@ func modelToFederationRelationship(tx *gorm.DB, model *FederatedTrustDomain) (*d td, err := spiffeid.TrustDomainFromString(model.TrustDomain) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } fr := &datastore.FederationRelationship{ @@ -4400,7 +4396,7 @@ func modelToFederationRelationship(tx *gorm.DB, model *FederatedTrustDomain) (*d func modelToBundle(model *Bundle) (*common.Bundle, error) { bundle := new(common.Bundle) if err := proto.Unmarshal(model.Data, bundle); err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return bundle, nil @@ -4408,11 +4404,11 @@ func modelToBundle(model *Bundle) (*common.Bundle, error) { func validateRegistrationEntry(entry *common.RegistrationEntry) error { if entry == nil { - return validationError.New("invalid request: missing registered entry") + return newValidationError("invalid request: missing registered entry") } if len(entry.Selectors) == 0 { - return validationError.New("invalid registration entry: missing selector list") + return newValidationError("invalid registration entry: missing selector list") } // In case of StoreSvid is set, all entries 'must' be the same type, @@ -4423,31 +4419,31 @@ func validateRegistrationEntry(entry *common.RegistrationEntry) error { tpe := entry.Selectors[0].Type for _, t := range entry.Selectors { if tpe != t.Type { - return validationError.New("invalid registration entry: selector types must be the same when store SVID is enabled") + return newValidationError("invalid registration entry: selector types must be the same when store SVID is enabled") } } } if len(entry.EntryId) > 255 { - return validationError.New("invalid registration entry: entry ID too long") + return newValidationError("invalid registration entry: entry ID too long") } for _, e := range entry.EntryId { if !unicode.In(e, validEntryIDChars) { - return validationError.New("invalid registration entry: entry ID contains invalid characters") + return newValidationError("invalid registration entry: entry ID contains invalid characters") } } if len(entry.SpiffeId) == 0 { - return validationError.New("invalid registration entry: missing SPIFFE ID") + return newValidationError("invalid registration entry: missing SPIFFE ID") } if entry.X509SvidTtl < 0 { - return validationError.New("invalid registration entry: X509SvidTtl is not set") + return newValidationError("invalid registration entry: X509SvidTtl is not set") } if entry.JwtSvidTtl < 0 { - return validationError.New("invalid registration entry: JwtSvidTtl is not set") + return newValidationError("invalid registration entry: JwtSvidTtl is not set") } return nil @@ -4469,26 +4465,26 @@ func equalSelectorTypes(selectors []Selector) bool { func validateRegistrationEntryForUpdate(entry *common.RegistrationEntry, mask *common.RegistrationEntryMask) error { if entry == nil { - return validationError.New("invalid request: missing registered entry") + return newValidationError("invalid request: missing registered entry") } if (mask == nil || mask.Selectors) && len(entry.Selectors) == 0 { - return validationError.New("invalid registration entry: missing selector list") + return newValidationError("invalid registration entry: missing selector list") } if (mask == nil || mask.SpiffeId) && entry.SpiffeId == "" { - return validationError.New("invalid registration entry: missing SPIFFE ID") + return newValidationError("invalid registration entry: missing SPIFFE ID") } if (mask == nil || mask.X509SvidTtl) && (entry.X509SvidTtl < 0) { - return validationError.New("invalid registration entry: X509SvidTtl is not set") + return newValidationError("invalid registration entry: X509SvidTtl is not set") } if (mask == nil || mask.JwtSvidTtl) && (entry.JwtSvidTtl < 0) { - return validationError.New("invalid registration entry: JwtSvidTtl is not set") + return newValidationError("invalid registration entry: JwtSvidTtl is not set") } return nil @@ -4498,11 +4494,11 @@ func validateRegistrationEntryForUpdate(entry *common.RegistrationEntry, mask *c // performs validation, and fully parses certificates to form CACert embedded models. func bundleToModel(pb *common.Bundle) (*Bundle, error) { if pb == nil { - return nil, sqlError.New("missing bundle in request") + return nil, newSQLError("missing bundle in request") } data, err := proto.Marshal(pb) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return &Bundle{ @@ -4514,7 +4510,7 @@ func bundleToModel(pb *common.Bundle) (*Bundle, error) { func modelToEntry(tx *gorm.DB, model RegisteredEntry) (*common.RegistrationEntry, error) { var fetchedSelectors []*Selector if err := tx.Model(&model).Related(&fetchedSelectors).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } selectors := make([]*common.Selector, 0, len(fetchedSelectors)) @@ -4527,7 +4523,7 @@ func modelToEntry(tx *gorm.DB, model RegisteredEntry) (*common.RegistrationEntry var fetchedDNSs []*DNSName if err := tx.Model(&model).Related(&fetchedDNSs).Order("registered_entry_id ASC").Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } var dnsList []string @@ -4540,7 +4536,7 @@ func modelToEntry(tx *gorm.DB, model RegisteredEntry) (*common.RegistrationEntry var fetchedBundles []*Bundle if err := tx.Model(&model).Association("FederatesWith").Find(&fetchedBundles).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } var federatesWith []string @@ -4655,11 +4651,11 @@ func bindVarsFn(fn func(int) string, query string) string { func (cfg *configuration) Validate() error { if cfg.databaseTypeConfig.databaseType == "" { - return sqlError.New("database_type must be set") + return newSQLError("database_type must be set") } if cfg.ConnectionString == "" { - return sqlError.New("connection_string must be set") + return newSQLError("connection_string must be set") } if isMySQLDbType(cfg.databaseTypeConfig.databaseType) { @@ -4701,12 +4697,12 @@ func getConnectionString(cfg *configuration, isReadOnly bool) string { func queryVersion(gormDB *gorm.DB, query string) (string, error) { db := gormDB.DB() if db == nil { - return "", sqlError.New("unable to get raw database object") + return "", newSQLError("unable to get raw database object") } var version string if err := db.QueryRow(query).Scan(&version); err != nil { - return "", sqlError.Wrap(err) + return "", newWrappedSQLError(err) } return version, nil } @@ -4762,7 +4758,7 @@ func createCAJournal(tx *gorm.DB, caJournal *datastore.CAJournal) (*datastore.CA } if err := tx.Create(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToCAJournal(model), nil @@ -4775,7 +4771,7 @@ func fetchCAJournal(tx *gorm.DB, activeX509AuthorityID string) (*datastore.CAJou case errors.Is(err, gorm.ErrRecordNotFound): return nil, nil case err != nil: - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToCAJournal(model), nil @@ -4784,7 +4780,7 @@ func fetchCAJournal(tx *gorm.DB, activeX509AuthorityID string) (*datastore.CAJou func listCAJournalsForTesting(tx *gorm.DB) (caJournals []*datastore.CAJournal, err error) { var caJournalsModel []CAJournal if err := tx.Find(&caJournalsModel).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } for _, model := range caJournalsModel { @@ -4797,14 +4793,14 @@ func listCAJournalsForTesting(tx *gorm.DB) (caJournals []*datastore.CAJournal, e func updateCAJournal(tx *gorm.DB, caJournal *datastore.CAJournal) (*datastore.CAJournal, error) { var model CAJournal if err := tx.Find(&model, "id = ?", caJournal.ID).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } model.ActiveX509AuthorityID = caJournal.ActiveX509AuthorityID model.Data = caJournal.Data if err := tx.Save(&model).Error; err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } return modelToCAJournal(model), nil @@ -4821,10 +4817,10 @@ func validateCAJournal(caJournal *datastore.CAJournal) error { func deleteCAJournal(tx *gorm.DB, caJournalID uint) error { model := new(CAJournal) if err := tx.Find(model, "id = ?", caJournalID).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } if err := tx.Delete(model).Error; err != nil { - return sqlError.Wrap(err) + return newWrappedSQLError(err) } return nil } diff --git a/pkg/server/datastore/sqlstore/sqlstore_test.go b/pkg/server/datastore/sqlstore/sqlstore_test.go index 40700d42f6..fbd7481afa 100644 --- a/pkg/server/datastore/sqlstore/sqlstore_test.go +++ b/pkg/server/datastore/sqlstore/sqlstore_test.go @@ -421,7 +421,8 @@ func (s *PluginSuite) TestListBundlesWithPagination() { PageSize: 2, }, expectedList: []*common.Bundle{bundle1, bundle2}, - expectedPagination: &datastore.Pagination{Token: "2", + expectedPagination: &datastore.Pagination{ + Token: "2", PageSize: 2, }, }, @@ -2858,8 +2859,8 @@ func (s *PluginSuite) testListRegistrationEntries(dataConsistency datastore.Data } var tokensIn []string - var actualEntriesOut = make(map[string]*common.RegistrationEntry) - var expectedEntriesOut = make(map[string]*common.RegistrationEntry) + actualEntriesOut := make(map[string]*common.RegistrationEntry) + expectedEntriesOut := make(map[string]*common.RegistrationEntry) req := &datastore.ListRegistrationEntriesRequest{ Pagination: pagination, ByParentID: tt.byParentID, @@ -3095,111 +3096,160 @@ func (s *PluginSuite) TestUpdateRegistrationEntryWithMask() { result func(*common.RegistrationEntry) err error }{ // SPIFFE ID FIELD -- this field is validated so we check with good and bad data - {name: "Update Spiffe ID, Good Data, Mask True", + { + name: "Update Spiffe ID, Good Data, Mask True", mask: &common.RegistrationEntryMask{SpiffeId: true}, update: func(e *common.RegistrationEntry) { e.SpiffeId = newEntry.SpiffeId }, - result: func(e *common.RegistrationEntry) { e.SpiffeId = newEntry.SpiffeId }}, - {name: "Update Spiffe ID, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.SpiffeId = newEntry.SpiffeId }, + }, + { + name: "Update Spiffe ID, Good Data, Mask False", mask: &common.RegistrationEntryMask{SpiffeId: false}, update: func(e *common.RegistrationEntry) { e.SpiffeId = newEntry.SpiffeId }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update Spiffe ID, Bad Data, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update Spiffe ID, Bad Data, Mask True", mask: &common.RegistrationEntryMask{SpiffeId: true}, update: func(e *common.RegistrationEntry) { e.SpiffeId = badEntry.SpiffeId }, - err: errors.New("invalid registration entry: missing SPIFFE ID")}, - {name: "Update Spiffe ID, Bad Data, Mask False", + err: errors.New("invalid registration entry: missing SPIFFE ID"), + }, + { + name: "Update Spiffe ID, Bad Data, Mask False", mask: &common.RegistrationEntryMask{SpiffeId: false}, update: func(e *common.RegistrationEntry) { e.SpiffeId = badEntry.SpiffeId }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // PARENT ID FIELD -- This field isn't validated so we just check with good data - {name: "Update Parent ID, Good Data, Mask True", + { + name: "Update Parent ID, Good Data, Mask True", mask: &common.RegistrationEntryMask{ParentId: true}, update: func(e *common.RegistrationEntry) { e.ParentId = newEntry.ParentId }, - result: func(e *common.RegistrationEntry) { e.ParentId = newEntry.ParentId }}, - {name: "Update Parent ID, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.ParentId = newEntry.ParentId }, + }, + { + name: "Update Parent ID, Good Data, Mask False", mask: &common.RegistrationEntryMask{ParentId: false}, update: func(e *common.RegistrationEntry) { e.ParentId = newEntry.ParentId }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // X509 SVID TTL FIELD -- This field is validated so we check with good and bad data - {name: "Update X509 SVID TTL, Good Data, Mask True", + { + name: "Update X509 SVID TTL, Good Data, Mask True", mask: &common.RegistrationEntryMask{X509SvidTtl: true}, update: func(e *common.RegistrationEntry) { e.X509SvidTtl = newEntry.X509SvidTtl }, - result: func(e *common.RegistrationEntry) { e.X509SvidTtl = newEntry.X509SvidTtl }}, - {name: "Update X509 SVID TTL, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.X509SvidTtl = newEntry.X509SvidTtl }, + }, + { + name: "Update X509 SVID TTL, Good Data, Mask False", mask: &common.RegistrationEntryMask{X509SvidTtl: false}, update: func(e *common.RegistrationEntry) { e.X509SvidTtl = badEntry.X509SvidTtl }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update X509 SVID TTL, Bad Data, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update X509 SVID TTL, Bad Data, Mask True", mask: &common.RegistrationEntryMask{X509SvidTtl: true}, update: func(e *common.RegistrationEntry) { e.X509SvidTtl = badEntry.X509SvidTtl }, - err: errors.New("invalid registration entry: X509SvidTtl is not set")}, - {name: "Update X509 SVID TTL, Bad Data, Mask False", + err: errors.New("invalid registration entry: X509SvidTtl is not set"), + }, + { + name: "Update X509 SVID TTL, Bad Data, Mask False", mask: &common.RegistrationEntryMask{X509SvidTtl: false}, update: func(e *common.RegistrationEntry) { e.X509SvidTtl = badEntry.X509SvidTtl }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // JWT SVID TTL FIELD -- This field is validated so we check with good and bad data - {name: "Update JWT SVID TTL, Good Data, Mask True", + { + name: "Update JWT SVID TTL, Good Data, Mask True", mask: &common.RegistrationEntryMask{JwtSvidTtl: true}, update: func(e *common.RegistrationEntry) { e.JwtSvidTtl = newEntry.JwtSvidTtl }, - result: func(e *common.RegistrationEntry) { e.JwtSvidTtl = newEntry.JwtSvidTtl }}, - {name: "Update JWT SVID TTL, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.JwtSvidTtl = newEntry.JwtSvidTtl }, + }, + { + name: "Update JWT SVID TTL, Good Data, Mask False", mask: &common.RegistrationEntryMask{JwtSvidTtl: false}, update: func(e *common.RegistrationEntry) { e.JwtSvidTtl = badEntry.JwtSvidTtl }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update JWT SVID TTL, Bad Data, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update JWT SVID TTL, Bad Data, Mask True", mask: &common.RegistrationEntryMask{JwtSvidTtl: true}, update: func(e *common.RegistrationEntry) { e.JwtSvidTtl = badEntry.JwtSvidTtl }, - err: errors.New("invalid registration entry: JwtSvidTtl is not set")}, - {name: "Update JWT SVID TTL, Bad Data, Mask False", + err: errors.New("invalid registration entry: JwtSvidTtl is not set"), + }, + { + name: "Update JWT SVID TTL, Bad Data, Mask False", mask: &common.RegistrationEntryMask{JwtSvidTtl: false}, update: func(e *common.RegistrationEntry) { e.JwtSvidTtl = badEntry.JwtSvidTtl }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // SELECTORS FIELD -- This field is validated so we check with good and bad data - {name: "Update Selectors, Good Data, Mask True", + { + name: "Update Selectors, Good Data, Mask True", mask: &common.RegistrationEntryMask{Selectors: true}, update: func(e *common.RegistrationEntry) { e.Selectors = newEntry.Selectors }, - result: func(e *common.RegistrationEntry) { e.Selectors = newEntry.Selectors }}, - {name: "Update Selectors, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.Selectors = newEntry.Selectors }, + }, + { + name: "Update Selectors, Good Data, Mask False", mask: &common.RegistrationEntryMask{Selectors: false}, update: func(e *common.RegistrationEntry) { e.Selectors = badEntry.Selectors }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update Selectors, Bad Data, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update Selectors, Bad Data, Mask True", mask: &common.RegistrationEntryMask{Selectors: true}, update: func(e *common.RegistrationEntry) { e.Selectors = badEntry.Selectors }, - err: errors.New("invalid registration entry: missing selector list")}, - {name: "Update Selectors, Bad Data, Mask False", + err: errors.New("invalid registration entry: missing selector list"), + }, + { + name: "Update Selectors, Bad Data, Mask False", mask: &common.RegistrationEntryMask{Selectors: false}, update: func(e *common.RegistrationEntry) { e.Selectors = badEntry.Selectors }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // FEDERATESWITH FIELD -- This field isn't validated so we just check with good data - {name: "Update FederatesWith, Good Data, Mask True", + { + name: "Update FederatesWith, Good Data, Mask True", mask: &common.RegistrationEntryMask{FederatesWith: true}, update: func(e *common.RegistrationEntry) { e.FederatesWith = newEntry.FederatesWith }, - result: func(e *common.RegistrationEntry) { e.FederatesWith = newEntry.FederatesWith }}, - {name: "Update FederatesWith Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.FederatesWith = newEntry.FederatesWith }, + }, + { + name: "Update FederatesWith Good Data, Mask False", mask: &common.RegistrationEntryMask{FederatesWith: false}, update: func(e *common.RegistrationEntry) { e.FederatesWith = newEntry.FederatesWith }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // ADMIN FIELD -- This field isn't validated so we just check with good data - {name: "Update Admin, Good Data, Mask True", + { + name: "Update Admin, Good Data, Mask True", mask: &common.RegistrationEntryMask{Admin: true}, update: func(e *common.RegistrationEntry) { e.Admin = newEntry.Admin }, - result: func(e *common.RegistrationEntry) { e.Admin = newEntry.Admin }}, - {name: "Update Admin, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.Admin = newEntry.Admin }, + }, + { + name: "Update Admin, Good Data, Mask False", mask: &common.RegistrationEntryMask{Admin: false}, update: func(e *common.RegistrationEntry) { e.Admin = newEntry.Admin }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // STORESVID FIELD -- This field isn't validated so we just check with good data - {name: "Update StoreSvid, Good Data, Mask True", + { + name: "Update StoreSvid, Good Data, Mask True", mask: &common.RegistrationEntryMask{StoreSvid: true}, update: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid }, - result: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid }}, - {name: "Update StoreSvid, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid }, + }, + { + name: "Update StoreSvid, Good Data, Mask False", mask: &common.RegistrationEntryMask{Admin: false}, update: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid }, - result: func(e *common.RegistrationEntry) {}}, - {name: "Update StoreSvid, Invalid selectors, Mask True", + result: func(e *common.RegistrationEntry) {}, + }, + { + name: "Update StoreSvid, Invalid selectors, Mask True", mask: &common.RegistrationEntryMask{StoreSvid: true, Selectors: true}, update: func(e *common.RegistrationEntry) { e.StoreSvid = newEntry.StoreSvid @@ -3208,50 +3258,68 @@ func (s *PluginSuite) TestUpdateRegistrationEntryWithMask() { {Type: "Type2", Value: "Value2"}, } }, - err: validationError.New("invalid registration entry: selector types must be the same when store SVID is enabled"), + err: newValidationError("invalid registration entry: selector types must be the same when store SVID is enabled"), }, // ENTRYEXPIRY FIELD -- This field isn't validated so we just check with good data - {name: "Update EntryExpiry, Good Data, Mask True", + { + name: "Update EntryExpiry, Good Data, Mask True", mask: &common.RegistrationEntryMask{EntryExpiry: true}, update: func(e *common.RegistrationEntry) { e.EntryExpiry = newEntry.EntryExpiry }, - result: func(e *common.RegistrationEntry) { e.EntryExpiry = newEntry.EntryExpiry }}, - {name: "Update EntryExpiry, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.EntryExpiry = newEntry.EntryExpiry }, + }, + { + name: "Update EntryExpiry, Good Data, Mask False", mask: &common.RegistrationEntryMask{EntryExpiry: false}, update: func(e *common.RegistrationEntry) { e.EntryExpiry = newEntry.EntryExpiry }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // DNSNAMES FIELD -- This field isn't validated so we just check with good data - {name: "Update DnsNames, Good Data, Mask True", + { + name: "Update DnsNames, Good Data, Mask True", mask: &common.RegistrationEntryMask{DnsNames: true}, update: func(e *common.RegistrationEntry) { e.DnsNames = newEntry.DnsNames }, - result: func(e *common.RegistrationEntry) { e.DnsNames = newEntry.DnsNames }}, - {name: "Update DnsNames, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.DnsNames = newEntry.DnsNames }, + }, + { + name: "Update DnsNames, Good Data, Mask False", mask: &common.RegistrationEntryMask{DnsNames: false}, update: func(e *common.RegistrationEntry) { e.DnsNames = newEntry.DnsNames }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // DOWNSTREAM FIELD -- This field isn't validated so we just check with good data - {name: "Update DnsNames, Good Data, Mask True", + { + name: "Update DnsNames, Good Data, Mask True", mask: &common.RegistrationEntryMask{Downstream: true}, update: func(e *common.RegistrationEntry) { e.Downstream = newEntry.Downstream }, - result: func(e *common.RegistrationEntry) { e.Downstream = newEntry.Downstream }}, - {name: "Update DnsNames, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.Downstream = newEntry.Downstream }, + }, + { + name: "Update DnsNames, Good Data, Mask False", mask: &common.RegistrationEntryMask{Downstream: false}, update: func(e *common.RegistrationEntry) { e.Downstream = newEntry.Downstream }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // HINT -- This field isn't validated so we just check with good data - {name: "Update Hint, Good Data, Mask True", + { + name: "Update Hint, Good Data, Mask True", mask: &common.RegistrationEntryMask{Hint: true}, update: func(e *common.RegistrationEntry) { e.Hint = newEntry.Hint }, - result: func(e *common.RegistrationEntry) { e.Hint = newEntry.Hint }}, - {name: "Update Hint, Good Data, Mask False", + result: func(e *common.RegistrationEntry) { e.Hint = newEntry.Hint }, + }, + { + name: "Update Hint, Good Data, Mask False", mask: &common.RegistrationEntryMask{Hint: false}, update: func(e *common.RegistrationEntry) { e.Hint = newEntry.Hint }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, // This should update all fields - {name: "Test With Nil Mask", + { + name: "Test With Nil Mask", mask: nil, update: func(e *common.RegistrationEntry) { proto.Merge(e, oldEntry) }, - result: func(e *common.RegistrationEntry) {}}, + result: func(e *common.RegistrationEntry) {}, + }, } { tt := testcase s.Run(tt.name, func() { @@ -3350,7 +3418,6 @@ func (s *PluginSuite) TestListParentIDEntries() { expectedList []*common.RegistrationEntry }{ { - name: "test_parentID_found", registrationEntries: allEntries, parentID: "spiffe://parent", @@ -4627,7 +4694,8 @@ func (s *PluginSuite) TestListFederationRelationships() { PageSize: 2, }, expectedList: []*datastore.FederationRelationship{fr1, fr2}, - expectedPagination: &datastore.Pagination{Token: "2", + expectedPagination: &datastore.Pagination{ + Token: "2", PageSize: 2, }, }, diff --git a/pkg/server/datastore/sqlstore/stmt_cache.go b/pkg/server/datastore/sqlstore/stmt_cache.go index f3fb354140..a934d2a880 100644 --- a/pkg/server/datastore/sqlstore/stmt_cache.go +++ b/pkg/server/datastore/sqlstore/stmt_cache.go @@ -25,7 +25,7 @@ func (cache *stmtCache) get(ctx context.Context, query string) (*sql.Stmt, error stmt, err := cache.db.PrepareContext(ctx, query) if err != nil { - return nil, sqlError.Wrap(err) + return nil, newWrappedSQLError(err) } value, loaded = cache.stmts.LoadOrStore(query, stmt) if loaded { diff --git a/pkg/server/endpoints/bundle/acme_auth.go b/pkg/server/endpoints/bundle/acme_auth.go index a9d12c5bcc..45e5fbb72b 100644 --- a/pkg/server/endpoints/bundle/acme_auth.go +++ b/pkg/server/endpoints/bundle/acme_auth.go @@ -4,12 +4,12 @@ import ( "context" "crypto" "crypto/tls" + "fmt" "github.com/sirupsen/logrus" "github.com/spiffe/spire/pkg/common/version" "github.com/spiffe/spire/pkg/server/endpoints/bundle/internal/autocert" "github.com/spiffe/spire/pkg/server/plugin/keymanager" - "github.com/zeebo/errs" "golang.org/x/crypto/acme" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -122,7 +122,7 @@ func (ks *acmeKeyStore) NewPrivateKey(ctx context.Context, id string, keyType au case autocert.EC256: kmKeyType = keymanager.ECP256 default: - return nil, errs.New("unsupported key type: %d", keyType) + return nil, fmt.Errorf("unsupported key type: %d", keyType) } key, err := ks.km.GenerateKey(ctx, keyID, kmKeyType) diff --git a/pkg/server/endpoints/bundle/server.go b/pkg/server/endpoints/bundle/server.go index d96490e476..e9c7a39bdf 100644 --- a/pkg/server/endpoints/bundle/server.go +++ b/pkg/server/endpoints/bundle/server.go @@ -11,7 +11,6 @@ import ( "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/bundle/spiffebundle" "github.com/spiffe/spire/pkg/common/bundleutil" - "github.com/zeebo/errs" ) type Getter interface { @@ -57,7 +56,7 @@ func (s *Server) ListenAndServe(ctx context.Context) error { // it gives us the ability to use/inspect an ephemeral port during testing. listener, err := s.c.listen("tcp", s.c.Address) if err != nil { - return errs.Wrap(err) + return err } // Set up the TLS config, setting TLS 1.2 as the minimum. @@ -72,7 +71,7 @@ func (s *Server) ListenAndServe(ctx context.Context) error { errCh := make(chan error, 1) go func() { - errCh <- errs.Wrap(server.ServeTLS(listener, "", "")) + errCh <- server.ServeTLS(listener, "", "") }() select { diff --git a/pkg/server/hostservice/identityprovider/identityprovider.go b/pkg/server/hostservice/identityprovider/identityprovider.go index 79213beff8..7aaf243c36 100644 --- a/pkg/server/hostservice/identityprovider/identityprovider.go +++ b/pkg/server/hostservice/identityprovider/identityprovider.go @@ -13,7 +13,6 @@ import ( "github.com/spiffe/spire/pkg/common/coretypes/jwtkey" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/server/datastore" - "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -124,7 +123,7 @@ func (v1 *identityProviderV1) FetchX509Identity(ctx context.Context, _ *identity privateKey, err := x509.MarshalPKCS8PrivateKey(x509Identity.PrivateKey) if err != nil { - return nil, errs.Wrap(err) + return nil, err } return &identityproviderv1.FetchX509IdentityResponse{ diff --git a/support/oidc-discovery-provider/config.go b/support/oidc-discovery-provider/config.go index c32600f940..f0cb81ed69 100644 --- a/support/oidc-discovery-provider/config.go +++ b/support/oidc-discovery-provider/config.go @@ -1,6 +1,8 @@ package main import ( + "errors" + "fmt" "net" "net/url" "os" @@ -8,7 +10,6 @@ import ( "github.com/hashicorp/hcl" "github.com/spiffe/spire/pkg/common/config" - "github.com/zeebo/errs" ) const ( @@ -189,7 +190,7 @@ type experimentalWorkloadAPIConfig struct { func LoadConfig(path string, expandEnv bool) (*Config, error) { hclBytes, err := os.ReadFile(path) if err != nil { - return nil, errs.New("unable to load configuration: %v", err) + return nil, fmt.Errorf("unable to load configuration: %w", err) } hclString := string(hclBytes) if expandEnv { @@ -201,7 +202,7 @@ func LoadConfig(path string, expandEnv bool) (*Config, error) { func ParseConfig(hclConfig string) (_ *Config, err error) { c := new(Config) if err := hcl.Decode(c, hclConfig); err != nil { - return nil, errs.New("unable to decode configuration: %v", err) + return nil, fmt.Errorf("unable to decode configuration: %w", err) } if c.LogLevel == "" { @@ -209,7 +210,7 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { } if len(c.Domains) == 0 { - return nil, errs.New("at least one domain must be configured") + return nil, errors.New("at least one domain must be configured") } c.Domains = dedupeList(c.Domains) @@ -220,20 +221,20 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { } switch { case c.InsecureAddr != "": - return nil, errs.New("insecure_addr and the acme section are mutually exclusive") + return nil, errors.New("insecure_addr and the acme section are mutually exclusive") case !c.ACME.ToSAccepted: - return nil, errs.New("tos_accepted must be set to true in the acme configuration section") + return nil, errors.New("tos_accepted must be set to true in the acme configuration section") case c.ACME.Email == "": - return nil, errs.New("email must be configured in the acme configuration section") + return nil, errors.New("email must be configured in the acme configuration section") } } if c.ServingCertFile != nil { if c.ServingCertFile.CertFilePath == "" { - return nil, errs.New("cert_file_path must be configured in the serving_cert_file configuration section") + return nil, errors.New("cert_file_path must be configured in the serving_cert_file configuration section") } if c.ServingCertFile.KeyFilePath == "" { - return nil, errs.New("key_file_path must be configured in the serving_cert_file configuration section") + return nil, errors.New("key_file_path must be configured in the serving_cert_file configuration section") } if c.ServingCertFile.RawAddr == "" { @@ -242,13 +243,13 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { addr, err := net.ResolveTCPAddr("tcp", c.ServingCertFile.RawAddr) if err != nil { - return nil, errs.New("invalid addr in the serving_cert_file configuration section: %v", err) + return nil, fmt.Errorf("invalid addr in the serving_cert_file configuration section: %w", err) } c.ServingCertFile.Addr = addr c.ServingCertFile.FileSyncInterval, err = parseDurationField(c.ServingCertFile.RawFileSyncInterval, defaultFileSyncInterval) if err != nil { - return nil, errs.New("invalid file_sync_interval in the serving_cert_file configuration section: %v", err) + return nil, fmt.Errorf("invalid file_sync_interval in the serving_cert_file configuration section: %w", err) } } @@ -257,18 +258,18 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { if c.ServerAPI != nil { c.ServerAPI.PollInterval, err = parseDurationField(c.ServerAPI.RawPollInterval, defaultPollInterval) if err != nil { - return nil, errs.New("invalid poll_interval in the server_api configuration section: %v", err) + return nil, fmt.Errorf("invalid poll_interval in the server_api configuration section: %w", err) } methodCount++ } if c.WorkloadAPI != nil { if c.WorkloadAPI.TrustDomain == "" { - return nil, errs.New("trust_domain must be configured in the workload_api configuration section") + return nil, errors.New("trust_domain must be configured in the workload_api configuration section") } c.WorkloadAPI.PollInterval, err = parseDurationField(c.WorkloadAPI.RawPollInterval, defaultPollInterval) if err != nil { - return nil, errs.New("invalid poll_interval in the workload_api configuration section: %v", err) + return nil, fmt.Errorf("invalid poll_interval in the workload_api configuration section: %w", err) } methodCount++ } @@ -291,15 +292,20 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { switch methodCount { case 0: - return nil, errs.New("either the server_api or workload_api section must be configured") + return nil, errors.New("either the server_api or workload_api section must be configured") case 1: default: - return nil, errs.New("the server_api and workload_api sections are mutually exclusive") + return nil, errors.New("the server_api and workload_api sections are mutually exclusive") } if c.JWTIssuer != "" { jwtIssuer, err := url.Parse(c.JWTIssuer) - if err != nil || jwtIssuer.Scheme == "" || jwtIssuer.Host == "" { - return nil, errs.New("the jwt_issuer url could not be parsed") + switch { + case err != nil: + return nil, fmt.Errorf("the jwt_issuer url could not be parsed: %w", err) + case jwtIssuer.Scheme == "": + return nil, errors.New("the jwt_issuer url must contain a scheme") + case jwtIssuer.Host == "": + return nil, errors.New("the jwt_issuer url must contain a host") } } return c, nil diff --git a/support/oidc-discovery-provider/config_posix_test.go b/support/oidc-discovery-provider/config_posix_test.go index bba9706483..3401c6818f 100644 --- a/support/oidc-discovery-provider/config_posix_test.go +++ b/support/oidc-discovery-provider/config_posix_test.go @@ -697,7 +697,7 @@ func parseConfigCasesOS() []parseConfigCase { address = "unix:///some/socket/path" } `, - err: "the jwt_issuer url could not be parsed", + err: "the jwt_issuer url must contain a scheme", }, { name: "JWT issuer with missing host", @@ -712,7 +712,7 @@ func parseConfigCasesOS() []parseConfigCase { address = "unix:///some/socket/path" } `, - err: "the jwt_issuer url could not be parsed", + err: "the jwt_issuer url must contain a host", }, { name: "JWT issuer is invalid", diff --git a/support/oidc-discovery-provider/config_test.go b/support/oidc-discovery-provider/config_test.go index ebe3e6b1a4..f194d3fdc6 100644 --- a/support/oidc-discovery-provider/config_test.go +++ b/support/oidc-discovery-provider/config_test.go @@ -27,7 +27,7 @@ func TestLoadConfig(t *testing.T) { require.Error(err) require.Contains(err.Error(), "unable to load configuration:") - err = os.WriteFile(confPath, []byte(minimalEnvServerAPIConfig), 0600) + err = os.WriteFile(confPath, []byte(minimalEnvServerAPIConfig), 0o600) require.NoError(err) os.Setenv("SPIFFE_TRUST_DOMAIN", "domain.test") @@ -45,7 +45,7 @@ func TestLoadConfig(t *testing.T) { ServerAPI: serverAPIConfig, }, config) - err = os.WriteFile(confPath, []byte(minimalServerAPIConfig), 0600) + err = os.WriteFile(confPath, []byte(minimalServerAPIConfig), 0o600) require.NoError(err) config, err = LoadConfig(confPath, false) diff --git a/support/oidc-discovery-provider/config_windows_test.go b/support/oidc-discovery-provider/config_windows_test.go index 728b81f440..7fd2efc266 100644 --- a/support/oidc-discovery-provider/config_windows_test.go +++ b/support/oidc-discovery-provider/config_windows_test.go @@ -645,7 +645,7 @@ func parseConfigCasesOS() []parseConfigCase { } } `, - err: "the jwt_issuer url could not be parsed", + err: "the jwt_issuer url must contain a scheme", }, { name: "JWT issuer with missing host", @@ -663,7 +663,7 @@ func parseConfigCasesOS() []parseConfigCase { } } `, - err: "the jwt_issuer url could not be parsed", + err: "the jwt_issuer url must contain a host", }, { name: "JWT issuer is invalid", diff --git a/support/oidc-discovery-provider/main.go b/support/oidc-discovery-provider/main.go index feb5e9b216..de2d70bbf2 100644 --- a/support/oidc-discovery-provider/main.go +++ b/support/oidc-discovery-provider/main.go @@ -3,6 +3,7 @@ package main import ( "context" "crypto/tls" + "errors" "flag" "fmt" "net" @@ -17,7 +18,6 @@ import ( "github.com/spiffe/spire/pkg/common/log" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/version" - "github.com/zeebo/errs" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" ) @@ -50,7 +50,7 @@ func run(configPath string, expandEnv bool) error { log, err := log.NewLogger(log.WithLevel(config.LogLevel), log.WithFormat(config.LogFormat), log.WithOutputFile(config.LogPath)) if err != nil { - return errs.Wrap(err) + return err } defer log.Close() @@ -158,7 +158,7 @@ func newSource(log logrus.FieldLogger, config *Config) (JWKSSource, error) { case config.WorkloadAPI != nil: workloadAPIAddr, err := config.getWorkloadAPIAddr() if err != nil { - return nil, errs.Wrap(err) + return nil, err } return NewWorkloadAPISource(WorkloadAPISourceConfig{ Log: log, @@ -168,7 +168,7 @@ func newSource(log logrus.FieldLogger, config *Config) (JWKSSource, error) { }) default: // This is defensive; LoadConfig should prevent this from happening. - return nil, errs.New("no source has been configured") + return nil, errors.New("no source has been configured") } } diff --git a/support/oidc-discovery-provider/main_posix.go b/support/oidc-discovery-provider/main_posix.go index d61d1091a6..4e6e75cee8 100644 --- a/support/oidc-discovery-provider/main_posix.go +++ b/support/oidc-discovery-provider/main_posix.go @@ -3,12 +3,12 @@ package main import ( + "errors" "net" "os" "strings" "github.com/spiffe/spire/pkg/common/util" - "github.com/zeebo/errs" ) func (c *Config) getWorkloadAPIAddr() (net.Addr, error) { @@ -23,33 +23,33 @@ func (c *Config) getServerAPITargetName() string { func (c *Config) validateOS() (err error) { switch { case c.ACME == nil && c.ListenSocketPath == "" && c.ServingCertFile == nil && c.InsecureAddr == "": - return errs.New("either acme, serving_cert_file, insecure_addr or listen_socket_path must be configured") + return errors.New("either acme, serving_cert_file, insecure_addr or listen_socket_path must be configured") case c.ACME != nil && c.ServingCertFile != nil: - return errs.New("acme and serving_cert_file are mutually exclusive") + return errors.New("acme and serving_cert_file are mutually exclusive") case c.ACME != nil && c.ListenSocketPath != "": - return errs.New("listen_socket_path and the acme section are mutually exclusive") + return errors.New("listen_socket_path and the acme section are mutually exclusive") case c.ServingCertFile != nil && c.InsecureAddr != "": - return errs.New("serving_cert_file and insecure_addr are mutually exclusive") + return errors.New("serving_cert_file and insecure_addr are mutually exclusive") case c.ServingCertFile != nil && c.ListenSocketPath != "": - return errs.New("serving_cert_file and listen_socket_path are mutually exclusive") + return errors.New("serving_cert_file and listen_socket_path are mutually exclusive") case c.ACME != nil && c.InsecureAddr != "": - return errs.New("acme and insecure_addr are mutually exclusive") + return errors.New("acme and insecure_addr are mutually exclusive") case c.InsecureAddr != "" && c.ListenSocketPath != "": - return errs.New("insecure_addr and listen_socket_path are mutually exclusive") + return errors.New("insecure_addr and listen_socket_path are mutually exclusive") } if c.ServerAPI != nil { if c.ServerAPI.Address == "" { - return errs.New("address must be configured in the server_api configuration section") + return errors.New("address must be configured in the server_api configuration section") } if !strings.HasPrefix(c.ServerAPI.Address, "unix:") { - return errs.New("address must use the unix name system in the server_api configuration section") + return errors.New("address must use the unix name system in the server_api configuration section") } } if c.WorkloadAPI != nil { if c.WorkloadAPI.SocketPath == "" { - return errs.New("socket_path must be configured in the workload_api configuration section") + return errors.New("socket_path must be configured in the workload_api configuration section") } } diff --git a/support/oidc-discovery-provider/main_windows.go b/support/oidc-discovery-provider/main_windows.go index a05b5fd32e..55d24ebdb6 100644 --- a/support/oidc-discovery-provider/main_windows.go +++ b/support/oidc-discovery-provider/main_windows.go @@ -3,6 +3,7 @@ package main import ( + "errors" "fmt" "net" "path/filepath" @@ -10,7 +11,6 @@ import ( "github.com/Microsoft/go-winio" "github.com/spiffe/spire/pkg/common/namedpipe" "github.com/spiffe/spire/pkg/common/sddl" - "github.com/zeebo/errs" ) func (c *Config) getWorkloadAPIAddr() (net.Addr, error) { @@ -25,29 +25,29 @@ func (c *Config) getServerAPITargetName() string { func (c *Config) validateOS() (err error) { switch { case c.ACME == nil && c.Experimental.ListenNamedPipeName == "" && c.ServingCertFile == nil && c.InsecureAddr == "": - return errs.New("either acme, serving_cert_file, insecure_addr or listen_named_pipe_name must be configured") + return errors.New("either acme, serving_cert_file, insecure_addr or listen_named_pipe_name must be configured") case c.ACME != nil && c.ServingCertFile != nil: - return errs.New("acme and serving_cert_file are mutually exclusive") + return errors.New("acme and serving_cert_file are mutually exclusive") case c.ACME != nil && c.Experimental.ListenNamedPipeName != "": - return errs.New("listen_named_pipe_name and the acme section are mutually exclusive") + return errors.New("listen_named_pipe_name and the acme section are mutually exclusive") case c.ACME != nil && c.InsecureAddr != "": - return errs.New("acme and insecure_addr are mutually exclusive") + return errors.New("acme and insecure_addr are mutually exclusive") case c.ServingCertFile != nil && c.InsecureAddr != "": - return errs.New("serving_cert_file and insecure_addr are mutually exclusive") + return errors.New("serving_cert_file and insecure_addr are mutually exclusive") case c.ServingCertFile != nil && c.Experimental.ListenNamedPipeName != "": - return errs.New("serving_cert_file and listen_named_pipe_name are mutually exclusive") + return errors.New("serving_cert_file and listen_named_pipe_name are mutually exclusive") case c.InsecureAddr != "" && c.Experimental.ListenNamedPipeName != "": - return errs.New("insecure_addr and listen_named_pipe_name are mutually exclusive") + return errors.New("insecure_addr and listen_named_pipe_name are mutually exclusive") } if c.ServerAPI != nil { if c.ServerAPI.Experimental.NamedPipeName == "" { - return errs.New("named_pipe_name must be configured in the server_api configuration section") + return errors.New("named_pipe_name must be configured in the server_api configuration section") } } if c.WorkloadAPI != nil { if c.WorkloadAPI.Experimental.NamedPipeName == "" { - return errs.New("named_pipe_name must be configured in the workload_api configuration section") + return errors.New("named_pipe_name must be configured in the workload_api configuration section") } } diff --git a/support/oidc-discovery-provider/server_api.go b/support/oidc-discovery-provider/server_api.go index 74724f1b36..5a9c98c444 100644 --- a/support/oidc-discovery-provider/server_api.go +++ b/support/oidc-discovery-provider/server_api.go @@ -12,7 +12,6 @@ import ( bundlev1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/bundle/v1" "github.com/spiffe/spire-api-sdk/proto/spire/api/types" "github.com/spiffe/spire/pkg/common/util" - "github.com/zeebo/errs" "google.golang.org/grpc" "google.golang.org/protobuf/proto" ) @@ -51,7 +50,7 @@ func NewServerAPISource(config ServerAPISourceConfig) (*ServerAPISource, error) conn, err := util.GRPCDialContext(context.Background(), config.GRPCTarget) if err != nil { - return nil, errs.Wrap(err) + return nil, err } ctx, cancel := context.WithCancel(context.Background()) diff --git a/support/oidc-discovery-provider/workload_api.go b/support/oidc-discovery-provider/workload_api.go index 8db442d4f3..caaabf9c78 100644 --- a/support/oidc-discovery-provider/workload_api.go +++ b/support/oidc-discovery-provider/workload_api.go @@ -16,7 +16,6 @@ import ( "github.com/spiffe/go-spiffe/v2/workloadapi" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/util" - "github.com/zeebo/errs" ) const ( @@ -56,19 +55,19 @@ func NewWorkloadAPISource(config WorkloadAPISourceConfig) (*WorkloadAPISource, e if config.Addr != nil { o, err := util.GetWorkloadAPIClientOption(config.Addr) if err != nil { - return nil, errs.Wrap(err) + return nil, err } opts = append(opts, o) } trustDomain, err := spiffeid.TrustDomainFromString(config.TrustDomain) if err != nil { - return nil, errs.Wrap(err) + return nil, err } client, err := workloadapi.New(context.Background(), opts...) if err != nil { - return nil, errs.Wrap(err) + return nil, err } ctx, cancel := context.WithCancel(context.Background())