Skip to content

Commit

Permalink
fix: fix SelectCount error #68 (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
bingtianyiyan authored Sep 13, 2023
1 parent bb6423a commit eb9dcc4
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
8 changes: 8 additions & 0 deletions gplus/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Pag
func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) {
var count int64
resultDb := buildCondition(q, opts...)
//fix 查询有设置Select并且数量只有一个且有设置别名,生成sql不对问题
resultDb.Statement.Selects = nil
resultDb.Count(&count)
return count, resultDb
}
Expand Down Expand Up @@ -249,6 +251,12 @@ func Begin(opts ...*sql.TxOptions) *gorm.DB {
return db.Begin(opts...)
}

// 事务
func Tx(txFunc func(tx *gorm.DB) error, opts ...OptionFunc) error {
db := getDb(opts...)
return db.Transaction(txFunc)
}

func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB {
page := p.Current
pageSize := p.Size
Expand Down
124 changes: 124 additions & 0 deletions gplus/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,130 @@ func (q *QueryCond[T]) Set(column any, val any) *QueryCond[T] {
return q
}

/*
* 自定义条件
*/

// AndCond 拼接 AND
func (q *QueryCond[T]) AndCond(cond bool, fn ...func(q *QueryCond[T])) *QueryCond[T] {
if cond {
return q.And(fn...)
}
return q
}

// OrCond 拼接 OR
func (q *QueryCond[T]) OrCond(cond bool, fn ...func(q *QueryCond[T])) *QueryCond[T] {
if cond {
return q.Or(fn...)
}
return q
}

// EqCond 等于 =
func (q *QueryCond[T]) EqCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Eq(column, val)
}
return q
}

// NeCond 不等于 !=
func (q *QueryCond[T]) NeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Ne(column, val)
}
return q
}

// GtCond 大于 >
func (q *QueryCond[T]) GtCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Gt(column, val)
}
return q
}

// GeCond 大于等于 >=
func (q *QueryCond[T]) GeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Ge(column, val)
}
return q
}

// LtCond 小于 <
func (q *QueryCond[T]) LtCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Lt(column, val)
}
return q
}

// LeCond 小于等于 <=
func (q *QueryCond[T]) LeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Le(column, val)
}
return q
}

// LikeCond 模糊 LIKE '%值%'
func (q *QueryCond[T]) LikeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.Like(column, val)
}
return q
}

// NotLikeCond 非模糊 NOT LIKE '%值%'
func (q *QueryCond[T]) NotLikeCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.NotLike(column, val)
}
return q
}

// LikeLeftCond 左模糊 LIKE '%值'
func (q *QueryCond[T]) LikeLeftCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.LikeLeft(column, val)
}
return q
}

// NotLikeLeftCond 非左模糊 NOT LIKE '%值'
func (q *QueryCond[T]) NotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.NotLike(column, val)
}
return q
}

// LikeRightCond 右模糊 LIKE '值%'
func (q *QueryCond[T]) LikeRightCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.LikeRight(column, val)
}
return q
}

// NotLikeRightCond 非右模糊 NOT LIKE '值%'
func (q *QueryCond[T]) NotLikeRightCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.NotLikeRight(column, val)
}
return q
}

// InCond 字段 IN (值1, 值2, ...)
func (q *QueryCond[T]) InCond(cond bool, column any, val any) *QueryCond[T] {
if cond {
return q.In(column, val)
}
return q
}

func (q *QueryCond[T]) addExpression(sqlSegments ...SqlSegment) {
if len(sqlSegments) == 1 {
q.handleSingle(sqlSegments[0])
Expand Down
12 changes: 12 additions & 0 deletions tests/dao_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,18 @@ func TestSelectGeneric7(t *testing.T) {
}
}

func TestTx(t *testing.T) {
deleteOldData()
users := getUsers()
err := gplus.Tx(func(tx *gorm.DB) error {
err := gplus.InsertBatch[User](users, gplus.Db(tx)).Error
return err
})
if err != nil {
t.Errorf(err.Error())
}
}

func deleteOldData() {
q, u := gplus.NewQuery[User]()
q.IsNotNull(&u.ID)
Expand Down

0 comments on commit eb9dcc4

Please sign in to comment.