Skip to content

Commit

Permalink
Merge pull request #7 from openfort-xyz/feat/pregenerate-wallet
Browse files Browse the repository at this point in the history
Feat/pregenerate wallet
  • Loading branch information
gllm-dev authored May 1, 2024
2 parents 85d1a0f + 70f1581 commit 29493e5
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 28 deletions.
41 changes: 41 additions & 0 deletions internal/infrastructure/authenticationmgr/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package authenticationmgr

import (
"context"
"errors"
"log/slog"
"strings"

"go.openfort.xyz/shield/pkg/contexter"
"go.openfort.xyz/shield/pkg/logger"

"go.openfort.xyz/shield/internal/core/domain"
"go.openfort.xyz/shield/internal/core/domain/provider"
"go.openfort.xyz/shield/internal/core/ports/authentication"
Expand All @@ -16,15 +21,21 @@ type Manager struct {
APISecretAuthenticator authentication.APISecretAuthenticator
UserAuthenticator authentication.UserAuthenticator
repo repositories.ProjectRepository
providerManager *providersmgr.Manager
userService services.UserService
mapOrigins map[string][]string
logger *slog.Logger
}

func NewManager(repo repositories.ProjectRepository, providerManager *providersmgr.Manager, userService services.UserService) *Manager {
return &Manager{
repo: repo,
APISecretAuthenticator: newAPISecretAuthenticator(repo),
providerManager: providerManager,
UserAuthenticator: newUserAuthenticator(repo, providerManager, userService),
userService: userService,
mapOrigins: make(map[string][]string),
logger: logger.New("authentication_manager"),
}
}

Expand All @@ -47,6 +58,36 @@ func (m *Manager) GetAuthProvider(providerStr string) (provider.Type, error) {
}
}

func (m *Manager) PreRegisterUser(ctx context.Context, userID string, providerType provider.Type) (string, error) {
projID := contexter.GetProjectID(ctx)
prov, err := m.providerManager.GetProvider(ctx, projID, providerType)
if err != nil {
m.logger.ErrorContext(ctx, "failed to get provider", logger.Error(err))
return "", err
}

usr, err := m.userService.GetByExternal(ctx, userID, prov.GetProviderID())
if err != nil {
if !errors.Is(err, domain.ErrUserNotFound) && !errors.Is(err, domain.ErrExternalUserNotFound) {
m.logger.ErrorContext(ctx, "failed to get user by external", logger.Error(err))
return "", err
}
usr, err = m.userService.Create(ctx, projID)
if err != nil {
m.logger.ErrorContext(ctx, "failed to create user", logger.Error(err))
return "", err
}

_, err = m.userService.CreateExternal(ctx, projID, usr.ID, userID, prov.GetProviderID())
if err != nil {
m.logger.ErrorContext(ctx, "failed to create external user", logger.Error(err))
return "", err
}
}

return usr.ID, nil
}

