diff --git a/config.go b/config.go index fc692ad..714443c 100644 --- a/config.go +++ b/config.go @@ -34,19 +34,20 @@ type bouncerConfig struct { } type AclConfig struct { - WebACLName string `yaml:"web_acl_name"` - RuleGroupName string `yaml:"rule_group_name"` - Region string `yaml:"region"` - Scope string `yaml:"scope"` - IpsetPrefix string `yaml:"ipset_prefix"` - FallbackAction string `yaml:"fallback_action"` - AWSProfile string `yaml:"aws_profile"` - IPHeader string `yaml:"ip_header"` - IPHeaderPosition string `yaml:"ip_header_position"` - Capacity int `yaml:"capacity"` - CloudWatchEnabled bool `yaml:"cloudwatch_enabled"` - CloudWatchMetricName string `yaml:"cloudwatch_metric_name"` - SampleRequests bool `yaml:"sample_requests"` + WebACLName string `yaml:"web_acl_name"` + WebACLNames []string `yaml:"web_acl_names"` + RuleGroupName string `yaml:"rule_group_name"` + Region string `yaml:"region"` + Scope string `yaml:"scope"` + IpsetPrefix string `yaml:"ipset_prefix"` + FallbackAction string `yaml:"fallback_action"` + AWSProfile string `yaml:"aws_profile"` + IPHeader string `yaml:"ip_header"` + IPHeaderPosition string `yaml:"ip_header_position"` + Capacity int `yaml:"capacity"` + CloudWatchEnabled bool `yaml:"cloudwatch_enabled"` + CloudWatchMetricName string `yaml:"cloudwatch_metric_name"` + SampleRequests bool `yaml:"sample_requests"` } var validActions = []string{"ban", "captcha", "count"} @@ -85,6 +86,8 @@ func getConfigFromEnv(config *bouncerConfig) { switch k2 { case "WEB_ACL_NAME": acl.WebACLName = value + case "WEB_ACL_NAMES": + acl.WebACLNames = strings.Split(value, ",") case "RULE_GROUP_NAME": acl.RuleGroupName = value case "REGION": @@ -269,6 +272,9 @@ func newConfig(configPath string) (bouncerConfig, error) { return bouncerConfig{}, fmt.Errorf("waf_config is required") } for _, c := range config.WebACLConfig { + if c.WebACLName != "" && c.WebACLNames != nil { + return bouncerConfig{}, fmt.Errorf("waf_config must contain either web_acl_name or web_acl_names") + } if c.FallbackAction == "" { return bouncerConfig{}, fmt.Errorf("fallback_action is required") } diff --git a/main.go b/main.go index cfdd157..d9c66a0 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "os/signal" + "runtime/debug" "strings" "syscall" @@ -30,6 +31,10 @@ var wafInstances []*WAF = make([]*WAF, 0) var t *tomb.Tomb = &tomb.Tomb{} func cleanup() { + if r := recover(); r != nil { + log.Errorf("panic: %s", r) + log.Errorf("%s", debug.Stack()) + } for _, waf := range wafInstances { waf.logger.Infof("Cleaning up ressources") err := waf.Cleanup() @@ -191,7 +196,6 @@ func main() { } go signalHandler() - t.Go(func() error { bouncer.Run() return fmt.Errorf("stream api init failed") diff --git a/waf.go b/waf.go index acd392c..4588755 100644 --- a/waf.go +++ b/waf.go @@ -26,6 +26,7 @@ type WAF struct { ipsetManager *IPSetManager visibilityConfig *wafv2.VisibilityConfig lock sync.Mutex + acls []string } type IpSet struct { @@ -426,10 +427,6 @@ func (w *WAF) CleanupAcl(acl *wafv2.WebACL, token *string) error { log.Debugf("RuleGroup %s not found, nothing to do", w.config.RuleGroupName) } - if err != nil { - return errors.Wrapf(err, "Failed to list IPSets") - } - w.ipsetManager.DeleteSets() return nil @@ -443,11 +440,18 @@ func (w *WAF) Cleanup() error { if err != nil { return errors.Wrapf(err, "Failed to list WAF resources") } - acl, token, err := w.GetWebACL(w.config.WebACLName, w.aclsInfo[w.config.WebACLName].Id) - if err != nil { - return errors.Wrapf(err, "Failed to get WebACL") + for _, acl := range w.acls { + aclDetails, token, err := w.GetWebACL(acl, w.aclsInfo[acl].Id) + if err != nil { + log.Errorf("Failed to get ACL %s: %s", acl, err) + continue + } + err = w.CleanupAcl(aclDetails, token) + if err != nil { + log.Errorf("Failed to cleanup ACL %s: %s", *aclDetails.Name, err) + } } - return w.CleanupAcl(acl, token) + return nil } func (w *WAF) ListRessources() (map[string]Acl, map[string]IpSet, map[string]RuleGroup, error) { @@ -484,24 +488,31 @@ func (w *WAF) Init() error { w.logger.Tracef("Found %d RuleGroups", len(w.ruleGroupsInfos)) w.logger.Tracef("RuleGroups: %+v", w.ruleGroupsInfos) - if _, ok := w.aclsInfo[w.config.WebACLName]; !ok { - return fmt.Errorf("WebACL %s does not exist in region %s", w.config.WebACLName, w.config.Region) - } + w.ipsetManager = NewIPSetManager(w.config.IpsetPrefix, w.config.Scope, w.client, w.logger) - acl, token, err := w.GetWebACL(w.config.WebACLName, w.aclsInfo[w.config.WebACLName].Id) + for _, acl := range w.acls { + w.logger.Tracef("Adding ACL %s", acl) + if _, ok := w.aclsInfo[acl]; !ok { + return fmt.Errorf("WebACL %s does not exist in region %s", acl, w.config.Region) + } - if err != nil { - return errors.Wrap(err, "Failed to get WebACL") - } + aclDetails, token, err := w.GetWebACL(acl, w.aclsInfo[acl].Id) - w.ipsetManager = NewIPSetManager(w.config.IpsetPrefix, w.config.Scope, w.client, w.logger) + if err != nil { + w.logger.Error(err) + return errors.Wrap(err, "Failed to get WebACL") + } - err = w.CleanupAcl(acl, token) + err = w.CleanupAcl(aclDetails, token) - if err != nil { - return errors.Wrap(err, "Failed to cleanup") + if err != nil { + w.logger.Error(err) + return errors.Wrap(err, "Failed to cleanup") + } } + w.logger.Info("Cleanup done") + w.aclsInfo, w.setsInfos, w.ruleGroupsInfos, err = w.ListRessources() if err != nil { @@ -514,22 +525,20 @@ func (w *WAF) Init() error { return errors.Wrapf(err, "Failed to create RuleGroup %s", w.config.RuleGroupName) } - acl, lockTocken, err := w.GetWebACL(w.config.WebACLName, w.aclsInfo[w.config.WebACLName].Id) + for _, acl := range w.acls { - if err != nil { - return errors.Wrapf(err, "Failed to get WebACL %s", w.config.WebACLName) - } + aclDetails, lockTocken, err := w.GetWebACL(acl, w.aclsInfo[acl].Id) - err = w.AddRuleGroupToACL(acl, lockTocken) + if err != nil { + return errors.Wrapf(err, "Failed to get WebACL %s", acl) + } - if err != nil { - return errors.Wrapf(err, "Failed to add RuleGroup %s to WebACL %s", w.config.RuleGroupName, w.config.WebACLName) - } + err = w.AddRuleGroupToACL(aclDetails, lockTocken) - if err != nil { - return fmt.Errorf("failed to list ressources: %s", err) + if err != nil { + return errors.Wrapf(err, "Failed to add RuleGroup %s to WebACL %s", w.config.RuleGroupName, w.config.WebACLName) + } } - return nil } @@ -717,14 +726,22 @@ func (w *WAF) Dump() { func NewWaf(config AclConfig) (*WAF, error) { var s *session.Session + var acls []string if config.Scope == "CLOUDFRONT" { config.Region = "us-east-1" } + if config.WebACLName != "" { + acls = append(acls, config.WebACLName) + } + + if config.WebACLNames != nil { + acls = append(acls, config.WebACLNames...) + } + logger := log.WithFields(log.Fields{ "region": config.Region, "scope": config.Scope, - "acl": config.WebACLName, }) w := &WAF{ @@ -733,6 +750,7 @@ func NewWaf(config AclConfig) (*WAF, error) { ruleGroupsInfos: make(map[string]RuleGroup), logger: logger, decisionsChan: make(chan Decisions), + acls: acls, } if config.AWSProfile == "" {