diff --git a/api/api.go b/api/api.go index 14f6d2b48a..c1bc130d21 100644 --- a/api/api.go +++ b/api/api.go @@ -269,15 +269,6 @@ func (a *ApplicationHandler) BuildControlPlaneRoutes() *chi.Mux { uiRouter.Route("/god-mode/configs", func(godModeRouter chi.Router) { godModeRouter.Use(middleware.RequireInstanceAdmin(handler.A)) - godModeRouter.Route("/defaults", func(defaultsRouter chi.Router) { - defaultsRouter.With(middleware.Pagination).Get("/", handler.GetInstanceDefaultsPaged) - defaultsRouter.Post("/", handler.CreateInstanceDefaults) - defaultsRouter.Route("/{configID}", func(configSubRouter chi.Router) { - configSubRouter.Get("/", handler.GetInstanceDefaults) - configSubRouter.Put("/", handler.UpdateInstanceDefaults) - }) - }) - godModeRouter.Route("/overrides", func(overridesRouter chi.Router) { overridesRouter.With(middleware.Pagination).Get("/", handler.GetInstanceOverridesPaged) overridesRouter.Post("/", handler.CreateInstanceOverrides) diff --git a/api/handlers/instance_defaults.go b/api/handlers/instance_defaults.go deleted file mode 100644 index 4d243ace64..0000000000 --- a/api/handlers/instance_defaults.go +++ /dev/null @@ -1,105 +0,0 @@ -package handlers - -import ( - "github.com/frain-dev/convoy/api/models" - "github.com/frain-dev/convoy/database/postgres" - "github.com/frain-dev/convoy/datastore" - m "github.com/frain-dev/convoy/internal/pkg/middleware" - "github.com/go-chi/chi/v5" - "net/http" - - "github.com/frain-dev/convoy/pkg/log" - - "github.com/frain-dev/convoy/util" - "github.com/go-chi/render" -) - -func (h *Handler) CreateInstanceDefaults(w http.ResponseWriter, r *http.Request) { - var instanceDefaults datastore.InstanceDefaults - err := util.ReadJSON(r, &instanceDefaults) - if err != nil { - _ = render.Render(w, r, util.NewErrorResponse(err.Error(), http.StatusBadRequest)) - return - } - - defaultsRepo := postgres.NewInstanceDefaultsRepo(h.A.DB) - - var result *datastore.InstanceDefaults - if result, err = defaultsRepo.Create(r.Context(), &instanceDefaults); err != nil { - _ = render.Render(w, r, util.NewErrorResponse(err.Error(), http.StatusUnprocessableEntity)) - return - } - - _ = render.Render(w, r, util.NewServerResponse("Instance default created successfully", result, http.StatusOK)) -} - -func (h *Handler) UpdateInstanceDefaults(w http.ResponseWriter, r *http.Request) { - var instanceDefaults datastore.InstanceDefaults - err := util.ReadJSON(r, &instanceDefaults) - if err != nil { - _ = render.Render(w, r, util.NewErrorResponse(err.Error(), http.StatusBadRequest)) - return - } - defaults, err := h.retrieveInstanceDefaults(r) - if err != nil { - _ = render.Render(w, r, util.NewErrorResponse(err.Error(), http.StatusBadRequest)) - return - } - - defaultsRepo := postgres.NewInstanceDefaultsRepo(h.A.DB) - - var result *datastore.InstanceDefaults - if result, err = defaultsRepo.Update(r.Context(), defaults.UID, &instanceDefaults); err != nil { - _ = render.Render(w, r, util.NewErrorResponse(err.Error(), http.StatusUnprocessableEntity)) - return - } - - _ = render.Render(w, r, util.NewServerResponse("Instance default updated successfully", result, http.StatusOK)) -} - -func (h *Handler) GetInstanceDefaults(w http.ResponseWriter, r *http.Request) { - var instanceDefaults datastore.InstanceDefaults - err := util.ReadJSON(r, &instanceDefaults) - if err != nil { - _ = render.Render(w, r, util.NewErrorResponse(err.Error(), http.StatusBadRequest)) - return - } - - defaults, err := h.retrieveInstanceDefaults(r) - if err != nil { - _ = render.Render(w, r, util.NewErrorResponse(err.Error(), http.StatusBadRequest)) - return - } - - _ = render.Render(w, r, util.NewServerResponse("Instance default fetched successfully", defaults, http.StatusOK)) -} - -func (h *Handler) retrieveInstanceDefaults(r *http.Request) (*datastore.InstanceDefaults, error) { - id := chi.URLParam(r, "configID") - - if util.IsStringEmpty(id) { - id = r.URL.Query().Get("configID") - } - - defaultsRepo := postgres.NewInstanceDefaultsRepo(h.A.DB) - return defaultsRepo.FetchByID(r.Context(), id) -} - -func (h *Handler) GetInstanceDefaultsPaged(w http.ResponseWriter, r *http.Request) { - pageable := m.GetPageableFromContext(r.Context()) - if pageable.PrevCursor == "" { - pageable.NextCursor = "0" - } - - defaultsRepo := postgres.NewInstanceDefaultsRepo(h.A.DB) - - defaults, paginationData, err := defaultsRepo.LoadPaged(r.Context(), pageable) - if err != nil { - log.FromContext(r.Context()).WithError(err).Error("failed to fetch instance defaults") - _ = render.Render(w, r, util.NewServiceErrResponse(err)) - return - } - - _ = render.Render(w, r, util.NewServerResponse("Data fetched successfully", - models.PagedResponse{Content: &defaults, Pagination: &paginationData}, http.StatusOK)) -} diff --git a/api/handlers/organisation.go b/api/handlers/organisation.go index 1eafbf95d5..d6873d5779 100644 --- a/api/handlers/organisation.go +++ b/api/handlers/organisation.go @@ -68,11 +68,12 @@ func (h *Handler) CreateOrganisation(w http.ResponseWriter, r *http.Request) { } co := services.CreateOrganisationService{ - OrgRepo: postgres.NewOrgRepo(h.A.DB), - OrgMemberRepo: postgres.NewOrgMemberRepo(h.A.DB), - NewOrg: &newOrg, - User: user, - Licenser: h.A.Licenser, + OrgRepo: postgres.NewOrgRepo(h.A.DB), + OrgMemberRepo: postgres.NewOrgMemberRepo(h.A.DB), + InstanceOverridesRepo: postgres.NewInstanceOverridesRepo(h.A.DB), + NewOrg: &newOrg, + User: user, + Licenser: h.A.Licenser, } organisation, err := co.Run(r.Context()) @@ -104,10 +105,11 @@ func (h *Handler) UpdateOrganisation(w http.ResponseWriter, r *http.Request) { } us := services.UpdateOrganisationService{ - OrgRepo: postgres.NewOrgRepo(h.A.DB), - OrgMemberRepo: postgres.NewOrgMemberRepo(h.A.DB), - Org: org, - Update: &orgUpdate, + OrgRepo: postgres.NewOrgRepo(h.A.DB), + OrgMemberRepo: postgres.NewOrgMemberRepo(h.A.DB), + InstanceOverridesRepo: postgres.NewInstanceOverridesRepo(h.A.DB), + Org: org, + Update: &orgUpdate, } org, err = us.Run(r.Context()) diff --git a/api/handlers/user.go b/api/handlers/user.go index fddbcc0c7d..eb0a526a2c 100644 --- a/api/handlers/user.go +++ b/api/handlers/user.go @@ -42,13 +42,14 @@ func (h *Handler) RegisterUser(w http.ResponseWriter, r *http.Request) { } rs := services.RegisterUserService{ - UserRepo: postgres.NewUserRepo(h.A.DB), - OrgRepo: postgres.NewOrgRepo(h.A.DB), - OrgMemberRepo: postgres.NewOrgMemberRepo(h.A.DB), - Queue: h.A.Queue, - JWT: jwt.NewJwt(&config.Auth.Jwt, h.A.Cache), - ConfigRepo: postgres.NewConfigRepo(h.A.DB), - Licenser: h.A.Licenser, + UserRepo: postgres.NewUserRepo(h.A.DB), + OrgRepo: postgres.NewOrgRepo(h.A.DB), + OrgMemberRepo: postgres.NewOrgMemberRepo(h.A.DB), + InstanceOverridesRepo: postgres.NewInstanceOverridesRepo(h.A.DB), + Queue: h.A.Queue, + JWT: jwt.NewJwt(&config.Auth.Jwt, h.A.Cache), + ConfigRepo: postgres.NewConfigRepo(h.A.DB), + Licenser: h.A.Licenser, BaseURL: baseUrl, Data: &newUser, diff --git a/api/ingest_cfg_integration_test.go b/api/ingest_cfg_integration_test.go index a1c0fdd00f..6127846a29 100644 --- a/api/ingest_cfg_integration_test.go +++ b/api/ingest_cfg_integration_test.go @@ -41,27 +41,8 @@ func TestIngestCfg_GetInstanceRateLimit(t *testing.T) { ctx := context.Background() - instanceDefaultsRepo := postgres.NewInstanceDefaultsRepo(db) instanceOverridesRepo := postgres.NewInstanceOverridesRepo(db) - t.Run("Default Found", func(t *testing.T) { - _, err := instanceDefaultsRepo.Create(ctx, &datastore.InstanceDefaults{ - UID: "default1", - ScopeType: instance.OrganisationScope, - Key: instance.KeyInstanceIngestRate, - DefaultValue: "{\"value\": 150}", - }) - assert.NoError(t, err) - - cacheKey := fmt.Sprintf("rate_limit:%s:%s:%s", instance.KeyInstanceIngestRate, projectID, organisationID) - err = memoryCache.Delete(ctx, cacheKey) - require.NoError(t, err) - - rateLimit, err := ingestCfg.GetInstanceRateLimitWithCache(ctx) - assert.NoError(t, err) - assert.Equal(t, 150, rateLimit) - }) - t.Run("Override Found", func(t *testing.T) { _, err := instanceOverridesRepo.Create(ctx, &datastore.InstanceOverrides{ UID: "override1", @@ -84,8 +65,6 @@ func TestIngestCfg_GetInstanceRateLimit(t *testing.T) { t.Run("Fallback to Default Rate", func(t *testing.T) { _, err := db.GetDB().ExecContext(ctx, `DELETE FROM convoy.instance_overrides WHERE key = $1`, instance.KeyInstanceIngestRate) assert.NoError(t, err) - _, err = db.GetDB().ExecContext(ctx, `DELETE FROM convoy.instance_defaults WHERE key = $1`, instance.KeyInstanceIngestRate) - assert.NoError(t, err) cacheKey := fmt.Sprintf("rate_limit:%s:%s:%s", instance.KeyInstanceIngestRate, projectID, organisationID) err = memoryCache.Delete(ctx, cacheKey) diff --git a/api/models/models.go b/api/models/models.go index b9604530af..d333559fda 100644 --- a/api/models/models.go +++ b/api/models/models.go @@ -2,6 +2,7 @@ package models import ( "encoding/json" + "github.com/frain-dev/convoy/config" "time" "github.com/frain-dev/convoy/auth" @@ -15,8 +16,20 @@ type PagedResponse struct { } type Organisation struct { - Name string `json:"name" bson:"name"` - CustomDomain string `json:"custom_domain" bson:"custom_domain"` + Name string `json:"name" bson:"name"` + CustomDomain string `json:"custom_domain" bson:"custom_domain"` + Config *Config `json:"config" bson:"config"` +} + +type Config struct { + StaticIP *bool `json:"static_ip" bson:"static_ip"` + EnterpriseSSO *bool `json:"enterprise_sso" bson:"enterprise_sso"` + ProjectConfig *Project `json:"project" bson:"project"` +} + +type Project struct { + RetentionPolicy *config.RetentionPolicyConfiguration `json:"retention_policy" bson:"retention_policy"` + IngestRateLimit *int `json:"ingest_rate_limit" bson:"ingest_rate_limit"` } type OrganisationInvite struct { diff --git a/api/retention_cfg_integration_test.go b/api/retention_cfg_integration_test.go index 7de09a15e0..98a4323f08 100644 --- a/api/retention_cfg_integration_test.go +++ b/api/retention_cfg_integration_test.go @@ -37,26 +37,8 @@ func TestRetentionCfg_GetRetentionPolicy(t *testing.T) { ctx := context.Background() - instanceDefaultsRepo := postgres.NewInstanceDefaultsRepo(db) instanceOverridesRepo := postgres.NewInstanceOverridesRepo(db) - t.Run("Default Found", func(t *testing.T) { - _, err := instanceDefaultsRepo.Create(ctx, &datastore.InstanceDefaults{ - UID: "default2", - ScopeType: instance.OrganisationScope, - Key: instance.KeyRetentionPolicy, - DefaultValue: "{\"policy\": \"36h\", \"enabled\": false}", - }) - assert.NoError(t, err) - - // Fetch the retention policy - retentionPolicy, err := retentionCfg.GetRetentionPolicy(ctx) - assert.NoError(t, err) - d, err := time.ParseDuration("36h") - require.NoError(t, err) - assert.Equal(t, d, retentionPolicy) - }) - t.Run("Override Found", func(t *testing.T) { _, err := instanceOverridesRepo.Create(ctx, &datastore.InstanceOverrides{ UID: "override2", @@ -77,8 +59,6 @@ func TestRetentionCfg_GetRetentionPolicy(t *testing.T) { t.Run("Fallback to Default Policy", func(t *testing.T) { _, err := db.GetDB().ExecContext(ctx, `DELETE FROM convoy.instance_overrides WHERE key = $1`, instance.KeyRetentionPolicy) assert.NoError(t, err) - _, err = db.GetDB().ExecContext(ctx, `DELETE FROM convoy.instance_defaults WHERE key = $1`, instance.KeyRetentionPolicy) - assert.NoError(t, err) retentionPolicy, err := retentionCfg.GetRetentionPolicy(ctx) assert.NoError(t, err) diff --git a/api/testdb/seed.go b/api/testdb/seed.go index 2ba394335e..bb1e2094af 100644 --- a/api/testdb/seed.go +++ b/api/testdb/seed.go @@ -725,7 +725,6 @@ func truncateTables(db database.Database) error { convoy.organisations, convoy.users, convoy.jobs, - convoy.instance_defaults, convoy.instance_overrides ` diff --git a/cmd/bootstrap/bootstrap.go b/cmd/bootstrap/bootstrap.go index 14bac4ebc5..d6bf5a7516 100644 --- a/cmd/bootstrap/bootstrap.go +++ b/cmd/bootstrap/bootstrap.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/frain-dev/convoy/auth" "time" "github.com/frain-dev/convoy/api/models" @@ -20,11 +21,16 @@ import ( "github.com/spf13/cobra" ) +var ( + ErrInstanceAdminOrRootRequired = errors.New("an instance admin or a root user is required") +) + func AddBootstrapCommand(a *cli.App) *cobra.Command { var firstName string var lastName string var format string var email string + var token string cmd := &cobra.Command{ Use: "bootstrap", @@ -33,7 +39,34 @@ func AddBootstrapCommand(a *cli.App) *cobra.Command { "ShouldBootstrap": "false", }, RunE: func(cmd *cobra.Command, args []string) error { - return runBootstrap(a, format, email, firstName, lastName) + + orgMemberRepo := postgres.NewOrgMemberRepo(a.DB) + + count, err := orgMemberRepo.CountSuperUsers(context.Background()) + if err != nil { + return fmt.Errorf("failed to count org admins: %w", err) + } + + if count > 0 { + // org admin exists + if token == "" { + return fmt.Errorf("an access token required to proceed") + } + authUser, member, err := getInstanceAdminOrRoot(a, token) + if err != nil { + log.WithError(err).Warn("failed to get instance admin or root") + return fmt.Errorf("failed to get instance admin or root: %w", err) + } + if authUser == nil || member == nil { + return ErrInstanceAdminOrRootRequired + } + + if member.Role.Type != auth.RoleRoot && member.Role.Type != auth.RoleInstanceAdmin { + return fmt.Errorf("invalid role %+v", authUser.Role.Type) + } + } + + return runBootstrap(a, format, email, firstName, lastName, auth.RoleOrganisationAdmin) }, } @@ -41,11 +74,12 @@ func AddBootstrapCommand(a *cli.App) *cobra.Command { cmd.Flags().StringVar(&firstName, "first-name", "admin", "Email") cmd.Flags().StringVar(&lastName, "last-name", "admin", "Email") cmd.Flags().StringVar(&format, "format", "json", "Output Format") + cmd.Flags().StringVar(&token, "token", "", "Root Personal Access Token") return cmd } -func runBootstrap(a *cli.App, format string, email string, firstName string, lastName string) error { +func runBootstrap(a *cli.App, format string, email string, firstName string, lastName string, roleType auth.RoleType) error { ok, err := a.Licenser.CreateUser(context.Background()) if err != nil { return err @@ -97,10 +131,13 @@ func runBootstrap(a *cli.App, format string, email string, firstName string, las } co := services.CreateOrganisationService{ - OrgRepo: postgres.NewOrgRepo(a.DB), - OrgMemberRepo: postgres.NewOrgMemberRepo(a.DB), - NewOrg: &models.Organisation{Name: "Default Organisation"}, - User: user, + OrgRepo: postgres.NewOrgRepo(a.DB), + OrgMemberRepo: postgres.NewOrgMemberRepo(a.DB), + InstanceOverridesRepo: postgres.NewInstanceOverridesRepo(a.DB), + NewOrg: &models.Organisation{Name: "Default Organisation"}, + User: user, + Licenser: a.Licenser, + RoleType: roleType, } _, err = co.Run(context.Background()) diff --git a/cmd/bootstrap/provision_instance_admin.go b/cmd/bootstrap/provision_instance_admin.go new file mode 100644 index 0000000000..2a6a9178f7 --- /dev/null +++ b/cmd/bootstrap/provision_instance_admin.go @@ -0,0 +1,129 @@ +package bootstrap + +import ( + "context" + "errors" + "fmt" + "github.com/frain-dev/convoy/auth" + "github.com/frain-dev/convoy/auth/realm_chain" + "github.com/frain-dev/convoy/config" + "github.com/frain-dev/convoy/database/postgres" + "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/internal/pkg/cli" + "github.com/frain-dev/convoy/internal/pkg/keys" + "github.com/spf13/cobra" +) + +var ( + ErrRootRequired = errors.New("a root user is required") +) + +func AddProvisionIACommand(a *cli.App) *cobra.Command { + var firstName string + var lastName string + var format string + var email string + var token string + + cmd := &cobra.Command{ + Use: "provision-instance-admin", + Short: "creates a new instance admin user account", + Annotations: map[string]string{ + "ShouldBootstrap": "false", + }, + RunE: func(_ *cobra.Command, args []string) error { + + if token == "" { + return fmt.Errorf("token required") + } + authUser, member, err := getInstanceAdminOrRoot(a, token) + if err != nil { + return err + } + + if member.Role.Type != auth.RoleRoot { + return fmt.Errorf("invalid role %+v", authUser.Role.Type) + } + + if authUser == nil || member == nil { + return ErrRootRequired + } + + return runBootstrap(a, format, email, firstName, lastName, auth.RoleInstanceAdmin) + }, + } + + cmd.Flags().StringVar(&email, "email", "", "Email") + cmd.Flags().StringVar(&firstName, "first-name", "instance-admin", "Email") + cmd.Flags().StringVar(&lastName, "last-name", "admin", "Email") + cmd.Flags().StringVar(&format, "format", "json", "Output Format") + cmd.Flags().StringVar(&token, "token", "", "Root Personal Access Token") + + return cmd +} + +func getInstanceAdminOrRoot(a *cli.App, token string) (*auth.AuthenticatedUser, *datastore.OrganisationMember, error) { + err := initialize(a) + if err != nil { + return nil, nil, err + } + + rc, err := realm_chain.Get() + if err != nil { + return nil, nil, err + } + authUser, err := rc.Authenticate(context.Background(), &auth.Credential{ + Type: auth.CredentialTypeAPIKey, + APIKey: token, + Token: token, + }) + if err != nil { + return nil, nil, fmt.Errorf("authorization failed %w", err) + } + + user, ok := authUser.Metadata.(*datastore.User) + if !ok { + return nil, nil, fmt.Errorf("authorization failed %w", err) + } + + orgMemberRepo := postgres.NewOrgMemberRepo(a.DB) + m, err := orgMemberRepo.FetchAnyInstanceAdminOrRootByUserID(context.Background(), user.UID) + if err != nil { + if errors.Is(err, datastore.ErrOrgMemberNotFound) { + return nil, nil, fmt.Errorf("root user not found %w", err) + } + return nil, nil, err + } + return authUser, m, nil +} + +func initialize(a *cli.App) error { + cfg, err := config.Get() + if err != nil { + return err + } + + km := keys.NewHCPVaultKeyManagerFromConfig(cfg.HCPVault, a.Licenser, a.Cache) + if km.IsSet() { + if _, err := km.GetCurrentKeyFromCache(); err != nil { + if !errors.Is(err, keys.ErrCredentialEncryptionFeatureUnavailable) { + return err + } + km.Unset() + } + } + if err := keys.Set(km); err != nil { + return err + } + + apiKeyRepo := postgres.NewAPIKeyRepo(a.DB) + userRepo := postgres.NewUserRepo(a.DB) + portalLinkRepo := postgres.NewPortalLinkRepo(a.DB) + + err = realm_chain.Init(&cfg.Auth, apiKeyRepo, userRepo, portalLinkRepo, a.Cache) + if err != nil { + a.Logger.WithError(err).Fatal("failed to initialize realm chain") + } + + return nil +} diff --git a/cmd/hooks/hooks.go b/cmd/hooks/hooks.go index 1208be0b5e..7159f2d834 100644 --- a/cmd/hooks/hooks.go +++ b/cmd/hooks/hooks.go @@ -163,7 +163,7 @@ func PreRun(app *cli.App, db *postgres.Postgres) func(cmd *cobra.Command, args [ app.Logger = lo app.Cache = ca - err = ensureRootUser(context.Background(), app, shouldBootstrap(cmd)) + err = ensureRootUser(context.Background(), app) if err != nil { return err } @@ -789,22 +789,14 @@ func shouldBootstrap(cmd *cobra.Command) bool { return false } -func ensureRootUser(ctx context.Context, a *cli.App, bootstrap bool) error { - return ensureUser(ctx, a, bootstrap) -} - -func ensureInstanceAdmin(ctx context.Context, a *cli.App, bootstrap bool) error { - return nil -} - -func ensureUser(ctx context.Context, a *cli.App, bootstrap bool) error { +func ensureRootUser(ctx context.Context, a *cli.App) error { userRepo := postgres.NewUserRepo(a.DB) orgRepo := postgres.NewOrgRepo(a.DB) orgMemberRepo := postgres.NewOrgMemberRepo(a.DB) - count, err := orgMemberRepo.CountInstanceAdmins(ctx) + count, err := orgMemberRepo.CountRootUsers(ctx) if err != nil { - return fmt.Errorf("failed to count instance admins: %w", err) + return fmt.Errorf("failed to count root admins: %w", err) } if count > 0 { @@ -818,26 +810,26 @@ func ensureUser(ctx context.Context, a *cli.App, bootstrap bool) error { return err } - instanceAdmin := &datastore.User{ + root := &datastore.User{ UID: ulid.Make().String(), - FirstName: "instance", - LastName: "admin", - Email: "instance-admin@default.com", + FirstName: "root", + LastName: "root", + Email: "root@default.com", Password: string(p.Hash), EmailVerified: true, CreatedAt: time.Now(), UpdatedAt: time.Now(), } - err = userRepo.CreateUser(ctx, instanceAdmin) + err = userRepo.CreateUser(ctx, root) if err != nil { - return fmt.Errorf("failed to create instance admin - %w", err) + return fmt.Errorf("failed to create root users - %w", err) } org := &datastore.Organisation{ UID: ulid.Make().String(), - OwnerID: instanceAdmin.UID, - Name: "Instance Org", + OwnerID: root.UID, + Name: "Root Org", CreatedAt: time.Now(), UpdatedAt: time.Now(), } @@ -850,8 +842,8 @@ func ensureUser(ctx context.Context, a *cli.App, bootstrap bool) error { member := &datastore.OrganisationMember{ UID: ulid.Make().String(), OrganisationID: org.UID, - UserID: instanceAdmin.UID, - Role: auth.Role{Type: auth.RoleInstanceAdmin}, + UserID: root.UID, + Role: auth.Role{Type: auth.RoleRoot}, CreatedAt: time.Now(), UpdatedAt: time.Now(), } @@ -861,7 +853,7 @@ func ensureUser(ctx context.Context, a *cli.App, bootstrap bool) error { return err } - a.Logger.Infof("Created instance admin with username: %s and password: %s", instanceAdmin.Email, p.Plaintext) + a.Logger.Infof("Created root user with username: %s and password: %s", root.Email, p.Plaintext) return nil } diff --git a/cmd/main.go b/cmd/main.go index 4ce13d1d19..7bc4f5b2d8 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -150,6 +150,7 @@ func main() { c.AddCommand(stream.AddStreamCommand(app)) c.AddCommand(ingest.AddIngestCommand(app)) c.AddCommand(bootstrap.AddBootstrapCommand(app)) + c.AddCommand(bootstrap.AddProvisionIACommand(app)) c.AddCommand(agent.AddAgentCommand(app)) c.AddCommand(ff.AddFeatureFlagsCommand()) c.AddCommand(utils.AddUtilsCommand(app)) diff --git a/database/postgres/instance_defaults.go b/database/postgres/instance_defaults.go deleted file mode 100644 index 9ab27508ce..0000000000 --- a/database/postgres/instance_defaults.go +++ /dev/null @@ -1,269 +0,0 @@ -package postgres - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "fmt" - "github.com/frain-dev/convoy/config" - "github.com/frain-dev/convoy/database" - "github.com/frain-dev/convoy/datastore" - "github.com/frain-dev/convoy/internal/pkg/instance" - "github.com/oklog/ulid/v2" - "time" -) - -var ( - ErrKeyCannotBeEmpty = errors.New("key cannot be empty") - ErrDefaultValueCannotBeEmpty = errors.New("default_value (plaintext) cannot be empty") - ErrInvalidIngestRate = errors.New("invalid ingest rate json value") - ErrInvalidRetentionPolicy = errors.New("invalid retention policy json value") -) - -type instanceDefaultsRepo struct { - db database.Database -} - -func NewInstanceDefaultsRepo(db database.Database) datastore.InstanceDefaultsRepository { - return &instanceDefaultsRepo{ - db: db, - } -} - -func (i *instanceDefaultsRepo) Create(ctx context.Context, instanceDefault *datastore.InstanceDefaults) (*datastore.InstanceDefaults, error) { - err := validate(instanceDefault) - if err != nil { - return nil, err - } - - encryptionPassphrase := instance.GetEncryptionPassphrase() - - instanceDefault.UID = ulid.Make().String() - - query := ` - INSERT INTO convoy.instance_defaults (id, scope_type, key, default_value_cipher, created_at, updated_at) - VALUES ($1, $2, $3, pgp_sym_encrypt($4::text, $5), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ` - _, err = i.db.GetDB().ExecContext(ctx, query, - instanceDefault.UID, - instanceDefault.ScopeType, - instanceDefault.Key, - instanceDefault.DefaultValue, - encryptionPassphrase, - ) - - if err != nil { - return nil, err - } - - d, err := i.FetchByID(ctx, instanceDefault.UID) - if err != nil { - return nil, err - } - - return d, err -} - -func validate(instanceDefault *datastore.InstanceDefaults) error { - validScopeTypes := map[string]bool{ - "organisation": true, - "project": true, - } - if !validScopeTypes[instanceDefault.ScopeType] { - return fmt.Errorf("invalid scope_type: %s, must be 'organisation' or 'project'", instanceDefault.ScopeType) - } - - if instanceDefault.Key == "" { - return ErrKeyCannotBeEmpty - } - validKeys := map[string]bool{ - instance.KeyInstanceIngestRate: true, - instance.KeyRetentionPolicy: true, - } - if !validKeys[instanceDefault.Key] { - return fmt.Errorf("invalid key: %s", instanceDefault.Key) - } - - if instanceDefault.DefaultValue == "" { - return ErrDefaultValueCannotBeEmpty - } - - if instanceDefault.Key == instance.KeyInstanceIngestRate { - var ingestRate instance.IngestRate - err := json.Unmarshal([]byte(instanceDefault.DefaultValue), &ingestRate) - if err != nil || ingestRate.Value == 0 { - return ErrInvalidIngestRate - } - } else { - var retentionPolicy config.RetentionPolicyConfiguration - err := json.Unmarshal([]byte(instanceDefault.DefaultValue), &retentionPolicy) - if err != nil { - return ErrInvalidRetentionPolicy - } - _, err = time.ParseDuration(retentionPolicy.Policy) - if err != nil { - return ErrInvalidRetentionPolicy - } - } - - return nil -} - -func (i *instanceDefaultsRepo) Update(ctx context.Context, id string, instanceDefault *datastore.InstanceDefaults) (*datastore.InstanceDefaults, error) { - err := validate(instanceDefault) - if err != nil { - return nil, err - } - encryptionPassphrase := instance.GetEncryptionPassphrase() - - query := ` - UPDATE convoy.instance_defaults - SET scope_type = $1, - key = $2, - default_value_cipher = pgp_sym_encrypt($3::text, $4), - updated_at = CURRENT_TIMESTAMP - WHERE id = $5 - ` - _, err = i.db.GetDB().ExecContext(ctx, query, - instanceDefault.ScopeType, - instanceDefault.Key, - instanceDefault.DefaultValue, - encryptionPassphrase, - id, - ) - if err != nil { - return nil, err - } - - d, err := i.FetchByID(ctx, id) - if err != nil { - return nil, err - } - - return d, err -} - -func (i *instanceDefaultsRepo) FetchByID(ctx context.Context, id string) (*datastore.InstanceDefaults, error) { - encryptionPassphrase := instance.GetEncryptionPassphrase() - - query := ` - SELECT id, scope_type, key, - pgp_sym_decrypt(default_value_cipher::bytea, $1) AS default_value_cipher, - created_at, updated_at, deleted_at - FROM convoy.instance_defaults - WHERE id = $2 - ` - - row := i.db.GetDB().QueryRowContext(ctx, query, encryptionPassphrase, id) - instanceDefault := &datastore.InstanceDefaults{} - err := row.Scan( - &instanceDefault.UID, - &instanceDefault.ScopeType, - &instanceDefault.Key, - &instanceDefault.DefaultValue, - &instanceDefault.CreatedAt, - &instanceDefault.UpdatedAt, - &instanceDefault.DeletedAt, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, datastore.ErrConfigNotFound - } - return nil, err - } - return instanceDefault, nil -} - -func (i *instanceDefaultsRepo) LoadPaged(ctx context.Context, pageable datastore.Pageable) ([]datastore.InstanceDefaults, datastore.PaginationData, error) { - encryptionPassphrase := instance.GetEncryptionPassphrase() - - var query string - var args []interface{} - - if pageable.PrevCursor != "" { - query = ` - SELECT id, scope_type, key, - pgp_sym_decrypt(default_value_cipher::bytea, $1) AS default_value_cipher, - created_at, updated_at, deleted_at - FROM convoy.instance_defaults - WHERE id < $2 - ORDER BY id DESC - LIMIT $3 - ` - args = append(args, encryptionPassphrase, pageable.PrevCursor, pageable.PerPage+1) - } else if pageable.NextCursor != "" { - query = ` - SELECT id, scope_type, key, - pgp_sym_decrypt(default_value_cipher::bytea, $1) AS default_value_cipher, - created_at, updated_at, deleted_at - FROM convoy.instance_defaults - WHERE id > $2 - ORDER BY id - LIMIT $3 - ` - args = append(args, encryptionPassphrase, pageable.NextCursor, pageable.PerPage+1) - } else { - query = ` - SELECT id, scope_type, key, - pgp_sym_decrypt(default_value_cipher::bytea, $1) AS default_value_cipher, - created_at, updated_at, deleted_at - FROM convoy.instance_defaults - ORDER BY id - LIMIT $2 - ` - args = append(args, encryptionPassphrase, pageable.PerPage+1) - } - - rows, err := i.db.GetDB().QueryContext(ctx, query, args...) - if err != nil { - return nil, datastore.PaginationData{}, err - } - defer rows.Close() - - var instanceDefaults = make([]datastore.InstanceDefaults, 0) - var rowCount int - var firstID, lastID string - - for rows.Next() { - if rowCount == pageable.PerPage { - break - } - - instanceDefault := datastore.InstanceDefaults{} - err := rows.Scan( - &instanceDefault.UID, - &instanceDefault.ScopeType, - &instanceDefault.Key, - &instanceDefault.DefaultValue, - &instanceDefault.CreatedAt, - &instanceDefault.UpdatedAt, - &instanceDefault.DeletedAt, - ) - if err != nil { - return nil, datastore.PaginationData{}, err - } - - if rowCount == 0 { - firstID = instanceDefault.UID - } - lastID = instanceDefault.UID - - instanceDefaults = append(instanceDefaults, instanceDefault) - rowCount++ - } - - hasNextPage := rowCount > pageable.PerPage - hasPreviousPage := pageable.PrevCursor != "" - - paginationData := datastore.PaginationData{ - PrevRowCount: datastore.PrevRowCount{Count: 0}, - PerPage: int64(pageable.PerPage), - HasNextPage: hasNextPage, - HasPreviousPage: hasPreviousPage, - PrevPageCursor: firstID, - NextPageCursor: lastID, - } - - return instanceDefaults, paginationData, nil -} diff --git a/database/postgres/instance_overrides.go b/database/postgres/instance_overrides.go index b0ceb6f4e3..80e2facd3f 100644 --- a/database/postgres/instance_overrides.go +++ b/database/postgres/instance_overrides.go @@ -11,10 +11,15 @@ import ( "github.com/frain-dev/convoy/datastore" "github.com/frain-dev/convoy/internal/pkg/instance" "github.com/oklog/ulid/v2" + "strings" "time" ) var ( + ErrKeyCannotBeEmpty = errors.New("key cannot be empty") + ErrInvalidBool = errors.New("invalid bool json value") + ErrInvalidIngestRate = errors.New("invalid ingest rate json value") + ErrInvalidRetentionPolicy = errors.New("invalid retention policy json value") ErrValueCipherCannotBeEmpty = errors.New("value (plaintext) cannot be empty") ErrScopeIDCannotBeEmpty = errors.New("scope_id cannot be empty") ) @@ -39,23 +44,33 @@ func (i *instanceOverridesRepo) Create(ctx context.Context, instanceOverride *da instanceOverride.UID = ulid.Make().String() + key := encryptionPassphrase + "-" + instanceOverride.UID + query := ` INSERT INTO convoy.instance_overrides (id, scope_type, scope_id, key, value_cipher, created_at, updated_at) VALUES ($1, $2, $3, $4, pgp_sym_encrypt($5::text, $6), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + ON CONFLICT (scope_type, scope_id, key) + DO UPDATE SET + value_cipher = pgp_sym_encrypt($5::text, CONCAT($7::text, '-', convoy.instance_overrides.id)), + updated_at = CURRENT_TIMESTAMP + RETURNING id; ` - _, err = i.db.GetDB().ExecContext(ctx, query, + + var uid string + err = i.db.GetDB().QueryRowContext(ctx, query, instanceOverride.UID, instanceOverride.ScopeType, instanceOverride.ScopeID, instanceOverride.Key, instanceOverride.Value, + key, encryptionPassphrase, - ) + ).Scan(&uid) if err != nil { return nil, err } - o, err := i.FetchByID(ctx, instanceOverride.UID) + o, err := i.FetchByID(ctx, uid) if err != nil { return nil, err } @@ -78,6 +93,8 @@ func (i *instanceOverridesRepo) validateOverride(ctx context.Context, instanceOv validKeys := map[string]bool{ instance.KeyInstanceIngestRate: true, instance.KeyRetentionPolicy: true, + instance.KeyStaticIP: true, + instance.KeyEnterpriseSSO: true, } if !validKeys[instanceOverride.Key] { return fmt.Errorf("invalid key: %s", instanceOverride.Key) @@ -97,7 +114,7 @@ func (i *instanceOverridesRepo) validateOverride(ctx context.Context, instanceOv if err != nil || ingestRate.Value == 0 { return ErrInvalidIngestRate } - } else { + } else if instanceOverride.Key == instance.KeyRetentionPolicy { var retentionPolicy config.RetentionPolicyConfiguration err := json.Unmarshal([]byte(instanceOverride.Value), &retentionPolicy) if err != nil { @@ -107,6 +124,12 @@ func (i *instanceOverridesRepo) validateOverride(ctx context.Context, instanceOv if err != nil { return ErrInvalidRetentionPolicy } + } else if instanceOverride.Key == instance.KeyStaticIP || instanceOverride.Key == instance.KeyEnterpriseSSO { + var boolean instance.Boolean + err := json.Unmarshal([]byte(instanceOverride.Value), &boolean) + if err != nil { + return ErrInvalidBool + } } if instanceOverride.ScopeType == instance.OrganisationScope { @@ -144,7 +167,7 @@ func (i *instanceOverridesRepo) Update(ctx context.Context, id string, instanceO SET scope_type = $1, scope_id = $2, key = $3, - value_cipher = pgp_sym_encrypt($4::text, $5), + value_cipher = pgp_sym_encrypt($4::text, CONCAT($5::text, '-', id)), updated_at = CURRENT_TIMESTAMP WHERE id = $6 ` @@ -173,7 +196,7 @@ func (i *instanceOverridesRepo) FetchByID(ctx context.Context, id string) (*data query := ` SELECT id, scope_type, scope_id, key, - pgp_sym_decrypt(value_cipher::bytea, $1) AS value_cipher, + pgp_sym_decrypt(value_cipher::bytea, CONCAT($1::text, '-', id)) AS value_cipher, created_at, updated_at, deleted_at FROM convoy.instance_overrides WHERE id = $2 @@ -209,7 +232,7 @@ func (i *instanceOverridesRepo) LoadPaged(ctx context.Context, pageable datastor if pageable.PrevCursor != "" { query = ` SELECT id, scope_type, scope_id, key, - pgp_sym_decrypt(value_cipher::bytea, $1) AS value_cipher, + pgp_sym_decrypt(value_cipher::bytea, CONCAT($1::text, '-', id)) AS value_cipher, created_at, updated_at, deleted_at FROM convoy.instance_overrides WHERE id < $2 @@ -220,7 +243,7 @@ func (i *instanceOverridesRepo) LoadPaged(ctx context.Context, pageable datastor } else if pageable.NextCursor != "" { query = ` SELECT id, scope_type, scope_id, key, - pgp_sym_decrypt(value_cipher::bytea, $1) AS value_cipher, + pgp_sym_decrypt(value_cipher::bytea, CONCAT($1::text, '-', id)) AS value_cipher, created_at, updated_at, deleted_at FROM convoy.instance_overrides WHERE id > $2 @@ -231,7 +254,7 @@ func (i *instanceOverridesRepo) LoadPaged(ctx context.Context, pageable datastor } else { query = ` SELECT id, scope_type, scope_id, key, - pgp_sym_decrypt(value_cipher::bytea, $1) AS value_cipher, + pgp_sym_decrypt(value_cipher::bytea, CONCAT($1::text, '-', id)) AS value_cipher, created_at, updated_at, deleted_at FROM convoy.instance_overrides ORDER BY id @@ -293,3 +316,26 @@ func (i *instanceOverridesRepo) LoadPaged(ctx context.Context, pageable datastor return instanceOverrides, paginationData, nil } + +func (i *instanceOverridesRepo) DeleteUnUpdatedKeys(ctx context.Context, scopeType, scopeID string, keysToUpdate map[string]bool) error { + var query string + var args []interface{} + + query = ` + DELETE FROM convoy.instance_overrides + WHERE scope_type = $1 + AND scope_id = $2` + args = append(args, scopeType, scopeID) + + if len(keysToUpdate) > 0 { + query += ` AND key NOT IN (` + var keyList []string + for key := range keysToUpdate { + keyList = append(keyList, fmt.Sprintf("'%s'", key)) + } + query += fmt.Sprintf("%s)", strings.Join(keyList, ", ")) + } + + _, err := i.db.GetDB().ExecContext(ctx, query, scopeType, scopeID) + return err +} diff --git a/database/postgres/organisation.go b/database/postgres/organisation.go index 2656c814b1..e173932aab 100644 --- a/database/postgres/organisation.go +++ b/database/postgres/organisation.go @@ -3,11 +3,14 @@ package postgres import ( "context" "database/sql" + "encoding/json" "errors" "fmt" - + "github.com/frain-dev/convoy/config" "github.com/frain-dev/convoy/database" "github.com/frain-dev/convoy/datastore" + "github.com/frain-dev/convoy/internal/pkg/instance" + "github.com/frain-dev/convoy/pkg/log" "github.com/jmoiron/sqlx" ) @@ -257,9 +260,84 @@ func (o *orgRepo) FetchOrganisationByID(ctx context.Context, id string) (*datast return nil, err } + err = EnrichOrganisationWithOverrides(ctx, o.db.GetReadDB(), org, instance.GetEncryptionPassphrase()) + if err != nil { + return nil, err + } + return org, nil } +func EnrichOrganisationWithOverrides(ctx context.Context, db *sqlx.DB, org *datastore.Organisation, encryptionKey string) error { + query := ` + SELECT key, pgp_sym_decrypt(value_cipher::bytea, CONCAT($1::text, '-', id)) AS value + FROM convoy.instance_overrides + WHERE scope_type = 'organisation' AND scope_id = $2; + ` + + rows, err := db.QueryContext(ctx, query, encryptionKey, org.UID) + if err != nil { + return fmt.Errorf("failed to fetch overrides: %w", err) + } + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + log.Error("failed to close rows: ", err) + } + }(rows) + + if org.Config == nil { + org.Config = &datastore.InstanceConfig{} + } + if org.Config.ProjectConfig == nil { + org.Config.ProjectConfig = &datastore.ProjectInstanceConfig{} + } + + for rows.Next() { + var key, value string + if err := rows.Scan(&key, &value); err != nil { + return fmt.Errorf("failed to scan override: %w", err) + } + + switch key { + case instance.KeyStaticIP: + var v instance.Boolean + if err := json.Unmarshal([]byte(value), &v); err != nil { + return fmt.Errorf("failed to unmarshal static_ip: %w", err) + } + org.Config.StaticIP = &v.Value + + case instance.KeyEnterpriseSSO: + var v instance.Boolean + if err := json.Unmarshal([]byte(value), &v); err != nil { + return fmt.Errorf("failed to unmarshal enterprise_sso: %w", err) + } + org.Config.EnterpriseSSO = &v.Value + + case instance.KeyRetentionPolicy: + var v config.RetentionPolicyConfiguration + if err := json.Unmarshal([]byte(value), &v); err != nil { + return fmt.Errorf("failed to unmarshal retention_policy: %w", err) + } + org.Config.ProjectConfig.RetentionPolicy = &v + + case instance.KeyInstanceIngestRate: + var v instance.IngestRate + if err := json.Unmarshal([]byte(value), &v); err != nil { + return fmt.Errorf("failed to unmarshal ingest_rate_limit: %w", err) + } + org.Config.ProjectConfig.IngestRateLimit = &v.Value + } + } + + // Check for errors during iteration + if err := rows.Err(); err != nil { + return fmt.Errorf("error iterating over overrides: %w", err) + } + + return nil +} + func (o *orgRepo) FetchOrganisationByAssignedDomain(ctx context.Context, domain string) (*datastore.Organisation, error) { org := &datastore.Organisation{} err := o.db.GetReadDB().QueryRowxContext(ctx, fmt.Sprintf("%s AND assigned_domain = $1", fetchOrganisation), domain).StructScan(org) diff --git a/database/postgres/organisation_member.go b/database/postgres/organisation_member.go index 008969ff77..e5b2f500cc 100644 --- a/database/postgres/organisation_member.go +++ b/database/postgres/organisation_member.go @@ -95,14 +95,25 @@ const ( WHERE o.user_id = $1 AND (o.role_type='instance_admin' OR o.role_type='root') AND o.deleted_at IS NULL LIMIT 1; ` - countInstanceAdmins = ` + countRootUsers = ` SELECT COUNT(*) FROM ( SELECT o.id AS id FROM convoy.organisation_members o LEFT JOIN convoy.users u ON o.user_id = u.id - WHERE o.role_type='instance_admin' AND o.deleted_at IS NULL LIMIT 1 + WHERE o.role_type='root' AND o.deleted_at IS NULL LIMIT 1 + ) ou; + ` + + countSuperUsers = ` + SELECT COUNT(*) FROM ( + SELECT + o.id AS id + FROM convoy.organisation_members o + LEFT JOIN convoy.users u + ON o.user_id = u.id + WHERE o.role_type='organisation_admin' AND o.deleted_at IS NULL LIMIT 1 ) ou; ` @@ -540,9 +551,19 @@ func (o *orgMemberRepo) FetchAnyInstanceAdminOrRootByUserID(ctx context.Context, return member, nil } -func (o *orgMemberRepo) CountInstanceAdmins(ctx context.Context) (int64, error) { +func (o *orgMemberRepo) CountRootUsers(ctx context.Context) (int64, error) { + var count int64 + err := o.db.GetReadDB().GetContext(ctx, &count, countRootUsers) + if err != nil { + return 0, err + } + + return count, nil +} + +func (o *orgMemberRepo) CountSuperUsers(ctx context.Context) (int64, error) { var count int64 - err := o.db.GetReadDB().GetContext(ctx, &count, countInstanceAdmins) + err := o.db.GetReadDB().GetContext(ctx, &count, countSuperUsers) if err != nil { return 0, err } diff --git a/database/postgres/organisation_test.go b/database/postgres/organisation_test.go index a86f59c394..9cb712dd6c 100644 --- a/database/postgres/organisation_test.go +++ b/database/postgres/organisation_test.go @@ -157,6 +157,8 @@ func TestFetchOrganisationByID(t *testing.T) { dbOrg.CreatedAt = time.Time{} dbOrg.UpdatedAt = time.Time{} + dbOrg.Config = nil + require.Equal(t, org, dbOrg) } diff --git a/datastore/models.go b/datastore/models.go index a87199bde5..40a08be9b6 100644 --- a/datastore/models.go +++ b/datastore/models.go @@ -1347,14 +1347,26 @@ type ApiKey struct { } type Organisation struct { - UID string `json:"uid" db:"id"` - OwnerID string `json:"" db:"owner_id"` - Name string `json:"name" db:"name"` - CustomDomain null.String `json:"custom_domain" db:"custom_domain"` - AssignedDomain null.String `json:"assigned_domain" db:"assigned_domain"` - CreatedAt time.Time `json:"created_at,omitempty" db:"created_at,omitempty" swaggertype:"string"` - UpdatedAt time.Time `json:"updated_at,omitempty" db:"updated_at,omitempty" swaggertype:"string"` - DeletedAt null.Time `json:"deleted_at,omitempty" db:"deleted_at" swaggertype:"string"` + UID string `json:"uid" db:"id"` + OwnerID string `json:"" db:"owner_id"` + Name string `json:"name" db:"name"` + CustomDomain null.String `json:"custom_domain" db:"custom_domain"` + AssignedDomain null.String `json:"assigned_domain" db:"assigned_domain"` + CreatedAt time.Time `json:"created_at,omitempty" db:"created_at,omitempty" swaggertype:"string"` + UpdatedAt time.Time `json:"updated_at,omitempty" db:"updated_at,omitempty" swaggertype:"string"` + DeletedAt null.Time `json:"deleted_at,omitempty" db:"deleted_at" swaggertype:"string"` + Config *InstanceConfig `json:"config,omitempty" db:"-"` +} + +type InstanceConfig struct { + StaticIP *bool `json:"static_ip,omitempty" bson:"static_ip"` + EnterpriseSSO *bool `json:"enterprise_sso,omitempty" bson:"enterprise_sso"` + ProjectConfig *ProjectInstanceConfig `json:"project,omitempty" bson:"project"` +} + +type ProjectInstanceConfig struct { + RetentionPolicy *config.RetentionPolicyConfiguration `json:"retention_policy,omitempty" bson:"retention_policy"` + IngestRateLimit *int `json:"ingest_rate_limit,omitempty" bson:"ingest_rate_limit"` } type Configuration struct { @@ -1562,7 +1574,7 @@ type InstanceOverrides struct { ScopeType string `json:"scope_type" db:"scope_type"` ScopeID string `json:"scope_id" db:"scope_id"` Key string `json:"key" db:"key"` - Value string `json:"value_cipher" db:"value_cipher"` + Value string `json:"valuelp" db:"value_cipher"` CreatedAt time.Time `json:"created_at,omitempty" db:"created_at,omitempty" swaggertype:"string"` UpdatedAt time.Time `json:"updated_at,omitempty" db:"updated_at,omitempty" swaggertype:"string"` diff --git a/datastore/repository.go b/datastore/repository.go index 88af4826d3..22d56964b0 100644 --- a/datastore/repository.go +++ b/datastore/repository.go @@ -104,7 +104,8 @@ type OrganisationMemberRepository interface { FetchOrganisationMemberByID(ctx context.Context, memberID string, organisationID string) (*OrganisationMember, error) FetchOrganisationMemberByUserID(ctx context.Context, userID string, organisationID string) (*OrganisationMember, error) FetchAnyInstanceAdminOrRootByUserID(ctx context.Context, userID string) (*OrganisationMember, error) - CountInstanceAdmins(ctx context.Context) (int64, error) + CountRootUsers(ctx context.Context) (int64, error) + CountSuperUsers(ctx context.Context) (int64, error) } type EndpointRepository interface { @@ -230,16 +231,10 @@ type EventTypesRepository interface { FetchAllEventTypes(context.Context, string) ([]ProjectEventType, error) } -type InstanceDefaultsRepository interface { - Create(context.Context, *InstanceDefaults) (*InstanceDefaults, error) - Update(context.Context, string, *InstanceDefaults) (*InstanceDefaults, error) - FetchByID(ctx context.Context, id string) (*InstanceDefaults, error) - LoadPaged(ctx context.Context, pageable Pageable) ([]InstanceDefaults, PaginationData, error) -} - type InstanceOverridesRepository interface { Create(ctx context.Context, record *InstanceOverrides) (*InstanceOverrides, error) Update(ctx context.Context, id string, record *InstanceOverrides) (*InstanceOverrides, error) FetchByID(ctx context.Context, id string) (*InstanceOverrides, error) LoadPaged(ctx context.Context, pageable Pageable) ([]InstanceOverrides, PaginationData, error) + DeleteUnUpdatedKeys(ctx context.Context, scopeType, scopeID string, keysToUpdate map[string]bool) error } diff --git a/internal/pkg/exporter/retention_policies_cfg.go b/internal/pkg/exporter/retention_policies_cfg.go index 88f66bdac3..b766c98806 100644 --- a/internal/pkg/exporter/retention_policies_cfg.go +++ b/internal/pkg/exporter/retention_policies_cfg.go @@ -58,19 +58,6 @@ func (r *RetentionCfg) fetchRetentionPolicyFromDatabase(ctx context.Context, key } } - if !found { - found, err = r.getInstanceDefault(ctx, key, "project", &retentionPolicy) - if err != nil { - return "", err - } - if !found { - found, err = r.getInstanceDefault(ctx, key, "organisation", &retentionPolicy) - if err != nil { - return "", err - } - } - } - // If no value found, fallback to default configuration if !found { retentionPolicy = r.defaultPolicy @@ -89,14 +76,3 @@ func (r *RetentionCfg) getInstanceOverride(ctx context.Context, key, scopeType, } return model != nil && model.Policy != "", nil } - -func (r *RetentionCfg) getInstanceDefault(ctx context.Context, key, scopeType string, model *config.RetentionPolicyConfiguration) (bool, error) { - _, err := instance.FetchDecryptedDefaults(ctx, r.db, key, scopeType, &model) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return false, nil - } - return false, err - } - return model != nil && model.Policy != "", nil -} diff --git a/internal/pkg/instance/instance.go b/internal/pkg/instance/instance.go index dd60dd720b..aa29b6253a 100644 --- a/internal/pkg/instance/instance.go +++ b/internal/pkg/instance/instance.go @@ -6,7 +6,6 @@ import ( "github.com/frain-dev/convoy/config" "github.com/frain-dev/convoy/database" "github.com/frain-dev/convoy/pkg/log" - "github.com/oklog/ulid/v2" "os" ) @@ -16,11 +15,17 @@ const ( ProjectScope = "project" KeyInstanceIngestRate = "InstanceIngestRate" KeyRetentionPolicy = "RetentionPolicy" + KeyStaticIP = "StaticIP" + KeyEnterpriseSSO = "EnterpriseSSO" ) +type Boolean struct { + Value bool `json:"value"` +} + // IngestRate is a Wrapper for InstanceIngestRate int type IngestRate struct { - Value int `json:"value" envconfig:"CONVOY_INSTANCE_INGEST_RATE"` + Value int `json:"value"` } type Defaults struct { @@ -28,56 +33,6 @@ type Defaults struct { RetentionPolicy config.RetentionPolicyConfiguration `json:"retention_policy"` } -func GetInstanceDefaults() *Defaults { - return &Defaults{ - InstanceIngestRate: IngestRate{ - Value: 25, - }, - RetentionPolicy: config.RetentionPolicyConfiguration{ - Policy: "720h", - IsRetentionPolicyEnabled: false, - }, - } -} - -func EncryptAndStoreInstanceDefaults(ctx context.Context, db database.Database, lo log.StdLogger) error { - - encryptionPassphrase := GetEncryptionPassphrase() - defaults := GetInstanceDefaults() - - ingestRateJSON, err := json.Marshal(defaults.InstanceIngestRate) - if err != nil { - lo.WithError(err).Error("error marshaling ingest rate defaults") - return err - } - - retentionJSON, err := json.Marshal(defaults.RetentionPolicy) - if err != nil { - lo.WithError(err).Error("error marshaling retention defaults") - return err - } - - plaintextDefaults := map[string]map[string]string{ - OrganisationScope: { - KeyInstanceIngestRate: string(ingestRateJSON), - KeyRetentionPolicy: string(retentionJSON), - }, - } - - for scopeType, instanceDefaults := range plaintextDefaults { - for key, plaintext := range instanceDefaults { - errI := InsertEncryptedDefault(ctx, db, scopeType, key, plaintext, encryptionPassphrase) - if errI != nil { - lo.WithError(errI).Error("error inserting encrypted default") - return errI - } - } - } - - lo.Info("Encrypted defaults inserted successfully!") - return nil -} - func GetEncryptionPassphrase() string { encryptionPassphrase := os.Getenv("CONVOY_INSTANCE_ENCRYPTION_PASSPHRASE") if encryptionPassphrase == "" { @@ -86,50 +41,13 @@ func GetEncryptionPassphrase() string { return encryptionPassphrase } -func InsertEncryptedDefault(ctx context.Context, db database.Database, scopeType, key, plaintext, encryptionPassphrase string) error { - id := ulid.Make().String() - - query := ` - INSERT INTO convoy.instance_defaults (id, scope_type, key, default_value_cipher, created_at, updated_at) - VALUES ($1, $2, $3, pgp_sym_encrypt($4::text, $5), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ` - - _, err := db.GetDB().ExecContext(ctx, query, id, scopeType, key, plaintext, encryptionPassphrase) - return err -} - -func FetchDecryptedDefaults(ctx context.Context, db database.Database, key, scopeType string, model interface{}) (string, error) { - encryptionPassphrase := GetEncryptionPassphrase() - - var decryptedValue string - - query := ` - SELECT pgp_sym_decrypt(default_value_cipher::bytea, $1) - FROM convoy.instance_defaults - WHERE key = $2 AND scope_type = $3 - ` - - err := db.GetReadDB().QueryRowContext(ctx, query, encryptionPassphrase, key, scopeType).Scan(&decryptedValue) - if err != nil { - return "", err - } - - err = json.Unmarshal([]byte(decryptedValue), model) - if err != nil { - log.WithError(err).Error("error unmarshalling decrypted JSON") - return "", err - } - - return decryptedValue, nil -} - func FetchDecryptedOverrides(ctx context.Context, db database.Database, key string, scopeType, scopeId string, model interface{}) (string, error) { encryptionPassphrase := GetEncryptionPassphrase() var decryptedValue string query := ` - SELECT pgp_sym_decrypt(value_cipher::bytea, $1) + SELECT pgp_sym_decrypt(value_cipher::bytea, CONCAT($1::text, '-', id)) FROM convoy.instance_overrides WHERE key = $2 AND scope_type = $3 AND scope_id = $4 ` diff --git a/internal/pkg/pubsub/ingest/ingest_cfg.go b/internal/pkg/pubsub/ingest/ingest_cfg.go index 637e370e2d..f6b1a90eb0 100644 --- a/internal/pkg/pubsub/ingest/ingest_cfg.go +++ b/internal/pkg/pubsub/ingest/ingest_cfg.go @@ -97,20 +97,7 @@ func (i *IngestCfg) fetchRateLimitFromDatabase(ctx context.Context, key, project } } - if !found { - found, err = i.getInstanceDefault(ctx, key, "project", &ingestRate) - if err != nil { - return 0, err - } - if !found { - found, err = i.getInstanceDefault(ctx, key, "organisation", &ingestRate) - if err != nil { - return 0, err - } - } - } - - // Fallback to default rate if no overrides or defaults found + // Fallback to default rate if no overrides found if !found { ingestRate.Value = i.defaultRate } @@ -128,14 +115,3 @@ func (i *IngestCfg) getInstanceOverride(ctx context.Context, key string, scopeTy } return model != nil && model.Value > 0, nil } - -func (i *IngestCfg) getInstanceDefault(ctx context.Context, key string, scopeType string, model *instance.IngestRate) (bool, error) { - _, err := instance.FetchDecryptedDefaults(ctx, i.db, key, scopeType, &model) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return false, nil - } - return false, err - } - return model != nil && model.Value > 0, nil -} diff --git a/mocks/repository.go b/mocks/repository.go index 1e04aadd4f..4a7a12844d 100644 --- a/mocks/repository.go +++ b/mocks/repository.go @@ -1178,6 +1178,36 @@ func (mr *MockOrganisationMemberRepositoryMockRecorder) CountInstanceAdmins(ctx return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountInstanceAdmins", reflect.TypeOf((*MockOrganisationMemberRepository)(nil).CountInstanceAdmins), ctx) } +// CountRootUsers mocks base method. +func (m *MockOrganisationMemberRepository) CountRootUsers(ctx context.Context) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountRootUsers", ctx) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountRootUsers indicates an expected call of CountRootUsers. +func (mr *MockOrganisationMemberRepositoryMockRecorder) CountRootUsers(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountRootUsers", reflect.TypeOf((*MockOrganisationMemberRepository)(nil).CountRootUsers), ctx) +} + +// CountSuperUsers mocks base method. +func (m *MockOrganisationMemberRepository) CountSuperUsers(ctx context.Context) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountSuperUsers", ctx) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountSuperUsers indicates an expected call of CountSuperUsers. +func (mr *MockOrganisationMemberRepositoryMockRecorder) CountSuperUsers(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountSuperUsers", reflect.TypeOf((*MockOrganisationMemberRepository)(nil).CountSuperUsers), ctx) +} + // CreateOrganisationMember mocks base method. func (m *MockOrganisationMemberRepository) CreateOrganisationMember(ctx context.Context, member *datastore.OrganisationMember) error { m.ctrl.T.Helper() @@ -3046,3 +3076,17 @@ func (mr *MockInstanceOverridesRepositoryMockRecorder) Update(ctx, id, record an mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockInstanceOverridesRepository)(nil).Update), ctx, id, record) } + +// DeleteUnUpdatedKeys mocks base method. +func (m *MockInstanceOverridesRepository) DeleteUnUpdatedKeys(ctx context.Context, scopeType, scopeID string, keysToUpdate map[string]bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUnUpdatedKeys", ctx, scopeType, scopeID, keysToUpdate) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUnUpdatedKeys indicates an expected call of DeleteUnUpdatedKeys. +func (mr *MockInstanceOverridesRepositoryMockRecorder) DeleteUnUpdatedKeys(ctx, scopeType, scopeID, keysToUpdate interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUnUpdatedKeys", reflect.TypeOf((*MockInstanceOverridesRepository)(nil).DeleteUnUpdatedKeys), ctx, scopeType, scopeID, keysToUpdate) +} diff --git a/services/create_organisation.go b/services/create_organisation.go index 63e61af80b..b0d23e02e1 100644 --- a/services/create_organisation.go +++ b/services/create_organisation.go @@ -2,8 +2,10 @@ package services import ( "context" + "encoding/json" "errors" "fmt" + "github.com/frain-dev/convoy/internal/pkg/instance" "time" "github.com/frain-dev/convoy/internal/pkg/license" @@ -20,11 +22,13 @@ import ( ) type CreateOrganisationService struct { - OrgRepo datastore.OrganisationRepository - OrgMemberRepo datastore.OrganisationMemberRepository - NewOrg *models.Organisation - User *datastore.User - Licenser license.Licenser + OrgRepo datastore.OrganisationRepository + OrgMemberRepo datastore.OrganisationMemberRepository + InstanceOverridesRepo datastore.InstanceOverridesRepository + NewOrg *models.Organisation + User *datastore.User + Licenser license.Licenser + RoleType auth.RoleType } var ErrOrgLimit = errors.New("your instance has reached it's organisation limit, upgrade to create new organisations") @@ -49,6 +53,10 @@ func (co *CreateOrganisationService) Run(ctx context.Context) (*datastore.Organi return nil, &ServiceError{ErrMsg: "organisation name is required", Err: err} } + if co.RoleType == "" { + co.RoleType = auth.RoleOrganisationAdmin + } + org := &datastore.Organisation{ UID: ulid.Make().String(), OwnerID: co.User.UID, @@ -72,10 +80,120 @@ func (co *CreateOrganisationService) Run(ctx context.Context) (*datastore.Organi log.FromContext(ctx).WithError(err).Error("failed to create organisation") return nil, &ServiceError{ErrMsg: "failed to create organisation", Err: err} } - _, err = NewOrganisationMemberService(co.OrgMemberRepo, co.Licenser).CreateOrganisationMember(ctx, org, co.User, &auth.Role{Type: auth.RoleOrganisationAdmin}) + _, err = NewOrganisationMemberService(co.OrgMemberRepo, co.Licenser).CreateOrganisationMember(ctx, org, co.User, &auth.Role{Type: co.RoleType}) if err != nil { log.FromContext(ctx).WithError(err).Error("failed to create super_user member for organisation owner") } + err = UpdateInstanceConfig(ctx, co.InstanceOverridesRepo, org, co.NewOrg) + if err != nil { + return nil, err + } + return org, nil } + +func UpdateInstanceConfig(ctx context.Context, repo datastore.InstanceOverridesRepository, org *datastore.Organisation, newOrg *models.Organisation) error { + orgCfg := newOrg.Config + if orgCfg != nil { + if org.Config == nil { + org.Config = &datastore.InstanceConfig{} + } + if org.Config.ProjectConfig == nil { + org.Config.ProjectConfig = &datastore.ProjectInstanceConfig{} + } + + keysToUpdate := map[string]bool{} + + org.Config.StaticIP = nil + if orgCfg.StaticIP != nil { + boolean := instance.Boolean{Value: *orgCfg.StaticIP} + bytes, err := json.Marshal(boolean) + if err != nil { + return err + } + + if err := handleOptionalJSONField(ctx, org, repo, instance.KeyStaticIP, string(bytes)); err != nil { + return err + } + org.Config.StaticIP = orgCfg.StaticIP + keysToUpdate[instance.KeyStaticIP] = true + } + + org.Config.EnterpriseSSO = nil + if orgCfg.EnterpriseSSO != nil { + boolean := instance.Boolean{Value: *orgCfg.EnterpriseSSO} + bytes, err := json.Marshal(boolean) + if err != nil { + return err + } + + if err := handleOptionalJSONField(ctx, org, repo, instance.KeyEnterpriseSSO, string(bytes)); err != nil { + return err + } + org.Config.EnterpriseSSO = orgCfg.EnterpriseSSO + keysToUpdate[instance.KeyEnterpriseSSO] = true + } + + org.Config.ProjectConfig.RetentionPolicy = nil + if orgCfg.ProjectConfig != nil && orgCfg.ProjectConfig.RetentionPolicy != nil && orgCfg.ProjectConfig.RetentionPolicy.Policy != "" { + policy := orgCfg.ProjectConfig.RetentionPolicy.Policy + _, err := time.ParseDuration(policy) + if err != nil { + return err + } + bytes, err := json.Marshal(orgCfg.ProjectConfig.RetentionPolicy) + if err != nil { + return err + } + + if err := handleOptionalJSONField(ctx, org, repo, instance.KeyRetentionPolicy, string(bytes)); err != nil { + return err + } + if org.Config.ProjectConfig.RetentionPolicy == nil { + org.Config.ProjectConfig.RetentionPolicy = &config.RetentionPolicyConfiguration{} + } + org.Config.ProjectConfig.RetentionPolicy.Policy = orgCfg.ProjectConfig.RetentionPolicy.Policy + keysToUpdate[instance.KeyRetentionPolicy] = true + } + + org.Config.ProjectConfig.IngestRateLimit = nil + if orgCfg.ProjectConfig != nil && orgCfg.ProjectConfig.IngestRateLimit != nil { + rateLimit := orgCfg.ProjectConfig.IngestRateLimit + if *rateLimit < 1 { + return fmt.Errorf("ingest rate limit must be greater than or equal to 1: %v", rateLimit) + } + boolean := instance.IngestRate{Value: *orgCfg.ProjectConfig.IngestRateLimit} + bytes, err := json.Marshal(boolean) + if err != nil { + return err + } + + if err := handleOptionalJSONField(ctx, org, repo, instance.KeyInstanceIngestRate, string(bytes)); err != nil { + return err + } + org.Config.ProjectConfig.IngestRateLimit = orgCfg.ProjectConfig.IngestRateLimit + keysToUpdate[instance.KeyInstanceIngestRate] = true + } + + err := deleteUnUpdatedKeys(ctx, repo, org.UID, keysToUpdate) + if err != nil { + return err + } + } + return nil +} + +func deleteUnUpdatedKeys(ctx context.Context, repo datastore.InstanceOverridesRepository, scopeID string, keysToUpdate map[string]bool) error { + return repo.DeleteUnUpdatedKeys(ctx, instance.OrganisationScope, scopeID, keysToUpdate) +} + +func handleOptionalJSONField(ctx context.Context, org *datastore.Organisation, repo datastore.InstanceOverridesRepository, key string, jsonValue string) error { + _, err := repo.Create(ctx, &datastore.InstanceOverrides{ + ScopeType: instance.OrganisationScope, + ScopeID: org.UID, + Key: key, + Value: jsonValue, + }) + return err +} diff --git a/services/create_organisation_test.go b/services/create_organisation_test.go index 409f44c5c8..c6327b6090 100644 --- a/services/create_organisation_test.go +++ b/services/create_organisation_test.go @@ -15,11 +15,12 @@ import ( func provideCreateOrganisationService(ctrl *gomock.Controller, newOrg *models.Organisation, user *datastore.User) *CreateOrganisationService { return &CreateOrganisationService{ - OrgRepo: mocks.NewMockOrganisationRepository(ctrl), - OrgMemberRepo: mocks.NewMockOrganisationMemberRepository(ctrl), - Licenser: mocks.NewMockLicenser(ctrl), - NewOrg: newOrg, - User: user, + OrgRepo: mocks.NewMockOrganisationRepository(ctrl), + OrgMemberRepo: mocks.NewMockOrganisationMemberRepository(ctrl), + InstanceOverridesRepo: mocks.NewMockInstanceOverridesRepository(ctrl), + Licenser: mocks.NewMockLicenser(ctrl), + NewOrg: newOrg, + User: user, } } diff --git a/services/login_sso.go b/services/login_sso.go index 22869dac85..fe78452bd1 100644 --- a/services/login_sso.go +++ b/services/login_sso.go @@ -8,6 +8,7 @@ import ( "github.com/frain-dev/convoy/api/models" "github.com/frain-dev/convoy/api/types" "github.com/frain-dev/convoy/auth/realm/jwt" + "github.com/frain-dev/convoy/database/postgres" "github.com/frain-dev/convoy/datastore" "github.com/frain-dev/convoy/internal/pkg/license" "github.com/frain-dev/convoy/pkg/log" @@ -229,11 +230,12 @@ func (u *LoginUserSSOService) RegisterSSOUser(ctx context.Context, a *types.APIO } co := CreateOrganisationService{ - OrgRepo: u.OrgRepo, - OrgMemberRepo: u.OrgMemberRepo, - Licenser: u.Licenser, - NewOrg: &models.Organisation{Name: t.Data.Payload.OrganizationExternalID}, - User: user, + OrgRepo: u.OrgRepo, + OrgMemberRepo: u.OrgMemberRepo, + InstanceOverridesRepo: postgres.NewInstanceOverridesRepo(a.DB), + Licenser: u.Licenser, + NewOrg: &models.Organisation{Name: t.Data.Payload.OrganizationExternalID}, + User: user, } _, err = co.Run(ctx) diff --git a/services/register_user.go b/services/register_user.go index 81300c2aa0..54a228a71f 100644 --- a/services/register_user.go +++ b/services/register_user.go @@ -22,13 +22,14 @@ import ( ) type RegisterUserService struct { - UserRepo datastore.UserRepository - OrgRepo datastore.OrganisationRepository - OrgMemberRepo datastore.OrganisationMemberRepository - Queue queue.Queuer - JWT *jwt.Jwt - ConfigRepo datastore.ConfigurationRepository - Licenser license.Licenser + UserRepo datastore.UserRepository + OrgRepo datastore.OrganisationRepository + OrgMemberRepo datastore.OrganisationMemberRepository + InstanceOverridesRepo datastore.InstanceOverridesRepository + Queue queue.Queuer + JWT *jwt.Jwt + ConfigRepo datastore.ConfigurationRepository + Licenser license.Licenser BaseURL string Data *models.RegisterUser @@ -87,11 +88,12 @@ func (u *RegisterUserService) Run(ctx context.Context) (*datastore.User, *jwt.To } co := CreateOrganisationService{ - OrgRepo: u.OrgRepo, - OrgMemberRepo: u.OrgMemberRepo, - Licenser: u.Licenser, - NewOrg: &models.Organisation{Name: u.Data.OrganisationName}, - User: user, + OrgRepo: u.OrgRepo, + OrgMemberRepo: u.OrgMemberRepo, + InstanceOverridesRepo: u.InstanceOverridesRepo, + Licenser: u.Licenser, + NewOrg: &models.Organisation{Name: u.Data.OrganisationName}, + User: user, } _, err = co.Run(ctx) diff --git a/services/register_user_test.go b/services/register_user_test.go index de0f376b43..20bc6f3b90 100644 --- a/services/register_user_test.go +++ b/services/register_user_test.go @@ -21,15 +21,16 @@ func provideRegisterUserService(ctrl *gomock.Controller, t *testing.T, baseUrl s c := mocks.NewMockCache(ctrl) return &RegisterUserService{ - UserRepo: mocks.NewMockUserRepository(ctrl), - OrgRepo: mocks.NewMockOrganisationRepository(ctrl), - OrgMemberRepo: mocks.NewMockOrganisationMemberRepository(ctrl), - Queue: mocks.NewMockQueuer(ctrl), - Licenser: mocks.NewMockLicenser(ctrl), - JWT: jwt.NewJwt(&configuration.Auth.Jwt, c), - ConfigRepo: mocks.NewMockConfigurationRepository(ctrl), - BaseURL: baseUrl, - Data: loginUser, + UserRepo: mocks.NewMockUserRepository(ctrl), + OrgRepo: mocks.NewMockOrganisationRepository(ctrl), + OrgMemberRepo: mocks.NewMockOrganisationMemberRepository(ctrl), + InstanceOverridesRepo: mocks.NewMockInstanceOverridesRepository(ctrl), + Queue: mocks.NewMockQueuer(ctrl), + Licenser: mocks.NewMockLicenser(ctrl), + JWT: jwt.NewJwt(&configuration.Auth.Jwt, c), + ConfigRepo: mocks.NewMockConfigurationRepository(ctrl), + BaseURL: baseUrl, + Data: loginUser, } } diff --git a/services/update_organisation.go b/services/update_organisation.go index 93b6520b9a..73893dcba3 100644 --- a/services/update_organisation.go +++ b/services/update_organisation.go @@ -12,10 +12,11 @@ import ( ) type UpdateOrganisationService struct { - OrgRepo datastore.OrganisationRepository - OrgMemberRepo datastore.OrganisationMemberRepository - Org *datastore.Organisation - Update *models.Organisation + OrgRepo datastore.OrganisationRepository + OrgMemberRepo datastore.OrganisationMemberRepository + InstanceOverridesRepo datastore.InstanceOverridesRepository + Org *datastore.Organisation + Update *models.Organisation } func (os *UpdateOrganisationService) Run(ctx context.Context) (*datastore.Organisation, error) { @@ -50,5 +51,10 @@ func (os *UpdateOrganisationService) Run(ctx context.Context) (*datastore.Organi return nil, &ServiceError{ErrMsg: "failed to update organisation", Err: err} } + err = UpdateInstanceConfig(ctx, os.InstanceOverridesRepo, os.Org, os.Update) + if err != nil { + return nil, err + } + return os.Org, nil } diff --git a/services/update_organisation_test.go b/services/update_organisation_test.go index 7dc17d59f6..10dd18dadf 100644 --- a/services/update_organisation_test.go +++ b/services/update_organisation_test.go @@ -15,10 +15,11 @@ import ( func provideUpdateOrganisationService(ctrl *gomock.Controller, org *datastore.Organisation, update *models.Organisation) *UpdateOrganisationService { return &UpdateOrganisationService{ - OrgRepo: mocks.NewMockOrganisationRepository(ctrl), - OrgMemberRepo: mocks.NewMockOrganisationMemberRepository(ctrl), - Org: org, - Update: update, + OrgRepo: mocks.NewMockOrganisationRepository(ctrl), + OrgMemberRepo: mocks.NewMockOrganisationMemberRepository(ctrl), + InstanceOverridesRepo: mocks.NewMockInstanceOverridesRepository(ctrl), + Org: org, + Update: update, } } diff --git a/sql/1734455884.sql b/sql/1734455884.sql index 70bfd41692..ce2ba75d8c 100644 --- a/sql/1734455884.sql +++ b/sql/1734455884.sql @@ -1,15 +1,4 @@ -- +migrate Up -CREATE TABLE IF NOT EXISTS convoy.instance_defaults ( - id CHAR(26) PRIMARY KEY, - scope_type VARCHAR(50) NOT NULL, - key VARCHAR(255) NOT NULL , - default_value_cipher TEXT, - created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, - deleted_at TIMESTAMPTZ, - CONSTRAINT unique_defaults_key UNIQUE (scope_type, key) -); - CREATE TABLE IF NOT EXISTS convoy.instance_overrides ( id CHAR(26) PRIMARY KEY, scope_type VARCHAR(50) NOT NULL, @@ -24,4 +13,3 @@ CREATE TABLE IF NOT EXISTS convoy.instance_overrides ( -- +migrate Down DROP TABLE IF EXISTS convoy.instance_overrides; -DROP TABLE IF EXISTS convoy.instance_defaults;