Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better dbump_pg #10

Merged
merged 2 commits into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbump_pg/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/cristalhq/dbump/dbump_pg
go 1.17

require (
github.com/cristalhq/dbump v0.3.0
github.com/cristalhq/dbump v0.14.0
github.com/lib/pq v1.10.6
)

Expand Down
114 changes: 90 additions & 24 deletions dbump_pg/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,136 @@ package dbump_pg
import (
"context"
"database/sql"
"errors"
"fmt"
"hash/fnv"

"github.com/cristalhq/dbump"
)

// to prevent multiple migrations running at the same time
const lockNum int64 = 707_707_707

var _ dbump.Migrator = &Migrator{}

// Migrator to migrate Postgres.
type Migrator struct {
conn *sql.DB
cfg Config
}

// Config for the migrator.
type Config struct {
// Schema for the dbump version table. Default is empty which means "public" schema.
Schema string
// Table for the dbump version table. Default is empty which means "_dbump_log" table.
Table string

// [schema.]table
tableName string
// to prevent multiple migrations running at the same time
lockNum int64
}

// NewMigrator instantiates new Migrator.
// Takes std *sql.DB.
func NewMigrator(conn *sql.DB) *Migrator {
func NewMigrator(conn *sql.DB, cfg Config) *Migrator {
if cfg.Schema == "" {
cfg.Schema = "public"
}
if cfg.Table == "" {
cfg.Table = "_dbump_log"
}

cfg.tableName = cfg.Schema + "." + cfg.Table
cfg.lockNum = hashTableName(cfg.tableName)

return &Migrator{
conn: conn,
cfg: cfg,
}
}

// Init migrator.
func (pg *Migrator) Init(ctx context.Context) error {
query := `CREATE TABLE IF NOT EXISTS _dbump_schema_version (
version BIGINT NOT NULL PRIMARY KEY,
var query string
if pg.cfg.Schema != "" {
query = fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s;`, pg.cfg.Schema)
}

query += fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
version BIGINT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL
);`
);`, pg.cfg.tableName)

_, err := pg.conn.ExecContext(ctx, query)
return err
}

// Drop is a method from Migrator interface.
func (pg *Migrator) Drop(ctx context.Context) error {
query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, pg.cfg.tableName)
_, err := pg.conn.ExecContext(ctx, query)
return err
}

// LockDB is a method for Migrator interface.
// LockDB is a method from Migrator interface.
func (pg *Migrator) LockDB(ctx context.Context) error {
_, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_lock($1);", lockNum)
_, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_lock($1);", pg.cfg.lockNum)
return err
}

// UnlockDB is a method for Migrator interface.
// UnlockDB is a method from Migrator interface.
func (pg *Migrator) UnlockDB(ctx context.Context) error {
_, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_unlock($1);", lockNum)
_, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_unlock($1);", pg.cfg.lockNum)
return err
}

// Version is a method for Migrator interface.
func (pg *Migrator) Version(ctx context.Context) (version int, err error) {
query := "SELECT COUNT(*) FROM _dbump_schema_version;"
query := fmt.Sprintf("SELECT version FROM %s ORDER BY created_at DESC LIMIT 1;", pg.cfg.tableName)
row := pg.conn.QueryRowContext(ctx, query)
err = row.Scan(&version)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return version, err
}

// SetVersion is a method for Migrator interface.
func (pg *Migrator) SetVersion(ctx context.Context, version int) error {
query := `INSERT INTO _dbump_schema_version (version, created_at)
VALUES ($1, NOW())
ON CONFLICT (version) DO UPDATE
SET created_at = NOW();`
_, err := pg.conn.ExecContext(ctx, query, version)
return err
// DoStep is a method for Migrator interface.
func (pg *Migrator) DoStep(ctx context.Context, step dbump.Step) error {
if step.DisableTx {
if _, err := pg.conn.ExecContext(ctx, step.Query); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, NOW());", pg.cfg.tableName)
_, err := pg.conn.ExecContext(ctx, query, step.Version)
return err
}

return pg.beginFunc(ctx, func(tx *sql.Tx) error {
if _, err := tx.ExecContext(ctx, step.Query); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, NOW());", pg.cfg.tableName)
_, err := tx.ExecContext(ctx, query, step.Version)
return err
})
}

