Skip to content

Commit

Permalink
Make DatastoreConfig an interface (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkajla12 authored May 17, 2024
1 parent 6390296 commit 87cce88
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 128 deletions.
30 changes: 15 additions & 15 deletions cmd/warrant/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ type ServiceEnv struct {
Datastore database.Database
}

func (env ServiceEnv) DB() database.Database {
func (env *ServiceEnv) DB() database.Database {
return env.Datastore
}

func (env *ServiceEnv) InitDB(cfg config.Config) error {
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFunc()

if cfg.GetDatastore().MySQL.Hostname != "" || cfg.GetDatastore().MySQL.DSN != "" {
db := database.NewMySQL(*cfg.GetDatastore().MySQL)
if cfg.GetDatastore().GetMySQL().Hostname != "" || cfg.GetDatastore().GetMySQL().DSN != "" {
db := database.NewMySQL(*cfg.GetDatastore().GetMySQL())
err := db.Connect(ctx)
if err != nil {
return err
Expand All @@ -74,8 +74,8 @@ func (env *ServiceEnv) InitDB(cfg config.Config) error {
return nil
}

if cfg.GetDatastore().Postgres.Hostname != "" {
db := database.NewPostgres(*cfg.GetDatastore().Postgres)
if cfg.GetDatastore().GetPostgres().Hostname != "" {
db := database.NewPostgres(*cfg.GetDatastore().GetPostgres())
err := db.Connect(ctx)
if err != nil {
return err
Expand All @@ -92,8 +92,8 @@ func (env *ServiceEnv) InitDB(cfg config.Config) error {
return nil
}

if cfg.GetDatastore().SQLite.Database != "" {
db := database.NewSQLite(*cfg.GetDatastore().SQLite)
if cfg.GetDatastore().GetSQLite().Database != "" {
db := database.NewSQLite(*cfg.GetDatastore().GetSQLite())
err := db.Connect(ctx)
if err != nil {
return err
Expand All @@ -113,8 +113,8 @@ func (env *ServiceEnv) InitDB(cfg config.Config) error {
return errors.New("invalid database configuration provided")
}

func NewServiceEnv() ServiceEnv {
return ServiceEnv{
func NewServiceEnv() *ServiceEnv {
return &ServiceEnv{
Datastore: nil,
}
}
Expand Down Expand Up @@ -155,22 +155,22 @@ func main() {
querySvc := query.NewService(svcEnv, objectTypeSvc, warrantSvc, objectSvc)

// Init feature service
featureSvc := feature.NewService(&svcEnv, objectSvc)
featureSvc := feature.NewService(svcEnv, objectSvc)

// Init permission service
permissionSvc := permission.NewService(&svcEnv, objectSvc)
permissionSvc := permission.NewService(svcEnv, objectSvc)

// Init pricing tier service
pricingTierSvc := pricingtier.NewService(&svcEnv, objectSvc)
pricingTierSvc := pricingtier.NewService(svcEnv, objectSvc)

// Init role service
roleSvc := role.NewService(&svcEnv, objectSvc)
roleSvc := role.NewService(svcEnv, objectSvc)

// Init tenant service
tenantSvc := tenant.NewService(&svcEnv, objectSvc)
tenantSvc := tenant.NewService(svcEnv, objectSvc)

// Init user service
userSvc := user.NewService(&svcEnv, objectSvc)
userSvc := user.NewService(svcEnv, objectSvc)

svcs := []service.Service{
checkSvc,
Expand Down
10 changes: 5 additions & 5 deletions pkg/authz/objecttype/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ type Model interface {
}

type ObjectType struct {
ID int64 `mysql:"id" postgres:"id" sqlite:"id"`
TypeId string `mysql:"typeId" postgres:"type_id" sqlite:"typeId"`
ID int64 `mysql:"id" postgres:"id" sqlite:"id"`
TypeId string `mysql:"typeId" postgres:"type_id" sqlite:"typeId"`
Definition string `mysql:"definition" postgres:"definition" sqlite:"definition"`
CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"`
UpdatedAt time.Time `mysql:"updatedAt" postgres:"updated_at" sqlite:"updatedAt"`
DeletedAt *time.Time `mysql:"deletedAt" postgres:"deleted_at" sqlite:"deletedAt"`
CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"`
UpdatedAt time.Time `mysql:"updatedAt" postgres:"updated_at" sqlite:"updatedAt"`
DeletedAt *time.Time `mysql:"deletedAt" postgres:"deleted_at" sqlite:"deletedAt"`
}

func (objectType ObjectType) GetID() int64 {
Expand Down
4 changes: 0 additions & 4 deletions pkg/authz/warrant/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,5 @@ func buildFilterOptions(r *http.Request) *FilterParams {
filterOptions.SubjectRelation = queryParams.Get("subjectRelation")
}

if queryParams.Has("policy") {
filterOptions.Policy = Policy(queryParams.Get("policy"))
}

return &filterOptions
}
6 changes: 0 additions & 6 deletions pkg/authz/warrant/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ type FilterParams struct {
SubjectType string `json:"subjectType,omitempty"`
SubjectId string `json:"subjectId,omitempty"`
SubjectRelation string `json:"subjectRelation,omitempty"`
Policy Policy `json:"policy,omitempty"`
}

func (fp FilterParams) String() string {
Expand Down Expand Up @@ -58,10 +57,6 @@ func (fp FilterParams) String() string {
s = fmt.Sprintf("%s&subjectRelation=%s", s, fp.SubjectRelation)
}

if fp.Policy != "" {
s = fmt.Sprintf("%s&policy=%s", s, fp.Policy)
}

return strings.TrimPrefix(s, "&")
}

Expand All @@ -79,7 +74,6 @@ func (parser WarrantListParamParser) GetSupportedSortBys() []string {
}

func (parser WarrantListParamParser) ParseValue(val string, sortBy string) (interface{}, error) {
// TODO: add support for more sortBy columns
switch sortBy {
case "createdAt":
value, err := time.Parse(time.RFC3339, val)
Expand Down
22 changes: 11 additions & 11 deletions pkg/authz/warrant/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ type Model interface {
}

type Warrant struct {
ID int64 `mysql:"id" postgres:"id" sqlite:"id"`
ObjectType string `mysql:"objectType" postgres:"object_type" sqlite:"objectType"`
ObjectId string `mysql:"objectId" postgres:"object_id" sqlite:"objectId"`
Relation string `mysql:"relation" postgres:"relation" sqlite:"relation"`
SubjectType string `mysql:"subjectType" postgres:"subject_type" sqlite:"subjectType"`
SubjectId string `mysql:"subjectId" postgres:"subject_id" sqlite:"subjectId"`
ID int64 `mysql:"id" postgres:"id" sqlite:"id"`
ObjectType string `mysql:"objectType" postgres:"object_type" sqlite:"objectType"`
ObjectId string `mysql:"objectId" postgres:"object_id" sqlite:"objectId"`
Relation string `mysql:"relation" postgres:"relation" sqlite:"relation"`
SubjectType string `mysql:"subjectType" postgres:"subject_type" sqlite:"subjectType"`
SubjectId string `mysql:"subjectId" postgres:"subject_id" sqlite:"subjectId"`
SubjectRelation string `mysql:"subjectRelation" postgres:"subject_relation" sqlite:"subjectRelation"`
Policy Policy `mysql:"policy" postgres:"policy" sqlite:"policy"`
PolicyHash string `mysql:"policyHash" postgres:"policy_hash" sqlite:"policyHash"`
CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"`
UpdatedAt time.Time `mysql:"updatedAt" postgres:"updated_at" sqlite:"updatedAt"`
DeletedAt *time.Time `mysql:"deletedAt" postgres:"deleted_at" sqlite:"deletedAt"`
Policy Policy `mysql:"policy" postgres:"policy" sqlite:"policy"`
PolicyHash string `mysql:"policyHash" postgres:"policy_hash" sqlite:"policyHash"`
CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"`
UpdatedAt time.Time `mysql:"updatedAt" postgres:"updated_at" sqlite:"updatedAt"`
DeletedAt *time.Time `mysql:"deletedAt" postgres:"deleted_at" sqlite:"deletedAt"`
}

func (warrant Warrant) GetID() int64 {
Expand Down
5 changes: 0 additions & 5 deletions pkg/authz/warrant/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,6 @@ func (repo MySQLRepository) List(ctx context.Context, filterParams FilterParams,
replacements = append(replacements, filterParams.SubjectRelation)
}

if filterParams.Policy != "" {
query = fmt.Sprintf("%s AND policyHash = ?", query)
replacements = append(replacements, filterParams.Policy.Hash())
}

if listParams.NextCursor != nil {
comparisonOp := "<"
if listParams.SortOrder == service.SortOrderAsc {
Expand Down
5 changes: 0 additions & 5 deletions pkg/authz/warrant/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,6 @@ func (repo PostgresRepository) List(ctx context.Context, filterParams FilterPara
replacements = append(replacements, filterParams.SubjectRelation)
}

if filterParams.Policy != "" {
query = fmt.Sprintf("%s AND policy_hash = ?", query)
replacements = append(replacements, filterParams.Policy.Hash())
}

if listParams.NextCursor != nil {
comparisonOp := "<"
if listParams.SortOrder == service.SortOrderAsc {
Expand Down
5 changes: 0 additions & 5 deletions pkg/authz/warrant/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,6 @@ func (repo SQLiteRepository) List(ctx context.Context, filterParams FilterParams
replacements = append(replacements, filterParams.SubjectRelation)
}

if filterParams.Policy != "" {
query = fmt.Sprintf("%s AND policyHash = ?", query)
replacements = append(replacements, filterParams.Policy.Hash())
}

if listParams.NextCursor != nil {
comparisonOp := "<"
if listParams.SortOrder == service.SortOrderAsc {
Expand Down
38 changes: 28 additions & 10 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ type Config interface {
GetLogLevel() int8
GetEnableAccessLog() bool
GetAutoMigrate() bool
GetDatastore() *DatastoreConfig
GetDatastore() DatastoreConfig
}

type WarrantConfig struct {
Port int `mapstructure:"port"`
LogLevel int8 `mapstructure:"logLevel"`
EnableAccessLog bool `mapstructure:"enableAccessLog"`
AutoMigrate bool `mapstructure:"autoMigrate"`
Datastore *DatastoreConfig `mapstructure:"datastore"`
Authentication *AuthConfig `mapstructure:"authentication"`
Check *CheckConfig `mapstructure:"check"`
Port int `mapstructure:"port"`
LogLevel int8 `mapstructure:"logLevel"`
EnableAccessLog bool `mapstructure:"enableAccessLog"`
AutoMigrate bool `mapstructure:"autoMigrate"`
Datastore *WarrantDatastoreConfig `mapstructure:"datastore"`
Authentication *AuthConfig `mapstructure:"authentication"`
Check *CheckConfig `mapstructure:"check"`
}

func (warrantConfig WarrantConfig) GetPort() int {
Expand All @@ -70,7 +70,7 @@ func (warrantConfig WarrantConfig) GetAutoMigrate() bool {
return warrantConfig.AutoMigrate
}

func (warrantConfig WarrantConfig) GetDatastore() *DatastoreConfig {
func (warrantConfig WarrantConfig) GetDatastore() DatastoreConfig {
return warrantConfig.Datastore
}

Expand All @@ -82,12 +82,30 @@ func (warrantConfig WarrantConfig) GetCheck() *CheckConfig {
return warrantConfig.Check
}

type DatastoreConfig struct {
type DatastoreConfig interface {
GetMySQL() *MySQLConfig
GetPostgres() *PostgresConfig
GetSQLite() *SQLiteConfig
}

type WarrantDatastoreConfig struct {
MySQL *MySQLConfig `mapstructure:"mysql"`
Postgres *PostgresConfig `mapstructure:"postgres"`
SQLite *SQLiteConfig `mapstructure:"sqlite"`
}

func (warrantDatastoreConfig WarrantDatastoreConfig) GetMySQL() *MySQLConfig {
return warrantDatastoreConfig.MySQL
}

func (warrantDatastoreConfig WarrantDatastoreConfig) GetPostgres() *PostgresConfig {
return warrantDatastoreConfig.Postgres
}

func (warrantDatastoreConfig WarrantDatastoreConfig) GetSQLite() *SQLiteConfig {
return warrantDatastoreConfig.SQLite
}

type MySQLConfig struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
Expand Down
32 changes: 16 additions & 16 deletions pkg/database/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func (q SqlTx) ExecContext(ctx context.Context, query string, args ...interface{
query = q.Tx.Rebind(query)
result, err := q.Tx.ExecContext(ctx, query, args...)
if err != nil {
switch err {
case sql.ErrNoRows:
switch {
case errors.Is(err, sql.ErrNoRows):
return result, err
default:
return result, errors.Wrap(err, "SqlTx error")
Expand All @@ -78,8 +78,8 @@ func (q SqlTx) GetContext(ctx context.Context, dest interface{}, query string, a
query = q.Tx.Rebind(query)
err := q.Tx.GetContext(ctx, dest, query, args...)
if err != nil {
switch err {
case sql.ErrNoRows:
switch {
case errors.Is(err, sql.ErrNoRows):
return err
default:
return errors.Wrap(err, "SqlTx error")
Expand All @@ -92,8 +92,8 @@ func (q SqlTx) NamedExecContext(ctx context.Context, query string, arg interface
query = q.Tx.Rebind(query)
result, err := q.Tx.NamedExecContext(ctx, query, arg)
if err != nil {
switch err {
case sql.ErrNoRows:
switch {
case errors.Is(err, sql.ErrNoRows):
return result, err
default:
return result, errors.Wrap(err, "SqlTx error")
Expand All @@ -119,8 +119,8 @@ func (q SqlTx) SelectContext(ctx context.Context, dest interface{}, query string
query = q.Tx.Rebind(query)
err := q.Tx.SelectContext(ctx, dest, query, args...)
if err != nil {
switch err {
case sql.ErrNoRows:
switch {
case errors.Is(err, sql.ErrNoRows):
return err
default:
return errors.Wrap(err, "SqlTx error")
Expand Down Expand Up @@ -203,8 +203,8 @@ func (ds SQL) ExecContext(ctx context.Context, query string, args ...interface{}

result, err := queryable.ExecContext(ctx, query, args...)
if err != nil {
switch err {
case sql.ErrNoRows:
switch {
case errors.Is(err, sql.ErrNoRows):
return result, err
default:
return result, errors.Wrap(err, "Error when calling sql ExecContext")
Expand All @@ -222,8 +222,8 @@ func (ds SQL) GetContext(ctx context.Context, dest interface{}, query string, ar

err := queryable.GetContext(ctx, dest, query, args...)
if err != nil {
switch err {
case sql.ErrNoRows:
switch {
case errors.Is(err, sql.ErrNoRows):
return err
default:
return errors.Wrap(err, "Error when calling sql GetContext")
Expand All @@ -241,8 +241,8 @@ func (ds SQL) NamedExecContext(ctx context.Context, query string, arg interface{

result, err := queryable.NamedExecContext(ctx, query, arg)
if err != nil {
switch err {
case sql.ErrNoRows:
switch {
case errors.Is(err, sql.ErrNoRows):
return result, err
default:
return result, errors.Wrap(err, "Error when calling sql NamedExecContext")
Expand Down Expand Up @@ -278,8 +278,8 @@ func (ds SQL) SelectContext(ctx context.Context, dest interface{}, query string,

err := queryable.SelectContext(ctx, dest, query, args...)
if err != nil {
switch err {
case sql.ErrNoRows:
switch {
case errors.Is(err, sql.ErrNoRows):
return err
default:
return errors.Wrap(err, "Error when calling sql SelectContext")
Expand Down
22 changes: 0 additions & 22 deletions tests/v1/warrants-list.json
Original file line number Diff line number Diff line change
Expand Up @@ -295,28 +295,6 @@
]
}
},
{
"name": "listWarrantsFilterByPolicy",
"request": {
"method": "GET",
"url": "/v1/warrants?policy=tenant%20%3D%3D%20%22tenant-a%22%20%26%26%20organization%20%3D%3D%20%22org-a%22"
},
"expectedResponse": {
"statusCode": 200,
"body": [
{
"objectType": "role",
"objectId": "senior-accountant",
"relation": "member",
"subject": {
"objectType": "user",
"objectId": "user-a"
},
"policy": "tenant == \"tenant-a\" \u0026\u0026 organization == \"org-a\""
}
]
}
},
{
"name": "removeRoleSeniorAccountantFromUserAWithPolicy",
"request": {
Expand Down
Loading

0 comments on commit 87cce88

Please sign in to comment.