diff --git a/.github/workflows/release-on-tag.yml b/.github/workflows/release-on-tag.yml new file mode 100644 index 0000000000..3868a8f562 --- /dev/null +++ b/.github/workflows/release-on-tag.yml @@ -0,0 +1,23 @@ +name: Create Release on Tag + +on: + push: + tags: + - '*' + +jobs: + create_release: + runs-on: ubuntu-latest + + steps: + - name: Create Release + uses: actions/create-release@v1 + with: + tag_name: ${{ github.ref_name }} + release_name: ${{ github.ref_name }} + body: | + Release ${{ github.ref_name }} of GORM. + draft: false + prerelease: false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index af471d20bf..24eab55abc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.21', '1.20', '1.19'] + go: ['1.22', '1.21', '1.20'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -41,8 +41,8 @@ jobs: mysql: strategy: matrix: - dbversion: ['mysql:latest', 'mysql:5.7'] - go: ['1.21', '1.20', '1.19'] + dbversion: ['mysql/mysql-server:latest', 'mysql:5.7'] + go: ['1.22', '1.21', '1.20'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -85,7 +85,7 @@ jobs: strategy: matrix: dbversion: [ 'mariadb:latest' ] - go: ['1.21', '1.20', '1.19'] + go: ['1.22', '1.21', '1.20'] platform: [ ubuntu-latest ] runs-on: ${{ matrix.platform }} @@ -127,8 +127,8 @@ jobs: postgres: strategy: matrix: - dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.21', '1.20', '1.19'] + dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13'] + go: ['1.22', '1.21', '1.20'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -170,23 +170,21 @@ jobs: sqlserver: strategy: matrix: - go: ['1.21', '1.20', '1.19'] + go: ['1.22', '1.21', '1.20'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} services: mssql: - image: mcmoe/mssqldocker:latest + image: mcr.microsoft.com/mssql/server:2022-latest env: + TZ: Asia/Shanghai ACCEPT_EULA: Y - SA_PASSWORD: LoremIpsum86 - MSSQL_DB: gorm - MSSQL_USER: gorm - MSSQL_PASSWORD: LoremIpsum86 + MSSQL_SA_PASSWORD: LoremIpsum86 ports: - 9930:1433 options: >- - --health-cmd="/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1" + --health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1" --health-start-period 10s --health-interval 10s --health-timeout 5s @@ -208,13 +206,13 @@ jobs: key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests - run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" ./tests/tests_all.sh + run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh tidb: strategy: matrix: dbversion: [ 'v6.5.0' ] - go: ['1.21', '1.20', '1.19'] + go: ['1.22', '1.21', '1.20'] platform: [ ubuntu-latest ] runs-on: ${{ matrix.platform }} diff --git a/association.go b/association.go index 7c93ebea0d..e3f51d173b 100644 --- a/association.go +++ b/association.go @@ -396,6 +396,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } case reflect.Struct: + if !rv.CanAddr() { + association.Error = ErrInvalidValue + return + } association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { @@ -433,6 +437,10 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) } case reflect.Struct: + if !rv.CanAddr() { + association.Error = ErrInvalidValue + return + } appendToFieldValues(rv.Addr()) } @@ -510,6 +518,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for i := 0; i < reflectValue.Len(); i++ { appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) + if association.Error != nil { + return + } // TODO support save slice data, sql with case? association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error @@ -531,6 +542,9 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ for idx, value := range values { rv := reflect.Indirect(reflect.ValueOf(value)) appendToRelations(reflectValue, rv, clear && idx == 0) + if association.Error != nil { + return + } } if len(values) > 0 { diff --git a/callbacks/preload.go b/callbacks/preload.go index 112343fa56..fd8214bb26 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -125,13 +125,15 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati case reflect.Slice, reflect.Array: if rv.Len() > 0 { reflectValue := rel.FieldSchema.MakeSlice().Elem() - reflectValue.SetLen(rv.Len()) for i := 0; i < rv.Len(); i++ { frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) if frv.Kind() != reflect.Ptr { - reflectValue.Index(i).Set(frv.Addr()) + reflectValue = reflect.Append(reflectValue, frv.Addr()) } else { - reflectValue.Index(i).Set(frv) + if frv.IsNil() { + continue + } + reflectValue = reflect.Append(reflectValue, frv) } } @@ -140,7 +142,7 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati return err } } - case reflect.Struct: + case reflect.Struct, reflect.Pointer: reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) tx := preloadDB(db, reflectValue, reflectValue.Interface()) if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { diff --git a/callbacks/query.go b/callbacks/query.go index 2a82eaba16..bbf238a9fd 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -286,7 +286,11 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { // clear the joins after query because preload need it - db.Statement.Joins = nil + if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { + fromClause := db.Statement.Clauses["FROM"] + fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins + db.Statement.Clauses["FROM"] = fromClause + } if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { diff --git a/chainable_api.go b/chainable_api.go index 3337060321..8953413d5f 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -185,6 +185,13 @@ func (db *DB) Omit(columns ...string) (tx *DB) { return } +// MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields +func (db *DB) MapColumns(m map[string]string) (tx *DB) { + tx = db.getInstance() + tx.Statement.ColumnMapping = m + return +} + // Where add conditions // // See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. @@ -299,10 +306,16 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { // // db.Order("name DESC") // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) +// db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{ +// {Column: clause.Column{Name: "name"}, Desc: true}, +// {Column: clause.Column{Name: "age"}, Desc: true}, +// }}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() switch v := value.(type) { + case clause.OrderBy: + tx.Statement.AddClause(v) case clause.OrderByColumn: tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{v}, diff --git a/finisher_api.go b/finisher_api.go index f97571ed04..6802945cc1 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "hash/maphash" "reflect" "strings" @@ -623,14 +624,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction if !db.DisableNestedTransaction { - err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + spID := new(maphash.Hash).Sum64() + err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error if err != nil { return } defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { - db.RollbackTo(fmt.Sprintf("sp%p", fc)) + db.RollbackTo(fmt.Sprintf("sp%d", spID)) } }() } diff --git a/go.mod b/go.mod index deb61b747c..9b30a0e7c2 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.18 require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.5 + golang.org/x/text v0.14.0 ) diff --git a/go.sum b/go.sum index bd6104c9b5..e3e29009d6 100644 --- a/go.sum +++ b/go.sum @@ -2,3 +2,5 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/gorm.go b/gorm.go index 1de30f2629..0c8c549170 100644 --- a/gorm.go +++ b/gorm.go @@ -52,6 +52,9 @@ type Config struct { TranslateError bool // SkipOutputStatement disable OUTPUT clause when inserting a row SkipOutputStatement bool + // PropagateUnscoped propagate Unscoped to every other nested statement + PropagateUnscoped bool + // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder // ConnPool db conn pool @@ -111,6 +114,7 @@ type Session struct { DisableNestedTransaction bool AllowGlobalUpdate bool FullSaveAssociations bool + PropagateUnscoped bool QueryFields bool SkipOutputStatement bool Context context.Context @@ -248,6 +252,10 @@ func (db *DB) Session(config *Session) *DB { txConfig.FullSaveAssociations = true } + if config.PropagateUnscoped { + txConfig.PropagateUnscoped = true + } + if config.Context != nil || config.PrepareStmt || config.SkipHooks { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx @@ -424,6 +432,9 @@ func (db *DB) getInstance() *DB { Vars: make([]interface{}, 0, 8), SkipHooks: db.Statement.SkipHooks, } + if db.Config.PropagateUnscoped { + tx.Statement.Unscoped = db.Statement.Unscoped + } } else { // with clone statement tx.Statement = db.Statement.clone() diff --git a/prepare_stmt.go b/prepare_stmt.go index 4d533885e4..094bb4775d 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -17,18 +17,16 @@ type Stmt struct { } type PreparedStmtDB struct { - Stmts map[string]*Stmt - PreparedSQL []string - Mux *sync.RWMutex + Stmts map[string]*Stmt + Mux *sync.RWMutex ConnPool } func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { return &PreparedStmtDB{ - ConnPool: connPool, - Stmts: make(map[string]*Stmt), - Mux: &sync.RWMutex{}, - PreparedSQL: make([]string, 0, 100), + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, } } @@ -48,12 +46,17 @@ func (db *PreparedStmtDB) Close() { db.Mux.Lock() defer db.Mux.Unlock() - for _, query := range db.PreparedSQL { - if stmt, ok := db.Stmts[query]; ok { - delete(db.Stmts, query) - go stmt.Close() - } + for _, stmt := range db.Stmts { + go func(s *Stmt) { + // make sure the stmt must finish preparation first + <-s.prepared + if s.Stmt != nil { + _ = s.Close() + } + }(stmt) } + // setting db.Stmts to nil to avoid further using + db.Stmts = nil } func (sdb *PreparedStmtDB) Reset() { @@ -61,9 +64,14 @@ func (sdb *PreparedStmtDB) Reset() { defer sdb.Mux.Unlock() for _, stmt := range sdb.Stmts { - go stmt.Close() + go func(s *Stmt) { + // make sure the stmt must finish preparation first + <-s.prepared + if s.Stmt != nil { + _ = s.Close() + } + }(stmt) } - sdb.PreparedSQL = make([]string, 0, 100) sdb.Stmts = make(map[string]*Stmt) } @@ -93,7 +101,12 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact return *stmt, nil } - + // check db.Stmts first to avoid Segmentation Fault(setting value to nil map) + // which cause by calling Close and executing SQL concurrently + if db.Stmts == nil { + db.Mux.Unlock() + return Stmt{}, ErrInvalidDB + } // cache preparing stmt first cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} db.Stmts[query] = &cacheStmt @@ -118,7 +131,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.Lock() cacheStmt.Stmt = stmt - db.PreparedSQL = append(db.PreparedSQL, query) db.Mux.Unlock() return cacheStmt, nil diff --git a/scan.go b/scan.go index 89b46c0a2f..d852c2c9f9 100644 --- a/scan.go +++ b/scan.go @@ -131,6 +131,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) { onConflictDonothing = mode&ScanOnConflictDoNothing != 0 ) + if len(db.Statement.ColumnMapping) > 0 { + for i, column := range columns { + v, ok := db.Statement.ColumnMapping[column] + if ok { + columns[i] = v + } + } + } + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { @@ -331,6 +340,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { + if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct { + db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) + } db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) } default: diff --git a/schema/naming.go b/schema/naming.go index e6fb81b2b3..6248bde8d5 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -8,6 +8,8 @@ import ( "unicode/utf8" "github.com/jinzhu/inflection" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) // Namer namer interface @@ -121,7 +123,7 @@ var ( func init() { commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) for _, initialism := range commonInitialisms { - commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism))) + commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism)) } commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) } @@ -186,9 +188,9 @@ func (ns NamingStrategy) toDBName(name string) string { } func (ns NamingStrategy) toSchemaName(name string) string { - result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") + result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "") for _, initialism := range commonInitialisms { - result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") + result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") } return result } diff --git a/schema/relationship.go b/schema/relationship.go index c11918a5e4..32676b399e 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -7,6 +7,9 @@ import ( "strings" "github.com/jinzhu/inflection" + "golang.org/x/text/cases" + "golang.org/x/text/language" + "gorm.io/gorm/clause" ) @@ -301,9 +304,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } for idx, ownField := range ownForeignFields { - joinFieldName := strings.Title(schema.Name) + ownField.Name + joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name if len(joinForeignKeys) > idx { - joinFieldName = strings.Title(joinForeignKeys[idx]) + joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx]) } ownFieldsMap[joinFieldName] = ownField @@ -318,7 +321,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } for idx, relField := range refForeignFields { - joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name + joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { @@ -329,7 +332,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } if len(joinReferences) > idx { - joinFieldName = strings.Title(joinReferences[idx]) + joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx]) } referFieldsMap[joinFieldName] = relField @@ -347,7 +350,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } joinTableFields = append(joinTableFields, reflect.StructField{ - Name: strings.Title(schema.Name) + field.Name, + Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name, Type: schema.ModelType, Tag: `gorm:"-"`, }) diff --git a/schema/schema.go b/schema/schema.go index 3e7459ce74..db2367975d 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -67,9 +67,10 @@ func (schema Schema) String() string { } func (schema Schema) MakeSlice() reflect.Value { - slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) + slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) + return results } @@ -337,7 +338,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { - if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { + if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { if schema.parseRelation(field); schema.err != nil { return schema, schema.err } else { diff --git a/schema/schema_test.go b/schema/schema_test.go index 45e152e903..a7115f60ac 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -19,6 +19,22 @@ func TestParseSchema(t *testing.T) { checkUserSchema(t, user) } +func TestParseSchemaWithMap(t *testing.T) { + type User struct { + tests.User + Attrs map[string]string `gorm:"type:Map(String,String);"` + } + + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user with map, got error %v", err) + } + + if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" { + t.Errorf("failed to parse user field Attrs") + } +} + func TestParseSchemaWithPointerFields(t *testing.T) { user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { diff --git a/schema/serializer.go b/schema/serializer.go index f500521ef7..0fafbcba07 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -84,7 +84,10 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, case string: bytes = []byte(v) default: - return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) + bytes, err = json.Marshal(v) + if err != nil { + return err + } } if len(bytes) > 0 { diff --git a/statement.go b/statement.go index ae79aa3218..39e05d093b 100644 --- a/statement.go +++ b/statement.go @@ -30,8 +30,9 @@ type Statement struct { Clauses map[string]clause.Clause BuildClauses []string Distinct bool - Selects []string // selected columns - Omits []string // omit columns + Selects []string // selected columns + Omits []string // omit columns + ColumnMapping map[string]string // map columns Joins []join Preloads map[string][]interface{} Settings sync.Map @@ -513,6 +514,7 @@ func (stmt *Statement) clone() *Statement { Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, + ColumnMapping: stmt.ColumnMapping, Preloads: map[string][]interface{}{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index b8e8ff5efc..db397eb78f 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -554,3 +554,15 @@ func TestHasManyAssociationUnscoped(t *testing.T) { t.Errorf("expected %d contents, got %d", 0, len(contents)) } } + +func TestHasManyAssociationReplaceWithNonValidValue(t *testing.T) { + user := User{Name: "jinzhu", Languages: []Language{{Name: "EN"}}} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, Language{Name: "FR"}); err == nil { + t.Error("expected association error to be not nil") + } +} diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index a2c0750904..78290ce90b 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -255,3 +255,15 @@ func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { DB.Model(&pets).Association("Toy").Clear() AssertAssociationCount(t, pets, "Toy", 0, "After Clear") } + +func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) { + user := User{Name: "jinzhu", Account: Account{Number: "1"}} + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil { + t.Error("expected association error to be not nil") + } +} diff --git a/tests/docker-compose.yml b/tests/compose.yml similarity index 69% rename from tests/docker-compose.yml rename to tests/compose.yml index 8abd4d0f7a..66f2daee7f 100644 --- a/tests/docker-compose.yml +++ b/tests/compose.yml @@ -1,10 +1,8 @@ -version: '3' - services: mysql: image: 'mysql/mysql-server:latest' ports: - - "9910:3306" + - "127.0.0.1:9910:3306" environment: - MYSQL_DATABASE=gorm - MYSQL_USER=gorm @@ -13,25 +11,22 @@ services: postgres: image: 'postgres:latest' ports: - - "9920:5432" + - "127.0.0.1:9920:5432" environment: - TZ=Asia/Shanghai - POSTGRES_DB=gorm - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm mssql: - image: '${MSSQL_IMAGE:-mcmoe/mssqldocker}:latest' + image: '${MSSQL_IMAGE}:2022-latest' ports: - - "9930:1433" + - "127.0.0.1:9930:1433" environment: - TZ=Asia/Shanghai - ACCEPT_EULA=Y - - SA_PASSWORD=LoremIpsum86 - - MSSQL_DB=gorm - - MSSQL_USER=gorm - - MSSQL_PASSWORD=LoremIpsum86 + - MSSQL_SA_PASSWORD=LoremIpsum86 tidb: image: 'pingcap/tidb:v6.5.0' ports: - - "9940:4000" + - "127.0.0.1:9940:4000" command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 & diff --git a/tests/go.mod b/tests/go.mod index 10fa7ec81f..ba8a84a868 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,11 +7,11 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 github.com/stretchr/testify v1.9.0 - gorm.io/driver/mysql v1.5.6 - gorm.io/driver/postgres v1.5.7 - gorm.io/driver/sqlite v1.5.5 + gorm.io/driver/mysql v1.5.7 + gorm.io/driver/postgres v1.5.9 + gorm.io/driver/sqlite v1.5.6 gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.9 + gorm.io/gorm v1.25.12 ) require ( @@ -21,16 +21,16 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.7.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect - github.com/microsoft/go-mssqldb v1.7.1 // indirect + github.com/mattn/go-sqlite3 v1.14.23 // indirect + github.com/microsoft/go-mssqldb v1.7.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.22.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/crypto v0.27.0 // indirect + golang.org/x/text v0.18.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/tests/hooks_test.go b/tests/hooks_test.go index 0753dd0b13..04f62bde21 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -2,6 +2,8 @@ package tests_test import ( "errors" + "log" + "os" "reflect" "strings" "testing" @@ -566,3 +568,44 @@ func TestUpdateCallbacks(t *testing.T) { t.Fatalf("before update should not be called") } } + +type Product6 struct { + gorm.Model + Name string + Item *ProductItem2 +} + +type ProductItem2 struct { + gorm.Model + Product6ID uint +} + +func (p *Product6) BeforeDelete(tx *gorm.DB) error { + if err := tx.Delete(&p.Item).Error; err != nil { + return err + } + return nil +} + +func TestPropagateUnscoped(t *testing.T) { + _DB, err := OpenTestConnection(&gorm.Config{ + PropagateUnscoped: true, + }) + if err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) + } + + _DB.Migrator().DropTable(&Product6{}, &ProductItem2{}) + _DB.AutoMigrate(&Product6{}, &ProductItem2{}) + + p := Product6{ + Name: "unique_code", + Item: &ProductItem2{}, + } + _DB.Model(&Product6{}).Create(&p) + + if err := _DB.Unscoped().Delete(&p).Error; err != nil { + t.Fatalf("unscoped did not propagate") + } +} diff --git a/tests/joins_test.go b/tests/joins_test.go index 786fc37e6a..497f814672 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -1,10 +1,12 @@ package tests_test import ( + "fmt" "regexp" "sort" "testing" + "github.com/stretchr/testify/assert" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -184,14 +186,12 @@ func TestJoinCount(t *testing.T) { DB.Create(&user) query := DB.Model(&User{}).Joins("Company") - // Bug happens when .Count is called on a query. - // Removing the below two lines or downgrading to gorm v1.20.12 will make this test pass. + var total int64 query.Count(&total) var result User - // Incorrectly generates a 'SELECT *' query which causes companies.id to overwrite users.id if err := query.First(&result, user.ID).Error; err != nil { t.Fatalf("Failed, got error: %v", err) } @@ -199,6 +199,10 @@ func TestJoinCount(t *testing.T) { if result.ID != user.ID { t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) } + // should find company + if result.Company.ID != *user.CompanyID { + t.Fatalf("result's id, %d, doesn't match user's company id, %d", result.Company.ID, *user.CompanyID) + } } func TestJoinWithSoftDeleted(t *testing.T) { @@ -400,3 +404,75 @@ func TestNestedJoins(t *testing.T) { CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) } } + +func TestJoinsPreload_Issue7013(t *testing.T) { + manager := &User{Name: "Manager"} + DB.Create(manager) + + var userIDs []uint + for i := 0; i < 21; i++ { + user := &User{Name: fmt.Sprintf("User%d", i), ManagerID: &manager.ID} + DB.Create(user) + userIDs = append(userIDs, user.ID) + } + + var entries []User + assert.NotPanics(t, func() { + assert.NoError(t, + DB.Debug().Preload("Manager.Team"). + Joins("Manager.Company"). + Find(&entries).Error) + }) +} + +func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) { + type ( + Furniture struct { + gorm.Model + OwnerID *uint + } + + Owner struct { + gorm.Model + Furnitures []Furniture + CompanyID *uint + Company Company + } + + Building struct { + gorm.Model + Name string + OwnerID *uint + Owner Owner + } + ) + + DB.Migrator().DropTable(&Building{}, &Owner{}, &Furniture{}) + DB.Migrator().AutoMigrate(&Building{}, &Owner{}, &Furniture{}) + + home := &Building{Name: "relation_empty"} + DB.Create(home) + + var entries []Building + assert.NotPanics(t, func() { + assert.NoError(t, + DB.Debug().Preload("Owner.Furnitures"). + Joins("Owner.Company"). + Find(&entries).Error) + }) + + AssertEqual(t, entries, []Building{{Model: home.Model, Name: "relation_empty", Owner: Owner{Company: Company{}}}}) +} + +func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) { + var entries []User + assert.NotPanics(t, func() { + assert.NoError(t, + DB.Debug().Preload("Manager.Team"). + Joins("Manager.Company"). + Where("1 <> 1"). + Find(&entries).Error) + }) + + AssertEqual(t, len(entries), 0) +} diff --git a/tests/preload_test.go b/tests/preload_test.go index 6e0e91bac6..f798b5f41a 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -440,6 +440,58 @@ func TestMergeNestedPreloadWithNestedJoin(t *testing.T) { } } +func TestNestedPreloadWithPointerJoin(t *testing.T) { + type ( + Preload struct { + ID uint + Value string + JoinID uint + } + Join struct { + ID uint + Value string + Preload Preload + NestedID uint + } + Nested struct { + ID uint + Join Join + ValueID uint + } + Value struct { + ID uint + Name string + Nested *Nested + } + ) + + DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) + DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) + + value := Value{ + Name: "value", + Nested: &Nested{ + Join: Join{ + Value: "j1", + Preload: Preload{ + Value: "p1", + }, + }, + }, + } + + if err := DB.Create(&value).Error; err != nil { + t.Errorf("failed to create value, got err: %v", err) + } + + var find1 Value + err := DB.Table("values").Joins("Nested").Joins("Nested.Join").Preload("Nested.Join.Preload").First(&find1).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + AssertEqual(t, find1, value) +} + func TestEmbedPreload(t *testing.T) { type Country struct { ID int `gorm:"primaryKey"` diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index b86bc3d64f..20a4f7308c 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "testing" "time" @@ -167,3 +168,149 @@ func TestPreparedStmtReset(t *testing.T) { t.Fatalf("prepared stmt should be empty") } } + +func isUsingClosedConnError(err error) bool { + // https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717 + return err.Error() == "sql: statement is closed" +} + +// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently +// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt +func TestPreparedStmtConcurrentReset(t *testing.T) { + name := "prepared_stmt_concurrent_reset" + user := *GetUser(name, Config{}) + createTx := DB.Session(&gorm.Session{}).Create(&user) + if createTx.Error != nil { + t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) + } + + // create a new connection to keep away from other tests + tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) + if err != nil { + t.Fatalf("failed to open test connection due to %s", err) + } + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + loopCount := 100 + var wg sync.WaitGroup + var unexpectedError bool + writerFinish := make(chan struct{}) + + wg.Add(1) + go func(id uint) { + defer wg.Done() + defer close(writerFinish) + + for j := 0; j < loopCount; j++ { + var tmp User + err := tx.Session(&gorm.Session{}).First(&tmp, id).Error + if err == nil || isUsingClosedConnError(err) { + continue + } + t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err) + unexpectedError = true + break + } + }(user.ID) + + wg.Add(1) + go func() { + defer wg.Done() + <-writerFinish + pdb.Reset() + }() + + wg.Wait() + + if unexpectedError { + t.Fatalf("should is a unexpected error") + } +} + +// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently +// for example: one goroutine found error and just close the database, and others are executing SQL +// this test making sure that the gorm would not get a Segmentation Fault, +// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB +// and all of the goroutine must got gorm.ErrInvalidDB after database close +func TestPreparedStmtConcurrentClose(t *testing.T) { + name := "prepared_stmt_concurrent_close" + user := *GetUser(name, Config{}) + createTx := DB.Session(&gorm.Session{}).Create(&user) + if createTx.Error != nil { + t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) + } + + // create a new connection to keep away from other tests + tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) + if err != nil { + t.Fatalf("failed to open test connection due to %s", err) + } + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + loopCount := 100 + var wg sync.WaitGroup + var lastErr error + closeValid := make(chan struct{}, loopCount) + closeStartIdx := loopCount / 2 // close the database at the middle of the execution + var lastRunIndex int + var closeFinishedAt int64 + + wg.Add(1) + go func(id uint) { + defer wg.Done() + defer close(closeValid) + for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ { + if lastRunIndex == closeStartIdx { + closeValid <- struct{}{} + } + var tmp User + now := time.Now().UnixNano() + err := tx.Session(&gorm.Session{}).First(&tmp, id).Error + if err == nil { + closeFinishedAt := atomic.LoadInt64(&closeFinishedAt) + if (closeFinishedAt != 0) && (now > closeFinishedAt) { + lastErr = errors.New("must got error after database closed") + break + } + continue + } + lastErr = err + break + } + }(user.ID) + + wg.Add(1) + go func() { + defer wg.Done() + for range closeValid { + for i := 0; i < loopCount; i++ { + pdb.Close() // the Close method must can be call multiple times + atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano()) + } + } + }() + + wg.Wait() + var tmp User + err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error + if err != gorm.ErrInvalidDB { + t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err) + } + + // must be error + if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) { + t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr) + } + if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx { + t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex) + } + if pdb.Stmts != nil { + t.Fatalf("stmts must be nil") + } +} diff --git a/tests/query_test.go b/tests/query_test.go index 79f7182bbb..566763c515 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -860,6 +860,28 @@ func TestOmitWithAllFields(t *testing.T) { } } +func TestMapColumns(t *testing.T) { + user := User{Name: "MapColumnsUser", Age: 12} + DB.Save(&user) + + type result struct { + Name string + Nickname string + Age uint + } + var res result + DB.Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "nickname"}).Scan(&res) + if res.Nickname != user.Name { + t.Errorf("Expected res.Nickname to be %s, but got %s", user.Name, res.Nickname) + } + if res.Name != "" { + t.Errorf("Expected res.Name to be empty, but got %s", res.Name) + } + if res.Age != user.Age { + t.Errorf("Expected res.Age to be %d, but got %d", user.Age, res.Age) + } +} + func TestPluckWithSelect(t *testing.T) { users := []User{ {Name: "pluck_with_select_1", Age: 25}, @@ -1194,7 +1216,6 @@ func TestSubQueryWithRaw(t *testing.T) { Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). Group("name"), ).Count(&count).Error - if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } @@ -1210,7 +1231,6 @@ func TestSubQueryWithRaw(t *testing.T) { Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). Group("name"), ).Count(&count).Error - if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } diff --git a/tests/scan_test.go b/tests/scan_test.go index 6f2e9f54dd..f7def90949 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -5,6 +5,7 @@ import ( "sort" "strings" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -126,7 +127,7 @@ func TestScanRows(t *testing.T) { rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() if err != nil { - t.Errorf("Not error should happen, got %v", err) + t.Errorf("No error should happen, got %v", err) } type Result struct { @@ -148,7 +149,7 @@ func TestScanRows(t *testing.T) { }) if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") + t.Errorf("Should find expected results, got %+v", results) } var ages int @@ -158,7 +159,105 @@ func TestScanRows(t *testing.T) { var name string if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { - t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + t.Fatalf("failed to scan name, got error %v, name: %v", err, name) + } +} + +func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) { + DB.Save(&User{}) + + rows, err := DB.Table("users"). + Select(` + NULL AS bool_field, + NULL AS int_field, + NULL AS int8_field, + NULL AS int16_field, + NULL AS int32_field, + NULL AS int64_field, + NULL AS uint_field, + NULL AS uint8_field, + NULL AS uint16_field, + NULL AS uint32_field, + NULL AS uint64_field, + NULL AS float32_field, + NULL AS float64_field, + NULL AS string_field, + NULL AS time_field, + NULL AS time_ptr_field, + NULL AS embedded_int_field, + NULL AS nested_embedded_int_field, + NULL AS embedded_ptr_int_field + `).Rows() + if err != nil { + t.Errorf("No error should happen, got %v", err) + } + + type NestedEmbeddedStruct struct { + NestedEmbeddedIntField int + NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"` + } + + type EmbeddedStruct struct { + EmbeddedIntField int + NestedEmbeddedStruct `gorm:"embedded"` + } + + type EmbeddedPtrStruct struct { + EmbeddedPtrIntField int + *NestedEmbeddedStruct `gorm:"embedded"` + } + + type Result struct { + BoolField bool + IntField int + Int8Field int8 + Int16Field int16 + Int32Field int32 + Int64Field int64 + UIntField uint + UInt8Field uint8 + UInt16Field uint16 + UInt32Field uint32 + UInt64Field uint64 + Float32Field float32 + Float64Field float64 + StringField string + TimeField time.Time + TimePtrField *time.Time + EmbeddedStruct `gorm:"embedded"` + *EmbeddedPtrStruct `gorm:"embedded"` + } + + currTime := time.Now() + reusedVar := Result{ + BoolField: true, + IntField: 1, + Int8Field: 1, + Int16Field: 1, + Int32Field: 1, + Int64Field: 1, + UIntField: 1, + UInt8Field: 1, + UInt16Field: 1, + UInt32Field: 1, + UInt64Field: 1, + Float32Field: 1.1, + Float64Field: 1.1, + StringField: "hello", + TimeField: currTime, + TimePtrField: &currTime, + EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, + EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, + } + + for rows.Next() { + if err := DB.ScanRows(rows, &reusedVar); err != nil { + t.Errorf("should get no error, but got %v", err) + } + } + + if !reflect.DeepEqual(reusedVar, Result{}) { + t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar) } } diff --git a/tests/tests_all.sh b/tests/tests_all.sh index ee9e767541..67c6938eaf 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -21,13 +21,13 @@ if [[ -z $GITHUB_ACTION ]]; then then cd tests if [[ $(uname -a) == *" arm64" ]]; then - MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker-compose start || true + MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker compose up -d || true go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" > /dev/null || true SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" > /dev/null || true SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" > /dev/null || true else - docker-compose start + MSSQL_IMAGE=mcr.microsoft.com/mssql/server docker compose up -d fi cd .. fi diff --git a/tests/tests_test.go b/tests/tests_test.go index a127734edc..e84162cd3d 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -20,7 +20,7 @@ var DB *gorm.DB var ( mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" - sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" + sqlserverDSN = "sqlserver://sa:LoremIpsum86@localhost:9930?database=master" tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index d2cbc9a95f..9f0f067c8b 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) { } } +func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) { + transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return callback(ctx, tx) + }) + } + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("inner rollback") + }); err == nil { + t.Fatalf("nested transaction has no error") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked parent record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should not find rollbacked nested record") + } +} + func TestDisabledNestedTransaction(t *testing.T) { var ( user = *GetUser("transaction-nested", Config{}) diff --git a/utils/utils.go b/utils/utils.go index b8d30b35b9..fc615d73b6 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -166,3 +166,14 @@ func SplitNestedRelationName(name string) []string { func JoinNestedRelationNames(relationNames []string) string { return strings.Join(relationNames, nestedRelationSplit) } + +// RTrimSlice Right trims the given slice by given length +func RTrimSlice[T any](v []T, trimLen int) []T { + if trimLen >= len(v) { // trimLen greater than slice len means fully sliced + return v[:0] + } + if trimLen < 0 { // negative trimLen is ignored + return v[:] + } + return v[:len(v)-trimLen] +} diff --git a/utils/utils_test.go b/utils/utils_test.go index 8ff42af8d1..089cc4c8e9 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -138,3 +138,64 @@ func TestToString(t *testing.T) { }) } } + +func TestRTrimSlice(t *testing.T) { + tests := []struct { + name string + input []int + trimLen int + expected []int + }{ + { + name: "Trim two elements from end", + input: []int{1, 2, 3, 4, 5}, + trimLen: 2, + expected: []int{1, 2, 3}, + }, + { + name: "Trim entire slice", + input: []int{1, 2, 3}, + trimLen: 3, + expected: []int{}, + }, + { + name: "Trim length greater than slice length", + input: []int{1, 2, 3}, + trimLen: 5, + expected: []int{}, + }, + { + name: "Zero trim length", + input: []int{1, 2, 3}, + trimLen: 0, + expected: []int{1, 2, 3}, + }, + { + name: "Trim one element from end", + input: []int{1, 2, 3}, + trimLen: 1, + expected: []int{1, 2}, + }, + { + name: "Empty slice", + input: []int{}, + trimLen: 2, + expected: []int{}, + }, + { + name: "Negative trim length (should be treated as zero)", + input: []int{1, 2, 3}, + trimLen: -1, + expected: []int{1, 2, 3}, + }, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + result := RTrimSlice(testcase.input, testcase.trimLen) + if !AssertEqual(result, testcase.expected) { + t.Errorf("RTrimSlice(%v, %d) = %v; want %v", testcase.input, testcase.trimLen, result, testcase.expected) + } + }) + } +}