Skip to content

Commit

Permalink
Merge branch 'go-gorm:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
simoneserra93 authored Mar 14, 2024
2 parents 5541a3e + e4e23d2 commit 5b76cc4
Show file tree
Hide file tree
Showing 13 changed files with 371 additions and 80 deletions.
157 changes: 157 additions & 0 deletions callbacks/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package callbacks

import (
"reflect"
"testing"

"gorm.io/gorm"
"gorm.io/gorm/clause"
)

func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}

t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}

vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}

if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}

// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}

if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}

func TestConvertMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input map[string]interface{}
expect clause.Values
}{
{
name: "Test convert string value",
input: map[string]interface{}{
"name": "my name",
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert int value",
input: map[string]interface{}{
"age": 18,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert float value",
input: map[string]interface{}{
"score": 99.5,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert bool value",
input: map[string]interface{}{
"active": true,
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}

for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input)
if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expect %v got %v", tc.expect, actual)
}
})
}
}

func TestConvertSliceOfMapToValuesForCreate(t *testing.T) {
testCase := []struct {
name string
input []map[string]interface{}
expect clause.Values
}{
{
name: "Test convert slice of string value",
input: []map[string]interface{}{
{"name": "my name"},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "name"}},
Values: [][]interface{}{{"my name"}},
},
},
{
name: "Test convert slice of int value",
input: []map[string]interface{}{
{"age": 18},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "age"}},
Values: [][]interface{}{{18}},
},
},
{
name: "Test convert slice of float value",
input: []map[string]interface{}{
{"score": 99.5},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "score"}},
Values: [][]interface{}{{99.5}},
},
},
{
name: "Test convert slice of bool value",
input: []map[string]interface{}{
{"active": true},
},
expect: clause.Values{
Columns: []clause.Column{{Name: "active"}},
Values: [][]interface{}{{true}},
},
},
}

for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input)

if !reflect.DeepEqual(actual, tc.expect) {
t.Errorf("expected %v but got %v", tc.expect, actual)
}
})
}

}
21 changes: 17 additions & 4 deletions callbacks/preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,23 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
Expand Down
36 changes: 0 additions & 36 deletions callbacks/visit_map_test.go

This file was deleted.

73 changes: 56 additions & 17 deletions clause/where.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ func (where Where) Name() string {

// Build build where clause
func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}

// Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
Expand Down Expand Up @@ -166,19 +166,58 @@ type NotConditions struct {
}

func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
anyNegationBuilder := false
for _, c := range not.Exprs {
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
}
}

for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
if anyNegationBuilder {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}

if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}

if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}

c.Build(builder)

if wrapInParentheses {
builder.WriteByte(')')
}
}
}

if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}

for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}

e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
Expand All @@ -193,9 +232,9 @@ func (not NotConditions) Build(builder Builder) {
builder.WriteByte(')')
}
}
}

if len(not.Exprs) > 1 {
builder.WriteByte(')')
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}
}
8 changes: 8 additions & 0 deletions clause/where_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ func TestWhere(t *testing.T) {
"SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)",
[]interface{}{"1", 100},
},
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{
Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}},
clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})},
}},
"SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)",
[]interface{}{100, 60},
},
}

for idx, result := range results {
Expand Down
14 changes: 8 additions & 6 deletions scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,14 @@ func Scan(rows Rows, db *DB, mode ScanMode) {

if !update || reflectValue.Len() == 0 {
update = false
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else if !isArrayKind {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
if !isArrayKind {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}
}

Expand Down
Loading

0 comments on commit 5b76cc4

Please sign in to comment.