diff --git a/go/cmd/vtctldclient/command/vreplication/vdiff/vdiff.go b/go/cmd/vtctldclient/command/vreplication/vdiff/vdiff.go index 6ad5abd818f..b2aefeee03b 100644 --- a/go/cmd/vtctldclient/command/vreplication/vdiff/vdiff.go +++ b/go/cmd/vtctldclient/command/vreplication/vdiff/vdiff.go @@ -198,7 +198,7 @@ vtctldclient --server localhost:15999 vdiff --workflow commerce2customer --targe for _, shard := range resumeOptions.TargetShards { if !key.IsValidKeyRange(shard) { - return fmt.Errorf("invalid target shard provided: %s", shard) + return fmt.Errorf("invalid target shard provided: %q", shard) } } @@ -250,7 +250,7 @@ vtctldclient --server localhost:15999 vdiff --workflow commerce2customer --targe for _, shard := range stopOptions.TargetShards { if !key.IsValidKeyRange(shard) { - return fmt.Errorf("invalid target shard provided: %s", shard) + return fmt.Errorf("invalid target shard provided: %q", shard) } } diff --git a/go/vt/vtctl/workflow/server.go b/go/vt/vtctl/workflow/server.go index dc45310494e..afffc0dba63 100644 --- a/go/vt/vtctl/workflow/server.go +++ b/go/vt/vtctl/workflow/server.go @@ -1927,10 +1927,8 @@ func (s *Server) VDiffResume(ctx context.Context, req *vtctldatapb.VDiffResumeRe } if len(targetShards) > 0 { - for key, target := range ts.targets { - if !slices.Contains(targetShards, target.GetShard().ShardName()) { - delete(ts.targets, key) - } + if err := applyTargetShards(ts, targetShards); err != nil { + return nil, err } } @@ -2012,10 +2010,8 @@ func (s *Server) VDiffStop(ctx context.Context, req *vtctldatapb.VDiffStopReques } if len(targetShards) > 0 { - for key, target := range ts.targets { - if !slices.Contains(targetShards, target.GetShard().ShardName()) { - delete(ts.targets, key) - } + if err := applyTargetShards(ts, targetShards); err != nil { + return nil, err } } diff --git a/go/vt/vtctl/workflow/utils.go b/go/vt/vtctl/workflow/utils.go index 374d96396f2..cd2ee6ba69f 100644 --- a/go/vt/vtctl/workflow/utils.go +++ b/go/vt/vtctl/workflow/utils.go @@ -967,3 +967,34 @@ func defaultErrorHandler(logger logutil.Logger, message string, err error) (*[]s logger.Error(werr) return nil, werr } + +// applyTargetShards applies the targetShards, coming from a command, to the trafficSwitcher. +// It will return an error if the targetShards list contains a shard that does not exist in +// the target keyspace. +// It will then remove any target shards from the trafficSwitcher that are not in the +// targetShards list. +func applyTargetShards(ts *trafficSwitcher, targetShards []string) error { + if ts == nil { + return nil + } + if ts.targets == nil { + return vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "no targets found for workflow %s", ts.workflow) + } + tsm := make(map[string]struct{}, len(targetShards)) + for _, targetShard := range targetShards { + if _, ok := ts.targets[targetShard]; !ok { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "specified target shard %s not a valid target for workflow %s", + targetShard, ts.workflow) + } + tsm[targetShard] = struct{}{} + } + for key, target := range ts.targets { + if target == nil || target.GetShard() == nil { // Should never happen + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid target found for workflow %s", ts.workflow) + } + if _, ok := tsm[target.GetShard().ShardName()]; !ok { + delete(ts.targets, key) + } + } + return nil +}