diff --git a/Makefile b/Makefile index 573f4abe..6e74f912 100644 --- a/Makefile +++ b/Makefile @@ -28,6 +28,7 @@ mocks: bootstrap mockgen -source ./api/events/event_handler.go -destination ./api/events/mock/event_handler_mock.go -package mock mockgen -source ./api/environmentvariables/env_vars_handler.go -destination ./api/environmentvariables/env_vars_handler_mock.go -package environmentvariables mockgen -source ./api/environmentvariables/env_vars_handler_factory.go -destination ./api/environmentvariables/env_vars_handler_factory_mock.go -package environmentvariables + mockgen -source ./api/utils/authn/validator.go -destination ./api/utils/authn/mock/validator_mock.go -package mock .PHONY: test test: diff --git a/api/middleware/auth/authentication.go b/api/middleware/auth/authentication.go index 8dfba5f5..3af7b314 100644 --- a/api/middleware/auth/authentication.go +++ b/api/middleware/auth/authentication.go @@ -3,7 +3,6 @@ package auth import ( "context" "net/http" - "net/url" token "github.com/equinor/radix-api/api/utils/authn" "github.com/equinor/radix-common/models" @@ -15,18 +14,7 @@ import ( type ctxUserKey struct{} type ctxImpersonationKey struct{} -func CreateAuthenticationMiddleware(issuer, audience string) negroni.HandlerFunc { - issuerUrl, err := url.Parse(issuer) - if err != nil { - log.Fatal().Err(err).Msg("Error parsing issuer url") - } - - // Set up the validator. - jwtValidator, err := token.NewValidator(issuerUrl, audience) - if err != nil { - log.Fatal().Err(err).Msg("Error creating JWT validator") - } - +func CreateAuthenticationMiddleware(validator token.ValidatorInterface) negroni.HandlerFunc { return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { ctx := r.Context() logger := log.Ctx(ctx) @@ -43,7 +31,7 @@ func CreateAuthenticationMiddleware(issuer, audience string) negroni.HandlerFunc } return } - principal, err := jwtValidator.ValidateToken(ctx, token) + principal, err := validator.ValidateToken(ctx, token) if err != nil { logger.Warn().Err(err).Msg("authentication error") if err = radixhttp.ErrorResponse(w, r, err); err != nil { diff --git a/api/router/api.go b/api/router/api.go index feb1ff34..c7126b90 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -8,6 +8,7 @@ import ( "github.com/equinor/radix-api/api/middleware/logger" "github.com/equinor/radix-api/api/middleware/recovery" "github.com/equinor/radix-api/api/utils" + token "github.com/equinor/radix-api/api/utils/authn" "github.com/equinor/radix-api/models" "github.com/equinor/radix-api/swaggerui" "github.com/gorilla/mux" @@ -19,7 +20,7 @@ const ( ) // NewAPIHandler Constructor function -func NewAPIHandler(clusterName, oidcIssuer, oidcAudience, radixDNSZone string, kubeUtil utils.KubeUtil, controllers ...models.Controller) http.Handler { +func NewAPIHandler(clusterName string, validator token.ValidatorInterface, radixDNSZone string, kubeUtil utils.KubeUtil, controllers ...models.Controller) http.Handler { serveMux := http.NewServeMux() serveMux.Handle("/health/", createHealthHandler()) serveMux.Handle("/swaggerui/", createSwaggerHandler()) @@ -30,7 +31,7 @@ func NewAPIHandler(clusterName, oidcIssuer, oidcAudience, radixDNSZone string, k cors.CreateMiddleware(clusterName, radixDNSZone), logger.CreateZerologRequestIdMiddleware(), logger.CreateZerologRequestDetailsMiddleware(), - auth.CreateAuthenticationMiddleware(oidcIssuer, oidcAudience), + auth.CreateAuthenticationMiddleware(validator), logger.CreateZerologRequestLoggerMiddleware(), ) n.UseHandler(serveMux) diff --git a/api/utils/authn/azure_principal.go b/api/utils/authn/azure_principal.go index 74d3a121..b469c53a 100644 --- a/api/utils/authn/azure_principal.go +++ b/api/utils/authn/azure_principal.go @@ -24,19 +24,19 @@ func (c *azureClaims) Validate(_ context.Context) error { return nil } -type azurePrincipal struct { +type AzurePrincipal struct { token string claims *validator.ValidatedClaims azureClaims *azureClaims } -func (p *azurePrincipal) Token() string { +func (p *AzurePrincipal) Token() string { return p.token } -func (p *azurePrincipal) IsAuthenticated() bool { +func (p *AzurePrincipal) IsAuthenticated() bool { return true } -func (p *azurePrincipal) Subject() string { +func (p *AzurePrincipal) Subject() string { if p.azureClaims.Upn != "" { return p.azureClaims.Upn } diff --git a/api/utils/authn/mock/validator_mock.go b/api/utils/authn/mock/validator_mock.go new file mode 100644 index 00000000..4c99f918 --- /dev/null +++ b/api/utils/authn/mock/validator_mock.go @@ -0,0 +1,116 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./api/utils/authn/validator.go + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + token "github.com/equinor/radix-api/api/utils/authn" + gomock "github.com/golang/mock/gomock" +) + +// MockTokenPrincipal is a mock of TokenPrincipal interface. +type MockTokenPrincipal struct { + ctrl *gomock.Controller + recorder *MockTokenPrincipalMockRecorder +} + +// MockTokenPrincipalMockRecorder is the mock recorder for MockTokenPrincipal. +type MockTokenPrincipalMockRecorder struct { + mock *MockTokenPrincipal +} + +// NewMockTokenPrincipal creates a new mock instance. +func NewMockTokenPrincipal(ctrl *gomock.Controller) *MockTokenPrincipal { + mock := &MockTokenPrincipal{ctrl: ctrl} + mock.recorder = &MockTokenPrincipalMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTokenPrincipal) EXPECT() *MockTokenPrincipalMockRecorder { + return m.recorder +} + +// IsAuthenticated mocks base method. +func (m *MockTokenPrincipal) IsAuthenticated() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsAuthenticated") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsAuthenticated indicates an expected call of IsAuthenticated. +func (mr *MockTokenPrincipalMockRecorder) IsAuthenticated() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsAuthenticated", reflect.TypeOf((*MockTokenPrincipal)(nil).IsAuthenticated)) +} + +// Subject mocks base method. +func (m *MockTokenPrincipal) Subject() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Subject") + ret0, _ := ret[0].(string) + return ret0 +} + +// Subject indicates an expected call of Subject. +func (mr *MockTokenPrincipalMockRecorder) Subject() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subject", reflect.TypeOf((*MockTokenPrincipal)(nil).Subject)) +} + +// Token mocks base method. +func (m *MockTokenPrincipal) Token() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Token") + ret0, _ := ret[0].(string) + return ret0 +} + +// Token indicates an expected call of Token. +func (mr *MockTokenPrincipalMockRecorder) Token() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Token", reflect.TypeOf((*MockTokenPrincipal)(nil).Token)) +} + +// MockValidatorInterface is a mock of ValidatorInterface interface. +type MockValidatorInterface struct { + ctrl *gomock.Controller + recorder *MockValidatorInterfaceMockRecorder +} + +// MockValidatorInterfaceMockRecorder is the mock recorder for MockValidatorInterface. +type MockValidatorInterfaceMockRecorder struct { + mock *MockValidatorInterface +} + +// NewMockValidatorInterface creates a new mock instance. +func NewMockValidatorInterface(ctrl *gomock.Controller) *MockValidatorInterface { + mock := &MockValidatorInterface{ctrl: ctrl} + mock.recorder = &MockValidatorInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockValidatorInterface) EXPECT() *MockValidatorInterfaceMockRecorder { + return m.recorder +} + +// ValidateToken mocks base method. +func (m *MockValidatorInterface) ValidateToken(arg0 context.Context, arg1 string) (token.TokenPrincipal, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateToken", arg0, arg1) + ret0, _ := ret[0].(token.TokenPrincipal) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ValidateToken indicates an expected call of ValidateToken. +func (mr *MockValidatorInterfaceMockRecorder) ValidateToken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateToken", reflect.TypeOf((*MockValidatorInterface)(nil).ValidateToken), arg0, arg1) +} diff --git a/api/utils/authn/validator.go b/api/utils/authn/validator.go index 097b154a..b11e2bae 100644 --- a/api/utils/authn/validator.go +++ b/api/utils/authn/validator.go @@ -17,13 +17,15 @@ type TokenPrincipal interface { } type ValidatorInterface interface { - ValidateToken(ctx context.Context, token string) (*validator.RegisteredClaims, error) + ValidateToken(context.Context, string) (TokenPrincipal, error) } type Validator struct { validator *validator.Validator } +var _ ValidatorInterface = &Validator{} + func NewValidator(issuerUrl *url.URL, audience string) (*Validator, error) { provider := jwks.NewCachingProvider(issuerUrl, 5*time.Minute) @@ -59,6 +61,6 @@ func (v *Validator) ValidateToken(ctx context.Context, token string) (TokenPrinc return nil, errors.New("invalid azure token") } - principal := &azurePrincipal{token: token, claims: claims, azureClaims: azureClaims} + principal := &AzurePrincipal{token: token, claims: claims, azureClaims: azureClaims} return principal, nil } diff --git a/go.mod b/go.mod index 276e6854..ae98b843 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( github.com/kedacore/keda/v2 v2.15.1 github.com/kelseyhightower/envconfig v1.4.0 github.com/marstr/guid v1.1.0 - github.com/mitchellh/mapstructure v1.5.0 github.com/prometheus-operator/prometheus-operator/pkg/client v0.76.0 github.com/prometheus/client_golang v1.20.2 github.com/rs/cors v1.11.0 @@ -79,6 +78,7 @@ require ( github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect diff --git a/main.go b/main.go index c16f841e..8b42790d 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "io" "net/http" _ "net/http/pprof" + "net/url" "os" "os/signal" "sync" @@ -14,6 +15,7 @@ import ( "time" "github.com/equinor/radix-api/api/secrets" + token "github.com/equinor/radix-api/api/utils/authn" "github.com/equinor/radix-api/api/utils/tlsvalidation" "github.com/equinor/radix-api/internal/config" "github.com/rs/zerolog" @@ -76,12 +78,13 @@ func main() { } func initializeServer(c config.Config) *http.Server { + jwtValidator := initializeTokenValidator(c) controllers, err := getControllers(c) if err != nil { log.Fatal().Err(err).Msgf("failed to initialize controllers: %v", err) } - handler := router.NewAPIHandler(c.ClusterName, c.OidcIssuer, c.OidcAudience, c.DNSZone, utils.NewKubeUtil(c.UseOutClusterClient, c.KubernetesApiServer), controllers...) + handler := router.NewAPIHandler(c.ClusterName, jwtValidator, c.DNSZone, utils.NewKubeUtil(c.UseOutClusterClient, c.KubernetesApiServer), controllers...) srv := &http.Server{ Addr: fmt.Sprintf(":%d", c.Port), Handler: handler, @@ -90,6 +93,20 @@ func initializeServer(c config.Config) *http.Server { return srv } +func initializeTokenValidator(c config.Config) *token.Validator { + issuerUrl, err := url.Parse(c.OidcIssuer) + if err != nil { + log.Fatal().Err(err).Msg("Error parsing issuer url") + } + + // Set up the validator. + jwtValidator, err := token.NewValidator(issuerUrl, c.OidcAudience) + if err != nil { + log.Fatal().Err(err).Msg("Error creating JWT validator") + } + return jwtValidator +} + func initializeMetricsServer(c config.Config) *http.Server { log.Info().Msgf("Initializing metrics server on port %d", c.Port) return &http.Server{