diff --git a/pkg/service/backup/backupspec/location.go b/pkg/service/backup/backupspec/location.go index fe10d9a38..5586a0e41 100644 --- a/pkg/service/backup/backupspec/location.go +++ b/pkg/service/backup/backupspec/location.go @@ -102,13 +102,19 @@ func NewLocation(location string) (l Location, err error) { } func (l Location) String() string { - p := l.Provider.String() + ":" + l.Path + p := l.StringWithoutDC() if l.DC != "" { p = l.DC + ":" + p } return p } +// StringWithoutDC returns Location string representation +// that lacks DC information. +func (l Location) StringWithoutDC() string { + return l.Provider.String() + ":" + l.Path +} + // Datacenter returns location's datacenter. func (l Location) Datacenter() string { return l.DC diff --git a/pkg/service/restore/batch.go b/pkg/service/restore/batch.go index 3af2469d4..d7f8ee7c9 100644 --- a/pkg/service/restore/batch.go +++ b/pkg/service/restore/batch.go @@ -3,6 +3,7 @@ package restore import ( + "context" "slices" "sync" @@ -10,21 +11,54 @@ import ( . "github.com/scylladb/scylla-manager/v3/pkg/service/backup/backupspec" ) +// batchDispatcher is a tool for batching SSTables from +// Workload across different hosts during restore. +// It follows a few rules: +// +// - it dispatches batches from the RemoteDirWorkload with the biggest +// initial size first +// +// - it aims to optimize batch size according to batchSize param +// +// - it selects the biggest SSTables from RemoteDirWorkload first, +// so that batch contains SSTables of similar size (improved shard utilization) +// +// - it supports batch retry - failed batch can be re-tried by other +// hosts (see wait description for more information) +// +// - it supports host retry - host that failed to restore batch can still +// restore other batches (see hostFailedDC description for more information). type batchDispatcher struct { - mu sync.Mutex - workload []LocationWorkload - batchSize int + // Guards all exported methods + mu sync.Mutex + // When there are no more batches to be restored, + // but some already dispatched batches are still + // being processed, idle hosts waits on wait chan. + // They should wait, as in case currently processed + // batch fails to be restored, they can be waked up + // by batchDispatcher, and re-try to restore returned + // batch on their own. + wait chan struct{} + + // Const workload defined during indexing + workload Workload + // Mutable workloadProgress updated as batches are dispatched + workloadProgress workloadProgress + // For batchSize X, batches contain X*node_shard_cnt SSTables. + // We always multiply batchSize by node_shard_cnt in order to + // utilize all shards more equally. + // For batchSize 0, batches contain N*node_shard_cnt SSTables + // of total size up to 5% of node expected workload + // (expectedShardWorkload*node_shard_cnt). + batchSize int + // Equals total_backup_size/($\sum_{node} shard_cnt(node)$) expectedShardWorkload int64 - hostShardCnt map[string]uint - locationHosts map[Location][]string + // Stores host shard count + hostShardCnt map[string]uint } -func newBatchDispatcher(workload []LocationWorkload, batchSize int, hostShardCnt map[string]uint, locationHosts map[Location][]string) *batchDispatcher { - sortWorkloadBySizeDesc(workload) - var size int64 - for _, t := range workload { - size += t.Size - } +func newBatchDispatcher(workload Workload, batchSize int, hostShardCnt map[string]uint, locationHosts map[Location][]string) *batchDispatcher { + sortWorkload(workload) var shards uint for _, sh := range hostShardCnt { shards += sh @@ -34,14 +68,84 @@ func newBatchDispatcher(workload []LocationWorkload, batchSize int, hostShardCnt } return &batchDispatcher{ mu: sync.Mutex{}, + wait: make(chan struct{}), workload: workload, + workloadProgress: newWorkloadProgress(workload, locationHosts), batchSize: batchSize, - expectedShardWorkload: size / int64(shards), + expectedShardWorkload: workload.TotalSize / int64(shards), hostShardCnt: hostShardCnt, - locationHosts: locationHosts, } } +// Describes current state of SSTables that are yet to be batched. +type workloadProgress struct { + // Bytes that are yet to be restored from given backed up DC. + // They are decreased after a successful batch restoration. + dcBytesToBeRestored map[string]int64 + // Marks which host failed to restore batches from which DCs. + // When host failed to restore a batch from one backed up DC, + // it can still restore other batches coming from different + // DCs. This is a host re-try mechanism aiming to help with #3871. + hostFailedDC map[string][]string + // Stores which hosts have access to restore which DCs. + // It assumes that the whole DC is backed up to a single + // backup location. + hostDCAccess map[string][]string + // SSTables grouped by RemoteSSTableDir that are yet to + // be batched. They are removed on batch dispatch, but can + // be re-added when batch failed to be restored. + // workloadProgress.remoteDir and Workload.RemoteDir have + // corresponding indexes. + remoteDir []remoteSSTableDirProgress +} + +// Describes current state of SSTables from given RemoteSSTableDir +// that are yet to be batched. +type remoteSSTableDirProgress struct { + RemainingSize int64 + RemainingSSTables []RemoteSSTable +} + +func newWorkloadProgress(workload Workload, locationHosts map[Location][]string) workloadProgress { + dcBytes := make(map[string]int64) + locationDC := make(map[string][]string) + p := make([]remoteSSTableDirProgress, len(workload.RemoteDir)) + for i, rdw := range workload.RemoteDir { + dcBytes[rdw.DC] += rdw.Size + locationDC[rdw.Location.StringWithoutDC()] = append(locationDC[rdw.Location.StringWithoutDC()], rdw.DC) + p[i] = remoteSSTableDirProgress{ + RemainingSize: rdw.Size, + RemainingSSTables: rdw.SSTables, + } + } + hostDCAccess := make(map[string][]string) + for loc, hosts := range locationHosts { + for _, h := range hosts { + hostDCAccess[h] = append(hostDCAccess[h], locationDC[loc.StringWithoutDC()]...) + } + } + return workloadProgress{ + dcBytesToBeRestored: dcBytes, + hostFailedDC: make(map[string][]string), + hostDCAccess: hostDCAccess, + remoteDir: p, + } +} + +// Checks if given host finished restoring all that it could. +func (wp workloadProgress) isDone(host string) bool { + failed := wp.hostFailedDC[host] + for _, dc := range wp.hostDCAccess[host] { + // Host isn't done when there is still some data to be restored + // from a DC that it has access to, and it didn't previously fail + // to restore data from this DC. + if !slices.Contains(failed, dc) && wp.dcBytesToBeRestored[dc] != 0 { + return false + } + } + return true +} + type batch struct { TableName *ManifestInfo @@ -89,107 +193,106 @@ func (b batch) IDs() []string { return ids } -// ValidateAllDispatched returns error if not all sstables were dispatched. -func (b *batchDispatcher) ValidateAllDispatched() error { - for _, lw := range b.workload { - if lw.Size != 0 { - for _, tw := range lw.Tables { - if tw.Size != 0 { - for _, dw := range tw.RemoteDirs { - if dw.Size != 0 || len(dw.SSTables) != 0 { - return errors.Errorf("expected all data to be restored, missing sstable ids from location %s table %s.%s: %v (%d bytes)", - dw.Location, dw.Keyspace, dw.Table, dw.SSTables, dw.Size) - } - } - return errors.Errorf("expected all data to be restored, missinng table from location %s: %s.%s (%d bytes)", - tw.Location, tw.Keyspace, tw.Table, tw.Size) - } - } - return errors.Errorf("expected all data to be restored, missinng location: %s (%d bytes)", - lw.Location, lw.Size) - } - } - return nil -} - -// DispatchBatch batch to be restored or false when there is no more work to do. -func (b *batchDispatcher) DispatchBatch(host string) (batch, bool) { - b.mu.Lock() - defer b.mu.Unlock() +// ValidateAllDispatched returns error if not all SSTables were dispatched. +func (bd *batchDispatcher) ValidateAllDispatched() error { + bd.mu.Lock() + defer bd.mu.Unlock() - l := b.chooseLocation(host) - if l == nil { - return batch{}, false - } - t := b.chooseTable(l) - if t == nil { - return batch{}, false + for i, rdp := range bd.workloadProgress.remoteDir { + if rdp.RemainingSize != 0 || len(rdp.RemainingSSTables) != 0 { + rdw := bd.workload.RemoteDir[i] + return errors.Errorf("failed to restore sstables from location %s table %s.%s (%d bytes). See logs for more info", + rdw.Location, rdw.Keyspace, rdw.Table, rdw.Size) + } } - dir := b.chooseRemoteDir(t) - if dir == nil { - return batch{}, false + for dc, bytes := range bd.workloadProgress.dcBytesToBeRestored { + if bytes != 0 { + return errors.Errorf("expected all data from DC %q to be restored (missing %d bytes): "+ + "internal progress calculation error", dc, bytes) + } } - return b.createBatch(l, t, dir, host) + return nil } -// Returns location for which batch should be created. -func (b *batchDispatcher) chooseLocation(host string) *LocationWorkload { - for i := range b.workload { - if b.workload[i].Size == 0 { - continue +// DispatchBatch returns batch to be restored or false when there is no more work to do. +// This method might hang and wait for sstables that might come from batches that +// failed to be restored (see batchDispatcher.wait description for more information). +// Because of that, it's important to call ReportSuccess or ReportFailure after +// each dispatched batch was attempted to be restored. +func (bd *batchDispatcher) DispatchBatch(ctx context.Context, host string) (batch, bool) { + for { + if ctx.Err() != nil { + return batch{}, false } - if slices.Contains(b.locationHosts[b.workload[i].Location], host) { - return &b.workload[i] + bd.mu.Lock() + // Check if there is anything to do for this host + if bd.workloadProgress.isDone(host) { + bd.mu.Unlock() + return batch{}, false + } + // Try to dispatch batch + b, ok := bd.dispatchBatch(host) + wait := bd.wait + bd.mu.Unlock() + if ok { + return b, true + } + // Wait for SSTables that might return after failure + select { + case <-ctx.Done(): + case <-wait: } } - return nil } -// Returns table for which batch should be created. -func (b *batchDispatcher) chooseTable(location *LocationWorkload) *TableWorkload { - for i := range location.Tables { - if location.Tables[i].Size == 0 { +func (bd *batchDispatcher) dispatchBatch(host string) (batch, bool) { + dirIdx := -1 + for i := range bd.workloadProgress.remoteDir { + rdw := bd.workload.RemoteDir[i] + // Skip empty dir + if bd.workloadProgress.remoteDir[i].RemainingSize == 0 { continue } - return &location.Tables[i] - } - return nil -} - -// Return remote dir for which batch should be created. -func (b *batchDispatcher) chooseRemoteDir(table *TableWorkload) *RemoteDirWorkload { - for i := range table.RemoteDirs { - if table.RemoteDirs[i].Size == 0 { + // Skip dir from already failed dc + if slices.Contains(bd.workloadProgress.hostFailedDC[host], rdw.DC) { + continue + } + // Sip dir from location without access + if !slices.Contains(bd.workloadProgress.hostDCAccess[host], rdw.DC) { continue } - return &table.RemoteDirs[i] + dirIdx = i + break } - return nil + if dirIdx < 0 { + return batch{}, false + } + return bd.createBatch(dirIdx, host) } -// Returns batch and updates RemoteDirWorkload and its parents. -func (b *batchDispatcher) createBatch(l *LocationWorkload, t *TableWorkload, dir *RemoteDirWorkload, host string) (batch, bool) { - shardCnt := b.hostShardCnt[host] +// Returns batch from given RemoteSSTableDir and updates workloadProgress. +func (bd *batchDispatcher) createBatch(dirIdx int, host string) (batch, bool) { + rdp := &bd.workloadProgress.remoteDir[dirIdx] + shardCnt := bd.hostShardCnt[host] if shardCnt == 0 { shardCnt = 1 } - var i int var size int64 - if b.batchSize == maxBatchSize { + if bd.batchSize == maxBatchSize { // Create batch containing multiple of node shard count sstables // and size up to 5% of expected node workload. - expectedNodeWorkload := b.expectedShardWorkload * int64(shardCnt) + expectedNodeWorkload := bd.expectedShardWorkload * int64(shardCnt) sizeLimit := expectedNodeWorkload / 20 for { for j := 0; j < int(shardCnt); j++ { - if i >= len(dir.SSTables) { + if i >= len(rdp.RemainingSSTables) { break } - size += dir.SSTables[i].Size + size += rdp.RemainingSSTables[i].Size i++ } - if i >= len(dir.SSTables) { + if i >= len(rdp.RemainingSSTables) { break } if size > sizeLimit { @@ -198,9 +301,9 @@ func (b *batchDispatcher) createBatch(l *LocationWorkload, t *TableWorkload, dir } } else { // Create batch containing node_shard_count*batch_size sstables. - i = min(b.batchSize*int(shardCnt), len(dir.SSTables)) + i = min(bd.batchSize*int(shardCnt), len(rdp.RemainingSSTables)) for j := 0; j < i; j++ { - size += dir.SSTables[j].Size + size += rdp.RemainingSSTables[j].Size } } @@ -209,44 +312,85 @@ func (b *batchDispatcher) createBatch(l *LocationWorkload, t *TableWorkload, dir } // Extend batch if it was to leave less than // 1 sstable per shard for the next one. - if len(dir.SSTables)-i < int(shardCnt) { - for ; i < len(dir.SSTables); i++ { - size += dir.SSTables[i].Size + if len(rdp.RemainingSSTables)-i < int(shardCnt) { + for ; i < len(rdp.RemainingSSTables); i++ { + size += rdp.RemainingSSTables[i].Size } } - sstables := dir.SSTables[:i] - dir.SSTables = dir.SSTables[i:] + sstables := rdp.RemainingSSTables[:i] + rdp.RemainingSSTables = rdp.RemainingSSTables[i:] + rdw := bd.workload.RemoteDir[dirIdx] - dir.Size -= size - t.Size -= size - l.Size -= size + rdp.RemainingSize -= size return batch{ - TableName: dir.TableName, - ManifestInfo: dir.ManifestInfo, - RemoteSSTableDir: dir.RemoteSSTableDir, + TableName: rdw.TableName, + ManifestInfo: rdw.ManifestInfo, + RemoteSSTableDir: rdw.RemoteSSTableDir, Size: size, SSTables: sstables, }, true } -func sortWorkloadBySizeDesc(workload []LocationWorkload) { - slices.SortFunc(workload, func(a, b LocationWorkload) int { +// ReportSuccess notifies batchDispatcher that given batch was restored successfully. +func (bd *batchDispatcher) ReportSuccess(b batch) { + bd.mu.Lock() + defer bd.mu.Unlock() + + dcBytes := bd.workloadProgress.dcBytesToBeRestored + dcBytes[b.DC] -= b.Size + // Mark batching as finished due to successful restore + if dcBytes[b.DC] == 0 { + bd.wakeUpWaiting() + } +} + +// ReportFailure notifies batchDispatcher that given batch failed to be restored. +func (bd *batchDispatcher) ReportFailure(host string, b batch) error { + bd.mu.Lock() + defer bd.mu.Unlock() + + // Mark failed DC for host + bd.workloadProgress.hostFailedDC[host] = append(bd.workloadProgress.hostFailedDC[host], b.DC) + + dirIdx := -1 + for i := range bd.workload.RemoteDir { + if bd.workload.RemoteDir[i].RemoteSSTableDir == b.RemoteSSTableDir { + dirIdx = i + break + } + } + if dirIdx < 0 { + return errors.Errorf("unknown remote sstable dir %s", b.RemoteSSTableDir) + } + + rdp := &bd.workloadProgress.remoteDir[dirIdx] + rdp.RemainingSSTables = append(b.SSTables, rdp.RemainingSSTables...) + rdp.RemainingSize += b.Size + + bd.wakeUpWaiting() + return nil +} + +func (bd *batchDispatcher) wakeUpWaiting() { + close(bd.wait) + bd.wait = make(chan struct{}) +} + +func sortWorkload(workload Workload) { + // Order remote sstable dirs by table size, then by their size (decreasing). + slices.SortFunc(workload.RemoteDir, func(a, b RemoteDirWorkload) int { + ats := workload.TableSize[a.TableName] + bts := workload.TableSize[b.TableName] + if ats != bts { + return int(bts - ats) + } return int(b.Size - a.Size) }) - for _, loc := range workload { - slices.SortFunc(loc.Tables, func(a, b TableWorkload) int { + // Order sstables by their size (decreasing) + for _, rdw := range workload.RemoteDir { + slices.SortFunc(rdw.SSTables, func(a, b RemoteSSTable) int { return int(b.Size - a.Size) }) - for _, tab := range loc.Tables { - slices.SortFunc(tab.RemoteDirs, func(a, b RemoteDirWorkload) int { - return int(b.Size - a.Size) - }) - for _, dir := range tab.RemoteDirs { - slices.SortFunc(dir.SSTables, func(a, b RemoteSSTable) int { - return int(b.Size - a.Size) - }) - } - } } } diff --git a/pkg/service/restore/batch_test.go b/pkg/service/restore/batch_test.go index cf87a8a51..9f206716e 100644 --- a/pkg/service/restore/batch_test.go +++ b/pkg/service/restore/batch_test.go @@ -17,76 +17,93 @@ func TestBatchDispatcher(t *testing.T) { Provider: "s3", Path: "l2", } - workload := []LocationWorkload{ + + rawWorkload := []RemoteDirWorkload{ + { + ManifestInfo: &backupspec.ManifestInfo{ + Location: l1, + DC: "dc1", + }, + TableName: TableName{ + Keyspace: "ks1", + Table: "t1", + }, + RemoteSSTableDir: "a", + Size: 20, + SSTables: []RemoteSSTable{ + {Size: 5}, + {Size: 15}, + }, + }, { - Location: l1, - Size: 170, - Tables: []TableWorkload{ - { - Size: 60, - RemoteDirs: []RemoteDirWorkload{ - { - RemoteSSTableDir: "a", - Size: 20, - SSTables: []RemoteSSTable{ - {Size: 5}, - {Size: 15}, - }, - }, - { - RemoteSSTableDir: "e", - Size: 10, - SSTables: []RemoteSSTable{ - {Size: 2}, - {Size: 4}, - {Size: 4}, - }, - }, - { - RemoteSSTableDir: "b", - Size: 30, - SSTables: []RemoteSSTable{ - {Size: 10}, - {Size: 20}, - }, - }, - }, - }, - { - Size: 110, - RemoteDirs: []RemoteDirWorkload{ - { - RemoteSSTableDir: "c", - Size: 110, - SSTables: []RemoteSSTable{ - {Size: 50}, - {Size: 60}, - }, - }, - }, - }, + ManifestInfo: &backupspec.ManifestInfo{ + Location: l1, + DC: "dc1", + }, + TableName: TableName{ + Keyspace: "ks1", + Table: "t1", + }, + RemoteSSTableDir: "e", + Size: 10, + SSTables: []RemoteSSTable{ + {Size: 2}, + {Size: 4}, + {Size: 4}, }, }, { - Location: l2, - Size: 200, - Tables: []TableWorkload{ - { - Size: 200, - RemoteDirs: []RemoteDirWorkload{ - { - RemoteSSTableDir: "d", - Size: 200, - SSTables: []RemoteSSTable{ - {Size: 110}, - {Size: 90}, - }, - }, - }, - }, + ManifestInfo: &backupspec.ManifestInfo{ + Location: l1, + DC: "dc2", + }, + TableName: TableName{ + Keyspace: "ks1", + Table: "t1", + }, + RemoteSSTableDir: "b", + Size: 30, + SSTables: []RemoteSSTable{ + {Size: 10}, + {Size: 20}, + }, + }, + { + ManifestInfo: &backupspec.ManifestInfo{ + Location: l1, + DC: "dc1", + }, + TableName: TableName{ + Keyspace: "ks1", + Table: "t2", + }, + RemoteSSTableDir: "c", + Size: 110, + SSTables: []RemoteSSTable{ + {Size: 50}, + {Size: 60}, + }, + }, + { + ManifestInfo: &backupspec.ManifestInfo{ + Location: l2, + DC: "dc3", + }, + TableName: TableName{ + Keyspace: "ks1", + Table: "t2", + }, + RemoteSSTableDir: "d", + Size: 200, + SSTables: []RemoteSSTable{ + {Size: 110}, + {Size: 90}, }, }, } + + workload := aggregateWorkload(rawWorkload) + locationHosts := map[backupspec.Location][]string{ l1: {"h1", "h2"}, l2: {"h3"}, @@ -105,34 +122,43 @@ func TestBatchDispatcher(t *testing.T) { dir string size int64 count int + err bool }{ {host: "h1", ok: true, dir: "c", size: 60, count: 1}, - {host: "h1", ok: true, dir: "c", size: 50, count: 1}, + {host: "h1", ok: true, dir: "c", size: 50, count: 1, err: true}, + {host: "h1", ok: true, dir: "b", size: 20, count: 1}, // host retry in different dc + {host: "h2", ok: true, dir: "c", size: 50, count: 1}, // batch retry + {host: "h1", ok: true, dir: "b", size: 10, count: 1, err: true}, + {host: "h1"}, // already failed in all dcs + {host: "h2", ok: true, dir: "b", size: 10, count: 1}, // batch retry {host: "h2", ok: true, dir: "b", size: 30, count: 2}, {host: "h3", ok: true, dir: "d", size: 200, count: 2}, - {host: "h3", ok: false}, + {host: "h3"}, {host: "h2", ok: true, dir: "a", size: 20, count: 2}, {host: "h2", ok: true, dir: "e", size: 10, count: 3}, // batch extended with leftovers < shard_cnt - {host: "h1", ok: false}, - {host: "h2", ok: false}, + {host: "h1"}, + {host: "h2"}, } for _, step := range scenario { - b, ok := bd.DispatchBatch(step.host) + // use dispatchBatch instead of DispatchBatch because + // we don't want to hang here. + b, ok := bd.dispatchBatch(step.host) if ok != step.ok { - t.Fatalf("Step: %+v, expected ok=%v, got ok=%v", step, step.ok, ok) + t.Errorf("Expected %v, got %#v", step, b) } if ok == false { - continue - } - if b.RemoteSSTableDir != step.dir { - t.Fatalf("Step: %+v, expected dir=%v, got dir=%v", step, step.dir, b.RemoteSSTableDir) + return } - if b.Size != step.size { - t.Fatalf("Step: %+v, expected size=%v, got size=%v", step, step.size, b.Size) + if b.RemoteSSTableDir != step.dir || b.Size != step.size || len(b.SSTables) != step.count { + t.Errorf("Expected %v, got %#v", step, b) } - if len(b.SSTables) != step.count { - t.Fatalf("Step: %+v, expected count=%v, got count=%v", step, step.count, len(b.SSTables)) + if step.err { + if err := bd.ReportFailure(step.host, b); err != nil { + t.Fatal(err) + } + } else { + bd.ReportSuccess(b) } } diff --git a/pkg/service/restore/index.go b/pkg/service/restore/index.go index dd7b7b72c..32f91dee9 100644 --- a/pkg/service/restore/index.go +++ b/pkg/service/restore/index.go @@ -12,23 +12,12 @@ import ( "github.com/scylladb/scylla-manager/v3/pkg/sstable" ) -// LocationWorkload represents aggregated restore workload -// in given backup location. -type LocationWorkload struct { - Location - - Size int64 - Tables []TableWorkload -} - -// TableWorkload represents restore workload -// from many manifests for given table in given backup location. -type TableWorkload struct { - Location - TableName - - Size int64 - RemoteDirs []RemoteDirWorkload +// Workload represents total restore workload. +type Workload struct { + TotalSize int64 + LocationSize map[Location]int64 + TableSize map[TableName]int64 + RemoteDir []RemoteDirWorkload } // RemoteDirWorkload represents restore workload @@ -56,32 +45,32 @@ type SSTable struct { } // IndexWorkload returns sstables to be restored aggregated by location, table and remote sstable dir. -func (w *tablesWorker) IndexWorkload(ctx context.Context, locations []Location) ([]LocationWorkload, error) { - var workload []LocationWorkload +func (w *tablesWorker) IndexWorkload(ctx context.Context, locations []Location) (Workload, error) { + var rawWorkload []RemoteDirWorkload for _, l := range locations { lw, err := w.indexLocationWorkload(ctx, l) if err != nil { - return nil, errors.Wrapf(err, "index workload in %s", l) + return Workload{}, errors.Wrapf(err, "index workload in %s", l) } - workload = append(workload, lw) + rawWorkload = append(rawWorkload, lw...) } + workload := aggregateWorkload(rawWorkload) + w.logWorkloadInfo(ctx, workload) return workload, nil } -func (w *tablesWorker) indexLocationWorkload(ctx context.Context, location Location) (LocationWorkload, error) { +func (w *tablesWorker) indexLocationWorkload(ctx context.Context, location Location) ([]RemoteDirWorkload, error) { rawWorkload, err := w.createRemoteDirWorkloads(ctx, location) if err != nil { - return LocationWorkload{}, errors.Wrap(err, "create remote dir workloads") + return nil, errors.Wrap(err, "create remote dir workloads") } if w.target.Continue { rawWorkload, err = w.filterPreviouslyRestoredSStables(ctx, rawWorkload) if err != nil { - return LocationWorkload{}, errors.Wrap(err, "filter already restored sstables") + return nil, errors.Wrap(err, "filter already restored sstables") } } - workload := aggregateLocationWorkload(rawWorkload) - w.logWorkloadInfo(ctx, workload) - return workload, nil + return rawWorkload, nil } func (w *tablesWorker) createRemoteDirWorkloads(ctx context.Context, location Location) ([]RemoteDirWorkload, error) { @@ -179,26 +168,22 @@ func (w *tablesWorker) filterPreviouslyRestoredSStables(ctx context.Context, raw return filtered, nil } -func (w *tablesWorker) initMetrics(workload []LocationWorkload) { +func (w *tablesWorker) initMetrics(workload Workload) { // For now, the only persistent across task runs metrics are progress and remaining_bytes. // The rest: state, view_build_status, batch_size are calculated from scratch. w.metrics.ResetClusterMetrics(w.run.ClusterID) // Init remaining bytes - for _, wl := range workload { - for _, twl := range wl.Tables { - for _, rdwl := range twl.RemoteDirs { - w.metrics.SetRemainingBytes(metrics.RestoreBytesLabels{ - ClusterID: rdwl.ClusterID.String(), - SnapshotTag: rdwl.SnapshotTag, - Location: rdwl.Location.String(), - DC: rdwl.DC, - Node: rdwl.NodeID, - Keyspace: rdwl.Keyspace, - Table: rdwl.Table, - }, rdwl.Size) - } - } + for _, rdw := range workload.RemoteDir { + w.metrics.SetRemainingBytes(metrics.RestoreBytesLabels{ + ClusterID: rdw.ClusterID.String(), + SnapshotTag: rdw.SnapshotTag, + Location: rdw.Location.String(), + DC: rdw.DC, + Node: rdw.NodeID, + Keyspace: rdw.Keyspace, + Table: rdw.Table, + }, rdw.Size) } // Init progress @@ -206,87 +191,59 @@ func (w *tablesWorker) initMetrics(workload []LocationWorkload) { for _, u := range w.run.Units { totalSize += u.Size } - var workloadSize int64 - for _, wl := range workload { - workloadSize += wl.Size - } w.metrics.SetProgress(metrics.RestoreProgressLabels{ ClusterID: w.run.ClusterID.String(), SnapshotTag: w.run.SnapshotTag, - }, float64(totalSize-workloadSize)/float64(totalSize)*100) + }, float64(totalSize-workload.TotalSize)/float64(totalSize)*100) } -func (w *tablesWorker) logWorkloadInfo(ctx context.Context, workload LocationWorkload) { - if workload.Size == 0 { - return +func (w *tablesWorker) logWorkloadInfo(ctx context.Context, workload Workload) { + for loc, size := range workload.LocationSize { + w.logger.Info(ctx, "Location workload", + "location", loc, + "size", size) + } + for tab, size := range workload.TableSize { + w.logger.Info(ctx, "Table workload", + "table", tab, + "size", size) } - var locMax, locCnt int64 - for _, twl := range workload.Tables { - if twl.Size == 0 { + for _, rdw := range workload.RemoteDir { + cnt := int64(len(rdw.SSTables)) + if cnt == 0 { + w.logger.Info(ctx, "Empty remote dir workload", "path", rdw.RemoteSSTableDir) continue } - var tabMax, tabCnt int64 - for _, rdwl := range twl.RemoteDirs { - if rdwl.Size == 0 { - continue - } - var dirMax int64 - for _, sst := range rdwl.SSTables { - dirMax = max(dirMax, sst.Size) - } - dirCnt := int64(len(rdwl.SSTables)) - w.logger.Info(ctx, "Remote sstable dir workload info", - "path", rdwl.RemoteSSTableDir, - "max size", dirMax, - "average size", rdwl.Size/dirCnt, - "count", dirCnt) - tabCnt += dirCnt - tabMax = max(tabMax, dirMax) - } - w.logger.Info(ctx, "Table workload info", - "keyspace", twl.Keyspace, - "table", twl.Table, - "max size", tabMax, - "average size", twl.Size/tabCnt, - "count", tabCnt) - locCnt += tabCnt - locMax = max(locMax, tabMax) - } - w.logger.Info(ctx, "Location workload info", - "location", workload.Location.String(), - "max size", locMax, - "average size", workload.Size/locCnt, - "count", locCnt) -} -func aggregateLocationWorkload(rawWorkload []RemoteDirWorkload) LocationWorkload { - remoteDirWorkloads := make(map[TableName][]RemoteDirWorkload) - for _, rw := range rawWorkload { - remoteDirWorkloads[rw.TableName] = append(remoteDirWorkloads[rw.TableName], rw) - } - - var tableWorkloads []TableWorkload - for _, tw := range remoteDirWorkloads { - var size int64 - for _, rdw := range tw { - size += rdw.Size + var maxSST int64 + for _, sst := range rdw.SSTables { + maxSST = max(maxSST, sst.Size) } - tableWorkloads = append(tableWorkloads, TableWorkload{ - Location: tw[0].Location, - TableName: tw[0].TableName, - Size: size, - RemoteDirs: tw, - }) + w.logger.Info(ctx, "Remote sstable dir workload info", + "path", rdw.RemoteSSTableDir, + "total size", rdw.Size, + "max size", maxSST, + "average size", rdw.Size/cnt, + "count", cnt) } +} - var size int64 - for _, tw := range tableWorkloads { - size += tw.Size +func aggregateWorkload(rawWorkload []RemoteDirWorkload) Workload { + var ( + totalSize int64 + locationSize = make(map[Location]int64) + tableSize = make(map[TableName]int64) + ) + for _, rdw := range rawWorkload { + totalSize += rdw.Size + locationSize[rdw.Location] += rdw.Size + tableSize[rdw.TableName] += rdw.Size } - return LocationWorkload{ - Location: tableWorkloads[0].Location, - Size: size, - Tables: tableWorkloads, + return Workload{ + TotalSize: totalSize, + LocationSize: locationSize, + TableSize: tableSize, + RemoteDir: rawWorkload, } } diff --git a/pkg/service/restore/model.go b/pkg/service/restore/model.go index da616b515..b8d4bf19d 100644 --- a/pkg/service/restore/model.go +++ b/pkg/service/restore/model.go @@ -296,6 +296,10 @@ type TableName struct { Table string } +func (t TableName) String() string { + return t.Keyspace + "." + t.Table +} + // HostInfo represents host with rclone download config. type HostInfo struct { Host string diff --git a/pkg/service/restore/restore_integration_test.go b/pkg/service/restore/restore_integration_test.go index 318377cd5..7503e1d76 100644 --- a/pkg/service/restore/restore_integration_test.go +++ b/pkg/service/restore/restore_integration_test.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" "fmt" + "maps" "net/http" "strings" "sync/atomic" @@ -16,6 +17,8 @@ import ( "time" "github.com/pkg/errors" + "github.com/scylladb/go-log" + "github.com/scylladb/scylla-manager/v3/pkg/scyllaclient" "github.com/scylladb/scylla-manager/v3/pkg/service/backup" . "github.com/scylladb/scylla-manager/v3/pkg/service/backup/backupspec" . "github.com/scylladb/scylla-manager/v3/pkg/testutils" @@ -25,6 +28,7 @@ import ( "github.com/scylladb/scylla-manager/v3/pkg/util/maputil" "github.com/scylladb/scylla-manager/v3/pkg/util/query" "github.com/scylladb/scylla-manager/v3/pkg/util/uuid" + "go.uber.org/zap/zapcore" ) func TestRestoreTablesUserIntegration(t *testing.T) { @@ -691,3 +695,266 @@ func TestRestoreTablesPreparationIntegration(t *testing.T) { Print("Validate table contents") h.validateIdenticalTables(t, []table{{ks: ks, tab: tab}}) } + +func TestRestoreTablesBatchRetryIntegration(t *testing.T) { + h := newTestHelper(t, ManagedSecondClusterHosts(), ManagedClusterHosts()) + // Ensure no built-in retries + clientCfg := scyllaclient.TestConfig(ManagedClusterHosts(), AgentAuthToken()) + clientCfg.Backoff.MaxRetries = 0 + h.dstCluster.Client = newTestClient(t, h.dstCluster.Hrt, log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("client"), &clientCfg) + + Print("Keyspace setup") + ksStmt := "CREATE KEYSPACE %q WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': %d}" + ks := randomizedName("batch_retry_1_") + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(ksStmt, ks, 1)) + ExecStmt(t, h.dstCluster.rootSession, fmt.Sprintf(ksStmt, ks, 1)) + + Print("Table setup") + tabStmt := "CREATE TABLE %q.%q (id int PRIMARY KEY, data int)" + tab1 := randomizedName("tab_1_") + tab2 := randomizedName("tab_2_") + tab3 := randomizedName("tab_3_") + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(tabStmt, ks, tab1)) + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(tabStmt, ks, tab2)) + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(tabStmt, ks, tab3)) + ExecStmt(t, h.dstCluster.rootSession, fmt.Sprintf(tabStmt, ks, tab1)) + ExecStmt(t, h.dstCluster.rootSession, fmt.Sprintf(tabStmt, ks, tab2)) + ExecStmt(t, h.dstCluster.rootSession, fmt.Sprintf(tabStmt, ks, tab3)) + + Print("Fill setup") + fillTable(t, h.srcCluster.rootSession, 100, ks, tab1, tab2, tab3) + + Print("Run backup") + loc := []Location{testLocation("batch-retry", "")} + S3InitBucket(t, loc[0].Path) + ksFilter := []string{ks} + tag := h.runBackup(t, map[string]any{ + "location": loc, + "keyspace": ksFilter, + "batch_size": 100, + }) + + downloadErr := errors.New("fake download error") + lasErr := errors.New("fake las error") + props := map[string]any{ + "location": loc, + "keyspace": ksFilter, + "snapshot_tag": tag, + "restore_tables": true, + } + + t.Run("batch retry finished with success", func(t *testing.T) { + Print("Inject errors to some download and las calls") + downloadCnt := atomic.Int64{} + lasCnt := atomic.Int64{} + h.dstCluster.Hrt.SetInterceptor(httpx.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + // For this setup, we have 6 remote sstable dirs and 6 workers. + // We inject 2 errors during download and 3 errors during LAS. + // This means that only a single node will be restoring at the end. + // Huge batch size and 3 LAS errors guarantee total 9 calls to LAS. + // The last failed call to LAS (cnt=8) waits a bit so that we test + // that batch dispatcher correctly reuses and releases nodes waiting + // for failed sstables to come back to the batch dispatcher. + if strings.HasPrefix(req.URL.Path, "/agent/rclone/sync/copypaths") { + if cnt := downloadCnt.Add(1); cnt == 1 || cnt == 3 { + t.Log("Fake download error ", cnt) + return nil, downloadErr + } + } + if strings.HasPrefix(req.URL.Path, "/storage_service/sstables/") { + cnt := lasCnt.Add(1) + if cnt == 8 { + time.Sleep(15 * time.Second) + } + if cnt == 1 || cnt == 5 || cnt == 8 { + t.Log("Fake LAS error ", cnt) + return nil, lasErr + } + } + return nil, nil + })) + + Print("Run restore") + grantRestoreTablesPermissions(t, h.dstCluster.rootSession, ksFilter, h.dstUser) + h.runRestore(t, props) + + Print("Validate success") + if cnt := lasCnt.Add(0); cnt < 9 { + t.Fatalf("Expected at least 9 calls to LAS, got %d", cnt) + } + validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, tab1, "id", "data") + validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, tab2, "id", "data") + validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, tab3, "id", "data") + }) + + t.Run("restore with injected failures only", func(t *testing.T) { + Print("Inject errors to all download and las calls") + reachedDataStage := atomic.Bool{} + reachedDataStageChan := make(chan struct{}) + h.dstCluster.Hrt.SetInterceptor(httpx.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.Path, "/agent/rclone/sync/copypaths") { + if reachedDataStage.CompareAndSwap(false, true) { + close(reachedDataStageChan) + } + return nil, downloadErr + } + if strings.HasPrefix(req.URL.Path, "/storage_service/sstables/") { + return nil, lasErr + } + return nil, nil + })) + + Print("Run restore") + grantRestoreTablesPermissions(t, h.dstCluster.rootSession, ksFilter, h.dstUser) + h.dstCluster.TaskID = uuid.NewTime() + h.dstCluster.RunID = uuid.NewTime() + rawProps, err := json.Marshal(props) + if err != nil { + t.Fatal(errors.Wrap(err, "marshal properties")) + } + res := make(chan error) + go func() { + res <- h.dstRestoreSvc.Restore(context.Background(), h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID, rawProps) + }() + + Print("Wait for data stage") + select { + case <-reachedDataStageChan: + case err := <-res: + t.Fatalf("Restore finished before reaching data stage with: %s", err) + } + + Print("Validate restore failure and that it does not hang") + select { + case err := <-res: + if err == nil { + t.Fatalf("Expected restore to end with error") + } + case <-time.NewTimer(time.Minute).C: + t.Fatal("Restore hanged") + } + }) + + t.Run("paused restore with slow calls to download and las", func(t *testing.T) { + Print("Make download and las calls slow") + reachedDataStage := atomic.Bool{} + reachedDataStageChan := make(chan struct{}) + h.dstCluster.Hrt.SetInterceptor(httpx.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.Path, "/agent/rclone/sync/copypaths") || + strings.HasPrefix(req.URL.Path, "/storage_service/sstables/") { + if reachedDataStage.CompareAndSwap(false, true) { + close(reachedDataStageChan) + } + time.Sleep(time.Second) + return nil, nil + } + return nil, nil + })) + + Print("Run restore") + grantRestoreTablesPermissions(t, h.dstCluster.rootSession, ksFilter, h.dstUser) + h.dstCluster.TaskID = uuid.NewTime() + h.dstCluster.RunID = uuid.NewTime() + rawProps, err := json.Marshal(props) + if err != nil { + t.Fatal(errors.Wrap(err, "marshal properties")) + } + ctx, cancel := context.WithCancel(context.Background()) + res := make(chan error) + go func() { + res <- h.dstRestoreSvc.Restore(ctx, h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID, rawProps) + }() + + Print("Wait for data stage") + select { + case <-reachedDataStageChan: + cancel() + case err := <-res: + t.Fatalf("Restore finished before reaching data stage with: %s", err) + } + + Print("Validate restore was paused in time") + select { + case err := <-res: + if !errors.Is(err, context.Canceled) { + t.Fatalf("Expected restore to end with context cancelled, got %q", err) + } + case <-time.NewTimer(2 * time.Second).C: + t.Fatal("Restore wasn't paused in time") + } + }) +} + +func TestRestoreTablesMultiLocationIntegration(t *testing.T) { + // Since we need multi-dc clusters for multi-dc backup/restore + // we will use the same cluster as both src and dst. + h := newTestHelper(t, ManagedClusterHosts(), ManagedClusterHosts()) + + Print("Keyspace setup") + ksStmt := "CREATE KEYSPACE %q WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': 1, 'dc2': 1}" + ks := randomizedName("multi_location_") + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(ksStmt, ks)) + + Print("Table setup") + tabStmt := "CREATE TABLE %q.%q (id int PRIMARY KEY, data int)" + tab := randomizedName("tab_") + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(tabStmt, ks, tab)) + + Print("Fill setup") + fillTable(t, h.srcCluster.rootSession, 100, ks, tab) + + Print("Save filled table into map") + srcM := selectTableAsMap[int, int](t, h.srcCluster.rootSession, ks, tab, "id", "data") + + Print("Run backup") + loc := []Location{ + testLocation("multi-location-1", "dc1"), + testLocation("multi-location-2", "dc2"), + } + S3InitBucket(t, loc[0].Path) + S3InitBucket(t, loc[1].Path) + ksFilter := []string{ks} + tag := h.runBackup(t, map[string]any{ + "location": loc, + "keyspace": ksFilter, + "batch_size": 100, + }) + + Print("Truncate backed up table") + truncateStmt := "TRUNCATE TABLE %q.%q" + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(truncateStmt, ks, tab)) + + // Reverse dcs - just for fun + loc[0].DC = "dc2" + loc[1].DC = "dc1" + + Print("Run restore") + grantRestoreTablesPermissions(t, h.dstCluster.rootSession, ksFilter, h.dstUser) + res := make(chan struct{}) + go func() { + h.runRestore(t, map[string]any{ + "location": loc, + "keyspace": ksFilter, + // Test if batching does not hang with + // limited parallel and location access. + "parallel": 1, + "snapshot_tag": tag, + "restore_tables": true, + }) + close(res) + }() + + select { + case <-res: + case <-time.NewTimer(2 * time.Minute).C: + t.Fatal("Restore hanged") + } + + Print("Save restored table into map") + dstM := selectTableAsMap[int, int](t, h.dstCluster.rootSession, ks, tab, "id", "data") + + Print("Validate success") + if !maps.Equal(srcM, dstM) { + t.Fatalf("tables have different contents\nsrc:\n%v\ndst:\n%v", srcM, dstM) + } +} diff --git a/pkg/service/restore/tables_worker.go b/pkg/service/restore/tables_worker.go index e35409f70..4ce143517 100644 --- a/pkg/service/restore/tables_worker.go +++ b/pkg/service/restore/tables_worker.go @@ -16,6 +16,7 @@ import ( "github.com/scylladb/scylla-manager/v3/pkg/util/parallel" "github.com/scylladb/scylla-manager/v3/pkg/util/query" "github.com/scylladb/scylla-manager/v3/pkg/util/uuid" + "go.uber.org/multierr" ) type tablesWorker struct { @@ -201,7 +202,7 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error { bd := newBatchDispatcher(workload, w.target.BatchSize, hostToShard, w.target.locationHosts) - f := func(n int) (err error) { + f := func(n int) error { host := hosts[n] dc, err := w.client.HostDatacenter(ctx, host) if err != nil { @@ -210,8 +211,11 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error { hi := w.hostInfo(host, dc, hostToShard[host]) w.logger.Info(ctx, "Host info", "host", hi.Host, "transfers", hi.Transfers, "rate limit", hi.RateLimit) for { + if ctx.Err() != nil { + return ctx.Err() + } // Download and stream in parallel - b, ok := bd.DispatchBatch(hi.Host) + b, ok := bd.DispatchBatch(ctx, hi.Host) if !ok { w.logger.Info(ctx, "No more batches to restore", "host", hi.Host) return nil @@ -227,11 +231,20 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error { pr, err := w.newRunProgress(ctx, hi, b) if err != nil { - return errors.Wrap(err, "create new run progress") + err = multierr.Append(errors.Wrap(err, "create new run progress"), bd.ReportFailure(hi.Host, b)) + w.logger.Error(ctx, "Failed to create new run progress", + "host", hi.Host, + "error", err) + continue } if err := w.restoreBatch(ctx, b, pr); err != nil { - return errors.Wrap(err, "restore batch") + err = multierr.Append(errors.Wrap(err, "restore batch"), bd.ReportFailure(hi.Host, b)) + w.logger.Error(ctx, "Failed to restore batch", + "host", hi.Host, + "error", err) + continue } + bd.ReportSuccess(b) w.decreaseRemainingBytesMetric(b) } } @@ -245,6 +258,9 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error { err = parallel.Run(len(hosts), w.target.Parallel, f, notify) if err == nil { + if ctx.Err() != nil { + return ctx.Err() + } return bd.ValidateAllDispatched() } return err