diff --git a/go/vt/vtorc/logic/tablet_discovery.go b/go/vt/vtorc/logic/tablet_discovery.go index 51fe5e7c2b9..169057d3c3d 100644 --- a/go/vt/vtorc/logic/tablet_discovery.go +++ b/go/vt/vtorc/logic/tablet_discovery.go @@ -303,17 +303,23 @@ func LockShard(ctx context.Context, instanceKey inst.InstanceKey) (context.Conte // tabletUndoDemotePrimary calls the said RPC for the given tablet. func tabletUndoDemotePrimary(ctx context.Context, tablet *topodatapb.Tablet, semiSync bool) error { - return tmc.UndoDemotePrimary(ctx, tablet, semiSync) + tmcCtx, tmcCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) + defer tmcCancel() + return tmc.UndoDemotePrimary(tmcCtx, tablet, semiSync) } // setReadOnly calls the said RPC for the given tablet func setReadOnly(ctx context.Context, tablet *topodatapb.Tablet) error { - return tmc.SetReadOnly(ctx, tablet) + tmcCtx, tmcCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) + defer tmcCancel() + return tmc.SetReadOnly(tmcCtx, tablet) } // setReplicationSource calls the said RPC with the parameters provided func setReplicationSource(ctx context.Context, replica *topodatapb.Tablet, primary *topodatapb.Tablet, semiSync bool) error { - return tmc.SetReplicationSource(ctx, replica, primary.Alias, 0, "", true, semiSync) + tmcCtx, tmcCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) + defer tmcCancel() + return tmc.SetReplicationSource(tmcCtx, replica, primary.Alias, 0, "", true, semiSync) } // shardPrimary finds the primary of the given keyspace-shard by reading the vtorc backend diff --git a/go/vt/vtorc/logic/tablet_discovery_test.go b/go/vt/vtorc/logic/tablet_discovery_test.go index d43cebefc0f..7790fa997d9 100644 --- a/go/vt/vtorc/logic/tablet_discovery_test.go +++ b/go/vt/vtorc/logic/tablet_discovery_test.go @@ -18,8 +18,10 @@ package logic import ( "context" + "fmt" "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" @@ -30,7 +32,9 @@ import ( topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/proto/vttime" + "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/memorytopo" + "vitess.io/vitess/go/vt/vtctl/grpcvtctldserver/testutil" "vitess.io/vitess/go/vt/vtorc/db" "vitess.io/vitess/go/vt/vtorc/inst" ) @@ -315,3 +319,131 @@ func verifyTabletCount(t *testing.T, countWanted int) { require.NoError(t, err) require.Equal(t, countWanted, totalTablets) } + +func TestSetReadOnly(t *testing.T) { + tests := []struct { + name string + tablet *topodatapb.Tablet + tmc *testutil.TabletManagerClient + remoteOpTimeout time.Duration + errShouldContain string + }{ + { + name: "Success", + tablet: tab100, + tmc: &testutil.TabletManagerClient{ + SetReadOnlyResults: map[string]error{ + "zone-1-0000000100": nil, + }, + }, + }, { + name: "Failure", + tablet: tab100, + tmc: &testutil.TabletManagerClient{ + SetReadOnlyResults: map[string]error{ + "zone-1-0000000100": fmt.Errorf("testing error"), + }, + }, + errShouldContain: "testing error", + }, { + name: "Timeout", + tablet: tab100, + remoteOpTimeout: 100 * time.Millisecond, + tmc: &testutil.TabletManagerClient{ + SetReadOnlyResults: map[string]error{ + "zone-1-0000000100": nil, + }, + SetReadOnlyDelays: map[string]time.Duration{ + "zone-1-0000000100": 200 * time.Millisecond, + }, + }, + errShouldContain: "context deadline exceeded", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oldTmc := tmc + oldRemoteOpTimeout := topo.RemoteOperationTimeout + defer func() { + tmc = oldTmc + topo.RemoteOperationTimeout = oldRemoteOpTimeout + }() + + tmc = tt.tmc + if tt.remoteOpTimeout != 0 { + topo.RemoteOperationTimeout = tt.remoteOpTimeout + } + + err := setReadOnly(context.Background(), tt.tablet) + if tt.errShouldContain == "" { + require.NoError(t, err) + return + } + require.ErrorContains(t, err, tt.errShouldContain) + }) + } +} + +func TestTabletUndoDemotePrimary(t *testing.T) { + tests := []struct { + name string + tablet *topodatapb.Tablet + tmc *testutil.TabletManagerClient + remoteOpTimeout time.Duration + errShouldContain string + }{ + { + name: "Success", + tablet: tab100, + tmc: &testutil.TabletManagerClient{ + UndoDemotePrimaryResults: map[string]error{ + "zone-1-0000000100": nil, + }, + }, + }, { + name: "Failure", + tablet: tab100, + tmc: &testutil.TabletManagerClient{ + UndoDemotePrimaryResults: map[string]error{ + "zone-1-0000000100": fmt.Errorf("testing error"), + }, + }, + errShouldContain: "testing error", + }, { + name: "Timeout", + tablet: tab100, + remoteOpTimeout: 100 * time.Millisecond, + tmc: &testutil.TabletManagerClient{ + UndoDemotePrimaryResults: map[string]error{ + "zone-1-0000000100": nil, + }, + UndoDemotePrimaryDelays: map[string]time.Duration{ + "zone-1-0000000100": 200 * time.Millisecond, + }, + }, + errShouldContain: "context deadline exceeded", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oldTmc := tmc + oldRemoteOpTimeout := topo.RemoteOperationTimeout + defer func() { + tmc = oldTmc + topo.RemoteOperationTimeout = oldRemoteOpTimeout + }() + + tmc = tt.tmc + if tt.remoteOpTimeout != 0 { + topo.RemoteOperationTimeout = tt.remoteOpTimeout + } + + err := tabletUndoDemotePrimary(context.Background(), tt.tablet, false) + if tt.errShouldContain == "" { + require.NoError(t, err) + return + } + require.ErrorContains(t, err, tt.errShouldContain) + }) + } +}