-
Notifications
You must be signed in to change notification settings - Fork 1
/
db.go
145 lines (117 loc) · 3.08 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package mgrt
import (
"database/sql"
"errors"
"strconv"
"strings"
"sync"
_ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v4/stdlib"
)
// DB is a thin abstraction over the *sql.DB struct from the stdlib.
type DB struct {
*sql.DB
// Type is the type of database being connected to. This will be passed to
// sql.Open when the connection is being opened.
Type string
// Init is the function to call to initialize the database for performing
// revisions.
Init func(*sql.DB) error
// Parameterize is the function that is called to parameterize the query
// that will be executed against the database. This will make sure the
// correct SQL dialect is being used for the type of database.
Parameterize func(string) string
}
var (
dbMu sync.RWMutex
dbs = make(map[string]*DB)
mysqlInit = `CREATE TABLE mgrt_revisions (
id VARCHAR NOT NULL UNIQUE,
author VARCHAR NOT NULL,
comment TEXT NOT NULL,
sql TEXT NOT NULL,
performed_at INT NOT NULL
);`
postgresInit = `CREATE TABLE mgrt_revisions (
id VARCHAR NOT NULL UNIQUE,
author VARCHAR NOT NULL,
comment TEXT NOT NULL,
sql TEXT NOT NULL,
performed_at INT NOT NULL
);`
)
func init() {
Register("mysql", &DB{
Type: "mysql",
Init: initMysql,
Parameterize: parameterizeMysql,
})
Register("postgresql", &DB{
Type: "pgx",
Init: initPostgresql,
Parameterize: parameterizePostgresql,
})
}
func initMysql(db *sql.DB) error {
if _, err := db.Exec(mysqlInit); err != nil {
if !strings.Contains(err.Error(), "already exists") {
return err
}
}
return nil
}
func initPostgresql(db *sql.DB) error {
if _, err := db.Exec(postgresInit); err != nil {
if !strings.Contains(err.Error(), "already exists") {
return err
}
}
return nil
}
func parameterizeMysql(s string) string { return s }
func parameterizePostgresql(s string) string {
q := make([]byte, 0, len(s))
n := int64(0)
for i := strings.Index(s, "?"); i != -1; i = strings.Index(s, "?") {
n++
q = append(q, s[:i]...)
q = append(q, '$')
q = strconv.AppendInt(q, n, 10)
s = s[i+1:]
}
return string(append(q, []byte(s)...))
}
// Register will register the given *DB for the given database type. If the
// given type is a duplicate, then this panics. If the given *DB is nil, then
// this panics.
func Register(typ string, db *DB) {
dbMu.Lock()
defer dbMu.Unlock()
if db == nil {
panic("mgrt: nil database registered")
}
if _, ok := dbs[typ]; ok {
panic("mgrt: database already registered for " + typ)
}
dbs[typ] = db
}
// Open is a utility function that will call sql.Open with the given typ and
// dsn. The database connection returned from this will then be passed to Init
// for initializing the database.
func Open(typ, dsn string) (*DB, error) {
dbMu.RLock()
defer dbMu.RUnlock()
db, ok := dbs[typ]
if !ok {
return nil, errors.New("unknown database type " + typ)
}
sqldb, err := sql.Open(db.Type, dsn)
if err != nil {
return nil, err
}
if err := db.Init(sqldb); err != nil {
return nil, err
}
db.DB = sqldb
return db, nil
}