diff --git a/cmd/server/app/options/options.go b/cmd/server/app/options/options.go index f887653c0..36718b552 100644 --- a/cmd/server/app/options/options.go +++ b/cmd/server/app/options/options.go @@ -19,7 +19,6 @@ package options import ( "fmt" "os" - "strings" "time" "github.com/google/uuid" @@ -90,10 +89,10 @@ type ProxyRunOptions struct { // Proxy strategies used by the server. // NOTE the order of the strategies matters. e.g., for list - // "destHost,destCIDR", the server will try to find a backend associating + // "destHost,destCIDR,default", the server will try to find a backend associating // to the destination host first, if not found, it will try to find a // backend within the destCIDR. if it still can't find any backend, - // it will use the default backend manager to choose a random backend. + // it will choose a random backend. ProxyStrategies string // Cipher suites used by the server. @@ -135,7 +134,7 @@ func (o *ProxyRunOptions) Flags() *pflag.FlagSet { flags.Float32Var(&o.KubeconfigQPS, "kubeconfig-qps", o.KubeconfigQPS, "Maximum client QPS (proxy server uses this client to authenticate agent tokens).") flags.IntVar(&o.KubeconfigBurst, "kubeconfig-burst", o.KubeconfigBurst, "Maximum client burst (proxy server uses this client to authenticate agent tokens).") flags.StringVar(&o.AuthenticationAudience, "authentication-audience", o.AuthenticationAudience, "Expected agent's token authentication audience (used with agent-namespace, agent-service-account, kubeconfig).") - flags.StringVar(&o.ProxyStrategies, "proxy-strategies", o.ProxyStrategies, "The list of proxy strategies used by the server to pick a backend/tunnel, available strategies are: default, destHost.") + flags.StringVar(&o.ProxyStrategies, "proxy-strategies", o.ProxyStrategies, "The list of proxy strategies used by the server to pick an agent/tunnel, available strategies are: default, destHost, defaultRoute.") flags.StringSliceVar(&o.CipherSuites, "cipher-suites", o.CipherSuites, "The comma separated list of allowed cipher suites. Has no effect on TLS1.3. Empty means allow default list.") flags.Bool("warn-on-channel-limit", true, "This behavior is now thread safe and always on. This flag will be removed in a future release.") @@ -292,17 +291,11 @@ func (o *ProxyRunOptions) Validate() error { } // validate the proxy strategies - if o.ProxyStrategies != "" { - pss := strings.Split(o.ProxyStrategies, ",") - for _, ps := range pss { - switch ps { - case string(server.ProxyStrategyDestHost): - case string(server.ProxyStrategyDefault): - case string(server.ProxyStrategyDefaultRoute): - default: - return fmt.Errorf("unknown proxy strategy: %s, available strategy are: default, destHost, defaultRoute", ps) - } - } + if len(o.ProxyStrategies) == 0 { + return fmt.Errorf("ProxyStrategies cannot be empty") + } + if _, err := server.ParseProxyStrategies(o.ProxyStrategies); err != nil { + return fmt.Errorf("invalid proxy strategies: %v", err) } // validate the cipher suites diff --git a/cmd/server/app/options/options_test.go b/cmd/server/app/options/options_test.go index f7e4d63ea..709f4e267 100644 --- a/cmd/server/app/options/options_test.go +++ b/cmd/server/app/options/options_test.go @@ -145,6 +145,16 @@ func TestValidate(t *testing.T) { value: "TLS_AES_256_GCM_SHA384,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", expected: nil, }, + "Empty proxy strategies": { + field: "ProxyStrategies", + value: "", + expected: fmt.Errorf("ProxyStrategies cannot be empty"), + }, + "Invalid proxy strategies": { + field: "ProxyStrategies", + value: "invalid", + expected: fmt.Errorf("invalid proxy strategies: unknown proxy strategy: invalid"), + }, } { t.Run(desc, func(t *testing.T) { testServerOptions := NewProxyRunOptions() diff --git a/cmd/server/app/server.go b/cmd/server/app/server.go index 4f2ce21b5..ff49a090e 100644 --- a/cmd/server/app/server.go +++ b/cmd/server/app/server.go @@ -128,7 +128,7 @@ func (p *Proxy) Run(o *options.ProxyRunOptions, stopCh <-chan struct{}) error { AuthenticationAudience: o.AuthenticationAudience, } klog.V(1).Infoln("Starting frontend server for client connections.") - ps, err := server.GenProxyStrategiesFromStr(o.ProxyStrategies) + ps, err := server.ParseProxyStrategies(o.ProxyStrategies) if err != nil { return err } diff --git a/pkg/server/backend_manager.go b/pkg/server/backend_manager.go index 56c266ae7..12dd229e4 100644 --- a/pkg/server/backend_manager.go +++ b/pkg/server/backend_manager.go @@ -36,40 +36,66 @@ import ( "sigs.k8s.io/apiserver-network-proxy/proto/header" ) -type ProxyStrategy string +type ProxyStrategy int const ( // With this strategy the Proxy Server will randomly pick a backend from // the current healthy backends to establish the tunnel over which to // forward requests. - ProxyStrategyDefault ProxyStrategy = "default" + ProxyStrategyDefault ProxyStrategy = iota + 1 // With this strategy the Proxy Server will pick a backend that has the same // associated host as the request.Host to establish the tunnel. - ProxyStrategyDestHost ProxyStrategy = "destHost" - + ProxyStrategyDestHost // ProxyStrategyDefaultRoute will only forward traffic to agents that have explicity advertised // they serve the default route through an agent identifier. Typically used in combination with destHost - ProxyStrategyDefaultRoute ProxyStrategy = "defaultRoute" + ProxyStrategyDefaultRoute ) +func (ps ProxyStrategy) String() string { + switch ps { + case ProxyStrategyDefault: + return "default" + case ProxyStrategyDestHost: + return "destHost" + case ProxyStrategyDefaultRoute: + return "defaultRoute" + } + panic(fmt.Sprintf("unhandled ProxyStrategy: %d", ps)) +} + +func ParseProxyStrategy(s string) (ProxyStrategy, error) { + switch s { + case ProxyStrategyDefault.String(): + return ProxyStrategyDefault, nil + case ProxyStrategyDestHost.String(): + return ProxyStrategyDestHost, nil + case ProxyStrategyDefaultRoute.String(): + return ProxyStrategyDefaultRoute, nil + default: + return 0, fmt.Errorf("unknown proxy strategy: %s", s) + } +} + // GenProxyStrategiesFromStr generates the list of proxy strategies from the // comma-seperated string, i.e., destHost. -func GenProxyStrategiesFromStr(proxyStrategies string) ([]ProxyStrategy, error) { - var ps []ProxyStrategy +func ParseProxyStrategies(proxyStrategies string) ([]ProxyStrategy, error) { + var result []ProxyStrategy + strs := strings.Split(proxyStrategies, ",") for _, s := range strs { - switch s { - case string(ProxyStrategyDestHost): - ps = append(ps, ProxyStrategyDestHost) - case string(ProxyStrategyDefault): - ps = append(ps, ProxyStrategyDefault) - case string(ProxyStrategyDefaultRoute): - ps = append(ps, ProxyStrategyDefaultRoute) - default: - return nil, fmt.Errorf("Unknown proxy strategy %s", s) + if len(s) == 0 { + continue } + ps, err := ParseProxyStrategy(s) + if err != nil { + return nil, err + } + result = append(result, ps) + } + if len(result) == 0 { + return nil, fmt.Errorf("proxy strategies cannot be empty") } - return ps, nil + return result, nil } // Backend abstracts a connected Konnectivity agent. diff --git a/pkg/server/backend_manager_test.go b/pkg/server/backend_manager_test.go index 4a8853194..107103cba 100644 --- a/pkg/server/backend_manager_test.go +++ b/pkg/server/backend_manager_test.go @@ -18,9 +18,11 @@ package server import ( "context" + "fmt" "reflect" "testing" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "google.golang.org/grpc/metadata" @@ -383,3 +385,123 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) { t.Errorf("expected %v, got %v", e, a) } } + +func TestProxyStrategy(t *testing.T) { + for desc, tc := range map[string]struct { + input ProxyStrategy + want string + wantPanic string + }{ + "default": { + input: ProxyStrategyDefault, + want: "default", + }, + "destHost": { + input: ProxyStrategyDestHost, + want: "destHost", + }, + "defaultRoute": { + input: ProxyStrategyDefaultRoute, + want: "defaultRoute", + }, + "unrecognized": { + input: ProxyStrategy(0), + wantPanic: "unhandled ProxyStrategy: 0", + }, + } { + t.Run(desc, func(t *testing.T) { + if tc.wantPanic != "" { + assert.PanicsWithValue(t, tc.wantPanic, func() { + _ = tc.input.String() + }) + } else { + got := tc.input.String() + if got != tc.want { + t.Errorf("ProxyStrategy.String(): got %v, want %v", got, tc.want) + } + } + }) + } +} + +func TestParseProxyStrategy(t *testing.T) { + for desc, tc := range map[string]struct { + input string + want ProxyStrategy + wantErr error + }{ + "empty": { + input: "", + wantErr: fmt.Errorf("unknown proxy strategy: "), + }, + "unrecognized": { + input: "unrecognized", + wantErr: fmt.Errorf("unknown proxy strategy: unrecognized"), + }, + "default": { + input: "default", + want: ProxyStrategyDefault, + }, + "destHost": { + input: "destHost", + want: ProxyStrategyDestHost, + }, + "defaultRoute": { + input: "defaultRoute", + want: ProxyStrategyDefaultRoute, + }, + } { + t.Run(desc, func(t *testing.T) { + got, err := ParseProxyStrategy(tc.input) + assert.Equal(t, tc.wantErr, err, "ParseProxyStrategy(%s): got error %q, want %v", tc.input, err, tc.wantErr) + if got != tc.want { + t.Errorf("ParseProxyStrategy(%s): got %v, want %v", tc.input, got, tc.want) + } + }) + } +} + +func TestParseProxyStrategies(t *testing.T) { + for desc, tc := range map[string]struct { + input string + want []ProxyStrategy + wantErr error + }{ + "empty": { + input: "", + wantErr: fmt.Errorf("proxy strategies cannot be empty"), + }, + "unrecognized": { + input: "unrecognized", + wantErr: fmt.Errorf("unknown proxy strategy: unrecognized"), + }, + "default": { + input: "default", + want: []ProxyStrategy{ProxyStrategyDefault}, + }, + "destHost": { + input: "destHost", + want: []ProxyStrategy{ProxyStrategyDestHost}, + }, + "defaultRoute": { + input: "defaultRoute", + want: []ProxyStrategy{ProxyStrategyDefaultRoute}, + }, + "duplicate": { + input: "destHost,defaultRoute,defaultRoute,default", + want: []ProxyStrategy{ProxyStrategyDestHost, ProxyStrategyDefaultRoute, ProxyStrategyDefaultRoute, ProxyStrategyDefault}, + }, + "multiple": { + input: "destHost,defaultRoute,default", + want: []ProxyStrategy{ProxyStrategyDestHost, ProxyStrategyDefaultRoute, ProxyStrategyDefault}, + }, + } { + t.Run(desc, func(t *testing.T) { + got, err := ParseProxyStrategies(tc.input) + assert.Equal(t, tc.wantErr, err, "ParseProxyStrategies(%s): got error %q, want %v", tc.input, err, tc.wantErr) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("ParseProxyStrategies(%s): got %v, want %v", tc.input, got, tc.want) + } + }) + } +}