diff --git a/callbacks.go b/callbacks.go index 195d17203d..50b5b0e937 100644 --- a/callbacks.go +++ b/callbacks.go @@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error { func (p *processor) compile() (err error) { var callbacks []*callback + removedMap := map[string]bool{} for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } + if callback.remove { + removedMap[callback.name] = true + } + } + + if len(removedMap) > 0 { + callbacks = removeCallbacks(callbacks, removedMap) } p.callbacks = callbacks @@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { return } + +func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { + callbacks := make([]*callback, 0, len(cs)) + for _, callback := range cs { + if nameMap[callback.name] { + continue + } + callbacks = append(callbacks, callback) + } + return callbacks +} diff --git a/callbacks/create.go b/callbacks/create.go index b1488b0822..8b7846b633 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) { pkField *schema.Field pkFieldName = "@id" ) + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + + if !insertOk { + if !supportReturning { + db.AddError(err) + } + return + } + if db.Statement.Schema != nil { if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { return @@ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName } - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) - return - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { @@ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) { } } } + + if config.LastInsertIDReversed { + insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement + } + for _, mapValue := range mapValues { if mapValue != nil { mapValue[pkFieldName] = insertID @@ -293,13 +302,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - for field, vs := range defaultValueFieldsHavingValue { - values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) - for idx := range values.Values { - if vs[idx] == nil { - values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) - } else { - values.Values[idx] = append(values.Values[idx], vs[idx]) + for _, field := range stmt.Schema.FieldsWithDefaultDBValue { + if vs, ok := defaultValueFieldsHavingValue[field]; ok { + values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) + for idx := range values.Values { + if vs[idx] == nil { + values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field)) + } else { + values.Values[idx] = append(values.Values[idx], vs[idx]) + } } } } @@ -322,7 +333,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil { if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], rvOfvalue) @@ -351,7 +362,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { case schema.UnixNanosecond: assignment.Value = curTime.UnixNano() case schema.UnixMillisecond: - assignment.Value = curTime.UnixNano() / 1e6 + assignment.Value = curTime.UnixMilli() case schema.UnixSecond: assignment.Value = curTime.Unix() } diff --git a/callbacks/create_test.go b/callbacks/create_test.go new file mode 100644 index 0000000000..da6b172bd5 --- /dev/null +++ b/callbacks/create_test.go @@ -0,0 +1,71 @@ +package callbacks + +import ( + "reflect" + "sync" + "testing" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" +) + +var schemaCache = &sync.Map{} + +func TestConvertToCreateValues_DestType_Slice(t *testing.T) { + type user struct { + ID int `gorm:"primaryKey"` + Name string + Email string `gorm:"default:(-)"` + Age int `gorm:"default:(-)"` + } + + s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{}) + if err != nil { + t.Errorf("parse schema error: %v, is not expected", err) + return + } + dest := []*user{ + { + ID: 1, + Name: "alice", + Email: "email", + Age: 18, + }, + { + ID: 2, + Name: "bob", + Email: "email", + Age: 19, + }, + } + stmt := &gorm.Statement{ + DB: &gorm.DB{ + Config: &gorm.Config{ + NowFunc: func() time.Time { return time.Time{} }, + }, + Statement: &gorm.Statement{ + Settings: sync.Map{}, + Schema: s, + }, + }, + ReflectValue: reflect.ValueOf(dest), + Dest: dest, + } + + stmt.Schema = s + + values := ConvertToCreateValues(stmt) + expected := clause.Values{ + // column has value + defaultValue column has value (which should have a stable order) + Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}}, + Values: [][]interface{}{ + {"alice", "email", 18, 1}, + {"bob", "email", 19, 2}, + }, + } + if !reflect.DeepEqual(expected, values) { + t.Errorf("expected: %v got %v", expected, values) + } +} diff --git a/callbacks/update.go b/callbacks/update.go index ff075dcf28..7cde7f6196 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -234,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()}) } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) } else { @@ -268,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 + value = stmt.DB.NowFunc().UnixMilli() } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() } else { diff --git a/logger/sql.go b/logger/sql.go index 8ce8d8b17c..ad4787956b 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO // RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) +func isNumeric(k reflect.Kind) bool { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + default: + return false + } +} + // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var ( @@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) + } else if isNumeric(rv.Kind()) { + if rv.CanInt() || rv.CanUint() { + vars[idx] = fmt.Sprintf("%d", rv.Interface()) + } else { + vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) + } } else { for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { diff --git a/logger/sql_test.go b/logger/sql_test.go index 036ef3a4a9..9002a7eb03 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -37,14 +37,18 @@ func format(v []byte, escaper string) string { func TestExplainSQL(t *testing.T) { type role string type password []byte + type intType int + type floatType float64 var ( - tt = now.MustParse("2020-02-23 11:10:10") - myrole = role("admin") - pwd = password("pass") - jsVal = []byte(`{"Name":"test","Val":"test"}`) - js = JSON(jsVal) - esVal = []byte(`{"Name":"test","Val":"test"}`) - es = ExampleStruct{Name: "test", Val: "test"} + tt = now.MustParse("2020-02-23 11:10:10") + myrole = role("admin") + pwd = password("pass") + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} + intVal intType = 1 + floatVal floatType = 1.23 ) results := []struct { @@ -107,6 +111,18 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`, + }, } for idx, r := range results { diff --git a/migrator/migrator.go b/migrator/migrator.go index ae82f76931..acce5df216 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -7,6 +7,7 @@ import ( "fmt" "reflect" "regexp" + "strconv" "strings" "time" @@ -518,12 +519,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || - (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { - // default value not equal - // not both null - if currentDefaultNotNull || dvNotNull { - alterColumn = true + } else if currentDefaultNotNull || dvNotNull { + switch field.GORMDataType { + case schema.Time: + if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) { + alterColumn = true + } + case schema.Bool: + v1, _ := strconv.ParseBool(dv) + v2, _ := strconv.ParseBool(field.DefaultValue) + alterColumn = v1 != v2 + default: + alterColumn = dv != field.DefaultValue } } } diff --git a/prepare_stmt.go b/prepare_stmt.go index aa944624c8..4d533885e4 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -3,6 +3,8 @@ package gorm import ( "context" "database/sql" + "database/sql/driver" + "errors" "reflect" "sync" ) @@ -147,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() go stmt.Close() @@ -161,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() @@ -180,6 +182,14 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg return &sql.Row{} } +func (db *PreparedStmtDB) Ping() error { + conn, err := db.GetDBConn() + if err != nil { + return err + } + return conn.Ping() +} + type PreparedStmtTX struct { Tx PreparedStmtDB *PreparedStmtDB @@ -207,7 +217,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -222,7 +232,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -240,3 +250,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg } return &sql.Row{} } + +func (tx *PreparedStmtTX) Ping() error { + conn, err := tx.GetDBConn() + if err != nil { + return err + } + return conn.Ping() +} diff --git a/scan.go b/scan.go index 54cd6769d4..415b9f0d74 100644 --- a/scan.go +++ b/scan.go @@ -274,7 +274,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if !update || reflectValue.Len() == 0 { update = false - if !isArrayKind { + if isArrayKind { + db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) + } else { // if the slice cap is externally initialized, the externally initialized slice is directly used here if reflectValue.Cap() == 0 { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) diff --git a/schema/constraint.go b/schema/constraint.go index 5f6beb89c5..80a743a835 100644 --- a/schema/constraint.go +++ b/schema/constraint.go @@ -8,7 +8,7 @@ import ( ) // reg match english letters and midline -var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$") +var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`) type CheckConstraint struct { Name string diff --git a/schema/field.go b/schema/field.go index 91e4c0abfd..ca2e11482c 100644 --- a/schema/field.go +++ b/schema/field.go @@ -664,7 +664,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -673,7 +673,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6) + field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } @@ -738,7 +738,7 @@ func (field *Field) setupValuerAndSetter() { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { - field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6)) + field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli())) } else { field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } diff --git a/schema/serializer.go b/schema/serializer.go index 397edff034..f500521ef7 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -126,12 +126,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect rv := reflect.ValueOf(fieldValue) switch v := fieldValue.(type) { case int64, int, uint, uint64, int32, uint32, int16, uint16: - result = time.Unix(reflect.Indirect(rv).Int(), 0) + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: if rv.IsZero() { return nil, nil } - result = time.Unix(reflect.Indirect(rv).Int(), 0) + result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC() default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) } diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 4479da4c04..f77209f135 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -91,7 +91,7 @@ func TestCallbacks(t *testing.T) { }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, - results: []string{"c1", "c5", "c3", "c4"}, + results: []string{"c1", "c3", "c4", "c5"}, }, { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, @@ -206,3 +206,49 @@ func TestPluginCallbacks(t *testing.T) { t.Errorf("callbacks tests failed, got %v", msg) } } + +func TestCallbacksGet(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) { + t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1) + } + + createCallback.Remove("c1") + if cb := createCallback.Get("c2"); cb != nil { + t.Errorf("callbacks test failed. got: %p, want: nil", cb) + } +} + +func TestCallbacksRemove(t *testing.T) { + db, _ := gorm.Open(nil, nil) + createCallback := db.Callback().Create() + + createCallback.Before("*").Register("c1", c1) + createCallback.After("*").Register("c2", c2) + createCallback.Before("c4").Register("c3", c3) + createCallback.After("c2").Register("c4", c4) + + // callbacks: []string{"c1", "c3", "c4", "c2"} + createCallback.Remove("c1") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c4") + if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c2") + if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } + + createCallback.Remove("c3") + if ok, msg := assertCallbacks(createCallback, []string{}); !ok { + t.Errorf("callbacks tests failed, got %v", msg) + } +} diff --git a/tests/create_test.go b/tests/create_test.go index 5e97a542dd..abb82472a7 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -713,18 +713,16 @@ func TestCreateFromMapWithoutPK(t *testing.T) { } func TestCreateFromMapWithTable(t *testing.T) { - if !isMysql() { - t.Skipf("This test case skipped, because of only supportting for mysql") - } - tableDB := DB.Table("`users`") + tableDB := DB.Table("users") + supportLastInsertID := isMysql() || isSqlite() // case 1: create from map[string]interface{} - record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} + record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18} if err := tableDB.Create(record).Error; err != nil { t.Fatalf("failed to create data from map with table, got error: %v", err) } - if _, ok := record["@id"]; !ok { + if _, ok := record["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -733,8 +731,8 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create from map, got error %v", err) } - if int64(res["id"].(uint64)) != record["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) { + t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"]) } // case 2: create from *map[string]interface{} @@ -743,7 +741,7 @@ func TestCreateFromMapWithTable(t *testing.T) { if err := tableDB2.Create(&record1).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } - if _, ok := record1["@id"]; !ok { + if _, ok := record1["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -752,7 +750,7 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create from map, got error %v", err) } - if int64(res1["id"].(uint64)) != record1["@id"] { + if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) { t.Fatal("failed to create data from map with table, @id != id") } @@ -767,11 +765,11 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create data from slice of map, got error: %v", err) } - if _, ok := records[0]["@id"]; !ok { + if _, ok := records[0]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } - if _, ok := records[1]["@id"]; !ok { + if _, ok := records[1]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -785,11 +783,11 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to query data after create from slice of map, got error %v", err) } - if int64(res2["id"].(uint64)) != records[0]["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"]) } - if int64(res3["id"].(uint64)) != records[1]["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id") } } diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 918f0796d4..71d6deb25b 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -38,4 +38,22 @@ func TestDefaultValue(t *testing.T) { } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to find created data with default data, got %+v", result) } + + type Harumph2 struct { + ID int `gorm:"default:0"` + Email string `gorm:"not null;index:,unique"` + Name string `gorm:"notNull;default:foo"` + Name2 string `gorm:"size:233;not null;default:'foo'"` + Name3 string `gorm:"size:233;notNull;default:''"` + Age int `gorm:"default:18"` + Created time.Time `gorm:"default:2000-01-02"` + Enabled bool `gorm:"default:true"` + } + + harumph2 := Harumph2{ID: 2, Email: "hello2@gorm.io"} + if err := DB.Table("harumphs").Create(&harumph2).Error; err != nil { + t.Fatalf("Failed to create data with default value, got error: %v", err) + } else if harumph2.ID != 2 || harumph2.Name != "foo" || harumph2.Name2 != "foo" || harumph2.Name3 != "" || harumph2.Age != 18 || !harumph2.Enabled || harumph2.Created.Format("20060102") != "20000102" { + t.Fatalf("Failed to create data with default value, got: %+v", harumph2) + } } diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 866a4d62df..8abd4d0f7a 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -24,6 +24,7 @@ services: ports: - "9930:1433" environment: + - TZ=Asia/Shanghai - ACCEPT_EULA=Y - SA_PASSWORD=LoremIpsum86 - MSSQL_DB=gorm diff --git a/tests/go.mod b/tests/go.mod index 136667b70b..d58469c45a 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,29 +6,30 @@ require ( github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 - github.com/stretchr/testify v1.8.4 - gorm.io/driver/mysql v1.5.4 - gorm.io/driver/postgres v1.5.6 + github.com/stretchr/testify v1.9.0 + gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde + gorm.io/gorm v1.25.9 ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.3 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect - github.com/microsoft/go-mssqldb v1.6.0 // indirect + github.com/microsoft/go-mssqldb v1.7.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.18.0 // indirect + golang.org/x/crypto v0.22.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/tests/helper_test.go b/tests/helper_test.go index feb67f9e1d..dc250b7c3a 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -281,6 +281,10 @@ func isMysql() bool { return os.Getenv("GORM_DIALECT") == "mysql" } +func isSqlite() bool { + return os.Getenv("GORM_DIALECT") == "sqlite" +} + func db(unscoped bool) *gorm.DB { if unscoped { return DB.Unscoped() diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b25b9da64a..d955c8d7f2 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "os" "reflect" + "strconv" "strings" "testing" "time" @@ -1420,7 +1421,7 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) { AssertEqual(t, nil, err) } -func TestMigrateDefaultNullString(t *testing.T) { +func TestMigrateWithDefaultValue(t *testing.T) { if DB.Dialector.Name() == "sqlserver" { // sqlserver driver treats NULL and 'NULL' the same t.Skip("skip sqlserver") @@ -1434,6 +1435,7 @@ func TestMigrateDefaultNullString(t *testing.T) { type NullStringModel struct { ID uint Content string `gorm:"default:'null'"` + Active bool `gorm:"default:false"` } tableName := "null_string_model" @@ -1454,6 +1456,14 @@ func TestMigrateDefaultNullString(t *testing.T) { AssertEqual(t, defVal, "null") AssertEqual(t, ok, true) + columnType2, err := findColumnType(tableName, "active") + AssertEqual(t, err, nil) + + defVal, ok = columnType2.DefaultValue() + bv, _ := strconv.ParseBool(defVal) + AssertEqual(t, bv, false) + AssertEqual(t, ok, true) + // default 'null' -> 'null' session := DB.Session(&gorm.Session{Logger: Tracer{ Logger: DB.Config.Logger, diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index b234c8bf21..b86bc3d64f 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -126,33 +126,6 @@ func TestPreparedStmtDeadlock(t *testing.T) { AssertEqual(t, sqlDB.Stats().InUse, 0) } -func TestPreparedStmtError(t *testing.T) { - tx, err := OpenTestConnection(&gorm.Config{}) - AssertEqual(t, err, nil) - - sqlDB, _ := tx.DB() - sqlDB.SetMaxOpenConns(1) - - tx = tx.Session(&gorm.Session{PrepareStmt: true}) - - wg := sync.WaitGroup{} - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - // err prepare - tag := Tag{Locale: "zh"} - tx.Table("users").Find(&tag) - wg.Done() - }() - } - wg.Wait() - - conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) - AssertEqual(t, ok, true) - AssertEqual(t, len(conn.Stmts), 0) - AssertEqual(t, sqlDB.Stats().InUse, 0) -} - func TestPreparedStmtInTransaction(t *testing.T) { user := User{Name: "jinzhu"} diff --git a/tests/query_test.go b/tests/query_test.go index e780e3bfe2..c0259a14a3 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1409,3 +1409,22 @@ func TestQueryError(t *testing.T) { }, Value: 1}).Scan(&p2).Error AssertEqual(t, err, gorm.ErrModelValueRequired) } + +func TestQueryScanToArray(t *testing.T) { + err := DB.Create(&User{Name: "testname1", Age: 10}).Error + if err != nil { + t.Fatal(err) + } + + users := [2]*User{{Name: "1"}, {Name: "2"}} + err = DB.Model(&User{}).Where("name = ?", "testname1").Find(&users).Error + if err != nil { + t.Fatal(err) + } + if users[0] == nil || users[0].Name != "testname1" { + t.Error("users[0] not covere") + } + if users[1] != nil { + t.Error("users[1] should be empty") + } +}