From eb1789eb8b079ae209e6a30a69cd2d7a7910c15f Mon Sep 17 00:00:00 2001 From: Shinnosuke Sawada-Dazai Date: Thu, 31 Oct 2024 15:42:17 +0900 Subject: [PATCH] Make configv1.Config generic type Signed-off-by: Shinnosuke Sawada-Dazai --- pkg/app/pipedv1/controller/planner.go | 4 +- pkg/configv1/analysis_template.go | 4 +- pkg/configv1/application_test.go | 12 +-- pkg/configv1/config.go | 108 +++++++------------------- pkg/configv1/control_plane_test.go | 10 +-- pkg/configv1/event_watcher.go | 4 +- pkg/configv1/piped_test.go | 4 +- 7 files changed, 45 insertions(+), 101 deletions(-) diff --git a/pkg/app/pipedv1/controller/planner.go b/pkg/app/pipedv1/controller/planner.go index bb7f118d73..fa398d2cc1 100644 --- a/pkg/app/pipedv1/controller/planner.go +++ b/pkg/app/pipedv1/controller/planner.go @@ -272,12 +272,12 @@ func (p *planner) buildPlan(ctx context.Context, runningDS, targetDS *deployment } } - cfg, err := config.DecodeYAML(targetDS.GetApplicationConfig()) + cfg, err := config.DecodeYAML[*config.GenericApplicationSpec](targetDS.GetApplicationConfig()) if err != nil { p.logger.Error("unable to parse application config", zap.Error(err)) return nil, err } - spec := cfg.ApplicationSpec + spec := cfg.Spec // In case the strategy has been decided by trigger. // For example: user triggered the deployment via web console. diff --git a/pkg/configv1/analysis_template.go b/pkg/configv1/analysis_template.go index d00be027e4..557c8fd0ca 100644 --- a/pkg/configv1/analysis_template.go +++ b/pkg/configv1/analysis_template.go @@ -47,12 +47,12 @@ func LoadAnalysisTemplate(repoRoot string) (*AnalysisTemplateSpec, error) { continue } path := filepath.Join(dir, f.Name()) - cfg, err := LoadFromYAML(path) + cfg, err := LoadFromYAML[*AnalysisTemplateSpec](path) if err != nil { return nil, fmt.Errorf("failed to load config file %s: %w", path, err) } if cfg.Kind == KindAnalysisTemplate { - return cfg.AnalysisTemplateSpec, nil + return cfg.Spec, nil } } return nil, ErrNotFound diff --git a/pkg/configv1/application_test.go b/pkg/configv1/application_test.go index 440aa6b11a..15eb2f17bb 100644 --- a/pkg/configv1/application_test.go +++ b/pkg/configv1/application_test.go @@ -409,12 +409,12 @@ func TestGenericTriggerConfiguration(t *testing.T) { } for _, tc := range testcases { t.Run(tc.fileName, func(t *testing.T) { - cfg, err := LoadFromYAML(tc.fileName) + cfg, err := LoadFromYAML[*GenericApplicationSpec](tc.fileName) require.Equal(t, tc.expectedError, err) if err == nil { assert.Equal(t, tc.expectedKind, cfg.Kind) assert.Equal(t, tc.expectedAPIVersion, cfg.APIVersion) - assert.Equal(t, tc.expectedSpec, cfg.spec) + assert.Equal(t, tc.expectedSpec, cfg.Spec) } }) } @@ -494,12 +494,12 @@ func TestTrueByDefaultBoolConfiguration(t *testing.T) { } for _, tc := range testcases { t.Run(tc.fileName, func(t *testing.T) { - cfg, err := LoadFromYAML(tc.fileName) + cfg, err := LoadFromYAML[*GenericApplicationSpec](tc.fileName) require.Equal(t, tc.expectedError, err) if err == nil { assert.Equal(t, tc.expectedKind, cfg.Kind) assert.Equal(t, tc.expectedAPIVersion, cfg.APIVersion) - assert.Equal(t, tc.expectedSpec, cfg.spec) + assert.Equal(t, tc.expectedSpec, cfg.Spec) } }) } @@ -555,12 +555,12 @@ func TestGenericPostSyncConfiguration(t *testing.T) { } for _, tc := range testcases { t.Run(tc.fileName, func(t *testing.T) { - cfg, err := LoadFromYAML(tc.fileName) + cfg, err := LoadFromYAML[*GenericApplicationSpec](tc.fileName) require.Equal(t, tc.expectedError, err) if err == nil { assert.Equal(t, tc.expectedKind, cfg.Kind) assert.Equal(t, tc.expectedAPIVersion, cfg.APIVersion) - assert.Equal(t, tc.expectedSpec, cfg.spec) + assert.Equal(t, tc.expectedSpec, cfg.Spec) } }) } diff --git a/pkg/configv1/config.go b/pkg/configv1/config.go index f455fbd1e6..700f405433 100644 --- a/pkg/configv1/config.go +++ b/pkg/configv1/config.go @@ -15,7 +15,6 @@ package config import ( - "bytes" "encoding/json" "errors" "fmt" @@ -79,126 +78,71 @@ var ( ErrNotFound = errors.New("not found") ) +// Spec[T] represents both of follows +// - the type is pointer type of T +// - the type has Validate method +type Spec[T any] interface { + *T + Validate() error +} + // Config represents configuration data load from file. // The spec is depend on the kind of configuration. -type Config struct { +type Config[T Spec[RT], RT any] struct { Kind Kind APIVersion string - spec interface{} - - ApplicationSpec *GenericApplicationSpec - - PipedSpec *PipedSpec - ControlPlaneSpec *ControlPlaneSpec - AnalysisTemplateSpec *AnalysisTemplateSpec - EventWatcherSpec *EventWatcherSpec -} - -type genericConfig struct { - Kind Kind `json:"kind"` - APIVersion string `json:"apiVersion,omitempty"` - Spec json.RawMessage `json:"spec"` -} - -func (c *Config) init(kind Kind, apiVersion string) error { - c.Kind = kind - c.APIVersion = apiVersion - - switch kind { - case KindApplication, KindKubernetesApp, KindTerraformApp, KindCloudRunApp, KindLambdaApp, KindECSApp: - c.ApplicationSpec = &GenericApplicationSpec{} - c.spec = c.ApplicationSpec - - case KindPiped: - c.PipedSpec = &PipedSpec{} - c.spec = c.PipedSpec - - case KindControlPlane: - c.ControlPlaneSpec = &ControlPlaneSpec{} - c.spec = c.ControlPlaneSpec - - case KindAnalysisTemplate: - c.AnalysisTemplateSpec = &AnalysisTemplateSpec{} - c.spec = c.AnalysisTemplateSpec - - case KindEventWatcher: - c.EventWatcherSpec = &EventWatcherSpec{} - c.spec = c.EventWatcherSpec - - default: - return fmt.Errorf("unsupported kind: %s", c.Kind) - } - return nil + Spec T } -// UnmarshalJSON customizes the way to unmarshal json data into Config struct. -// Firstly, this unmarshal to a generic config and then unmarshal the spec -// which depend on the kind of configuration. -func (c *Config) UnmarshalJSON(data []byte) error { - var ( - err error - gc = genericConfig{} - ) - dec := json.NewDecoder(bytes.NewReader(data)) - dec.DisallowUnknownFields() - if err := dec.Decode(&gc); err != nil { - return err - } - if err = c.init(gc.Kind, gc.APIVersion); err != nil { +func (c *Config[T, RT]) UnmarshalJSON(data []byte) error { + // Define a type alias Config[T, RT] to avoid infinite recursion. + type alias Config[T, RT] + a := alias{} + if err := json.Unmarshal(data, &a); err != nil { return err } + *c = Config[T, RT](a) - if len(gc.Spec) > 0 { - dec := json.NewDecoder(bytes.NewReader(gc.Spec)) - dec.DisallowUnknownFields() - err = dec.Decode(c.spec) + // Set default values. + if c.Spec == nil { + c.Spec = new(RT) } - return err -} -type validator interface { - Validate() error + return nil } // Validate validates the value of all fields. -func (c *Config) Validate() error { +func (c *Config[T, RT]) Validate() error { if c.APIVersion != VersionV1Beta1 { return fmt.Errorf("unsupported version: %s", c.APIVersion) } if c.Kind == "" { return fmt.Errorf("kind is required") } - if c.spec == nil { - return fmt.Errorf("spec is required") - } - spec, ok := c.spec.(validator) - if !ok { - return fmt.Errorf("spec must have Validate function") - } - if err := spec.Validate(); err != nil { + if err := c.Spec.Validate(); err != nil { return err } return nil } // LoadFromYAML reads and decodes a yaml file to construct the Config. -func LoadFromYAML(file string) (*Config, error) { +func LoadFromYAML[T Spec[RT], RT any](file string) (*Config[T, RT], error) { data, err := os.ReadFile(file) if err != nil { return nil, err } - return DecodeYAML(data) + return DecodeYAML[T, RT](data) } // DecodeYAML unmarshals config YAML data to config struct. // It also validates the configuration after decoding. -func DecodeYAML(data []byte) (*Config, error) { +func DecodeYAML[T Spec[RT], RT any](data []byte) (*Config[T, RT], error) { js, err := yaml.YAMLToJSON(data) if err != nil { return nil, err } - c := &Config{} + c := &Config[T, RT]{} if err := json.Unmarshal(js, c); err != nil { return nil, err } diff --git a/pkg/configv1/control_plane_test.go b/pkg/configv1/control_plane_test.go index 333c2ec88e..019d8b7e95 100644 --- a/pkg/configv1/control_plane_test.go +++ b/pkg/configv1/control_plane_test.go @@ -96,20 +96,20 @@ func TestControlPlaneConfig(t *testing.T) { } for _, tc := range testcases { t.Run(tc.fileName, func(t *testing.T) { - cfg, err := LoadFromYAML(tc.fileName) + cfg, err := LoadFromYAML[*ControlPlaneSpec](tc.fileName) require.Equal(t, tc.expectedError, err) if err == nil { assert.Equal(t, tc.expectedKind, cfg.Kind) assert.Equal(t, tc.expectedAPIVersion, cfg.APIVersion) require.Equal(t, 1, len(tc.expectedSpec.SharedSSOConfigs)) - require.Equal(t, 1, len(cfg.ControlPlaneSpec.SharedSSOConfigs)) + require.Equal(t, 1, len(cfg.Spec.SharedSSOConfigs)) // Why don't we use assert.Equal to compare? // https://github.com/stretchr/testify/issues/758 - assert.True(t, proto.Equal(&tc.expectedSpec.SharedSSOConfigs[0].ProjectSSOConfig, &cfg.ControlPlaneSpec.SharedSSOConfigs[0].ProjectSSOConfig)) + assert.True(t, proto.Equal(&tc.expectedSpec.SharedSSOConfigs[0].ProjectSSOConfig, &cfg.Spec.SharedSSOConfigs[0].ProjectSSOConfig)) tc.expectedSpec.SharedSSOConfigs = nil - cfg.ControlPlaneSpec.SharedSSOConfigs = nil - assert.Equal(t, tc.expectedSpec, cfg.ControlPlaneSpec) + cfg.Spec.SharedSSOConfigs = nil + assert.Equal(t, tc.expectedSpec, cfg.Spec) } }) } diff --git a/pkg/configv1/event_watcher.go b/pkg/configv1/event_watcher.go index 282df6d3ed..3bbcad5b04 100644 --- a/pkg/configv1/event_watcher.go +++ b/pkg/configv1/event_watcher.go @@ -132,12 +132,12 @@ func LoadEventWatcher(repoRoot string, includePatterns, excludePatterns []string } for _, f := range filtered { path := filepath.Join(dir, f) - cfg, err := LoadFromYAML(path) + cfg, err := LoadFromYAML[*EventWatcherSpec](path) if err != nil { return nil, fmt.Errorf("failed to load config file %s: %w", path, err) } if cfg.Kind == KindEventWatcher { - spec.Events = append(spec.Events, cfg.EventWatcherSpec.Events...) + spec.Events = append(spec.Events, cfg.Spec.Events...) } } diff --git a/pkg/configv1/piped_test.go b/pkg/configv1/piped_test.go index a5f4f2694c..96f7b54d4e 100644 --- a/pkg/configv1/piped_test.go +++ b/pkg/configv1/piped_test.go @@ -366,12 +366,12 @@ func TestPipedConfig(t *testing.T) { } for _, tc := range testcases { t.Run(tc.fileName, func(t *testing.T) { - cfg, err := LoadFromYAML(tc.fileName) + cfg, err := LoadFromYAML[*PipedSpec](tc.fileName) require.Equal(t, tc.expectedError, err) if err == nil { assert.Equal(t, tc.expectedKind, cfg.Kind) assert.Equal(t, tc.expectedAPIVersion, cfg.APIVersion) - assert.Equal(t, tc.expectedSpec, cfg.spec) + assert.Equal(t, tc.expectedSpec, cfg.Spec) } }) }