Skip to content

Commit

Permalink
Merge branch 'go-gorm:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
simoneserra93 authored Apr 12, 2024
2 parents 5b76cc4 + 1e13fd7 commit c06ae47
Show file tree
Hide file tree
Showing 21 changed files with 328 additions and 95 deletions.
19 changes: 19 additions & 0 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,18 @@ func (p *processor) Replace(name string, fn func(*DB)) error {

func (p *processor) compile() (err error) {
var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
if callback.remove {
removedMap[callback.name] = true
}
}

if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
}
p.callbacks = callbacks

Expand Down Expand Up @@ -339,3 +347,14 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {

return
}

func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}
43 changes: 27 additions & 16 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) {
pkField *schema.Field
pkFieldName = "@id"
)

insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0

if !insertOk {
if !supportReturning {
db.AddError(err)
}
return
}

if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
Expand All @@ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) {
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}

insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
return
}

// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
Expand All @@ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) {
}
}
}

if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}

for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
Expand Down Expand Up @@ -293,13 +302,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
}

for field, vs := range defaultValueFieldsHavingValue {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
}
Expand All @@ -322,7 +333,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}

for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
Expand Down Expand Up @@ -351,7 +362,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixNano() / 1e6
assignment.Value = curTime.UnixMilli()
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}
Expand Down
71 changes: 71 additions & 0 deletions callbacks/create_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package callbacks

import (
"reflect"
"sync"
"testing"
"time"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

var schemaCache = &sync.Map{}

func TestConvertToCreateValues_DestType_Slice(t *testing.T) {
type user struct {
ID int `gorm:"primaryKey"`
Name string
Email string `gorm:"default:(-)"`
Age int `gorm:"default:(-)"`
}

s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{})
if err != nil {
t.Errorf("parse schema error: %v, is not expected", err)
return
}
dest := []*user{
{
ID: 1,
Name: "alice",
Email: "email",
Age: 18,
},
{
ID: 2,
Name: "bob",
Email: "email",
Age: 19,
},
}
stmt := &gorm.Statement{
DB: &gorm.DB{
Config: &gorm.Config{
NowFunc: func() time.Time { return time.Time{} },
},
Statement: &gorm.Statement{
Settings: sync.Map{},
Schema: s,
},
},
ReflectValue: reflect.ValueOf(dest),
Dest: dest,
}

stmt.Schema = s

values := ConvertToCreateValues(stmt)
expected := clause.Values{
// column has value + defaultValue column has value (which should have a stable order)
Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}},
Values: [][]interface{}{
{"alice", "email", 18, 1},
{"bob", "email", 19, 2},
},
}
if !reflect.DeepEqual(expected, values) {
t.Errorf("expected: %v got %v", expected, values)
}
}
4 changes: 2 additions & 2 deletions callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
Expand Down Expand Up @@ -268,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6
value = stmt.DB.NowFunc().UnixMilli()
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {
Expand Down
19 changes: 19 additions & 0 deletions logger/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)

func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}

// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var (
Expand Down Expand Up @@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) {
Expand Down
30 changes: 23 additions & 7 deletions logger/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@ func format(v []byte, escaper string) string {
func TestExplainSQL(t *testing.T) {
type role string
type password []byte
type intType int
type floatType float64
var (
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin")
pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
)

results := []struct {
Expand Down Expand Up @@ -107,6 +111,18 @@ func TestExplainSQL(t *testing.T) {
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
},
}

for idx, r := range results {
Expand Down
19 changes: 13 additions & 6 deletions migrator/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -518,12 +519,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value
alterColumn = true
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
// default value not equal
// not both null
if currentDefaultNotNull || dvNotNull {
alterColumn = true
} else if currentDefaultNotNull || dvNotNull {
switch field.GORMDataType {
case schema.Time:
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
alterColumn = true
}
case schema.Bool:
v1, _ := strconv.ParseBool(dv)
v2, _ := strconv.ParseBool(field.DefaultValue)
alterColumn = v1 != v2
default:
alterColumn = dv != field.DefaultValue
}
}
}
Expand Down
26 changes: 22 additions & 4 deletions prepare_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync"
)
Expand Down Expand Up @@ -147,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
Expand All @@ -161,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()

Expand All @@ -180,6 +182,14 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
return &sql.Row{}
}

func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}

type PreparedStmtTX struct {
Tx
PreparedStmtDB *PreparedStmtDB
Expand Down Expand Up @@ -207,7 +217,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if err != nil {
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()

Expand All @@ -222,7 +232,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if err != nil {
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()

Expand All @@ -240,3 +250,11 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
}
return &sql.Row{}
}

func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
Loading

0 comments on commit c06ae47

Please sign in to comment.