From 3dae6e2fed819d2c27ddf1d083ac66c2a9a87891 Mon Sep 17 00:00:00 2001 From: Jim Date: Thu, 1 Aug 2024 14:32:34 -0400 Subject: [PATCH] feat (config): add support for a http.RoundTripper (#137) Add support for specifying an optional http.RoundTripper for a provider config. If specified the http client will use the RoundTripper when making requests to the provider. --- oidc/config.go | 40 +++++++++++- oidc/config_test.go | 141 +++++++++++++++++++++++++++++++++++++++++- oidc/docs_test.go | 4 +- oidc/provider.go | 13 ++-- oidc/provider_test.go | 25 ++++++++ 5 files changed, 214 insertions(+), 9 deletions(-) diff --git a/oidc/config.go b/oidc/config.go index db2f4a6..4f24b37 100644 --- a/oidc/config.go +++ b/oidc/config.go @@ -12,6 +12,7 @@ import ( "fmt" "hash" "hash/fnv" + "net/http" "net/url" "reflect" "runtime" @@ -89,9 +90,15 @@ type Config struct { // ProviderCA is an optional CA certs (PEM encoded) to use when sending // requests to the provider. If you have a list of *x509.Certificates, then - // see EncodeCertificates(...) to PEM encode them. + // see EncodeCertificates(...) to PEM encode them. Note: specifying both + // ProviderCA and RoundTripper is an error. ProviderCA string + // RoundTripper is an optional http.RoundTripper to use when sending requests + // to the provider. Note: specifying both ProviderCA and RoundTripper is an + // error. + RoundTripper http.RoundTripper + // NowFunc is a time func that returns the current time. NowFunc func() time.Time `json:"-"` @@ -118,6 +125,7 @@ func NewConfig(issuer string, clientID string, clientSecret ClientSecret, suppor SupportedSigningAlgs: supported, Scopes: opts.withScopes, ProviderCA: opts.withProviderCA, + RoundTripper: opts.withRoundTripper, Audiences: opts.withAudiences, NowFunc: opts.withNowFunc, AllowedRedirectURLs: allowedRedirectURLs, @@ -168,6 +176,16 @@ func (c *Config) Hash() (uint64, error) { args = append(args, audiences...) args = append(args, redirects...) + if c.RoundTripper != nil { + v := reflect.ValueOf(c.RoundTripper) + switch { + case v.CanAddr(): + args = append(args, v.Addr().String()) + default: + args = append(args, v.String()) + } + } + if c.ProviderConfig != nil { args = append( args, @@ -269,6 +287,9 @@ func (c *Config) Validate() error { return fmt.Errorf("%s: %w", op, ErrInvalidCACert) } } + if c.ProviderCA != "" && c.RoundTripper != nil { + return fmt.Errorf("%s: you cannot specify both a ProviderCA and RoundTripper: %w", op, ErrInvalidParameter) + } if c.ProviderConfig != nil { switch { @@ -300,6 +321,7 @@ type configOptions struct { withProviderCA string withNowFunc func() time.Time withProviderConfig *ProviderConfig + withRoundTripper http.RoundTripper } // configDefaults is a handy way to get the defaults at runtime and @@ -319,12 +341,14 @@ func getConfigOpts(opt ...Option) configOptions { } // WithProviderCA provides optional CA certs (PEM encoded) for the provider's -// config. These certs will can be used when making http requests to the +// config. These certs will be used when making http requests to the // provider. // // Valid for: Config // // See EncodeCertificates(...) to PEM encode a number of certs. +// +// Note: specifying both WithProviderCA and WithRoundTripper is a error. func WithProviderCA(cert string) Option { return func(o interface{}) { if o, ok := o.(*configOptions); ok { @@ -333,6 +357,18 @@ func WithProviderCA(cert string) Option { } } +// WithRoundTripper provides and optional RoundTripper for the provider's +// config. This RoundTripper will be used when making http requests to the +// provider. Note: specifying both WithProviderCA and WithRoundTripper is a +// error. +func WithRoundTripper(rt http.RoundTripper) Option { + return func(o interface{}) { + if o, ok := o.(*configOptions); ok { + o.withRoundTripper = rt + } + } +} + // EncodeCertificates will encode a number of x509 certificates to PEM. It will // help encode certs for use with the WithProviderCA(...) option. func EncodeCertificates(certs ...*x509.Certificate) (string, error) { diff --git a/oidc/config_test.go b/oidc/config_test.go index 4dd3b00..85b576f 100644 --- a/oidc/config_test.go +++ b/oidc/config_test.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "errors" "fmt" + "net/http" "testing" "time" @@ -44,6 +45,8 @@ func TestNewConfig(t *testing.T) { return time.Now().Add(-1 * time.Minute) } + testRt := newTestRoundTripper(t) + type args struct { issuer string clientID string @@ -61,7 +64,7 @@ func TestNewConfig(t *testing.T) { wantErrContains string }{ { - name: "valid-with-all-valid-opts", + name: "valid-with-all-valid-opts-except-with-round-tripper", args: args{ issuer: "http://your_issuer/", clientID: "your_client_id", @@ -103,6 +106,49 @@ func TestNewConfig(t *testing.T) { }, }, }, + { + name: "with-round-tripper", + args: args{ + issuer: "http://your_issuer/", + clientID: "your_client_id", + clientSecret: "your_client_secret", + supported: []Alg{RS512}, + allowedRedirectURLs: []string{"http://your_redirect_url", "http://redirect_url_two", "http://redirect_url_three"}, + opt: []Option{ + WithAudiences("your_aud1", "your_aud2"), + WithScopes("email", "profile"), + WithRoundTripper(testRt), + WithNow(testNow), + WithProviderConfig(&ProviderConfig{ + AuthURL: "https://auth-endpoint", + JWKSURL: "https://jwks-endpoint", + TokenURL: "https://token-endpoint", + UserInfoURL: "https://userinfo-endpoint", + }), + }, + }, + want: &Config{ + Issuer: "http://your_issuer/", + ClientID: "your_client_id", + ClientSecret: "your_client_secret", + SupportedSigningAlgs: []Alg{RS512}, + Audiences: []string{"your_aud1", "your_aud2"}, + Scopes: []string{oidc.ScopeOpenID, "email", "profile"}, + RoundTripper: testRt, + NowFunc: testNow, + AllowedRedirectURLs: []string{ + "http://your_redirect_url", + "http://redirect_url_two", + "http://redirect_url_three", + }, + ProviderConfig: &ProviderConfig{ + AuthURL: "https://auth-endpoint", + JWKSURL: "https://jwks-endpoint", + TokenURL: "https://token-endpoint", + UserInfoURL: "https://userinfo-endpoint", + }, + }, + }, { name: "missing-provider-config-auth-url", args: args{ @@ -282,6 +328,22 @@ func TestNewConfig(t *testing.T) { wantErr: true, wantIsErr: ErrInvalidCACert, }, + { + name: "invalid-both-cert-and-round-tripper", + args: args{ + issuer: "http://your_issuer/", + clientID: "your_client_id", + clientSecret: "your_client_secret", + supported: []Alg{RS512}, + allowedRedirectURLs: []string{"http://your_redirect_url"}, + opt: []Option{ + WithProviderCA(testCaPem), + WithRoundTripper(testRt), + }, + }, + wantErr: true, + wantIsErr: ErrInvalidParameter, + }, { name: "invalid-alg", args: args{ @@ -430,6 +492,7 @@ func TestConfig_Hash(t *testing.T) { require.NoError(t, err) return c } + testRt := newTestRoundTripper(t) tests := []struct { name string c1 *Config @@ -473,6 +536,42 @@ func TestConfig_Hash(t *testing.T) { ), wantEqual: true, }, + { + name: "equal-with-round-tripper", + c1: newCfg( + "https://www.alice.com", + "client-id", "client-secret", + []Alg{RS256}, + []string{"www.alice.com/callback", "www.bob.com/callback"}, + WithScopes("email", "profile"), + WithAudiences("alice.com", "bob.com"), + WithRoundTripper(testRt), + WithNow(time.Now), + WithProviderConfig(&ProviderConfig{ + AuthURL: "https://auth-endpoint", + JWKSURL: "https://jwks-endpoint", + TokenURL: "https://token-endpoint", + UserInfoURL: "https://userinfo-endpoint", + }), + ), + c2: newCfg( + "https://www.alice.com", + "client-id", "client-secret", + []Alg{RS256}, + []string{"www.bob.com/callback", "www.alice.com/callback"}, + WithScopes("profile", "email"), + WithAudiences("bob.com", "alice.com"), + WithRoundTripper(testRt), + WithNow(time.Now), + WithProviderConfig(&ProviderConfig{ + AuthURL: "https://auth-endpoint", + JWKSURL: "https://jwks-endpoint", + TokenURL: "https://token-endpoint", + UserInfoURL: "https://userinfo-endpoint", + }), + ), + wantEqual: true, + }, { name: "diff-issuer", c1: newCfg( @@ -664,6 +763,29 @@ func TestConfig_Hash(t *testing.T) { ), wantEqual: false, }, + { + name: "diff-round-trippers", + c1: newCfg( + "https://www.alice.com", + "client-id", "client-secret", + []Alg{RS256}, + []string{"www.alice.com/callback"}, + WithScopes("email", "profile"), + WithAudiences("alice.com", "bob.com"), + WithRoundTripper(newTestRoundTripper(t)), + WithNow(time.Now), + ), + c2: newCfg( + "https://www.alice.com", + "client-id", "client-secret", + []Alg{RS256}, + []string{"www.alice.com/callback"}, + WithScopes("email", "profile"), + WithAudiences("alice.com", "bob.com"), + WithNow(time.Now), + ), + wantEqual: false, + }, { name: "diff-now-func", c1: newCfg( @@ -855,3 +977,20 @@ func TestConfig_Hash(t *testing.T) { }) } } + +type testRoundTripper struct { + transport http.RoundTripper + called int +} + +func newTestRoundTripper(t *testing.T) *testRoundTripper { + t.Helper() + return &testRoundTripper{ + transport: http.DefaultTransport, + } +} + +func (rt *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.called++ + return rt.transport.RoundTrip(req) +} diff --git a/oidc/docs_test.go b/oidc/docs_test.go index 0e50db0..4b84ec5 100644 --- a/oidc/docs_test.go +++ b/oidc/docs_test.go @@ -95,7 +95,7 @@ func ExampleNewConfig() { fmt.Println(pc) // Output: - // &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] } + // &{your_client_id [REDACTED: client secret] [openid] http://your_issuer/ [RS256] [http://your_redirect_url/callback] [] } } func ExampleWithProviderConfig() { @@ -120,7 +120,7 @@ func ExampleWithProviderConfig() { fmt.Println(string(val)) // Output: - // {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}} + // {"ClientID":"your_client_id","ClientSecret":"[REDACTED: client secret]","Scopes":["openid"],"Issuer":"https://your_issuer/","SupportedSigningAlgs":["RS256"],"AllowedRedirectURLs":["https://your_redirect_url/callback"],"Audiences":null,"ProviderCA":"","RoundTripper":null,"ProviderConfig":{"AuthURL":"https://your_issuer/authorize","TokenURL":"https://your_issuer/token","UserInfoURL":"https://your_issuer/userinfo","JWKSURL":"https://your_issuer/.well-known/jwks.json"}} } func ExampleNewProvider() { diff --git a/oidc/provider.go b/oidc/provider.go index 4c0f491..f8b342e 100644 --- a/oidc/provider.go +++ b/oidc/provider.go @@ -635,17 +635,22 @@ func (p *Provider) HTTPClient() (*http.Client, error) { // to the same host. On the downside, this transport can leak file // descriptors over time, so we'll be sure to call // client.CloseIdleConnections() in the Provider.Done() to stave that off. - tr := cleanhttp.DefaultPooledTransport() + var tr http.RoundTripper - if p.config.ProviderCA != "" { + switch { + case p.config.RoundTripper != nil && p.config.ProviderCA != "": + return nil, fmt.Errorf("%s: you cannot specify config for both a ProviderCA and RoundTripper: %w", op, ErrInvalidParameter) + case p.config.ProviderCA != "": certPool := x509.NewCertPool() if ok := certPool.AppendCertsFromPEM([]byte(p.config.ProviderCA)); !ok { return nil, fmt.Errorf("%s: %w", op, ErrInvalidCACert) } - - tr.TLSClientConfig = &tls.Config{ + tr = cleanhttp.DefaultPooledTransport() + tr.(*http.Transport).TLSClientConfig = &tls.Config{ RootCAs: certPool, } + case p.config.RoundTripper != nil: + tr = p.config.RoundTripper } c := &http.Client{ diff --git a/oidc/provider_test.go b/oidc/provider_test.go index b715863..b07b128 100644 --- a/oidc/provider_test.go +++ b/oidc/provider_test.go @@ -714,6 +714,31 @@ func TestHTTPClient(t *testing.T) { require.NoError(t, err) assert.Equal(t, c.Transport, p.client.Transport) }) + t.Run("check-transport-with-round-tripper", func(t *testing.T) { + testRt := newTestRoundTripper(t) + p := &Provider{ + config: &Config{ + RoundTripper: testRt, + }, + } + c, err := p.HTTPClient() + require.NoError(t, err) + assert.Equal(t, c.Transport, p.client.Transport) + }) + t.Run("err-both-ca-and-round-trippe", func(t *testing.T) { + _, testCaPem := TestGenerateCA(t, []string{"localhost"}) + + p := &Provider{ + config: &Config{ + ProviderCA: testCaPem, + RoundTripper: newTestRoundTripper(t), + }, + } + _, err := p.HTTPClient() + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidParameter) + assert.ErrorContains(t, err, "you cannot specify config for both a ProviderCA and RoundTripper") + }) } func TestProvider_UserInfo(t *testing.T) {