Skip to content

Commit

Permalink
Better dbump_pg (#10)
Browse files Browse the repository at this point in the history
Co-authored-by: Oleg Kovalov <oleg@hey.com>
  • Loading branch information
zhezhel and cristaloleg authored Sep 23, 2023
1 parent 4611380 commit 53e63dc
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 76 deletions.
2 changes: 1 addition & 1 deletion dbump_pg/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/cristalhq/dbump/dbump_pg
go 1.17

require (
github.com/cristalhq/dbump v0.3.0
github.com/cristalhq/dbump v0.14.0
github.com/lib/pq v1.10.6
)

Expand Down
114 changes: 90 additions & 24 deletions dbump_pg/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,136 @@ package dbump_pg
import (
"context"
"database/sql"
"errors"
"fmt"
"hash/fnv"

"github.com/cristalhq/dbump"
)

// to prevent multiple migrations running at the same time
const lockNum int64 = 707_707_707

var _ dbump.Migrator = &Migrator{}

// Migrator to migrate Postgres.
type Migrator struct {
conn *sql.DB
cfg Config
}

// Config for the migrator.
type Config struct {
// Schema for the dbump version table. Default is empty which means "public" schema.
Schema string
// Table for the dbump version table. Default is empty which means "_dbump_log" table.
Table string

// [schema.]table
tableName string
// to prevent multiple migrations running at the same time
lockNum int64
}

// NewMigrator instantiates new Migrator.
// Takes std *sql.DB.
func NewMigrator(conn *sql.DB) *Migrator {
func NewMigrator(conn *sql.DB, cfg Config) *Migrator {
if cfg.Schema == "" {
cfg.Schema = "public"
}
if cfg.Table == "" {
cfg.Table = "_dbump_log"
}

cfg.tableName = cfg.Schema + "." + cfg.Table
cfg.lockNum = hashTableName(cfg.tableName)

return &Migrator{
conn: conn,
cfg: cfg,
}
}

// Init migrator.
func (pg *Migrator) Init(ctx context.Context) error {
query := `CREATE TABLE IF NOT EXISTS _dbump_schema_version (
version BIGINT NOT NULL PRIMARY KEY,
var query string
if pg.cfg.Schema != "" {
query = fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s;`, pg.cfg.Schema)
}

query += fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
version BIGINT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL
);`
);`, pg.cfg.tableName)

_, err := pg.conn.ExecContext(ctx, query)
return err
}

// Drop is a method from Migrator interface.
func (pg *Migrator) Drop(ctx context.Context) error {
query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, pg.cfg.tableName)
_, err := pg.conn.ExecContext(ctx, query)
return err
}

// LockDB is a method for Migrator interface.
// LockDB is a method from Migrator interface.
func (pg *Migrator) LockDB(ctx context.Context) error {
_, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_lock($1);", lockNum)
_, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_lock($1);", pg.cfg.lockNum)
return err
}

// UnlockDB is a method for Migrator interface.
// UnlockDB is a method from Migrator interface.
func (pg *Migrator) UnlockDB(ctx context.Context) error {
_, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_unlock($1);", lockNum)
_, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_unlock($1);", pg.cfg.lockNum)
return err
}

// Version is a method for Migrator interface.
func (pg *Migrator) Version(ctx context.Context) (version int, err error) {
query := "SELECT COUNT(*) FROM _dbump_schema_version;"
query := fmt.Sprintf("SELECT version FROM %s ORDER BY created_at DESC LIMIT 1;", pg.cfg.tableName)
row := pg.conn.QueryRowContext(ctx, query)
err = row.Scan(&version)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return version, err
}

// SetVersion is a method for Migrator interface.
func (pg *Migrator) SetVersion(ctx context.Context, version int) error {
query := `INSERT INTO _dbump_schema_version (version, created_at)
VALUES ($1, NOW())
ON CONFLICT (version) DO UPDATE
SET created_at = NOW();`
_, err := pg.conn.ExecContext(ctx, query, version)
return err
// DoStep is a method for Migrator interface.
func (pg *Migrator) DoStep(ctx context.Context, step dbump.Step) error {
if step.DisableTx {
if _, err := pg.conn.ExecContext(ctx, step.Query); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, NOW());", pg.cfg.tableName)
_, err := pg.conn.ExecContext(ctx, query, step.Version)
return err
}

return pg.beginFunc(ctx, func(tx *sql.Tx) error {
if _, err := tx.ExecContext(ctx, step.Query); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, NOW());", pg.cfg.tableName)
_, err := tx.ExecContext(ctx, query, step.Version)
return err
})
}