func (m *Manager) IsAllowedOrigin(ctx context.Context, apiKey string, origin string) (bool, error) {
if cachedOrigins, cached := m.mapOrigins[apiKey]; cached {
for _, o := range cachedOrigins {
Expand Down
3 changes: 3 additions & 0 deletions internal/infrastructure/handlers/rest/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ var (
ErrInvalidProviderConfig = &Error{"Invalid provider config", "PV_CFG_INVALID", http.StatusBadRequest}
ErrMissingKeyType = &Error{"Missing key type", "PV_CFG_INVALID", http.StatusBadRequest}
ErrProviderAlreadyExists = &Error{"Custom authentication already registered for this project", "PV_EXISTS", http.StatusConflict}
ErrMissingUserID = &Error{"Missing user ID", "US_ID_MISSING", http.StatusBadRequest}

ErrShareNotFound = &Error{"Share not found", "SH_NOT_FOUND", http.StatusNotFound}
ErrShareAlreadyExists = &Error{"Share already exists", "SH_EXISTS", http.StatusConflict}

ErrPreRegisterUser = &Error{"Failed to pre-register user", "US_PREREG_FAILED", http.StatusInternalServerError}

ErrUserNotFound = &Error{"User not found", "US_NOT_FOUND", http.StatusNotFound}
ErrExternalUserNotFound = &Error{"External user not found", "US_EXT_NOT_FOUND", http.StatusNotFound}
ErrExternalUserAlreadyExists = &Error{"External user already exists", "US_EXT_EXISTS", http.StatusConflict}
Expand Down
32 changes: 32 additions & 0 deletions internal/infrastructure/handlers/rest/authmdw/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const OpenfortProviderHeader = "X-Openfort-Provider" //nolint:go
const OpenfortTokenTypeHeader = "X-Openfort-Token-Type" //nolint:gosec
const AccessControlAllowOriginHeader = "Access-Control-Allow-Origin" //nolint:gosec
const EncryptionPartHeader = "X-Encryption-Part" //nolint:gosec
const UserIDHeader = "X-User-ID" //nolint:gosec

type Middleware struct {
manager *authenticationmgr.Manager
Expand Down Expand Up @@ -56,6 +57,37 @@ func (m *Middleware) AuthenticateAPISecret(next http.Handler) http.Handler {
})
}

func (m *Middleware) PreRegisterUser(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := r.Header.Get(UserIDHeader)
if userID == "" {
api.RespondWithError(w, api.ErrMissingUserID)
return
}

providerStr := r.Header.Get(AuthProviderHeader)
if providerStr == "" {
api.RespondWithError(w, api.ErrMissingAuthProvider)
return
}

provider, err := m.manager.GetAuthProvider(providerStr)
if err != nil {
api.RespondWithError(w, api.ErrInvalidAuthProvider)
return
}

usr, err := m.manager.PreRegisterUser(r.Context(), userID, provider)
if err != nil {
api.RespondWithError(w, api.ErrPreRegisterUser)
return
}

ctx := contexter.WithUserID(r.Context(), usr)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

func (m *Middleware) AuthenticateUser(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := r.Header.Get(APIKeyHeader)
Expand Down
5 changes: 5 additions & 0 deletions internal/infrastructure/handlers/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ func (s *Server) Start(ctx context.Context) error {
u.HandleFunc("", shareHdl.RegisterShare).Methods(http.MethodPost)
u.HandleFunc("", shareHdl.DeleteShare).Methods(http.MethodDelete)

a := r.PathPrefix("/admin").Subrouter()
a.Use(authMdw.AuthenticateAPISecret)
a.Use(authMdw.PreRegisterUser)
a.HandleFunc("/preregister", shareHdl.RegisterShare).Methods(http.MethodPost)

extraHeaders := strings.Split(s.config.CORSExtraAllowedHeaders, ",")
c := cors.New(cors.Options{
AllowOriginRequestFunc: authMdw.AllowedOrigin,
Expand Down
26 changes: 0 additions & 26 deletions pkg/contexter/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,6 @@ func GetProjectID(ctx context.Context) string {
return projectID
}

func WithAPIKey(ctx context.Context, apiKey string) context.Context {
return context.WithValue(ctx, ContextKeyAPIKey, apiKey)
}

func GetAPIKey(ctx context.Context) string {
apiKey, ok := ctx.Value(ContextKeyAPIKey).(string)
if !ok {
return ""
}

return apiKey
}

func WithAPISecret(ctx context.Context, apiSecret string) context.Context {
return context.WithValue(ctx, ContextKeyAPISecret, apiSecret)
}

func GetAPISecret(ctx context.Context) string {
apiSecret, ok := ctx.Value(ContextKeyAPISecret).(string)
if !ok {
return ""
}

return apiSecret
}

func WithUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, ContextKeyUserID, userID)
}
Expand Down
15 changes: 13 additions & 2 deletions pkg/logger/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,40 @@ import (
"go.openfort.xyz/shield/pkg/contexter"
)

// New creates a new standard logger with a context handler.
func New(name string) *slog.Logger {
return slog.New(NewContextHandler(slog.NewTextHandler(os.Stdout, nil))).WithGroup(name)
return slog.New(NewContextHandler(name, slog.NewTextHandler(os.Stdout, nil)))
}

// Error returns an attribute for an error string value.
func Error(err error) slog.Attr {
return slog.String("error", err.Error())
}

// ContextHandler is a logger handler that adds context attributes to log records.
type ContextHandler struct {
name string
baseHandler slog.Handler
}

func NewContextHandler(baseHandler slog.Handler) *ContextHandler {
// NewContextHandler creates a new context handler.
func NewContextHandler(name string, baseHandler slog.Handler) *ContextHandler {
return &ContextHandler{
name: name,
baseHandler: baseHandler,
}
}

var _ slog.Handler = (*ContextHandler)(nil)

// Enabled wraps the base handler's Enabled method.
func (c *ContextHandler) Enabled(ctx context.Context, level slog.Level) bool {
return c.baseHandler.Enabled(ctx, level)
}

// Handle warps the base handler's Handle method and adds context attributes to the log record.
func (c *ContextHandler) Handle(ctx context.Context, record slog.Record) error {
record.Add(slog.String("logger", c.name))
if projID := contexter.GetProjectID(ctx); projID != "" {
record.Add(slog.String(ProjectID, projID))
}
Expand All @@ -44,10 +53,12 @@ func (c *ContextHandler) Handle(ctx context.Context, record slog.Record) error {
return c.baseHandler.Handle(ctx, record)
}

// WithAttrs wraps the base handler's WithAttrs method.
func (c *ContextHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return c.baseHandler.WithAttrs(attrs)
}

// WithGroup wraps the base handler's WithGroup method.
func (c *ContextHandler) WithGroup(name string) slog.Handler {
return c.baseHandler.WithGroup(name)
}

0 comments on commit 29493e5

Please sign in to comment.