Skip to content

Commit

Permalink
Add advertised_url support
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Fox <Kevin.Fox@pnnl.gov>
  • Loading branch information
kfox1111 committed Dec 15, 2024
1 parent 6b33c7f commit 29bdd98
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 17 deletions.
10 changes: 10 additions & 0 deletions support/oidc-discovery-provider/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ type Config struct {

// JWTIssuer specifies the issuer for the OIDC provider configuration request.
JWTIssuer string `hcl:"jwt_issuer"`

// AdvertisedURL specifies the absolute urls to return in documents. Use this if you are fronting the
// discovery provider with a load balancer or reverse proxy
AdvertisedURL string `hcl:"advertised_url"`
}

type ServingCertFileConfig struct {
Expand Down Expand Up @@ -297,6 +301,12 @@ func ParseConfig(hclConfig string) (_ *Config, err error) {
return nil, errs.New("the jwt_issuer url could not be parsed")
}
}
if c.AdvertisedURL != "" {
advertisedURL, err := url.Parse(c.AdvertisedURL)
if err != nil || advertisedURL.Scheme == "" || advertisedURL.Host == "" {
return nil, errs.New("the advertised_url setting could not be parsed")
}
}
return c, nil
}

Expand Down
36 changes: 28 additions & 8 deletions support/oidc-discovery-provider/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,20 @@ type Handler struct {
setKeyUse bool
log logrus.FieldLogger
jwtIssuer string
advertisedURL string

http.Handler
}

func NewHandler(log logrus.FieldLogger, domainPolicy DomainPolicy, source JWKSSource, allowInsecureScheme bool, setKeyUse bool, jwtIssuer string) *Handler {
func NewHandler(log logrus.FieldLogger, domainPolicy DomainPolicy, source JWKSSource, allowInsecureScheme bool, setKeyUse bool, jwtIssuer string, advertisedURL string) *Handler {
h := &Handler{
domainPolicy: domainPolicy,
source: source,
allowInsecureScheme: allowInsecureScheme,
setKeyUse: setKeyUse,
log: log,
jwtIssuer: jwtIssuer,
advertisedURL: advertisedURL,
}

mux := http.NewServeMux()
Expand All @@ -56,6 +58,7 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) {
var host string
var path string
var urlScheme string
var keysURL url.URL
if h.jwtIssuer != "" {
jwtIssuerURL, _ := url.Parse(h.jwtIssuer)
host = jwtIssuerURL.Host
Expand All @@ -68,6 +71,29 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) {
urlScheme = "http"
}
}
if h.advertisedURL != "" {
tmpURL, _ := url.Parse(h.advertisedURL)
keysPath, err := url.JoinPath(tmpURL.Path, "/keys")
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
keysURL = url.URL{
Scheme: tmpURL.Scheme,
Host: tmpURL.Host,
Path: keysPath,
}
} else {
tmpURLScheme := "https"
if h.allowInsecureScheme && r.TLS == nil && r.URL.Scheme != "https" {
tmpURLScheme = "http"
}
keysURL = url.URL{
Scheme: tmpURLScheme,
Host: r.Host,
Path: "/keys",
}
}

if err := h.verifyHost(host); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
Expand All @@ -80,12 +106,6 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) {
Path: path,
}

jwksURI := url.URL{
Scheme: urlScheme,
Host: r.Host,
Path: "/keys",
}

