diff --git a/callbacks/helper_test.go b/callbacks/helper_test.go new file mode 100644 index 000000000..08f94e202 --- /dev/null +++ b/callbacks/helper_test.go @@ -0,0 +1,157 @@ +package callbacks + +import ( + "reflect" + "testing" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func TestLoadOrStoreVisitMap(t *testing.T) { + var vm visitMap + var loaded bool + type testM struct { + Name string + } + + t1 := testM{Name: "t1"} + t2 := testM{Name: "t2"} + t3 := testM{Name: "t3"} + + vm = make(visitMap) + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { + t.Fatalf("loaded should be true") + } + + // t1 already exist but t2 not + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { + t.Fatalf("loaded should be false") + } + + if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { + t.Fatalf("loaded should be true") + } +} + +func TestConvertMapToValuesForCreate(t *testing.T) { + testCase := []struct { + name string + input map[string]interface{} + expect clause.Values + }{ + { + name: "Test convert string value", + input: map[string]interface{}{ + "name": "my name", + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "name"}}, + Values: [][]interface{}{{"my name"}}, + }, + }, + { + name: "Test convert int value", + input: map[string]interface{}{ + "age": 18, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "age"}}, + Values: [][]interface{}{{18}}, + }, + }, + { + name: "Test convert float value", + input: map[string]interface{}{ + "score": 99.5, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "score"}}, + Values: [][]interface{}{{99.5}}, + }, + }, + { + name: "Test convert bool value", + input: map[string]interface{}{ + "active": true, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "active"}}, + Values: [][]interface{}{{true}}, + }, + }, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input) + if !reflect.DeepEqual(actual, tc.expect) { + t.Errorf("expect %v got %v", tc.expect, actual) + } + }) + } +} + +func TestConvertSliceOfMapToValuesForCreate(t *testing.T) { + testCase := []struct { + name string + input []map[string]interface{} + expect clause.Values + }{ + { + name: "Test convert slice of string value", + input: []map[string]interface{}{ + {"name": "my name"}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "name"}}, + Values: [][]interface{}{{"my name"}}, + }, + }, + { + name: "Test convert slice of int value", + input: []map[string]interface{}{ + {"age": 18}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "age"}}, + Values: [][]interface{}{{18}}, + }, + }, + { + name: "Test convert slice of float value", + input: []map[string]interface{}{ + {"score": 99.5}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "score"}}, + Values: [][]interface{}{{99.5}}, + }, + }, + { + name: "Test convert slice of bool value", + input: []map[string]interface{}{ + {"active": true}, + }, + expect: clause.Values{ + Columns: []clause.Column{{Name: "active"}}, + Values: [][]interface{}{{true}}, + }, + }, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input) + + if !reflect.DeepEqual(actual, tc.expect) { + t.Errorf("expected %v but got %v", tc.expect, actual) + } + }) + } + +} diff --git a/callbacks/preload.go b/callbacks/preload.go index 25ecfe761..cf7a0d2ba 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -121,10 +121,23 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati } } else if rel := relationships.Relations[name]; rel != nil { if joined, nestedJoins := isJoined(name); joined { - reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) - tx := preloadDB(db, reflectValue, reflectValue.Interface()) - if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { - return err + switch rv := db.Statement.ReflectValue; rv.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + } + case reflect.Struct: + reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) + tx := preloadDB(db, reflectValue, reflectValue.Interface()) + if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { + return err + } + default: + return gorm.ErrInvalidData } } else { tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) diff --git a/callbacks/visit_map_test.go b/callbacks/visit_map_test.go deleted file mode 100644 index b1fb86dbe..000000000 --- a/callbacks/visit_map_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package callbacks - -import ( - "reflect" - "testing" -) - -func TestLoadOrStoreVisitMap(t *testing.T) { - var vm visitMap - var loaded bool - type testM struct { - Name string - } - - t1 := testM{Name: "t1"} - t2 := testM{Name: "t2"} - t3 := testM{Name: "t3"} - - vm = make(visitMap) - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { - t.Fatalf("loaded should be false") - } - - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { - t.Fatalf("loaded should be true") - } - - // t1 already exist but t2 not - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { - t.Fatalf("loaded should be false") - } - - if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { - t.Fatalf("loaded should be true") - } -} diff --git a/clause/where.go b/clause/where.go index 46d0b3193..9ac78578e 100644 --- a/clause/where.go +++ b/clause/where.go @@ -21,11 +21,11 @@ func (where Where) Name() string { // Build build where clause func (where Where) Build(builder Builder) { - if len(where.Exprs) == 1 { - if andCondition, ok := where.Exprs[0].(AndConditions); ok { - where.Exprs = andCondition.Exprs - } - } + if len(where.Exprs) == 1 { + if andCondition, ok := where.Exprs[0].(AndConditions); ok { + where.Exprs = andCondition.Exprs + } + } // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { @@ -166,19 +166,58 @@ type NotConditions struct { } func (not NotConditions) Build(builder Builder) { - if len(not.Exprs) > 1 { - builder.WriteByte('(') + anyNegationBuilder := false + for _, c := range not.Exprs { + if _, ok := c.(NegationExpressionBuilder); ok { + anyNegationBuilder = true + break + } } - for idx, c := range not.Exprs { - if idx > 0 { - builder.WriteString(AndWithSpace) + if anyNegationBuilder { + if len(not.Exprs) > 1 { + builder.WriteByte('(') } - if negationBuilder, ok := c.(NegationExpressionBuilder); ok { - negationBuilder.NegationBuild(builder) - } else { - builder.WriteString("NOT ") + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + + if negationBuilder, ok := c.(NegationExpressionBuilder); ok { + negationBuilder.NegationBuild(builder) + } else { + builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToUpper(e.SQL) + if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { + builder.WriteByte('(') + } + } + + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } + } + } + + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } + } else { + builder.WriteString("NOT ") + if len(not.Exprs) > 1 { + builder.WriteByte('(') + } + + for idx, c := range not.Exprs { + if idx > 0 { + builder.WriteString(AndWithSpace) + } + e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) @@ -193,9 +232,9 @@ func (not NotConditions) Build(builder Builder) { builder.WriteByte(')') } } - } - if len(not.Exprs) > 1 { - builder.WriteByte(')') + if len(not.Exprs) > 1 { + builder.WriteByte(')') + } } } diff --git a/clause/where_test.go b/clause/where_test.go index aa9d06ebf..7d5aca1ff 100644 --- a/clause/where_test.go +++ b/clause/where_test.go @@ -105,6 +105,14 @@ func TestWhere(t *testing.T) { "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", []interface{}{"1", 100}, }, + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ + Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}}, + clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})}, + }}, + "SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)", + []interface{}{100, 60}, + }, } for idx, result := range results { diff --git a/scan.go b/scan.go index 736db4d3a..54cd6769d 100644 --- a/scan.go +++ b/scan.go @@ -274,12 +274,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if !update || reflectValue.Len() == 0 { update = false - // if the slice cap is externally initialized, the externally initialized slice is directly used here - if reflectValue.Cap() == 0 { - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) - } else if !isArrayKind { - reflectValue.SetLen(0) - db.Statement.ReflectValue.Set(reflectValue) + if !isArrayKind { + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } } } diff --git a/tests/create_test.go b/tests/create_test.go index d9b54b7f6..5e97a542d 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -582,44 +582,44 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { } func TestCreateOnConflictWithDefaultNull(t *testing.T) { - type OnConfilctUser struct { + type OnConflictUser struct { ID string Name string `gorm:"default:null"` Email string Mobile string `gorm:"default:'133xxxx'"` } - err := DB.Migrator().DropTable(&OnConfilctUser{}) + err := DB.Migrator().DropTable(&OnConflictUser{}) AssertEqual(t, err, nil) - err = DB.AutoMigrate(&OnConfilctUser{}) + err = DB.AutoMigrate(&OnConflictUser{}) AssertEqual(t, err, nil) - u := OnConfilctUser{ - ID: "on-confilct-user-id", - Name: "on-confilct-user-name", - Email: "on-confilct-user-email", - Mobile: "on-confilct-user-mobile", + u := OnConflictUser{ + ID: "on-conflict-user-id", + Name: "on-conflict-user-name", + Email: "on-conflict-user-email", + Mobile: "on-conflict-user-mobile", } err = DB.Create(&u).Error AssertEqual(t, err, nil) - u.Name = "on-confilct-user-name-2" - u.Email = "on-confilct-user-email-2" + u.Name = "on-conflict-user-name-2" + u.Email = "on-conflict-user-email-2" u.Mobile = "" err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error AssertEqual(t, err, nil) - var u2 OnConfilctUser + var u2 OnConflictUser err = DB.Where("id = ?", u.ID).First(&u2).Error AssertEqual(t, err, nil) - AssertEqual(t, u2.Name, "on-confilct-user-name-2") - AssertEqual(t, u2.Email, "on-confilct-user-email-2") + AssertEqual(t, u2.Name, "on-conflict-user-name-2") + AssertEqual(t, u2.Email, "on-conflict-user-email-2") AssertEqual(t, u2.Mobile, "133xxxx") } func TestCreateFromMapWithoutPK(t *testing.T) { if !isMysql() { - t.Skipf("This test case skipped, because of only supportting for mysql") + t.Skipf("This test case skipped, because of only supporting for mysql") } // case 1: one record, create from map[string]interface{} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 837d92c1b..b25b9da64 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1413,10 +1413,10 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) { err = DB.Table("game_users").AutoMigrate(&GameUser1{}) AssertEqual(t, nil, err) - _, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count") + _, err = findColumnType(&GameUser{}, "stat_ab_ground_destroy_count") AssertEqual(t, nil, err) - _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") + _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destroy_count") AssertEqual(t, nil, err) } diff --git a/tests/preload_test.go b/tests/preload_test.go index 26b08d7de..14f94139d 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -8,6 +8,8 @@ import ( "sync" "testing" + "github.com/stretchr/testify/require" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" @@ -362,6 +364,14 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) { t.Errorf("failed to find value, got err: %v", err) } AssertEqual(t, find2, value) + + var finds []Value + err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error + if err != nil { + t.Errorf("failed to find value, got err: %v", err) + } + require.Len(t, finds, 1) + AssertEqual(t, finds[0], value) } func TestEmbedPreload(t *testing.T) { diff --git a/tests/query_test.go b/tests/query_test.go index cadf7164e..e780e3bfe 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -554,6 +554,11 @@ func TestNot(t *testing.T) { if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + + result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } } func TestNotWithAllFields(t *testing.T) { diff --git a/tests/table_test.go b/tests/table_test.go index fa569d320..0d44a15be 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -2,8 +2,10 @@ package tests_test import ( "regexp" + "sync" "testing" + "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" @@ -172,3 +174,88 @@ func TestTableWithNamer(t *testing.T) { t.Errorf("Table with namer, got %v", sql) } } + +func TestPostgresTableWithIdentifierLength(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + return + } + + type LongString struct { + ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"` + } + + t.Run("default", func(t *testing.T) { + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if len(key) != 63 { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) + + t.Run("naming strategy", func(t *testing.T) { + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{ + NamingStrategy: schema.NamingStrategy{}, + }) + + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if len(key) != 63 { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) + + t.Run("namer", func(t *testing.T) { + uname := "custom_unique_name" + db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{ + NamingStrategy: mockUniqueNamingStrategy{ + UName: uname, + }, + }) + + user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) + if err != nil { + t.Fatalf("failed to parse user unique, got error %v", err) + } + + constraints := user.ParseUniqueConstraints() + if len(constraints) != 1 { + t.Fatalf("failed to find unique constraint, got %v", constraints) + } + + for key := range constraints { + if key != uname { + t.Errorf("failed to find unique constraint, got %v", constraints) + } + } + }) +} + +type mockUniqueNamingStrategy struct { + UName string + schema.NamingStrategy +} + +func (a mockUniqueNamingStrategy) UniqueName(table, column string) string { + return a.UName +} diff --git a/utils/utils.go b/utils/utils.go index a4d8ac250..347a331fb 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -74,7 +74,11 @@ func ToStringKey(values ...interface{}) string { case uint: results[idx] = strconv.FormatUint(uint64(v), 10) default: - results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface()) + results[idx] = "nil" + vv := reflect.ValueOf(v) + if vv.IsValid() && !vv.IsZero() { + results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface()) + } } } diff --git a/utils/utils_test.go b/utils/utils_test.go index d0486822c..8ff42af8d 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -48,8 +48,10 @@ func TestToStringKey(t *testing.T) { }{ {[]interface{}{"a"}, "a"}, {[]interface{}{1, 2, 3}, "1_2_3"}, + {[]interface{}{1, nil, 3}, "1_nil_3"}, {[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"}, {[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"}, + {[]interface{}{[]interface{}{"1", nil, "3"}}, "[1 3]"}, } for _, c := range cases { if key := ToStringKey(c.values...); key != c.key {