// Exec is a method for Migrator interface.
func (pg *Migrator) Exec(ctx context.Context, query string, args ...interface{}) error {
_, err := pg.conn.ExecContext(ctx, query, args...)
return err
func (pg *Migrator) beginFunc(ctx context.Context, f func(*sql.Tx) error) (err error) {
tx, err := pg.conn.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()

if err := f(tx); err != nil {
return err
}

return tx.Commit()
}

func hashTableName(s string) int64 {
h := fnv.New64()
h.Write([]byte(s))
return int64(h.Sum64())
}
149 changes: 98 additions & 51 deletions dbump_pg/pg_test.go
Original file line number Diff line number Diff line change
@@ -1,85 +1,132 @@
package dbump_pg_test
package dbump_pg

import (
"context"
"database/sql"
"fmt"
"os"
"reflect"
"testing"

"github.com/cristalhq/dbump"
"github.com/cristalhq/dbump/dbump_pg"

"github.com/cristalhq/dbump/tests"
_ "github.com/lib/pq"
)

var sqldb *sql.DB

func init() {
var (
host = os.Getenv("DBUMP_PG_HOST")
port = os.Getenv("DBUMP_PG_PORT")
username = os.Getenv("DBUMP_PG_USER")
password = os.Getenv("DBUMP_PG_PASS")
db = os.Getenv("DBUMP_PG_DB")
sslmode = os.Getenv("DBUMP_PG_SSL")
)

if host == "" {
host = "localhost"
}
if port == "" {
port = "5432"
}
if username == "" {
username = "postgres"
}
if password == "" {
password = "postgres"
}
if db == "" {
db = "postgres"
}
if sslmode == "" {
sslmode = "disable"
}
host := envOrDef("DBUMP_PG_HOST", "localhost")
port := envOrDef("DBUMP_PG_PORT", "5432")
username := envOrDef("DBUMP_PG_USER", "postgres")
password := envOrDef("DBUMP_PG_PASS", "postgres")
db := envOrDef("DBUMP_PG_DB", "postgres")
sslmode := envOrDef("DBUMP_PG_SSL", "disable")

dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
host, port, username, password, db, sslmode)

var err error
sqldb, err = sql.Open("postgres", dsn)
if err != nil {
panic(err)
panic(fmt.Sprintf("dbump_pgx: cannot connect to container: %s", err))
}
}

func TestPG_Simple(t *testing.T) {
m := dbump_pg.NewMigrator(sqldb)
l := dbump.NewSliceLoader([]*dbump.Migration{
func TestNonDefaultSchemaTable(t *testing.T) {
testCases := []struct {
name string
schema string
table string
wantTableName string
wantLockNum int64
}{
{
ID: 1,
Apply: "SELECT 1;",
Rollback: "SELECT 1;",
name: "all empty",
schema: "",
table: "",
wantTableName: "public._dbump_log",
wantLockNum: 1542931740578198266,
},
{
ID: 2,
Apply: "SELECT 1;",
Rollback: "SELECT 1;",
name: "schema set",
schema: "test_schema",
table: "",
wantTableName: "test_schema._dbump_log",
wantLockNum: 1417388815471108263,
},
{
ID: 3,
Apply: "SELECT 1;",
Rollback: "SELECT 1;",
name: "table set",
schema: "",
table: "test_table",
wantTableName: "public.test_table",
wantLockNum: 8592189678091584965,
},
})
{
name: "schema and table set",
schema: "test_schema",
table: "test_table",
wantTableName: "test_schema.test_table",
wantLockNum: 4631047095544292572,
},
}

for _, tc := range testCases {
m := NewMigrator(sqldb, Config{
Schema: tc.schema,
Table: tc.table,
})
mustEqual(t, m.cfg.tableName, tc.wantTableName)
mustEqual(t, m.cfg.lockNum, tc.wantLockNum)
}
}

func TestMigrate_ApplyAll(t *testing.T) {
newSuite().ApplyAll(t)
}

errRun := dbump.Run(context.Background(), m, l)
failIfErr(t, errRun)
func TestMigrate_ApplyOne(t *testing.T) {
newSuite().ApplyOne(t)
}

func failIfErr(tb testing.TB, err error) {
func TestMigrate_ApplyAllWhenFull(t *testing.T) {
newSuite().ApplyAllWhenFull(t)
}

func TestMigrate_RevertOne(t *testing.T) {
newSuite().RevertOne(t)
}

func TestMigrate_RevertAllWhenEmpty(t *testing.T) {
newSuite().RevertAllWhenEmpty(t)
}

func TestMigrate_RevertAll(t *testing.T) {
newSuite().RevertAll(t)
}

func TestMigrate_Redo(t *testing.T) {
newSuite().Redo(t)
}

func newSuite() *tests.MigratorSuite {
m := NewMigrator(sqldb, Config{})
suite := tests.NewMigratorSuite(m)
suite.ApplyTmpl = "CREATE TABLE public.%[1]s_%[2]d (id INT);"
suite.RevertTmpl = "DROP TABLE public.%[1]s_%[2]d;"
suite.CleanMigTmpl = "DROP TABLE IF EXISTS public.%[1]s_%[2]d;"
suite.CleanTest = "TRUNCATE TABLE _dbump_log;"
return suite
}

func mustEqual(tb testing.TB, got, want interface{}) {
tb.Helper()
if err != nil {
tb.Fatal(err)
if !reflect.DeepEqual(got, want) {
tb.Fatalf("\nhave %+v\nwant %+v", got, want)
}
}

func envOrDef(env, def string) string {
if val := os.Getenv(env); val != "" {
return val
}
return def
}

0 comments on commit 53e63dc

Please sign in to comment.