diff --git a/README.md b/README.md index 6ef758b..7ecb6be 100644 --- a/README.md +++ b/README.md @@ -317,108 +317,6 @@ log.Printf("error:%v\n", resultDb.Error) log.Printf("RowsAffected:%v\n", resultDb.RowsAffected) ``` - - -### 通用查询 - -gplus提供通用的CRUD操作,只需要在struct中嵌入` gplus.CommonDao`即可使用gplus提供的所有CRUD操作。 - -```Go -type StudentDao struct { - gplus.CommonDao[Student] -} - -var studentDao = &StudentDao{} -``` - - - -#### 根据ID查询 - -```Go - student, resultDb := studentDao.GetById(2) - log.Printf("error:%+v", resultDb.Error) - log.Printf("RowsAffected:%+v", resultDb.RowsAffected) - log.Printf("student:%+v", student) -``` - - - -#### 根据条件查询一条数据 - -```Go - query, model := gplus.NewQuery[Student]() - query.Eq(&model.Name, "zhangsan1") - student, resultDb := studentDao.GetOne(query) - log.Printf("error:%+v", resultDb.Error) - log.Printf("RowsAffected:%+v", resultDb.RowsAffected) - log.Printf("student:%+v", student) -``` - - - -#### 查询列表所有数据 - -```Go - students, resultDb := studentDao.ListAll() - log.Printf("error:%+v", resultDb.Error) - fmt.Println("RowsAffected:", resultDb.RowsAffected) - for _, student := range students { - log.Printf("student:%+v", student) - } -``` - - - -#### 根据条件查询数据 - -```Go - query, model := gplus.NewQuery[Student]() - query.Eq(&model.Name, "zhangsan1") - students, resultDb := studentDao.List(query) - log.Printf("error:%+v", resultDb.Error) - fmt.Println("RowsAffected:", resultDb.RowsAffected) - for _, student := range students { - log.Printf("student:%+v", student) - } -``` - - - -#### 分页查询所有数据 - -```Go - page := gplus.NewPage[Student](1, 2) - page, resultDb := studentDao.PageAll(page) - log.Printf("error:%+v", resultDb.Error) - fmt.Println("RowsAffected:", resultDb.RowsAffected) - for _, student := range page.Records { - log.Printf("student:%+v", student) - } -``` - - - -#### 分页条件查询数据 - -```Go - page := gplus.NewPage[Student](1, 2) - query, model := gplus.NewQuery[Student]() - query.Eq(&model.Name, "zhangsan1") - page, resultDb := studentDao.Page(page, query) - log.Printf("error:%+v", resultDb.Error) - fmt.Println("RowsAffected:", resultDb.RowsAffected) - for _, student := range page.Records { - log.Printf("student:%+v", student) - } -``` - - - - - - - ### 高级查询 #### 条件构造器 @@ -458,7 +356,42 @@ gorm-plus 提供了强大的条件构造器,通过构造器能够组合不同的 log.Printf("studentResult:%+v\n", studentResult) ``` +#### Query泛型简化 + +如果不希望每次创建Query对象的时候携带上泛型,我们可以提供一个全局的泛型Dao。 +```Go +var dao gplus.Dao[Student] +func main() { + query, model := dao.NewQuery() + query.Eq(&model.Name, "zhangsan") + list, resultDb := gplus.SelectList(query) + fmt.Println(resultDb.RowsAffected) + for _, v := range list { + marshal, _ := json.Marshal(v) + fmt.Println(string(marshal)) + } +} +``` + +我们也可以把`gplus.Dao`组合到我们自己定义的Dao对象中 + +```Go +type StudentDao struct { + gplus.Dao[Student] +} +var studentDao StudentDao +func main() { + query, model := studentDao.NewQuery() + query.Eq(&model.Name, "zhangsan") + list, resultDb := gplus.SelectList(query) + fmt.Println(resultDb.RowsAffected) + for _, v := range list { + marshal, _ := json.Marshal(v) + fmt.Println(string(marshal)) + } +} +``` #### 查询指定字段 diff --git a/gplus/common_dao.go b/gplus/common_dao.go deleted file mode 100644 index 996251c..0000000 --- a/gplus/common_dao.go +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the AcmeStack under one or more contributor license - * agreements. See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package gplus - -import ( - "gorm.io/gorm" -) - -type CommonDao[T any] struct{} - -func NewCommonDao[T any]() *CommonDao[T] { - return &CommonDao[T]{} -} - -func (service CommonDao[T]) Db() *gorm.DB { - return gormDb -} - -func (service CommonDao[T]) Save(entity *T) *gorm.DB { - return Insert[T](entity) -} - -func (service CommonDao[T]) SaveBatch(entities []*T) *gorm.DB { - return InsertBatch[T](entities) -} - -func (service CommonDao[T]) SaveBatchSize(entities []*T, batchSize int) *gorm.DB { - return InsertBatchSize[T](entities, batchSize) -} - -func (service CommonDao[T]) RemoveById(id any) *gorm.DB { - return DeleteById[T](id) -} - -func (service CommonDao[T]) RemoveByIds(ids any) *gorm.DB { - return DeleteByIds[T](ids) -} - -func (service CommonDao[T]) Remove(q *Query[T]) *gorm.DB { - return Delete[T](q) -} - -func (service CommonDao[T]) UpdateById(entity *T) *gorm.DB { - return UpdateById[T](entity) -} - -func (service CommonDao[T]) Update(q *Query[T]) *gorm.DB { - return Update[T](q) -} - -func (service CommonDao[T]) GetById(id any) (*T, *gorm.DB) { - return SelectById[T](id) -} - -func (service CommonDao[T]) GetOne(q *Query[T]) (*T, *gorm.DB) { - return SelectOne[T](q) -} - -func (service CommonDao[T]) ListAll() ([]*T, *gorm.DB) { - return SelectList[T](nil) -} - -func (service CommonDao[T]) List(q *Query[T]) ([]*T, *gorm.DB) { - return SelectList[T](q) -} - -func (service CommonDao[T]) ListByIds(ids any) ([]*T, *gorm.DB) { - return SelectByIds[T](ids) -} - -func (service CommonDao[T]) PageAll(page *Page[T]) (*Page[T], *gorm.DB) { - return SelectPage[T](page, nil) -} - -func (service CommonDao[T]) Page(page *Page[T], q *Query[T]) (*Page[T], *gorm.DB) { - return SelectPage[T](page, q) -} - -func (service CommonDao[T]) CountAll() (int64, *gorm.DB) { - return SelectCount[T](nil) -} - -func (service CommonDao[T]) Count(q *Query[T]) (int64, *gorm.DB) { - return SelectCount[T](q) -} diff --git a/gplus/base_dao.go b/gplus/dao.go similarity index 62% rename from gplus/base_dao.go rename to gplus/dao.go index 56a3277..c2dc7b5 100644 --- a/gplus/base_dao.go +++ b/gplus/dao.go @@ -25,11 +25,11 @@ import ( "reflect" ) -var gormDb *gorm.DB +var globalDb *gorm.DB var defaultBatchSize = 1000 func Init(db *gorm.DB) { - gormDb = db + globalDb = db } type Page[T any] struct { @@ -39,123 +39,138 @@ type Page[T any] struct { Records []*T } +type Dao[T any] struct{} + +func (dao Dao[T]) NewQuery() (*Query[T], *T) { + q := &Query[T]{} + return q, q.buildColumnNameMap() +} + func NewPage[T any](current, size int) *Page[T] { return &Page[T]{Current: current, Size: size} } -func Insert[T any](entity *T) *gorm.DB { - resultDb := gormDb.Create(entity) +func Insert[T any](entity *T, dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) + resultDb := db.Create(entity) return resultDb } -func InsertBatch[T any](entities []*T) *gorm.DB { +func InsertBatch[T any](entities []*T, dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) if len(entities) == 0 { - return gormDb + return db } - resultDb := gormDb.CreateInBatches(entities, defaultBatchSize) + resultDb := db.CreateInBatches(entities, defaultBatchSize) return resultDb } -func InsertBatchSize[T any](entities []*T, batchSize int) *gorm.DB { +func InsertBatchSize[T any](entities []*T, batchSize int, dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) if len(entities) == 0 { - return gormDb + return db } if batchSize <= 0 { batchSize = defaultBatchSize } - resultDb := gormDb.CreateInBatches(entities, batchSize) + resultDb := db.CreateInBatches(entities, batchSize) return resultDb } -func DeleteById[T any](id any) *gorm.DB { +func DeleteById[T any](id any, dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) var entity T - resultDb := gormDb.Where(getPkColumnName[T](), id).Delete(&entity) + resultDb := db.Where(getPkColumnName[T](), id).Delete(&entity) return resultDb } -func DeleteByIds[T any](ids any) *gorm.DB { +func DeleteByIds[T any](ids any, dbs ...*gorm.DB) *gorm.DB { q, _ := NewQuery[T]() q.In(getPkColumnName[T](), ids) - resultDb := Delete[T](q) + resultDb := Delete[T](q, dbs...) return resultDb } -func Delete[T any](q *Query[T]) *gorm.DB { +func Delete[T any](q *Query[T], dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) var entity T - resultDb := gormDb.Where(q.QueryBuilder.String(), q.QueryArgs...).Delete(&entity) + resultDb := db.Where(q.QueryBuilder.String(), q.QueryArgs...).Delete(&entity) return resultDb } -func DeleteByMap[T any](q *Query[T]) *gorm.DB { +func DeleteByMap[T any](q *Query[T], dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) for k, v := range q.ConditionMap { columnName := q.getColumnName(k) q.Eq(columnName, v) } var entity T - resultDb := gormDb.Where(q.QueryBuilder.String(), q.QueryArgs...).Delete(&entity) + resultDb := db.Where(q.QueryBuilder.String(), q.QueryArgs...).Delete(&entity) return resultDb } -func UpdateById[T any](entity *T) *gorm.DB { - resultDb := gormDb.Model(entity).Updates(entity) +func UpdateById[T any](entity *T, dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) + resultDb := db.Model(entity).Updates(entity) return resultDb } -func Update[T any](q *Query[T]) *gorm.DB { - resultDb := gormDb.Model(new(T)).Where(q.QueryBuilder.String(), q.QueryArgs...).Updates(&q.UpdateMap) +func Update[T any](q *Query[T], dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) + resultDb := db.Model(new(T)).Where(q.QueryBuilder.String(), q.QueryArgs...).Updates(&q.UpdateMap) return resultDb } -func SelectById[T any](id any) (*T, *gorm.DB) { +func SelectById[T any](id any, dbs ...*gorm.DB) (*T, *gorm.DB) { q, _ := NewQuery[T]() q.Eq(getPkColumnName[T](), id) var entity T - resultDb := buildCondition(q) + resultDb := buildCondition(q, dbs...) return &entity, resultDb.Limit(1).Find(&entity) } -func SelectByIds[T any](ids any) ([]*T, *gorm.DB) { +func SelectByIds[T any](ids any, dbs ...*gorm.DB) ([]*T, *gorm.DB) { q, _ := NewQuery[T]() q.In(getPkColumnName[T](), ids) - return SelectList[T](q) + return SelectList[T](q, dbs...) } -func SelectOne[T any](q *Query[T]) (*T, *gorm.DB) { +func SelectOne[T any](q *Query[T], dbs ...*gorm.DB) (*T, *gorm.DB) { var entity T - resultDb := buildCondition(q) + resultDb := buildCondition(q, dbs...) return &entity, resultDb.Limit(1).Find(&entity) } -func SelectList[T any](q *Query[T]) ([]*T, *gorm.DB) { - resultDb := buildCondition(q) +func SelectList[T any](q *Query[T], dbs ...*gorm.DB) ([]*T, *gorm.DB) { + resultDb := buildCondition(q, dbs...) var results []*T resultDb.Find(&results) return results, resultDb } -func SelectListModel[T any, R any](q *Query[T]) ([]*R, *gorm.DB) { - resultDb := buildCondition(q) +func SelectListModel[T any, R any](q *Query[T], dbs ...*gorm.DB) ([]*R, *gorm.DB) { + resultDb := buildCondition(q, dbs...) var results []*R resultDb.Scan(&results) return results, resultDb } -func SelectListByMap[T any](q *Query[T]) ([]*T, *gorm.DB) { - resultDb := buildCondition(q) +func SelectListByMap[T any](q *Query[T], dbs ...*gorm.DB) ([]*T, *gorm.DB) { + resultDb := buildCondition(q, dbs...) var results []*T resultDb.Find(&results) return results, resultDb } -func SelectListMaps[T any](q *Query[T]) ([]map[string]any, *gorm.DB) { - resultDb := buildCondition(q) +func SelectListMaps[T any](q *Query[T], dbs ...*gorm.DB) ([]map[string]any, *gorm.DB) { + resultDb := buildCondition(q, dbs...) var results []map[string]any resultDb.Find(&results) return results, resultDb } -func SelectPage[T any](page *Page[T], q *Query[T]) (*Page[T], *gorm.DB) { - total, countDb := SelectCount[T](q) +func SelectPage[T any](page *Page[T], q *Query[T], dbs ...*gorm.DB) (*Page[T], *gorm.DB) { + total, countDb := SelectCount[T](q, dbs...) if countDb.Error != nil { return page, countDb } @@ -167,8 +182,8 @@ func SelectPage[T any](page *Page[T], q *Query[T]) (*Page[T], *gorm.DB) { return page, resultDb } -func SelectPageModel[T any, R any](page *Page[R], q *Query[T]) (*Page[R], *gorm.DB) { - total, countDb := SelectCount[T](q) +func SelectPageModel[T any, R any](page *Page[R], q *Query[T], dbs ...*gorm.DB) (*Page[R], *gorm.DB) { + total, countDb := SelectCount[T](q, dbs...) if countDb.Error != nil { return page, countDb } @@ -180,8 +195,8 @@ func SelectPageModel[T any, R any](page *Page[R], q *Query[T]) (*Page[R], *gorm. return page, resultDb } -func SelectPageMaps[T any](page *Page[map[string]any], q *Query[T]) (*Page[map[string]any], *gorm.DB) { - total, countDb := SelectCount[T](q) +func SelectPageMaps[T any](page *Page[map[string]any], q *Query[T], dbs ...*gorm.DB) (*Page[map[string]any], *gorm.DB) { + total, countDb := SelectCount[T](q, dbs...) if countDb.Error != nil { return page, countDb } @@ -195,9 +210,9 @@ func SelectPageMaps[T any](page *Page[map[string]any], q *Query[T]) (*Page[map[s return page, resultDb } -func SelectCount[T any](q *Query[T]) (int64, *gorm.DB) { +func SelectCount[T any](q *Query[T], dbs ...*gorm.DB) (int64, *gorm.DB) { var count int64 - resultDb := buildCondition(q) + resultDb := buildCondition(q, dbs...) resultDb.Count(&count) return count, resultDb } @@ -217,8 +232,9 @@ func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB { } } -func buildCondition[T any](q *Query[T]) *gorm.DB { - resultDb := gormDb.Model(new(T)) +func buildCondition[T any](q *Query[T], dbs ...*gorm.DB) *gorm.DB { + db := getDb(dbs...) + resultDb := db.Model(new(T)) if q != nil { if len(q.DistinctColumns) > 0 { resultDb.Distinct(q.DistinctColumns) @@ -291,3 +307,11 @@ func getPkColumnName[T any]() string { } return columnName } + +func getDb(dbs ...*gorm.DB) *gorm.DB { + if len(dbs) > 0 { + db := dbs[0] + return db + } + return globalDb +} diff --git a/tests/base_dao_test.go b/tests/base_dao_test.go index 228782a..2ce5fad 100644 --- a/tests/base_dao_test.go +++ b/tests/base_dao_test.go @@ -43,13 +43,9 @@ func init() { gplus.Init(gormDb) } -type Test1 struct { - gorm.Model - Code string - Price uint -} - func TestInsert(t *testing.T) { + var u User + gormDb.AutoMigrate(u) user := &User{Username: "user1", Password: "123456", Age: 18, Score: 100, Dept: "财务部门"} result := gplus.Insert(user) if result.Error != nil {