-
Notifications
You must be signed in to change notification settings - Fork 0
/
core.go
131 lines (118 loc) · 3.01 KB
/
core.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package orm
import (
"context"
"database/sql"
"orm/internal/valuer"
"orm/model"
)
// core 只是一个简单的封装,将一些 CRUD 都 需要使用的东西放到了一起。
type core struct {
dbName string
r model.Registry
dialect Dialect
ms []Middleware
valCreator valuer.BasicTypeCreator
}
func getMultiHandler[T any](ctx context.Context, c core,
sess session, qc *QueryContext) *QueryResult {
q, err := qc.Query()
if err != nil {
return &QueryResult{
Err: err,
}
}
rows, err := sess.queryContext(ctx, q.SQL, q.Args...)
if err != nil {
return &QueryResult{
Err: err,
}
}
//if !rows.Next() {
// return nil, ErrNoRows
//}
res := make([]*T, 0, 16)
for rows.Next() {
tp := new(T)
// 在这里灵活切换反射或者 unsafe
val := c.valCreator.NewBasicTypeValue(tp, qc.Meta)
err = val.SetColumns(rows)
if err != nil {
return &QueryResult{Err: err}
}
res = append(res, tp)
}
return &QueryResult{Result: res, Err: err}
}
func getMulti[T any](ctx context.Context, c core,
sess session, qc *QueryContext) *QueryResult {
var handler HandleFunc = func(ctx context.Context, qc *QueryContext) *QueryResult {
return getMultiHandler[T](ctx, c, sess, qc)
}
ms := c.ms
for i := len(ms) - 1; i >= 0; i-- {
handler = ms[i](handler)
}
return handler(ctx, qc)
}
func getHandler[T any](ctx context.Context, c core,
sess session, qc *QueryContext) *QueryResult {
q, err := qc.Query()
if err != nil {
return &QueryResult{
Err: err,
}
}
// s.db 是我们定义的 DB
// s.db.db 则是 sql.DB
// 使用 QueryContext,从而和 GetMulti 能够复用处理结果集的代码
rows, err := sess.queryContext(ctx, q.SQL, q.Args...)
if err != nil {
return &QueryResult{
Err: err,
}
}
if !rows.Next() {
return &QueryResult{
Err: ErrNoRows,
}
}
// 有 vals 了,接下来将 vals= [123, "Ming", 18, "Deng"] 反射放回去 t 里面
tp := new(T)
// 在这里灵活切换反射或者 unsafe
// 这里使用 BasicTypeValue 的 NewBasicTypeValue
val := c.valCreator.NewBasicTypeValue(tp, qc.Meta)
err = val.SetColumns(rows)
return &QueryResult{Result: tp, Err: err}
}
func get[T any](ctx context.Context, c core,
sess session, qc *QueryContext) *QueryResult {
var handler HandleFunc = func(ctx context.Context, qc *QueryContext) *QueryResult {
return getHandler[T](ctx, c, sess, qc)
}
ms := c.ms
for i := len(ms) - 1; i >= 0; i-- {
handler = ms[i](handler)
}
return handler(ctx, qc)
}
func exec[T any](ctx context.Context, c core,
sess session, qc *QueryContext) Result {
var handler HandleFunc = func(ctx context.Context, qc *QueryContext) *QueryResult {
q, err := qc.Query()
if err != nil {
return &QueryResult{Err: err}
}
res, err := sess.execContext(ctx, q.SQL, q.Args...)
return &QueryResult{Result: res, Err: err}
}
ms := c.ms
for i := len(ms) - 1; i >= 0; i-- {
handler = ms[i](handler)
}
qr := handler(ctx, qc)
var res sql.Result
if qr.Result != nil {
res = qr.Result.(sql.Result)
}
return Result{err: qr.Err, res: res}
}