diff --git a/pkg/service/restore/helper_integration_test.go b/pkg/service/restore/helper_integration_test.go index 8bf63f69a..c6f16e2de 100644 --- a/pkg/service/restore/helper_integration_test.go +++ b/pkg/service/restore/helper_integration_test.go @@ -21,6 +21,7 @@ import ( "github.com/scylladb/gocqlx/v2/qb" "github.com/scylladb/scylla-manager/v3/pkg/service/cluster" "github.com/scylladb/scylla-manager/v3/pkg/util/version" + "go.uber.org/multierr" "go.uber.org/zap/zapcore" "github.com/scylladb/scylla-manager/v3/pkg/metrics" @@ -476,3 +477,63 @@ func checkAnyConstraint(t *testing.T, client *scyllaclient.Client, constraints . } return false } + +func createTable(t *testing.T, session gocqlx.Session, keyspace string, tables ...string) { + for _, tab := range tables { + ExecStmt(t, session, fmt.Sprintf("CREATE TABLE %q.%q (id int PRIMARY KEY, data int)", keyspace, tab)) + } +} + +func fillTable(t *testing.T, session gocqlx.Session, rowCnt int, keyspace string, tables ...string) { + for _, tab := range tables { + stmt := fmt.Sprintf("INSERT INTO %q.%q (id, data) VALUES (?, ?)", keyspace, tab) + q := session.Query(stmt, []string{"id", "data"}) + + for i := 0; i < rowCnt; i++ { + if err := q.Bind(i, i).Exec(); err != nil { + t.Fatal(err) + } + } + + q.Release() + } +} + +func runPausedRestore(t *testing.T, restore func(ctx context.Context) error, pauseInterval time.Duration, minPauseCnt int) (err error) { + t.Helper() + + ticker := time.NewTicker(pauseInterval) + ctx, cancel := context.WithCancel(context.Background()) + res := make(chan error) + pauseCnt := 0 + defer func() { + t.Logf("Restore was paused %d times", pauseCnt) + if pauseCnt < minPauseCnt { + err = multierr.Append(err, errors.Errorf("expected to pause at least %d times, got %d", minPauseCnt, pauseCnt)) + } + }() + + go func() { + res <- restore(ctx) + }() + for { + select { + case err := <-res: + cancel() + return err + case <-ticker.C: + t.Log("Pause restore") + cancel() + err := <-res + if err == nil || !errors.Is(err, context.Canceled) { + return err + } + + pauseCnt++ + ctx, cancel = context.WithCancel(context.Background()) + go func() { + res <- restore(ctx) + }() + } + } +} diff --git a/pkg/service/restore/restore_integration_test.go b/pkg/service/restore/restore_integration_test.go index efbd588df..71cce69af 100644 --- a/pkg/service/restore/restore_integration_test.go +++ b/pkg/service/restore/restore_integration_test.go @@ -7,17 +7,21 @@ package restore_test import ( "context" + "encoding/json" "fmt" "strings" "testing" + "time" "github.com/pkg/errors" + "github.com/scylladb/scylla-manager/v3/pkg/service/backup" . "github.com/scylladb/scylla-manager/v3/pkg/service/backup/backupspec" . "github.com/scylladb/scylla-manager/v3/pkg/testutils" . "github.com/scylladb/scylla-manager/v3/pkg/testutils/db" . "github.com/scylladb/scylla-manager/v3/pkg/testutils/testconfig" "github.com/scylladb/scylla-manager/v3/pkg/util/maputil" "github.com/scylladb/scylla-manager/v3/pkg/util/query" + "github.com/scylladb/scylla-manager/v3/pkg/util/uuid" ) func TestRestoreTablesUserIntegration(t *testing.T) { @@ -334,3 +338,119 @@ func TestRestoreTablesVnodeToTabletsIntegration(t *testing.T) { validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, tab, c1, c2) } + +func TestRestoreTablesPausedIntegration(t *testing.T) { + h := newTestHelper(t, ManagedSecondClusterHosts(), ManagedClusterHosts()) + + // Setup: + // ks1: tab, mv, si + // ks2: tab1, tab2, mv1 + + Print("Keyspace setup") + ksStmt := "CREATE KEYSPACE %q WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': %d}" + ks1 := randomizedName("paused_1_") + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(ksStmt, ks1, 1)) + ExecStmt(t, h.dstCluster.rootSession, fmt.Sprintf(ksStmt, ks1, 1)) + ks2 := randomizedName("paused_2_") + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(ksStmt, ks2, 1)) + ExecStmt(t, h.dstCluster.rootSession, fmt.Sprintf(ksStmt, ks2, 1)) + + Print("Table setup") + tab := randomizedName("tab_") + createTable(t, h.srcCluster.rootSession, ks1, tab) + createTable(t, h.dstCluster.rootSession, ks1, tab) + tab1 := randomizedName("tab_1_") + createTable(t, h.srcCluster.rootSession, ks2, tab1) + createTable(t, h.dstCluster.rootSession, ks2, tab1) + tab2 := randomizedName("tab_2_") + createTable(t, h.srcCluster.rootSession, ks2, tab2) + createTable(t, h.dstCluster.rootSession, ks2, tab2) + + Print("View setup") + mv := randomizedName("mv_") + CreateMaterializedView(t, h.srcCluster.rootSession, ks1, tab, mv) + CreateMaterializedView(t, h.dstCluster.rootSession, ks1, tab, mv) + si := randomizedName("si_") + CreateSecondaryIndex(t, h.srcCluster.rootSession, ks1, tab, si) + CreateSecondaryIndex(t, h.dstCluster.rootSession, ks1, tab, si) + mv1 := randomizedName("mv_1_") + CreateMaterializedView(t, h.srcCluster.rootSession, ks2, tab1, mv1) + CreateMaterializedView(t, h.dstCluster.rootSession, ks2, tab1, mv1) + + Print("Fill setup") + fillTable(t, h.srcCluster.rootSession, 100, ks1, tab) + fillTable(t, h.srcCluster.rootSession, 100, ks2, tab1, tab2) + + units := []backup.Unit{ + { + Keyspace: ks1, + Tables: []string{tab, mv, si + "_index"}, + AllTables: true, + }, + { + Keyspace: ks2, + Tables: []string{tab1, tab2, mv1}, + AllTables: true, + }, + } + + Print("Run backup") + loc := []Location{testLocation("paused", "")} + S3InitBucket(t, loc[0].Path) + + // Starting from SM 3.3.1, SM does not allow to back up views, + // but backed up views should still be tested as older backups might + // contain them. That's why here we manually force backup target + // to contain the views. + ctx := context.Background() + h.srcCluster.TaskID = uuid.NewTime() + h.srcCluster.RunID = uuid.NewTime() + + rawProps, err := json.Marshal(map[string]any{"location": loc}) + if err != nil { + t.Fatal(errors.Wrap(err, "marshal properties")) + } + + target, err := h.srcBackupSvc.GetTarget(ctx, h.srcCluster.ClusterID, rawProps) + if err != nil { + t.Fatal(errors.Wrap(err, "generate target")) + } + target.Units = units + + err = h.srcBackupSvc.Backup(ctx, h.srcCluster.ClusterID, h.srcCluster.TaskID, h.srcCluster.RunID, target) + if err != nil { + t.Fatal(errors.Wrap(err, "run backup")) + } + + pr, err := h.srcBackupSvc.GetProgress(ctx, h.srcCluster.ClusterID, h.srcCluster.TaskID, h.srcCluster.RunID) + if err != nil { + t.Fatal(errors.Wrap(err, "get progress")) + } + tag := pr.SnapshotTag + + Print("Run restore tables") + grantRestoreTablesPermissions(t, h.dstCluster.rootSession, []string{ks1, ks2}, h.dstUser) + props := map[string]any{ + "location": loc, + "keyspace": []string{ks1, ks2}, + "snapshot_tag": tag, + "restore_tables": true, + } + err = runPausedRestore(t, func(ctx context.Context) error { + h.dstCluster.RunID = uuid.NewTime() + rawProps, err := json.Marshal(props) + if err != nil { + return err + } + return h.dstRestoreSvc.Restore(ctx, h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID, rawProps) + }, 45*time.Second, 2) + if err != nil { + t.Fatal(err) + } + + for _, u := range units { + for _, tb := range u.Tables { + validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, u.Keyspace, tb, "id", "data") + } + } +}