From eb9dcc4e8b414f0807ea22a4438cfe9777abcf9f Mon Sep 17 00:00:00 2001 From: bingtianyiyan <48984656+bingtianyiyan@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:02:01 +0800 Subject: [PATCH] fix: fix SelectCount error #68 (#69) --- gplus/dao.go | 8 +++ gplus/query.go | 124 ++++++++++++++++++++++++++++++++++++++++++++++ tests/dao_test.go | 12 +++++ 3 files changed, 144 insertions(+) diff --git a/gplus/dao.go b/gplus/dao.go index 986459a..24e4293 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -197,6 +197,8 @@ func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Pag func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) { var count int64 resultDb := buildCondition(q, opts...) + //fix 查询有设置Select并且数量只有一个且有设置别名,生成sql不对问题 + resultDb.Statement.Selects = nil resultDb.Count(&count) return count, resultDb } @@ -249,6 +251,12 @@ func Begin(opts ...*sql.TxOptions) *gorm.DB { return db.Begin(opts...) } +// 事务 +func Tx(txFunc func(tx *gorm.DB) error, opts ...OptionFunc) error { + db := getDb(opts...) + return db.Transaction(txFunc) +} + func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB { page := p.Current pageSize := p.Size diff --git a/gplus/query.go b/gplus/query.go index 93a7013..b26a743 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -318,6 +318,130 @@ func (q *QueryCond[T]) Set(column any, val any) *QueryCond[T] { return q } +/* +* 自定义条件 + */ + +// AndCond 拼接 AND +func (q *QueryCond[T]) AndCond(cond bool, fn ...func(q *QueryCond[T])) *QueryCond[T] { + if cond { + return q.And(fn...) + } + return q +} + +// OrCond 拼接 OR +func (q *QueryCond[T]) OrCond(cond bool, fn ...func(q *QueryCond[T])) *QueryCond[T] { + if cond { + return q.Or(fn...) + } + return q +} + +// EqCond 等于 = +func (q *QueryCond[T]) EqCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Eq(column, val) + } + return q +} + +// NeCond 不等于 != +func (q *QueryCond[T]) NeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Ne(column, val) + } + return q +} + +// GtCond 大于 > +func (q *QueryCond[T]) GtCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Gt(column, val) + } + return q +} + +// GeCond 大于等于 >= +func (q *QueryCond[T]) GeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Ge(column, val) + } + return q +} + +// LtCond 小于 < +func (q *QueryCond[T]) LtCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Lt(column, val) + } + return q +} + +// LeCond 小于等于 <= +func (q *QueryCond[T]) LeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Le(column, val) + } + return q +} + +// LikeCond 模糊 LIKE '%值%' +func (q *QueryCond[T]) LikeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Like(column, val) + } + return q +} + +// NotLikeCond 非模糊 NOT LIKE '%值%' +func (q *QueryCond[T]) NotLikeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.NotLike(column, val) + } + return q +} + +// LikeLeftCond 左模糊 LIKE '%值' +func (q *QueryCond[T]) LikeLeftCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.LikeLeft(column, val) + } + return q +} + +// NotLikeLeftCond 非左模糊 NOT LIKE '%值' +func (q *QueryCond[T]) NotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.NotLike(column, val) + } + return q +} + +// LikeRightCond 右模糊 LIKE '值%' +func (q *QueryCond[T]) LikeRightCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.LikeRight(column, val) + } + return q +} + +// NotLikeRightCond 非右模糊 NOT LIKE '值%' +func (q *QueryCond[T]) NotLikeRightCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.NotLikeRight(column, val) + } + return q +} + +// InCond 字段 IN (值1, 值2, ...) +func (q *QueryCond[T]) InCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.In(column, val) + } + return q +} + func (q *QueryCond[T]) addExpression(sqlSegments ...SqlSegment) { if len(sqlSegments) == 1 { q.handleSingle(sqlSegments[0]) diff --git a/tests/dao_test.go b/tests/dao_test.go index fd71969..d077f92 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -557,6 +557,18 @@ func TestSelectGeneric7(t *testing.T) { } } +func TestTx(t *testing.T) { + deleteOldData() + users := getUsers() + err := gplus.Tx(func(tx *gorm.DB) error { + err := gplus.InsertBatch[User](users, gplus.Db(tx)).Error + return err + }) + if err != nil { + t.Errorf(err.Error()) + } +} + func deleteOldData() { q, u := gplus.NewQuery[User]() q.IsNotNull(&u.ID)