diff --git a/dbump_pg/go.mod b/dbump_pg/go.mod index b7cd76b..f992ec9 100644 --- a/dbump_pg/go.mod +++ b/dbump_pg/go.mod @@ -3,7 +3,7 @@ module github.com/cristalhq/dbump/dbump_pg go 1.17 require ( - github.com/cristalhq/dbump v0.3.0 + github.com/cristalhq/dbump v0.14.0 github.com/lib/pq v1.10.6 ) diff --git a/dbump_pg/pg.go b/dbump_pg/pg.go index e646f1b..5f14d31 100644 --- a/dbump_pg/pg.go +++ b/dbump_pg/pg.go @@ -3,70 +3,136 @@ package dbump_pg import ( "context" "database/sql" + "errors" + "fmt" + "hash/fnv" "github.com/cristalhq/dbump" ) -// to prevent multiple migrations running at the same time -const lockNum int64 = 707_707_707 - var _ dbump.Migrator = &Migrator{} // Migrator to migrate Postgres. type Migrator struct { conn *sql.DB + cfg Config +} + +// Config for the migrator. +type Config struct { + // Schema for the dbump version table. Default is empty which means "public" schema. + Schema string + // Table for the dbump version table. Default is empty which means "_dbump_log" table. + Table string + + // [schema.]table + tableName string + // to prevent multiple migrations running at the same time + lockNum int64 } // NewMigrator instantiates new Migrator. // Takes std *sql.DB. -func NewMigrator(conn *sql.DB) *Migrator { +func NewMigrator(conn *sql.DB, cfg Config) *Migrator { + if cfg.Schema == "" { + cfg.Schema = "public" + } + if cfg.Table == "" { + cfg.Table = "_dbump_log" + } + + cfg.tableName = cfg.Schema + "." + cfg.Table + cfg.lockNum = hashTableName(cfg.tableName) + return &Migrator{ conn: conn, + cfg: cfg, } } // Init migrator. func (pg *Migrator) Init(ctx context.Context) error { - query := `CREATE TABLE IF NOT EXISTS _dbump_schema_version ( - version BIGINT NOT NULL PRIMARY KEY, + var query string + if pg.cfg.Schema != "" { + query = fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s;`, pg.cfg.Schema) + } + + query += fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + version BIGINT NOT NULL, created_at TIMESTAMP WITH TIME ZONE NOT NULL -);` +);`, pg.cfg.tableName) + + _, err := pg.conn.ExecContext(ctx, query) + return err +} + +// Drop is a method from Migrator interface. +func (pg *Migrator) Drop(ctx context.Context) error { + query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, pg.cfg.tableName) _, err := pg.conn.ExecContext(ctx, query) return err } -// LockDB is a method for Migrator interface. +// LockDB is a method from Migrator interface. func (pg *Migrator) LockDB(ctx context.Context) error { - _, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_lock($1);", lockNum) + _, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_lock($1);", pg.cfg.lockNum) return err } -// UnlockDB is a method for Migrator interface. +// UnlockDB is a method from Migrator interface. func (pg *Migrator) UnlockDB(ctx context.Context) error { - _, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_unlock($1);", lockNum) + _, err := pg.conn.ExecContext(ctx, "SELECT pg_advisory_unlock($1);", pg.cfg.lockNum) return err } // Version is a method for Migrator interface. func (pg *Migrator) Version(ctx context.Context) (version int, err error) { - query := "SELECT COUNT(*) FROM _dbump_schema_version;" + query := fmt.Sprintf("SELECT version FROM %s ORDER BY created_at DESC LIMIT 1;", pg.cfg.tableName) row := pg.conn.QueryRowContext(ctx, query) err = row.Scan(&version) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return 0, nil + } return version, err } -// SetVersion is a method for Migrator interface. -func (pg *Migrator) SetVersion(ctx context.Context, version int) error { - query := `INSERT INTO _dbump_schema_version (version, created_at) -VALUES ($1, NOW()) -ON CONFLICT (version) DO UPDATE -SET created_at = NOW();` - _, err := pg.conn.ExecContext(ctx, query, version) - return err +// DoStep is a method for Migrator interface. +func (pg *Migrator) DoStep(ctx context.Context, step dbump.Step) error { + if step.DisableTx { + if _, err := pg.conn.ExecContext(ctx, step.Query); err != nil { + return err + } + query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, NOW());", pg.cfg.tableName) + _, err := pg.conn.ExecContext(ctx, query, step.Version) + return err + } + + return pg.beginFunc(ctx, func(tx *sql.Tx) error { + if _, err := tx.ExecContext(ctx, step.Query); err != nil { + return err + } + query := fmt.Sprintf("INSERT INTO %s (version, created_at) VALUES ($1, NOW());", pg.cfg.tableName) + _, err := tx.ExecContext(ctx, query, step.Version) + return err + }) } -// Exec is a method for Migrator interface. -func (pg *Migrator) Exec(ctx context.Context, query string, args ...interface{}) error { - _, err := pg.conn.ExecContext(ctx, query, args...) - return err +func (pg *Migrator) beginFunc(ctx context.Context, f func(*sql.Tx) error) (err error) { + tx, err := pg.conn.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if err := f(tx); err != nil { + return err + } + + return tx.Commit() +} + +func hashTableName(s string) int64 { + h := fnv.New64() + h.Write([]byte(s)) + return int64(h.Sum64()) } diff --git a/dbump_pg/pg_test.go b/dbump_pg/pg_test.go index cc6fcb8..47011ce 100644 --- a/dbump_pg/pg_test.go +++ b/dbump_pg/pg_test.go @@ -1,85 +1,132 @@ -package dbump_pg_test +package dbump_pg import ( - "context" "database/sql" "fmt" "os" + "reflect" "testing" - "github.com/cristalhq/dbump" - "github.com/cristalhq/dbump/dbump_pg" - + "github.com/cristalhq/dbump/tests" _ "github.com/lib/pq" ) var sqldb *sql.DB func init() { - var ( - host = os.Getenv("DBUMP_PG_HOST") - port = os.Getenv("DBUMP_PG_PORT") - username = os.Getenv("DBUMP_PG_USER") - password = os.Getenv("DBUMP_PG_PASS") - db = os.Getenv("DBUMP_PG_DB") - sslmode = os.Getenv("DBUMP_PG_SSL") - ) - - if host == "" { - host = "localhost" - } - if port == "" { - port = "5432" - } - if username == "" { - username = "postgres" - } - if password == "" { - password = "postgres" - } - if db == "" { - db = "postgres" - } - if sslmode == "" { - sslmode = "disable" - } + host := envOrDef("DBUMP_PG_HOST", "localhost") + port := envOrDef("DBUMP_PG_PORT", "5432") + username := envOrDef("DBUMP_PG_USER", "postgres") + password := envOrDef("DBUMP_PG_PASS", "postgres") + db := envOrDef("DBUMP_PG_DB", "postgres") + sslmode := envOrDef("DBUMP_PG_SSL", "disable") + dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", host, port, username, password, db, sslmode) var err error sqldb, err = sql.Open("postgres", dsn) if err != nil { - panic(err) + panic(fmt.Sprintf("dbump_pgx: cannot connect to container: %s", err)) } } -func TestPG_Simple(t *testing.T) { - m := dbump_pg.NewMigrator(sqldb) - l := dbump.NewSliceLoader([]*dbump.Migration{ +func TestNonDefaultSchemaTable(t *testing.T) { + testCases := []struct { + name string + schema string + table string + wantTableName string + wantLockNum int64 + }{ { - ID: 1, - Apply: "SELECT 1;", - Rollback: "SELECT 1;", + name: "all empty", + schema: "", + table: "", + wantTableName: "public._dbump_log", + wantLockNum: 1542931740578198266, }, { - ID: 2, - Apply: "SELECT 1;", - Rollback: "SELECT 1;", + name: "schema set", + schema: "test_schema", + table: "", + wantTableName: "test_schema._dbump_log", + wantLockNum: 1417388815471108263, }, { - ID: 3, - Apply: "SELECT 1;", - Rollback: "SELECT 1;", + name: "table set", + schema: "", + table: "test_table", + wantTableName: "public.test_table", + wantLockNum: 8592189678091584965, }, - }) + { + name: "schema and table set", + schema: "test_schema", + table: "test_table", + wantTableName: "test_schema.test_table", + wantLockNum: 4631047095544292572, + }, + } + + for _, tc := range testCases { + m := NewMigrator(sqldb, Config{ + Schema: tc.schema, + Table: tc.table, + }) + mustEqual(t, m.cfg.tableName, tc.wantTableName) + mustEqual(t, m.cfg.lockNum, tc.wantLockNum) + } +} + +func TestMigrate_ApplyAll(t *testing.T) { + newSuite().ApplyAll(t) +} - errRun := dbump.Run(context.Background(), m, l) - failIfErr(t, errRun) +func TestMigrate_ApplyOne(t *testing.T) { + newSuite().ApplyOne(t) } -func failIfErr(tb testing.TB, err error) { +func TestMigrate_ApplyAllWhenFull(t *testing.T) { + newSuite().ApplyAllWhenFull(t) +} + +func TestMigrate_RevertOne(t *testing.T) { + newSuite().RevertOne(t) +} + +func TestMigrate_RevertAllWhenEmpty(t *testing.T) { + newSuite().RevertAllWhenEmpty(t) +} + +func TestMigrate_RevertAll(t *testing.T) { + newSuite().RevertAll(t) +} + +func TestMigrate_Redo(t *testing.T) { + newSuite().Redo(t) +} + +func newSuite() *tests.MigratorSuite { + m := NewMigrator(sqldb, Config{}) + suite := tests.NewMigratorSuite(m) + suite.ApplyTmpl = "CREATE TABLE public.%[1]s_%[2]d (id INT);" + suite.RevertTmpl = "DROP TABLE public.%[1]s_%[2]d;" + suite.CleanMigTmpl = "DROP TABLE IF EXISTS public.%[1]s_%[2]d;" + suite.CleanTest = "TRUNCATE TABLE _dbump_log;" + return suite +} + +func mustEqual(tb testing.TB, got, want interface{}) { tb.Helper() - if err != nil { - tb.Fatal(err) + if !reflect.DeepEqual(got, want) { + tb.Fatalf("\nhave %+v\nwant %+v", got, want) + } +} + +func envOrDef(env, def string) string { + if val := os.Getenv(env); val != "" { + return val } + return def }