diff --git a/compose/compose_pkce.go b/compose/compose_pkce.go index cd0e0a138..d87a53b6f 100644 --- a/compose/compose_pkce.go +++ b/compose/compose_pkce.go @@ -7,14 +7,12 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/pkce" - "github.com/ory/fosite/handler/rfc8628" ) // OAuth2PKCEFactory creates a PKCE handler. func OAuth2PKCEFactory(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { return &pkce.Handler{ AuthorizeCodeStrategy: strategy.(oauth2.AuthorizeCodeStrategy), - DeviceCodeStrategy: strategy.(rfc8628.DeviceCodeStrategy), Storage: storage.(pkce.PKCERequestStorage), Config: config, } diff --git a/handler/pkce/handler.go b/handler/pkce/handler.go index dc527839a..f457b8bea 100644 --- a/handler/pkce/handler.go +++ b/handler/pkce/handler.go @@ -15,14 +15,12 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" - "github.com/ory/fosite/handler/rfc8628" ) var _ fosite.TokenEndpointHandler = (*Handler)(nil) type Handler struct { AuthorizeCodeStrategy oauth2.AuthorizeCodeStrategy - DeviceCodeStrategy rfc8628.DeviceCodeStrategy Storage PKCERequestStorage Config interface { fosite.EnforcePKCEProvider @@ -35,51 +33,27 @@ var _ fosite.TokenEndpointHandler = (*Handler)(nil) var verifierWrongFormat = regexp.MustCompile("[^\\w\\.\\-~]") -func (c *Handler) HandleDeviceEndpointRequest(ctx context.Context, dr fosite.DeviceRequester, resp fosite.DeviceResponder) error { - return c.handlePkceEndpointRequest(ctx, dr, resp) -} - func (c *Handler) HandleAuthorizeEndpointRequest(ctx context.Context, ar fosite.AuthorizeRequester, resp fosite.AuthorizeResponder) error { - return c.handlePkceEndpointRequest(ctx, ar, resp) -} - -func (c *Handler) handlePkceEndpointRequest(ctx context.Context, r fosite.Requester, resp fosite.Responder) error { // This let's us define multiple response types, for example open id connect's id_token - if !(isAuthorizationCode(r) || isDeviceCode(r)) { + if !ar.GetResponseTypes().Has("code") { return nil } - challenge := r.GetRequestForm().Get("code_challenge") - method := r.GetRequestForm().Get("code_challenge_method") - client := r.GetClient() + challenge := ar.GetRequestForm().Get("code_challenge") + method := ar.GetRequestForm().Get("code_challenge_method") + client := ar.GetClient() if err := c.validate(ctx, challenge, method, client); err != nil { return err } - var signature string - if authorizeResp, ok := resp.(fosite.AuthorizeResponder); ok { - code := authorizeResp.GetCode() - if len(code) == 0 { - return errorsx.WithStack(fosite.ErrServerError.WithDebug("The PKCE handler must be loaded after the authorize/device code handler.")) - } - signature = c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code) - } else if deviceResp, ok := resp.(fosite.DeviceResponder); ok { - code := deviceResp.GetDeviceCode() - if len(code) == 0 { - return errorsx.WithStack(fosite.ErrServerError.WithDebug("The PKCE handler must be loaded after the device code handler.")) - } - - var err error - signature, err = c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) - if err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } - } else { - return errorsx.WithStack(fosite.ErrServerError.WithDebug("This PKCE handle could not find the proper response type")) + code := resp.GetCode() + if len(code) == 0 { + return errorsx.WithStack(fosite.ErrServerError.WithDebug("The PKCE handler must be loaded after the authorize code handler.")) } - if err := c.Storage.CreatePKCERequestSession(ctx, signature, r.Sanitize([]string{ + signature := c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code) + if err := c.Storage.CreatePKCERequestSession(ctx, signature, ar.Sanitize([]string{ "code_challenge", "code_challenge_method", })); err != nil { @@ -89,7 +63,7 @@ func (c *Handler) handlePkceEndpointRequest(ctx context.Context, r fosite.Reques return nil } -func (c *Handler) validate(ctx context.Context, challenge string, method string, client fosite.Client) error { +func (c *Handler) validate(ctx context.Context, challenge, method string, client fosite.Client) error { if challenge == "" { // If the server requires Proof Key for Code Exchange (PKCE) by OAuth // clients and the client does not send the "code_challenge" in @@ -147,19 +121,8 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, request fosite // endpoint MUST use to verify the "code_verifier". verifier := request.GetRequestForm().Get("code_verifier") - var signature string - if request.GetGrantTypes().ExactOne("authorization_code") { - code := request.GetRequestForm().Get("code") - signature = c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code) - } else if request.GetGrantTypes().ExactOne(string(fosite.GrantTypeDeviceCode)) { - var err error - code := request.GetRequestForm().Get("device_code") - signature, err = c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) - if err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } - } - + code := request.GetRequestForm().Get("code") + signature := c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code) authorizeRequest, err := c.Storage.GetPKCERequestSession(ctx, signature, request.GetSession()) if errors.Is(err, fosite.ErrNotFound) { return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("Unable to find initial PKCE data tied to this request").WithWrap(err).WithDebug(err.Error())) @@ -254,15 +217,7 @@ func (c *Handler) CanSkipClientAuth(ctx context.Context, requester fosite.Access } func (c *Handler) CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool { - return requester.GetGrantTypes().ExactOne("authorization_code") || - requester.GetGrantTypes().ExactOne(string(fosite.GrantTypeDeviceCode)) -} - -func isDeviceCode(r fosite.Requester) bool { - return r.GetClient().GetGrantTypes().Has(string(fosite.GrantTypeDeviceCode)) -} - -func isAuthorizationCode(r fosite.Requester) bool { - ar, ok := r.(*fosite.AuthorizeRequest) - return ok && ar.GetResponseTypes().Has("code") + // grant_type REQUIRED. + // Value MUST be set to "authorization_code" + return requester.GetGrantTypes().ExactOne("authorization_code") } diff --git a/handler/pkce/handler_device_test.go b/handler/pkce/handler_device_test.go deleted file mode 100644 index 414effc18..000000000 --- a/handler/pkce/handler_device_test.go +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package pkce - -import ( - "context" - "crypto/sha256" - "encoding/base64" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/ory/fosite" - "github.com/ory/fosite/handler/rfc8628" - "github.com/ory/fosite/storage" -) - -type mockDeviceCodeStrategy struct { - signature string -} - -func (m *mockDeviceCodeStrategy) DeviceCodeSignature(ctx context.Context, token string) (signature string, err error) { - return m.signature, nil -} - -func (m *mockDeviceCodeStrategy) GenerateDeviceCode(ctx context.Context) (token string, signature string, err error) { - return "", "", nil -} - -func (m *mockDeviceCodeStrategy) ValidateDeviceCode(ctx context.Context, requester fosite.Requester, token string) (err error) { - return nil -} - -func TestPKCEHandlerDevice_HandleAuthorizeEndpointRequest(t *testing.T) { - var config fosite.Config - h := &Handler{ - Storage: storage.NewMemoryStore(), - DeviceCodeStrategy: new(rfc8628.DefaultDeviceStrategy), - Config: &config, - } - w := fosite.NewDeviceResponse() - r := fosite.NewDeviceRequest() - config.GlobalSecret = []byte("thisissecret") - - w.SetDeviceCode("foo") - - r.Form.Add("code_challenge", "challenge") - r.Form.Add("code_challenge_method", "plain") - - c := &fosite.DefaultClient{} - r.Client = c - require.NoError(t, h.HandleDeviceEndpointRequest(context.Background(), r, w)) - - c = &fosite.DefaultClient{ - GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}, - } - r.Client = c - require.Error(t, h.HandleDeviceEndpointRequest(context.Background(), r, w)) - - c.Public = true - config.EnablePKCEPlainChallengeMethod = true - require.NoError(t, h.HandleDeviceEndpointRequest(context.Background(), r, w)) - - c.Public = false - config.EnablePKCEPlainChallengeMethod = true - require.NoError(t, h.HandleDeviceEndpointRequest(context.Background(), r, w)) - - config.EnablePKCEPlainChallengeMethod = false - require.Error(t, h.HandleDeviceEndpointRequest(context.Background(), r, w)) - - r.Form.Set("code_challenge_method", "S256") - r.Form.Set("code_challenge", "") - config.EnforcePKCE = true - require.Error(t, h.HandleDeviceEndpointRequest(context.Background(), r, w)) - - r.Form.Set("code_challenge", "challenge") - require.NoError(t, h.HandleDeviceEndpointRequest(context.Background(), r, w)) -} - -func TestPKCEHandlerDevice_HandlerValidate(t *testing.T) { - s := storage.NewMemoryStore() - ds := &mockDeviceCodeStrategy{} - config := &fosite.Config{} - h := &Handler{ - Storage: s, - DeviceCodeStrategy: ds, - Config: config, - } - pc := &fosite.DefaultClient{Public: true} - - s256verifier := "KGCt4m8AmjUvIR5ArTByrmehjtbxn1A49YpTZhsH8N7fhDr7LQayn9xx6mck" - hash := sha256.New() - hash.Write([]byte(s256verifier)) - s256challenge := base64.RawURLEncoding.EncodeToString(hash.Sum([]byte{})) - - for k, tc := range []struct { - d string - grant string - force bool - enablePlain bool - challenge string - method string - verifier string - code string - expectErr error - client *fosite.DefaultClient - }{ - { - d: "fails because not auth code flow", - grant: "not_urn:ietf:params:oauth:grant-type:device_code", - expectErr: fosite.ErrUnknownRequest, - client: &fosite.DefaultClient{Public: false}, - }, - { - d: "passes with private client", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - verifier: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - method: "plain", - client: &fosite.DefaultClient{Public: false}, - enablePlain: true, - force: true, - code: "valid-code-1", - }, - { - d: "fails because invalid code", - grant: "urn:ietf:params:oauth:grant-type:device_code", - expectErr: fosite.ErrInvalidGrant, - client: pc, - code: "invalid-code-2", - }, - { - d: "passes because auth code flow but pkce is not forced and no challenge given", - grant: "urn:ietf:params:oauth:grant-type:device_code", - client: pc, - code: "valid-code-3", - }, - { - d: "fails because auth code flow and pkce challenge given but plain is disabled", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "foo", - client: pc, - expectErr: fosite.ErrInvalidRequest, - code: "valid-code-4", - }, - { - d: "passes", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - verifier: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - client: pc, - enablePlain: true, - force: true, - code: "valid-code-5", - }, - { - d: "passes", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - verifier: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - method: "plain", - client: pc, - enablePlain: true, - force: true, - code: "valid-code-6", - }, - { - d: "fails because challenge and verifier do not match", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "not-foo", - verifier: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - method: "plain", - client: pc, - enablePlain: true, - code: "valid-code-7", - expectErr: fosite.ErrInvalidGrant, - }, - { - d: "fails because challenge and verifier do not match", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "not-foonot-foonot-foonot-foonot-foonot-foonot-foonot-foo", - verifier: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - client: pc, - enablePlain: true, - code: "valid-code-8", - expectErr: fosite.ErrInvalidGrant, - }, - { - d: "fails because verifier is too short", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "foo", - verifier: "foo", - method: "S256", - client: pc, - force: true, - code: "valid-code-9a", - expectErr: fosite.ErrInvalidGrant, - }, - { - d: "fails because verifier is too long", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "foo", - verifier: "foofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoofoo", - method: "S256", - client: pc, - force: true, - code: "valid-code-10", - expectErr: fosite.ErrInvalidGrant, - }, - { - d: "fails because verifier is malformed", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "foo", - verifier: `(!"/$%Z&$T()/)OUZI>$"&=/T(PUOI>"%/)TUOI&/(O/()RGTE>=/(%"/()="$/)(=()=/R/()=))`, - method: "S256", - client: pc, - force: true, - code: "valid-code-11", - expectErr: fosite.ErrInvalidGrant, - }, - { - d: "fails because challenge and verifier do not match", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: "Zm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9v", - verifier: "Zm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9vZm9v", - method: "S256", - client: pc, - force: true, - code: "valid-code-12", - expectErr: fosite.ErrInvalidGrant, - }, - { - d: "passes because challenge and verifier match", - grant: "urn:ietf:params:oauth:grant-type:device_code", - challenge: s256challenge, - verifier: s256verifier, - method: "S256", - client: pc, - force: true, - code: "valid-code-13", - }, - } { - t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - config.EnablePKCEPlainChallengeMethod = tc.enablePlain - config.EnforcePKCE = tc.force - ds.signature = tc.code - ar := fosite.NewAuthorizeRequest() - ar.Form.Add("code_challenge", tc.challenge) - ar.Form.Add("code_challenge_method", tc.method) - require.NoError(t, s.CreatePKCERequestSession(context.TODO(), fmt.Sprintf("valid-code-%d", k), ar)) - - r := fosite.NewAccessRequest(nil) - r.Client = tc.client - r.GrantTypes = fosite.Arguments{tc.grant} - r.Form.Add("code_verifier", tc.verifier) - r.Form.Add("device_code", tc.code) - if tc.expectErr == nil { - require.NoError(t, h.HandleTokenEndpointRequest(context.Background(), r)) - } else { - err := h.HandleTokenEndpointRequest(context.Background(), r) - require.EqualError(t, err, tc.expectErr.Error(), "%+v", err) - } - }) - } -} - -func TestPKCEHandlerDevice_HandleTokenEndpointRequest(t *testing.T) { - for k, tc := range []struct { - d string - force bool - forcePublic bool - enablePlain bool - challenge string - method string - expectErr bool - client *fosite.DefaultClient - }{ - { - d: "should pass because pkce is not enforced", - }, - { - d: "should fail because plain is not enabled and method is empty which defaults to plain", - expectErr: true, - force: true, - }, - { - d: "should fail because force is enabled and no challenge was given", - force: true, - enablePlain: true, - expectErr: true, - method: "S256", - }, - { - d: "should fail because forcePublic is enabled, the client is public, and no challenge was given", - forcePublic: true, - client: &fosite.DefaultClient{Public: true}, - expectErr: true, - method: "S256", - }, - { - d: "should fail because although force is enabled and a challenge was given, plain is disabled", - force: true, - expectErr: true, - method: "plain", - challenge: "challenge", - }, - { - d: "should fail because although force is enabled and a challenge was given, plain is disabled and method is empty", - force: true, - expectErr: true, - challenge: "challenge", - }, - { - d: "should fail because invalid challenge method", - force: true, - expectErr: true, - method: "invalid", - challenge: "challenge", - }, - { - d: "should pass because force is enabled with challenge given and method is S256", - force: true, - method: "S256", - challenge: "challenge", - }, - { - d: "should pass because forcePublic is enabled with challenge given and method is S256", - forcePublic: true, - client: &fosite.DefaultClient{Public: true}, - method: "S256", - challenge: "challenge", - }, - } { - t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) { - h := &Handler{ - Config: &fosite.Config{ - EnforcePKCE: tc.force, - EnforcePKCEForPublicClients: tc.forcePublic, - EnablePKCEPlainChallengeMethod: tc.enablePlain, - }, - } - - if tc.expectErr { - assert.Error(t, h.validate(context.Background(), tc.challenge, tc.method, tc.client)) - } else { - assert.NoError(t, h.validate(context.Background(), tc.challenge, tc.method, tc.client)) - } - }) - } -} diff --git a/handler/pkce/handler_test.go b/handler/pkce/handler_test.go index 774d90196..68c42b438 100644 --- a/handler/pkce/handler_test.go +++ b/handler/pkce/handler_test.go @@ -251,7 +251,6 @@ func TestPKCEHandlerValidate(t *testing.T) { r.Client = tc.client r.GrantTypes = fosite.Arguments{tc.grant} r.Form.Add("code_verifier", tc.verifier) - r.Form.Add("code", tc.code) if tc.expectErr == nil { require.NoError(t, h.HandleTokenEndpointRequest(context.Background(), r)) } else {