Skip to content

Commit

Permalink
Add mock, refactor validator to make testing possible
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard87 committed Sep 27, 2024
1 parent 2538465 commit ebe79be
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 24 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mocks: bootstrap
mockgen -source ./api/events/event_handler.go -destination ./api/events/mock/event_handler_mock.go -package mock
mockgen -source ./api/environmentvariables/env_vars_handler.go -destination ./api/environmentvariables/env_vars_handler_mock.go -package environmentvariables
mockgen -source ./api/environmentvariables/env_vars_handler_factory.go -destination ./api/environmentvariables/env_vars_handler_factory_mock.go -package environmentvariables
mockgen -source ./api/utils/authn/validator.go -destination ./api/utils/authn/mock/validator_mock.go -package mock

.PHONY: test
test:
Expand Down
16 changes: 2 additions & 14 deletions api/middleware/auth/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package auth
import (
"context"
"net/http"
"net/url"

token "github.com/equinor/radix-api/api/utils/authn"
"github.com/equinor/radix-common/models"
Expand All @@ -15,18 +14,7 @@ import (
type ctxUserKey struct{}
type ctxImpersonationKey struct{}

func CreateAuthenticationMiddleware(issuer, audience string) negroni.HandlerFunc {
issuerUrl, err := url.Parse(issuer)
if err != nil {
log.Fatal().Err(err).Msg("Error parsing issuer url")
}

// Set up the validator.
jwtValidator, err := token.NewValidator(issuerUrl, audience)
if err != nil {
log.Fatal().Err(err).Msg("Error creating JWT validator")
}

func CreateAuthenticationMiddleware(validator token.ValidatorInterface) negroni.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
ctx := r.Context()
logger := log.Ctx(ctx)
Expand All @@ -43,7 +31,7 @@ func CreateAuthenticationMiddleware(issuer, audience string) negroni.HandlerFunc
}
return
}
principal, err := jwtValidator.ValidateToken(ctx, token)
principal, err := validator.ValidateToken(ctx, token)
if err != nil {
logger.Warn().Err(err).Msg("authentication error")
if err = radixhttp.ErrorResponse(w, r, err); err != nil {
Expand Down
5 changes: 3 additions & 2 deletions api/router/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/equinor/radix-api/api/middleware/logger"
"github.com/equinor/radix-api/api/middleware/recovery"
"github.com/equinor/radix-api/api/utils"
token "github.com/equinor/radix-api/api/utils/authn"
"github.com/equinor/radix-api/models"
"github.com/equinor/radix-api/swaggerui"
"github.com/gorilla/mux"
Expand All @@ -19,7 +20,7 @@ const (
)

// NewAPIHandler Constructor function
func NewAPIHandler(clusterName, oidcIssuer, oidcAudience, radixDNSZone string, kubeUtil utils.KubeUtil, controllers ...models.Controller) http.Handler {
func NewAPIHandler(clusterName string, validator token.ValidatorInterface, radixDNSZone string, kubeUtil utils.KubeUtil, controllers ...models.Controller) http.Handler {
serveMux := http.NewServeMux()
serveMux.Handle("/health/", createHealthHandler())
serveMux.Handle("/swaggerui/", createSwaggerHandler())
Expand All @@ -30,7 +31,7 @@ func NewAPIHandler(clusterName, oidcIssuer, oidcAudience, radixDNSZone string, k
cors.CreateMiddleware(clusterName, radixDNSZone),
logger.CreateZerologRequestIdMiddleware(),
logger.CreateZerologRequestDetailsMiddleware(),
auth.CreateAuthenticationMiddleware(oidcIssuer, oidcAudience),
auth.CreateAuthenticationMiddleware(validator),
logger.CreateZerologRequestLoggerMiddleware(),
)
n.UseHandler(serveMux)
Expand Down
8 changes: 4 additions & 4 deletions api/utils/authn/azure_principal.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ func (c *azureClaims) Validate(_ context.Context) error {
return nil
}

type azurePrincipal struct {
type AzurePrincipal struct {
token string
claims *validator.ValidatedClaims
azureClaims *azureClaims
}

func (p *azurePrincipal) Token() string {
func (p *AzurePrincipal) Token() string {
return p.token
}
func (p *azurePrincipal) IsAuthenticated() bool {
func (p *AzurePrincipal) IsAuthenticated() bool {
return true
}
func (p *azurePrincipal) Subject() string {
func (p *AzurePrincipal) Subject() string {
if p.azureClaims.Upn != "" {
return p.azureClaims.Upn
}
Expand Down
116 changes: 116 additions & 0 deletions api/utils/authn/mock/validator_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions api/utils/authn/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ type TokenPrincipal interface {
}

type ValidatorInterface interface {
ValidateToken(ctx context.Context, token string) (*validator.RegisteredClaims, error)
ValidateToken(context.Context, string) (TokenPrincipal, error)
}

type Validator struct {
validator *validator.Validator
}

var _ ValidatorInterface = &Validator{}

func NewValidator(issuerUrl *url.URL, audience string) (*Validator, error) {
provider := jwks.NewCachingProvider(issuerUrl, 5*time.Minute)

Expand Down Expand Up @@ -59,6 +61,6 @@ func (v *Validator) ValidateToken(ctx context.Context, token string) (TokenPrinc
return nil, errors.New("invalid azure token")
}

principal := &azurePrincipal{token: token, claims: claims, azureClaims: azureClaims}
principal := &AzurePrincipal{token: token, claims: claims, azureClaims: azureClaims}
return principal, nil
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ require (
github.com/kedacore/keda/v2 v2.15.1
github.com/kelseyhightower/envconfig v1.4.0
github.com/marstr/guid v1.1.0
github.com/mitchellh/mapstructure v1.5.0
github.com/prometheus-operator/prometheus-operator/pkg/client v0.76.0
github.com/prometheus/client_golang v1.20.2
github.com/rs/cors v1.11.0
Expand Down Expand Up @@ -79,6 +78,7 @@ require (
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
Expand Down
19 changes: 18 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"io"
"net/http"
_ "net/http/pprof"
"net/url"
"os"
"os/signal"
"sync"
"syscall"
"time"

"github.com/equinor/radix-api/api/secrets"
token "github.com/equinor/radix-api/api/utils/authn"
"github.com/equinor/radix-api/api/utils/tlsvalidation"
"github.com/equinor/radix-api/internal/config"
"github.com/rs/zerolog"
Expand Down Expand Up @@ -76,12 +78,13 @@ func main() {
}

func initializeServer(c config.Config) *http.Server {
jwtValidator := initializeTokenValidator(c)
controllers, err := getControllers(c)
if err != nil {
log.Fatal().Err(err).Msgf("failed to initialize controllers: %v", err)
}

handler := router.NewAPIHandler(c.ClusterName, c.OidcIssuer, c.OidcAudience, c.DNSZone, utils.NewKubeUtil(c.UseOutClusterClient, c.KubernetesApiServer), controllers...)
handler := router.NewAPIHandler(c.ClusterName, jwtValidator, c.DNSZone, utils.NewKubeUtil(c.UseOutClusterClient, c.KubernetesApiServer), controllers...)
srv := &http.Server{
Addr: fmt.Sprintf(":%d", c.Port),
Handler: handler,
Expand All @@ -90,6 +93,20 @@ func initializeServer(c config.Config) *http.Server {
return srv
}

func initializeTokenValidator(c config.Config) *token.Validator {
issuerUrl, err := url.Parse(c.OidcIssuer)
if err != nil {
log.Fatal().Err(err).Msg("Error parsing issuer url")
}

// Set up the validator.
jwtValidator, err := token.NewValidator(issuerUrl, c.OidcAudience)
if err != nil {
log.Fatal().Err(err).Msg("Error creating JWT validator")
}
return jwtValidator
}

func initializeMetricsServer(c config.Config) *http.Server {
log.Info().Msgf("Initializing metrics server on port %d", c.Port)
return &http.Server{
Expand Down

0 comments on commit ebe79be

Please sign in to comment.