diff --git a/src/oidcc_token.erl b/src/oidcc_token.erl index 784b583..4e22637 100644 --- a/src/oidcc_token.erl +++ b/src/oidcc_token.erl @@ -230,6 +230,7 @@ See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3. | {invalid_property, { Field :: id_token | refresh_token | access_token | expires_in | scopes, GivenValue :: term() }} + | no_supported_code_challenge | oidcc_jwt_util:error() | oidcc_http_util:error(). @@ -359,7 +360,6 @@ retrieve(AuthCode, ClientContext, Opts) -> case lists:member(<<"authorization_code">>, GrantTypesSupported) of true -> - PkceVerifier = maps:get(pkce_verifier, Opts, none), QsBody = [ {<<"grant_type">>, <<"authorization_code">>}, @@ -375,7 +375,7 @@ retrieve(AuthCode, ClientContext, Opts) -> maybe {ok, Token} ?= retrieve_a_token( - QsBody, PkceVerifier, ClientContext, Opts, TelemetryOpts, true + QsBody, ClientContext, Opts, TelemetryOpts, true ), extract_response(Token, ClientContext, Opts) end; @@ -535,7 +535,7 @@ refresh(RefreshToken, ClientContext, Opts) -> maybe {ok, Token} ?= - retrieve_a_token(QueryString1, none, ClientContext, Opts, TelemetryOpts, true), + retrieve_a_token(QueryString1, ClientContext, Opts, TelemetryOpts, true), {ok, TokenRecord} ?= extract_response(Token, ClientContext, maps:put(nonce, any, Opts)), case TokenRecord of @@ -631,7 +631,7 @@ jwt_profile(Subject, ClientContext, Jwk, Opts) -> maybe {ok, Token} ?= - retrieve_a_token(QueryString1, none, ClientContext, Opts, TelemetryOpts, false), + retrieve_a_token(QueryString1, ClientContext, Opts, TelemetryOpts, false), {ok, TokenRecord} ?= extract_response(Token, ClientContext, maps:put(nonce, any, Opts)), case TokenRecord of @@ -692,7 +692,7 @@ client_credentials(ClientContext, Opts) -> maybe {ok, Token} ?= - retrieve_a_token(QueryString1, none, ClientContext, Opts, TelemetryOpts, true), + retrieve_a_token(QueryString1, ClientContext, Opts, TelemetryOpts, true), extract_response(Token, ClientContext, maps:put(nonce, any, Opts)) end; false -> @@ -1139,21 +1139,16 @@ verify_missing_required_claims(Claims) -> end. -spec retrieve_a_token( - QsBodyIn, PkceVerifier, ClientContext, Opts, TelemetryOpts, AuthenticateClient + QsBodyIn, ClientContext, Opts, TelemetryOpts, AuthenticateClient ) -> {ok, map()} | {error, error()} when QsBodyIn :: oidcc_http_util:query_params(), - PkceVerifier :: binary() | none, ClientContext :: oidcc_client_context:t(), Opts :: retrieve_opts() | refresh_opts(), TelemetryOpts :: oidcc_http_util:telemetry_opts(), AuthenticateClient :: boolean(). -retrieve_a_token( - _QsBodyIn, none, _ClientContext, #{require_pkce := true}, _TelemetryOpts, _AuthenticateClient -) -> - {error, pkce_verifier_required}; -retrieve_a_token(QsBodyIn, PkceVerifier, ClientContext, Opts, TelemetryOpts, AuthenticateClient) -> +retrieve_a_token(QsBodyIn, ClientContext, Opts, TelemetryOpts, AuthenticateClient) -> #oidcc_client_context{provider_configuration = Configuration} = ClientContext, #oidcc_provider_configuration{ @@ -1168,7 +1163,6 @@ retrieve_a_token(QsBodyIn, PkceVerifier, ClientContext, Opts, TelemetryOpts, Aut Header0 = [{"accept", "application/jwt, application/json"}], QsBody0 = QsBodyIn ++ maps:get(body_extension, Opts, []), - QsBody = add_pkce_verifier(QsBody0, PkceVerifier), SupportedAuthMethods = case AuthenticateClient of @@ -1184,6 +1178,7 @@ retrieve_a_token(QsBodyIn, PkceVerifier, ClientContext, Opts, TelemetryOpts, Aut #{} end, maybe + {ok, QsBody} ?= add_pkce_verifier(QsBody0, Opts, ClientContext), {ok, {Body, Header1}, AuthMethod} ?= oidcc_auth_util:add_client_authentication( QsBody, Header0, SupportedAuthMethods, SigningAlgs, Opts, ClientContext @@ -1211,7 +1206,6 @@ retrieve_a_token(QsBodyIn, PkceVerifier, ClientContext, Opts, TelemetryOpts, Aut %% (to avoid infinite loops) retrieve_a_token( QsBodyIn, - PkceVerifier, ClientContext, Opts#{dpop_nonce => NewDpopNonce}, TelemetryOpts, @@ -1221,10 +1215,37 @@ retrieve_a_token(QsBodyIn, PkceVerifier, ClientContext, Opts, TelemetryOpts, Aut {error, Reason} end. --spec add_pkce_verifier(QueryList, PkceVerifier) -> oidcc_http_util:query_params() when +-spec add_pkce_verifier(QueryList, Opts, ClientContext) -> + {ok, oidcc_http_util:query_params()} | {error, error()} +when QueryList :: oidcc_http_util:query_params(), - PkceVerifier :: binary() | none. -add_pkce_verifier(BodyQs, none) -> - BodyQs; -add_pkce_verifier(BodyQs, PkceVerifier) -> - [{<<"code_verifier">>, PkceVerifier} | BodyQs]. + Opts :: retrieve_opts() | refresh_opts(), + ClientContext :: oidcc_client_context:t(). +add_pkce_verifier(BodyQs, #{pkce_verifier := PkceVerifier} = Opts, ClientContext) -> + #oidcc_client_context{provider_configuration = ProviderConfiguration} = ClientContext, + #oidcc_provider_configuration{code_challenge_methods_supported = CodeChallengeMethodsSupported} = + ProviderConfiguration, + RequirePkce = maps:get(require_pkce, Opts, false), + + case CodeChallengeMethodsSupported of + undefined when RequirePkce =:= true -> + {error, no_supported_code_challenge}; + undefined -> + {ok, BodyQs}; + Methods when is_list(Methods) -> + case + lists:member(<<"S256">>, CodeChallengeMethodsSupported) or + lists:member(<<"plain">>, CodeChallengeMethodsSupported) + of + true -> + {ok, [{<<"code_verifier">>, PkceVerifier} | BodyQs]}; + false when RequirePkce =:= true -> + {error, no_supported_code_challenge}; + false -> + {ok, BodyQs} + end + end; +add_pkce_verifier(_BodyQs, #{require_pkce := true}, _ClientContext) -> + {error, pkce_verifier_required}; +add_pkce_verifier(BodyQs, _Opts, _ClientContext) -> + {ok, BodyQs}. diff --git a/test/oidcc_token_test.erl b/test/oidcc_token_test.erl index 60021a4..2eef3ae 100644 --- a/test/oidcc_token_test.erl +++ b/test/oidcc_token_test.erl @@ -1710,18 +1710,66 @@ trusted_audiences_test() -> ok. -retrieve_pkce_required_test() -> - ClientContext = client_context_fapi2_fixture(), +retrieve_pkce_test() -> + ok = meck:new(httpc, [no_link]), + HttpFun = + fun( + post, + {_TokenEndpoint, _Header, "application/x-www-form-urlencoded", _Body}, + _HttpOpts, + _Opts, + _Profile + ) -> + {ok, {{"HTTP/1.1", 500, "Server Error"}, [], "SUCCESS"}} + end, + ok = meck:expect(httpc, request, HttpFun), + + PkceSupportedClientContext = client_context_fapi2_fixture(), + PkceUnsupportedClientContext = PkceSupportedClientContext#oidcc_client_context{ + provider_configuration = PkceSupportedClientContext#oidcc_client_context.provider_configuration#oidcc_provider_configuration{ + code_challenge_methods_supported = undefined + } + }, RedirectUri = <<"https://redirect.example/">>, ?assertEqual( {error, pkce_verifier_required}, - oidcc_token:retrieve(<<"code">>, ClientContext, #{ + oidcc_token:retrieve(<<"code">>, PkceSupportedClientContext, #{ redirect_uri => RedirectUri, require_pkce => true }) ), + ?assertEqual( + {error, {http_error, 500, "SUCCESS"}}, + oidcc_token:retrieve(<<"code">>, PkceSupportedClientContext, #{ + redirect_uri => RedirectUri, + require_pkce => true, + pkce_verifier => <<"verifier">> + }) + ), + + ?assertEqual( + {error, no_supported_code_challenge}, + oidcc_token:retrieve(<<"code">>, PkceUnsupportedClientContext, #{ + redirect_uri => RedirectUri, + require_pkce => true, + pkce_verifier => <<"verifier">> + }) + ), + + ?assertEqual( + {error, {http_error, 500, "SUCCESS"}}, + oidcc_token:retrieve(<<"code">>, PkceUnsupportedClientContext, #{ + redirect_uri => RedirectUri, + pkce_verifier => <<"verifier">> + }) + ), + + true = meck:validate(httpc), + + meck:unload(httpc), + ok. validate_jarm_test() ->