Skip to content

Commit

Permalink
overwrite default connection for model
Browse files Browse the repository at this point in the history
  • Loading branch information
kkumar-gcc committed Jul 29, 2023
1 parent 5c6fdac commit aedcab4
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 17 deletions.
11 changes: 5 additions & 6 deletions contracts/database/orm/mocks/Association.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions contracts/database/orm/mocks/Orm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

76 changes: 65 additions & 11 deletions contracts/database/orm/mocks/Transaction.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Orm interface {
DB() (*sql.DB, error)
Query() Query
Factory() Factory
Model(value any) Orm
Observe(model any, observer Observer)
Transaction(txFunc func(tx Transaction) error) error
WithContext(ctx context.Context) Orm
Expand Down Expand Up @@ -84,6 +85,10 @@ type Association interface {
Count() int64
}

type Model interface {
Connection() string
}

//go:generate mockery --name=Cursor
type Cursor interface {
Scan(value any) error
Expand Down
16 changes: 16 additions & 0 deletions database/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ func (r *OrmImpl) Factory() ormcontract.Factory {
return NewFactoryImpl(r.query)
}

func (r *OrmImpl) Model(value any) ormcontract.Orm {
model, ok := value.(ormcontract.Model)
if !ok {
return r
}

// Check if the model has a connection specified
if conn := model.Connection(); conn != "" {
r.query = r.query.Model(value)
return r.Connection(conn)
}

// If the model doesn't have a connection specified, return the current OrmImpl instance
return r
}

func (r *OrmImpl) Observe(model any, observer ormcontract.Observer) {
orm.Observers = append(orm.Observers, orm.Observer{
Model: model,
Expand Down
8 changes: 8 additions & 0 deletions database/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ type User struct {
Avatar string
}

func (u *User) Connection() string {
return "postgresql"
}

type OrmSuite struct {
suite.Suite
orm *OrmImpl
Expand Down Expand Up @@ -109,6 +113,10 @@ func (s *OrmSuite) TestConnection() {
}
}

func (s *OrmSuite) TestModel() {
s.NotNil(s.orm.Model(&User{}))
}

func (s *OrmSuite) TestDB() {
db, err := s.orm.DB()
s.NotNil(db)
Expand Down

0 comments on commit aedcab4

Please sign in to comment.