From 92d2a8bea01a0c5de03fac79cdcd49c8343c7d45 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Wed, 12 Jul 2023 15:51:22 -0700 Subject: [PATCH 01/10] Adding a predredirect hook plugin Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 27 ++++++++++++++-------- pkg/server/service.go | 53 +++++++++++++++++++++---------------------- plugins/registry.go | 1 + 3 files changed, 45 insertions(+), 36 deletions(-) diff --git a/auth/handlers.go b/auth/handlers.go index 4133af4e7..fff4cf5b2 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -8,11 +8,6 @@ import ( "strings" "time" - "github.com/flyteorg/flyteadmin/auth/interfaces" - "github.com/flyteorg/flyteadmin/pkg/common" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flytestdlib/errors" - "github.com/flyteorg/flytestdlib/logger" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "golang.org/x/oauth2" "google.golang.org/grpc" @@ -21,6 +16,13 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "google.golang.org/protobuf/runtime/protoiface" + + "github.com/flyteorg/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/plugins" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" ) const ( @@ -29,6 +31,7 @@ const ( FromHTTPVal = "true" ) +type PreRedirectHookFunc func(ctx context.Context) error type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD type UserInfoForwardResponseHandler func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error @@ -39,11 +42,11 @@ type AuthenticatedClientMeta struct { Subject string } -func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer, authCtx interfaces.AuthenticationContext) { +func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) { // Add HTTP handlers for OAuth2 endpoints handler.HandleFunc("/login", RefreshTokensIfExists(ctx, authCtx, GetLoginHandler(ctx, authCtx))) - handler.HandleFunc("/callback", GetCallbackHandler(ctx, authCtx)) + handler.HandleFunc("/callback", GetCallbackHandler(ctx, authCtx, pluginRegistry)) // The metadata endpoint is an RFC-defined constant, but we need a leading / for the handler to pattern match correctly. handler.HandleFunc(fmt.Sprintf("/%s", OIdCMetadataEndpoint), GetOIdCMetadataEndpointRedirectHandler(ctx, authCtx)) @@ -129,14 +132,13 @@ func GetLoginHandler(ctx context.Context, authCtx interfaces.AuthenticationConte logger.Errorf(ctx, "Was not able to create a redirect cookie") } } - http.Redirect(writer, request, url, http.StatusTemporaryRedirect) } } // GetCallbackHandler returns a handler that is called by the OIdC provider with the authorization code to complete // the user authentication flow. -func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationContext) http.HandlerFunc { +func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) http.HandlerFunc { return func(writer http.ResponseWriter, request *http.Request) { logger.Debugf(ctx, "Running callback handler... for RequestURI %v", request.RequestURI) authorizationCode := request.FormValue(AuthorizationResponseCodeType) @@ -178,6 +180,13 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } + preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook) + if preRedirectHook != nil { + if err := preRedirectHook(ctx); err != nil { + logger.Errorf(ctx, "failed the preRedirect hook due to %v", err) + writer.WriteHeader(http.StatusInternalServerError) + } + } redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) } diff --git a/pkg/server/service.go b/pkg/server/service.go index 4a7983087..b060c3ec7 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -9,42 +9,39 @@ import ( "strings" "time" + "github.com/gorilla/handlers" + grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/reflection" "k8s.io/apimachinery/pkg/util/rand" - "github.com/flyteorg/flytestdlib/contextutils" - "github.com/flyteorg/flytestdlib/promutils/labeled" - - runtime2 "github.com/flyteorg/flyteadmin/pkg/runtime" - "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flytestdlib/storage" - - "github.com/flyteorg/flyteadmin/dataproxy" - "github.com/flyteorg/flyteadmin/plugins" - "github.com/flyteorg/flyteadmin/auth" "github.com/flyteorg/flyteadmin/auth/authzserver" authConfig "github.com/flyteorg/flyteadmin/auth/config" "github.com/flyteorg/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyteadmin/dataproxy" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/config" "github.com/flyteorg/flyteadmin/pkg/rpc" "github.com/flyteorg/flyteadmin/pkg/rpc/adminservice" + runtime2 "github.com/flyteorg/flyteadmin/pkg/runtime" runtimeIfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteadmin/plugins" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/logger" - "github.com/gorilla/handlers" - grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" - grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "github.com/grpc-ecosystem/grpc-gateway/runtime" - "github.com/pkg/errors" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/health" - "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/reflection" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" ) var defaultCorsHeaders = []string{"Content-Type"} @@ -163,7 +160,7 @@ func healthCheckFunc(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) } -func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig.Config, authCtx interfaces.AuthenticationContext, +func newHTTPServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, _ *authConfig.Config, authCtx interfaces.AuthenticationContext, additionalHandlers map[string]func(http.ResponseWriter, *http.Request), grpcAddress string, grpcConnectionOpts ...grpc.DialOption) (*http.ServeMux, error) { @@ -191,7 +188,7 @@ func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig. if cfg.Security.UseAuth { // Add HTTP handlers for OIDC endpoints - auth.RegisterHandlers(ctx, mux, authCtx) + auth.RegisterHandlers(ctx, mux, authCtx, pluginRegistry) // Add HTTP handlers for OAuth2 endpoints authzserver.RegisterHandlers(mux, authCtx) @@ -278,7 +275,8 @@ func generateRequestID() string { func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, authCfg *authConfig.Config, storageConfig *storage.Config, - additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error { + additionalHandlers map[string]func(http.ResponseWriter, *http.Request), + scope promutils.Scope) error { logger.Infof(ctx, "Serving Flyte Admin Insecure") // This will parse configuration and create the necessary objects for dealing with auth @@ -343,7 +341,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, grpcOptions = append(grpcOptions, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) } - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetGrpcHostAddress(), grpcOptions...) + httpServer, err := newHTTPServer(ctx, pluginRegistry, cfg, authCfg, authCtx, additionalHandlers, cfg.GetGrpcHostAddress(), grpcOptions...) if err != nil { return err } @@ -390,7 +388,8 @@ func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Ha func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, authCfg *authConfig.Config, storageCfg *storage.Config, - additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error { + additionalHandlers map[string]func(http.ResponseWriter, *http.Request), + scope promutils.Scope) error { certPool, cert, err := GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) if err != nil { return err @@ -445,7 +444,7 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c serverOpts = append(serverOpts, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) } - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetHostAddress(), serverOpts...) + httpServer, err := newHTTPServer(ctx, pluginRegistry, cfg, authCfg, authCtx, additionalHandlers, cfg.GetHostAddress(), serverOpts...) if err != nil { return err } diff --git a/plugins/registry.go b/plugins/registry.go index 3c2186326..14682f7e8 100644 --- a/plugins/registry.go +++ b/plugins/registry.go @@ -12,6 +12,7 @@ const ( PluginIDWorkflowExecutor PluginID = "WorkflowExecutor" PluginIDDataProxy PluginID = "DataProxy" PluginIDUnaryServiceMiddleware PluginID = "UnaryServiceMiddleware" + PluginIDPreRedirectHook PluginID = "PreRedirectHook" ) type AtomicRegistry struct { From 19ddf62270bec25c7039f8c793a40086c63e6519 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Thu, 13 Jul 2023 18:29:35 -0700 Subject: [PATCH 02/10] Add test logs Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auth/handlers.go b/auth/handlers.go index fff4cf5b2..c01f2755e 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -182,10 +182,12 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook) if preRedirectHook != nil { + logger.Infof(ctx, "preRedirect hook is set") if err := preRedirectHook(ctx); err != nil { logger.Errorf(ctx, "failed the preRedirect hook due to %v", err) writer.WriteHeader(http.StatusInternalServerError) } + logger.Infof(ctx, "Successfully called the preRedirect hook") } redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) From 256a582000aa156744e59ee8f27101196a5c8a07 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Fri, 14 Jul 2023 11:58:51 -0700 Subject: [PATCH 03/10] test logs Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/auth/handlers.go b/auth/handlers.go index c01f2755e..c0105eaff 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -145,6 +145,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo ctx = context.WithValue(ctx, oauth2.HTTPClient, authCtx.GetHTTPClient()) + logger.Debugf(ctx, "Going to verify th CSRF cookie... for RequestURI %v", request.RequestURI) err := VerifyCsrfCookie(ctx, request) if err != nil { logger.Errorf(ctx, "Invalid CSRF token cookie %s", err) @@ -152,6 +153,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } + logger.Debugf(ctx, "Going to exchange the token for the authorizationCode %v ... for RequestURI %v", authorizationCode, request.RequestURI) token, err := authCtx.OAuth2ClientConfig(GetPublicURL(ctx, request, authCtx.Options())).Exchange(ctx, authorizationCode) if err != nil { logger.Errorf(ctx, "Error when exchanging code %s", err) @@ -159,6 +161,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } + logger.Debugf(ctx, "Going to set token cookies ... for RequestURI %v", request.RequestURI) err = authCtx.CookieManager().SetTokenCookies(ctx, writer, token) if err != nil { logger.Errorf(ctx, "Error setting encrypted JWT cookie %s", err) @@ -166,6 +169,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } + logger.Debugf(ctx, "Going to query user info token ... for RequestURI %v", request.RequestURI) userInfo, err := QueryUserInfoUsingAccessToken(ctx, request, authCtx, token.AccessToken) if err != nil { logger.Errorf(ctx, "Failed to query user info. Error: %v", err) @@ -173,6 +177,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } + logger.Debugf(ctx, "Going to set user info cookie ... for RequestURI %v with userInfo %v", request.RequestURI, userInfo) err = authCtx.CookieManager().SetUserInfoCookie(ctx, writer, userInfo) if err != nil { logger.Errorf(ctx, "Error setting encrypted user info cookie. Error: %v", err) @@ -180,6 +185,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } + logger.Infof(ctx, "Going to look up the preredirect hook in the registry") preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook) if preRedirectHook != nil { logger.Infof(ctx, "preRedirect hook is set") @@ -189,7 +195,9 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo } logger.Infof(ctx, "Successfully called the preRedirect hook") } + redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) + logger.Infof(ctx, "Going to perform the redirect with redirectURl %v", redirectURL) http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) } } From 7c7915d9a43323cbba4dc30a7d1cd7690fd045eb Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Fri, 14 Jul 2023 14:37:10 -0700 Subject: [PATCH 04/10] fix Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/auth/handlers.go b/auth/handlers.go index c0105eaff..d359ccc5c 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -31,7 +31,7 @@ const ( FromHTTPVal = "true" ) -type PreRedirectHookFunc func(ctx context.Context) error +type PreRedirectHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) error type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD type UserInfoForwardResponseHandler func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error @@ -189,13 +189,13 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook) if preRedirectHook != nil { logger.Infof(ctx, "preRedirect hook is set") - if err := preRedirectHook(ctx); err != nil { + if err := preRedirectHook(ctx, authCtx, request, writer); err != nil { logger.Errorf(ctx, "failed the preRedirect hook due to %v", err) writer.WriteHeader(http.StatusInternalServerError) + return } logger.Infof(ctx, "Successfully called the preRedirect hook") } - redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) logger.Infof(ctx, "Going to perform the redirect with redirectURl %v", redirectURL) http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) From 799f72d7c1be70624452b7a5863cb815674b02a1 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Tue, 18 Jul 2023 11:52:07 -0700 Subject: [PATCH 05/10] Reading identity token for getting subject Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 21 ++++++++++++++++++--- auth/handlers_test.go | 2 +- plugins/registry_test.go | 21 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/auth/handlers.go b/auth/handlers.go index d359ccc5c..be420b20f 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -189,12 +189,27 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook) if preRedirectHook != nil { logger.Infof(ctx, "preRedirect hook is set") - if err := preRedirectHook(ctx, authCtx, request, writer); err != nil { - logger.Errorf(ctx, "failed the preRedirect hook due to %v", err) + redirectHookCtx := ctx + if idTokenRaw, converted := token.Extra(idTokenExtra).(string); converted { + identityContext, err := IdentityContextFromIDTokenToken(redirectHookCtx, idTokenRaw, authCtx.Options().UserAuth.OpenID.ClientID, + authCtx.OidcProvider(), userInfo) + if err != nil { + logger.Errorf(redirectHookCtx, "failed to get identity context from the IDToken due to %v", err) + writer.WriteHeader(http.StatusInternalServerError) + return + } + redirectHookCtx = identityContext.WithContext(ctx) + } else { + logger.Errorf(redirectHookCtx, "failed to get IDToken from the exchanged token with the auth server") + writer.WriteHeader(http.StatusInternalServerError) + } + + if err := preRedirectHook(redirectHookCtx, authCtx, request, writer); err != nil { + logger.Errorf(redirectHookCtx, "failed the preRedirect hook due to %v", err) writer.WriteHeader(http.StatusInternalServerError) return } - logger.Infof(ctx, "Successfully called the preRedirect hook") + logger.Infof(redirectHookCtx, "Successfully called the preRedirect hook") } redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) logger.Infof(ctx, "Going to perform the redirect with redirectURl %v", redirectURL) diff --git a/auth/handlers_test.go b/auth/handlers_test.go index 88232de1c..532e3cd73 100644 --- a/auth/handlers_test.go +++ b/auth/handlers_test.go @@ -153,7 +153,7 @@ func TestGetCallbackHandler(t *testing.T) { t.Run("forbidden request when accessing user info", func(t *testing.T) { mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx) + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, nil) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) addCsrfCookie(request) addStateString(request) diff --git a/plugins/registry_test.go b/plugins/registry_test.go index 757b596fd..cb9ae1ffc 100644 --- a/plugins/registry_test.go +++ b/plugins/registry_test.go @@ -1,6 +1,8 @@ package plugins import ( + "context" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -21,6 +23,25 @@ func TestNewAtomicRegistry(t *testing.T) { assert.Equal(t, 5, r.Get(PluginIDDataProxy)) } +type PreRedirectHookFunc func(ctx context.Context) error + +func TestRedirectHook(t *testing.T) { + ar := NewAtomicRegistry(nil) + r := NewRegistry() + + var redirectHookfn PreRedirectHookFunc + redirectHookfn = func(ctx context.Context) error { + return fmt.Errorf("redirect hook error") + } + err := r.Register(PluginIDPreRedirectHook, redirectHookfn) + assert.NoError(t, err) + ar.Store(r) + r = ar.Load() + fn := Get[PreRedirectHookFunc](r, PluginIDPreRedirectHook) + err = fn(context.Background()) + assert.Equal(t, fmt.Errorf("redirect hook error"), err) +} + func TestRegistry_RegisterDefault(t *testing.T) { r := NewRegistry() r.RegisterDefault("hello", 5) From ae9e5ccc82f08e324e9872f5ef181c083a7b80b5 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Tue, 18 Jul 2023 12:20:00 -0700 Subject: [PATCH 06/10] reverting Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/auth/handlers.go b/auth/handlers.go index be420b20f..d359ccc5c 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -189,27 +189,12 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook) if preRedirectHook != nil { logger.Infof(ctx, "preRedirect hook is set") - redirectHookCtx := ctx - if idTokenRaw, converted := token.Extra(idTokenExtra).(string); converted { - identityContext, err := IdentityContextFromIDTokenToken(redirectHookCtx, idTokenRaw, authCtx.Options().UserAuth.OpenID.ClientID, - authCtx.OidcProvider(), userInfo) - if err != nil { - logger.Errorf(redirectHookCtx, "failed to get identity context from the IDToken due to %v", err) - writer.WriteHeader(http.StatusInternalServerError) - return - } - redirectHookCtx = identityContext.WithContext(ctx) - } else { - logger.Errorf(redirectHookCtx, "failed to get IDToken from the exchanged token with the auth server") - writer.WriteHeader(http.StatusInternalServerError) - } - - if err := preRedirectHook(redirectHookCtx, authCtx, request, writer); err != nil { - logger.Errorf(redirectHookCtx, "failed the preRedirect hook due to %v", err) + if err := preRedirectHook(ctx, authCtx, request, writer); err != nil { + logger.Errorf(ctx, "failed the preRedirect hook due to %v", err) writer.WriteHeader(http.StatusInternalServerError) return } - logger.Infof(redirectHookCtx, "Successfully called the preRedirect hook") + logger.Infof(ctx, "Successfully called the preRedirect hook") } redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) logger.Infof(ctx, "Going to perform the redirect with redirectURl %v", redirectURL) From f81f441e2af84892a2c13843cf381b355c8ba686 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Wed, 9 Aug 2023 15:58:48 -0700 Subject: [PATCH 07/10] Adding PreRedirectHookError Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/auth/handlers.go b/auth/handlers.go index d359ccc5c..e9e458da0 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -31,7 +31,23 @@ const ( FromHTTPVal = "true" ) -type PreRedirectHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) error +type PreRedirectHookError struct { + Message string + Code int +} + +func (e *PreRedirectHookError) Error() string { + return e.Message +} + +// PreRedirectHookFunc Interface used for running custom code before the redirect happens during a successful auth flow. +// This might be useful in cases where the auth flow allows the user to login since the IDP has been configured +// for eg: to allow all users from a particular domain to login +// but you want to restrict access to only a particular set of user ids. eg : users@domain.com are allowed to login but user user1@domain.com, user2@domain.com +// should only be allowed +// PreRedirectHookError is the error interface which allows the user to set correct http status code and Message to be set in case the function returns an error +// without which the current usage in GetCallbackHandler will set this to InternalServerError +type PreRedirectHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) *PreRedirectHookError type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD type UserInfoForwardResponseHandler func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error @@ -145,7 +161,6 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo ctx = context.WithValue(ctx, oauth2.HTTPClient, authCtx.GetHTTPClient()) - logger.Debugf(ctx, "Going to verify th CSRF cookie... for RequestURI %v", request.RequestURI) err := VerifyCsrfCookie(ctx, request) if err != nil { logger.Errorf(ctx, "Invalid CSRF token cookie %s", err) @@ -153,7 +168,6 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } - logger.Debugf(ctx, "Going to exchange the token for the authorizationCode %v ... for RequestURI %v", authorizationCode, request.RequestURI) token, err := authCtx.OAuth2ClientConfig(GetPublicURL(ctx, request, authCtx.Options())).Exchange(ctx, authorizationCode) if err != nil { logger.Errorf(ctx, "Error when exchanging code %s", err) @@ -161,7 +175,6 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } - logger.Debugf(ctx, "Going to set token cookies ... for RequestURI %v", request.RequestURI) err = authCtx.CookieManager().SetTokenCookies(ctx, writer, token) if err != nil { logger.Errorf(ctx, "Error setting encrypted JWT cookie %s", err) @@ -169,7 +182,6 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } - logger.Debugf(ctx, "Going to query user info token ... for RequestURI %v", request.RequestURI) userInfo, err := QueryUserInfoUsingAccessToken(ctx, request, authCtx, token.AccessToken) if err != nil { logger.Errorf(ctx, "Failed to query user info. Error: %v", err) @@ -177,7 +189,6 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } - logger.Debugf(ctx, "Going to set user info cookie ... for RequestURI %v with userInfo %v", request.RequestURI, userInfo) err = authCtx.CookieManager().SetUserInfoCookie(ctx, writer, userInfo) if err != nil { logger.Errorf(ctx, "Error setting encrypted user info cookie. Error: %v", err) @@ -185,16 +196,19 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo return } - logger.Infof(ctx, "Going to look up the preredirect hook in the registry") preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook) if preRedirectHook != nil { logger.Infof(ctx, "preRedirect hook is set") if err := preRedirectHook(ctx, authCtx, request, writer); err != nil { - logger.Errorf(ctx, "failed the preRedirect hook due to %v", err) - writer.WriteHeader(http.StatusInternalServerError) + logger.Errorf(ctx, "failed the preRedirect hook due %v with status code %v", err.Message, err.Code) + if http.StatusText(err.Code) != "" { + writer.WriteHeader(err.Code) + } else { + writer.WriteHeader(http.StatusInternalServerError) + } return } - logger.Infof(ctx, "Successfully called the preRedirect hook") + logger.Info(ctx, "Successfully called the preRedirect hook") } redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) logger.Infof(ctx, "Going to perform the redirect with redirectURl %v", redirectURL) From 253e0b22b74de837c7906ea9d2130e39fad9f938 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Wed, 9 Aug 2023 17:33:59 -0700 Subject: [PATCH 08/10] Add some more tests Signed-off-by: pmahindrakar-oss --- auth/handlers_test.go | 63 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/auth/handlers_test.go b/auth/handlers_test.go index 532e3cd73..797ba3350 100644 --- a/auth/handlers_test.go +++ b/auth/handlers_test.go @@ -10,18 +10,19 @@ import ( "strings" "testing" + "github.com/coreos/go-oidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/oauth2" "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyteadmin/auth/config" + "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flyteadmin/auth/interfaces/mocks" "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/plugins" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" stdConfig "github.com/flyteorg/flytestdlib/config" - - "github.com/coreos/go-oidc" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "golang.org/x/oauth2" ) const ( @@ -81,7 +82,8 @@ func TestGetCallbackHandlerWithErrorOnToken(t *testing.T) { defer localServer.Close() http.DefaultClient = localServer.Client() mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx) + r := plugins.NewRegistry() + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) addCsrfCookie(request) addStateString(request) @@ -102,7 +104,8 @@ func TestGetCallbackHandlerWithUnAuthorized(t *testing.T) { defer localServer.Close() http.DefaultClient = localServer.Client() mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx) + r := plugins.NewRegistry() + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) writer := httptest.NewRecorder() callbackHandlerFunc(writer, request) @@ -153,7 +156,8 @@ func TestGetCallbackHandler(t *testing.T) { t.Run("forbidden request when accessing user info", func(t *testing.T) { mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, nil) + r := plugins.NewRegistry() + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) addCsrfCookie(request) addStateString(request) @@ -172,9 +176,16 @@ func TestGetCallbackHandler(t *testing.T) { assert.Equal(t, "403 Forbidden", writer.Result().Status) }) - t.Run("successful callback and redirect", func(t *testing.T) { + t.Run("successful callback with redirect and successful preredirect hook call", func(t *testing.T) { mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) - callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx) + r := plugins.NewRegistry() + var redirectFunc PreRedirectHookFunc + redirectFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError { + return nil + } + + r.RegisterDefault(plugins.PluginIDPreRedirectHook, redirectFunc) + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) addCsrfCookie(request) addStateString(request) @@ -193,6 +204,38 @@ func TestGetCallbackHandler(t *testing.T) { callbackHandlerFunc(writer, request) assert.Equal(t, "307 Temporary Redirect", writer.Result().Status) }) + + t.Run("successful callback with pre-redirecthook failure", func(t *testing.T) { + mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) + r := plugins.NewRegistry() + var redirectFunc PreRedirectHookFunc + redirectFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError { + return &PreRedirectHookError{ + Code: http.StatusPreconditionFailed, + Message: "precondition error", + } + } + + r.RegisterDefault(plugins.PluginIDPreRedirectHook, redirectFunc) + callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r) + request := httptest.NewRequest("GET", localServer.URL+"/callback", nil) + addCsrfCookie(request) + addStateString(request) + writer := httptest.NewRecorder() + openIDConfigJSON = fmt.Sprintf(`{ + "userinfo_endpoint": "%v/userinfo", + "issuer": "%v", + "authorization_endpoint": "%v/auth", + "token_endpoint": "%v/token", + "jwks_uri": "%v/keys", + "id_token_signing_alg_values_supported": ["RS256"] + }`, issuer, issuer, issuer, issuer, issuer) + oidcProvider, err := oidc.NewProvider(ctx, issuer) + assert.Nil(t, err) + mockAuthCtx.OnOidcProviderMatch().Return(oidcProvider) + callbackHandlerFunc(writer, request) + assert.Equal(t, "412 Precondition Failed", writer.Result().Status) + }) } func TestGetLoginHandler(t *testing.T) { From 269b0e1dcba1092fd90af0ceacf3689afcf08cd4 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Wed, 9 Aug 2023 18:00:09 -0700 Subject: [PATCH 09/10] lint fixes Signed-off-by: pmahindrakar-oss --- auth/handlers_test.go | 6 ++---- plugins/registry_test.go | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/auth/handlers_test.go b/auth/handlers_test.go index 797ba3350..449b13c4a 100644 --- a/auth/handlers_test.go +++ b/auth/handlers_test.go @@ -179,8 +179,7 @@ func TestGetCallbackHandler(t *testing.T) { t.Run("successful callback with redirect and successful preredirect hook call", func(t *testing.T) { mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) r := plugins.NewRegistry() - var redirectFunc PreRedirectHookFunc - redirectFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError { + var redirectFunc PreRedirectHookFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError { return nil } @@ -208,8 +207,7 @@ func TestGetCallbackHandler(t *testing.T) { t.Run("successful callback with pre-redirecthook failure", func(t *testing.T) { mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL) r := plugins.NewRegistry() - var redirectFunc PreRedirectHookFunc - redirectFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError { + var redirectFunc PreRedirectHookFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError { return &PreRedirectHookError{ Code: http.StatusPreconditionFailed, Message: "precondition error", diff --git a/plugins/registry_test.go b/plugins/registry_test.go index cb9ae1ffc..0737c1281 100644 --- a/plugins/registry_test.go +++ b/plugins/registry_test.go @@ -29,8 +29,7 @@ func TestRedirectHook(t *testing.T) { ar := NewAtomicRegistry(nil) r := NewRegistry() - var redirectHookfn PreRedirectHookFunc - redirectHookfn = func(ctx context.Context) error { + var redirectHookfn PreRedirectHookFunc = func(ctx context.Context) error { return fmt.Errorf("redirect hook error") } err := r.Register(PluginIDPreRedirectHook, redirectHookfn) From c7b92c7b9f01458394e86472bbe56ceedaf1f9c5 Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Thu, 10 Aug 2023 10:43:58 -0700 Subject: [PATCH 10/10] removed log line Signed-off-by: pmahindrakar-oss --- auth/handlers.go | 1 - 1 file changed, 1 deletion(-) diff --git a/auth/handlers.go b/auth/handlers.go index e9e458da0..9604d90ec 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -211,7 +211,6 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo logger.Info(ctx, "Successfully called the preRedirect hook") } redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request) - logger.Infof(ctx, "Going to perform the redirect with redirectURl %v", redirectURL) http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect) } }