diff --git a/support/oidc-discovery-provider/config.go b/support/oidc-discovery-provider/config.go index 7df3eaa615..3b77cfc252 100644 --- a/support/oidc-discovery-provider/config.go +++ b/support/oidc-discovery-provider/config.go @@ -79,6 +79,10 @@ type Config struct { // JWTIssuer specifies the issuer for the OIDC provider configuration request. JWTIssuer string `hcl:"jwt_issuer"` + + // AdvertisedURL specifies the absolute urls to return in documents. Use this if you are fronting the + // discovery provider with a load balancer or reverse proxy + AdvertisedURL string `hcl:"advertised_url"` } type ServingCertFileConfig struct { @@ -297,6 +301,12 @@ func ParseConfig(hclConfig string) (_ *Config, err error) { return nil, errs.New("the jwt_issuer url could not be parsed") } } + if c.AdvertisedURL != "" { + advertisedURL, err := url.Parse(c.AdvertisedURL) + if err != nil || advertisedURL.Scheme == "" || advertisedURL.Host == "" { + return nil, errs.New("the advertised_url setting could not be parsed") + } + } return c, nil } diff --git a/support/oidc-discovery-provider/handler.go b/support/oidc-discovery-provider/handler.go index 7d970ce9e9..299b06e0b8 100644 --- a/support/oidc-discovery-provider/handler.go +++ b/support/oidc-discovery-provider/handler.go @@ -25,11 +25,12 @@ type Handler struct { setKeyUse bool log logrus.FieldLogger jwtIssuer string + advertisedURL string http.Handler } -func NewHandler(log logrus.FieldLogger, domainPolicy DomainPolicy, source JWKSSource, allowInsecureScheme bool, setKeyUse bool, jwtIssuer string) *Handler { +func NewHandler(log logrus.FieldLogger, domainPolicy DomainPolicy, source JWKSSource, allowInsecureScheme bool, setKeyUse bool, jwtIssuer string, advertisedURL string) *Handler { h := &Handler{ domainPolicy: domainPolicy, source: source, @@ -37,6 +38,7 @@ func NewHandler(log logrus.FieldLogger, domainPolicy DomainPolicy, source JWKSSo setKeyUse: setKeyUse, log: log, jwtIssuer: jwtIssuer, + advertisedURL: advertisedURL, } mux := http.NewServeMux() @@ -56,6 +58,7 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) { var host string var path string var urlScheme string + var keysURL url.URL if h.jwtIssuer != "" { jwtIssuerURL, _ := url.Parse(h.jwtIssuer) host = jwtIssuerURL.Host @@ -68,6 +71,29 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) { urlScheme = "http" } } + if h.advertisedURL != "" { + tmpURL, _ := url.Parse(h.advertisedURL) + keysPath, err := url.JoinPath(tmpURL.Path, "/keys") + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + keysURL = url.URL{ + Scheme: tmpURL.Scheme, + Host: tmpURL.Host, + Path: keysPath, + } + } else { + tmpURLScheme := "https" + if h.allowInsecureScheme && r.TLS == nil && r.URL.Scheme != "https" { + tmpURLScheme = "http" + } + keysURL = url.URL{ + Scheme: tmpURLScheme, + Host: r.Host, + Path: "/keys", + } + } if err := h.verifyHost(host); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -80,12 +106,6 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) { Path: path, } - jwksURI := url.URL{ - Scheme: urlScheme, - Host: r.Host, - Path: "/keys", - } - doc := struct { Issuer string `json:"issuer"` JWKSURI string `json:"jwks_uri"` @@ -98,7 +118,7 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) { IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` }{ Issuer: issuerURL.String(), - JWKSURI: jwksURI.String(), + JWKSURI: keysURL.String(), AuthorizationEndpoint: "", ResponseTypesSupported: []string{"id_token"}, diff --git a/support/oidc-discovery-provider/handler_test.go b/support/oidc-discovery-provider/handler_test.go index 86a041a5ec..498ffe0a8f 100644 --- a/support/oidc-discovery-provider/handler_test.go +++ b/support/oidc-discovery-provider/handler_test.go @@ -173,7 +173,7 @@ func TestHandlerHTTPS(t *testing.T) { require.NoError(t, err) w := httptest.NewRecorder() - h := NewHandler(log, domainAllowlist(t, "localhost", "domain.test"), source, false, testCase.setKeyUse, "") + h := NewHandler(log, domainAllowlist(t, "localhost", "domain.test"), source, false, testCase.setKeyUse, "", "") h.ServeHTTP(w, r) t.Logf("HEADERS: %q", w.Header()) @@ -286,7 +286,7 @@ func TestHandlerHTTPInsecure(t *testing.T) { require.NoError(t, err) w := httptest.NewRecorder() - h := NewHandler(log, domainAllowlist(t, "localhost", "domain.test"), source, true, false, "") + h := NewHandler(log, domainAllowlist(t, "localhost", "domain.test"), source, true, false, "", "") h.ServeHTTP(w, r) t.Logf("HEADERS: %q", w.Header()) @@ -456,7 +456,7 @@ func TestHandlerHTTP(t *testing.T) { require.NoError(t, err) w := httptest.NewRecorder() - h := NewHandler(log, domainAllowlist(t, "domain.test", "xn--n38h.test"), source, false, false, "") + h := NewHandler(log, domainAllowlist(t, "domain.test", "xn--n38h.test"), source, false, false, "", "") h.ServeHTTP(w, r) t.Logf("HEADERS: %q", w.Header()) @@ -568,7 +568,7 @@ func TestHandlerProxied(t *testing.T) { r.Header.Add("X-Forwarded-Scheme", "https") r.Header.Add("X-Forwarded-Host", "domain.test") w := httptest.NewRecorder() - h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, "") + h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, "", "") h.ServeHTTP(w, r) t.Logf("HEADERS: %q", w.Header()) assert.Equal(t, testCase.code, w.Code) @@ -619,7 +619,7 @@ func TestHandlerJWTIssuer(t *testing.T) { code: http.StatusOK, body: `{ "issuer": "http://domain.test/some/issuer/path/issuer1", - "jwks_uri": "http://domain.test/keys", + "jwks_uri": "https://domain.test/keys", "authorization_endpoint": "", "response_types_supported": [ "id_token" @@ -640,7 +640,7 @@ func TestHandlerJWTIssuer(t *testing.T) { code: http.StatusOK, body: `{ "issuer": "http://domain.test/some/issuer/path/issuer1/", - "jwks_uri": "http://domain.test/keys", + "jwks_uri": "https://domain.test/keys", "authorization_endpoint": "", "response_types_supported": [ "id_token" @@ -661,7 +661,7 @@ func TestHandlerJWTIssuer(t *testing.T) { code: http.StatusOK, body: `{ "issuer": "http://domain.test/", - "jwks_uri": "http://domain.test/keys", + "jwks_uri": "https://domain.test/keys", "authorization_endpoint": "", "response_types_supported": [ "id_token" @@ -682,6 +682,147 @@ func TestHandlerJWTIssuer(t *testing.T) { code: http.StatusOK, body: `{ "issuer": "http://domain.test", + "jwks_uri": "https://domain.test/keys", + "authorization_endpoint": "", + "response_types_supported": [ + "id_token" + ], + "subject_types_supported": [], + "id_token_signing_alg_values_supported": [ + "RS256", + "ES256", + "ES384" + ] +}`, + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + source := new(FakeKeySetSource) + source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime) + + r, err := http.NewRequest(testCase.method, "http://localhost"+testCase.path, nil) + require.NoError(t, err) + r.Header.Add("X-Forwarded-Scheme", "https") + r.Header.Add("X-Forwarded-Host", "domain.test") + w := httptest.NewRecorder() + + h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, testCase.jwtIssuer, "") + h.ServeHTTP(w, r) + + t.Logf("HEADERS: %q", w.Header()) + assert.Equal(t, testCase.code, w.Code) + assert.Equal(t, testCase.body, w.Body.String()) + }) + } +} +func TestHandlerAdvertisedURL(t *testing.T) { + log, _ := test.NewNullLogger() + log.Level = logrus.DebugLevel + testCases := []struct { + name string + advertisedURL string + method string + path string + jwks *jose.JSONWebKeySet + modTime time.Time + pollTime time.Time + code int + body string + }{ + { + name: "GET well-known HTTPS JWT Issuer", + advertisedURL: "https://domain.test/some/issuer/path/issuer1", + method: "GET", + path: "/.well-known/openid-configuration", + code: http.StatusOK, + body: `{ + "issuer": "https://domain.test", + "jwks_uri": "https://domain.test/some/issuer/path/issuer1/keys", + "authorization_endpoint": "", + "response_types_supported": [ + "id_token" + ], + "subject_types_supported": [], + "id_token_signing_alg_values_supported": [ + "RS256", + "ES256", + "ES384" + ] +}`, + }, + { + name: "GET well-known HTTP JWT Issuer", + advertisedURL: "http://domain.test/some/issuer/path/issuer1", + method: "GET", + path: "/.well-known/openid-configuration", + code: http.StatusOK, + body: `{ + "issuer": "https://domain.test", + "jwks_uri": "http://domain.test/some/issuer/path/issuer1/keys", + "authorization_endpoint": "", + "response_types_supported": [ + "id_token" + ], + "subject_types_supported": [], + "id_token_signing_alg_values_supported": [ + "RS256", + "ES256", + "ES384" + ] +}`, + }, + { + name: "GET well-known JWT Issuer with trailing forward-slash", + advertisedURL: "http://domain.test/some/issuer/path/issuer1/", + method: "GET", + path: "/.well-known/openid-configuration", + code: http.StatusOK, + body: `{ + "issuer": "https://domain.test", + "jwks_uri": "http://domain.test/some/issuer/path/issuer1/keys", + "authorization_endpoint": "", + "response_types_supported": [ + "id_token" + ], + "subject_types_supported": [], + "id_token_signing_alg_values_supported": [ + "RS256", + "ES256", + "ES384" + ] +}`, + }, + { + name: "GET well-known JWT Issuer without a path with trailing forward-slash", + advertisedURL: "http://domain.test/", + method: "GET", + path: "/.well-known/openid-configuration", + code: http.StatusOK, + body: `{ + "issuer": "https://domain.test", + "jwks_uri": "http://domain.test/keys", + "authorization_endpoint": "", + "response_types_supported": [ + "id_token" + ], + "subject_types_supported": [], + "id_token_signing_alg_values_supported": [ + "RS256", + "ES256", + "ES384" + ] +}`, + }, + { + name: "GET well-known JWT Issuer without a path", + advertisedURL: "http://domain.test", + method: "GET", + path: "/.well-known/openid-configuration", + code: http.StatusOK, + body: `{ + "issuer": "https://domain.test", "jwks_uri": "http://domain.test/keys", "authorization_endpoint": "", "response_types_supported": [ @@ -708,7 +849,7 @@ func TestHandlerJWTIssuer(t *testing.T) { r.Header.Add("X-Forwarded-Host", "domain.test") w := httptest.NewRecorder() - h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, testCase.jwtIssuer) + h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, "", testCase.advertisedURL) h.ServeHTTP(w, r) t.Logf("HEADERS: %q", w.Header()) diff --git a/support/oidc-discovery-provider/main.go b/support/oidc-discovery-provider/main.go index 0e65d9cb68..d60e0f49e6 100644 --- a/support/oidc-discovery-provider/main.go +++ b/support/oidc-discovery-provider/main.go @@ -67,7 +67,7 @@ func run(configPath string) error { return err } - var handler http.Handler = NewHandler(log, domainPolicy, source, config.AllowInsecureScheme, config.SetKeyUse, config.JWTIssuer) + var handler http.Handler = NewHandler(log, domainPolicy, source, config.AllowInsecureScheme, config.SetKeyUse, config.JWTIssuer, config.AdvertisedURL) if config.LogRequests { log.Info("Logging all requests") handler = logHandler(log, handler)