Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Introduce pre redirect hook plugin during auth callback #601

Merged
merged 10 commits into from
Aug 10, 2023
Merged
50 changes: 41 additions & 9 deletions auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
"strings"
"time"

"github.com/flyteorg/flyteadmin/auth/interfaces"
"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/errors"
"github.com/flyteorg/flytestdlib/logger"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
"golang.org/x/oauth2"
"google.golang.org/grpc"
Expand All @@ -21,6 +16,13 @@
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/runtime/protoiface"

"github.com/flyteorg/flyteadmin/auth/interfaces"
"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteadmin/plugins"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/errors"
"github.com/flyteorg/flytestdlib/logger"
)

const (
Expand All @@ -29,6 +31,23 @@
FromHTTPVal = "true"
)

type PreRedirectHookError struct {
Message string
Code int
}

func (e *PreRedirectHookError) Error() string {
return e.Message

Check warning on line 40 in auth/handlers.go

View check run for this annotation

Codecov / codecov/patch

auth/handlers.go#L39-L40

Added lines #L39 - L40 were not covered by tests
}

// PreRedirectHookFunc Interface used for running custom code before the redirect happens during a successful auth flow.
// This might be useful in cases where the auth flow allows the user to login since the IDP has been configured
// for eg: to allow all users from a particular domain to login
// but you want to restrict access to only a particular set of user ids. eg : users@domain.com are allowed to login but user user1@domain.com, user2@domain.com
// should only be allowed
// PreRedirectHookError is the error interface which allows the user to set correct http status code and Message to be set in case the function returns an error
// without which the current usage in GetCallbackHandler will set this to InternalServerError
type PreRedirectHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) *PreRedirectHookError
type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD
type UserInfoForwardResponseHandler func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error

Expand All @@ -39,11 +58,11 @@
Subject string
}

func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer, authCtx interfaces.AuthenticationContext) {
func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) {

Check warning on line 61 in auth/handlers.go

View check run for this annotation

Codecov / codecov/patch

auth/handlers.go#L61

Added line #L61 was not covered by tests
// Add HTTP handlers for OAuth2 endpoints
handler.HandleFunc("/login", RefreshTokensIfExists(ctx, authCtx,
GetLoginHandler(ctx, authCtx)))
handler.HandleFunc("/callback", GetCallbackHandler(ctx, authCtx))
handler.HandleFunc("/callback", GetCallbackHandler(ctx, authCtx, pluginRegistry))

Check warning on line 65 in auth/handlers.go

View check run for this annotation

Codecov / codecov/patch

auth/handlers.go#L65

Added line #L65 was not covered by tests

// The metadata endpoint is an RFC-defined constant, but we need a leading / for the handler to pattern match correctly.
handler.HandleFunc(fmt.Sprintf("/%s", OIdCMetadataEndpoint), GetOIdCMetadataEndpointRedirectHandler(ctx, authCtx))
Expand Down Expand Up @@ -129,14 +148,13 @@
logger.Errorf(ctx, "Was not able to create a redirect cookie")
}
}

http.Redirect(writer, request, url, http.StatusTemporaryRedirect)
}
}

// GetCallbackHandler returns a handler that is called by the OIdC provider with the authorization code to complete
// the user authentication flow.
func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationContext) http.HandlerFunc {
func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
logger.Debugf(ctx, "Running callback handler... for RequestURI %v", request.RequestURI)
authorizationCode := request.FormValue(AuthorizationResponseCodeType)
Expand Down Expand Up @@ -178,6 +196,20 @@
return
}

preRedirectHook := plugins.Get[PreRedirectHookFunc](pluginRegistry, plugins.PluginIDPreRedirectHook)
if preRedirectHook != nil {
logger.Infof(ctx, "preRedirect hook is set")
if err := preRedirectHook(ctx, authCtx, request, writer); err != nil {
logger.Errorf(ctx, "failed the preRedirect hook due %v with status code %v", err.Message, err.Code)
if http.StatusText(err.Code) != "" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also check that it isn't status ok? https://go.dev/src/net/http/status.go

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So i am relying on the user code which set the hook to set the correct http status code .
If they choose to set OK, then thats what we will set on the header.

writer.WriteHeader(err.Code)
} else {
writer.WriteHeader(http.StatusInternalServerError)
}

Check warning on line 208 in auth/handlers.go

View check run for this annotation

Codecov / codecov/patch

auth/handlers.go#L207-L208

Added lines #L207 - L208 were not covered by tests
return
}
logger.Info(ctx, "Successfully called the preRedirect hook")
}
redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request)
http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect)
}
Expand Down
61 changes: 51 additions & 10 deletions auth/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@ import (
"strings"
"testing"

"github.com/coreos/go-oidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/structpb"

"github.com/flyteorg/flyteadmin/auth/config"
"github.com/flyteorg/flyteadmin/auth/interfaces"
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteadmin/plugins"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
stdConfig "github.com/flyteorg/flytestdlib/config"

"github.com/coreos/go-oidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
)

