From 2280ccd6cd4ae8fa8d1751dd5fa80e45ad8a880b Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Wed, 1 May 2024 11:12:44 +0200 Subject: [PATCH 1/2] feat: add pre register user endpoint --- .../authenticationmgr/manager.go | 40 +++++++++++++++++++ .../handlers/rest/api/errors.go | 3 ++ .../handlers/rest/authmdw/middleware.go | 32 +++++++++++++++ .../infrastructure/handlers/rest/server.go | 5 +++ pkg/contexter/context.go | 26 ------------ pkg/logger/handler.go | 15 ++++++- 6 files changed, 93 insertions(+), 28 deletions(-) diff --git a/internal/infrastructure/authenticationmgr/manager.go b/internal/infrastructure/authenticationmgr/manager.go index e96a40f..623e7be 100644 --- a/internal/infrastructure/authenticationmgr/manager.go +++ b/internal/infrastructure/authenticationmgr/manager.go @@ -2,6 +2,10 @@ package authenticationmgr import ( "context" + "errors" + "go.openfort.xyz/shield/pkg/contexter" + "go.openfort.xyz/shield/pkg/logger" + "log/slog" "strings" "go.openfort.xyz/shield/internal/core/domain" @@ -16,15 +20,21 @@ type Manager struct { APISecretAuthenticator authentication.APISecretAuthenticator UserAuthenticator authentication.UserAuthenticator repo repositories.ProjectRepository + providerManager *providersmgr.Manager + userService services.UserService mapOrigins map[string][]string + logger *slog.Logger } func NewManager(repo repositories.ProjectRepository, providerManager *providersmgr.Manager, userService services.UserService) *Manager { return &Manager{ repo: repo, APISecretAuthenticator: newAPISecretAuthenticator(repo), + providerManager: providerManager, UserAuthenticator: newUserAuthenticator(repo, providerManager, userService), + userService: userService, mapOrigins: make(map[string][]string), + logger: logger.New("authentication_manager"), } } @@ -47,6 +57,36 @@ func (m *Manager) GetAuthProvider(providerStr string) (provider.Type, error) { } } +func (m *Manager) PreRegisterUser(ctx context.Context, userID string, providerType provider.Type) (string, error) { + projID := contexter.GetProjectID(ctx) + prov, err := m.providerManager.GetProvider(ctx, projID, providerType) + if err != nil { + m.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err)) + return "", err + } + + usr, err := m.userService.GetByExternal(ctx, userID, prov.GetProviderID()) + if err != nil { + if !errors.Is(err, domain.ErrUserNotFound) && !errors.Is(err, domain.ErrExternalUserNotFound) { + m.logger.ErrorContext(ctx, "failed to get user by external", logger.Error(err)) + return "", err + } + usr, err = m.userService.Create(ctx, projID) + if err != nil { + m.logger.ErrorContext(ctx, "failed to create user", logger.Error(err)) + return "", err + } + + _, err = m.userService.CreateExternal(ctx, projID, usr.ID, userID, prov.GetProviderID()) + if err != nil { + m.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err)) + return "", err + } + } + + return usr.ID, nil +} + func (m *Manager) IsAllowedOrigin(ctx context.Context, apiKey string, origin string) (bool, error) { if cachedOrigins, cached := m.mapOrigins[apiKey]; cached { for _, o := range cachedOrigins { diff --git a/internal/infrastructure/handlers/rest/api/errors.go b/internal/infrastructure/handlers/rest/api/errors.go index 3c98e74..b542dbd 100644 --- a/internal/infrastructure/handlers/rest/api/errors.go +++ b/internal/infrastructure/handlers/rest/api/errors.go @@ -30,10 +30,13 @@ var ( ErrInvalidProviderConfig = &Error{"Invalid provider config", "PV_CFG_INVALID", http.StatusBadRequest} ErrMissingKeyType = &Error{"Missing key type", "PV_CFG_INVALID", http.StatusBadRequest} ErrProviderAlreadyExists = &Error{"Custom authentication already registered for this project", "PV_EXISTS", http.StatusConflict} + ErrMissingUserID = &Error{"Missing user ID", "US_ID_MISSING", http.StatusBadRequest} ErrShareNotFound = &Error{"Share not found", "SH_NOT_FOUND", http.StatusNotFound} ErrShareAlreadyExists = &Error{"Share already exists", "SH_EXISTS", http.StatusConflict} + ErrPreRegisterUser = &Error{"Failed to pre-register user", "US_PREREG_FAILED", http.StatusInternalServerError} + ErrUserNotFound = &Error{"User not found", "US_NOT_FOUND", http.StatusNotFound} ErrExternalUserNotFound = &Error{"External user not found", "US_EXT_NOT_FOUND", http.StatusNotFound} ErrExternalUserAlreadyExists = &Error{"External user already exists", "US_EXT_EXISTS", http.StatusConflict} diff --git a/internal/infrastructure/handlers/rest/authmdw/middleware.go b/internal/infrastructure/handlers/rest/authmdw/middleware.go index 0be500f..e7b92a7 100644 --- a/internal/infrastructure/handlers/rest/authmdw/middleware.go +++ b/internal/infrastructure/handlers/rest/authmdw/middleware.go @@ -20,6 +20,7 @@ const OpenfortProviderHeader = "X-Openfort-Provider" //nolint:go const OpenfortTokenTypeHeader = "X-Openfort-Token-Type" //nolint:gosec const AccessControlAllowOriginHeader = "Access-Control-Allow-Origin" //nolint:gosec const EncryptionPartHeader = "X-Encryption-Part" //nolint:gosec +const UserIDHeader = "X-User-ID" //nolint:gosec type Middleware struct { manager *authenticationmgr.Manager @@ -56,6 +57,37 @@ func (m *Middleware) AuthenticateAPISecret(next http.Handler) http.Handler { }) } +func (m *Middleware) PreRegisterUser(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userID := r.Header.Get(UserIDHeader) + if userID == "" { + api.RespondWithError(w, api.ErrMissingUserID) + return + } + + providerStr := r.Header.Get(AuthProviderHeader) + if providerStr == "" { + api.RespondWithError(w, api.ErrMissingAuthProvider) + return + } + + provider, err := m.manager.GetAuthProvider(providerStr) + if err != nil { + api.RespondWithError(w, api.ErrInvalidAuthProvider) + return + } + + usr, err := m.manager.PreRegisterUser(r.Context(), userID, provider) + if err != nil { + api.RespondWithError(w, api.ErrPreRegisterUser) + return + } + + ctx := contexter.WithUserID(r.Context(), usr) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func (m *Middleware) AuthenticateUser(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiKey := r.Header.Get(APIKeyHeader) diff --git a/internal/infrastructure/handlers/rest/server.go b/internal/infrastructure/handlers/rest/server.go index 609dba5..513806a 100644 --- a/internal/infrastructure/handlers/rest/server.go +++ b/internal/infrastructure/handlers/rest/server.go @@ -75,6 +75,11 @@ func (s *Server) Start(ctx context.Context) error { u.HandleFunc("", shareHdl.RegisterShare).Methods(http.MethodPost) u.HandleFunc("", shareHdl.DeleteShare).Methods(http.MethodDelete) + a := r.PathPrefix("/admin").Subrouter() + a.Use(authMdw.AuthenticateAPISecret) + a.Use(authMdw.PreRegisterUser) + a.HandleFunc("/preregister", shareHdl.RegisterShare).Methods(http.MethodPost) + extraHeaders := strings.Split(s.config.CORSExtraAllowedHeaders, ",") c := cors.New(cors.Options{ AllowOriginRequestFunc: authMdw.AllowedOrigin, diff --git a/pkg/contexter/context.go b/pkg/contexter/context.go index 2b99b5f..1c0c1a3 100644 --- a/pkg/contexter/context.go +++ b/pkg/contexter/context.go @@ -28,32 +28,6 @@ func GetProjectID(ctx context.Context) string { return projectID } -func WithAPIKey(ctx context.Context, apiKey string) context.Context { - return context.WithValue(ctx, ContextKeyAPIKey, apiKey) -} - -func GetAPIKey(ctx context.Context) string { - apiKey, ok := ctx.Value(ContextKeyAPIKey).(string) - if !ok { - return "" - } - - return apiKey -} - -func WithAPISecret(ctx context.Context, apiSecret string) context.Context { - return context.WithValue(ctx, ContextKeyAPISecret, apiSecret) -} - -func GetAPISecret(ctx context.Context) string { - apiSecret, ok := ctx.Value(ContextKeyAPISecret).(string) - if !ok { - return "" - } - - return apiSecret -} - func WithUserID(ctx context.Context, userID string) context.Context { return context.WithValue(ctx, ContextKeyUserID, userID) } diff --git a/pkg/logger/handler.go b/pkg/logger/handler.go index 336c5e5..f601064 100644 --- a/pkg/logger/handler.go +++ b/pkg/logger/handler.go @@ -8,31 +8,40 @@ import ( "go.openfort.xyz/shield/pkg/contexter" ) +// New creates a new standard logger with a context handler. func New(name string) *slog.Logger { - return slog.New(NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup(name) + return slog.New(NewContextHandler(name, slog.NewTextHandler(os.Stdout, nil))) } +// Error returns an attribute for an error string value. func Error(err error) slog.Attr { return slog.String("error", err.Error()) } +// ContextHandler is a logger handler that adds context attributes to log records. type ContextHandler struct { + name string baseHandler slog.Handler } -func NewContextHandler(baseHandler slog.Handler) *ContextHandler { +// NewContextHandler creates a new context handler. +func NewContextHandler(name string, baseHandler slog.Handler) *ContextHandler { return &ContextHandler{ + name: name, baseHandler: baseHandler, } } var _ slog.Handler = (*ContextHandler)(nil) +// Enabled wraps the base handler's Enabled method. func (c *ContextHandler) Enabled(ctx context.Context, level slog.Level) bool { return c.baseHandler.Enabled(ctx, level) } +// Handle warps the base handler's Handle method and adds context attributes to the log record. func (c *ContextHandler) Handle(ctx context.Context, record slog.Record) error { + record.Add(slog.String("logger", c.name)) if projID := contexter.GetProjectID(ctx); projID != "" { record.Add(slog.String(ProjectID, projID)) } @@ -44,10 +53,12 @@ func (c *ContextHandler) Handle(ctx context.Context, record slog.Record) error { return c.baseHandler.Handle(ctx, record) } +// WithAttrs wraps the base handler's WithAttrs method. func (c *ContextHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return c.baseHandler.WithAttrs(attrs) } +// WithGroup wraps the base handler's WithGroup method. func (c *ContextHandler) WithGroup(name string) slog.Handler { return c.baseHandler.WithGroup(name) } From 70f158140d0f862ec67eb7159144ecf3a2c4fe10 Mon Sep 17 00:00:00 2001 From: gllm-dev Date: Wed, 1 May 2024 11:13:08 +0200 Subject: [PATCH 2/2] fix: linter --- internal/infrastructure/authenticationmgr/manager.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/infrastructure/authenticationmgr/manager.go b/internal/infrastructure/authenticationmgr/manager.go index 623e7be..036c424 100644 --- a/internal/infrastructure/authenticationmgr/manager.go +++ b/internal/infrastructure/authenticationmgr/manager.go @@ -3,11 +3,12 @@ package authenticationmgr import ( "context" "errors" - "go.openfort.xyz/shield/pkg/contexter" - "go.openfort.xyz/shield/pkg/logger" "log/slog" "strings" + "go.openfort.xyz/shield/pkg/contexter" + "go.openfort.xyz/shield/pkg/logger" + "go.openfort.xyz/shield/internal/core/domain" "go.openfort.xyz/shield/internal/core/domain/provider" "go.openfort.xyz/shield/internal/core/ports/authentication"