diff --git a/go/cmd/vtctldclient/command/keyspace_routing_rules.go b/go/cmd/vtctldclient/command/keyspace_routing_rules.go index 7d1134d3abf..68aaa35b8bb 100644 --- a/go/cmd/vtctldclient/command/keyspace_routing_rules.go +++ b/go/cmd/vtctldclient/command/keyspace_routing_rules.go @@ -82,7 +82,7 @@ func commandApplyKeyspaceRoutingRules(cmd *cobra.Command, args []string) error { } krr := &vschemapb.KeyspaceRoutingRules{} - if err := json2.Unmarshal(rulesBytes, &krr); err != nil { + if err := json2.UnmarshalPB(rulesBytes, krr); err != nil { return err } diff --git a/go/cmd/vtctldclient/command/routing_rules.go b/go/cmd/vtctldclient/command/routing_rules.go index 0ffee0c2c24..8a228589098 100644 --- a/go/cmd/vtctldclient/command/routing_rules.go +++ b/go/cmd/vtctldclient/command/routing_rules.go @@ -82,7 +82,7 @@ func commandApplyRoutingRules(cmd *cobra.Command, args []string) error { } rr := &vschemapb.RoutingRules{} - if err := json2.Unmarshal(rulesBytes, &rr); err != nil { + if err := json2.UnmarshalPB(rulesBytes, rr); err != nil { return err } diff --git a/go/cmd/vtctldclient/command/shard_routing_rules.go b/go/cmd/vtctldclient/command/shard_routing_rules.go index 10ce7e81747..2214269d0f3 100644 --- a/go/cmd/vtctldclient/command/shard_routing_rules.go +++ b/go/cmd/vtctldclient/command/shard_routing_rules.go @@ -87,7 +87,7 @@ func commandApplyShardRoutingRules(cmd *cobra.Command, args []string) error { } srr := &vschemapb.ShardRoutingRules{} - if err := json2.Unmarshal(rulesBytes, &srr); err != nil { + if err := json2.UnmarshalPB(rulesBytes, srr); err != nil { return err } // Round-trip so when we display the result it's readable. diff --git a/go/json2/unmarshal.go b/go/json2/unmarshal.go index e382b8ad47a..e2034fa71c9 100644 --- a/go/json2/unmarshal.go +++ b/go/json2/unmarshal.go @@ -33,8 +33,7 @@ var carriageReturn = []byte("\n") // efficient and should not be used for high QPS operations. func Unmarshal(data []byte, v any) error { if pb, ok := v.(proto.Message); ok { - opts := protojson.UnmarshalOptions{DiscardUnknown: true} - return annotate(data, opts.Unmarshal(data, pb)) + return UnmarshalPB(data, pb) } return annotate(data, json.Unmarshal(data, v)) } @@ -53,3 +52,9 @@ func annotate(data []byte, err error) error { return fmt.Errorf("line: %d, position %d: %v", line, pos, err) } + +// UnmarshalPB is similar to Unmarshal but specifically for proto.Message to add type safety. +func UnmarshalPB(data []byte, pb proto.Message) error { + opts := protojson.UnmarshalOptions{DiscardUnknown: true} + return annotate(data, opts.Unmarshal(data, pb)) +} diff --git a/go/json2/unmarshal_test.go b/go/json2/unmarshal_test.go index ff18a29def8..1ba3368d5ca 100644 --- a/go/json2/unmarshal_test.go +++ b/go/json2/unmarshal_test.go @@ -91,3 +91,14 @@ func TestAnnotate(t *testing.T) { require.Equal(t, tcase.err, err, "annotate(%s, %v) error", string(tcase.data), tcase.err) } } + +func TestUnmarshalPB(t *testing.T) { + want := &emptypb.Empty{} + json, err := protojson.Marshal(want) + require.NoError(t, err) + + var got emptypb.Empty + err = UnmarshalPB(json, &got) + require.NoError(t, err) + require.Equal(t, want, &got) +} diff --git a/go/test/endtoend/vreplication/vreplication_vtctldclient_cli_test.go b/go/test/endtoend/vreplication/vreplication_vtctldclient_cli_test.go index bca51512a3c..4a3f16a1cc9 100644 --- a/go/test/endtoend/vreplication/vreplication_vtctldclient_cli_test.go +++ b/go/test/endtoend/vreplication/vreplication_vtctldclient_cli_test.go @@ -19,6 +19,7 @@ package vreplication import ( "encoding/json" "fmt" + "os" "slices" "strings" "testing" @@ -27,6 +28,7 @@ import ( "golang.org/x/exp/maps" "google.golang.org/protobuf/encoding/protojson" + "vitess.io/vitess/go/json2" "vitess.io/vitess/go/test/endtoend/cluster" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" @@ -61,6 +63,9 @@ func TestVtctldclientCLI(t *testing.T) { workflowName := "wf1" targetTabs := setupMinimalCustomerKeyspace(t) + t.Run("RoutingRulesApply", func(t *testing.T) { + testRoutingRulesApplyCommands(t) + }) t.Run("WorkflowList", func(t *testing.T) { testWorkflowList(t, sourceKeyspaceName, targetKeyspaceName) }) @@ -438,3 +443,138 @@ func validateMoveTablesWorkflow(t *testing.T, workflows []*vtctldatapb.Workflow) require.Equal(t, binlogdatapb.OnDDLAction_STOP, bls.OnDdl) require.True(t, bls.StopAfterCopy) } + +// Test that routing rules can be applied using the vtctldclient CLI for all types of routing rules. +func testRoutingRulesApplyCommands(t *testing.T) { + var rulesBytes []byte + var err error + var validateRules func(want, got string) + + ruleTypes := []string{"RoutingRules", "ShardRoutingRules", "KeyspaceRoutingRules"} + for _, typ := range ruleTypes { + switch typ { + case "RoutingRules": + rr := &vschemapb.RoutingRules{ + Rules: []*vschemapb.RoutingRule{ + { + FromTable: "from1", + ToTables: []string{"to1", "to2"}, + }, + }, + } + rulesBytes, err = json2.MarshalPB(rr) + require.NoError(t, err) + validateRules = func(want, got string) { + var wantRules = &vschemapb.RoutingRules{} + require.NoError(t, json2.UnmarshalPB([]byte(want), wantRules)) + var gotRules = &vschemapb.RoutingRules{} + require.NoError(t, json2.UnmarshalPB([]byte(got), gotRules)) + require.EqualValues(t, wantRules, gotRules) + } + case "ShardRoutingRules": + srr := &vschemapb.ShardRoutingRules{ + Rules: []*vschemapb.ShardRoutingRule{ + { + FromKeyspace: "from1", + ToKeyspace: "to1", + Shard: "-80", + }, + }, + } + rulesBytes, err = json2.MarshalPB(srr) + require.NoError(t, err) + validateRules = func(want, got string) { + var wantRules = &vschemapb.ShardRoutingRules{} + require.NoError(t, json2.UnmarshalPB([]byte(want), wantRules)) + var gotRules = &vschemapb.ShardRoutingRules{} + require.NoError(t, json2.UnmarshalPB([]byte(got), gotRules)) + require.EqualValues(t, wantRules, gotRules) + } + case "KeyspaceRoutingRules": + krr := &vschemapb.KeyspaceRoutingRules{ + Rules: []*vschemapb.KeyspaceRoutingRule{ + { + FromKeyspace: "from1", + ToKeyspace: "to1", + }, + }, + } + rulesBytes, err = json2.MarshalPB(krr) + require.NoError(t, err) + validateRules = func(want, got string) { + var wantRules = &vschemapb.KeyspaceRoutingRules{} + require.NoError(t, json2.UnmarshalPB([]byte(want), wantRules)) + var gotRules = &vschemapb.KeyspaceRoutingRules{} + require.NoError(t, json2.UnmarshalPB([]byte(got), gotRules)) + require.EqualValues(t, wantRules, gotRules) + } + default: + require.FailNow(t, "Unknown type %s", typ) + } + testOneRoutingRulesCommand(t, typ, string(rulesBytes), validateRules) + } + +} + +// For a given routing rules type, test that the rules can be applied using the vtctldclient CLI. +// We test both inline and file-based rules. +// The test also validates that both camelCase and snake_case key names work correctly. +func testOneRoutingRulesCommand(t *testing.T, typ string, rules string, validateRules func(want, got string)) { + type routingRulesTest struct { + name string + rules string + useFile bool // if true, use a file to pass the rules + } + tests := []routingRulesTest{ + { + name: "inline", + rules: rules, + }, + { + name: "file", + rules: rules, + useFile: true, + }, + { + name: "empty", // finally, cleanup rules + rules: "{}", + }, + } + for _, tt := range tests { + t.Run(typ+"/"+tt.name, func(t *testing.T) { + wantRules := tt.rules + // The input rules are in camelCase, since they are the output of json2.MarshalPB + // The first iteration uses the output of routing rule Gets which are in snake_case. + for _, keyCase := range []string{"camelCase", "snake_case"} { + t.Run(keyCase, func(t *testing.T) { + var args []string + apply := fmt.Sprintf("Apply%s", typ) + get := fmt.Sprintf("Get%s", typ) + args = append(args, apply) + if tt.useFile { + tmpFile, err := os.CreateTemp("", fmt.Sprintf("%s_rules.json", tt.name)) + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + _, err = tmpFile.WriteString(wantRules) + require.NoError(t, err) + args = append(args, "--rules-file", tmpFile.Name()) + } else { + args = append(args, "--rules", wantRules) + } + var output string + var err error + if output, err = vc.VtctldClient.ExecuteCommandWithOutput(args...); err != nil { + require.FailNowf(t, "failed action", apply, "%v: %s", err, output) + } + if output, err = vc.VtctldClient.ExecuteCommandWithOutput(get); err != nil { + require.FailNowf(t, "failed action", get, "%v: %s", err, output) + } + validateRules(wantRules, output) + // output of GetRoutingRules is in snake_case and we use it for the next iteration which + // tests applying rules with snake_case keys. + wantRules = output + }) + } + }) + } +}