From bf9d7615196f4eddd6a663ade61bfe496eb589f3 Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Mon, 19 Jun 2023 17:59:14 +0800 Subject: [PATCH 01/12] add new feature --- gplus/dao.go | 84 ++++++++++++++++++++++++++-- gplus/query.go | 25 ++++++++- tests/dao_test.go | 128 +++++++++++++++++++++++++++++++++++++++++-- tests/select_test.go | 4 +- tests/update_test.go | 4 +- tests/user.go | 24 ++++---- 6 files changed, 245 insertions(+), 24 deletions(-) diff --git a/gplus/dao.go b/gplus/dao.go index 48bbf4e..4333cc0 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -14,16 +14,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package gplus import ( "database/sql" + "reflect" + "strings" + "time" + "github.com/acmestack/gorm-plus/constants" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils" - "reflect" - "strings" ) var globalDb *gorm.DB @@ -38,13 +41,15 @@ type Page[T any] struct { Size int Total int64 Records []*T + CurTime int64 RecordsMap []T } type Dao[T any] struct{} func (dao Dao[T]) NewQuery() (*QueryCond[T], *T) { - return NewQuery[T]() + q := &QueryCond[T]{} + return q, nil } func NewPage[T any](current, size int) *Page[T] { @@ -173,6 +178,70 @@ func SelectList[T any](q *QueryCond[T], opts ...OptionFunc) ([]*T, *gorm.DB) { return results, resultDb } +// add start + +// SelectByIdGeneric 查询时,转化为其他类型 +// 第一个泛型代表数据库表实体 +// 第二个泛型代表返回记录实体 +func SelectByIdGeneric[T any, R any](id any, opts ...OptionFunc) (*R, *gorm.DB) { + q, _ := NewQuery[T]() + q.Eq(getPkColumnName[T](), id) + var entity R + resultDb := buildCondition(q, opts...) + return &entity, resultDb.First(&entity) +} + +// Pluck 取某列值,不去重 +func Pluck[T any, R any](column string, q *QueryCond[T], opts ...OptionFunc) ([]R, *gorm.DB) { + var results []R + resultDb := buildCondition(q, opts...) + resultDb.Pluck(column, &results) + return results, resultDb +} + +// PluckDistinct 取某列值,去重 +func PluckDistinct[T any, R any](column string, q *QueryCond[T], opts ...OptionFunc) ([]R, *gorm.DB) { + var results []R + resultDb := buildCondition(q, opts...) + resultDb.Distinct(column).Pluck(column, &results) + return results, resultDb +} + +// SelectListBySql 按任意SQL执行,指定返回类型数组 +func SelectListBySql[R any](querySql string, opts ...OptionFunc) ([]*R, *gorm.DB) { + resultDb := getDb(opts...) + var results []*R + resultDb = resultDb.Raw(querySql).Scan(&results) + return results, resultDb +} + +// SelectOneBySql 根据原始的SQL语句,取一个 +func SelectOneBySql[R any](countSql string, opts ...OptionFunc) (R, *gorm.DB) { + resultDb := getDb(opts...) + var result R + resultDb = resultDb.Raw(countSql).Scan(&result) + return result, resultDb +} + +// ExcSql 按任意SQL执行,返回影响的行 +func ExcSql(querySql string, opts ...OptionFunc) *gorm.DB { + resultDb := getDb(opts...) + resultDb = resultDb.Exec(querySql) + return resultDb +} + +// add end + +// SelectListGeneric 根据条件查询多条记录 +// 第一个泛型代表数据库表实体 +// 第二个泛型代表返回记录实体 +func SelectListGeneric[T any, R any](q *QueryCond[T], opts ...OptionFunc) ([]*R, *gorm.DB) { + resultDb := buildCondition(q, opts...) + var results []*R + resultDb.Scan(&results) + return results, resultDb +} + // SelectPage 根据条件分页查询记录 func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Page[T], *gorm.DB) { option := getOption(opts) @@ -190,6 +259,7 @@ func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Pag var results []*T resultDb.Scopes(paginate(page)).Find(&results) page.Records = results + page.CurTime = time.Now().UnixMilli() return page, resultDb } @@ -202,9 +272,12 @@ func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) { } // Exists 根据条件判断记录是否存在 -func Exists[T any](q *QueryCond[T], opts ...OptionFunc) (bool, *gorm.DB) { +func Exists[T any](q *QueryCond[T], opts ...OptionFunc) (bool, error) { count, resultDb := SelectCount[T](q, opts...) - return count > 0, resultDb + if resultDb.Error == gorm.ErrRecordNotFound { + return false, nil + } + return count > 0, resultDb.Error } // SelectPageGeneric 根据传入的泛型封装分页记录 @@ -232,6 +305,7 @@ func SelectPageGeneric[T any, R any](page *Page[R], q *QueryCond[T], opts ...Opt resultDb.Scopes(paginate(page)).Scan(&results) page.Records = results } + page.CurTime = time.Now().UnixMilli() return page, resultDb } diff --git a/gplus/query.go b/gplus/query.go index d328d7d..6b95d0b 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -235,7 +235,7 @@ func (q *QueryCond[T]) OrderByAsc(columns ...any) *QueryCond[T] { return q } -// Having HAVING SQl语句 +// Having SQl语句 func (q *QueryCond[T]) Having(having string, args ...any) *QueryCond[T] { q.havingBuilder.WriteString(having) if len(args) == 1 { @@ -388,3 +388,26 @@ func (q *QueryCond[T]) buildOrder(orderType string, columns ...string) { q.orderBuilder.WriteString(orderType) } } + +// 根据条件,执行方法 +func (q *QueryCond[T]) Case(isTrue bool, handleFunc func()) *QueryCond[T] { + if isTrue { + handleFunc() + } + return q +} + +// 重置查询条件 +func (q *QueryCond[T]) Reset() *QueryCond[T] { + q.selectColumns = q.selectColumns[:0] + q.distinctColumns = q.distinctColumns[:0] + q.queryExpressions = q.queryExpressions[:0] + q.orderBuilder.Reset() + q.groupBuilder.Reset() + q.havingBuilder.Reset() + q.havingArgs = q.havingArgs[:0] + q.queryArgs = q.queryArgs[:0] + q.updateMap = nil + + return q +} diff --git a/tests/dao_test.go b/tests/dao_test.go index 536fc02..2d22fb2 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -183,7 +183,7 @@ func TestUpdateZeroById(t *testing.T) { users := getUsers() gplus.InsertBatch[User](users) - updateUser := &User{ID: users[0].ID, Score: 100, Age: 25} + updateUser := &User{Base: Base{ID: users[0].ID}, Score: 100, Age: 25} if res := gplus.UpdateZeroById[User](updateUser); res.Error != nil || res.RowsAffected != 1 { t.Errorf("errors happened when deleteByIds: %v, affected: %v", res.Error, res.RowsAffected) @@ -376,11 +376,11 @@ func TestExists(t *testing.T) { query, model := gplus.NewQuery[User]() query.Eq(&model.Username, users[0].Username) exists, db := gplus.Exists[User](query) - if db.Error != nil { - t.Errorf("errors happened when SelectCount : %v", db.Error) + if db != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error()) } if !exists { - t.Errorf("errors happened when SelectCount : %v", db.Error) + t.Errorf("errors happened when SelectCount : %v", db.Error()) } } @@ -557,6 +557,126 @@ func TestSelectGeneric7(t *testing.T) { } } +func TestCase(t *testing.T) { + deleteOldData() + users := getUsers() + gplus.InsertBatch[User](users) + + query, model := gplus.NewQuery[User]() + query.Case(true, func() { + query.Eq(&model.Username, "afumu1") + }) + count, db := gplus.SelectCount(query) + if db.Error != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error) + } + if count != 1 { + t.Errorf("count expects: %v, got %v", 1, count) + } +} + +func TestPluck(t *testing.T) { + deleteOldData() + users := getUsers() + gplus.InsertBatch[User](users) + + query, _ := gplus.NewQuery[User]() + + usernames, db := gplus.Pluck[User, string]("username", query) + if db.Error != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error.Error()) + } + if len(usernames) == 0 { + t.Errorf("count expects: %v, got %v", len(usernames), 0) + } else { + for _, item := range usernames { + fmt.Printf("pluck list %s\n", item) + } + + } +} + +func TestPluckDistinct(t *testing.T) { + deleteOldData() + users := getUsers() + gplus.InsertBatch[User](users) + + passwords, db := gplus.PluckDistinct[User, string]("password", nil) + if db.Error != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error.Error()) + } + if len(passwords) != 1 { + t.Errorf("count expects: %v, got %v", 1, len(passwords)) + } else { + for _, item := range passwords { + fmt.Printf("pluck list %s\n", item) + } + + } +} + +func TestReset(t *testing.T) { + deleteOldData() + users := getUsers() + gplus.InsertBatch[User](users) + + query, model := gplus.NewQuery[User]() + query.Eq(&model.Username, "afumu1").Or().Eq(&model.Username, "afumu2") + count, db := gplus.SelectCount(query) + if db.Error != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error) + } + if count != 2 { + t.Errorf("count expects: %v, got %v", 2, count) + } + + query.Reset().Eq(&model.Username, "afumu3") + count, db = gplus.SelectCount(query) + if db.Error != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error) + } + if count != 1 { + t.Errorf("count expects: %v, got %v", 1, count) + } + +} + +func TestBySql(t *testing.T) { + deleteOldData() + users := getUsers() + + gplus.InsertBatch[User](users) + + type UserPlus struct { + User + Num int + } + + records, db := gplus.SelectListBySql[UserPlus]("select * , 1 as num from Users") + if db.Error != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error) + } + + if len(records) > 0 { + for _, item := range records { + if item.Num != 1 { + t.Errorf("count expects: %v, got %v", 1, item.Num) + } + } + } else { + t.Errorf("count expects: %v, got %v", len(records), 0) + } + + db = gplus.ExcSql("delete from Users") + + if db.Error != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error) + } + if db.RowsAffected == 0 { + t.Errorf("count expects: %v, got %v", db.RowsAffected, 0) + } +} + func deleteOldData() { q, u := gplus.NewQuery[User]() q.IsNotNull(&u.ID) diff --git a/tests/select_test.go b/tests/select_test.go index acdcf91..31094e4 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -38,7 +38,7 @@ func TestSelectById2Name(t *testing.T) { } func TestSelectById3Name(t *testing.T) { - var expectSql = "SELECT `Users`.`id`,`Users`.`password`,`Users`.`address`,`Users`.`phone`,`Users`.`score`,`Users`.`dept`,`Users`.`created_at`,`Users`.`updated_at` FROM `Users` WHERE id = '1' ORDER BY `Users`.`id` LIMIT 1" + var expectSql = "SELECT `Users`.`id`,`Users`.`created_at`,`Users`.`updated_at`,`Users`.`password`,`Users`.`address`,`Users`.`phone`,`Users`.`score`,`Users`.`dept` FROM `Users` WHERE id = '1' ORDER BY `Users`.`id` LIMIT 1" sessionDb := checkSelectSql(t, expectSql) u := gplus.GetModel[User]() gplus.SelectById[User](1, gplus.Db(sessionDb), gplus.Omit(&u.Username, &u.Age)) @@ -163,7 +163,7 @@ func TestSelectList13Name(t *testing.T) { } func TestSelectList14Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE age BETWEEN '18' and '20'" + var expectSql = "SELECT * FROM `Users` WHERE age BETWEEN '18' AND '20'" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Between(&u.Age, 18, 20) diff --git a/tests/update_test.go b/tests/update_test.go index afa88dd..6a9c389 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -27,7 +27,7 @@ import ( func TestUpdateByIdName(t *testing.T) { var expectSql = "UPDATE `Users` SET `score`='100' WHERE `id` = '1'" sessionDb := checkUpdateSql(t, expectSql) - var user = &User{ID: 1, Score: 100} + var user = &User{Base: Base{ID: 1}, Score: 100} u := gplus.GetModel[User]() gplus.UpdateById(user, gplus.Db(sessionDb), gplus.Omit(&u.CreatedAt, &u.UpdatedAt)) } @@ -35,7 +35,7 @@ func TestUpdateByIdName(t *testing.T) { func TestUpdateZeroByIdName(t *testing.T) { var expectSql = "UPDATE `Users` SET `username`='',`password`='',`address`='',`age`='0',`phone`='',`score`='100',`dept`='' WHERE `id` = '1'" sessionDb := checkUpdateSql(t, expectSql) - var user = &User{ID: 1, Score: 100} + var user = &User{Base: Base{ID: 1}, Score: 100} u := gplus.GetModel[User]() gplus.UpdateZeroById(user, gplus.Db(sessionDb), gplus.Omit(&u.CreatedAt, &u.UpdatedAt)) } diff --git a/tests/user.go b/tests/user.go index 3019bbf..aa3560b 100644 --- a/tests/user.go +++ b/tests/user.go @@ -21,17 +21,21 @@ import ( "time" ) -type User struct { +type Base struct { ID int64 - Username string - Password string - Address string - Age int - Phone string - Score int - Dept string - CreatedAt time.Time - UpdatedAt time.Time + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +type User struct { + Base + Username string `gorm:"column:username"` + Password string + Address string + Age int + Phone string + Score int + Dept string } func (User) TableName() string { From 2b3f2afc04c2271eefe90909b42ab60a0ce2b50a Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Mon, 19 Jun 2023 19:19:33 +0800 Subject: [PATCH 02/12] fix --- gplus/dao.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/gplus/dao.go b/gplus/dao.go index 4333cc0..cf557a0 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -37,19 +37,18 @@ func Init(db *gorm.DB) { } type Page[T any] struct { - Current int - Size int - Total int64 - Records []*T - CurTime int64 - RecordsMap []T + Current int `json:"page"` // 页码 + Size int `json:"pageSize"` // 每页大小 + Total int64 `json:"total"` + Records []*T `json:"list"` + CurTime int64 `json:"curTime"` // 当前时间,毫秒 + RecordsMap []T `json:"listMap"` } type Dao[T any] struct{} func (dao Dao[T]) NewQuery() (*QueryCond[T], *T) { - q := &QueryCond[T]{} - return q, nil + return NewQuery[T]() } func NewPage[T any](current, size int) *Page[T] { From f3bda2556beca71185f45175a504fc0c2cb4e4ad Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Mon, 19 Jun 2023 19:24:32 +0800 Subject: [PATCH 03/12] fix --- README.md | 6 +++--- go.mod | 2 +- gplus/dao.go | 2 +- gplus/function.go | 2 +- gplus/option.go | 2 +- gplus/query.go | 2 +- tests/dao_test.go | 2 +- tests/delete_test.go | 2 +- tests/insert_test.go | 2 +- tests/select_test.go | 2 +- tests/update_test.go | 2 +- 11 files changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 5f97cd1..622e8a8 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ VALUES 下载Gorm-Plus ```SQL - go get github.com/acmestack/gorm-plus + go get github.com/aixj1984/gorm-plus ``` @@ -62,7 +62,7 @@ VALUES package main import ( - "github.com/acmestack/gorm-plus/gplus" + "github.com/aixj1984/gorm-plus/gplus" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -135,5 +135,5 @@ func main() { 然而,`Gorm-Plus`的强大功能远不止于此。 -更多文档请查看: [https://github.com/acmestack/gorm-plus/wiki](https://github.com/acmestack/gorm-plus/wiki) +更多文档请查看: [https://github.com/aixj1984/gorm-plus/wiki](https://github.com/aixj1984/gorm-plus/wiki) diff --git a/go.mod b/go.mod index 5318a1b..2212a5f 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/acmestack/gorm-plus +module github.com/aixj1984/gorm-plus go 1.18 diff --git a/gplus/dao.go b/gplus/dao.go index cf557a0..2730be5 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "github.com/acmestack/gorm-plus/constants" + "github.com/aixj1984/gorm-plus/constants" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils" diff --git a/gplus/function.go b/gplus/function.go index 2a26f5d..2d4da9a 100644 --- a/gplus/function.go +++ b/gplus/function.go @@ -18,7 +18,7 @@ package gplus import ( - "github.com/acmestack/gorm-plus/constants" + "github.com/aixj1984/gorm-plus/constants" "strings" ) diff --git a/gplus/option.go b/gplus/option.go index 1a89ee7..fbe0efb 100644 --- a/gplus/option.go +++ b/gplus/option.go @@ -56,7 +56,7 @@ func Omit(columns ...any) OptionFunc { } } -// IgnoreTotal 分页查询忽略总数 issue: https://github.com/acmestack/gorm-plus/issues/37 +// IgnoreTotal 分页查询忽略总数 issue: https://github.com/aixj1984/gorm-plus/issues/37 func IgnoreTotal() OptionFunc { return func(o *Option) { o.IgnoreTotal = true diff --git a/gplus/query.go b/gplus/query.go index 6b95d0b..9c4321a 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -19,7 +19,7 @@ package gplus import ( "fmt" - "github.com/acmestack/gorm-plus/constants" + "github.com/aixj1984/gorm-plus/constants" "reflect" "strings" ) diff --git a/tests/dao_test.go b/tests/dao_test.go index 2d22fb2..8923e63 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -20,7 +20,7 @@ package tests import ( "errors" "fmt" - "github.com/acmestack/gorm-plus/gplus" + "github.com/aixj1984/gorm-plus/gplus" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" diff --git a/tests/delete_test.go b/tests/delete_test.go index cc1318b..f1f50b1 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -18,7 +18,7 @@ package tests import ( - "github.com/acmestack/gorm-plus/gplus" + "github.com/aixj1984/gorm-plus/gplus" "gorm.io/gorm" "strings" "testing" diff --git a/tests/insert_test.go b/tests/insert_test.go index 95b6959..b547ca6 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -18,7 +18,7 @@ package tests import ( - "github.com/acmestack/gorm-plus/gplus" + "github.com/aixj1984/gorm-plus/gplus" "gorm.io/gorm" "strings" "testing" diff --git a/tests/select_test.go b/tests/select_test.go index 31094e4..d0cee58 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -18,7 +18,7 @@ package tests import ( - "github.com/acmestack/gorm-plus/gplus" + "github.com/aixj1984/gorm-plus/gplus" "gorm.io/gorm" "strings" "testing" diff --git a/tests/update_test.go b/tests/update_test.go index 6a9c389..42ee52d 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -18,7 +18,7 @@ package tests import ( - "github.com/acmestack/gorm-plus/gplus" + "github.com/aixj1984/gorm-plus/gplus" "gorm.io/gorm" "strings" "testing" From b8859f211cfbeaad99495b811ab63e22c9834994 Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Mon, 19 Jun 2023 19:51:48 +0800 Subject: [PATCH 04/12] fix --- gplus/query.go | 12 ++++++++++++ tests/dao_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/gplus/query.go b/gplus/query.go index 9c4321a..0dae08c 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -389,6 +389,18 @@ func (q *QueryCond[T]) buildOrder(orderType string, columns ...string) { } } +// 执行执行增加条件 +func (q *QueryCond[T]) AddStrCond(cond string) *QueryCond[T] { + if len(q.queryExpressions) > 0 { + sk := sqlKeyword{keyword: constants.And} + q.queryExpressions = append(q.queryExpressions, &sk) + } + condSk := sqlKeyword{keyword: cond} + q.queryExpressions = append(q.queryExpressions, &condSk) + q.last = &condSk + return q +} + // 根据条件,执行方法 func (q *QueryCond[T]) Case(isTrue bool, handleFunc func()) *QueryCond[T] { if isTrue { diff --git a/tests/dao_test.go b/tests/dao_test.go index 8923e63..b7a9123 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -641,6 +641,42 @@ func TestReset(t *testing.T) { } +func TestQueryBuilder(t *testing.T) { + deleteOldData() + users := getUsers() + gplus.InsertBatch[User](users) + + query, _ := gplus.NewQuery[User]() + + query.AddStrCond(fmt.Sprintf(" username = '%s' ", "afumu1")) + + count, db := gplus.SelectCount(query) + if db.Error != nil { + t.Errorf("errors happened when SelectCount : %v", db.Error) + } + if count != 1 { + t.Errorf("count expects: %v, got %v", 1, count) + } +} + +func TestExist(t *testing.T) { + deleteOldData() + users := getUsers() + gplus.InsertBatch[User](users) + + query, _ := gplus.NewQuery[User]() + + query.AddStrCond(fmt.Sprintf(" username = '%s' ", "afumu1")) + + exist, dbErr := gplus.Exists(query) + if dbErr != nil { + t.Errorf("errors happened when SelectCount : %v", dbErr.Error()) + } + if !exist { + t.Errorf("count expects: %v, got %v", true, exist) + } +} + func TestBySql(t *testing.T) { deleteOldData() users := getUsers() From 1e59a7ae0e46cb47beb7ebc997f489ab26059fab Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Tue, 20 Jun 2023 09:55:34 +0800 Subject: [PATCH 05/12] add test case --- tests/update_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/update_test.go b/tests/update_test.go index 42ee52d..a5357bf 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -113,6 +113,24 @@ func TestUpdate7Name(t *testing.T) { gplus.Update(query, gplus.Db(sessionDb), gplus.Omit(&u.CreatedAt, &u.UpdatedAt)) } +func TestUpdateRest(t *testing.T) { + var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username = 'afumu' OR age = '18'" + sessionDb := checkUpdateSql(t, expectSql) + query, u := gplus.NewQuery[User]() + query.Eq(&u.Username, "afumu").Or().Eq(&u.Age, 18). + Set(&u.Score, 100). + Set(&u.Address, "shanghai") + gplus.Update(query, gplus.Db(sessionDb), gplus.Omit(&u.CreatedAt, &u.UpdatedAt)) + + expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username = 'afumu' AND age = '18'" + sessionDb = checkUpdateSql(t, expectSql) + query.Reset() + query.Eq(&u.Username, "afumu").Eq(&u.Age, 18). + Set(&u.Score, 100). + Set(&u.Address, "shanghai") + gplus.Update(query, gplus.Db(sessionDb), gplus.Omit(&u.CreatedAt, &u.UpdatedAt)) +} + func checkUpdateSql(t *testing.T, expect string) *gorm.DB { expect = strings.TrimSpace(expect) sessionDb := gormDb.Session(&gorm.Session{DryRun: true}) From 186685bab22a8931eb8315f882bba65fb02c630b Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Mon, 24 Jul 2023 14:46:25 +0800 Subject: [PATCH 06/12] sync master --- README.md | 78 ++++++ gplus/dao.go | 8 +- gplus/query.go | 29 +- gplus/{sqlSegment.go => segment.go} | 0 gplus/tool.go | 421 ++++++++++++++++++++++++++++ tests/dao_test.go | 2 +- tests/delete_test.go | 28 +- tests/insert_test.go | 4 +- tests/select_test.go | 100 ++++--- tests/tool_test.go | 283 +++++++++++++++++++ tests/update_test.go | 22 +- tests/user.go | 6 +- tests/utils.go | 11 +- 13 files changed, 914 insertions(+), 78 deletions(-) rename gplus/{sqlSegment.go => segment.go} (100%) create mode 100644 gplus/tool.go create mode 100644 tests/tool_test.go diff --git a/README.md b/README.md index 622e8a8..14895fe 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,84 @@ func main() { ``` +## 搜索工具 + +只需要下面一行代码即可完成单表的所有查询功能 + +```Bash +gplus.SelectList(gplus.BuildQuery[User](queryParams)) +``` + + + +例子: + +```Bash +func main() { + http.HandleFunc("/", handleRequest) + http.ListenAndServe(":8080", nil) +} + +func handleRequest(w http.ResponseWriter, r *http.Request) { + queryParams := r.URL.Query() + list, _ := gplus.SelectList(gplus.BuildQuery[User](queryParams)) + marshal, _ := json.Marshal(list) + w.Write(marshal) +} +``` + +假设我们要查询username为zhangsan的用户 + +```Bash +http://localhost:8080?q=username=zhangsan +``` + + + +假设我们要查询username姓zhang的用户 + +```Bash +http://localhost:8080?q=username~>=zhang +``` + + + +假设我们要查询age大于20的用户 + +```Bash +http://localhost:8080?q=age>20 +``` + + + +假设我们要查询username等于zhagnsan,password等于123456的用户 + +```Bash +http://localhost:8080?q=username=zhangsan&q=password=123456 +``` + + + +假设我们要查询username等于zhagnsan,password等于123456的用户 + +```Bash +http://localhost:8080?q=username=zhangsan&q=password=123456 +``` + + + +假设我们要查询username等于zhagnsan,或者usename等于lisi的用户 + +可以增加一个分组和gcond的条件查询来实现 + +```Bash +http://localhost:8080?q=A.username=zhangsan&q=B.username=lisi&gcond=A|B +``` + + + +所有的单表查询我们都只需要一行代码即可。 + ## 总结 diff --git a/gplus/dao.go b/gplus/dao.go index 2730be5..caa55cb 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -152,7 +152,7 @@ func SelectById[T any](id any, opts ...OptionFunc) (*T, *gorm.DB) { q.Eq(getPkColumnName[T](), id) var entity T resultDb := buildCondition(q, opts...) - return &entity, resultDb.First(&entity) + return &entity, resultDb.Take(&entity) } // SelectByIds 根据 ID 查询多条记录 @@ -166,7 +166,7 @@ func SelectByIds[T any](ids any, opts ...OptionFunc) ([]*T, *gorm.DB) { func SelectOne[T any](q *QueryCond[T], opts ...OptionFunc) (*T, *gorm.DB) { var entity T resultDb := buildCondition(q, opts...) - return &entity, resultDb.First(&entity) + return &entity, resultDb.Take(&entity) } // SelectList 根据条件查询多条记录 @@ -352,6 +352,10 @@ func buildCondition[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB { resultDb.Select(q.selectColumns) } + if len(q.omitColumns) > 0 { + resultDb.Omit(q.omitColumns...) + } + expressions := q.queryExpressions if len(expressions) > 0 { var sqlBuilder strings.Builder diff --git a/gplus/query.go b/gplus/query.go index 0dae08c..0520029 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -26,6 +26,7 @@ import ( type QueryCond[T any] struct { selectColumns []string + omitColumns []string distinctColumns []string queryExpressions []any orderBuilder strings.Builder @@ -37,6 +38,7 @@ type QueryCond[T any] struct { limit *int offset int updateMap map[string]any + columnTypeMap map[string]reflect.Type } func (q *QueryCond[T]) getSqlSegment() string { @@ -46,7 +48,6 @@ func (q *QueryCond[T]) getSqlSegment() string { // NewQuery 构建查询条件 func NewQuery[T any]() (*QueryCond[T], *T) { q := &QueryCond[T]{} - modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { m, isReal := model.(*T) @@ -150,6 +151,13 @@ func (q *QueryCond[T]) LikeLeft(column any, val any) *QueryCond[T] { return q } +// NotLikeLeft 非左模糊 NOT LIKE '%值' +func (q *QueryCond[T]) NotLikeLeft(column any, val any) *QueryCond[T] { + s := fmt.Sprintf("%v", val) + q.addExpression(q.buildSqlSegment(column, constants.Not+" "+constants.Like, "%"+s)...) + return q +} + // LikeRight 右模糊 LIKE '值%' func (q *QueryCond[T]) LikeRight(column any, val any) *QueryCond[T] { s := fmt.Sprintf("%v", val) @@ -157,6 +165,13 @@ func (q *QueryCond[T]) LikeRight(column any, val any) *QueryCond[T] { return q } +// NotLikeRight 非右模糊 NOT LIKE '值%' +func (q *QueryCond[T]) NotLikeRight(column any, val any) *QueryCond[T] { + s := fmt.Sprintf("%v", val) + q.addExpression(q.buildSqlSegment(column, constants.Not+" "+constants.Like, s+"%")...) + return q +} + // IsNull 是否为空 字段 IS NULL func (q *QueryCond[T]) IsNull(column any) *QueryCond[T] { q.addExpression(q.buildSqlSegment(column, constants.IsNull, "")...) @@ -235,7 +250,7 @@ func (q *QueryCond[T]) OrderByAsc(columns ...any) *QueryCond[T] { return q } -// Having SQl语句 +// Having HAVING SQl语句 func (q *QueryCond[T]) Having(having string, args ...any) *QueryCond[T] { q.havingBuilder.WriteString(having) if len(args) == 1 { @@ -284,6 +299,15 @@ func (q *QueryCond[T]) Select(columns ...any) *QueryCond[T] { return q } +// Omit 忽略字段 +func (q *QueryCond[T]) Omit(columns ...any) *QueryCond[T] { + for _, v := range columns { + columnName := getColumnName(v) + q.omitColumns = append(q.omitColumns, columnName) + } + return q +} + // Set 设置更新的字段 func (q *QueryCond[T]) Set(column any, val any) *QueryCond[T] { columnName := getColumnName(column) @@ -412,6 +436,7 @@ func (q *QueryCond[T]) Case(isTrue bool, handleFunc func()) *QueryCond[T] { // 重置查询条件 func (q *QueryCond[T]) Reset() *QueryCond[T] { q.selectColumns = q.selectColumns[:0] + q.omitColumns = q.omitColumns[:0] q.distinctColumns = q.distinctColumns[:0] q.queryExpressions = q.queryExpressions[:0] q.orderBuilder.Reset() diff --git a/gplus/sqlSegment.go b/gplus/segment.go similarity index 100% rename from gplus/sqlSegment.go rename to gplus/segment.go diff --git a/gplus/tool.go b/gplus/tool.go new file mode 100644 index 0000000..5e7917c --- /dev/null +++ b/gplus/tool.go @@ -0,0 +1,421 @@ +/* + * Licensed to the AcmeStack under one or more contributor license + * agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gplus + +import ( + "fmt" + "net/url" + "reflect" + "strconv" + "strings" + "sync" +) + +type Condition struct { + Group string + ColumnName string + Op string + ColumnValue any +} + +var columnTypeCache sync.Map + +var operators = []string{"!~<=", "!~>=", "~<=", "~>=", "!?=", "!^=", "!~=", "?=", "^=", "~=", "!=", ">=", "<=", "=", ">", "<"} +var builders = map[string]func(query *QueryCond[any], name string, value any){ + "!~<=": notLikeLeft, + "!~>=": notLikeRight, + "~<=": LikeLeft, + "~>=": LikeRight, + "!?=": notIn, + "!^=": notBetween, + "!~=": notLike, + "?=": in, + "^=": between, + "~=": like, + "!=": ne, + ">=": ge, + "<=": le, + "=": eq, + ">": gt, + "<": lt, +} + +func BuildQuery[T any](queryParams url.Values) *QueryCond[T] { + + columnCondMap, conditionMap, gcond := parseParams(queryParams) + + parentQuery := buildParentQuery[T](conditionMap) + + queryCondMap := buildQueryCondMap[T](columnCondMap) + + // 如果没有分组条件,直接返回默认的查询条件 + if len(gcond) == 0 { + if q, ok := queryCondMap["default"]; ok { + q.orderBuilder = parentQuery.orderBuilder + q.selectColumns = parentQuery.selectColumns + q.omitColumns = parentQuery.omitColumns + return q + } + + // 如果没有分组条件,但是有分组设置,返回第一个查询条件。主要为了兼容只有一个分组但是没有设置条件的情况。 + if len(queryCondMap) == 1 { + for _, q := range queryCondMap { + q.orderBuilder = parentQuery.orderBuilder + q.selectColumns = parentQuery.selectColumns + q.omitColumns = parentQuery.omitColumns + return q + } + } + } + + return buildGroupQuery[T](gcond, queryCondMap, parentQuery) +} + +func parseParams(queryParams url.Values) (map[string][]*Condition, map[string]string, string) { + var gcond string + var columnCondMap = make(map[string][]*Condition) + var conditionMap = make(map[string]string) + for key, values := range queryParams { + switch key { + case "q": + columnCondMap = buildColumnCondMap(values) + case "sort": + if len(values) > 0 { + conditionMap["sort"] = values[len(values)-1] + } + case "select": + if len(values) > 0 { + conditionMap["select"] = values[len(values)-1] + } + case "omit": + if len(values) > 0 { + conditionMap["omit"] = values[len(values)-1] + } + case "gcond": + gcond = values[0] + } + } + return columnCondMap, conditionMap, gcond +} + +// buildColumnCondMap 根据url参数构建字段条件 +func buildColumnCondMap(values []string) map[string][]*Condition { + var maps = make(map[string][]*Condition) + for _, value := range values { + currentOperator := getCurrentOp(value) + params := strings.SplitN(value, currentOperator, 2) + if len(params) == 2 { + condition := &Condition{} + groups := strings.Split(params[0], ".") + var groupName string + var columnName string + // 如果不包含组,默认分为同一个组 + if len(groups) == 1 { + groupName = "default" + columnName = groups[0] + } else if len(groups) == 2 { + groupName = groups[0] + columnName = groups[1] + } + condition.Group = groupName + condition.ColumnName = columnName + condition.Op = currentOperator + condition.ColumnValue = params[1] + conditions, ok := maps[groupName] + if ok { + conditions = append(conditions, condition) + } else { + conditions = []*Condition{condition} + } + maps[groupName] = conditions + } + } + return maps +} + +func getCurrentOp(value string) string { + var currentOperator string + for _, op := range operators { + if strings.Contains(value, op) { + currentOperator = op + break + } + } + return currentOperator +} + +func buildQueryCondMap[T any](columnCondMap map[string][]*Condition) map[string]*QueryCond[T] { + var queryCondMap = make(map[string]*QueryCond[T]) + columnTypeMap := getColumnTypeMap[T]() + for key, conditions := range columnCondMap { + query := &QueryCond[any]{} + query.columnTypeMap = columnTypeMap + for _, condition := range conditions { + name := condition.ColumnName + op := condition.Op + value := condition.ColumnValue + builders[op](query, name, value) + } + newQuery, _ := NewQuery[T]() + newQuery.queryExpressions = append(newQuery.queryExpressions, query.queryExpressions...) + queryCondMap[key] = newQuery + } + return queryCondMap +} + +func buildParentQuery[T any](conditionMap map[string]string) *QueryCond[T] { + parentQuery, _ := NewQuery[T]() + for key, value := range conditionMap { + if key == "sort" { + orderColumns := strings.Split(value, ",") + for _, column := range orderColumns { + if strings.HasPrefix(column, "-") { + newValue := strings.TrimLeft(column, "-") + parentQuery.OrderByDesc(newValue) + } else { + parentQuery.OrderByAsc(column) + } + } + } else if key == "select" { + selectColumns := strings.Split(value, ",") + for _, column := range selectColumns { + parentQuery.Select(column) + } + } else if key == "omit" { + omitColumns := strings.Split(value, ",") + for _, column := range omitColumns { + parentQuery.Omit(column) + } + } + } + return parentQuery +} + +func buildGroupQuery[T any](gcond string, queryMaps map[string]*QueryCond[T], query *QueryCond[T]) *QueryCond[T] { + var tempQuerys []*QueryCond[T] + tempQuerys = append(tempQuerys, query) + for i, char := range gcond { + str := string(char) + tempQuery := tempQuerys[len(tempQuerys)-1] + // 如果是 左括号 开头,则代表需要嵌套查询 + if str == "(" && i != len(gcond)-1 { + if i != 0 && string(gcond[i-1]) == "|" { + tempQuery.Or(func(q *QueryCond[T]) { + paramQuery, isOk := queryMaps[string(gcond[i+1])] + if isOk { + q.queryExpressions = append(q.queryExpressions, paramQuery.queryExpressions...) + tempQuerys = append(tempQuerys, q) + } + }) + continue + } else { + tempQuery.And(func(q *QueryCond[T]) { + paramQuery, isOk := queryMaps[string(gcond[i+1])] + if isOk { + q.queryExpressions = append(q.queryExpressions, paramQuery.queryExpressions...) + tempQuerys = append(tempQuerys, q) + } + }) + } + continue + } + + // 如果当前为 | ,而且不是最后一个字符,而且下一个字符不是 ( ,则为 or + if str == "|" && i != len(gcond)-1 { + paramQuery, isOk := queryMaps[string(gcond[i+1])] + if isOk { + tempQuery.Or().queryExpressions = append(tempQuery.queryExpressions, paramQuery.queryExpressions...) + tempQuery.last = paramQuery.queryExpressions[len(paramQuery.queryExpressions)-1] + } + continue + } + + if str == "*" && i != len(gcond)-1 { + paramQuery, isOk := queryMaps[string(gcond[i+1])] + if isOk { + tempQuery.And() + tempQuery.queryExpressions = append(tempQuery.queryExpressions, paramQuery.queryExpressions...) + tempQuery.last = paramQuery.queryExpressions[len(paramQuery.queryExpressions)-1] + } + continue + } + + if str == ")" { + // 删除最后一个query对象 + tempQuerys = tempQuerys[:len(tempQuerys)-1] + continue + } + + // 如果上面的条件不满足,而且是第一个的话,那么就直接添加条件 + if i == 0 { + paramQuery, isOk := queryMaps[string(gcond[i])] + if isOk { + tempQuery.queryExpressions = append(tempQuery.queryExpressions, paramQuery.queryExpressions...) + tempQuery.last = paramQuery.queryExpressions[len(paramQuery.queryExpressions)-1] + } + } + } + return query +} + +func getColumnTypeMap[T any]() map[string]reflect.Type { + modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() + if model, ok := columnTypeCache.Load(modelTypeStr); ok { + if columnNameMap, isOk := model.(map[string]reflect.Type); isOk { + return columnNameMap + } + } + var columnTypeMap = make(map[string]reflect.Type) + typeOf := reflect.TypeOf((*T)(nil)).Elem() + for i := 0; i < typeOf.NumField(); i++ { + field := typeOf.Field(i) + if field.Anonymous { + nestedFields := getSubFieldColumnTypeMap(field) + for key, value := range nestedFields { + columnTypeMap[key] = value + } + } + columnName := parseColumnName(field) + columnTypeMap[columnName] = field.Type + } + columnTypeCache.Store(modelTypeStr, columnTypeMap) + return columnTypeMap +} + +func getSubFieldColumnTypeMap(field reflect.StructField) map[string]reflect.Type { + columnTypeMap := make(map[string]reflect.Type) + modelType := field.Type + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + for j := 0; j < modelType.NumField(); j++ { + subField := modelType.Field(j) + if subField.Anonymous { + nestedFields := getSubFieldColumnTypeMap(subField) + for key, value := range nestedFields { + columnTypeMap[key] = value + } + } else { + columnName := parseColumnName(subField) + columnTypeMap[columnName] = subField.Type + } + } + return columnTypeMap +} + +func notLikeLeft(query *QueryCond[any], name string, value any) { + query.NotLikeLeft(name, convert(query.columnTypeMap, name, value)) +} + +func notLikeRight(query *QueryCond[any], name string, value any) { + query.NotLikeRight(name, convert(query.columnTypeMap, name, value)) +} + +func LikeLeft(query *QueryCond[any], name string, value any) { + query.LikeLeft(name, convert(query.columnTypeMap, name, value)) +} + +func LikeRight(query *QueryCond[any], name string, value any) { + query.LikeRight(name, convert(query.columnTypeMap, name, value)) +} + +func notIn(query *QueryCond[any], name string, value any) { + values := strings.Split(fmt.Sprintf("%s", value), ",") + var queryValues []any + for _, v := range values { + queryValues = append(queryValues, convert(query.columnTypeMap, name, v)) + } + query.NotIn(name, queryValues) +} + +func notBetween(query *QueryCond[any], name string, value any) { + values := strings.Split(fmt.Sprintf("%s", value), ",") + if len(values) == 2 { + query.NotBetween(name, convert(query.columnTypeMap, name, values[0]), convert(query.columnTypeMap, name, values[1])) + } +} + +func notLike(query *QueryCond[any], name string, value any) { + query.NotLike(name, convert(query.columnTypeMap, name, value)) +} + +func in(query *QueryCond[any], name string, value any) { + values := strings.Split(fmt.Sprintf("%s", value), ",") + var queryValues []any + for _, v := range values { + queryValues = append(queryValues, convert(query.columnTypeMap, name, v)) + } + query.In(name, queryValues) +} + +func between(query *QueryCond[any], name string, value any) { + values := strings.Split(fmt.Sprintf("%s", value), ",") + if len(values) == 2 { + query.Between(name, convert(query.columnTypeMap, name, values[0]), convert(query.columnTypeMap, name, values[1])) + } +} + +func like(query *QueryCond[any], name string, value any) { + query.Like(name, convert(query.columnTypeMap, name, value)) +} + +func ne(query *QueryCond[any], name string, value any) { + if strings.ToLower(fmt.Sprintf("%s", value)) == "null" { + query.IsNotNull(name) + } else { + query.Ne(name, convert(query.columnTypeMap, name, value)) + } +} + +func ge(query *QueryCond[any], name string, value any) { + query.Ge(name, convert(query.columnTypeMap, name, value)) +} + +func le(query *QueryCond[any], name string, value any) { + query.Le(name, convert(query.columnTypeMap, name, value)) +} + +func eq(query *QueryCond[any], name string, value any) { + if strings.ToLower(fmt.Sprintf("%s", value)) == "null" { + query.IsNull(name) + } else { + query.Eq(name, convert(query.columnTypeMap, name, value)) + } +} + +func gt(query *QueryCond[any], name string, value any) { + query.Gt(name, convert(query.columnTypeMap, name, value)) +} + +func lt(query *QueryCond[any], name string, value any) { + query.Lt(name, convert(query.columnTypeMap, name, value)) +} + +func convert(columnTypeMap map[string]reflect.Type, name string, value any) any { + columnType, ok := columnTypeMap[name] + if ok { + switch columnType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + atoi, err := strconv.Atoi(fmt.Sprintf("%s", value)) + if err == nil { + value = atoi + } + } + } + return value +} diff --git a/tests/dao_test.go b/tests/dao_test.go index b7a9123..c48ba8b 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -183,7 +183,7 @@ func TestUpdateZeroById(t *testing.T) { users := getUsers() gplus.InsertBatch[User](users) - updateUser := &User{Base: Base{ID: users[0].ID}, Score: 100, Age: 25} + updateUser := &User{Base: Base{ID: users[0].ID, CreatedAt: users[0].CreatedAt}, Score: 100, Age: 25} if res := gplus.UpdateZeroById[User](updateUser); res.Error != nil || res.RowsAffected != 1 { t.Errorf("errors happened when deleteByIds: %v, affected: %v", res.Error, res.RowsAffected) diff --git a/tests/delete_test.go b/tests/delete_test.go index f1f50b1..16fc15c 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -25,18 +25,18 @@ import ( ) func TestDeleteByIdName(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE `id` = '1'" + var expectSql = "DELETE FROM `Users` WHERE `id` = 1" sessionDb := checkDeleteSql(t, expectSql) gplus.DeleteById[User](1, gplus.Db(sessionDb)) } func TestDeleteByIdsName(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE `id` IN ('1','2')" + var expectSql = "DELETE FROM `Users` WHERE `id` IN (1,2)" sessionDb := checkDeleteSql(t, expectSql) gplus.DeleteById[User]([]int{1, 2}, gplus.Db(sessionDb)) } -func TestDelete1Name(t *testing.T) { +func TestDeleteEq(t *testing.T) { var expectSql = "DELETE FROM `Users` WHERE username = 'afumu'" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -53,7 +53,7 @@ func TestDelete2Name(t *testing.T) { } func TestDelete3Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE username = 'afumu' OR ( username = 'afumu2' AND score = '12' )" + var expectSql = "DELETE FROM `Users` WHERE username = 'afumu' OR ( username = 'afumu2' AND score = 12 )" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").Or(func(q *gplus.QueryCond[User]) { @@ -63,7 +63,7 @@ func TestDelete3Name(t *testing.T) { } func TestDelete4Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE username = 'afumu' AND ( username = 'afumu2' AND score = '12' )" + var expectSql = "DELETE FROM `Users` WHERE username = 'afumu' AND ( username = 'afumu2' AND score = 12 )" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").And(func(q *gplus.QueryCond[User]) { @@ -81,7 +81,7 @@ func TestDelete5Name(t *testing.T) { } func TestDelete6Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE username = 'afumu' AND score = '60'" + var expectSql = "DELETE FROM `Users` WHERE username = 'afumu' AND score = 60" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").And().Eq(&u.Score, 60) @@ -89,7 +89,7 @@ func TestDelete6Name(t *testing.T) { } func TestDelete7Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE score > '60'" + var expectSql = "DELETE FROM `Users` WHERE score > 60" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Gt(&u.Score, 60) @@ -97,7 +97,7 @@ func TestDelete7Name(t *testing.T) { } func TestDelete8Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE score > '60'" + var expectSql = "DELETE FROM `Users` WHERE score > 60" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Gt(&u.Score, 60) @@ -105,7 +105,7 @@ func TestDelete8Name(t *testing.T) { } func TestDelete9Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE score >= '60'" + var expectSql = "DELETE FROM `Users` WHERE score >= 60" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Ge(&u.Score, 60) @@ -113,7 +113,7 @@ func TestDelete9Name(t *testing.T) { } func TestDelete10Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE score < '60'" + var expectSql = "DELETE FROM `Users` WHERE score < 60" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Lt(&u.Score, 60) @@ -121,7 +121,7 @@ func TestDelete10Name(t *testing.T) { } func TestDelete11Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE score <= '60'" + var expectSql = "DELETE FROM `Users` WHERE score <= 60" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Le(&u.Score, 60) @@ -185,7 +185,7 @@ func TestDelete18Name(t *testing.T) { } func TestDelete20Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE score BETWEEN '60' AND '80'" + var expectSql = "DELETE FROM `Users` WHERE score BETWEEN 60 AND 80" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Between(&u.Score, 60, 80) @@ -193,7 +193,7 @@ func TestDelete20Name(t *testing.T) { } func TestDelete21Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE score NOT BETWEEN '60' AND '80'" + var expectSql = "DELETE FROM `Users` WHERE score NOT BETWEEN 60 AND 80" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.NotBetween(&u.Score, 60, 80) @@ -201,7 +201,7 @@ func TestDelete21Name(t *testing.T) { } func TestDelete22Name(t *testing.T) { - var expectSql = "DELETE FROM `Users` WHERE score NOT BETWEEN '60' AND '80'" + var expectSql = "DELETE FROM `Users` WHERE score NOT BETWEEN 60 AND 80" sessionDb := checkDeleteSql(t, expectSql) query, u := gplus.NewQuery[User]() query.NotBetween(&u.Score, 60, 80) diff --git a/tests/insert_test.go b/tests/insert_test.go index b547ca6..914239e 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -25,7 +25,7 @@ import ( ) func TestInsert1Name(t *testing.T) { - var expectSql = "INSERT INTO `Users` (`username`,`password`,`address`,`age`,`phone`,`score`,`dept`) VALUES ('afumu','123456','','18','','12','研发部门')" + var expectSql = "INSERT INTO `Users` (`username`,`password`,`address`,`age`,`phone`,`score`,`dept`) VALUES ('afumu','123456','',18,'',12,'研发部门')" user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} u := gplus.GetModel[User]() sessionDb := checkInsertSql(t, expectSql) @@ -33,7 +33,7 @@ func TestInsert1Name(t *testing.T) { } func TestInsert2Name(t *testing.T) { - var expectSql = "INSERT INTO `Users` (`username`,`password`,`address`,`age`,`phone`,`score`) VALUES ('afumu','123456','','18','','12')" + var expectSql = "INSERT INTO `Users` (`username`,`password`,`address`,`age`,`phone`,`score`) VALUES ('afumu','123456','',18,'',12)" user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} u := gplus.GetModel[User]() sessionDb := checkInsertSql(t, expectSql) diff --git a/tests/select_test.go b/tests/select_test.go index d0cee58..677598a 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -24,41 +24,41 @@ import ( "testing" ) -func TestSelectById1Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE id = '1' ORDER BY `Users`.`id` LIMIT 1" +func TestSelectByIdName(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE id = 1 LIMIT 1" sessionDb := checkSelectSql(t, expectSql) gplus.SelectById[User](1, gplus.Db(sessionDb)) } -func TestSelectById2Name(t *testing.T) { - var expectSql = "SELECT `username`,`age` FROM `Users` WHERE id = '1' ORDER BY `Users`.`id` LIMIT 1" +func TestSelectByIdSelect(t *testing.T) { + var expectSql = "SELECT `username`,`age` FROM `Users` WHERE id = 1 LIMIT 1" sessionDb := checkSelectSql(t, expectSql) u := gplus.GetModel[User]() gplus.SelectById[User](1, gplus.Db(sessionDb), gplus.Select(&u.Username, &u.Age)) } -func TestSelectById3Name(t *testing.T) { - var expectSql = "SELECT `Users`.`id`,`Users`.`created_at`,`Users`.`updated_at`,`Users`.`password`,`Users`.`address`,`Users`.`phone`,`Users`.`score`,`Users`.`dept` FROM `Users` WHERE id = '1' ORDER BY `Users`.`id` LIMIT 1" +func TestSelectByIdOmit(t *testing.T) { + var expectSql = "SELECT `Users`.`id`,`Users`.`created_at`,`Users`.`updated_at`,`Users`.`password`,`Users`.`address`,`Users`.`phone`,`Users`.`score`,`Users`.`dept` FROM `Users` WHERE id = 1 LIMIT 1" sessionDb := checkSelectSql(t, expectSql) u := gplus.GetModel[User]() gplus.SelectById[User](1, gplus.Db(sessionDb), gplus.Omit(&u.Username, &u.Age)) } -func TestSelectByIdsName(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE id IN ('1','2')" +func TestSelectByIdsIn(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE id IN (1,2)" sessionDb := checkSelectSql(t, expectSql) gplus.SelectByIds[User]([]int{1, 2}, gplus.Db(sessionDb)) } func TestSelectOneName(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' ORDER BY `Users`.`id` LIMIT 1" + var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' LIMIT 1" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu") gplus.SelectOne[User](query, gplus.Db(sessionDb)) } -func TestSelectList1Name(t *testing.T) { +func TestSelectListEq(t *testing.T) { var expectSql = " SELECT * FROM `Users` WHERE username = 'afumu'" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -66,7 +66,7 @@ func TestSelectList1Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList2Name(t *testing.T) { +func TestSelectListNe(t *testing.T) { var expectSql = " SELECT * FROM `Users` WHERE username <> 'afumu'" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -74,39 +74,39 @@ func TestSelectList2Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList3Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE age > '20'" +func TestSelectListGt(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE age > 20" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Gt(&u.Age, 20) gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList4Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE age >= '20'" +func TestSelectListGe(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE age >= 20" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Ge(&u.Age, 20) gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList5Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE age < '20'" +func TestSelectListLt(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE age < 20" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Lt(&u.Age, 20) gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList6Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE age <= '20'" +func TestSelectListLe(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE age <= 20" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Le(&u.Age, 20) gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList7Name(t *testing.T) { +func TestSelectListLike(t *testing.T) { var expectSql = "SELECT * FROM `Users` WHERE username LIKE '%zhang%'" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -114,7 +114,7 @@ func TestSelectList7Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList8Name(t *testing.T) { +func TestSelectListLeftLike(t *testing.T) { var expectSql = "SELECT * FROM `Users` WHERE username LIKE '%zhang'" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -122,7 +122,15 @@ func TestSelectList8Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList9Name(t *testing.T) { +func TestSelectListNotLeftLike(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE username NOT LIKE '%zhang'" + sessionDb := checkSelectSql(t, expectSql) + query, u := gplus.NewQuery[User]() + query.NotLikeLeft(&u.Username, "zhang") + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestSelectListRightLike(t *testing.T) { var expectSql = "SELECT * FROM `Users` WHERE username LIKE 'zhang%'" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -130,7 +138,15 @@ func TestSelectList9Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList10Name(t *testing.T) { +func TestSelectListNotRightLike(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE username NOT LIKE 'zhang%'" + sessionDb := checkSelectSql(t, expectSql) + query, u := gplus.NewQuery[User]() + query.NotLikeRight(&u.Username, "zhang") + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestSelectListIsNull(t *testing.T) { var expectSql = "SELECT * FROM `Users` WHERE username IS NULL" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -138,7 +154,7 @@ func TestSelectList10Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList11Name(t *testing.T) { +func TestSelectListIsNotNull(t *testing.T) { var expectSql = "SELECT * FROM `Users` WHERE username IS NOT NULL" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -146,7 +162,7 @@ func TestSelectList11Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList12Name(t *testing.T) { +func TestSelectListIn(t *testing.T) { var expectSql = "SELECT * FROM `Users` WHERE username IN ('afumu','zhangsan')" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -154,7 +170,7 @@ func TestSelectList12Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList13Name(t *testing.T) { +func TestSelectListNotIn(t *testing.T) { var expectSql = "SELECT * FROM `Users` WHERE username NOT IN ('afumu','zhangsan')" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() @@ -162,40 +178,40 @@ func TestSelectList13Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList14Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE age BETWEEN '18' AND '20'" +func TestSelectListBetween(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE age BETWEEN 18 AND 20" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Between(&u.Age, 18, 20) gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList17Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE age NOT BETWEEN '18' AND '20'" +func TestSelectListNotBetween(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE age NOT BETWEEN 18 AND 20" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.NotBetween(&u.Age, 18, 20) gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList20Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' AND age = '20'" +func TestSelectListAnd(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' AND age = 20" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").And().Eq(&u.Age, 20) gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList21Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' OR age = '20'" +func TestSelectListOr(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' OR age = 20" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").Or().Eq(&u.Age, 20) gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList22Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' OR ( username = 'zhangsan' AND age = '30' )" +func TestSelectListOrNest(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' OR ( username = 'zhangsan' AND age = 30 )" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").Or(func(q *gplus.QueryCond[User]) { @@ -204,8 +220,8 @@ func TestSelectList22Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList23Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' AND ( username = 'zhangsan' OR age = '30' )" +func TestSelectListAndNest(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu' AND ( username = 'zhangsan' OR age = 30 )" sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").And(func(q *gplus.QueryCond[User]) { @@ -214,8 +230,8 @@ func TestSelectList23Name(t *testing.T) { gplus.SelectList[User](query, gplus.Db(sessionDb)) } -func TestSelectList24Name(t *testing.T) { - var expectSql = "SELECT * FROM `Users` WHERE ( username = 'afumu' AND ( password = '123456' OR score = '60' ) OR dept = '开发' ) AND address = '北京' " +func TestSelectListAndOrNest(t *testing.T) { + var expectSql = "SELECT * FROM `Users` WHERE ( username = 'afumu' AND ( password = '123456' OR score = 60 ) OR dept = '开发' ) AND address = '北京' " sessionDb := checkSelectSql(t, expectSql) query, u := gplus.NewQuery[User]() query.And(func(q *gplus.QueryCond[User]) { @@ -235,7 +251,7 @@ func TestSelectListOrder(t *testing.T) { } func TestSelectListQueryModel(t *testing.T) { - var expectSql = "SELECT username AS name,`age` FROM `Users` WHERE username = 'afumu' AND ( address = '北京' OR age = '20' ) " + var expectSql = "SELECT username AS name,`age` FROM `Users` WHERE username = 'afumu' AND ( address = '北京' OR age = 20 ) " sessionDb := checkSelectSql(t, expectSql) type UserVo struct { Name string @@ -249,7 +265,7 @@ func TestSelectListQueryModel(t *testing.T) { } func TestSelectListQueryModelSum(t *testing.T) { - var expectSql = "SELECT `username`,SUM(age) AS total FROM `Users` GROUP BY `username` HAVING SUM(age) NOT BETWEEN '333' AND '1000'" + var expectSql = "SELECT `username`,SUM(age) AS total FROM `Users` GROUP BY `username` HAVING SUM(age) NOT BETWEEN 333 AND 1000" sessionDb := checkSelectSql(t, expectSql) type UserVo struct { Username string diff --git a/tests/tool_test.go b/tests/tool_test.go new file mode 100644 index 0000000..86c3369 --- /dev/null +++ b/tests/tool_test.go @@ -0,0 +1,283 @@ +/* + * Licensed to the AcmeStack under one or more contributor license + * agreements. See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package tests + +import ( + "github.com/aixj1984/gorm-plus/gplus" + "net/url" + "testing" +) + +func TestQueryById(t *testing.T) { + values := url.Values{} + values["q"] = []string{"id=1"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE id = 1" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByIdSelect(t *testing.T) { + values := url.Values{} + values["q"] = []string{"id=1"} + values["select"] = []string{"username,age"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT `username`,`age` FROM `Users` WHERE id = 1" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByIdOmit(t *testing.T) { + values := url.Values{} + values["q"] = []string{"id=1"} + values["omit"] = []string{"username,age"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT `Users`.`id`,`Users`.`created_at`,`Users`.`updated_at`,`Users`.`password`,`Users`.`address`,`Users`.`phone`,`Users`.`score`,`Users`.`dept` FROM `Users` WHERE id = 1" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByIdSortAsc(t *testing.T) { + values := url.Values{} + values["sort"] = []string{"age"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` ORDER BY age ASC" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByIdSortDesc(t *testing.T) { + values := url.Values{} + values["sort"] = []string{"-age"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` ORDER BY age DESC" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByIdsIn(t *testing.T) { + values := url.Values{} + values["q"] = []string{"id?=1,2"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE id IN (1,2)" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByEqUsername(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username=afumu"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu'" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByNeUsername(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username!=afumu"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username <> 'afumu'" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByGtAge(t *testing.T) { + values := url.Values{} + values["q"] = []string{"age>20"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE age > 20" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByGeAge(t *testing.T) { + values := url.Values{} + values["q"] = []string{"age>=20"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE age >= 20" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByLtAge(t *testing.T) { + values := url.Values{} + values["q"] = []string{"age<20"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE age < 20" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByLeAge(t *testing.T) { + values := url.Values{} + values["q"] = []string{"age<=20"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE age <= 20" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByLike(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username~=afumu"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username LIKE '%afumu%'" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByLeftLike(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username~<=afumu"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username LIKE '%afumu'" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByNotLeftLike(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username!~<=afumu"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username NOT LIKE '%afumu'" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByRightLike(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username~>=afumu"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username LIKE 'afumu%'" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByNotRightLike(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username!~>=afumu"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username NOT LIKE 'afumu%'" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByIsNull(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username=null"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username IS NULL" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByIsNotNull(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username!=null"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username IS NOT NULL" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByIn(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username?=afumu,zhangsan"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username IN ('afumu','zhangsan')" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByNotIn(t *testing.T) { + values := url.Values{} + values["q"] = []string{"username!?=afumu,zhangsan"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE username NOT IN ('afumu','zhangsan')" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByBetween(t *testing.T) { + values := url.Values{} + values["q"] = []string{"age^=20,30"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE age BETWEEN 20 AND 30" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByNotBetween(t *testing.T) { + values := url.Values{} + values["q"] = []string{"age!^=20,30"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE age NOT BETWEEN 20 AND 30" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByAnd(t *testing.T) { + values := url.Values{} + values["q"] = []string{"useranme=afumu", "age=20"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE useranme = 'afumu' AND age = 20" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByGroupOnlyOne(t *testing.T) { + values := url.Values{} + values["q"] = []string{"A.useranme=afumu", "A.age=20"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE useranme = 'afumu' AND age = 20" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByGroupAAndB(t *testing.T) { + values := url.Values{} + values["q"] = []string{"A.useranme=afumu", "A.password=123456", "B.age=20", "B.score=90"} + values["gcond"] = []string{"A*B"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE useranme = 'afumu' AND password = '123456' AND age = 20 AND score = 90" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByGroupAOrB(t *testing.T) { + values := url.Values{} + values["q"] = []string{"A.useranme=afumu", "A.password=123456", "B.age=20", "B.score=90"} + values["gcond"] = []string{"A|B"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE useranme = 'afumu' AND password = '123456' OR age = 20 AND score = 90" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} + +func TestQueryByGroupNest(t *testing.T) { + values := url.Values{} + values["q"] = []string{ + "A.useranme=afumu", "B.password=12345", "C.score=60", + "D.dept=开发", "F.address=北京", + } + values["gcond"] = []string{"(A*(B|C)|D)*F"} + query := gplus.BuildQuery[User](values) + var expectSql = "SELECT * FROM `Users` WHERE ( useranme = 'afumu' AND ( password = '12345' OR score = 60 ) OR dept = '开发' ) AND address = '北京'" + sessionDb := checkSelectSql(t, expectSql) + gplus.SelectList[User](query, gplus.Db(sessionDb)) +} diff --git a/tests/update_test.go b/tests/update_test.go index a5357bf..abf06b1 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -25,7 +25,7 @@ import ( ) func TestUpdateByIdName(t *testing.T) { - var expectSql = "UPDATE `Users` SET `score`='100' WHERE `id` = '1'" + var expectSql = "UPDATE `Users` SET `score`=100 WHERE `id` = 1" sessionDb := checkUpdateSql(t, expectSql) var user = &User{Base: Base{ID: 1}, Score: 100} u := gplus.GetModel[User]() @@ -33,7 +33,7 @@ func TestUpdateByIdName(t *testing.T) { } func TestUpdateZeroByIdName(t *testing.T) { - var expectSql = "UPDATE `Users` SET `username`='',`password`='',`address`='',`age`='0',`phone`='',`score`='100',`dept`='' WHERE `id` = '1'" + var expectSql = "UPDATE `Users` SET `username`='',`password`='',`address`='',`age`=0,`phone`='',`score`=100,`dept`='' WHERE `id` = 1" sessionDb := checkUpdateSql(t, expectSql) var user = &User{Base: Base{ID: 1}, Score: 100} u := gplus.GetModel[User]() @@ -41,7 +41,7 @@ func TestUpdateZeroByIdName(t *testing.T) { } func TestUpdate1Name(t *testing.T) { - var expectSql = "UPDATE `Users` SET `score`='100' WHERE id = '1'" + var expectSql = "UPDATE `Users` SET `score`=100 WHERE id = 1" sessionDb := checkUpdateSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.ID, 1).Set(&u.Score, 100) @@ -49,7 +49,7 @@ func TestUpdate1Name(t *testing.T) { } func TestUpdate2Name(t *testing.T) { - var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username = 'afumu' AND age = '18'" + var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`=100 WHERE username = 'afumu' AND age = 18" sessionDb := checkUpdateSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").Eq(&u.Age, 18). @@ -59,7 +59,7 @@ func TestUpdate2Name(t *testing.T) { } func TestUpdate3Name(t *testing.T) { - var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username = 'afumu' OR age = '18'" + var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`=100 WHERE username = 'afumu' OR age = 18" sessionDb := checkUpdateSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").Or().Eq(&u.Age, 18). @@ -69,7 +69,7 @@ func TestUpdate3Name(t *testing.T) { } func TestUpdate4Name(t *testing.T) { - var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username = 'afumu' OR ( age = '18' AND score = '100' )" + var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`=100 WHERE username = 'afumu' OR ( age = 18 AND score = 100 )" sessionDb := checkUpdateSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").Or(func(q *gplus.QueryCond[User]) { @@ -81,7 +81,7 @@ func TestUpdate4Name(t *testing.T) { } func TestUpdate5Name(t *testing.T) { - var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username = 'afumu' AND ( age = '18' OR score = '100' )" + var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`=100 WHERE username = 'afumu' AND ( age = 18 OR score = 100 )" sessionDb := checkUpdateSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu"). @@ -94,7 +94,7 @@ func TestUpdate5Name(t *testing.T) { } func TestUpdate6Name(t *testing.T) { - var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username <> 'afumu'" + var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`=100 WHERE username <> 'afumu'" sessionDb := checkUpdateSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Ne(&u.Username, "afumu"). @@ -104,7 +104,7 @@ func TestUpdate6Name(t *testing.T) { } func TestUpdate7Name(t *testing.T) { - var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username IS NULL" + var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`=100 WHERE username IS NULL" sessionDb := checkUpdateSql(t, expectSql) query, u := gplus.NewQuery[User]() query.IsNull(&u.Username). @@ -114,7 +114,7 @@ func TestUpdate7Name(t *testing.T) { } func TestUpdateRest(t *testing.T) { - var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username = 'afumu' OR age = '18'" + var expectSql = "UPDATE `Users` SET `address`='shanghai',`score`=100 WHERE username = 'afumu' OR age = 18" sessionDb := checkUpdateSql(t, expectSql) query, u := gplus.NewQuery[User]() query.Eq(&u.Username, "afumu").Or().Eq(&u.Age, 18). @@ -122,7 +122,7 @@ func TestUpdateRest(t *testing.T) { Set(&u.Address, "shanghai") gplus.Update(query, gplus.Db(sessionDb), gplus.Omit(&u.CreatedAt, &u.UpdatedAt)) - expectSql = "UPDATE `Users` SET `address`='shanghai',`score`='100' WHERE username = 'afumu' AND age = '18'" + expectSql = "UPDATE `Users` SET `address`='shanghai',`score`=100 WHERE username = 'afumu' AND age = 18" sessionDb = checkUpdateSql(t, expectSql) query.Reset() query.Eq(&u.Username, "afumu").Eq(&u.Age, 18). diff --git a/tests/user.go b/tests/user.go index aa3560b..b89f8c7 100644 --- a/tests/user.go +++ b/tests/user.go @@ -23,13 +23,13 @@ import ( type Base struct { ID int64 - CreatedAt time.Time `json:"createdAt"` - UpdatedAt time.Time `json:"updatedAt"` + CreatedAt time.Time + UpdatedAt time.Time } type User struct { Base - Username string `gorm:"column:username"` + Username string Password string Address string Age int diff --git a/tests/utils.go b/tests/utils.go index 2f78c63..876ccd7 100644 --- a/tests/utils.go +++ b/tests/utils.go @@ -143,7 +143,16 @@ func Now() *time.Time { func buildSql(db *gorm.DB) string { sql := db.Statement.SQL.String() for _, value := range db.Statement.Vars { - sql = strings.Replace(sql, "?", fmt.Sprintf("'%v'", value), 1) + sql = strings.Replace(sql, "?", convert(value), 1) } return sql } + +func convert(value any) string { + columnType := reflect.TypeOf(value) + switch columnType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return fmt.Sprintf("%v", value) + } + return fmt.Sprintf("'%v'", value) +} From b9f937b327e0d462cd86bda5e300963da2f632c1 Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Wed, 13 Sep 2023 20:15:29 +0800 Subject: [PATCH 07/12] sync --- gplus/dao.go | 8 +++ gplus/query.go | 124 ++++++++++++++++++++++++++++++++++++++++++++++ tests/dao_test.go | 12 +++++ 3 files changed, 144 insertions(+) diff --git a/gplus/dao.go b/gplus/dao.go index caa55cb..88c2af3 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -266,6 +266,8 @@ func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Pag func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) { var count int64 resultDb := buildCondition(q, opts...) + //fix 查询有设置Select并且数量只有一个且有设置别名,生成sql不对问题 + resultDb.Statement.Selects = nil resultDb.Count(&count) return count, resultDb } @@ -322,6 +324,12 @@ func Begin(opts ...*sql.TxOptions) *gorm.DB { return db.Begin(opts...) } +// 事务 +func Tx(txFunc func(tx *gorm.DB) error, opts ...OptionFunc) error { + db := getDb(opts...) + return db.Transaction(txFunc) +} + func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB { page := p.Current pageSize := p.Size diff --git a/gplus/query.go b/gplus/query.go index 0520029..134f6f8 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -318,6 +318,130 @@ func (q *QueryCond[T]) Set(column any, val any) *QueryCond[T] { return q } +/* +* 自定义条件 + */ + +// AndCond 拼接 AND +func (q *QueryCond[T]) AndCond(cond bool, fn ...func(q *QueryCond[T])) *QueryCond[T] { + if cond { + return q.And(fn...) + } + return q +} + +// OrCond 拼接 OR +func (q *QueryCond[T]) OrCond(cond bool, fn ...func(q *QueryCond[T])) *QueryCond[T] { + if cond { + return q.Or(fn...) + } + return q +} + +// EqCond 等于 = +func (q *QueryCond[T]) EqCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Eq(column, val) + } + return q +} + +// NeCond 不等于 != +func (q *QueryCond[T]) NeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Ne(column, val) + } + return q +} + +// GtCond 大于 > +func (q *QueryCond[T]) GtCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Gt(column, val) + } + return q +} + +// GeCond 大于等于 >= +func (q *QueryCond[T]) GeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Ge(column, val) + } + return q +} + +// LtCond 小于 < +func (q *QueryCond[T]) LtCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Lt(column, val) + } + return q +} + +// LeCond 小于等于 <= +func (q *QueryCond[T]) LeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Le(column, val) + } + return q +} + +// LikeCond 模糊 LIKE '%值%' +func (q *QueryCond[T]) LikeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Like(column, val) + } + return q +} + +// NotLikeCond 非模糊 NOT LIKE '%值%' +func (q *QueryCond[T]) NotLikeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.NotLike(column, val) + } + return q +} + +// LikeLeftCond 左模糊 LIKE '%值' +func (q *QueryCond[T]) LikeLeftCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.LikeLeft(column, val) + } + return q +} + +// NotLikeLeftCond 非左模糊 NOT LIKE '%值' +func (q *QueryCond[T]) NotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.NotLike(column, val) + } + return q +} + +// LikeRightCond 右模糊 LIKE '值%' +func (q *QueryCond[T]) LikeRightCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.LikeRight(column, val) + } + return q +} + +// NotLikeRightCond 非右模糊 NOT LIKE '值%' +func (q *QueryCond[T]) NotLikeRightCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.NotLikeRight(column, val) + } + return q +} + +// InCond 字段 IN (值1, 值2, ...) +func (q *QueryCond[T]) InCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.In(column, val) + } + return q +} + func (q *QueryCond[T]) addExpression(sqlSegments ...SqlSegment) { if len(sqlSegments) == 1 { q.handleSingle(sqlSegments[0]) diff --git a/tests/dao_test.go b/tests/dao_test.go index c48ba8b..e448640 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -557,6 +557,18 @@ func TestSelectGeneric7(t *testing.T) { } } +func TestTx(t *testing.T) { + deleteOldData() + users := getUsers() + err := gplus.Tx(func(tx *gorm.DB) error { + err := gplus.InsertBatch[User](users, gplus.Db(tx)).Error + return err + }) + if err != nil { + t.Errorf(err.Error()) + } +} + func TestCase(t *testing.T) { deleteOldData() users := getUsers() From 83828958b57ded38f3e6f19de01bbd579e1547e8 Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Fri, 15 Sep 2023 10:07:54 +0800 Subject: [PATCH 08/12] sync buf fix --- gplus/dao.go | 91 +++++++++++++++++++++- gplus/query.go | 208 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 298 insertions(+), 1 deletion(-) diff --git a/gplus/dao.go b/gplus/dao.go index 88c2af3..514597f 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -19,6 +19,7 @@ package gplus import ( "database/sql" + "fmt" "reflect" "strings" "time" @@ -55,6 +56,29 @@ func NewPage[T any](current, size int) *Page[T] { return &Page[T]{Current: current, Size: size} } +type Comparable interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | time.Time +} + +type StreamingPage[T any, V Comparable] struct { + ColumnName any `json:"columnName"` // 进行分页的列字段名称 + StartValue V `json:"startValue"` // 分页起始值 + Limit int `json:"limit"` // 页大小 + Forward bool `json:"forward"` // 上下页翻页标识 + Total int64 `json:"total"` // 总记录数 + Records []*T `json:"records"` // 查询记录 + RecordsMap []T `json:"recordsMap"` // 查询记录Map +} + +func NewStreamingPage[T any, V Comparable](columnName any, startValue V, limit int) *StreamingPage[T, V] { + return &StreamingPage[T, V]{ + ColumnName: columnName, + StartValue: startValue, + Limit: limit, + Forward: true, + } +} + // Insert 插入一条记录 func Insert[T any](entity *T, opts ...OptionFunc) *gorm.DB { db := getDb(opts...) @@ -262,6 +286,26 @@ func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Pag return page, resultDb } +// SelectStreamingPage 根据条件分页查询记录 +func SelectStreamingPage[T any, V Comparable](page *StreamingPage[T, V], q *QueryCond[T], opts ...OptionFunc) (*StreamingPage[T, V], *gorm.DB) { + option := getOption(opts) + + // 如果需要分页忽略总数,不查询总数 + if !option.IgnoreTotal { + total, countDb := SelectCount[T](q, opts...) + if countDb.Error != nil { + return page, countDb + } + page.Total = total + } + + resultDb := buildCondition(q, opts...) + var results []*T + resultDb.Scopes(streamingPaginate(page)).Find(&results) + page.Records = results + return page, resultDb +} + // SelectCount 根据条件查询记录数量 func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) { var count int64 @@ -310,6 +354,34 @@ func SelectPageGeneric[T any, R any](page *Page[R], q *QueryCond[T], opts ...Opt return page, resultDb } +// SelectStreamingPageGeneric 根据传入的泛型封装分页记录 +// 第一个泛型代表数据库表实体 +// 第二个泛型代表返回记录实体 +func SelectStreamingPageGeneric[T any, R any, V Comparable](page *StreamingPage[R, V], q *QueryCond[T], opts ...OptionFunc) (*StreamingPage[R, V], *gorm.DB) { + option := getOption(opts) + // 如果需要分页忽略总数,不查询总数 + if !option.IgnoreTotal { + total, countDb := SelectCount[T](q, opts...) + if countDb.Error != nil { + return page, countDb + } + page.Total = total + } + resultDb := buildCondition(q, opts...) + var r R + switch any(r).(type) { + case map[string]any: + var results []R + resultDb.Scopes(streamingPaginate(page)).Scan(&results) + page.RecordsMap = results + default: + var results []*R + resultDb.Scopes(streamingPaginate(page)).Scan(&results) + page.Records = results + } + return page, resultDb +} + // SelectGeneric 根据传入的泛型封装记录 // 第一个泛型代表数据库表实体 // 第二个泛型代表返回记录实体 @@ -324,12 +396,13 @@ func Begin(opts ...*sql.TxOptions) *gorm.DB { return db.Begin(opts...) } -// 事务 +// Tx 事务 func Tx(txFunc func(tx *gorm.DB) error, opts ...OptionFunc) error { db := getDb(opts...) return db.Transaction(txFunc) } +// paginate offset分页 func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB { page := p.Current pageSize := p.Size @@ -345,6 +418,22 @@ func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB { } } +// streamingPaginate 流式分页,根据自增ID、雪花ID、时间等数值类型或者时间类型分页 +// Tips: 相比于 offset 分页性能更好,走的是 range,缺点是没办法跳页查询 +func streamingPaginate[T any, V Comparable](p *StreamingPage[T, V]) func(db *gorm.DB) *gorm.DB { + column := getColumnName(p.ColumnName) + startValue := p.StartValue + limit := p.Limit + return func(db *gorm.DB) *gorm.DB { + // 下一页 + if p.Forward { + return db.Where(fmt.Sprintf("%v > ?", column), startValue).Limit(limit) + } + // 上一页 + return db.Where(fmt.Sprintf("%v < ?", column), startValue).Order(fmt.Sprintf("%v DESC", column)).Limit(limit) + } +} + func buildCondition[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB { db := getDb(opts...) resultDb := db.Model(new(T)) diff --git a/gplus/query.go b/gplus/query.go index 134f6f8..bf19ec9 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -442,6 +442,214 @@ func (q *QueryCond[T]) InCond(cond bool, column any, val any) *QueryCond[T] { return q } +// AndEqCond 并且等于 = +func (q *QueryCond[T]) AndEqCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().Eq(column, val) + } + return q +} + +// AndNeCond 并且不等于 != +func (q *QueryCond[T]) AndNeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().Ne(column, val) + } + return q +} + +// AndGtCond 并且大于 > +func (q *QueryCond[T]) AndGtCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().Gt(column, val) + } + return q +} + +// AndGeCond 并且大于等于 >= +func (q *QueryCond[T]) AndGeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().Ge(column, val) + } + return q +} + +// AndLtCond 并且小于 < +func (q *QueryCond[T]) AndLtCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().Lt(column, val) + } + return q +} + +// AndLeCond 并且小于等于 <= +func (q *QueryCond[T]) AndLeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().Le(column, val) + } + return q +} + +// AndLikeCond 并且模糊 LIKE '%值%' +func (q *QueryCond[T]) AndLikeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().Like(column, val) + } + return q +} + +// AndNotLikeCond 并且非模糊 NOT LIKE '%值%' +func (q *QueryCond[T]) AndNotLikeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().NotLike(column, val) + } + return q +} + +// AndLikeLeftCond 并且左模糊 LIKE '%值' +func (q *QueryCond[T]) AndLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().LikeLeft(column, val) + } + return q +} + +// AndNotLikeLeftCond 并且非左模糊 NOT LIKE '%值' +func (q *QueryCond[T]) AndNotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().NotLikeLeft(column, val) + } + return q +} + +// AndLikeRightCond 并且右模糊 LIKE '值%' +func (q *QueryCond[T]) AndLikeRightCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().LikeRight(column, val) + } + return q +} + +// AndNotLikeRightCond 并且非右模糊 NOT LIKE '值%' +func (q *QueryCond[T]) AndNotLikeRightCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().NotLikeRight(column, val) + } + return q +} + +// AndInCond 并且字段 IN (值1, 值2, ...) +func (q *QueryCond[T]) AndInCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.And().In(column, val) + } + return q +} + +// OrEqCond 或者等于 = +func (q *QueryCond[T]) OrEqCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().Eq(column, val) + } + return q +} + +// OrNeCond 或者不等于 != +func (q *QueryCond[T]) OrNeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().Ne(column, val) + } + return q +} + +// OrGtCond 或者大于 > +func (q *QueryCond[T]) OrGtCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().Gt(column, val) + } + return q +} + +// OrGeCond 或者大于等于 >= +func (q *QueryCond[T]) OrGeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().Ge(column, val) + } + return q +} + +// OrLtCond 或者小于 < +func (q *QueryCond[T]) OrLtCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().Lt(column, val) + } + return q +} + +// OrLeCond 或者小于等于 <= +func (q *QueryCond[T]) OrLeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().Le(column, val) + } + return q +} + +// OrLikeCond 或者模糊 LIKE '%值%' +func (q *QueryCond[T]) OrLikeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().Like(column, val) + } + return q +} + +// OrNotLikeCond 或者非模糊 NOT LIKE '%值%' +func (q *QueryCond[T]) OrNotLikeCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().NotLike(column, val) + } + return q +} + +// OrLikeLeftCond 或者左模糊 LIKE '%值' +func (q *QueryCond[T]) OrLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().LikeLeft(column, val) + } + return q +} + +// OrNotLikeLeftCond 或者非左模糊 NOT LIKE '%值' +func (q *QueryCond[T]) OrNotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().NotLikeLeft(column, val) + } + return q +} + +// OrLikeRightCond 或者右模糊 LIKE '值%' +func (q *QueryCond[T]) OrLikeRightCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().LikeRight(column, val) + } + return q +} + +// OrNotLikeRightCond 或者非右模糊 NOT LIKE '值%' +func (q *QueryCond[T]) OrNotLikeRightCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().NotLikeRight(column, val) + } + return q +} + +// OrInCond 或者字段 IN (值1, 值2, ...) +func (q *QueryCond[T]) OrInCond(cond bool, column any, val any) *QueryCond[T] { + if cond { + return q.Or().In(column, val) + } + return q +} + func (q *QueryCond[T]) addExpression(sqlSegments ...SqlSegment) { if len(sqlSegments) == 1 { q.handleSingle(sqlSegments[0]) From 72f3e231cd38ce39b3b27237c0e12f3d0728aaf1 Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Fri, 15 Sep 2023 10:18:13 +0800 Subject: [PATCH 09/12] add str condition --- gplus/query.go | 16 ++++++++++++++-- tests/dao_test.go | 5 +++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/gplus/query.go b/gplus/query.go index bf19ec9..6066f6a 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -745,8 +745,8 @@ func (q *QueryCond[T]) buildOrder(orderType string, columns ...string) { } } -// 执行执行增加条件 -func (q *QueryCond[T]) AddStrCond(cond string) *QueryCond[T] { +// 执行增加AND条件 +func (q *QueryCond[T]) AddAndStrCond(cond string) *QueryCond[T] { if len(q.queryExpressions) > 0 { sk := sqlKeyword{keyword: constants.And} q.queryExpressions = append(q.queryExpressions, &sk) @@ -757,6 +757,18 @@ func (q *QueryCond[T]) AddStrCond(cond string) *QueryCond[T] { return q } +// 执行增加OR条件 +func (q *QueryCond[T]) AddOrStrCond(cond string) *QueryCond[T] { + if len(q.queryExpressions) > 0 { + sk := sqlKeyword{keyword: constants.Or} + q.queryExpressions = append(q.queryExpressions, &sk) + } + condSk := sqlKeyword{keyword: cond} + q.queryExpressions = append(q.queryExpressions, &condSk) + q.last = &condSk + return q +} + // 根据条件,执行方法 func (q *QueryCond[T]) Case(isTrue bool, handleFunc func()) *QueryCond[T] { if isTrue { diff --git a/tests/dao_test.go b/tests/dao_test.go index e448640..f21419a 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -34,6 +34,7 @@ var gormDb *gorm.DB func init() { dsn := "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local" + var err error gormDb, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), @@ -660,7 +661,7 @@ func TestQueryBuilder(t *testing.T) { query, _ := gplus.NewQuery[User]() - query.AddStrCond(fmt.Sprintf(" username = '%s' ", "afumu1")) + query.AddAndStrCond(fmt.Sprintf(" username = '%s' ", "afumu1")) count, db := gplus.SelectCount(query) if db.Error != nil { @@ -678,7 +679,7 @@ func TestExist(t *testing.T) { query, _ := gplus.NewQuery[User]() - query.AddStrCond(fmt.Sprintf(" username = '%s' ", "afumu1")) + query.AddAndStrCond(fmt.Sprintf(" username = '%s' ", "afumu1")) exist, dbErr := gplus.Exists(query) if dbErr != nil { From 202a0eee2da0ea62499c828e540cb8da35559c4a Mon Sep 17 00:00:00 2001 From: aixj1984 Date: Fri, 15 Sep 2023 11:50:47 +0800 Subject: [PATCH 10/12] fix --- gplus/dao.go | 4 ++++ tests/dao_test.go | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/gplus/dao.go b/gplus/dao.go index 514597f..3929e6f 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -501,6 +501,10 @@ func buildSqlAndArgs[T any](expressions []any, sqlBuilder *strings.Builder, quer queryArgs = append(queryArgs, segment.value) } case *QueryCond[T]: + // 当子条件不存在查询表达式时,无需进行递归处理 + if len(segment.queryExpressions) == 0 { + continue + } sqlBuilder.WriteString(constants.LeftBracket + " ") // 递归处理条件 queryArgs = buildSqlAndArgs[T](segment.queryExpressions, sqlBuilder, queryArgs) diff --git a/tests/dao_test.go b/tests/dao_test.go index f21419a..d675d6a 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -34,7 +34,6 @@ var gormDb *gorm.DB func init() { dsn := "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local" - var err error gormDb, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), From 13c7967932137dcf30758971edfd3499f700a832 Mon Sep 17 00:00:00 2001 From: xiongjun ai Date: Wed, 1 Nov 2023 15:05:18 +0800 Subject: [PATCH 11/12] fix bug --- README.md | 78 ------ gplus/dao.go | 95 ------- gplus/query.go | 333 +--------------------- gplus/{segment.go => sqlSegment.go} | 0 gplus/tool.go | 421 ---------------------------- tests/dao_test.go | 25 +- tests/delete_test.go | 5 +- tests/insert_test.go | 15 +- tests/select_test.go | 5 +- tests/tool_test.go | 283 ------------------- 10 files changed, 22 insertions(+), 1238 deletions(-) rename gplus/{segment.go => sqlSegment.go} (100%) delete mode 100644 gplus/tool.go delete mode 100644 tests/tool_test.go diff --git a/README.md b/README.md index 14895fe..622e8a8 100644 --- a/README.md +++ b/README.md @@ -125,84 +125,6 @@ func main() { ``` -## 搜索工具 - -只需要下面一行代码即可完成单表的所有查询功能 - -```Bash -gplus.SelectList(gplus.BuildQuery[User](queryParams)) -``` - - - -例子: - -```Bash -func main() { - http.HandleFunc("/", handleRequest) - http.ListenAndServe(":8080", nil) -} - -func handleRequest(w http.ResponseWriter, r *http.Request) { - queryParams := r.URL.Query() - list, _ := gplus.SelectList(gplus.BuildQuery[User](queryParams)) - marshal, _ := json.Marshal(list) - w.Write(marshal) -} -``` - -假设我们要查询username为zhangsan的用户 - -```Bash -http://localhost:8080?q=username=zhangsan -``` - - - -假设我们要查询username姓zhang的用户 - -```Bash -http://localhost:8080?q=username~>=zhang -``` - - - -假设我们要查询age大于20的用户 - -```Bash -http://localhost:8080?q=age>20 -``` - - - -假设我们要查询username等于zhagnsan,password等于123456的用户 - -```Bash -http://localhost:8080?q=username=zhangsan&q=password=123456 -``` - - - -假设我们要查询username等于zhagnsan,password等于123456的用户 - -```Bash -http://localhost:8080?q=username=zhangsan&q=password=123456 -``` - - - -假设我们要查询username等于zhagnsan,或者usename等于lisi的用户 - -可以增加一个分组和gcond的条件查询来实现 - -```Bash -http://localhost:8080?q=A.username=zhangsan&q=B.username=lisi&gcond=A|B -``` - - - -所有的单表查询我们都只需要一行代码即可。 - ## 总结 diff --git a/gplus/dao.go b/gplus/dao.go index 3929e6f..a202709 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -19,7 +19,6 @@ package gplus import ( "database/sql" - "fmt" "reflect" "strings" "time" @@ -56,29 +55,6 @@ func NewPage[T any](current, size int) *Page[T] { return &Page[T]{Current: current, Size: size} } -type Comparable interface { - ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | time.Time -} - -type StreamingPage[T any, V Comparable] struct { - ColumnName any `json:"columnName"` // 进行分页的列字段名称 - StartValue V `json:"startValue"` // 分页起始值 - Limit int `json:"limit"` // 页大小 - Forward bool `json:"forward"` // 上下页翻页标识 - Total int64 `json:"total"` // 总记录数 - Records []*T `json:"records"` // 查询记录 - RecordsMap []T `json:"recordsMap"` // 查询记录Map -} - -func NewStreamingPage[T any, V Comparable](columnName any, startValue V, limit int) *StreamingPage[T, V] { - return &StreamingPage[T, V]{ - ColumnName: columnName, - StartValue: startValue, - Limit: limit, - Forward: true, - } -} - // Insert 插入一条记录 func Insert[T any](entity *T, opts ...OptionFunc) *gorm.DB { db := getDb(opts...) @@ -286,26 +262,6 @@ func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Pag return page, resultDb } -// SelectStreamingPage 根据条件分页查询记录 -func SelectStreamingPage[T any, V Comparable](page *StreamingPage[T, V], q *QueryCond[T], opts ...OptionFunc) (*StreamingPage[T, V], *gorm.DB) { - option := getOption(opts) - - // 如果需要分页忽略总数,不查询总数 - if !option.IgnoreTotal { - total, countDb := SelectCount[T](q, opts...) - if countDb.Error != nil { - return page, countDb - } - page.Total = total - } - - resultDb := buildCondition(q, opts...) - var results []*T - resultDb.Scopes(streamingPaginate(page)).Find(&results) - page.Records = results - return page, resultDb -} - // SelectCount 根据条件查询记录数量 func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) { var count int64 @@ -354,34 +310,6 @@ func SelectPageGeneric[T any, R any](page *Page[R], q *QueryCond[T], opts ...Opt return page, resultDb } -// SelectStreamingPageGeneric 根据传入的泛型封装分页记录 -// 第一个泛型代表数据库表实体 -// 第二个泛型代表返回记录实体 -func SelectStreamingPageGeneric[T any, R any, V Comparable](page *StreamingPage[R, V], q *QueryCond[T], opts ...OptionFunc) (*StreamingPage[R, V], *gorm.DB) { - option := getOption(opts) - // 如果需要分页忽略总数,不查询总数 - if !option.IgnoreTotal { - total, countDb := SelectCount[T](q, opts...) - if countDb.Error != nil { - return page, countDb - } - page.Total = total - } - resultDb := buildCondition(q, opts...) - var r R - switch any(r).(type) { - case map[string]any: - var results []R - resultDb.Scopes(streamingPaginate(page)).Scan(&results) - page.RecordsMap = results - default: - var results []*R - resultDb.Scopes(streamingPaginate(page)).Scan(&results) - page.Records = results - } - return page, resultDb -} - // SelectGeneric 根据传入的泛型封装记录 // 第一个泛型代表数据库表实体 // 第二个泛型代表返回记录实体 @@ -396,13 +324,6 @@ func Begin(opts ...*sql.TxOptions) *gorm.DB { return db.Begin(opts...) } -// Tx 事务 -func Tx(txFunc func(tx *gorm.DB) error, opts ...OptionFunc) error { - db := getDb(opts...) - return db.Transaction(txFunc) -} - -// paginate offset分页 func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB { page := p.Current pageSize := p.Size @@ -418,22 +339,6 @@ func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB { } } -// streamingPaginate 流式分页,根据自增ID、雪花ID、时间等数值类型或者时间类型分页 -// Tips: 相比于 offset 分页性能更好,走的是 range,缺点是没办法跳页查询 -func streamingPaginate[T any, V Comparable](p *StreamingPage[T, V]) func(db *gorm.DB) *gorm.DB { - column := getColumnName(p.ColumnName) - startValue := p.StartValue - limit := p.Limit - return func(db *gorm.DB) *gorm.DB { - // 下一页 - if p.Forward { - return db.Where(fmt.Sprintf("%v > ?", column), startValue).Limit(limit) - } - // 上一页 - return db.Where(fmt.Sprintf("%v < ?", column), startValue).Order(fmt.Sprintf("%v DESC", column)).Limit(limit) - } -} - func buildCondition[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB { db := getDb(opts...) resultDb := db.Model(new(T)) diff --git a/gplus/query.go b/gplus/query.go index 6066f6a..cf630ae 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -48,6 +48,7 @@ func (q *QueryCond[T]) getSqlSegment() string { // NewQuery 构建查询条件 func NewQuery[T any]() (*QueryCond[T], *T) { q := &QueryCond[T]{} + modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { m, isReal := model.(*T) @@ -318,338 +319,6 @@ func (q *QueryCond[T]) Set(column any, val any) *QueryCond[T] { return q } -/* -* 自定义条件 - */ - -// AndCond 拼接 AND -func (q *QueryCond[T]) AndCond(cond bool, fn ...func(q *QueryCond[T])) *QueryCond[T] { - if cond { - return q.And(fn...) - } - return q -} - -// OrCond 拼接 OR -func (q *QueryCond[T]) OrCond(cond bool, fn ...func(q *QueryCond[T])) *QueryCond[T] { - if cond { - return q.Or(fn...) - } - return q -} - -// EqCond 等于 = -func (q *QueryCond[T]) EqCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Eq(column, val) - } - return q -} - -// NeCond 不等于 != -func (q *QueryCond[T]) NeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Ne(column, val) - } - return q -} - -// GtCond 大于 > -func (q *QueryCond[T]) GtCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Gt(column, val) - } - return q -} - -// GeCond 大于等于 >= -func (q *QueryCond[T]) GeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Ge(column, val) - } - return q -} - -// LtCond 小于 < -func (q *QueryCond[T]) LtCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Lt(column, val) - } - return q -} - -// LeCond 小于等于 <= -func (q *QueryCond[T]) LeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Le(column, val) - } - return q -} - -// LikeCond 模糊 LIKE '%值%' -func (q *QueryCond[T]) LikeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Like(column, val) - } - return q -} - -// NotLikeCond 非模糊 NOT LIKE '%值%' -func (q *QueryCond[T]) NotLikeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.NotLike(column, val) - } - return q -} - -// LikeLeftCond 左模糊 LIKE '%值' -func (q *QueryCond[T]) LikeLeftCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.LikeLeft(column, val) - } - return q -} - -// NotLikeLeftCond 非左模糊 NOT LIKE '%值' -func (q *QueryCond[T]) NotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.NotLike(column, val) - } - return q -} - -// LikeRightCond 右模糊 LIKE '值%' -func (q *QueryCond[T]) LikeRightCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.LikeRight(column, val) - } - return q -} - -// NotLikeRightCond 非右模糊 NOT LIKE '值%' -func (q *QueryCond[T]) NotLikeRightCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.NotLikeRight(column, val) - } - return q -} - -// InCond 字段 IN (值1, 值2, ...) -func (q *QueryCond[T]) InCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.In(column, val) - } - return q -} - -// AndEqCond 并且等于 = -func (q *QueryCond[T]) AndEqCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().Eq(column, val) - } - return q -} - -// AndNeCond 并且不等于 != -func (q *QueryCond[T]) AndNeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().Ne(column, val) - } - return q -} - -// AndGtCond 并且大于 > -func (q *QueryCond[T]) AndGtCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().Gt(column, val) - } - return q -} - -// AndGeCond 并且大于等于 >= -func (q *QueryCond[T]) AndGeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().Ge(column, val) - } - return q -} - -// AndLtCond 并且小于 < -func (q *QueryCond[T]) AndLtCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().Lt(column, val) - } - return q -} - -// AndLeCond 并且小于等于 <= -func (q *QueryCond[T]) AndLeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().Le(column, val) - } - return q -} - -// AndLikeCond 并且模糊 LIKE '%值%' -func (q *QueryCond[T]) AndLikeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().Like(column, val) - } - return q -} - -// AndNotLikeCond 并且非模糊 NOT LIKE '%值%' -func (q *QueryCond[T]) AndNotLikeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().NotLike(column, val) - } - return q -} - -// AndLikeLeftCond 并且左模糊 LIKE '%值' -func (q *QueryCond[T]) AndLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().LikeLeft(column, val) - } - return q -} - -// AndNotLikeLeftCond 并且非左模糊 NOT LIKE '%值' -func (q *QueryCond[T]) AndNotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().NotLikeLeft(column, val) - } - return q -} - -// AndLikeRightCond 并且右模糊 LIKE '值%' -func (q *QueryCond[T]) AndLikeRightCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().LikeRight(column, val) - } - return q -} - -// AndNotLikeRightCond 并且非右模糊 NOT LIKE '值%' -func (q *QueryCond[T]) AndNotLikeRightCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().NotLikeRight(column, val) - } - return q -} - -// AndInCond 并且字段 IN (值1, 值2, ...) -func (q *QueryCond[T]) AndInCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.And().In(column, val) - } - return q -} - -// OrEqCond 或者等于 = -func (q *QueryCond[T]) OrEqCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().Eq(column, val) - } - return q -} - -// OrNeCond 或者不等于 != -func (q *QueryCond[T]) OrNeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().Ne(column, val) - } - return q -} - -// OrGtCond 或者大于 > -func (q *QueryCond[T]) OrGtCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().Gt(column, val) - } - return q -} - -// OrGeCond 或者大于等于 >= -func (q *QueryCond[T]) OrGeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().Ge(column, val) - } - return q -} - -// OrLtCond 或者小于 < -func (q *QueryCond[T]) OrLtCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().Lt(column, val) - } - return q -} - -// OrLeCond 或者小于等于 <= -func (q *QueryCond[T]) OrLeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().Le(column, val) - } - return q -} - -// OrLikeCond 或者模糊 LIKE '%值%' -func (q *QueryCond[T]) OrLikeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().Like(column, val) - } - return q -} - -// OrNotLikeCond 或者非模糊 NOT LIKE '%值%' -func (q *QueryCond[T]) OrNotLikeCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().NotLike(column, val) - } - return q -} - -// OrLikeLeftCond 或者左模糊 LIKE '%值' -func (q *QueryCond[T]) OrLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().LikeLeft(column, val) - } - return q -} - -// OrNotLikeLeftCond 或者非左模糊 NOT LIKE '%值' -func (q *QueryCond[T]) OrNotLikeLeftCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().NotLikeLeft(column, val) - } - return q -} - -// OrLikeRightCond 或者右模糊 LIKE '值%' -func (q *QueryCond[T]) OrLikeRightCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().LikeRight(column, val) - } - return q -} - -// OrNotLikeRightCond 或者非右模糊 NOT LIKE '值%' -func (q *QueryCond[T]) OrNotLikeRightCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().NotLikeRight(column, val) - } - return q -} - -// OrInCond 或者字段 IN (值1, 值2, ...) -func (q *QueryCond[T]) OrInCond(cond bool, column any, val any) *QueryCond[T] { - if cond { - return q.Or().In(column, val) - } - return q -} - func (q *QueryCond[T]) addExpression(sqlSegments ...SqlSegment) { if len(sqlSegments) == 1 { q.handleSingle(sqlSegments[0]) diff --git a/gplus/segment.go b/gplus/sqlSegment.go similarity index 100% rename from gplus/segment.go rename to gplus/sqlSegment.go diff --git a/gplus/tool.go b/gplus/tool.go deleted file mode 100644 index 5e7917c..0000000 --- a/gplus/tool.go +++ /dev/null @@ -1,421 +0,0 @@ -/* - * Licensed to the AcmeStack under one or more contributor license - * agreements. See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package gplus - -import ( - "fmt" - "net/url" - "reflect" - "strconv" - "strings" - "sync" -) - -type Condition struct { - Group string - ColumnName string - Op string - ColumnValue any -} - -var columnTypeCache sync.Map - -var operators = []string{"!~<=", "!~>=", "~<=", "~>=", "!?=", "!^=", "!~=", "?=", "^=", "~=", "!=", ">=", "<=", "=", ">", "<"} -var builders = map[string]func(query *QueryCond[any], name string, value any){ - "!~<=": notLikeLeft, - "!~>=": notLikeRight, - "~<=": LikeLeft, - "~>=": LikeRight, - "!?=": notIn, - "!^=": notBetween, - "!~=": notLike, - "?=": in, - "^=": between, - "~=": like, - "!=": ne, - ">=": ge, - "<=": le, - "=": eq, - ">": gt, - "<": lt, -} - -func BuildQuery[T any](queryParams url.Values) *QueryCond[T] { - - columnCondMap, conditionMap, gcond := parseParams(queryParams) - - parentQuery := buildParentQuery[T](conditionMap) - - queryCondMap := buildQueryCondMap[T](columnCondMap) - - // 如果没有分组条件,直接返回默认的查询条件 - if len(gcond) == 0 { - if q, ok := queryCondMap["default"]; ok { - q.orderBuilder = parentQuery.orderBuilder - q.selectColumns = parentQuery.selectColumns - q.omitColumns = parentQuery.omitColumns - return q - } - - // 如果没有分组条件,但是有分组设置,返回第一个查询条件。主要为了兼容只有一个分组但是没有设置条件的情况。 - if len(queryCondMap) == 1 { - for _, q := range queryCondMap { - q.orderBuilder = parentQuery.orderBuilder - q.selectColumns = parentQuery.selectColumns - q.omitColumns = parentQuery.omitColumns - return q - } - } - } - - return buildGroupQuery[T](gcond, queryCondMap, parentQuery) -} - -func parseParams(queryParams url.Values) (map[string][]*Condition, map[string]string, string) { - var gcond string - var columnCondMap = make(map[string][]*Condition) - var conditionMap = make(map[string]string) - for key, values := range queryParams { - switch key { - case "q": - columnCondMap = buildColumnCondMap(values) - case "sort": - if len(values) > 0 { - conditionMap["sort"] = values[len(values)-1] - } - case "select": - if len(values) > 0 { - conditionMap["select"] = values[len(values)-1] - } - case "omit": - if len(values) > 0 { - conditionMap["omit"] = values[len(values)-1] - } - case "gcond": - gcond = values[0] - } - } - return columnCondMap, conditionMap, gcond -} - -// buildColumnCondMap 根据url参数构建字段条件 -func buildColumnCondMap(values []string) map[string][]*Condition { - var maps = make(map[string][]*Condition) - for _, value := range values { - currentOperator := getCurrentOp(value) - params := strings.SplitN(value, currentOperator, 2) - if len(params) == 2 { - condition := &Condition{} - groups := strings.Split(params[0], ".") - var groupName string - var columnName string - // 如果不包含组,默认分为同一个组 - if len(groups) == 1 { - groupName = "default" - columnName = groups[0] - } else if len(groups) == 2 { - groupName = groups[0] - columnName = groups[1] - } - condition.Group = groupName - condition.ColumnName = columnName - condition.Op = currentOperator - condition.ColumnValue = params[1] - conditions, ok := maps[groupName] - if ok { - conditions = append(conditions, condition) - } else { - conditions = []*Condition{condition} - } - maps[groupName] = conditions - } - } - return maps -} - -func getCurrentOp(value string) string { - var currentOperator string - for _, op := range operators { - if strings.Contains(value, op) { - currentOperator = op - break - } - } - return currentOperator -} - -func buildQueryCondMap[T any](columnCondMap map[string][]*Condition) map[string]*QueryCond[T] { - var queryCondMap = make(map[string]*QueryCond[T]) - columnTypeMap := getColumnTypeMap[T]() - for key, conditions := range columnCondMap { - query := &QueryCond[any]{} - query.columnTypeMap = columnTypeMap - for _, condition := range conditions { - name := condition.ColumnName - op := condition.Op - value := condition.ColumnValue - builders[op](query, name, value) - } - newQuery, _ := NewQuery[T]() - newQuery.queryExpressions = append(newQuery.queryExpressions, query.queryExpressions...) - queryCondMap[key] = newQuery - } - return queryCondMap -} - -func buildParentQuery[T any](conditionMap map[string]string) *QueryCond[T] { - parentQuery, _ := NewQuery[T]() - for key, value := range conditionMap { - if key == "sort" { - orderColumns := strings.Split(value, ",") - for _, column := range orderColumns { - if strings.HasPrefix(column, "-") { - newValue := strings.TrimLeft(column, "-") - parentQuery.OrderByDesc(newValue) - } else { - parentQuery.OrderByAsc(column) - } - } - } else if key == "select" { - selectColumns := strings.Split(value, ",") - for _, column := range selectColumns { - parentQuery.Select(column) - } - } else if key == "omit" { - omitColumns := strings.Split(value, ",") - for _, column := range omitColumns { - parentQuery.Omit(column) - } - } - } - return parentQuery -} - -func buildGroupQuery[T any](gcond string, queryMaps map[string]*QueryCond[T], query *QueryCond[T]) *QueryCond[T] { - var tempQuerys []*QueryCond[T] - tempQuerys = append(tempQuerys, query) - for i, char := range gcond { - str := string(char) - tempQuery := tempQuerys[len(tempQuerys)-1] - // 如果是 左括号 开头,则代表需要嵌套查询 - if str == "(" && i != len(gcond)-1 { - if i != 0 && string(gcond[i-1]) == "|" { - tempQuery.Or(func(q *QueryCond[T]) { - paramQuery, isOk := queryMaps[string(gcond[i+1])] - if isOk { - q.queryExpressions = append(q.queryExpressions, paramQuery.queryExpressions...) - tempQuerys = append(tempQuerys, q) - } - }) - continue - } else { - tempQuery.And(func(q *QueryCond[T]) { - paramQuery, isOk := queryMaps[string(gcond[i+1])] - if isOk { - q.queryExpressions = append(q.queryExpressions, paramQuery.queryExpressions...) - tempQuerys = append(tempQuerys, q) - } - }) - } - continue - } - - // 如果当前为 | ,而且不是最后一个字符,而且下一个字符不是 ( ,则为 or - if str == "|" && i != len(gcond)-1 { - paramQuery, isOk := queryMaps[string(gcond[i+1])] - if isOk { - tempQuery.Or().queryExpressions = append(tempQuery.queryExpressions, paramQuery.queryExpressions...) - tempQuery.last = paramQuery.queryExpressions[len(paramQuery.queryExpressions)-1] - } - continue - } - - if str == "*" && i != len(gcond)-1 { - paramQuery, isOk := queryMaps[string(gcond[i+1])] - if isOk { - tempQuery.And() - tempQuery.queryExpressions = append(tempQuery.queryExpressions, paramQuery.queryExpressions...) - tempQuery.last = paramQuery.queryExpressions[len(paramQuery.queryExpressions)-1] - } - continue - } - - if str == ")" { - // 删除最后一个query对象 - tempQuerys = tempQuerys[:len(tempQuerys)-1] - continue - } - - // 如果上面的条件不满足,而且是第一个的话,那么就直接添加条件 - if i == 0 { - paramQuery, isOk := queryMaps[string(gcond[i])] - if isOk { - tempQuery.queryExpressions = append(tempQuery.queryExpressions, paramQuery.queryExpressions...) - tempQuery.last = paramQuery.queryExpressions[len(paramQuery.queryExpressions)-1] - } - } - } - return query -} - -func getColumnTypeMap[T any]() map[string]reflect.Type { - modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() - if model, ok := columnTypeCache.Load(modelTypeStr); ok { - if columnNameMap, isOk := model.(map[string]reflect.Type); isOk { - return columnNameMap - } - } - var columnTypeMap = make(map[string]reflect.Type) - typeOf := reflect.TypeOf((*T)(nil)).Elem() - for i := 0; i < typeOf.NumField(); i++ { - field := typeOf.Field(i) - if field.Anonymous { - nestedFields := getSubFieldColumnTypeMap(field) - for key, value := range nestedFields { - columnTypeMap[key] = value - } - } - columnName := parseColumnName(field) - columnTypeMap[columnName] = field.Type - } - columnTypeCache.Store(modelTypeStr, columnTypeMap) - return columnTypeMap -} - -func getSubFieldColumnTypeMap(field reflect.StructField) map[string]reflect.Type { - columnTypeMap := make(map[string]reflect.Type) - modelType := field.Type - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - for j := 0; j < modelType.NumField(); j++ { - subField := modelType.Field(j) - if subField.Anonymous { - nestedFields := getSubFieldColumnTypeMap(subField) - for key, value := range nestedFields { - columnTypeMap[key] = value - } - } else { - columnName := parseColumnName(subField) - columnTypeMap[columnName] = subField.Type - } - } - return columnTypeMap -} - -func notLikeLeft(query *QueryCond[any], name string, value any) { - query.NotLikeLeft(name, convert(query.columnTypeMap, name, value)) -} - -func notLikeRight(query *QueryCond[any], name string, value any) { - query.NotLikeRight(name, convert(query.columnTypeMap, name, value)) -} - -func LikeLeft(query *QueryCond[any], name string, value any) { - query.LikeLeft(name, convert(query.columnTypeMap, name, value)) -} - -func LikeRight(query *QueryCond[any], name string, value any) { - query.LikeRight(name, convert(query.columnTypeMap, name, value)) -} - -func notIn(query *QueryCond[any], name string, value any) { - values := strings.Split(fmt.Sprintf("%s", value), ",") - var queryValues []any - for _, v := range values { - queryValues = append(queryValues, convert(query.columnTypeMap, name, v)) - } - query.NotIn(name, queryValues) -} - -func notBetween(query *QueryCond[any], name string, value any) { - values := strings.Split(fmt.Sprintf("%s", value), ",") - if len(values) == 2 { - query.NotBetween(name, convert(query.columnTypeMap, name, values[0]), convert(query.columnTypeMap, name, values[1])) - } -} - -func notLike(query *QueryCond[any], name string, value any) { - query.NotLike(name, convert(query.columnTypeMap, name, value)) -} - -func in(query *QueryCond[any], name string, value any) { - values := strings.Split(fmt.Sprintf("%s", value), ",") - var queryValues []any - for _, v := range values { - queryValues = append(queryValues, convert(query.columnTypeMap, name, v)) - } - query.In(name, queryValues) -} - -func between(query *QueryCond[any], name string, value any) { - values := strings.Split(fmt.Sprintf("%s", value), ",") - if len(values) == 2 { - query.Between(name, convert(query.columnTypeMap, name, values[0]), convert(query.columnTypeMap, name, values[1])) - } -} - -func like(query *QueryCond[any], name string, value any) { - query.Like(name, convert(query.columnTypeMap, name, value)) -} - -func ne(query *QueryCond[any], name string, value any) { - if strings.ToLower(fmt.Sprintf("%s", value)) == "null" { - query.IsNotNull(name) - } else { - query.Ne(name, convert(query.columnTypeMap, name, value)) - } -} - -func ge(query *QueryCond[any], name string, value any) { - query.Ge(name, convert(query.columnTypeMap, name, value)) -} - -func le(query *QueryCond[any], name string, value any) { - query.Le(name, convert(query.columnTypeMap, name, value)) -} - -func eq(query *QueryCond[any], name string, value any) { - if strings.ToLower(fmt.Sprintf("%s", value)) == "null" { - query.IsNull(name) - } else { - query.Eq(name, convert(query.columnTypeMap, name, value)) - } -} - -func gt(query *QueryCond[any], name string, value any) { - query.Gt(name, convert(query.columnTypeMap, name, value)) -} - -func lt(query *QueryCond[any], name string, value any) { - query.Lt(name, convert(query.columnTypeMap, name, value)) -} - -func convert(columnTypeMap map[string]reflect.Type, name string, value any) any { - columnType, ok := columnTypeMap[name] - if ok { - switch columnType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - atoi, err := strconv.Atoi(fmt.Sprintf("%s", value)) - if err == nil { - value = atoi - } - } - } - return value -} diff --git a/tests/dao_test.go b/tests/dao_test.go index d675d6a..71a95ce 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -20,20 +20,21 @@ package tests import ( "errors" "fmt" - "github.com/aixj1984/gorm-plus/gplus" - "gorm.io/driver/mysql" - "gorm.io/gorm" - "gorm.io/gorm/logger" "reflect" "sort" "strconv" "testing" + + "github.com/aixj1984/gorm-plus/gplus" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" ) var gormDb *gorm.DB func init() { - dsn := "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local" + dsn := "root:my-secret-pw@tcp(127.0.0.1:3306)/test_db?charset=utf8mb4&parseTime=True&loc=Local" var err error gormDb, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), @@ -183,7 +184,7 @@ func TestUpdateZeroById(t *testing.T) { users := getUsers() gplus.InsertBatch[User](users) - updateUser := &User{Base: Base{ID: users[0].ID, CreatedAt: users[0].CreatedAt}, Score: 100, Age: 25} + updateUser := &User{Base: Base{ID: users[0].ID}, Score: 100, Age: 25} if res := gplus.UpdateZeroById[User](updateUser); res.Error != nil || res.RowsAffected != 1 { t.Errorf("errors happened when deleteByIds: %v, affected: %v", res.Error, res.RowsAffected) @@ -557,18 +558,6 @@ func TestSelectGeneric7(t *testing.T) { } } -func TestTx(t *testing.T) { - deleteOldData() - users := getUsers() - err := gplus.Tx(func(tx *gorm.DB) error { - err := gplus.InsertBatch[User](users, gplus.Db(tx)).Error - return err - }) - if err != nil { - t.Errorf(err.Error()) - } -} - func TestCase(t *testing.T) { deleteOldData() users := getUsers() diff --git a/tests/delete_test.go b/tests/delete_test.go index 16fc15c..7f9a078 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -18,10 +18,11 @@ package tests import ( - "github.com/aixj1984/gorm-plus/gplus" - "gorm.io/gorm" "strings" "testing" + + "github.com/aixj1984/gorm-plus/gplus" + "gorm.io/gorm" ) func TestDeleteByIdName(t *testing.T) { diff --git a/tests/insert_test.go b/tests/insert_test.go index 914239e..a9d43a9 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -18,14 +18,15 @@ package tests import ( - "github.com/aixj1984/gorm-plus/gplus" - "gorm.io/gorm" "strings" "testing" + + "github.com/aixj1984/gorm-plus/gplus" + "gorm.io/gorm" ) func TestInsert1Name(t *testing.T) { - var expectSql = "INSERT INTO `Users` (`username`,`password`,`address`,`age`,`phone`,`score`,`dept`) VALUES ('afumu','123456','',18,'',12,'研发部门')" + var expectSql = "INSERT INTO `Users` (`username`,`password`,`address`,`age`,`phone`,`score`,`dept`) VALUES ('afumu','123456','',18,'',12,'研发部门') RETURNING `id`" user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} u := gplus.GetModel[User]() sessionDb := checkInsertSql(t, expectSql) @@ -33,7 +34,7 @@ func TestInsert1Name(t *testing.T) { } func TestInsert2Name(t *testing.T) { - var expectSql = "INSERT INTO `Users` (`username`,`password`,`address`,`age`,`phone`,`score`) VALUES ('afumu','123456','',18,'',12)" + var expectSql = "INSERT INTO `Users` (`username`,`password`,`address`,`age`,`phone`,`score`) VALUES ('afumu','123456','',18,'',12) RETURNING `id`" user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} u := gplus.GetModel[User]() sessionDb := checkInsertSql(t, expectSql) @@ -41,7 +42,7 @@ func TestInsert2Name(t *testing.T) { } func TestInsert3Name(t *testing.T) { - var expectSql = "INSERT INTO `Users` (`username`,`password`) VALUES ('afumu','123456')" + var expectSql = "INSERT INTO `Users` (`username`,`password`) VALUES ('afumu','123456') RETURNING `id`" user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} u := gplus.GetModel[User]() sessionDb := checkInsertSql(t, expectSql) @@ -49,7 +50,7 @@ func TestInsert3Name(t *testing.T) { } func TestInsertBatchName(t *testing.T) { - var expectSql = "INSERT INTO `Users` (`username`,`password`) VALUES ('afumu','123456'),('afumu','123456')" + var expectSql = "INSERT INTO `Users` (`username`,`password`) VALUES ('afumu','123456'),('afumu','123456') RETURNING `id`" user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} user2 := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} sessionDb := checkInsertSql(t, expectSql) @@ -58,7 +59,7 @@ func TestInsertBatchName(t *testing.T) { } func TestInsertBatchSizeName(t *testing.T) { - var expectSql = "INSERT INTO `Users` (`username`,`password`) VALUES ('afumu','123456'),('afumu','123456')" + var expectSql = "INSERT INTO `Users` (`username`,`password`) VALUES ('afumu','123456'),('afumu','123456') RETURNING `id`" user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} user2 := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} user3 := &User{Username: "afumu", Password: "123456", Age: 18, Score: 12, Dept: "研发部门"} diff --git a/tests/select_test.go b/tests/select_test.go index 677598a..7ae3fee 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -18,10 +18,11 @@ package tests import ( - "github.com/aixj1984/gorm-plus/gplus" - "gorm.io/gorm" "strings" "testing" + + "github.com/aixj1984/gorm-plus/gplus" + "gorm.io/gorm" ) func TestSelectByIdName(t *testing.T) { diff --git a/tests/tool_test.go b/tests/tool_test.go deleted file mode 100644 index 86c3369..0000000 --- a/tests/tool_test.go +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the AcmeStack under one or more contributor license - * agreements. See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package tests - -import ( - "github.com/aixj1984/gorm-plus/gplus" - "net/url" - "testing" -) - -func TestQueryById(t *testing.T) { - values := url.Values{} - values["q"] = []string{"id=1"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE id = 1" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByIdSelect(t *testing.T) { - values := url.Values{} - values["q"] = []string{"id=1"} - values["select"] = []string{"username,age"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT `username`,`age` FROM `Users` WHERE id = 1" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByIdOmit(t *testing.T) { - values := url.Values{} - values["q"] = []string{"id=1"} - values["omit"] = []string{"username,age"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT `Users`.`id`,`Users`.`created_at`,`Users`.`updated_at`,`Users`.`password`,`Users`.`address`,`Users`.`phone`,`Users`.`score`,`Users`.`dept` FROM `Users` WHERE id = 1" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByIdSortAsc(t *testing.T) { - values := url.Values{} - values["sort"] = []string{"age"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` ORDER BY age ASC" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByIdSortDesc(t *testing.T) { - values := url.Values{} - values["sort"] = []string{"-age"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` ORDER BY age DESC" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByIdsIn(t *testing.T) { - values := url.Values{} - values["q"] = []string{"id?=1,2"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE id IN (1,2)" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByEqUsername(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username=afumu"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username = 'afumu'" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByNeUsername(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username!=afumu"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username <> 'afumu'" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByGtAge(t *testing.T) { - values := url.Values{} - values["q"] = []string{"age>20"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE age > 20" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByGeAge(t *testing.T) { - values := url.Values{} - values["q"] = []string{"age>=20"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE age >= 20" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByLtAge(t *testing.T) { - values := url.Values{} - values["q"] = []string{"age<20"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE age < 20" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByLeAge(t *testing.T) { - values := url.Values{} - values["q"] = []string{"age<=20"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE age <= 20" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByLike(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username~=afumu"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username LIKE '%afumu%'" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByLeftLike(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username~<=afumu"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username LIKE '%afumu'" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByNotLeftLike(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username!~<=afumu"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username NOT LIKE '%afumu'" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByRightLike(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username~>=afumu"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username LIKE 'afumu%'" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByNotRightLike(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username!~>=afumu"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username NOT LIKE 'afumu%'" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByIsNull(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username=null"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username IS NULL" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByIsNotNull(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username!=null"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username IS NOT NULL" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByIn(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username?=afumu,zhangsan"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username IN ('afumu','zhangsan')" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByNotIn(t *testing.T) { - values := url.Values{} - values["q"] = []string{"username!?=afumu,zhangsan"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE username NOT IN ('afumu','zhangsan')" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByBetween(t *testing.T) { - values := url.Values{} - values["q"] = []string{"age^=20,30"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE age BETWEEN 20 AND 30" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByNotBetween(t *testing.T) { - values := url.Values{} - values["q"] = []string{"age!^=20,30"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE age NOT BETWEEN 20 AND 30" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByAnd(t *testing.T) { - values := url.Values{} - values["q"] = []string{"useranme=afumu", "age=20"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE useranme = 'afumu' AND age = 20" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByGroupOnlyOne(t *testing.T) { - values := url.Values{} - values["q"] = []string{"A.useranme=afumu", "A.age=20"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE useranme = 'afumu' AND age = 20" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByGroupAAndB(t *testing.T) { - values := url.Values{} - values["q"] = []string{"A.useranme=afumu", "A.password=123456", "B.age=20", "B.score=90"} - values["gcond"] = []string{"A*B"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE useranme = 'afumu' AND password = '123456' AND age = 20 AND score = 90" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByGroupAOrB(t *testing.T) { - values := url.Values{} - values["q"] = []string{"A.useranme=afumu", "A.password=123456", "B.age=20", "B.score=90"} - values["gcond"] = []string{"A|B"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE useranme = 'afumu' AND password = '123456' OR age = 20 AND score = 90" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} - -func TestQueryByGroupNest(t *testing.T) { - values := url.Values{} - values["q"] = []string{ - "A.useranme=afumu", "B.password=12345", "C.score=60", - "D.dept=开发", "F.address=北京", - } - values["gcond"] = []string{"(A*(B|C)|D)*F"} - query := gplus.BuildQuery[User](values) - var expectSql = "SELECT * FROM `Users` WHERE ( useranme = 'afumu' AND ( password = '12345' OR score = 60 ) OR dept = '开发' ) AND address = '北京'" - sessionDb := checkSelectSql(t, expectSql) - gplus.SelectList[User](query, gplus.Db(sessionDb)) -} From aa63c64dc8f6613ed76d7cae0cd078906198261f Mon Sep 17 00:00:00 2001 From: JohnAi Date: Wed, 1 Nov 2023 15:15:31 +0800 Subject: [PATCH 12/12] update test db --- tests/dao_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dao_test.go b/tests/dao_test.go index 71a95ce..d7fb0c9 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -34,7 +34,7 @@ import ( var gormDb *gorm.DB func init() { - dsn := "root:my-secret-pw@tcp(127.0.0.1:3306)/test_db?charset=utf8mb4&parseTime=True&loc=Local" + dsn := "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local" var err error gormDb, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info),