Skip to content

Commit

Permalink
feat: optional client_id on token endpoint for private_key_jwt (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
vdbulcke committed Nov 1, 2023
1 parent 6abb85b commit 84ab65a
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .goreleaser.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,4 @@ signs:
# footer: |
# ## Thanks!

# Those were the changes on {{ .Tag }}!
# Those were the changes on {{ .Tag }}!
6 changes: 6 additions & 0 deletions example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ client_secret: bar
### NOTE: tls_client_auth requires '--pem-key' and '--pem-certificate' flag
auth_method: client_secret_basic


## Always Set 'client_id' for token endpoint (optional)
### Since version v0.19.0
# always_set_client_id_for_token_endpoint: true


## Private Key Jwt (Optional)
### Since version v0.16.0
###
Expand Down
41 changes: 41 additions & 0 deletions src/client/access_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package oidcclient

import (
"context"
"net/url"

internaloauth2 "github.com/vdbulcke/oidc-client-demo/src/client/internal/oauth2"
)

// TokenExchange call oauth2 token endpoint with the configured auth method
// expects the querystring parameters as input (token=..., grant_type=...)
func (c *OIDCClient) TokenExchange(params url.Values) (*internaloauth2.Token, error) {

if c.config.AuthMethod == "private_key_jwt" {

// signedJwt, err := c.GenerateJwtProfile(c.config.IntrospectEndpoint)
signedJwt, err := c.GenerateJwtProfile(c.Wellknown.TokenEndpoint)
if err != nil {
c.logger.Error("Failed to generate jwt client_assertion", "err", err)
return nil, err
}
c.logger.Debug("introspect setting client_assertion", "jwt", signedJwt)
params.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
params.Set("client_assertion", signedJwt)

} else if c.config.AuthMethod == "tls_client_auth" {
c.logger.Debug("set client_id", c.config.ClientID)
params.Set("client_id", c.config.ClientID)
}

if c.config.ClientIDParamForTokenEndpoint {
params.Set("client_id", c.config.ClientID)
}
c.logger.Debug("retrieve token ", "param", params)
oauth2Token, err := internaloauth2.RetrieveToken(context.TODO(), c.config.ClientID, c.config.ClientSecret, c.Wellknown.TokenEndpoint, params, c.oAuthConfig.Endpoint.AuthStyle)
if err != nil {
return nil, err
}

return oauth2Token, nil
}
38 changes: 33 additions & 5 deletions src/client/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ func (c *OIDCClient) OIDCAuthorizationCodeFlow() error {
authZCode := r.URL.Query().Get("code")
c.logger.Info("Received AuthZ Code", "code", authZCode)

v := url.Values{
"grant_type": {"authorization_code"},
"code": {authZCode},
"redirect_uri": {c.config.RedirectUri},
}
// Extra parameter for authorize endpoint
var tokenOpts []oauth2.AuthCodeOption
// var oauth2Token *oauth2.Token
Expand All @@ -209,6 +214,7 @@ func (c *OIDCClient) OIDCAuthorizationCodeFlow() error {

// set extra pkce param
pkceVerifierOption := oauth2.SetAuthURLParam("code_verifier", codeVerifier)
v.Set("code_verifier", codeVerifier)
tokenOpts = append(tokenOpts, pkceVerifierOption)
c.logger.Debug("using pkce code_verifier for getting access token")

Expand All @@ -218,6 +224,8 @@ func (c *OIDCClient) OIDCAuthorizationCodeFlow() error {

assertionTypeOption := oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")

v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")

signedJwt, err := c.GenerateJwtProfile(c.Wellknown.TokenEndpoint)
if err != nil {
c.logger.Error("Failed to generate jwt client_assertion", "err", err)
Expand All @@ -226,22 +234,35 @@ func (c *OIDCClient) OIDCAuthorizationCodeFlow() error {
}

assertionOption := oauth2.SetAuthURLParam("client_assertion", signedJwt)

v.Set("client_assertion", signedJwt)
c.logger.Debug("generated jwt", "client_assertion", signedJwt)
tokenOpts = append(tokenOpts, assertionTypeOption)
tokenOpts = append(tokenOpts, assertionOption)

}

c.logger.Debug("Token exchange", "opts", tokenOpts)
// Access Token Response
oauth2Token, err := c.oAuthConfig.Exchange(c.ctx, authZCode, tokenOpts...)
// oauth2Token, err := c.oAuthConfig.Exchange(c.ctx, authZCode, tokenOpts...)
// if err != nil {
// c.logger.Error("Failed to get Access Token", "err", err)
// http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
// return
// }

// for _, opt := range tokenOpts {
// opt.setValue(v)
// }
oauth2Token, err := c.TokenExchange(v)
if err != nil {
c.logger.Error("Failed to get Access Token", "err", err)
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
return
}

// Parse Access Token
accessTokenResponse, err := c.parseAccessTokenResponse(oauth2Token)
// accessTokenResponse, err := c.parseAccessTokenResponse(oauth2Token)
accessTokenResponse, err := c.parseInternalAccessTokenResponse(oauth2Token)
if err != nil {
c.logger.Error("Error Parsing Access Token", "err", err)
http.Error(w, "Error Parsing Access Token", http.StatusBadRequest)
Expand Down Expand Up @@ -316,7 +337,14 @@ func (c *OIDCClient) OIDCAuthorizationCodeFlow() error {
}

// Fetch Userinfo
err = c.userinfo(oauth2Token)

tok := &oauth2.Token{
AccessToken: oauth2Token.AccessToken,
RefreshToken: oauth2Token.RefreshToken,
TokenType: oauth2Token.TokenType,
Expiry: oauth2Token.Expiry,
}
err = c.userinfo(tok)
if err != nil {
http.Error(w, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
return
Expand All @@ -326,7 +354,7 @@ func (c *OIDCClient) OIDCAuthorizationCodeFlow() error {
resp := struct {
OAuth2Token *oauth2.Token
AccessTokenResp *JSONAccessTokenResponse
}{oauth2Token, accessTokenResponse}
}{tok, accessTokenResponse}

// Format in JSON global HTTP response
data, err := json.MarshalIndent(resp, "", " ")
Expand Down
5 changes: 3 additions & 2 deletions src/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"net/http"
"time"

"golang.org/x/oauth2"
internaloauth2 "github.com/vdbulcke/oidc-client-demo/src/client/internal/oauth2"
)

// JSONAccessTokenResponse ...
Expand Down Expand Up @@ -42,7 +42,7 @@ func (c *OIDCClient) setCallbackCookie(w http.ResponseWriter, r *http.Request, n
http.SetCookie(w, cookie)
}

func (c *OIDCClient) parseAccessTokenResponse(oauth2Token *oauth2.Token) (*JSONAccessTokenResponse, error) {
func (c *OIDCClient) parseInternalAccessTokenResponse(oauth2Token *internaloauth2.Token) (*JSONAccessTokenResponse, error) {
// common logger text
commonLoggerText := "Access Token Response"

Expand All @@ -54,6 +54,7 @@ func (c *OIDCClient) parseAccessTokenResponse(oauth2Token *oauth2.Token) (*JSONA

// Parse Token Type
tokenType := oauth2Token.Type()
// tokenType := "Bearer"
if c.logger.IsDebug() {
c.logger.Debug(commonLoggerText, "token_type", tokenType)
}
Expand Down
10 changes: 6 additions & 4 deletions src/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type OIDCClientConfig struct {
ClientSecret string `yaml:"client_secret" `
AuthMethod string `yaml:"auth_method" validate:"required,oneof=client_secret_basic client_secret_post private_key_jwt tls_client_auth"`

ClientIDParamForTokenEndpoint bool `yaml:"always_set_client_id_for_token_endpoint" default:"false"`

UsePKCE bool `yaml:"use_pkce"`
PKCEChallengeMethod string `yaml:"pkce_challenge_method"`
PKCECodeLength int
Expand Down Expand Up @@ -130,10 +132,10 @@ func ValidateConfig(config *OIDCClientConfig) bool {
}
}

if !config.UsePKCE && config.ClientSecret == "" {
fmt.Println("Error 'client_secret' not set")
return false
}
// if !config.UsePKCE && config.ClientSecret == "" {
// fmt.Println("Error 'client_secret' not set")
// return false
// }

if errs == nil {
return true
Expand Down
4 changes: 4 additions & 0 deletions src/client/internal/oauth2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ func (t *Token) Extra(key string) interface{} {
return v
}

func (t *Token) Type() string {
return t.TokenType
}

// tokenJSON is the struct representing the HTTP response from OAuth2
// providers returning a token or error in JSON form.
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
Expand Down
3 changes: 1 addition & 2 deletions src/client/internal/oidc/discovery/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"strings"

"github.com/go-playground/validator"
"github.com/hashicorp/go-hclog"
Expand Down Expand Up @@ -46,7 +45,7 @@ func NewWellKnown(wellKnown string) (*OIDCWellKnownOpenidConfiguration, error) {
// ValidWellKnown validate config
func ValidWellKnown(w *OIDCWellKnownOpenidConfiguration, issuer string, logger hclog.Logger) bool {

issuer = strings.TrimSuffix(issuer, "/")
// issuer = strings.TrimSuffix(issuer, "/")
if issuer != w.Issuer {
logger.Error("Issuer not matching discovery", "issuer", issuer, "discovery", w.Issuer)
return false
Expand Down
2 changes: 2 additions & 0 deletions src/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ func NewOIDCClient(c *OIDCClientConfig, jwtsigner signer.JwtSigner, clientCert t

case "client_secret_post":
oAuthConfig.Endpoint.AuthStyle = oauth2.AuthStyleInParams
default:
oAuthConfig.Endpoint.AuthStyle = oauth2.AuthStyleAutoDetect

}

Expand Down
31 changes: 1 addition & 30 deletions src/client/refresh_token.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package oidcclient

import (
"context"
"net/url"

internaloauth2 "github.com/vdbulcke/oidc-client-demo/src/client/internal/oauth2"
"golang.org/x/oauth2"
)

Expand All @@ -18,34 +16,7 @@ func (c *OIDCClient) RefreshTokenFlow(refreshToken string, skipIdTokenVerificati
"refresh_token": {refreshToken},
}

if c.config.AuthMethod == "private_key_jwt" {

// signedJwt, err := c.GenerateJwtProfile(c.config.IntrospectEndpoint)
signedJwt, err := c.GenerateJwtProfile(c.Wellknown.TokenEndpoint)
if err != nil {
c.logger.Error("Failed to generate jwt client_assertion", "err", err)
return err
}
c.logger.Debug("introspect setting client_assertion", "jwt", signedJwt)
params.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
params.Set("client_assertion", signedJwt)

} else if c.config.AuthMethod == "tls_client_auth" {
c.logger.Debug("set client_id", c.config.ClientID)
params.Set("client_id", c.config.ClientID)
}

// token := new(oauth2.Token)
// token.RefreshToken = refreshToken
// token.Expiry = time.Now()

// TokenSource will refresh the token if needed (which is likely in this
// use case)
// ts := c.oAuthConfig.TokenSource(context.TODO(), token)

// get the oauth Token
// oauth2Token, err := ts.Token()
oauth2Token, err := internaloauth2.RetrieveToken(context.TODO(), c.config.ClientID, c.config.ClientSecret, c.Wellknown.TokenEndpoint, params, c.oAuthConfig.Endpoint.AuthStyle)
oauth2Token, err := c.TokenExchange(params)
if err != nil {
c.logger.Error("Failed to Renew Access Token from refresh token", "refresh-token", refreshToken, "error", err)
return err
Expand Down

0 comments on commit 84ab65a

Please sign in to comment.