-
Notifications
You must be signed in to change notification settings - Fork 20
/
copyist.go
268 lines (235 loc) · 9.33 KB
/
copyist.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
// Copyright 2020 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.
package copyist
import (
"database/sql"
"errors"
"flag"
"fmt"
"io"
"os"
"path"
"runtime"
"strings"
"github.com/jmoiron/sqlx"
)
// testingT is a subset of the testing.T methods that are used by copyist. The
// minimal subset allows easier mocking of testing.T in tests.
type testingT interface {
Fatalf(format string, args ...interface{})
Name() string
}
// recordFlag instructs copyist to record all calls to the registered driver, if
// true. Otherwise, it plays back previously recorded calls.
var recordFlag = flag.Bool("record", true, "record sql database accesses")
var visitedRecording bool
// IsRecording returns true if copyist is currently in recording mode.
func IsRecording() bool {
// Determine whether the "record" flag was explicitly passed rather than
// defaulted. This is painful and slow in Go, so do it just once.
if !visitedRecording {
found := false
flag.Visit(func(f *flag.Flag) {
if f.Name == "record" {
found = true
}
})
if !found {
// If the record flag was not explicitly specified, then next check
// the value of the COPYIST_RECORD environment variable.
if os.Getenv("COPYIST_RECORD") != "" {
*recordFlag = true
} else {
*recordFlag = false
}
}
visitedRecording = true
}
return *recordFlag
}
// MaxRecordingSize is the maximum size, in bytes, of a single recording in its
// text format.
var MaxRecordingSize = 1024 * 1024
// SessionInitCallback types a function that is invoked once per session for
// each driver, when in recording mode, in order to initialize the database to a
// clean, well-known state.
type SessionInitCallback func()
// sessionInit is called at the beginning of each new session, if not nil.
var sessionInit SessionInitCallback
// registered is the set of proxy drivers created via calls to Register, indexed
// by driver name.
var registered map[string]*proxyDriver
// Register constructs a proxy driver that wraps the "real" driver of the given
// name. Depending on the value of the "record" command-line flag, the
// constructed proxy will either record calls to the wrapped driver, or else
// play back calls that were previously recorded. Register must be called before
// copyist.Open can be called, typically in an init() method. Note that the
// wrapped driver is lazily fetched from the `sql` package, so if a driver of
// that name does not exist, an error will not be raised until a connection is
// opened for the first time.
//
// The Register method takes the name of the SQL driver to be wrapped (e.g.
// "postgres"). Below is an example of how copyist.Register should be invoked.
//
// copyist.Register("postgres")
//
// Note that Register can only be called once for a given driver; subsequent
// attempts will fail with an error. In addition, the same copyist driver must
// be used with playback as was was used during recording.
func Register(driverName string) {
if registered == nil {
registered = make(map[string]*proxyDriver)
} else if _, ok := registered[driverName]; ok {
panic(fmt.Errorf("Register called twice for driver %s", driverName))
}
copyistDriver := &proxyDriver{driverName: driverName}
registered[driverName] = copyistDriver
// sqlx uses a default list of driver names to determine how to represent
// parameters in prepared queries. For example, postgres uses $1, mysql
// uses ?, sqlserver uses @, and so on. But since copyist defines a custom
// driver name, sqlx falls back to the default ?, which won't work with some
// databases. Register the copyist driver name with sqlx and tell it to use
// the bind type of the underlying driver rather than the default ?.
copyistDriverName := copyistDriverName(driverName)
sqlx.BindDriver(copyistDriverName, sqlx.BindType(driverName))
// Register the copyist driver with the `sql` package.
sql.Register(copyistDriverName, copyistDriver)
}
// SetSessionInit sets the callback function that will be invoked at the
// beginning of each copyist session. This can be used to initialize the test
// database to a clean, well-known state.
//
// NOTE: The callback is only invoked in "recording" mode. There is no need to
// call it in "playback" mode, as the database is not actually accessed at that
// time.
func SetSessionInit(callback SessionInitCallback) {
sessionInit = callback
}
// Open begins a recording or playback session, depending on the value of the
// "record" command-line flag. If recording, then all calls to registered
// drivers will be recorded and then saved in a copyist recording file that sits
// alongside the calling test file. If playing back, then the recording will
// be fetched from that recording file. Here is a typical calling pattern:
//
// func init() {
// copyist.Register("postgres")
// }
//
// func TestMyStuff(t *testing.T) {
// defer copyist.Open(t).Close()
// ...
// }
//
// The call to Open will initiate a new recording session. The deferred call to
// Close will complete the recording session and write the recording to a file
// in the testdata/ directory, like:
//
// mystuff_test.go
// testdata/
// mystuff_test.copyist
//
// Each test or sub-test that needs to be executed independently needs to record
// its own session.
func Open(t testingT) io.Closer {
if registered == nil {
panic(errors.New("Register was not called"))
}
// Get name of calling test file.
fileName := findTestFile()
// Construct the recording pathName name by locating the copyist recording
// file in the testdata directory with the ".copyist" extension.
dirName := path.Join(path.Dir(fileName), "testdata")
fileName = path.Base(fileName[:len(fileName)-3]) + ".copyist"
pathName := path.Join(dirName, fileName)
// The recording name is the name of the test.
recordingName := t.Name()
return OpenNamed(t, pathName, recordingName)
}
// OpenNamed is a variant of Open which accepts a caller-specified pathName and
// recordingName rather than deriving default values for them. The given
// pathName will be used as the name of the output file containing the
// recordings rather than the default "_test.copyist" file in the testdata
// directory. The given recordingName will be used as the recording name in that
// file rather than using the testing.T.Name() value.
func OpenNamed(t testingT, pathName, recordingName string) io.Closer {
if registered == nil {
panic(errors.New("Register was not called"))
}
return OpenSource(t, fileSource{PathName: pathName}, recordingName)
}
// OpenSource is a variant of Open which accepts a caller-specified source and
// recordingName rather than deriving default values for them. The given source
// will be used to persist and load recordings rather than the default
// "_test.copyist" file in the testdata directory. The given recordingName will
// be used as the recording name in that file rather than using the
// testing.T.Name() value.
func OpenSource(t testingT, source Source, recordingName string) io.Closer {
if registered == nil {
panic(errors.New("Register was not called"))
}
// Start a new recording or playback session.
currentSession = newSession(source, recordingName)
// Return a closer that will close the session when called.
return closer(func(r interface{}) error {
// Convert sessionError panics into fatal test errors.
if _, ok := r.(*sessionError); ok {
t.Fatalf("%v\n", r)
} else if r != nil {
panic(r)
}
if currentSession.verificationErr != nil {
t.Fatalf("%+v\n", currentSession.verificationErr.error)
}
currentSession.Close()
currentSession = nil
return nil
})
}
// findTestFile searches the call stack, looking for the test that called
// copyist.Open. It searches up to N levels, looking for the last file that
// ends in "_test.go" and returns that filename.
func findTestFile() string {
const levels = 10
var lastTestFilename string
for i := 0; i < levels; i++ {
_, fileName, _, _ := runtime.Caller(2 + i)
if strings.HasSuffix(fileName, "_test.go") {
lastTestFilename = fileName
}
}
if lastTestFilename != "" {
return lastTestFilename
}
panic(fmt.Errorf("Open was not called directly or indirectly from a test file"))
}
// copyistDriverName constructs the copyist wrapper driver's name as a function
// of the wrapped driver's name.
func copyistDriverName(driverName string) string {
return "copyist_" + driverName
}
// clearPooledConnections clears any pooled connection on all registered
// drivers, in order to ensure determinism. For more information, see the
// proxyDriver comment regarding connection pooling.
func clearPooledConnections() {
for _, driver := range registered {
driver.clearPooledConnection()
}
}
// closer implements the io.Closer interface by invoking an arbitrary function
// when Close is called. The function is passed the return value of recover().
type closer func(r interface{}) error
// Close implements the io.Closer interface method.
func (c closer) Close() error {
return c(recover())
}