diff --git a/.github/workflows/pr-validation.yaml b/.github/workflows/pr-validation.yaml index 9d8e13e..4671003 100644 --- a/.github/workflows/pr-validation.yaml +++ b/.github/workflows/pr-validation.yaml @@ -17,7 +17,7 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v3.3.1 with: - version: v1.48.0 + version: v1.50.1 fmt: runs-on: ubuntu-latest diff --git a/.golangci.yaml b/.golangci.yaml index 37a4747..0867a2c 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -99,6 +99,8 @@ linters-settings: - truncatecmp - ruleguard - nestingreduce + disabled-checks: + - newDeref enabled-tags: - performance disabled-tags: diff --git a/README.md b/README.md index 0d26696..b8df471 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,14 @@ This is a middleware for http to make it easy to use OpenID Connect. +## Changelog + +Below, large (breaking) changes will be documented: + +### v0.0.37 + +From `v0.0.37` and forward, the `options.WithRequiredClaims()` has been deprecated and generics are used to provide the claims type. A new validation function can be provided instead of `options.WithRequiredClaims()`. If you don't need claims validation, you can pass `nil` instead of a `ClaimsValidationFn`. + ## Stability notice This library is under active development and the api will have breaking changes until `v0.1.0` - after that only breaking changes will be introduced between minor versions (`v0.1.0` -> `v0.2.0`). @@ -29,6 +37,57 @@ This library is under active development and the api will have breaking changes Import: `"github.com/xenitab/go-oidc-middleware/options"` +### Claims validation example + +From `v0.0.37` and forward, claim validation is done using a `ClaimsValidationFn`. The below examples will use the following claims type and validation function: + +```go +type AzureADClaims struct { + Aio string `json:"aio"` + Audience []string `json:"aud"` + Azp string `json:"azp"` + Azpacr string `json:"azpacr"` + ExpiresAt time.Time `json:"exp"` + IssuedAt time.Time `json:"iat"` + Idp string `json:"idp"` + Issuer string `json:"iss"` + Name string `json:"name"` + NotBefore time.Time `json:"nbf"` + Oid string `json:"oid"` + PreferredUsername string `json:"preferred_username"` + Rh string `json:"rh"` + Scope string `json:"scp"` + Subject string `json:"sub"` + TenantId string `json:"tid"` + Uti string `json:"uti"` + TokenVersion string `json:"ver"` +} + +func GetAzureADClaimsValidationFn(requiredTenantId string) options.ClaimsValidationFn[AzureADClaims] { + return func(claims *AzureADClaims) error { + if requiredTenantId != "" && claims.TenantId != requiredTenantId { + return fmt.Errorf("tid claim is required to be %q but was: %s", requiredTenantId, claims.TenantId) + } + + return nil + } +} +``` + +If you don't want typed claims, use `type Claims map[string]interface{}` and provide it. If you don't want to use a `ClaimsValidationFn` (as it will provide the type) the handlers will need to be configured as below: + +```go +type Claims map[string]interface{} + +oidcHandler := oidchttp.New[Claims](h, nil, opts...) +``` + +or + +```go +oidcHandler := oidchttp.New[map[string]interface{}](h, nil, opts...) +``` + ### net/http, mux & chi **Import** @@ -39,13 +98,11 @@ Import: `"github.com/xenitab/go-oidc-middleware/options"` ```go oidcHandler := oidchttp.New(h, + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(cfg.Issuer), options.WithRequiredTokenType("JWT"), options.WithRequiredAudience(cfg.Audience), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(map[string]interface{}{ - "tid": cfg.TenantID, - }), ) ``` @@ -54,7 +111,7 @@ oidcHandler := oidchttp.New(h, ```go func newClaimsHandler() http.HandlerFunc { fn := func(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(options.DefaultClaimsContextKeyName).(map[string]interface{}) + claims, ok := r.Context().Value(options.DefaultClaimsContextKeyName).(AzureADClaims) if !ok { w.WriteHeader(http.StatusUnauthorized) return @@ -82,13 +139,11 @@ func newClaimsHandler() http.HandlerFunc { ```go oidcHandler := oidcgin.New( + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(cfg.Issuer), options.WithRequiredTokenType("JWT"), options.WithRequiredAudience(cfg.Audience), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(map[string]interface{}{ - "tid": cfg.TenantID, - }), ) ``` @@ -103,7 +158,7 @@ func newClaimsHandler() gin.HandlerFunc { return } - claims, ok := claimsValue.(map[string]interface{}) + claims, ok := claimsValue.(AzureADClaims) if !ok { c.AbortWithStatus(http.StatusUnauthorized) return @@ -124,13 +179,11 @@ func newClaimsHandler() gin.HandlerFunc { ```go oidcHandler := oidcfiber.New( + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(cfg.Issuer), options.WithRequiredTokenType("JWT"), options.WithRequiredAudience(cfg.Audience), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(map[string]interface{}{ - "tid": cfg.TenantID, - }), ) ``` @@ -139,7 +192,7 @@ oidcHandler := oidcfiber.New( ```go func newClaimsHandler() fiber.Handler { return func(c *fiber.Ctx) error { - claims, ok := c.Locals("claims").(map[string]interface{}) + claims, ok := c.Locals("claims").(AzureADClaims) if !ok { return c.SendStatus(fiber.StatusUnauthorized) } @@ -160,13 +213,11 @@ func newClaimsHandler() fiber.Handler { ```go e.Use(middleware.JWTWithConfig(middleware.JWTConfig{ ParseTokenFunc: oidcechojwt.New( + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(cfg.Issuer), options.WithRequiredTokenType("JWT"), options.WithRequiredAudience(cfg.Audience), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(map[string]interface{}{ - "tid": cfg.TenantID, - }), ), })) ``` @@ -175,7 +226,7 @@ e.Use(middleware.JWTWithConfig(middleware.JWTConfig{ ```go func newClaimsHandler(c echo.Context) error { - claims, ok := c.Get("user").(map[string]interface{}) + claims, ok := c.Get("user").(AzureADClaims) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "invalid token") } @@ -194,13 +245,11 @@ func newClaimsHandler(c echo.Context) error { ```go oidcTokenHandler := oidctoken.New(h, + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(cfg.Issuer), options.WithRequiredTokenType("JWT"), options.WithRequiredAudience(cfg.Audience), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(map[string]interface{}{ - "tid": cfg.TenantID, - }), ) // oidctoken.GetTokenString is optional, but you will need the JWT token as a string @@ -217,54 +266,15 @@ if err != nil { ## Other options -### Deeply nested required claims - -If you want to use `options.WithRequiredClaims()` with nested values, you need to specify the actual type when configuring it and not an interface and the middleware will use this to infer what types the token claims are. - -Example claims could look like this: - -```json -{ - "foo": { - "bar": ["uno", "dos", "baz", "tres"] - } -} -``` - -This would then be interpreted as the following inside the code: - -```go -"foo": map[string]interface {}{ - "bar":[]interface {}{ - "uno", - "dos", - "baz", - "tres" - }, -} -``` - -If you want to require the claim `foo.bar` to contain the value `baz`, it would look like this: - -```go -options.WithRequiredClaims(map[string]interface{}{ - "foo": map[string][]string{ - "bar": {"baz"}, - } -}) -``` - ### Extract token from multiple headers Example for `Authorization` and `Foo` headers. If token is found in `Authorization`, `Foo` will not be tried. If `Authorization` extraction fails but there's a header `Foo = Bar_baz` then `baz` would be extracted as the token. ```go oidcHandler := oidcgin.New( + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(cfg.Issuer), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(map[string]interface{}{ - "cid": cfg.ClientID, - }), options.WithTokenString( options.WithTokenStringHeaderName("Authorization"), options.WithTokenStringTokenPrefix("Bearer "), @@ -284,11 +294,9 @@ The following would be used by a the Kubernetes api server, where the kubernetes ```go oidcHandler := oidcgin.New( + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(cfg.Issuer), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(map[string]interface{}{ - "cid": cfg.ClientID, - }), options.WithTokenString( options.WithTokenStringHeaderName("Authorization"), options.WithTokenStringTokenPrefix("Bearer "), @@ -319,11 +327,9 @@ errorHandler := func(description options.ErrorDescription, err error) { } oidcHandler := oidcgin.New( + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(cfg.Issuer), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(map[string]interface{}{ - "cid": cfg.ClientID, - }), options.WithErrorHandler(errorHandler), ) ``` @@ -348,6 +354,7 @@ func TestFoobar(t *testing.T) { [...] oidcHandler := oidchttp.New(h, + GetAzureADClaimsValidationFn(cfg.TenantID), options.WithIssuer(op.GetURL(t)), options.WithRequiredTokenType("JWT+AT"), options.WithRequiredAudience("test-client"), diff --git a/examples/PROVIDER_AUTH0.md b/examples/PROVIDER_AUTH0.md index 08327ca..da0d814 100644 --- a/examples/PROVIDER_AUTH0.md +++ b/examples/PROVIDER_AUTH0.md @@ -8,7 +8,7 @@ Create an Auth0 account and an api as well as a native app. TOKEN_ISSUER="https://.auth0.com/" TOKEN_AUDIENCE="https://localhost:8081" CLIENT_ID="Auth0NativeAppClientID" -go run ./api/main.go --server [server] --provider auth0 --token-issuer ${TOKEN_ISSUER} --token-audience ${TOKEN_AUDIENCE} --required-claims azp:${CLIENT_ID} --port 8081 +go run ./api/main.go --server [server] --provider auth0 --token-issuer ${TOKEN_ISSUER} --token-audience ${TOKEN_AUDIENCE} --required-auth0-client-id ${CLIENT_ID} --port 8081 ``` ## Test with curl @@ -17,4 +17,4 @@ go run ./api/main.go --server [server] --provider auth0 --token-issuer ${TOKEN_I ACCESS_TOKEN=$(go run ./pkce-cli/main.go --issuer ${TOKEN_ISSUER} --client-id ${CLIENT_ID} --extra-authz-params audience:${TOKEN_AUDIENCE} | jq -r ".access_token") curl -s http://localhost:8081 | jq curl -s -H "Authorization: Bearer ${ACCESS_TOKEN}" http://localhost:8081 | jq -``` \ No newline at end of file +``` diff --git a/examples/PROVIDER_AZUREAD.md b/examples/PROVIDER_AZUREAD.md index 1d9f560..e571063 100644 --- a/examples/PROVIDER_AZUREAD.md +++ b/examples/PROVIDER_AZUREAD.md @@ -27,7 +27,7 @@ az rest --method PATCH --uri "https://graph.microsoft.com/beta/applications/${AZ az rest --method PATCH --uri "https://graph.microsoft.com/beta/applications/${AZ_APP_OBJECT_ID}" --body "{\"api\":{\"preAuthorizedApplications\":[{\"appId\":\"04b07795-8ddb-461a-bbee-02f9e1bf7b46\",\"permissionIds\":[\"${AZ_APP_PERMISSION_ID}\"]}]}}" # Add PKCE-CLI as allowed client az rest --method PATCH --uri "https://graph.microsoft.com/beta/applications/${AZ_APP_OBJECT_ID}" --body "{\"api\":{\"preAuthorizedApplications\":[{\"appId\":\"04b07795-8ddb-461a-bbee-02f9e1bf7b46\",\"permissionIds\":[\"${AZ_APP_PERMISSION_ID}\"]},{\"appId\":\"${AZ_APP_PKCECLI_ID}\",\"permissionIds\":[\"${AZ_APP_PERMISSION_ID}\"]}]}}" -``` +``` ## Run web server @@ -35,7 +35,7 @@ az rest --method PATCH --uri "https://graph.microsoft.com/beta/applications/${AZ TENANT_ID=$(az account show -o json | jq -r .tenantId) TOKEN_ISSUER="https://login.microsoftonline.com/${TENANT_ID}/v2.0" TOKEN_AUDIENCE=$(az ad app list --identifier-uri ${AZ_APP_URI} | jq -r ".[0].appId") -go run ./api/main.go --server [server] --provider azuread --token-issuer ${TOKEN_ISSUER} --token-audience ${TOKEN_AUDIENCE} --required-claims tid:${TENANT_ID} --port 8081 +go run ./api/main.go --server [server] --provider azuread --token-issuer ${TOKEN_ISSUER} --token-audience ${TOKEN_AUDIENCE} --required-azure-ad-tenant-id ${TENANT_ID} --port 8081 ``` ## Test with curl diff --git a/examples/PROVIDER_COGNITO.md b/examples/PROVIDER_COGNITO.md index 7ce9133..e911ff7 100644 --- a/examples/PROVIDER_COGNITO.md +++ b/examples/PROVIDER_COGNITO.md @@ -7,7 +7,7 @@ Create a Cognito user pool, app client and configure the callback for the app cl ```shell TOKEN_ISSUER="https://cognito-idp.{region}.amazonaws.com/{userPoolId}" CLIENT_ID="CognitoClientID" -go run ./api/main.go --server [server] --provider cognito --token-issuer ${TOKEN_ISSUER} --required-claims client_id:${CLIENT_ID} --port 8081 +go run ./api/main.go --server [server] --provider cognito --token-issuer ${TOKEN_ISSUER} --required-cognito-client-id ${CLIENT_ID} --port 8081 ``` ## Test with curl @@ -17,4 +17,4 @@ CLIENT_SECRET="CognitoAppSecret" ACCESS_TOKEN=$(go run ./pkce-cli/main.go --issuer ${TOKEN_ISSUER} --client-id ${CLIENT_ID} --extra-token-params client_secret:${CLIENT_SECRET} | jq -r ".access_token") curl -s http://localhost:8081 | jq curl -s -H "Authorization: Bearer ${ACCESS_TOKEN}" http://localhost:8081 | jq -``` \ No newline at end of file +``` diff --git a/examples/PROVIDER_OKTA.md b/examples/PROVIDER_OKTA.md index fd0ff26..7e84ee6 100644 --- a/examples/PROVIDER_OKTA.md +++ b/examples/PROVIDER_OKTA.md @@ -7,7 +7,7 @@ Create an Okta organization and a native app. Copy the issuer and client id. ```shell TOKEN_ISSUER="https://.okta.com/oauth2/default" CLIENT_ID="OktaClientID" -go run ./api/main.go --server [server] --provider okta --token-issuer ${TOKEN_ISSUER} --required-claims cid:${CLIENT_ID} --port 8081 +go run ./api/main.go --server [server] --provider okta --token-issuer ${TOKEN_ISSUER} --required-okta-client-id ${CLIENT_ID} --port 8081 ``` ## Test with curl @@ -16,4 +16,4 @@ go run ./api/main.go --server [server] --provider okta --token-issuer ${TOKEN_IS ACCESS_TOKEN=$(go run ./pkce-cli/main.go --issuer ${TOKEN_ISSUER} --client-id ${CLIENT_ID} | jq -r ".access_token") curl -s http://localhost:8081 | jq curl -s -H "Authorization: Bearer ${ACCESS_TOKEN}" http://localhost:8081 | jq -``` \ No newline at end of file +``` diff --git a/examples/api/main.go b/examples/api/main.go index 52674b4..90e5f2d 100644 --- a/examples/api/main.go +++ b/examples/api/main.go @@ -29,18 +29,13 @@ func main() { func run(cfg shared.RuntimeConfig) error { var opts []options.Option - requiredClaims := make(map[string]interface{}) - for k, v := range cfg.RequiredClaims { - requiredClaims[k] = v - } - switch cfg.Provider { case shared.Auth0Provider: inputs := map[string]string{ "issuer": cfg.Issuer, "audience": cfg.Audience, "fallbackSignatureAlgorithm": cfg.FallbackSignatureAlgorithm, - "requiredClaims azp": cfg.RequiredClaims["azp"], + "RequiredAuth0ClientId": cfg.RequiredAuth0ClientId, } err := stringNotEmpty(inputs) @@ -53,14 +48,15 @@ func run(cfg shared.RuntimeConfig) error { options.WithRequiredTokenType("JWT"), options.WithRequiredAudience(cfg.Audience), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(requiredClaims), } + claimsValidationFn := shared.GetAuth0ClaimsValidationFn(cfg.RequiredAuth0ClientId) + return getHandler(cfg, claimsValidationFn, opts...) case shared.AzureADProvider: inputs := map[string]string{ "issuer": cfg.Issuer, "audience": cfg.Audience, "fallbackSignatureAlgorithm": cfg.FallbackSignatureAlgorithm, - "requiredClaims tid": cfg.RequiredClaims["tid"], + "RequiredAzureADTenantId": cfg.RequiredAzureADTenantId, } err := stringNotEmpty(inputs) @@ -73,13 +69,14 @@ func run(cfg shared.RuntimeConfig) error { options.WithRequiredTokenType("JWT"), options.WithRequiredAudience(cfg.Audience), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(requiredClaims), } + claimsValidationFn := shared.GetAzureADClaimsValidationFn(cfg.RequiredAzureADTenantId) + return getHandler(cfg, claimsValidationFn, opts...) case shared.CognitoProvider: inputs := map[string]string{ "issuer": cfg.Issuer, "fallbackSignatureAlgorithm": cfg.FallbackSignatureAlgorithm, - "requiredClaims client_id": cfg.RequiredClaims["client_id"], + "RequiredCognitoClientId": cfg.RequiredCognitoClientId, } err := stringNotEmpty(inputs) @@ -90,13 +87,14 @@ func run(cfg shared.RuntimeConfig) error { opts = []options.Option{ options.WithIssuer(cfg.Issuer), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(requiredClaims), } + claimsValidationFn := shared.GetCognitoClaimsValidationFn(cfg.RequiredCognitoClientId) + return getHandler(cfg, claimsValidationFn, opts...) case shared.OktaProvider: inputs := map[string]string{ "issuer": cfg.Issuer, "fallbackSignatureAlgorithm": cfg.FallbackSignatureAlgorithm, - "requiredClaims cid": cfg.RequiredClaims["cid"], + "RequiredOktaClientId": cfg.RequiredOktaClientId, } err := stringNotEmpty(inputs) @@ -107,30 +105,33 @@ func run(cfg shared.RuntimeConfig) error { opts = []options.Option{ options.WithIssuer(cfg.Issuer), options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm), - options.WithRequiredClaims(requiredClaims), } + claimsValidationFn := shared.GetOktaClaimsValidationFn(cfg.RequiredOktaClientId) + return getHandler(cfg, claimsValidationFn, opts...) default: return fmt.Errorf("unknown provider: %s", cfg.Provider) } +} +func getHandler[T any](cfg shared.RuntimeConfig, claimsValidationFn options.ClaimsValidationFn[T], opts ...options.Option) error { switch cfg.Server { case shared.HttpServer: - h := shared.NewHttpClaimsHandler() - oidcHandler := oidchttp.New(h, opts...) + h := shared.NewHttpClaimsHandler[T]() + oidcHandler := oidchttp.New(h, claimsValidationFn, opts...) return shared.RunHttp(oidcHandler, cfg.Address, cfg.Port) case shared.GinServer: - oidcHandler := oidcgin.New(opts...) + oidcHandler := oidcgin.New(claimsValidationFn, opts...) - return shared.RunGin(oidcHandler, cfg.Address, cfg.Port) + return shared.RunGin[T](oidcHandler, cfg.Address, cfg.Port) case shared.EchoJwtServer: - parseToken := oidcechojwt.New(opts...) + parseToken := oidcechojwt.New(claimsValidationFn, opts...) - return shared.RunEchoJWT(parseToken, cfg.Address, cfg.Port) + return shared.RunEchoJWT[T](parseToken, cfg.Address, cfg.Port) case shared.FiberServer: - oidcHandler := oidcfiber.New(opts...) + oidcHandler := oidcfiber.New(claimsValidationFn, opts...) - return shared.RunFiber(oidcHandler, cfg.Address, cfg.Port) + return shared.RunFiber[T](oidcHandler, cfg.Address, cfg.Port) default: return fmt.Errorf("unknown server: %s", cfg.Server) } diff --git a/examples/go.mod b/examples/go.mod index 25a0c73..8cb5e11 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -3,11 +3,11 @@ module examples go 1.19 require ( - github.com/xenitab/go-oidc-middleware v0.0.35 - github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.35 - github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.35 - github.com/xenitab/go-oidc-middleware/oidcgin v0.0.35 - github.com/xenitab/go-oidc-middleware/oidchttp v0.0.35 + github.com/xenitab/go-oidc-middleware v0.0.36 + github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.36 + github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.36 + github.com/xenitab/go-oidc-middleware/oidcgin v0.0.36 + github.com/xenitab/go-oidc-middleware/oidchttp v0.0.36 ) require ( diff --git a/examples/go.sum b/examples/go.sum index db87859..c830e63 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -134,16 +134,16 @@ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/xenitab/dispans v0.0.10 h1:S+gSUM14rDJWK7MYNrjb8JbjeQPip6mlNJyLX+g7Agc= -github.com/xenitab/go-oidc-middleware v0.0.35 h1:9u1rQ/MqYXg4IpeJcOKyCSA2Xo8Pji3IiIZ+ZbAoqFI= -github.com/xenitab/go-oidc-middleware v0.0.35/go.mod h1:a8lpsTfdmiEsbclX4oIQE2gXj+8cYLLGRKUtgccwR94= -github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.35 h1:MTjxN9H5Ymwo6LTb3nR1btLReVkV7daUwkfGqqIBpgA= -github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.35/go.mod h1:DvLMJloowJ60zGx/OpI/TclBsEiQHVldVgJi78MKmFM= -github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.35 h1:JwTDajjdApaJHotl6Gqsw59c/H8Nie32TE9c6/qOwj8= -github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.35/go.mod h1:wyIeNsvNni7++FgMHNXbuXlPz4ws5Q8O+d7HjWAxePA= -github.com/xenitab/go-oidc-middleware/oidcgin v0.0.35 h1:FZwKzh5KdCQ+hBBESL+SNihWxWj+iv8iHvmLqcKaYGk= -github.com/xenitab/go-oidc-middleware/oidcgin v0.0.35/go.mod h1:h1g6NRJybNyL72U6irl5zext94tnqPClvrRZJz/DDsY= -github.com/xenitab/go-oidc-middleware/oidchttp v0.0.35 h1:6zeXgAcQcs48TzTHQDwPgnSFfCgFFDEKQTND6NXdxdQ= -github.com/xenitab/go-oidc-middleware/oidchttp v0.0.35/go.mod h1:6CyXTQ3EuILXY6Aj+N+P3rNKVqbJ2xLQOShKhgptFSY= +github.com/xenitab/go-oidc-middleware v0.0.36 h1:iBm+8usJZg9mCWrZWliHpzNatWn6g31AAcbb1q4M6go= +github.com/xenitab/go-oidc-middleware v0.0.36/go.mod h1:dUakIYup0Grr7Bn/88xTTKtlS6MWoWZtrrnzdt/SUZU= +github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.36 h1:zJD6glq1w5BGyNUejenGnN/gIpPtzmv+XSGtvWrWnrE= +github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.36/go.mod h1:XQY+d76KId7Q0FSs2Bos4NhJHElwNtdTMGwTymUmEqg= +github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.36 h1:0Xgthl0LFDinlxdSOOF/xTEaa7daMGNqejL0oSbLLuQ= +github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.36/go.mod h1:J/t3JB185krNB5SXAnsaip+FysV9AKJxeg5GPfAUEHM= +github.com/xenitab/go-oidc-middleware/oidcgin v0.0.36 h1:MnAWuFadBi1YJuR3xlVABAZ1DB5vTATKjsq43CKesNs= +github.com/xenitab/go-oidc-middleware/oidcgin v0.0.36/go.mod h1:nmbp/vOda0HjOyz83IqveGhY90oH6Cln2MzVmVKB1I8= +github.com/xenitab/go-oidc-middleware/oidchttp v0.0.36 h1:1KAdu+gQbHsj+3gsSDm1VOqvHA0g65qOgB7Wgun3ddU= +github.com/xenitab/go-oidc-middleware/oidchttp v0.0.36/go.mod h1:BTN8cQlYi+ZS4Hqwx7P8LIrR252K5oO2dcTAJY7kaAg= github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= diff --git a/examples/shared/runtime.go b/examples/shared/runtime.go index 5cea417..039faef 100644 --- a/examples/shared/runtime.go +++ b/examples/shared/runtime.go @@ -2,8 +2,10 @@ package shared import ( "fmt" + "time" "github.com/cristalhq/aconfig" + "github.com/xenitab/go-oidc-middleware/options" ) type Server string @@ -43,15 +45,18 @@ func (p Provider) Validate() error { } type RuntimeConfig struct { - Server Server `flag:"server" env:"server" usage:"what server to use" required:"true"` - Provider Provider `flag:"provider" env:"PROVIDER" usage:"what provider to use" required:"true"` - Address string `flag:"address" env:"ADDRESS" default:"127.0.0.1" usage:"address webserver will listen to"` - Port int `flag:"port" env:"PORT" default:"8080" usage:"port webserver will listen to"` - Issuer string `flag:"token-issuer" env:"TOKEN_ISSUER" usage:"the oidc issuer url for tokens"` - Audience string `flag:"token-audience" env:"TOKEN_AUDIENCE" usage:"the audience that tokens need to contain"` - ClientID string `flag:"client-id" env:"CLIENT_ID" usage:"the client id that tokens need to contain"` - FallbackSignatureAlgorithm string `flag:"fallback-signature-algorithm" env:"FALLBACK_SIGNATURE_ALGORITHM" default:"RS256" usage:"if the issue jwks doesn't contain key alg, use the following signature algorithm to verify the signature of the tokens"` - RequiredClaims map[string]string `flag:"required-claims" env:"REQUIRED_CLAIMS" usage:"adds required claims"` + Server Server `flag:"server" env:"server" usage:"what server to use" required:"true"` + Provider Provider `flag:"provider" env:"PROVIDER" usage:"what provider to use" required:"true"` + Address string `flag:"address" env:"ADDRESS" default:"127.0.0.1" usage:"address webserver will listen to"` + Port int `flag:"port" env:"PORT" default:"8080" usage:"port webserver will listen to"` + Issuer string `flag:"token-issuer" env:"TOKEN_ISSUER" usage:"the oidc issuer url for tokens"` + Audience string `flag:"token-audience" env:"TOKEN_AUDIENCE" usage:"the audience that tokens need to contain"` + ClientID string `flag:"client-id" env:"CLIENT_ID" usage:"the client id that tokens need to contain"` + FallbackSignatureAlgorithm string `flag:"fallback-signature-algorithm" env:"FALLBACK_SIGNATURE_ALGORITHM" default:"RS256" usage:"if the issue jwks doesn't contain key alg, use the following signature algorithm to verify the signature of the tokens"` + RequiredAuth0ClientId string `flag:"required-auth0-client-id" env:"REQUIRED_AUTH0_CLIENT_ID" usage:"the required Auth0 Client ID"` + RequiredAzureADTenantId string `flag:"required-azure-ad-tenant-id" env:"REQUIRED_AZURE_AD_TENANT_ID" usage:"the required Azure AD Tenant ID"` + RequiredCognitoClientId string `flag:"required-cognito-client-id" env:"REQUIRED_COGNITO_CLIENT_ID" usage:"the required Cognito Client ID"` + RequiredOktaClientId string `flag:"required-okta-client-id" env:"REQUIRED_OKTA_CLIENT_ID" usage:"the required Okta Client ID"` } func NewRuntimeConfig() (RuntimeConfig, error) { @@ -83,3 +88,104 @@ func NewRuntimeConfig() (RuntimeConfig, error) { return cfg, nil } + +type Auth0Claims struct { + Audience []string `json:"aud"` + ClientId string `json:"azp"` + ExpiresAt time.Time `json:"exp"` + IssuedAt time.Time `json:"iat"` + Issuer string `json:"iss"` + Scope string `json:"scope"` + Subject string `json:"sub"` +} + +func GetAuth0ClaimsValidationFn(requiredClientId string) options.ClaimsValidationFn[Auth0Claims] { + return func(claims *Auth0Claims) error { + if requiredClientId != "" && claims.ClientId != requiredClientId { + return fmt.Errorf("azp claim is required to be %q but was: %s", requiredClientId, claims.ClientId) + } + + return nil + } +} + +type AzureADClaims struct { + Aio string `json:"aio"` + Audience []string `json:"aud"` + Azp string `json:"azp"` + Azpacr string `json:"azpacr"` + ExpiresAt time.Time `json:"exp"` + IssuedAt time.Time `json:"iat"` + Idp string `json:"idp"` + Issuer string `json:"iss"` + Name string `json:"name"` + NotBefore time.Time `json:"nbf"` + Oid string `json:"oid"` + PreferredUsername string `json:"preferred_username"` + Rh string `json:"rh"` + Scope string `json:"scp"` + Subject string `json:"sub"` + TenantId string `json:"tid"` + Uti string `json:"uti"` + TokenVersion string `json:"ver"` +} + +func GetAzureADClaimsValidationFn(requiredTenantId string) options.ClaimsValidationFn[AzureADClaims] { + return func(claims *AzureADClaims) error { + if requiredTenantId != "" && claims.TenantId != requiredTenantId { + return fmt.Errorf("tid claim is required to be %q but was: %s", requiredTenantId, claims.TenantId) + } + + return nil + } +} + +type CognitoClaims struct { + AuthTime int64 `json:"auth_time"` + ClientId string `json:"client_id"` + EventId string `json:"event_id"` + ExpiresAt time.Time `json:"exp"` + IssuedAt time.Time `json:"iat"` + Issuer string `json:"iss"` + Jti string `json:"jti"` + OriginJti string `json:"origin_jti"` + Scope string `json:"scope"` + Subject string `json:"sub"` + TokenUse string `json:"token_use"` + Username string `json:"username"` + Version int `json:"version"` +} + +func GetCognitoClaimsValidationFn(requiredClientId string) options.ClaimsValidationFn[CognitoClaims] { + return func(claims *CognitoClaims) error { + if requiredClientId != "" && claims.ClientId != requiredClientId { + return fmt.Errorf("client_id claim is required to be %q but was: %s", requiredClientId, claims.ClientId) + } + + return nil + } +} + +type OktaClaims struct { + Audience []string `json:"aud"` + AuthTime int64 `json:"auth_time"` + ClientId string `json:"cid"` + ExpiresAt time.Time `json:"exp"` + IssuedAt time.Time `json:"iat"` + Issuer string `json:"iss"` + Jti string `json:"jti"` + Scope []string `json:"scp"` + Subject string `json:"sub"` + Uid string `json:"uid"` + Version int `json:"ver"` +} + +func GetOktaClaimsValidationFn(requiredClientId string) options.ClaimsValidationFn[OktaClaims] { + return func(claims *OktaClaims) error { + if requiredClientId != "" && claims.ClientId != requiredClientId { + return fmt.Errorf("cid claim is required to be %q but was: %s", requiredClientId, claims.ClientId) + } + + return nil + } +} diff --git a/examples/shared/server_echo_jwt.go b/examples/shared/server_echo_jwt.go index 7d1b620..8f6f3c5 100644 --- a/examples/shared/server_echo_jwt.go +++ b/examples/shared/server_echo_jwt.go @@ -11,8 +11,8 @@ import ( type echoJWTParseTokenFunc func(auth string, c echo.Context) (interface{}, error) -func newEchoJWTClaimsHandler(c echo.Context) error { - claims, ok := c.Get("user").(map[string]interface{}) +func newEchoJWTClaimsHandler[T any](c echo.Context) error { + claims, ok := c.Get("user").(T) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "invalid token") } @@ -20,7 +20,7 @@ func newEchoJWTClaimsHandler(c echo.Context) error { return c.JSON(http.StatusOK, claims) } -func RunEchoJWT(parseToken echoJWTParseTokenFunc, address string, port int) error { +func RunEchoJWT[T any](parseToken echoJWTParseTokenFunc, address string, port int) error { e := echo.New() e.HideBanner = true @@ -32,7 +32,7 @@ func RunEchoJWT(parseToken echoJWTParseTokenFunc, address string, port int) erro ParseTokenFunc: parseToken, })) - handler := newEchoJWTClaimsHandler + handler := newEchoJWTClaimsHandler[T] e.GET("/", handler) diff --git a/examples/shared/server_fiber.go b/examples/shared/server_fiber.go index 86c49b8..6552b3d 100644 --- a/examples/shared/server_fiber.go +++ b/examples/shared/server_fiber.go @@ -12,9 +12,9 @@ import ( "golang.org/x/sync/errgroup" ) -func newFiberClaimsHandler() fiber.Handler { +func newFiberClaimsHandler[T any]() fiber.Handler { return func(c *fiber.Ctx) error { - claims, ok := c.Locals("claims").(map[string]interface{}) + claims, ok := c.Locals("claims").(T) if !ok { return c.SendStatus(fiber.StatusUnauthorized) } @@ -23,7 +23,7 @@ func newFiberClaimsHandler() fiber.Handler { } } -func RunFiber(oidcHandler fiber.Handler, address string, port int) error { +func RunFiber[T any](oidcHandler fiber.Handler, address string, port int) error { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -39,7 +39,7 @@ func RunFiber(oidcHandler fiber.Handler, address string, port int) error { app.Use(oidcHandler) - claimsHandler := newFiberClaimsHandler() + claimsHandler := newFiberClaimsHandler[T]() app.Get("/", claimsHandler) g.Go(func() error { diff --git a/examples/shared/server_gin.go b/examples/shared/server_gin.go index 81df6c0..c7abf6b 100644 --- a/examples/shared/server_gin.go +++ b/examples/shared/server_gin.go @@ -8,7 +8,7 @@ import ( "github.com/gin-gonic/gin" ) -func newGinClaimsHandler() gin.HandlerFunc { +func newGinClaimsHandler[T any]() gin.HandlerFunc { return func(c *gin.Context) { claimsValue, found := c.Get("claims") if !found { @@ -16,7 +16,7 @@ func newGinClaimsHandler() gin.HandlerFunc { return } - claims, ok := claimsValue.(map[string]interface{}) + claims, ok := claimsValue.(T) if !ok { c.AbortWithStatus(http.StatusUnauthorized) return @@ -26,7 +26,7 @@ func newGinClaimsHandler() gin.HandlerFunc { } } -func RunGin(oidcHandler gin.HandlerFunc, address string, port int) error { +func RunGin[T any](oidcHandler gin.HandlerFunc, address string, port int) error { addr := net.JoinHostPort(address, fmt.Sprintf("%d", port)) gin.SetMode(gin.ReleaseMode) @@ -34,7 +34,7 @@ func RunGin(oidcHandler gin.HandlerFunc, address string, port int) error { r.Use(oidcHandler) - claimsHandler := newGinClaimsHandler() + claimsHandler := newGinClaimsHandler[T]() r.GET("/", claimsHandler) return r.Run(addr) diff --git a/examples/shared/server_http.go b/examples/shared/server_http.go index 3027424..044a9a2 100644 --- a/examples/shared/server_http.go +++ b/examples/shared/server_http.go @@ -16,9 +16,9 @@ import ( "golang.org/x/sync/errgroup" ) -func NewHttpClaimsHandler() http.HandlerFunc { +func NewHttpClaimsHandler[T any]() http.HandlerFunc { fn := func(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(options.DefaultClaimsContextKeyName).(map[string]interface{}) + claims, ok := r.Context().Value(options.DefaultClaimsContextKeyName).(T) if !ok { w.WriteHeader(http.StatusUnauthorized) return diff --git a/go.mod b/go.mod index 2bde93a..8ec1de9 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/lestrrat-go/jwx v1.2.25 github.com/stretchr/testify v1.8.1 github.com/xenitab/dispans v0.0.10 - github.com/zclconf/go-cty v1.12.1 go.uber.org/ratelimit v0.2.0 golang.org/x/sync v0.1.0 ) @@ -44,7 +43,6 @@ require ( golang.org/x/crypto v0.3.0 // indirect golang.org/x/net v0.2.0 // indirect golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect - golang.org/x/text v0.4.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.28.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index 88a3530..e4c936d 100644 --- a/go.sum +++ b/go.sum @@ -274,8 +274,6 @@ github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZ github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= -github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -419,8 +417,6 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/coverage/go.mod b/internal/coverage/go.mod index 1ec9955..a11f6f0 100644 --- a/internal/coverage/go.mod +++ b/internal/coverage/go.mod @@ -3,11 +3,11 @@ module coverage go 1.19 require ( - github.com/xenitab/go-oidc-middleware v0.0.35 - github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.35 - github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.35 - github.com/xenitab/go-oidc-middleware/oidcgin v0.0.35 - github.com/xenitab/go-oidc-middleware/oidchttp v0.0.35 + github.com/xenitab/go-oidc-middleware v0.0.36 + github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.36 + github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.36 + github.com/xenitab/go-oidc-middleware/oidcgin v0.0.36 + github.com/xenitab/go-oidc-middleware/oidchttp v0.0.36 ) require ( diff --git a/internal/coverage/go.sum b/internal/coverage/go.sum index c8174f3..5ddee90 100644 --- a/internal/coverage/go.sum +++ b/internal/coverage/go.sum @@ -129,16 +129,16 @@ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/xenitab/dispans v0.0.10 h1:S+gSUM14rDJWK7MYNrjb8JbjeQPip6mlNJyLX+g7Agc= -github.com/xenitab/go-oidc-middleware v0.0.35 h1:9u1rQ/MqYXg4IpeJcOKyCSA2Xo8Pji3IiIZ+ZbAoqFI= -github.com/xenitab/go-oidc-middleware v0.0.35/go.mod h1:a8lpsTfdmiEsbclX4oIQE2gXj+8cYLLGRKUtgccwR94= -github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.35 h1:MTjxN9H5Ymwo6LTb3nR1btLReVkV7daUwkfGqqIBpgA= -github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.35/go.mod h1:DvLMJloowJ60zGx/OpI/TclBsEiQHVldVgJi78MKmFM= -github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.35 h1:JwTDajjdApaJHotl6Gqsw59c/H8Nie32TE9c6/qOwj8= -github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.35/go.mod h1:wyIeNsvNni7++FgMHNXbuXlPz4ws5Q8O+d7HjWAxePA= -github.com/xenitab/go-oidc-middleware/oidcgin v0.0.35 h1:FZwKzh5KdCQ+hBBESL+SNihWxWj+iv8iHvmLqcKaYGk= -github.com/xenitab/go-oidc-middleware/oidcgin v0.0.35/go.mod h1:h1g6NRJybNyL72U6irl5zext94tnqPClvrRZJz/DDsY= -github.com/xenitab/go-oidc-middleware/oidchttp v0.0.35 h1:6zeXgAcQcs48TzTHQDwPgnSFfCgFFDEKQTND6NXdxdQ= -github.com/xenitab/go-oidc-middleware/oidchttp v0.0.35/go.mod h1:6CyXTQ3EuILXY6Aj+N+P3rNKVqbJ2xLQOShKhgptFSY= +github.com/xenitab/go-oidc-middleware v0.0.36 h1:iBm+8usJZg9mCWrZWliHpzNatWn6g31AAcbb1q4M6go= +github.com/xenitab/go-oidc-middleware v0.0.36/go.mod h1:dUakIYup0Grr7Bn/88xTTKtlS6MWoWZtrrnzdt/SUZU= +github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.36 h1:zJD6glq1w5BGyNUejenGnN/gIpPtzmv+XSGtvWrWnrE= +github.com/xenitab/go-oidc-middleware/oidcechojwt v0.0.36/go.mod h1:XQY+d76KId7Q0FSs2Bos4NhJHElwNtdTMGwTymUmEqg= +github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.36 h1:0Xgthl0LFDinlxdSOOF/xTEaa7daMGNqejL0oSbLLuQ= +github.com/xenitab/go-oidc-middleware/oidcfiber v0.0.36/go.mod h1:J/t3JB185krNB5SXAnsaip+FysV9AKJxeg5GPfAUEHM= +github.com/xenitab/go-oidc-middleware/oidcgin v0.0.36 h1:MnAWuFadBi1YJuR3xlVABAZ1DB5vTATKjsq43CKesNs= +github.com/xenitab/go-oidc-middleware/oidcgin v0.0.36/go.mod h1:nmbp/vOda0HjOyz83IqveGhY90oH6Cln2MzVmVKB1I8= +github.com/xenitab/go-oidc-middleware/oidchttp v0.0.36 h1:1KAdu+gQbHsj+3gsSDm1VOqvHA0g65qOgB7Wgun3ddU= +github.com/xenitab/go-oidc-middleware/oidchttp v0.0.36/go.mod h1:BTN8cQlYi+ZS4Hqwx7P8LIrR252K5oO2dcTAJY7kaAg= github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= diff --git a/internal/coverage/main.go b/internal/coverage/main.go index fa400a4..b7a0b5e 100644 --- a/internal/coverage/main.go +++ b/internal/coverage/main.go @@ -10,12 +10,14 @@ import ( "github.com/xenitab/go-oidc-middleware/options" ) +type testClaims map[string]interface{} + func main() { f := &foo{} - _ = oidcechojwt.New() - _ = oidcfiber.New() - _ = oidcgin.New() - _ = oidchttp.New(f) + _ = oidcechojwt.New[testClaims](nil) + _ = oidcfiber.New[testClaims](nil) + _ = oidcgin.New[testClaims](nil) + _ = oidchttp.New[testClaims](f, nil) _ = options.New() } diff --git a/internal/oidc/cty.go b/internal/oidc/cty.go deleted file mode 100644 index f4289eb..0000000 --- a/internal/oidc/cty.go +++ /dev/null @@ -1,191 +0,0 @@ -package oidc - -import ( - "fmt" - - "github.com/zclconf/go-cty/cty" - "github.com/zclconf/go-cty/cty/gocty" -) - -func getCtyValueWithImpliedType(a interface{}) (cty.Value, error) { - if a == nil { - return cty.NilVal, fmt.Errorf("input is nil") - } - - valueType, err := gocty.ImpliedType(a) - if err != nil { - return cty.NilVal, fmt.Errorf("unable to get cty.Type: %w", err) - } - - return getCtyValueWithType(a, valueType) -} - -func getCtyValueWithType(a interface{}, vt cty.Type) (cty.Value, error) { - if a == nil { - return cty.NilVal, fmt.Errorf("input value is nil") - } - - if vt == cty.NilType { - return cty.NilVal, fmt.Errorf("input type is nil") - } - - value, err := gocty.ToCtyValue(a, vt) - if err != nil { - // we should never receive this error - return cty.NilVal, fmt.Errorf("unable to get cty.Value: %w", err) - } - - return value, nil -} - -func getCtyValues(a interface{}, b interface{}) (cty.Value, cty.Value, error) { - first, err := getCtyValueWithImpliedType(a) - if err != nil { - return cty.NilVal, cty.NilVal, err - } - - second, err := getCtyValueWithType(b, first.Type()) - if err != nil { - return cty.NilVal, cty.NilVal, err - } - - return first, second, nil -} - -func isCtyPrimitiveValueValid(a cty.Value, b cty.Value) bool { - if !isCtyTypeSame(a, b) { - return false - } - - if getCtyType(a) != primitiveCtyType { - return false - } - - return a.Equals(b) == cty.True -} - -func isCtyListValid(a cty.Value, b cty.Value) bool { - if !isCtyTypeSame(a, b) { - return false - } - - if getCtyType(a) != listCtyType { - return false - } - - listA := a.AsValueSlice() - listB := b.AsValueSlice() - - for i := range listA { - if !ctyListContains(listB, listA[i]) { - return false - } - } - - return true -} - -func isCtyMapValid(a cty.Value, b cty.Value) bool { - if !isCtyTypeSame(a, b) { - return false - } - - if getCtyType(a) != mapCtyType { - return false - } - - mapA := a.AsValueMap() - mapB := b.AsValueMap() - - for k := range mapA { - mapBValue, ok := mapB[k] - if !ok { - return false - } - - err := isCtyValueValid(mapA[k], mapBValue) - if err != nil { - return false - } - } - - return true -} - -func ctyListContains(a []cty.Value, b cty.Value) bool { - for i := range a { - err := isCtyValueValid(a[i], b) - if err == nil { - return true - } - } - - return false -} - -func isCtyTypeSame(a cty.Value, b cty.Value) bool { - return a.Type().Equals(b.Type()) -} - -func isCtyValueValid(a cty.Value, b cty.Value) error { - if !isCtyTypeSame(a, b) { - return fmt.Errorf("should be type %s, was type: %s", a.Type().GoString(), b.Type().GoString()) - } - - switch getCtyType(a) { - case primitiveCtyType: - valid := isCtyPrimitiveValueValid(a, b) - if !valid { - return fmt.Errorf("should be %s, was: %s", a.GoString(), b.GoString()) - } - case listCtyType: - valid := isCtyListValid(a, b) - if !valid { - return fmt.Errorf("should contain %s, received: %s", a.GoString(), b.GoString()) - } - case mapCtyType: - valid := isCtyMapValid(a, b) - if !valid { - return fmt.Errorf("should contain %s, received: %s", a.GoString(), b.GoString()) - } - default: - return fmt.Errorf("non-implemented type - should be %s, received: %s", a.GoString(), b.GoString()) - } - - return nil -} - -type ctyType int - -const ( - unknownCtyType = iota - primitiveCtyType - listCtyType - mapCtyType -) - -func getCtyType(a cty.Value) ctyType { - if a.Type().IsPrimitiveType() { - return primitiveCtyType - } - - switch { - case a.Type().IsListType(): - return listCtyType - - // Adding the other cases to make it easier in the - // future to build logic for more types. - case a.Type().IsMapType(): - return mapCtyType - case a.Type().IsSetType(): - return unknownCtyType - case a.Type().IsObjectType(): - return unknownCtyType - case a.Type().IsTupleType(): - return unknownCtyType - case a.Type().IsCapsuleType(): - return unknownCtyType - } - - return unknownCtyType -} diff --git a/internal/oidc/cty_test.go b/internal/oidc/cty_test.go deleted file mode 100644 index 5160067..0000000 --- a/internal/oidc/cty_test.go +++ /dev/null @@ -1,642 +0,0 @@ -package oidc - -import ( - "testing" - - "github.com/stretchr/testify/require" - "github.com/zclconf/go-cty/cty" -) - -func TestGetCtyValueWithImpliedType(t *testing.T) { - cases := []struct { - testDescription string - input interface{} - expectedCtyType cty.Type - expectedError bool - }{ - { - testDescription: "string as cty.String", - input: "foo", - expectedCtyType: cty.String, - expectedError: false, - }, - { - testDescription: "string number as cty.String", - input: "1234", - expectedCtyType: cty.String, - expectedError: false, - }, - { - testDescription: "int as cty.Number", - input: int(1234), - expectedCtyType: cty.Number, - expectedError: false, - }, - { - testDescription: "float64 as cty.Number", - input: float64(1234), - expectedCtyType: cty.Number, - expectedError: false, - }, - { - testDescription: "list of strings as cty.List(cty.String)", - input: []string{"foo"}, - expectedCtyType: cty.List(cty.String), - expectedError: false, - }, - { - testDescription: "string map as cty.Map(cty.String)", - input: map[string]string{"foo": "bar"}, - expectedCtyType: cty.Map(cty.String), - expectedError: false, - }, - { - testDescription: "empty array of interfaces as cty.NilType and error", - input: []interface{}{}, - expectedCtyType: cty.NilType, - expectedError: true, - }, - { - testDescription: "nil as cty.NilType and error", - input: nil, - expectedCtyType: cty.NilType, - expectedError: true, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - v, err := getCtyValueWithImpliedType(c.input) - require.Equal(t, c.expectedCtyType, v.Type()) - - if c.expectedError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - } -} - -func TestGetCtyValueWithType(t *testing.T) { - cases := []struct { - testDescription string - input interface{} - inputType cty.Type - expectedCtyType cty.Type - expectedError bool - }{ - { - testDescription: "string as cty.String", - input: "foo", - inputType: cty.String, - expectedCtyType: cty.String, - expectedError: false, - }, - { - testDescription: "string number as cty.String", - input: "1234", - inputType: cty.String, - expectedCtyType: cty.String, - expectedError: false, - }, - { - testDescription: "int as cty.Number", - input: int(1234), - inputType: cty.Number, - expectedCtyType: cty.Number, - expectedError: false, - }, - { - testDescription: "float64 as cty.Number", - input: float64(1234), - inputType: cty.Number, - expectedCtyType: cty.Number, - expectedError: false, - }, - { - testDescription: "list of strings as cty.List(cty.String)", - input: []string{"foo"}, - inputType: cty.List(cty.String), - expectedCtyType: cty.List(cty.String), - expectedError: false, - }, - { - testDescription: "string map as cty.Map(cty.String)", - input: map[string]string{"foo": "bar"}, - inputType: cty.Map(cty.String), - expectedCtyType: cty.Map(cty.String), - expectedError: false, - }, - { - testDescription: "empty array of interfaces as cty.NilType and error", - input: []interface{}{}, - inputType: cty.NilType, - expectedCtyType: cty.NilType, - expectedError: true, - }, - { - testDescription: "nil as cty.NilType and error", - input: nil, - inputType: cty.NilType, - expectedCtyType: cty.NilType, - expectedError: true, - }, - { - testDescription: "interface list in an interface map", - input: map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": []interface{}{ - "uno", - "dos", - "tres", - }, - }, - }, - inputType: cty.Map(cty.Map(cty.List(cty.String))), - expectedCtyType: cty.Map(cty.Map(cty.List(cty.String))), - expectedError: false, - }, - { - testDescription: "interface list in an interface map with wrong input type", - input: map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": []interface{}{ - "uno", - "dos", - "tres", - }, - }, - }, - inputType: cty.String, - expectedCtyType: cty.NilType, - expectedError: true, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - v, err := getCtyValueWithType(c.input, c.inputType) - require.Equal(t, c.expectedCtyType, v.Type()) - - if c.expectedError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - } -} - -func TestGetCtyValues(t *testing.T) { - var a, b interface{} - - a = "foo" - b = "bar" - - ctyA, ctyB, err := getCtyValues(a, b) - - require.NoError(t, err) - require.Equal(t, "cty.StringVal(\"foo\")", ctyA.GoString()) - require.Equal(t, "cty.StringVal(\"bar\")", ctyB.GoString()) -} - -func TestIsCtyPrimitiveValueValid(t *testing.T) { - cases := []struct { - testDescription string - firstValue cty.Value - secondValue cty.Value - expectedResult bool - }{ - { - testDescription: "same input strings", - firstValue: cty.StringVal("foo"), - secondValue: cty.StringVal("foo"), - expectedResult: true, - }, - { - testDescription: "same input numbers", - firstValue: cty.NumberIntVal(1337), - secondValue: cty.NumberIntVal(1337), - expectedResult: true, - }, - { - testDescription: "different input strings", - firstValue: cty.StringVal("foo"), - secondValue: cty.StringVal("bar"), - expectedResult: false, - }, - { - testDescription: "different input numbers", - firstValue: cty.NumberIntVal(1337), - secondValue: cty.NumberIntVal(7331), - expectedResult: false, - }, - { - testDescription: "different types", - firstValue: cty.StringVal("bar"), - secondValue: cty.NumberIntVal(7331), - expectedResult: false, - }, - { - testDescription: "input list", - firstValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - secondValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - expectedResult: false, - }, - { - testDescription: "input map", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - expectedResult: false, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - result := isCtyPrimitiveValueValid(c.firstValue, c.secondValue) - require.Equal(t, c.expectedResult, result) - } -} - -func TestIsCtyListValid(t *testing.T) { - cases := []struct { - testDescription string - firstValue cty.Value - secondValue cty.Value - expectedResult bool - }{ - { - testDescription: "same input string", - firstValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - secondValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - expectedResult: true, - }, - { - testDescription: "same input int", - firstValue: cty.ListVal([]cty.Value{cty.NumberIntVal(1337)}), - secondValue: cty.ListVal([]cty.Value{cty.NumberIntVal(1337)}), - expectedResult: true, - }, - { - testDescription: "different input string", - firstValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - secondValue: cty.ListVal([]cty.Value{cty.StringVal("bar")}), - expectedResult: false, - }, - { - testDescription: "different input int", - firstValue: cty.ListVal([]cty.Value{cty.NumberIntVal(1337)}), - secondValue: cty.ListVal([]cty.Value{cty.NumberIntVal(7331)}), - expectedResult: false, - }, - { - testDescription: "same input multiple second", - firstValue: cty.ListVal([]cty.Value{cty.StringVal("bar")}), - secondValue: cty.ListVal([]cty.Value{cty.StringVal("foo"), cty.StringVal("bar"), cty.StringVal("baz")}), - expectedResult: true, - }, - { - testDescription: "input string", - firstValue: cty.StringVal("foo"), - secondValue: cty.StringVal("foo"), - expectedResult: false, - }, - { - testDescription: "same input map", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - expectedResult: false, - }, - { - testDescription: "different types", - firstValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - secondValue: cty.ListVal([]cty.Value{cty.NumberIntVal(1337)}), - expectedResult: false, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - result := isCtyListValid(c.firstValue, c.secondValue) - require.Equal(t, c.expectedResult, result) - } -} - -func TestIsCtyMapValid(t *testing.T) { - cases := []struct { - testDescription string - firstValue cty.Value - secondValue cty.Value - expectedResult bool - }{ - { - testDescription: "same input string", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - expectedResult: true, - }, - { - testDescription: "same input int", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(1337)}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(1337)}), - expectedResult: true, - }, - { - testDescription: "different input string", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("bar")}), - expectedResult: false, - }, - { - testDescription: "different input int", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(1337)}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(7331)}), - expectedResult: false, - }, - { - testDescription: "different types", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(1337)}), - expectedResult: false, - }, - { - testDescription: "input string", - firstValue: cty.StringVal("foo"), - secondValue: cty.StringVal("foo"), - expectedResult: false, - }, - { - testDescription: "input list", - firstValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - secondValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - expectedResult: false, - }, - { - testDescription: "same input multiple second", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - secondValue: cty.MapVal(map[string]cty.Value{"a": cty.StringVal("b"), "foo": cty.StringVal("foo"), "c": cty.StringVal("d")}), - expectedResult: true, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - result := isCtyMapValid(c.firstValue, c.secondValue) - require.Equal(t, c.expectedResult, result) - } -} - -func TestCtyListContains(t *testing.T) { - cases := []struct { - testDescription string - slice []cty.Value - value cty.Value - expectedResult bool - }{ - { - testDescription: "same input string", - slice: []cty.Value{cty.StringVal("foo")}, - value: cty.StringVal("foo"), - expectedResult: true, - }, - { - testDescription: "same input int", - slice: []cty.Value{cty.NumberIntVal(1337)}, - value: cty.NumberIntVal(1337), - expectedResult: true, - }, - { - testDescription: "different input string", - slice: []cty.Value{cty.StringVal("foo")}, - value: cty.StringVal("bar"), - expectedResult: false, - }, - { - testDescription: "different input int", - slice: []cty.Value{cty.NumberIntVal(1337)}, - value: cty.NumberIntVal(7331), - expectedResult: false, - }, - { - testDescription: "same input string multiple", - slice: []cty.Value{cty.StringVal("foo"), cty.StringVal("bar"), cty.StringVal("baz")}, - value: cty.StringVal("bar"), - expectedResult: true, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - result := ctyListContains(c.slice, c.value) - require.Equal(t, c.expectedResult, result) - } -} - -func TestIsCtyTypeSame(t *testing.T) { - cases := []struct { - testDescription string - firstValue cty.Value - secondValue cty.Value - expectedResult bool - }{ - { - testDescription: "same input strings", - firstValue: cty.StringVal("foo"), - secondValue: cty.StringVal("foo"), - expectedResult: true, - }, - { - testDescription: "same input numbers", - firstValue: cty.NumberIntVal(1337), - secondValue: cty.NumberIntVal(1337), - expectedResult: true, - }, - { - testDescription: "different input strings", - firstValue: cty.StringVal("foo"), - secondValue: cty.StringVal("bar"), - expectedResult: true, - }, - { - testDescription: "different input numbers", - firstValue: cty.NumberIntVal(1337), - secondValue: cty.NumberIntVal(7331), - expectedResult: true, - }, - { - testDescription: "different types", - firstValue: cty.StringVal("foo"), - secondValue: cty.NumberIntVal(1337), - expectedResult: false, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - result := isCtyTypeSame(c.firstValue, c.secondValue) - require.Equal(t, c.expectedResult, result) - } -} - -func TestIsCtyValueValid(t *testing.T) { - cases := []struct { - testDescription string - firstValue cty.Value - secondValue cty.Value - expectedError bool - }{ - { - testDescription: "same input strings", - firstValue: cty.StringVal("foo"), - secondValue: cty.StringVal("foo"), - expectedError: false, - }, - { - testDescription: "same input numbers", - firstValue: cty.NumberIntVal(1337), - secondValue: cty.NumberIntVal(1337), - expectedError: false, - }, - { - testDescription: "different input strings", - firstValue: cty.StringVal("foo"), - secondValue: cty.StringVal("bar"), - expectedError: true, - }, - { - testDescription: "different input numbers", - firstValue: cty.NumberIntVal(1337), - secondValue: cty.NumberIntVal(7331), - expectedError: true, - }, - { - testDescription: "different types", - firstValue: cty.StringVal("foo"), - secondValue: cty.NumberIntVal(1337), - expectedError: true, - }, - { - testDescription: "same input list string", - firstValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - secondValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - expectedError: false, - }, - { - testDescription: "same input list int", - firstValue: cty.ListVal([]cty.Value{cty.NumberIntVal(1337)}), - secondValue: cty.ListVal([]cty.Value{cty.NumberIntVal(1337)}), - expectedError: false, - }, - { - testDescription: "different input list string", - firstValue: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - secondValue: cty.ListVal([]cty.Value{cty.StringVal("bar")}), - expectedError: true, - }, - { - testDescription: "different input list int", - firstValue: cty.ListVal([]cty.Value{cty.NumberIntVal(1337)}), - secondValue: cty.ListVal([]cty.Value{cty.NumberIntVal(7331)}), - expectedError: true, - }, - { - testDescription: "same input map string", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - expectedError: false, - }, - { - testDescription: "same input map int", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(1337)}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(1337)}), - expectedError: false, - }, - { - testDescription: "different input map string", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("bar")}), - expectedError: true, - }, - { - testDescription: "different input map int", - firstValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(1337)}), - secondValue: cty.MapVal(map[string]cty.Value{"foo": cty.NumberIntVal(7331)}), - expectedError: true, - }, - { - testDescription: "non-imlemented type", - firstValue: cty.SetVal([]cty.Value{cty.StringVal("foo")}), - secondValue: cty.SetVal([]cty.Value{cty.StringVal("foo")}), - expectedError: true, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - err := isCtyValueValid(c.firstValue, c.secondValue) - - if c.expectedError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - } -} - -func TestGetCtyType(t *testing.T) { - cases := []struct { - testDescription string - input cty.Value - expectedType ctyType - }{ - { - testDescription: "string is primitiveCtyType", - input: cty.StringVal("foo"), - expectedType: primitiveCtyType, - }, - { - testDescription: "int is primitiveCtyType", - input: cty.NumberIntVal(1337), - expectedType: primitiveCtyType, - }, - { - testDescription: "float is primitiveCtyType", - input: cty.NumberFloatVal(1337), - expectedType: primitiveCtyType, - }, - { - testDescription: "bool is primitiveCtyType", - input: cty.BoolVal(true), - expectedType: primitiveCtyType, - }, - { - testDescription: "slice is listCtyType", - input: cty.ListVal([]cty.Value{cty.StringVal("foo")}), - expectedType: listCtyType, - }, - { - testDescription: "map is mapCtyType", - input: cty.MapVal(map[string]cty.Value{"foo": cty.StringVal("foo")}), - expectedType: mapCtyType, - }, - { - testDescription: "set is unknownCtyType", - input: cty.SetVal([]cty.Value{cty.StringVal("foo")}), - expectedType: unknownCtyType, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - resultType := getCtyType(c.input) - require.Equal(t, c.expectedType, resultType) - } -} diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index dba0247..ee5ecea 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -22,7 +22,7 @@ var ( errSignatureVerification = fmt.Errorf("failed to verify signature") ) -type handler struct { +type handler[T any] struct { issuer string discoveryUri string discoveryFetchTimeout time.Duration @@ -33,17 +33,16 @@ type handler struct { allowedTokenDrift time.Duration requiredAudience string requiredTokenType string - requiredClaims map[string]interface{} disableKeyID bool httpClient *http.Client - - keyHandler *keyHandler + keyHandler *keyHandler + claimsValidationFn options.ClaimsValidationFn[T] } -func NewHandler(setters ...options.Option) (*handler, error) { +func NewHandler[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) (*handler[T], error) { opts := options.New(setters...) - h := &handler{ + h := &handler[T]{ issuer: opts.Issuer, discoveryUri: opts.DiscoveryUri, discoveryFetchTimeout: opts.DiscoveryFetchTimeout, @@ -53,9 +52,9 @@ func NewHandler(setters ...options.Option) (*handler, error) { allowedTokenDrift: opts.AllowedTokenDrift, requiredTokenType: opts.RequiredTokenType, requiredAudience: opts.RequiredAudience, - requiredClaims: opts.RequiredClaims, disableKeyID: opts.DisableKeyID, httpClient: opts.HttpClient, + claimsValidationFn: claimsValidationFn, } if h.issuer == "" { @@ -82,7 +81,7 @@ func NewHandler(setters ...options.Option) (*handler, error) { return h, nil } -func (h *handler) loadJwks() error { +func (h *handler[T]) loadJwks() error { if h.jwksUri == "" { jwksUri, err := getJwksUriFromDiscoveryUri(h.httpClient, h.discoveryUri, h.discoveryFetchTimeout) if err != nil { @@ -101,27 +100,27 @@ func (h *handler) loadJwks() error { return nil } -func (h *handler) SetIssuer(issuer string) { +func (h *handler[T]) SetIssuer(issuer string) { h.issuer = issuer } -func (h *handler) SetDiscoveryUri(discoveryUri string) { +func (h *handler[T]) SetDiscoveryUri(discoveryUri string) { h.discoveryUri = discoveryUri } -type ParseTokenFunc func(ctx context.Context, tokenString string) (jwt.Token, error) +type ParseTokenFunc[T any] func(ctx context.Context, tokenString string) (T, error) -func (h *handler) ParseToken(ctx context.Context, tokenString string) (jwt.Token, error) { +func (h *handler[T]) ParseToken(ctx context.Context, tokenString string) (T, error) { if h.keyHandler == nil { err := h.loadJwks() if err != nil { - return nil, fmt.Errorf("unable to load jwks: %w", err) + return *new(T), fmt.Errorf("unable to load jwks: %w", err) } } tokenTypeValid := isTokenTypeValid(h.requiredTokenType, tokenString) if !tokenTypeValid { - return nil, fmt.Errorf("token type %q required", h.requiredTokenType) + return *new(T), fmt.Errorf("token type %q required", h.requiredTokenType) } keyID := "" @@ -129,18 +128,18 @@ func (h *handler) ParseToken(ctx context.Context, tokenString string) (jwt.Token var err error keyID, err = getKeyIDFromTokenString(tokenString) if err != nil { - return nil, err + return *new(T), err } } key, err := h.keyHandler.getKey(ctx, keyID) if err != nil { - return nil, fmt.Errorf("unable to get public key: %w", err) + return *new(T), fmt.Errorf("unable to get public key: %w", err) } alg, err := getSignatureAlgorithm(key.KeyType(), key.Algorithm(), h.fallbackSignatureAlgorithm) if err != nil { - return nil, err + return *new(T), err } token, err := getAndValidateTokenFromString(tokenString, key, alg) @@ -148,51 +147,77 @@ func (h *handler) ParseToken(ctx context.Context, tokenString string) (jwt.Token if h.disableKeyID && errors.Is(err, errSignatureVerification) { updatedKey, err := h.keyHandler.waitForUpdateKeySetAndGetKey(ctx) if err != nil { - return nil, err + return *new(T), err } alg, err := getSignatureAlgorithm(key.KeyType(), key.Algorithm(), h.fallbackSignatureAlgorithm) if err != nil { - return nil, err + return *new(T), err } token, err = getAndValidateTokenFromString(tokenString, updatedKey, alg) if err != nil { - return nil, err + return *new(T), err } } else { - return nil, err + return *new(T), err } } validExpiration := isTokenExpirationValid(token.Expiration(), h.allowedTokenDrift) if !validExpiration { - return nil, fmt.Errorf("token has expired: %s", token.Expiration()) + return *new(T), fmt.Errorf("token has expired: %s", token.Expiration()) } validIssuer := isTokenIssuerValid(h.issuer, token.Issuer()) if !validIssuer { - return nil, fmt.Errorf("required issuer %q was not found, received: %s", h.issuer, token.Issuer()) + return *new(T), fmt.Errorf("required issuer %q was not found, received: %s", h.issuer, token.Issuer()) } validAudience := isTokenAudienceValid(h.requiredAudience, token.Audience()) if !validAudience { - return nil, fmt.Errorf("required audience %q was not found, received: %v", h.requiredAudience, token.Audience()) + return *new(T), fmt.Errorf("required audience %q was not found, received: %v", h.requiredAudience, token.Audience()) } - if h.requiredClaims != nil { - tokenClaims, err := token.AsMap(ctx) - if err != nil { - return nil, fmt.Errorf("unable to get token claims: %w", err) - } + claims, err := h.jwtTokenToClaims(ctx, token) + if err != nil { + return *new(T), fmt.Errorf("unable to convert jwt.Token to claims: %w", err) + } - err = isRequiredClaimsValid(h.requiredClaims, tokenClaims) - if err != nil { - return nil, fmt.Errorf("unable to validate required claims: %w", err) - } + err = h.validateClaims(&claims) + if err != nil { + return *new(T), fmt.Errorf("claims validation returned an error: %w", err) } - return token, nil + return claims, nil +} + +func (h *handler[T]) validateClaims(claims *T) error { + if h.claimsValidationFn == nil { + return nil + } + + return h.claimsValidationFn(claims) +} + +func (h *handler[T]) jwtTokenToClaims(ctx context.Context, token jwt.Token) (T, error) { + rawClaims, err := token.AsMap(ctx) + if err != nil { + return *new(T), fmt.Errorf("unable to convert token to claims: %w", err) + } + + claimsBytes, err := json.Marshal(rawClaims) + if err != nil { + return *new(T), fmt.Errorf("unable to marshal raw claims to json: %w", err) + } + + claims := *new(T) + err = json.Unmarshal(claimsBytes, &claims) + if err != nil { + return *new(T), fmt.Errorf("unable to unmarshal claims from json: %w", err) + } + + return claims, nil } func GetDiscoveryUriFromIssuer(issuer string) string { @@ -330,27 +355,6 @@ func isTokenTypeValid(requiredTokenType string, tokenString string) bool { return true } -func isRequiredClaimsValid(requiredClaims map[string]interface{}, tokenClaims map[string]interface{}) error { - for requiredKey, requiredValue := range requiredClaims { - tokenValue, ok := tokenClaims[requiredKey] - if !ok { - return fmt.Errorf("token does not have the claim: %s", requiredKey) - } - - required, received, err := getCtyValues(requiredValue, tokenValue) - if err != nil { - return err - } - - err = isCtyValueValid(required, received) - if err != nil { - return fmt.Errorf("claim %q not valid: %w", requiredKey, err) - } - } - - return nil -} - func getAndValidateTokenFromString(tokenString string, key jwk.Key, alg jwa.SignatureAlgorithm) (jwt.Token, error) { token, err := jwt.ParseString(tokenString, jwt.WithVerify(alg, key)) if err != nil { diff --git a/internal/oidc/oidc_test.go b/internal/oidc/oidc_test.go index c09e8ab..d740085 100644 --- a/internal/oidc/oidc_test.go +++ b/internal/oidc/oidc_test.go @@ -21,6 +21,8 @@ import ( "github.com/xenitab/dispans/server" ) +type testClaims map[string]interface{} + func TestGetHeadersFromTokenString(t *testing.T) { key, _ := testNewKey(t) @@ -640,37 +642,6 @@ func TestParseToken(t *testing.T) { customExpirationMinutes: -1, expectedErrorContains: "token has expired", }, - { - testDescription: "correct requiredClaim", - options: []options.Option{ - options.WithIssuer("http://foo.bar"), - options.WithDiscoveryUri("http://foo.bar"), - options.WithJwksUri(testServer.URL), - options.WithRequiredClaims(map[string]interface{}{ - "foo": "bar", - }), - options.WithDisableKeyID(false), - }, - numKeys: 1, - expectedErrorContains: "", - }, - { - testDescription: "correct requiredClaim", - options: []options.Option{ - options.WithIssuer("http://foo.bar"), - options.WithDiscoveryUri("http://foo.bar"), - options.WithJwksUri(testServer.URL), - options.WithRequiredClaims(map[string]interface{}{ - "foo": "bar", - }), - options.WithDisableKeyID(false), - }, - numKeys: 1, - customClaims: map[string]string{ - "foo": "baz", - }, - expectedErrorContains: "unable to validate required claims", - }, } for i, c := range cases { @@ -684,7 +655,7 @@ func TestParseToken(t *testing.T) { keySets.setKeys(testNewKeySet(t, c.numKeys, opts.DisableKeyID)) - h, err := NewHandler(c.options...) + h, err := NewHandler[testClaims](nil, c.options...) require.NoError(t, err) parseTokenFunc := h.ParseToken @@ -735,7 +706,7 @@ func TestParseTokenWithKeyID(t *testing.T) { options.WithJwksRateLimit(100), } - h, err := NewHandler(opts...) + h, err := NewHandler[testClaims](nil, opts...) require.NoError(t, err) parseTokenFunc := h.ParseToken @@ -822,7 +793,7 @@ func TestParseTokenWithoutKeyID(t *testing.T) { options.WithJwksRateLimit(100), } - h, err := NewHandler(opts...) + h, err := NewHandler[testClaims](nil, opts...) require.NoError(t, err) parseTokenFunc := h.ParseToken @@ -928,393 +899,6 @@ func TestGetAndValidateTokenFromStringWithoutKeyID(t *testing.T) { require.ErrorIs(t, err, errSignatureVerification) } -func TestIsRequiredClaimsValid(t *testing.T) { - cases := []struct { - testDescription string - requiredClaims map[string]interface{} - tokenClaims map[string]interface{} - expectedResult bool - }{ - { - testDescription: "both are nil", - requiredClaims: nil, - tokenClaims: nil, - expectedResult: true, - }, - { - testDescription: "both are empty", - requiredClaims: map[string]interface{}{}, - tokenClaims: map[string]interface{}{}, - expectedResult: true, - }, - { - testDescription: "required claims are nil", - requiredClaims: nil, - tokenClaims: map[string]interface{}{ - "foo": "bar", - }, - expectedResult: true, - }, - { - testDescription: "required claims are empty", - requiredClaims: map[string]interface{}{}, - tokenClaims: map[string]interface{}{ - "foo": "bar", - }, - expectedResult: true, - }, - { - testDescription: "token claims are nil", - requiredClaims: map[string]interface{}{ - "foo": "bar", - }, - tokenClaims: nil, - expectedResult: false, - }, - { - testDescription: "token claims are empty", - requiredClaims: map[string]interface{}{ - "foo": "bar", - }, - tokenClaims: map[string]interface{}{}, - expectedResult: false, - }, - { - testDescription: "required is string, token is int", - requiredClaims: map[string]interface{}{ - "foo": "bar", - }, - tokenClaims: map[string]interface{}{ - "foo": 1337, - }, - expectedResult: false, - }, - { - testDescription: "matching with string", - requiredClaims: map[string]interface{}{ - "foo": "bar", - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - }, - expectedResult: true, - }, - { - testDescription: "matching with string and int", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - }, - expectedResult: true, - }, - { - testDescription: "matching with string and int in different orders", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - }, - tokenClaims: map[string]interface{}{ - "bar": 1337, - "foo": "bar", - }, - expectedResult: true, - }, - { - testDescription: "matching with string, int and float", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": 13.37, - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": 13.37, - }, - expectedResult: true, - }, - { - testDescription: "not matching with string, int and float", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": 13.37, - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": 12.27, - }, - expectedResult: false, - }, - { - testDescription: "matching slice", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"foo"}, - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"foo"}, - }, - expectedResult: true, - }, - { - testDescription: "matching slice with multiple values", - requiredClaims: map[string]interface{}{ - "oof": []string{"foo", "bar"}, - }, - tokenClaims: map[string]interface{}{ - "oof": []string{"foo", "bar", "baz"}, - }, - expectedResult: true, - }, - { - testDescription: "required slice contains in token slice", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"foo"}, - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"foo", "bar", "baz"}, - }, - expectedResult: true, - }, - { - testDescription: "not matching slice", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"foo"}, - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"bar"}, - }, - expectedResult: false, - }, - { - testDescription: "matching map", - requiredClaims: map[string]interface{}{ - "foo": map[string]string{ - "foo": "bar", - }, - }, - tokenClaims: map[string]interface{}{ - "foo": map[string]string{ - "foo": "bar", - }, - }, - expectedResult: true, - }, - { - testDescription: "matching map with multiple values", - requiredClaims: map[string]interface{}{ - "foo": map[string]string{ - "foo": "bar", - "bar": "foo", - }, - }, - tokenClaims: map[string]interface{}{ - "foo": map[string]string{ - "a": "b", - "foo": "bar", - "bar": "foo", - "c": "d", - }, - }, - expectedResult: true, - }, - { - testDescription: "matching map with multiple keys in token claims", - requiredClaims: map[string]interface{}{ - "foo": map[string]string{ - "foo": "bar", - }, - }, - tokenClaims: map[string]interface{}{ - "foo": map[string]string{ - "a": "b", - "foo": "bar", - "c": "d", - }, - }, - expectedResult: true, - }, - { - testDescription: "not matching map", - requiredClaims: map[string]interface{}{ - "foo": map[string]string{ - "foo": "bar", - }, - }, - tokenClaims: map[string]interface{}{ - "foo": map[string]int{ - "foo": 1337, - }, - }, - expectedResult: false, - }, - { - testDescription: "matching map with string slice", - requiredClaims: map[string]interface{}{ - "foo": map[string][]string{ - "foo": {"bar"}, - }, - }, - tokenClaims: map[string]interface{}{ - "foo": map[string][]string{ - "foo": {"foo", "bar", "baz"}, - }, - }, - expectedResult: true, - }, - { - testDescription: "not matching map with string slice", - requiredClaims: map[string]interface{}{ - "foo": map[string][]string{ - "foo": {"foobar"}, - }, - }, - tokenClaims: map[string]interface{}{ - "foo": map[string][]string{ - "foo": {"foo", "bar", "baz"}, - }, - }, - expectedResult: false, - }, - { - testDescription: "matching slice with map", - requiredClaims: map[string]interface{}{ - "foo": []map[string]string{ - {"bar": "baz"}, - }, - }, - tokenClaims: map[string]interface{}{ - "foo": []map[string]string{ - {"bar": "baz"}, - }, - }, - expectedResult: true, - }, - { - testDescription: "not matching slice with map", - requiredClaims: map[string]interface{}{ - "foo": []map[string]string{ - {"bar": "foobar"}, - }, - }, - tokenClaims: map[string]interface{}{ - "foo": []map[string]string{ - {"bar": "baz"}, - }, - }, - expectedResult: false, - }, - { - testDescription: "matching primitive types, slice and map", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"foo"}, - "oof": []map[string]string{ - {"bar": "baz"}, - }, - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"foo"}, - "oof": []map[string]string{ - {"bar": "baz"}, - }, - }, - expectedResult: true, - }, - { - testDescription: "matching primitive types, slice and map where token contains multiple values", - requiredClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"bar"}, - "oof": []map[string]string{ - {"bar": "baz"}, - }, - }, - tokenClaims: map[string]interface{}{ - "foo": "bar", - "bar": 1337, - "baz": []string{"foo", "bar", "baz"}, - "oof": []map[string]string{ - {"a": "b"}, - {"bar": "baz"}, - {"c": "d"}, - }, - }, - expectedResult: true, - }, - { - testDescription: "valid interface list in an interface map", - requiredClaims: map[string]interface{}{ - "foo": map[string][]string{ - "bar": {"baz"}, - }, - }, - tokenClaims: map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": []interface{}{ - "uno", - "dos", - "baz", - "tres", - }, - }, - }, - expectedResult: true, - }, - { - testDescription: "invalid interface list in an interface map", - requiredClaims: map[string]interface{}{ - "foo": map[string][]string{ - "bar": {"baz"}, - }, - }, - tokenClaims: map[string]interface{}{ - "foo": map[string]interface{}{ - "bar": []interface{}{ - "uno", - "dos", - "tres", - }, - }, - }, - expectedResult: false, - }, - } - - for i, c := range cases { - t.Logf("Test iteration %d: %s", i, c.testDescription) - - err := isRequiredClaimsValid(c.requiredClaims, c.tokenClaims) - - if c.expectedResult { - require.NoError(t, err) - } else { - require.Error(t, err) - } - } -} - func TestGetSignatureAlgorithm(t *testing.T) { cases := []struct { inputKty jwa.KeyType diff --git a/internal/oidctesting/benchmarks.go b/internal/oidctesting/benchmarks.go index 72728b4..4029c98 100644 --- a/internal/oidctesting/benchmarks.go +++ b/internal/oidctesting/benchmarks.go @@ -28,6 +28,7 @@ func runBenchmarkHandler(b *testing.B, testName string, tester tester) { defer op.Close(b) handler := tester.NewHandlerFn( + nil, options.WithIssuer(op.GetURL(b)), ) @@ -47,12 +48,12 @@ func runBenchmarkRequirements(b *testing.B, testName string, tester tester) { defer op.Close(b) handler := tester.NewHandlerFn( + func(claims *TestClaims) error { + return testClaimsValueEq(claims, "sub", "test") + }, options.WithIssuer(op.GetURL(b)), options.WithRequiredTokenType("JWT+AT"), options.WithRequiredAudience("test-client"), - options.WithRequiredClaims(map[string]interface{}{ - "sub": "test", - }), ) fn := func(token *optest.TokenResponse) { diff --git a/internal/oidctesting/tests.go b/internal/oidctesting/tests.go index d57cccd..fc5ce70 100644 --- a/internal/oidctesting/tests.go +++ b/internal/oidctesting/tests.go @@ -14,14 +14,34 @@ import ( "github.com/xenitab/go-oidc-middleware/options" ) +type TestClaims map[string]interface{} + +func testClaimsValueEq(claims *TestClaims, key string, expectedValue string) error { + rawValue, ok := (*claims)[key] + if !ok { + return fmt.Errorf("key %s not found", key) + } + + value, ok := rawValue.(string) + if !ok { + return fmt.Errorf("key %s not expected type %T, received: %v", key, expectedValue, rawValue) + } + + if expectedValue != value { + return fmt.Errorf("key %s %v != %v", key, expectedValue, value) + } + + return nil +} + type ServerTester interface { Close() URL() string } type tester interface { - NewHandlerFn(opts ...options.Option) http.Handler - ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...options.Option) http.Handler + NewHandlerFn(claimsValidationFn options.ClaimsValidationFn[TestClaims], opts ...options.Option) http.Handler + ToHandlerFn(parseToken oidc.ParseTokenFunc[TestClaims], opts ...options.Option) http.Handler NewTestServer(opts ...options.Option) ServerTester } @@ -112,9 +132,9 @@ func runTestNew(t *testing.T, testName string, tester tester) { c := cases[i] t.Logf("Test iteration %d: %s", i, c.testDescription) if c.expectPanic { - require.Panics(t, func() { tester.NewHandlerFn(c.config...) }) + require.Panics(t, func() { tester.NewHandlerFn(nil, c.config...) }) } else { - require.NotPanics(t, func() { tester.NewHandlerFn(c.config...) }) + require.NotPanics(t, func() { tester.NewHandlerFn(nil, c.config...) }) } } }) @@ -128,6 +148,7 @@ func runTestHandler(t *testing.T, testName string, tester tester) { defer op.Close(t) handler := tester.NewHandlerFn( + nil, options.WithIssuer(op.GetURL(t)), options.WithRequiredAudience("test-client"), options.WithRequiredTokenType("JWT+AT"), @@ -163,7 +184,8 @@ func runTestLazyLoad(t *testing.T, testName string, tester tester) { op := optest.NewTesting(t) defer op.Close(t) - oidcHandler, err := oidc.NewHandler( + oidcHandler, err := oidc.NewHandler[TestClaims]( + nil, options.WithIssuer("http://foo.bar/baz"), options.WithRequiredAudience("test-client"), options.WithRequiredTokenType("JWT+AT"), @@ -199,9 +221,10 @@ func runTestRequirements(t *testing.T, testName string, tester tester) { defer op.Close(t) cases := []struct { - testDescription string - options []options.Option - succeeds bool + testDescription string + options []options.Option + claimsValidationFn options.ClaimsValidationFn[TestClaims] + succeeds bool }{ { testDescription: "no requirements", @@ -246,9 +269,12 @@ func runTestRequirements(t *testing.T, testName string, tester tester) { testDescription: "required sub matches", options: []options.Option{ options.WithIssuer(op.GetURL(t)), - options.WithRequiredClaims(map[string]interface{}{ - "sub": "test", - }), + // options.WithRequiredClaims(map[string]interface{}{ + // "sub": "test", + // }), + }, + claimsValidationFn: func(claims *TestClaims) error { + return testClaimsValueEq(claims, "sub", "test") }, succeeds: true, }, @@ -256,9 +282,9 @@ func runTestRequirements(t *testing.T, testName string, tester tester) { testDescription: "required sub doesn't match", options: []options.Option{ options.WithIssuer(op.GetURL(t)), - options.WithRequiredClaims(map[string]interface{}{ - "sub": "foo", - }), + }, + claimsValidationFn: func(claims *TestClaims) error { + return testClaimsValueEq(claims, "sub", "foo") }, succeeds: false, }, @@ -268,7 +294,7 @@ func runTestRequirements(t *testing.T, testName string, tester tester) { c := cases[i] t.Logf("Test iteration %d: %s", i, c.testDescription) - handler := tester.NewHandlerFn(c.options...) + handler := tester.NewHandlerFn(c.claimsValidationFn, c.options...) token := op.GetToken(t) if c.succeeds { @@ -318,7 +344,7 @@ func runTestErrorHandler(t *testing.T, testName string, tester tester) { options.WithErrorHandler(errorHandler), } - oidcHandler, err := oidc.NewHandler(opts...) + oidcHandler, err := oidc.NewHandler[TestClaims](nil, opts...) require.NoError(t, err) handler := tester.ToHandlerFn(oidcHandler.ParseToken, opts...) diff --git a/oidcechojwt/echo_jwt.go b/oidcechojwt/echo_jwt.go index f09477d..bfac0b9 100644 --- a/oidcechojwt/echo_jwt.go +++ b/oidcechojwt/echo_jwt.go @@ -10,8 +10,8 @@ import ( // New returns an OpenID Connect (OIDC) discovery `ParseTokenFunc` // to be used with the the echo `JWT` middleware. -func New(setters ...options.Option) func(auth string, c echo.Context) (interface{}, error) { - h, err := oidc.NewHandler(setters...) +func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) func(auth string, c echo.Context) (interface{}, error) { + h, err := oidc.NewHandler(claimsValidationFn, setters...) if err != nil { panic(fmt.Sprintf("oidc discovery: %v", err)) } @@ -27,25 +27,19 @@ func onError(errorHandler options.ErrorHandler, description options.ErrorDescrip } } -func toEchoJWTParseTokenFunc(parseToken oidc.ParseTokenFunc, setters ...options.Option) echoJWTParseTokenFunc { +func toEchoJWTParseTokenFunc[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) echoJWTParseTokenFunc { opts := options.New(setters...) echoJWTParseTokenFunc := func(auth string, c echo.Context) (interface{}, error) { ctx := c.Request().Context() - token, err := parseToken(ctx, auth) + claims, err := parseToken(ctx, auth) if err != nil { onError(opts.ErrorHandler, options.ParseTokenErrorDescription, err) return nil, err } - tokenClaims, err := token.AsMap(ctx) - if err != nil { - onError(opts.ErrorHandler, options.ConvertTokenErrorDescription, err) - return nil, err - } - - return tokenClaims, nil + return claims, nil } return echoJWTParseTokenFunc diff --git a/oidcechojwt/echo_jwt_test.go b/oidcechojwt/echo_jwt_test.go index c02be93..5eded2b 100644 --- a/oidcechojwt/echo_jwt_test.go +++ b/oidcechojwt/echo_jwt_test.go @@ -40,7 +40,7 @@ func testGetEchoRouter(tb testing.TB, parseToken echoJWTParseTokenFunc) *echo.Ec })) e.GET("/", func(c echo.Context) error { - claims, ok := c.Get("user").(map[string]interface{}) + claims, ok := c.Get("user").(oidctesting.TestClaims) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "invalid token") } @@ -105,14 +105,14 @@ func newTestHandler(tb testing.TB) *testHandler { } } -func (h *testHandler) NewHandlerFn(opts ...options.Option) http.Handler { +func (h *testHandler) NewHandlerFn(claimsValidationFn options.ClaimsValidationFn[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() - echoParseToken := New(opts...) + echoParseToken := New(claimsValidationFn, opts...) return testGetEchoRouter(h.tb, echoParseToken) } -func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...options.Option) http.Handler { +func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() echoParseToken := toEchoJWTParseTokenFunc(parseToken, opts...) @@ -122,6 +122,6 @@ func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...option func (h *testHandler) NewTestServer(opts ...options.Option) oidctesting.ServerTester { h.tb.Helper() - echoParseToken := New(opts...) + echoParseToken := New[oidctesting.TestClaims](nil, opts...) return newTestServer(h.tb, testGetEchoRouter(h.tb, echoParseToken)) } diff --git a/oidcechojwt/go.sum b/oidcechojwt/go.sum index f97dbf1..003f847 100644 --- a/oidcechojwt/go.sum +++ b/oidcechojwt/go.sum @@ -79,8 +79,8 @@ github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/xenitab/dispans v0.0.10 h1:S+gSUM14rDJWK7MYNrjb8JbjeQPip6mlNJyLX+g7Agc= -github.com/xenitab/go-oidc-middleware v0.0.35 h1:9u1rQ/MqYXg4IpeJcOKyCSA2Xo8Pji3IiIZ+ZbAoqFI= -github.com/xenitab/go-oidc-middleware v0.0.35/go.mod h1:a8lpsTfdmiEsbclX4oIQE2gXj+8cYLLGRKUtgccwR94= +github.com/xenitab/go-oidc-middleware v0.0.36 h1:iBm+8usJZg9mCWrZWliHpzNatWn6g31AAcbb1q4M6go= +github.com/xenitab/go-oidc-middleware v0.0.36/go.mod h1:dUakIYup0Grr7Bn/88xTTKtlS6MWoWZtrrnzdt/SUZU= github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= diff --git a/oidcfiber/fiber.go b/oidcfiber/fiber.go index 6be62eb..e83abd4 100644 --- a/oidcfiber/fiber.go +++ b/oidcfiber/fiber.go @@ -10,8 +10,8 @@ import ( // New returns an OpenID Connect (OIDC) discovery handler (middleware) // to be used with `fiber`. -func New(setters ...options.Option) fiber.Handler { - oidcHandler, err := oidc.NewHandler(setters...) +func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) fiber.Handler { + oidcHandler, err := oidc.NewHandler(claimsValidationFn, setters...) if err != nil { panic(fmt.Sprintf("oidc discovery: %v", err)) } @@ -27,7 +27,7 @@ func onError(c *fiber.Ctx, errorHandler options.ErrorHandler, statusCode int, de return c.SendStatus(statusCode) } -func toFiberHandler(parseToken oidc.ParseTokenFunc, setters ...options.Option) fiber.Handler { +func toFiberHandler[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) fiber.Handler { opts := options.New(setters...) return func(c *fiber.Ctx) error { @@ -42,17 +42,12 @@ func toFiberHandler(parseToken oidc.ParseTokenFunc, setters ...options.Option) f return onError(c, opts.ErrorHandler, fiber.StatusBadRequest, options.GetTokenErrorDescription, err) } - token, err := parseToken(ctx, tokenString) + claims, err := parseToken(ctx, tokenString) if err != nil { return onError(c, opts.ErrorHandler, fiber.StatusUnauthorized, options.ParseTokenErrorDescription, err) } - tokenClaims, err := token.AsMap(ctx) - if err != nil { - return onError(c, opts.ErrorHandler, fiber.StatusUnauthorized, options.ConvertTokenErrorDescription, err) - } - - c.Locals(string(opts.ClaimsContextKeyName), tokenClaims) + c.Locals(string(opts.ClaimsContextKeyName), claims) return c.Next() } diff --git a/oidcfiber/fiber_test.go b/oidcfiber/fiber_test.go index 9dc4116..56e644b 100644 --- a/oidcfiber/fiber_test.go +++ b/oidcfiber/fiber_test.go @@ -34,7 +34,7 @@ func testGetFiberRouter(tb testing.TB, middleware fiber.Handler) *fiber.App { app.Use(middleware) app.Get("/", func(c *fiber.Ctx) error { - claims, ok := c.Locals("claims").(map[string]interface{}) + claims, ok := c.Locals("claims").(oidctesting.TestClaims) if !ok { return c.SendStatus(fiber.StatusUnauthorized) } @@ -94,16 +94,16 @@ func newTestHandler(tb testing.TB) *testHandler { } } -func (h *testHandler) NewHandlerFn(opts ...options.Option) http.Handler { +func (h *testHandler) NewHandlerFn(claimsValidationFn options.ClaimsValidationFn[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() - middleware := New(opts...) + middleware := New(claimsValidationFn, opts...) app := testGetFiberRouter(h.tb, middleware) return newTestFiberHandler(h.tb, app) } -func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...options.Option) http.Handler { +func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() middleware := toFiberHandler(parseToken, opts...) @@ -115,7 +115,7 @@ func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...option func (h *testHandler) NewTestServer(opts ...options.Option) oidctesting.ServerTester { h.tb.Helper() - middleware := New(opts...) + middleware := New[oidctesting.TestClaims](nil, opts...) app := testGetFiberRouter(h.tb, middleware) return newTestServer(h.tb, app) diff --git a/oidcfiber/go.sum b/oidcfiber/go.sum index ef73613..07ff0f5 100644 --- a/oidcfiber/go.sum +++ b/oidcfiber/go.sum @@ -85,8 +85,8 @@ github.com/valyala/fasthttp v1.42.0/go.mod h1:f6VbjjoI3z1NDOZOv17o6RvtRSWxC77seB github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/xenitab/dispans v0.0.10 h1:S+gSUM14rDJWK7MYNrjb8JbjeQPip6mlNJyLX+g7Agc= -github.com/xenitab/go-oidc-middleware v0.0.35 h1:9u1rQ/MqYXg4IpeJcOKyCSA2Xo8Pji3IiIZ+ZbAoqFI= -github.com/xenitab/go-oidc-middleware v0.0.35/go.mod h1:a8lpsTfdmiEsbclX4oIQE2gXj+8cYLLGRKUtgccwR94= +github.com/xenitab/go-oidc-middleware v0.0.36 h1:iBm+8usJZg9mCWrZWliHpzNatWn6g31AAcbb1q4M6go= +github.com/xenitab/go-oidc-middleware v0.0.36/go.mod h1:dUakIYup0Grr7Bn/88xTTKtlS6MWoWZtrrnzdt/SUZU= github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= diff --git a/oidcgin/gin.go b/oidcgin/gin.go index 701af6f..8774112 100644 --- a/oidcgin/gin.go +++ b/oidcgin/gin.go @@ -11,8 +11,8 @@ import ( // New returns an OpenID Connect (OIDC) discovery handler (middleware) // to be used with `gin`. -func New(setters ...options.Option) gin.HandlerFunc { - oidcHandler, err := oidc.NewHandler(setters...) +func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) gin.HandlerFunc { + oidcHandler, err := oidc.NewHandler(claimsValidationFn, setters...) if err != nil { panic(fmt.Sprintf("oidc discovery: %v", err)) } @@ -29,7 +29,7 @@ func onError(c *gin.Context, errorHandler options.ErrorHandler, statusCode int, c.AbortWithError(statusCode, err) } -func toGinHandler(parseToken oidc.ParseTokenFunc, setters ...options.Option) gin.HandlerFunc { +func toGinHandler[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) gin.HandlerFunc { opts := options.New(setters...) return func(c *gin.Context) { @@ -41,19 +41,13 @@ func toGinHandler(parseToken oidc.ParseTokenFunc, setters ...options.Option) gin return } - token, err := parseToken(ctx, tokenString) + claims, err := parseToken(ctx, tokenString) if err != nil { onError(c, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err) return } - tokenClaims, err := token.AsMap(ctx) - if err != nil { - onError(c, opts.ErrorHandler, http.StatusUnauthorized, options.ConvertTokenErrorDescription, err) - return - } - - c.Set(string(opts.ClaimsContextKeyName), tokenClaims) + c.Set(string(opts.ClaimsContextKeyName), claims) c.Next() } diff --git a/oidcgin/gin_test.go b/oidcgin/gin_test.go index fa522b6..7712460 100644 --- a/oidcgin/gin_test.go +++ b/oidcgin/gin_test.go @@ -42,7 +42,7 @@ func testGetGinRouter(tb testing.TB, middleware gin.HandlerFunc) *gin.Engine { return } - claims, ok := claimsValue.(map[string]interface{}) + claims, ok := claimsValue.(oidctesting.TestClaims) if !ok { c.AbortWithStatus(http.StatusUnauthorized) return @@ -94,14 +94,14 @@ func newTestHandler(tb testing.TB) *testHandler { } } -func (h *testHandler) NewHandlerFn(opts ...options.Option) http.Handler { +func (h *testHandler) NewHandlerFn(claimsValidationFn options.ClaimsValidationFn[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() - middleware := New(opts...) + middleware := New(claimsValidationFn, opts...) return testGetGinRouter(h.tb, middleware) } -func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...options.Option) http.Handler { +func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() middleware := toGinHandler(parseToken, opts...) @@ -111,6 +111,6 @@ func (h *testHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...option func (h *testHandler) NewTestServer(opts ...options.Option) oidctesting.ServerTester { h.tb.Helper() - middleware := New(opts...) + middleware := New[oidctesting.TestClaims](nil, opts...) return newTestServer(h.tb, testGetGinRouter(h.tb, middleware)) } diff --git a/oidcgin/go.sum b/oidcgin/go.sum index add1b31..ad84a57 100644 --- a/oidcgin/go.sum +++ b/oidcgin/go.sum @@ -99,8 +99,8 @@ github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6 github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/xenitab/dispans v0.0.10 h1:S+gSUM14rDJWK7MYNrjb8JbjeQPip6mlNJyLX+g7Agc= -github.com/xenitab/go-oidc-middleware v0.0.35 h1:9u1rQ/MqYXg4IpeJcOKyCSA2Xo8Pji3IiIZ+ZbAoqFI= -github.com/xenitab/go-oidc-middleware v0.0.35/go.mod h1:a8lpsTfdmiEsbclX4oIQE2gXj+8cYLLGRKUtgccwR94= +github.com/xenitab/go-oidc-middleware v0.0.36 h1:iBm+8usJZg9mCWrZWliHpzNatWn6g31AAcbb1q4M6go= +github.com/xenitab/go-oidc-middleware v0.0.36/go.mod h1:dUakIYup0Grr7Bn/88xTTKtlS6MWoWZtrrnzdt/SUZU= github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= diff --git a/oidchttp/go.sum b/oidchttp/go.sum index 83bd15e..e9d8a04 100644 --- a/oidchttp/go.sum +++ b/oidchttp/go.sum @@ -61,8 +61,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/rtred v0.1.2 h1:exmoQtOLvDoO8ud++6LwVsAMTu0KPzLTUrMln8u1yu8= github.com/tidwall/tinyqueue v0.1.1 h1:SpNEvEggbpyN5DIReaJ2/1ndroY8iyEGxPYxoSaymYE= github.com/xenitab/dispans v0.0.10 h1:S+gSUM14rDJWK7MYNrjb8JbjeQPip6mlNJyLX+g7Agc= -github.com/xenitab/go-oidc-middleware v0.0.35 h1:9u1rQ/MqYXg4IpeJcOKyCSA2Xo8Pji3IiIZ+ZbAoqFI= -github.com/xenitab/go-oidc-middleware v0.0.35/go.mod h1:a8lpsTfdmiEsbclX4oIQE2gXj+8cYLLGRKUtgccwR94= +github.com/xenitab/go-oidc-middleware v0.0.36 h1:iBm+8usJZg9mCWrZWliHpzNatWn6g31AAcbb1q4M6go= +github.com/xenitab/go-oidc-middleware v0.0.36/go.mod h1:dUakIYup0Grr7Bn/88xTTKtlS6MWoWZtrrnzdt/SUZU= github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= diff --git a/oidchttp/http.go b/oidchttp/http.go index 60e23c2..6ec15a8 100644 --- a/oidchttp/http.go +++ b/oidchttp/http.go @@ -11,8 +11,8 @@ import ( // New returns an OpenID Connect (OIDC) discovery handler (middleware) // to be used with `net/http`, `mux` and `chi`. -func New(h http.Handler, setters ...options.Option) http.Handler { - oidcHandler, err := oidc.NewHandler(setters...) +func New[T any](h http.Handler, claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) http.Handler { + oidcHandler, err := oidc.NewHandler(claimsValidationFn, setters...) if err != nil { panic(fmt.Sprintf("oidc discovery: %v", err)) } @@ -28,7 +28,7 @@ func onError(w http.ResponseWriter, errorHandler options.ErrorHandler, statusCod w.WriteHeader(statusCode) } -func toHttpHandler(h http.Handler, parseToken oidc.ParseTokenFunc, setters ...options.Option) http.Handler { +func toHttpHandler[T any](h http.Handler, parseToken oidc.ParseTokenFunc[T], setters ...options.Option) http.Handler { opts := options.New(setters...) fn := func(w http.ResponseWriter, r *http.Request) { @@ -40,19 +40,13 @@ func toHttpHandler(h http.Handler, parseToken oidc.ParseTokenFunc, setters ...op return } - token, err := parseToken(ctx, tokenString) + claims, err := parseToken(ctx, tokenString) if err != nil { onError(w, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err) return } - tokenClaims, err := token.AsMap(ctx) - if err != nil { - onError(w, opts.ErrorHandler, http.StatusUnauthorized, options.ConvertTokenErrorDescription, err) - return - } - - ctxWithClaims := context.WithValue(ctx, opts.ClaimsContextKeyName, tokenClaims) + ctxWithClaims := context.WithValue(ctx, opts.ClaimsContextKeyName, claims) reqWithClaims := r.WithContext(ctxWithClaims) h.ServeHTTP(w, reqWithClaims) diff --git a/oidchttp/http_test.go b/oidchttp/http_test.go index 06857bd..2073224 100644 --- a/oidchttp/http_test.go +++ b/oidchttp/http_test.go @@ -31,7 +31,7 @@ func testNewClaimsHandler(tb testing.TB) func(w http.ResponseWriter, r *http.Req tb.Helper() return func(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(options.DefaultClaimsContextKeyName).(map[string]interface{}) + claims, ok := r.Context().Value(options.DefaultClaimsContextKeyName).(oidctesting.TestClaims) if !ok { w.WriteHeader(http.StatusUnauthorized) return @@ -86,14 +86,14 @@ func newTestHttpHandler(tb testing.TB) *testHttpHandler { } } -func (h *testHttpHandler) NewHandlerFn(opts ...options.Option) http.Handler { +func (h *testHttpHandler) NewHandlerFn(claimsValidationFn options.ClaimsValidationFn[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() handler := testGetHttpHandler(h.tb) - return New(handler, opts...) + return New(handler, claimsValidationFn, opts...) } -func (h *testHttpHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...options.Option) http.Handler { +func (h *testHttpHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() handler := testGetHttpHandler(h.tb) @@ -104,5 +104,5 @@ func (h *testHttpHandler) NewTestServer(opts ...options.Option) oidctesting.Serv h.tb.Helper() handler := testGetHttpHandler(h.tb) - return newTestServer(h.tb, New(handler, opts...)) + return newTestServer(h.tb, New[oidctesting.TestClaims](handler, nil, opts...)) } diff --git a/oidctoken/token.go b/oidctoken/token.go index 4053918..08febc9 100644 --- a/oidctoken/token.go +++ b/oidctoken/token.go @@ -3,28 +3,27 @@ package oidctoken import ( "context" - "github.com/lestrrat-go/jwx/jwt" "github.com/xenitab/go-oidc-middleware/internal/oidc" "github.com/xenitab/go-oidc-middleware/options" ) // TokenHandler is used to parse tokens. -type TokenHandler struct { - parseTokenFunc oidc.ParseTokenFunc +type TokenHandler[T any] struct { + parseTokenFunc oidc.ParseTokenFunc[T] tokenOptions *options.Options } // New returns an OpenID Connect (OIDC) discovery token handler. // Can be used to create your own middleware. -func New(setters ...options.Option) (*TokenHandler, error) { - oidcHandler, err := oidc.NewHandler(setters...) +func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) (*TokenHandler[T], error) { + oidcHandler, err := oidc.NewHandler(claimsValidationFn, setters...) if err != nil { return nil, err } tokenOpts := options.New(setters...) - return &TokenHandler{ + return &TokenHandler[T]{ parseTokenFunc: oidcHandler.ParseToken, tokenOptions: tokenOpts, }, nil @@ -32,13 +31,13 @@ func New(setters ...options.Option) (*TokenHandler, error) { // ParseToken takes a context and a string and returns a jwt.Token or an error. // jwt.Token is from `github.com/lestrrat-go/jwx/jwt`. -func (t *TokenHandler) ParseToken(ctx context.Context, tokenString string) (jwt.Token, error) { - token, err := t.parseTokenFunc(ctx, tokenString) +func (t *TokenHandler[T]) ParseToken(ctx context.Context, tokenString string) (T, error) { + claims, err := t.parseTokenFunc(ctx, tokenString) if err != nil { - return nil, err + return *new(T), err } - return token, nil + return claims, nil } // GetTokenString takes a GetHeaderFn `func(key string) string` and [][]options.TokenStringOption and diff --git a/oidctoken/token_test.go b/oidctoken/token_test.go index 600fd3a..fc1ab31 100644 --- a/oidctoken/token_test.go +++ b/oidctoken/token_test.go @@ -33,7 +33,7 @@ func testNewClaimsHandler(tb testing.TB) func(w http.ResponseWriter, r *http.Req tb.Helper() return func(w http.ResponseWriter, r *http.Request) { - claims, ok := r.Context().Value(options.DefaultClaimsContextKeyName).(map[string]interface{}) + claims, ok := r.Context().Value(options.DefaultClaimsContextKeyName).(oidctesting.TestClaims) if !ok { w.WriteHeader(http.StatusUnauthorized) return @@ -88,14 +88,14 @@ func newTestHttpHandler(tb testing.TB) *testHttpHandler { } } -func (h *testHttpHandler) NewHandlerFn(opts ...options.Option) http.Handler { +func (h *testHttpHandler) NewHandlerFn(claimsValidationFn options.ClaimsValidationFn[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() handler := testGetHttpHandler(h.tb) - return testNew(h.tb, handler, opts...) + return testNew(h.tb, handler, claimsValidationFn, opts...) } -func (h *testHttpHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc, opts ...options.Option) http.Handler { +func (h *testHttpHandler) ToHandlerFn(parseToken oidc.ParseTokenFunc[oidctesting.TestClaims], opts ...options.Option) http.Handler { h.tb.Helper() handler := testGetHttpHandler(h.tb) @@ -106,7 +106,7 @@ func (h *testHttpHandler) NewTestServer(opts ...options.Option) oidctesting.Serv h.tb.Helper() handler := testGetHttpHandler(h.tb) - return newTestServer(h.tb, testNew(h.tb, handler, opts...)) + return newTestServer(h.tb, testNew(h.tb, handler, nil, opts...)) } func testOnError(tb testing.TB, w http.ResponseWriter, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) { @@ -119,10 +119,10 @@ func testOnError(tb testing.TB, w http.ResponseWriter, errorHandler options.Erro w.WriteHeader(statusCode) } -func testNew(tb testing.TB, h http.Handler, setters ...options.Option) http.Handler { +func testNew(tb testing.TB, h http.Handler, claimsValidationFn options.ClaimsValidationFn[oidctesting.TestClaims], setters ...options.Option) http.Handler { tb.Helper() - tokenHandler, err := New(setters...) + tokenHandler, err := New(claimsValidationFn, setters...) if err != nil { panic(fmt.Sprintf("oidc discovery: %v", err)) } @@ -130,7 +130,7 @@ func testNew(tb testing.TB, h http.Handler, setters ...options.Option) http.Hand return testToHttpHandler(tb, h, tokenHandler.ParseToken, setters...) } -func testToHttpHandler(tb testing.TB, h http.Handler, parseToken oidc.ParseTokenFunc, setters ...options.Option) http.Handler { +func testToHttpHandler[T any](tb testing.TB, h http.Handler, parseToken oidc.ParseTokenFunc[T], setters ...options.Option) http.Handler { tb.Helper() opts := options.New(setters...) @@ -144,19 +144,13 @@ func testToHttpHandler(tb testing.TB, h http.Handler, parseToken oidc.ParseToken return } - token, err := parseToken(ctx, tokenString) + claims, err := parseToken(ctx, tokenString) if err != nil { testOnError(tb, w, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err) return } - tokenClaims, err := token.AsMap(ctx) - if err != nil { - testOnError(tb, w, opts.ErrorHandler, http.StatusUnauthorized, options.ConvertTokenErrorDescription, err) - return - } - - ctxWithClaims := context.WithValue(ctx, opts.ClaimsContextKeyName, tokenClaims) + ctxWithClaims := context.WithValue(ctx, opts.ClaimsContextKeyName, claims) reqWithClaims := r.WithContext(ctxWithClaims) h.ServeHTTP(w, reqWithClaims) diff --git a/options/options.go b/options/options.go index 1203313..becbe60 100644 --- a/options/options.go +++ b/options/options.go @@ -5,6 +5,12 @@ import ( "time" ) +// ClaimsValidationFn is a generic function to validate calims. +// If an error is returned, the claims failed the validation. +// If `nil` is provided instead of a function when configuration the handler, +// no additional validation of the claims will be done. +type ClaimsValidationFn[T any] func(*T) error + // ClaimsContextKeyName is the type for they key value used to pass claims using request context. // Using separate type because of the following: https://staticcheck.io/docs/checks#SA1029 type ClaimsContextKeyName string @@ -40,7 +46,6 @@ type Options struct { LazyLoadJwks bool RequiredTokenType string RequiredAudience string - RequiredClaims map[string]interface{} DisableKeyID bool HttpClient *http.Client TokenString [][]TokenStringOption @@ -184,34 +189,6 @@ func WithRequiredAudience(opt string) Option { } } -// WithRequiredClaims sets the RequiredClaims parameter for an Options pointer. -// RequiredClaims is used to require specific claims in the token -// Defaults to empty map (nil) and won't check for anything else -// Works with primitive types, slices and maps. -// Please observe: slices and strings checks that the token contains it, but more is allowed. -// Required claim []string{"bar"} matches token []string{"foo", "bar", "baz"} -// Required claim map[string]string{{"foo": "bar"}} matches token map[string]string{{"a": "b"},{"foo": "bar"},{"c": "d"}} -// -// Example: -// -// ```go -// -// map[string]interface{}{ -// "foo": "bar", -// "bar": 1337, -// "baz": []string{"bar"}, -// "oof": []map[string]string{ -// {"bar": "baz"}, -// }, -// }, -// -// ``` -func WithRequiredClaims(opt map[string]interface{}) Option { - return func(opts *Options) { - opts.RequiredClaims = opt - } -} - // WithDisableKeyID sets the DisableKeyID parameter for an Options pointer. // DisableKeyID adjusts if a KeyID needs to be extracted from the token or not // Defaults to false and means KeyID is required to be present in both the jwks and token diff --git a/options/options_test.go b/options/options_test.go index 26459dc..6660bec 100644 --- a/options/options_test.go +++ b/options/options_test.go @@ -21,10 +21,7 @@ func TestOptions(t *testing.T) { LazyLoadJwks: true, RequiredTokenType: "foo", RequiredAudience: "foo", - RequiredClaims: map[string]interface{}{ - "foo": "bar", - }, - DisableKeyID: true, + DisableKeyID: true, HttpClient: &http.Client{ Timeout: 1234 * time.Second, }, @@ -57,9 +54,6 @@ func TestOptions(t *testing.T) { WithLazyLoadJwks(true), WithRequiredTokenType("foo"), WithRequiredAudience("foo"), - WithRequiredClaims(map[string]interface{}{ - "foo": "bar", - }), WithDisableKeyID(true), WithHttpClient(&http.Client{ Timeout: 1234 * time.Second,