doc := struct {
Issuer string `json:"issuer"`
JWKSURI string `json:"jwks_uri"`
Expand All @@ -98,7 +118,7 @@ func (h *Handler) serveWellKnown(w http.ResponseWriter, r *http.Request) {
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
}{
Issuer: issuerURL.String(),
JWKSURI: jwksURI.String(),
JWKSURI: keysURL.String(),

AuthorizationEndpoint: "",
ResponseTypesSupported: []string{"id_token"},
Expand Down
157 changes: 149 additions & 8 deletions support/oidc-discovery-provider/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func TestHandlerHTTPS(t *testing.T) {
require.NoError(t, err)
w := httptest.NewRecorder()

h := NewHandler(log, domainAllowlist(t, "localhost", "domain.test"), source, false, testCase.setKeyUse, "")
h := NewHandler(log, domainAllowlist(t, "localhost", "domain.test"), source, false, testCase.setKeyUse, "", "")
h.ServeHTTP(w, r)

t.Logf("HEADERS: %q", w.Header())
Expand Down Expand Up @@ -286,7 +286,7 @@ func TestHandlerHTTPInsecure(t *testing.T) {
require.NoError(t, err)
w := httptest.NewRecorder()

h := NewHandler(log, domainAllowlist(t, "localhost", "domain.test"), source, true, false, "")
h := NewHandler(log, domainAllowlist(t, "localhost", "domain.test"), source, true, false, "", "")
h.ServeHTTP(w, r)

t.Logf("HEADERS: %q", w.Header())
Expand Down Expand Up @@ -456,7 +456,7 @@ func TestHandlerHTTP(t *testing.T) {
require.NoError(t, err)
w := httptest.NewRecorder()

h := NewHandler(log, domainAllowlist(t, "domain.test", "xn--n38h.test"), source, false, false, "")
h := NewHandler(log, domainAllowlist(t, "domain.test", "xn--n38h.test"), source, false, false, "", "")
h.ServeHTTP(w, r)

t.Logf("HEADERS: %q", w.Header())
Expand Down Expand Up @@ -568,7 +568,7 @@ func TestHandlerProxied(t *testing.T) {
r.Header.Add("X-Forwarded-Scheme", "https")
r.Header.Add("X-Forwarded-Host", "domain.test")
w := httptest.NewRecorder()
h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, "")
h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, "", "")
h.ServeHTTP(w, r)
t.Logf("HEADERS: %q", w.Header())
assert.Equal(t, testCase.code, w.Code)
Expand Down Expand Up @@ -619,7 +619,7 @@ func TestHandlerJWTIssuer(t *testing.T) {
code: http.StatusOK,
body: `{
"issuer": "http://domain.test/some/issuer/path/issuer1",
"jwks_uri": "http://domain.test/keys",
"jwks_uri": "https://domain.test/keys",
"authorization_endpoint": "",
"response_types_supported": [
"id_token"
Expand All @@ -640,7 +640,7 @@ func TestHandlerJWTIssuer(t *testing.T) {
code: http.StatusOK,
body: `{
"issuer": "http://domain.test/some/issuer/path/issuer1/",
"jwks_uri": "http://domain.test/keys",
"jwks_uri": "https://domain.test/keys",
"authorization_endpoint": "",
"response_types_supported": [
"id_token"
Expand All @@ -661,7 +661,7 @@ func TestHandlerJWTIssuer(t *testing.T) {
code: http.StatusOK,
body: `{
"issuer": "http://domain.test/",
"jwks_uri": "http://domain.test/keys",
"jwks_uri": "https://domain.test/keys",
"authorization_endpoint": "",
"response_types_supported": [
"id_token"
Expand All @@ -682,6 +682,147 @@ func TestHandlerJWTIssuer(t *testing.T) {
code: http.StatusOK,
body: `{
"issuer": "http://domain.test",
"jwks_uri": "https://domain.test/keys",
"authorization_endpoint": "",
"response_types_supported": [
"id_token"
],
"subject_types_supported": [],
"id_token_signing_alg_values_supported": [
"RS256",
"ES256",
"ES384"
]
}`,
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
source := new(FakeKeySetSource)
source.SetKeySet(testCase.jwks, testCase.modTime, testCase.pollTime)

r, err := http.NewRequest(testCase.method, "http://localhost"+testCase.path, nil)
require.NoError(t, err)
r.Header.Add("X-Forwarded-Scheme", "https")
r.Header.Add("X-Forwarded-Host", "domain.test")
w := httptest.NewRecorder()

h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, testCase.jwtIssuer, "")
h.ServeHTTP(w, r)

t.Logf("HEADERS: %q", w.Header())
assert.Equal(t, testCase.code, w.Code)
assert.Equal(t, testCase.body, w.Body.String())
})
}
}
func TestHandlerAdvertisedURL(t *testing.T) {
log, _ := test.NewNullLogger()
log.Level = logrus.DebugLevel
testCases := []struct {
name string
advertisedURL string
method string
path string
jwks *jose.JSONWebKeySet
modTime time.Time
pollTime time.Time
code int
body string
}{
{
name: "GET well-known HTTPS JWT Issuer",
advertisedURL: "https://domain.test/some/issuer/path/issuer1",
method: "GET",
path: "/.well-known/openid-configuration",
code: http.StatusOK,
body: `{
"issuer": "https://domain.test",
"jwks_uri": "https://domain.test/some/issuer/path/issuer1/keys",
"authorization_endpoint": "",
"response_types_supported": [
"id_token"
],
"subject_types_supported": [],
"id_token_signing_alg_values_supported": [
"RS256",
"ES256",
"ES384"
]
}`,
},
{
name: "GET well-known HTTP JWT Issuer",
advertisedURL: "http://domain.test/some/issuer/path/issuer1",
method: "GET",
path: "/.well-known/openid-configuration",
code: http.StatusOK,
body: `{
"issuer": "https://domain.test",
"jwks_uri": "http://domain.test/some/issuer/path/issuer1/keys",
"authorization_endpoint": "",
"response_types_supported": [
"id_token"
],
"subject_types_supported": [],
"id_token_signing_alg_values_supported": [
"RS256",
"ES256",
"ES384"
]
}`,
},
{
name: "GET well-known JWT Issuer with trailing forward-slash",
advertisedURL: "http://domain.test/some/issuer/path/issuer1/",
method: "GET",
path: "/.well-known/openid-configuration",
code: http.StatusOK,
body: `{
"issuer": "https://domain.test",
"jwks_uri": "http://domain.test/some/issuer/path/issuer1/keys",
"authorization_endpoint": "",
"response_types_supported": [
"id_token"
],
"subject_types_supported": [],
"id_token_signing_alg_values_supported": [
"RS256",
"ES256",
"ES384"
]
}`,
},
{
name: "GET well-known JWT Issuer without a path with trailing forward-slash",
advertisedURL: "http://domain.test/",
method: "GET",
path: "/.well-known/openid-configuration",
code: http.StatusOK,
body: `{
"issuer": "https://domain.test",
"jwks_uri": "http://domain.test/keys",
"authorization_endpoint": "",
"response_types_supported": [
"id_token"
],
"subject_types_supported": [],
"id_token_signing_alg_values_supported": [
"RS256",
"ES256",
"ES384"
]
}`,
},
{
name: "GET well-known JWT Issuer without a path",
advertisedURL: "http://domain.test",
method: "GET",
path: "/.well-known/openid-configuration",
code: http.StatusOK,
body: `{
"issuer": "https://domain.test",
"jwks_uri": "http://domain.test/keys",
"authorization_endpoint": "",
"response_types_supported": [
Expand All @@ -708,7 +849,7 @@ func TestHandlerJWTIssuer(t *testing.T) {
r.Header.Add("X-Forwarded-Host", "domain.test")
w := httptest.NewRecorder()

h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, testCase.jwtIssuer)
h := NewHandler(log, domainAllowlist(t, "domain.test"), source, false, false, "", testCase.advertisedURL)
h.ServeHTTP(w, r)

t.Logf("HEADERS: %q", w.Header())
Expand Down
2 changes: 1 addition & 1 deletion support/oidc-discovery-provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func run(configPath string) error {
return err
}

var handler http.Handler = NewHandler(log, domainPolicy, source, config.AllowInsecureScheme, config.SetKeyUse, config.JWTIssuer)
var handler http.Handler = NewHandler(log, domainPolicy, source, config.AllowInsecureScheme, config.SetKeyUse, config.JWTIssuer, config.AdvertisedURL)
if config.LogRequests {
log.Info("Logging all requests")
handler = logHandler(log, handler)
Expand Down

0 comments on commit 29bdd98

Please sign in to comment.