// Exec is a method for Migrator interface.
func (pg *Migrator) Exec(ctx context.Context, query string, args ...interface{}) error {
_, err := pg.conn.ExecContext(ctx, query, args...)
return err
func (pg *Migrator) beginFunc(ctx context.Context, f func(*sql.Tx) error) (err error) {
tx, err := pg.conn.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()

if err := f(tx); err != nil {
return err
}

return tx.Commit()
}

func hashTableName(s string) int64 {
h := fnv.New64()
h.Write([]byte(s))
return int64(h.Sum64())
}
149 changes: 98 additions & 51 deletions dbump_pg/pg_test.go
Original file line number Diff line number Diff line change
@@ -1,85 +1,132 @@
package dbump_pg_test
package dbump_pg

import (
"context"
"database/sql"
"fmt"
"os"
"reflect"
"testing"

"github.com/cristalhq/dbump"
"github.com/cristalhq/dbump/dbump_pg"

"github.com/cristalhq/dbump/tests"
_ "github.com/lib/pq"
)

var sqldb *sql.DB

func init() {
var (
host = os.Getenv("DBUMP_PG_HOST")
port = os.Getenv("DBUMP_PG_PORT")
username = os.Getenv("DBUMP_PG_USER")
password = os.Getenv("DBUMP_PG_PASS")
db = os.Getenv("DBUMP_PG_DB")
sslmode = os.Getenv("DBUMP_PG_SSL")
)

if host == "" {
host = "localhost"
}
if port == "" {
port = "5432"
}
if username == "" {
username = "postgres"
}
if password == "" {
password = "postgres"
}
if db == "" {
db = "postgres"
}
if sslmode == "" {
sslmode = "disable"
}
host := envOrDef("DBUMP_PG_HOST", "localhost")
port := envOrDef("DBUMP_PG_PORT", "5432")
username := envOrDef("DBUMP_PG_USER", "postgres")
password := envOrDef("DBUMP_PG_PASS", "postgres")
db := envOrDef("DBUMP_PG_DB", "postgres")
sslmode := envOrDef("DBUMP_PG_SSL", "disable")

dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
host, port, username, password, db, sslmode)

var err error
sqldb, err = sql.Open("postgres", dsn)
if err != nil {
panic(err)
panic(fmt.Sprintf("dbump_pgx: cannot connect to container: %s", err))
}
}

func TestPG_Simple(t *testing.T) {
m := dbump_pg.NewMigrator(sqldb)
l := dbump.NewSliceLoader([]*dbump.Migration{
func TestNonDefaultSchemaTable(t *testing.T) {
testCases := []struct {
name string
schema string
table string
wantTableName string
wantLockNum int64
}{
{
ID: 1,
Apply: "SELECT 1;",
Rollback: "SELECT 1;",
name: "all empty",
schema: "",
table: "",
wantTableName: "public._dbump_log",
wantLockNum: 1542931740578198266,
},
{
ID: 2,
Apply: "SELECT 1;",
Rollback: "SELECT 1;",
name: "schema set",
schema: "test_schema",
table: "",
wantTableName: "test_schema._dbump_log",
wantLockNum: 1417388815471108263,
},
{
ID: 3,
Apply: "SELECT 1;",
Rollback: "SELECT 1;",
name: "table set",
schema: "",
table: "test_table",
wantTableName: "public.test_table",
wantLockNum: 8592189678091584965,
},
})
{
name: "schema and table set",
schema: "test_schema",
table: "test_table",
wantTableName: "test_schema.test_table",
wantLockNum: 4631047095544292572,
},
}

for _, tc := range testCases {
m := NewMigrator(sqldb, Config{
Schema: tc.schema,
Table: tc.table,
})
mustEqual(t, m.cfg.tableName, tc.wantTableName)
mustEqual(t, m.cfg.lockNum, tc.wantLockNum)
}
}

func TestMigrate_ApplyAll(t *testing.T) {
newSuite().ApplyAll(t)
}

errRun := dbump.Run(context.Background(), m, l)
failIfErr(t, errRun)
func TestMigrate_ApplyOne(t *testing.T) {
newSuite().ApplyOne(t)
}

func failIfErr(tb testing.TB, err error) {
func TestMigrate_ApplyAllWhenFull(t *testing.T) {
newSuite().ApplyAllWhenFull(t)
}

func TestMigrate_RevertOne(t *testing.T) {
newSuite().RevertOne(t)
}

func TestMigrate_RevertAllWhenEmpty(t *testing.T) {
newSuite().RevertAllWhenEmpty(t)
}

func TestMigrate_RevertAll(t *testing.T) {
newSuite().RevertAll(t)
}

func TestMigrate_Redo(t *testing.T) {
newSuite().Redo(t)
}

func newSuite() *tests.MigratorSuite {
m := NewMigrator(sqldb, Config{})
suite := tests.NewMigratorSuite(m)
suite.ApplyTmpl = "CREATE TABLE public.%[1]s_%[2]d (id INT);"
suite.RevertTmpl = "DROP TABLE public.%[1]s_%[2]d;"
suite.CleanMigTmpl = "DROP TABLE IF EXISTS public.%[1]s_%[2]d;"
suite.CleanTest = "TRUNCATE TABLE _dbump_log;"
return suite
}

func mustEqual(tb testing.TB, got, want interface{}) {
tb.Helper()
if err != nil {
tb.Fatal(err)
if !reflect.DeepEqual(got, want) {
tb.Fatalf("\nhave %+v\nwant %+v", got, want)
}
}

func envOrDef(env, def string) string {
if val := os.Getenv(env); val != "" {
return val
}
return def
}