const (
Expand Down Expand Up @@ -81,7 +82,8 @@ func TestGetCallbackHandlerWithErrorOnToken(t *testing.T) {
defer localServer.Close()
http.DefaultClient = localServer.Client()
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx)
r := plugins.NewRegistry()
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
addCsrfCookie(request)
addStateString(request)
Expand All @@ -102,7 +104,8 @@ func TestGetCallbackHandlerWithUnAuthorized(t *testing.T) {
defer localServer.Close()
http.DefaultClient = localServer.Client()
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx)
r := plugins.NewRegistry()
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
writer := httptest.NewRecorder()
callbackHandlerFunc(writer, request)
Expand Down Expand Up @@ -153,7 +156,8 @@ func TestGetCallbackHandler(t *testing.T) {

t.Run("forbidden request when accessing user info", func(t *testing.T) {
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx)
r := plugins.NewRegistry()
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
addCsrfCookie(request)
addStateString(request)
Expand All @@ -172,9 +176,15 @@ func TestGetCallbackHandler(t *testing.T) {
assert.Equal(t, "403 Forbidden", writer.Result().Status)
})

t.Run("successful callback and redirect", func(t *testing.T) {
t.Run("successful callback with redirect and successful preredirect hook call", func(t *testing.T) {
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx)
r := plugins.NewRegistry()
var redirectFunc PreRedirectHookFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError {
return nil
}

r.RegisterDefault(plugins.PluginIDPreRedirectHook, redirectFunc)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
addCsrfCookie(request)
addStateString(request)
Expand All @@ -193,6 +203,37 @@ func TestGetCallbackHandler(t *testing.T) {
callbackHandlerFunc(writer, request)
assert.Equal(t, "307 Temporary Redirect", writer.Result().Status)
})

t.Run("successful callback with pre-redirecthook failure", func(t *testing.T) {
mockAuthCtx := setupMockedAuthContextAtEndpoint(localServer.URL)
r := plugins.NewRegistry()
var redirectFunc PreRedirectHookFunc = func(redirectContext context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, responseWriter http.ResponseWriter) *PreRedirectHookError {
return &PreRedirectHookError{
Code: http.StatusPreconditionFailed,
Message: "precondition error",
}
}

r.RegisterDefault(plugins.PluginIDPreRedirectHook, redirectFunc)
callbackHandlerFunc := GetCallbackHandler(ctx, mockAuthCtx, r)
request := httptest.NewRequest("GET", localServer.URL+"/callback", nil)
addCsrfCookie(request)
addStateString(request)
writer := httptest.NewRecorder()
openIDConfigJSON = fmt.Sprintf(`{
"userinfo_endpoint": "%v/userinfo",
"issuer": "%v",
"authorization_endpoint": "%v/auth",
"token_endpoint": "%v/token",
"jwks_uri": "%v/keys",
"id_token_signing_alg_values_supported": ["RS256"]
}`, issuer, issuer, issuer, issuer, issuer)
oidcProvider, err := oidc.NewProvider(ctx, issuer)
assert.Nil(t, err)
mockAuthCtx.OnOidcProviderMatch().Return(oidcProvider)
callbackHandlerFunc(writer, request)
assert.Equal(t, "412 Precondition Failed", writer.Result().Status)
})
}

func TestGetLoginHandler(t *testing.T) {
Expand Down
53 changes: 26 additions & 27 deletions pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,39 @@ import (
"strings"
"time"

"github.com/gorilla/handlers"
grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/health"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/reflection"
"k8s.io/apimachinery/pkg/util/rand"

"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"

runtime2 "github.com/flyteorg/flyteadmin/pkg/runtime"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteadmin/dataproxy"
"github.com/flyteorg/flyteadmin/plugins"

"github.com/flyteorg/flyteadmin/auth"
"github.com/flyteorg/flyteadmin/auth/authzserver"
authConfig "github.com/flyteorg/flyteadmin/auth/config"
"github.com/flyteorg/flyteadmin/auth/interfaces"
"github.com/flyteorg/flyteadmin/dataproxy"
"github.com/flyteorg/flyteadmin/pkg/common"
"github.com/flyteorg/flyteadmin/pkg/config"
"github.com/flyteorg/flyteadmin/pkg/rpc"
"github.com/flyteorg/flyteadmin/pkg/rpc/adminservice"
runtime2 "github.com/flyteorg/flyteadmin/pkg/runtime"
runtimeIfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces"
"github.com/flyteorg/flyteadmin/plugins"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager"
"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/logger"
"github.com/gorilla/handlers"
grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/health"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"
"github.com/flyteorg/flytestdlib/storage"
)

var defaultCorsHeaders = []string{"Content-Type"}
Expand Down Expand Up @@ -163,7 +160,7 @@ func healthCheckFunc(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}

func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig.Config, authCtx interfaces.AuthenticationContext,
func newHTTPServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, _ *authConfig.Config, authCtx interfaces.AuthenticationContext,
additionalHandlers map[string]func(http.ResponseWriter, *http.Request),
grpcAddress string, grpcConnectionOpts ...grpc.DialOption) (*http.ServeMux, error) {

Expand Down Expand Up @@ -191,7 +188,7 @@ func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig.

if cfg.Security.UseAuth {
// Add HTTP handlers for OIDC endpoints
auth.RegisterHandlers(ctx, mux, authCtx)
auth.RegisterHandlers(ctx, mux, authCtx, pluginRegistry)

// Add HTTP handlers for OAuth2 endpoints
authzserver.RegisterHandlers(mux, authCtx)
Expand Down Expand Up @@ -278,7 +275,8 @@ func generateRequestID() string {

func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig,
authCfg *authConfig.Config, storageConfig *storage.Config,
additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error {
additionalHandlers map[string]func(http.ResponseWriter, *http.Request),
scope promutils.Scope) error {
logger.Infof(ctx, "Serving Flyte Admin Insecure")

// This will parse configuration and create the necessary objects for dealing with auth
Expand Down Expand Up @@ -343,7 +341,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry,
grpcOptions = append(grpcOptions,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes)))
}
httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetGrpcHostAddress(), grpcOptions...)
httpServer, err := newHTTPServer(ctx, pluginRegistry, cfg, authCfg, authCtx, additionalHandlers, cfg.GetGrpcHostAddress(), grpcOptions...)
if err != nil {
return err
}
Expand Down Expand Up @@ -390,7 +388,8 @@ func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Ha

func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, authCfg *authConfig.Config,
storageCfg *storage.Config,
additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error {
additionalHandlers map[string]func(http.ResponseWriter, *http.Request),
scope promutils.Scope) error {
certPool, cert, err := GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile)
if err != nil {
return err
Expand Down Expand Up @@ -445,7 +444,7 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c
serverOpts = append(serverOpts,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes)))
}
httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetHostAddress(), serverOpts...)
httpServer, err := newHTTPServer(ctx, pluginRegistry, cfg, authCfg, authCtx, additionalHandlers, cfg.GetHostAddress(), serverOpts...)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions plugins/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const (
PluginIDWorkflowExecutor PluginID = "WorkflowExecutor"
PluginIDDataProxy PluginID = "DataProxy"
PluginIDUnaryServiceMiddleware PluginID = "UnaryServiceMiddleware"
PluginIDPreRedirectHook PluginID = "PreRedirectHook"
)

type AtomicRegistry struct {
Expand Down
20 changes: 20 additions & 0 deletions plugins/registry_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package plugins

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -21,6 +23,24 @@ func TestNewAtomicRegistry(t *testing.T) {
assert.Equal(t, 5, r.Get(PluginIDDataProxy))
}

type PreRedirectHookFunc func(ctx context.Context) error

func TestRedirectHook(t *testing.T) {
ar := NewAtomicRegistry(nil)
r := NewRegistry()

var redirectHookfn PreRedirectHookFunc = func(ctx context.Context) error {
return fmt.Errorf("redirect hook error")
}
err := r.Register(PluginIDPreRedirectHook, redirectHookfn)
assert.NoError(t, err)
ar.Store(r)
r = ar.Load()
fn := Get[PreRedirectHookFunc](r, PluginIDPreRedirectHook)
err = fn(context.Background())
assert.Equal(t, fmt.Errorf("redirect hook error"), err)
}

func TestRegistry_RegisterDefault(t *testing.T) {
r := NewRegistry()
r.RegisterDefault("hello", 5)
Expand Down
Loading