Skip to content

Commit

Permalink
Merge branch 'go-gorm:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
emaele authored Feb 7, 2024
2 parents f926e37 + 8fb9a31 commit 5541a3e
Show file tree
Hide file tree
Showing 48 changed files with 1,496 additions and 372 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
uses: actions/checkout@v4

- name: go mod package cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
Expand Down Expand Up @@ -73,7 +73,7 @@ jobs:
uses: actions/checkout@v4

- name: go mod package cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
Expand Down Expand Up @@ -116,7 +116,7 @@ jobs:
uses: actions/checkout@v4

- name: go mod package cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
Expand Down Expand Up @@ -159,7 +159,7 @@ jobs:
uses: actions/checkout@v4

- name: go mod package cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
Expand Down Expand Up @@ -202,7 +202,7 @@ jobs:
uses: actions/checkout@v4

- name: go mod package cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
Expand Down Expand Up @@ -235,7 +235,7 @@ jobs:


- name: go mod package cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly.

© Jinzhu, 2013~time.Now

Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
70 changes: 55 additions & 15 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,53 @@ func Create(config *Config) func(db *gorm.DB) {
}

db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
db.Statement.Schema.PrioritizedPrimaryField != nil &&
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
if db.RowsAffected == 0 {
return
}

var (
pkField *schema.Field
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}

insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
return
}

// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}

Expand All @@ -122,10 +162,10 @@ func Create(config *Config) func(db *gorm.DB) {
break
}

_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
}
}
} else {
Expand All @@ -135,16 +175,16 @@ func Create(config *Config) func(db *gorm.DB) {
break
}

if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
Expand Down
74 changes: 64 additions & 10 deletions callbacks/preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"

"gorm.io/gorm"
Expand Down Expand Up @@ -82,27 +83,80 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
return names
}

func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
if relationships == nil {
return nil
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)

// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)

isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
joinNames := strings.SplitN(join, ".", 2)
if len(joinNames) == 2 {
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
joined = true
nestedJoins = append(nestedJoins, joinNames[1])
}
}
}
return joined, nestedJoins
}
preloadMap := parsePreloadMap(s, preloads)
for name := range preloadMap {
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {

for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
return err
if joined, nestedJoins := isJoined(name); joined {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
}
}
return nil
}

func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})

if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}

func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var (
reflectValue = tx.Statement.ReflectValue
Expand Down
35 changes: 9 additions & 26 deletions callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"

"gorm.io/gorm"
Expand Down Expand Up @@ -254,7 +253,6 @@ func BuildQuerySQL(db *gorm.DB) {
}

db.Statement.AddClause(fromClause)
db.Statement.Joins = nil
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
Expand All @@ -272,38 +270,23 @@ func Preload(db *gorm.DB) {
return
}

preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
joins := make([]string, 0, len(db.Statement.Joins))
for _, join := range db.Statement.Joins {
joins = append(joins, join.Name)
}
sort.Strings(preloadNames)

preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
preloadDB.Statement.Settings.Store(k, v)
return true
})

if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
if tx.Error != nil {
return
}
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
preloadDB.Statement.Unscoped = db.Statement.Unscoped

for _, name := range preloadNames {
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
}

db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
}
}

func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
db.Statement.Joins = nil
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {
Expand Down
27 changes: 3 additions & 24 deletions chainable_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,33 +367,12 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
}

func (db *DB) executeScopes() (tx *DB) {
tx = db.getInstance()
scopes := db.Statement.scopes
if len(scopes) == 0 {
return tx
}
tx.Statement.scopes = nil

conditions := make([]clause.Interface, 0, 4)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}

db.Statement.scopes = nil
for _, scope := range scopes {
tx = scope(tx)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
}

for _, condition := range conditions {
tx.Statement.AddClause(condition)
db = scope(db)
}
return tx
return db
}

// Preload preload associations with given conditions
Expand Down
6 changes: 2 additions & 4 deletions clause/limit.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package clause

import "strconv"

// Limit limit clause
type Limit struct {
Limit *int
Expand All @@ -17,14 +15,14 @@ func (limit Limit) Name() string {
func (limit Limit) Build(builder Builder) {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(*limit.Limit))
builder.AddVar(builder, *limit.Limit)
}
if limit.Offset > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
builder.AddVar(builder, limit.Offset)
}
}

Expand Down
Loading

0 comments on commit 5541a3e

Please sign in to comment.