Skip to content

Commit

Permalink
Add config validation for all components (#3952)
Browse files Browse the repository at this point in the history
* Add config validation to all armada components

* Make invalid config error and remove required tag from unrequired fields

---------

Co-authored-by: Eleanor Pratt <Eleanor.Pratt@gresearch.co.uk>
  • Loading branch information
eleanorpratt and Eleanor Pratt committed Sep 18, 2024
1 parent f45aefa commit 549c5b9
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 51 deletions.
2 changes: 1 addition & 1 deletion cmd/scheduler/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func loadConfig() (schedulerconfig.Configuration, error) {
common.LoadConfig(&config, "./config/scheduler", viper.GetStringSlice(CustomConfigLocation))
err := config.Validate()
if err != nil {
commonconfig.LogValidationErrors(err)
return config, commonconfig.FormatValidationErrors(err)
}
return config, err
}
8 changes: 8 additions & 0 deletions internal/binoculars/configuration/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package configuration

import "github.com/go-playground/validator/v10"

func (c BinocularsConfig) Validate() error {
validate := validator.New()
return validate.Struct(c)
}
2 changes: 1 addition & 1 deletion internal/common/config/pulsar.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type PulsarConfig struct {
// Maximum allowed message size in bytes
MaxAllowedMessageSize uint
// Timeout when sending messages asynchronously
SendTimeout time.Duration `validate:"required"`
SendTimeout time.Duration
// Backoff from polling when Pulsar returns an error
BackoffTime time.Duration
// Number of pulsar messages that will be queued by the pulsar consumer.
Expand Down
29 changes: 17 additions & 12 deletions internal/common/config/validation.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
package config

import (
"errors"
"fmt"
"strings"

"github.com/go-playground/validator/v10"
log "github.com/sirupsen/logrus"
)

func LogValidationErrors(err error) {
if err != nil {
for _, err := range err.(validator.ValidationErrors) {
fieldName := stripPrefix(err.Namespace())
tag := err.Tag()
switch tag {
case "required":
log.Errorf("ConfigError: Field %s is required but was not found", fieldName)
default:
log.Errorf("ConfigError: Field %s has invalid value %s: %s", fieldName, err.Value(), tag)
}
type Config interface {
Validate() error
}

func FormatValidationErrors(err error) error {
var validationErrors error
for _, err := range err.(validator.ValidationErrors) {
fieldName := stripPrefix(err.Namespace())
tag := err.Tag()
switch tag {
case "required":
validationErrors = errors.Join(validationErrors, fmt.Errorf("ConfigError: Field %s is required but was not found", fieldName))
default:
validationErrors = errors.Join(validationErrors, fmt.Errorf("ConfigError: Field %s has invalid value %s: %s", fieldName, err.Value(), tag))
}
}
return validationErrors
}

func stripPrefix(s string) string {
Expand Down
2 changes: 1 addition & 1 deletion internal/common/grpc/configuration/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
)

type GrpcConfig struct {
Port int `validate:"required"`
Port int
KeepaliveParams keepalive.ServerParameters
KeepaliveEnforcementPolicy keepalive.EnforcementPolicy
Tls TlsConfig
Expand Down
7 changes: 6 additions & 1 deletion internal/common/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func BindCommandlineArguments() {
}

// TODO Move code relating to config out of common into a new package internal/serverconfig
func LoadConfig(config any, defaultPath string, overrideConfigs []string) *viper.Viper {
func LoadConfig(config commonconfig.Config, defaultPath string, overrideConfigs []string) *viper.Viper {
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
v.SetConfigName(baseConfigFileName)
v.AddConfigPath(defaultPath)
Expand Down Expand Up @@ -89,6 +89,11 @@ func LoadConfig(config any, defaultPath string, overrideConfigs []string) *viper
log.Debugf("Unset keys: %v", metadata.Unset)
}

if err := config.Validate(); err != nil {
log.Error(commonconfig.FormatValidationErrors(err))
os.Exit(-1)
}

return v
}

Expand Down
10 changes: 10 additions & 0 deletions internal/eventingester/configuration/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package configuration

import (
"github.com/go-playground/validator/v10"
)

func (c EventIngesterConfiguration) Validate() error {
validate := validator.New()
return validate.Struct(c)
}
8 changes: 8 additions & 0 deletions internal/executor/configuration/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package configuration

import "github.com/go-playground/validator/v10"

func (c ExecutorConfiguration) Validate() error {
validate := validator.New()
return validate.Struct(c)
}
8 changes: 8 additions & 0 deletions internal/lookoutingesterv2/configuration/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package configuration

import "github.com/go-playground/validator/v10"

func (c LookoutIngesterV2Configuration) Validate() error {
validate := validator.New()
return validate.Struct(c)
}
8 changes: 8 additions & 0 deletions internal/lookoutv2/configuration/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package configuration

import "github.com/go-playground/validator/v10"

func (c LookoutV2Config) Validate() error {
validate := validator.New()
return validate.Struct(c)
}
35 changes: 0 additions & 35 deletions internal/scheduler/configuration/configuration.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package configuration

import (
"fmt"
"time"

"github.com/go-playground/validator/v10"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"

Expand Down Expand Up @@ -59,12 +57,6 @@ type Configuration struct {
QueueRefreshPeriod time.Duration `validate:"required"`
}

func (c Configuration) Validate() error {
validate := validator.New()
validate.RegisterStructValidation(SchedulingConfigValidation, SchedulingConfig{})
return validate.Struct(c)
}

type LeaderConfig struct {
// Valid modes are "standalone" or "kubernetes"
Mode string `validate:"required"`
Expand Down Expand Up @@ -251,33 +243,6 @@ const (
UnknownWellKnownNodeTypeErrorMessage = "priority class refers to unknown well-known node type"
)

func SchedulingConfigValidation(sl validator.StructLevel) {
c := sl.Current().Interface().(SchedulingConfig)

wellKnownNodeTypes := make(map[string]bool)
for i, wellKnownNodeType := range c.WellKnownNodeTypes {
if wellKnownNodeTypes[wellKnownNodeType.Name] {
fieldName := fmt.Sprintf("WellKnownNodeTypes[%d].Name", i)
sl.ReportError(wellKnownNodeType.Name, fieldName, "", DuplicateWellKnownNodeTypeErrorMessage, "")
}
wellKnownNodeTypes[wellKnownNodeType.Name] = true
}

for priorityClassName, priorityClass := range c.PriorityClasses {
if len(priorityClass.AwayNodeTypes) > 0 && !priorityClass.Preemptible {
fieldName := fmt.Sprintf("Preemption.PriorityClasses[%s].Preemptible", priorityClassName)
sl.ReportError(priorityClass.Preemptible, fieldName, "", AwayNodeTypesWithoutPreemptionErrorMessage, "")
}

for i, awayNodeType := range priorityClass.AwayNodeTypes {
if !wellKnownNodeTypes[awayNodeType.WellKnownNodeTypeName] {
fieldName := fmt.Sprintf("Preemption.PriorityClasses[%s].AwayNodeTypes[%d].WellKnownNodeTypeName", priorityClassName, i)
sl.ReportError(awayNodeType.WellKnownNodeTypeName, fieldName, "", UnknownWellKnownNodeTypeErrorMessage, "")
}
}
}
}

// ResourceType represents a resource the scheduler indexes for efficient lookup.
type ResourceType struct {
// Resource name, e.g., "cpu", "memory", or "nvidia.com/gpu".
Expand Down
40 changes: 40 additions & 0 deletions internal/scheduler/configuration/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package configuration

import (
"fmt"

"github.com/go-playground/validator/v10"
)

func (c Configuration) Validate() error {
validate := validator.New()
validate.RegisterStructValidation(SchedulingConfigValidation, SchedulingConfig{})
return validate.Struct(c)
}

func SchedulingConfigValidation(sl validator.StructLevel) {
c := sl.Current().Interface().(SchedulingConfig)

wellKnownNodeTypes := make(map[string]bool)
for i, wellKnownNodeType := range c.WellKnownNodeTypes {
if wellKnownNodeTypes[wellKnownNodeType.Name] {
fieldName := fmt.Sprintf("WellKnownNodeTypes[%d].Name", i)
sl.ReportError(wellKnownNodeType.Name, fieldName, "", DuplicateWellKnownNodeTypeErrorMessage, "")
}
wellKnownNodeTypes[wellKnownNodeType.Name] = true
}

for priorityClassName, priorityClass := range c.PriorityClasses {
if len(priorityClass.AwayNodeTypes) > 0 && !priorityClass.Preemptible {
fieldName := fmt.Sprintf("Preemption.PriorityClasses[%s].Preemptible", priorityClassName)
sl.ReportError(priorityClass.Preemptible, fieldName, "", AwayNodeTypesWithoutPreemptionErrorMessage, "")
}

for i, awayNodeType := range priorityClass.AwayNodeTypes {
if !wellKnownNodeTypes[awayNodeType.WellKnownNodeTypeName] {
fieldName := fmt.Sprintf("Preemption.PriorityClasses[%s].AwayNodeTypes[%d].WellKnownNodeTypeName", priorityClassName, i)
sl.ReportError(awayNodeType.WellKnownNodeTypeName, fieldName, "", UnknownWellKnownNodeTypeErrorMessage, "")
}
}
}
}
7 changes: 7 additions & 0 deletions internal/scheduleringester/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package scheduleringester
import (
"time"

"github.com/go-playground/validator/v10"

commonconfig "github.com/armadaproject/armada/internal/common/config"
profilingconfig "github.com/armadaproject/armada/internal/common/profiling/configuration"
"github.com/armadaproject/armada/internal/server/configuration"
Expand All @@ -24,3 +26,8 @@ type Configuration struct {
// If non-nil, configures pprof profiling
Profiling *profilingconfig.ProfilingConfig
}

func (c Configuration) Validate() error {
validate := validator.New()
return validate.Struct(c)
}
8 changes: 8 additions & 0 deletions internal/server/configuration/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package configuration

import "github.com/go-playground/validator/v10"

func (c ArmadaConfig) Validate() error {
validate := validator.New()
return validate.Struct(c)
}

0 comments on commit 549c5b9

Please sign in to comment.