Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to link multiple webACLs to the same rulegroup #15

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,20 @@ type bouncerConfig struct {
}

type AclConfig struct {
WebACLName string `yaml:"web_acl_name"`
RuleGroupName string `yaml:"rule_group_name"`
Region string `yaml:"region"`
Scope string `yaml:"scope"`
IpsetPrefix string `yaml:"ipset_prefix"`
FallbackAction string `yaml:"fallback_action"`
AWSProfile string `yaml:"aws_profile"`
IPHeader string `yaml:"ip_header"`
IPHeaderPosition string `yaml:"ip_header_position"`
Capacity int `yaml:"capacity"`
CloudWatchEnabled bool `yaml:"cloudwatch_enabled"`
CloudWatchMetricName string `yaml:"cloudwatch_metric_name"`
SampleRequests bool `yaml:"sample_requests"`
WebACLName string `yaml:"web_acl_name"`
WebACLNames []string `yaml:"web_acl_names"`
RuleGroupName string `yaml:"rule_group_name"`
Region string `yaml:"region"`
Scope string `yaml:"scope"`
IpsetPrefix string `yaml:"ipset_prefix"`
FallbackAction string `yaml:"fallback_action"`
AWSProfile string `yaml:"aws_profile"`
IPHeader string `yaml:"ip_header"`
IPHeaderPosition string `yaml:"ip_header_position"`
Capacity int `yaml:"capacity"`
CloudWatchEnabled bool `yaml:"cloudwatch_enabled"`
CloudWatchMetricName string `yaml:"cloudwatch_metric_name"`
SampleRequests bool `yaml:"sample_requests"`
}

var validActions = []string{"ban", "captcha", "count"}
Expand Down Expand Up @@ -85,6 +86,8 @@ func getConfigFromEnv(config *bouncerConfig) {
switch k2 {
case "WEB_ACL_NAME":
acl.WebACLName = value
case "WEB_ACL_NAMES":
acl.WebACLNames = strings.Split(value, ",")
case "RULE_GROUP_NAME":
acl.RuleGroupName = value
case "REGION":
Expand Down Expand Up @@ -269,6 +272,9 @@ func newConfig(configPath string) (bouncerConfig, error) {
return bouncerConfig{}, fmt.Errorf("waf_config is required")
}
for _, c := range config.WebACLConfig {
if c.WebACLName != "" && c.WebACLNames != nil {
return bouncerConfig{}, fmt.Errorf("waf_config must contain either web_acl_name or web_acl_names")
}
if c.FallbackAction == "" {
return bouncerConfig{}, fmt.Errorf("fallback_action is required")
}
Expand Down
6 changes: 5 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"os/signal"
"runtime/debug"
"strings"
"syscall"

Expand All @@ -30,6 +31,10 @@ var wafInstances []*WAF = make([]*WAF, 0)
var t *tomb.Tomb = &tomb.Tomb{}

func cleanup() {
if r := recover(); r != nil {
log.Errorf("panic: %s", r)
log.Errorf("%s", debug.Stack())
}
for _, waf := range wafInstances {
waf.logger.Infof("Cleaning up ressources")
err := waf.Cleanup()
Expand Down Expand Up @@ -191,7 +196,6 @@ func main() {
}

go signalHandler()

t.Go(func() error {
bouncer.Run()
return fmt.Errorf("stream api init failed")
Expand Down
80 changes: 49 additions & 31 deletions waf.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type WAF struct {
ipsetManager *IPSetManager
visibilityConfig *wafv2.VisibilityConfig
lock sync.Mutex
acls []string
}

type IpSet struct {
Expand Down Expand Up @@ -426,10 +427,6 @@ func (w *WAF) CleanupAcl(acl *wafv2.WebACL, token *string) error {
log.Debugf("RuleGroup %s not found, nothing to do", w.config.RuleGroupName)
}

if err != nil {
return errors.Wrapf(err, "Failed to list IPSets")
}

w.ipsetManager.DeleteSets()

return nil
Expand All @@ -443,11 +440,18 @@ func (w *WAF) Cleanup() error {
if err != nil {
return errors.Wrapf(err, "Failed to list WAF resources")
}
acl, token, err := w.GetWebACL(w.config.WebACLName, w.aclsInfo[w.config.WebACLName].Id)
if err != nil {
return errors.Wrapf(err, "Failed to get WebACL")
for _, acl := range w.acls {
aclDetails, token, err := w.GetWebACL(acl, w.aclsInfo[acl].Id)
if err != nil {
log.Errorf("Failed to get ACL %s: %s", acl, err)
continue
}
err = w.CleanupAcl(aclDetails, token)
if err != nil {
log.Errorf("Failed to cleanup ACL %s: %s", *aclDetails.Name, err)
}
}
return w.CleanupAcl(acl, token)
return nil
}

func (w *WAF) ListRessources() (map[string]Acl, map[string]IpSet, map[string]RuleGroup, error) {
Expand Down Expand Up @@ -484,24 +488,31 @@ func (w *WAF) Init() error {
w.logger.Tracef("Found %d RuleGroups", len(w.ruleGroupsInfos))
w.logger.Tracef("RuleGroups: %+v", w.ruleGroupsInfos)

if _, ok := w.aclsInfo[w.config.WebACLName]; !ok {
return fmt.Errorf("WebACL %s does not exist in region %s", w.config.WebACLName, w.config.Region)
}
w.ipsetManager = NewIPSetManager(w.config.IpsetPrefix, w.config.Scope, w.client, w.logger)

acl, token, err := w.GetWebACL(w.config.WebACLName, w.aclsInfo[w.config.WebACLName].Id)
for _, acl := range w.acls {
w.logger.Tracef("Adding ACL %s", acl)
if _, ok := w.aclsInfo[acl]; !ok {
return fmt.Errorf("WebACL %s does not exist in region %s", acl, w.config.Region)
}

if err != nil {
return errors.Wrap(err, "Failed to get WebACL")
}
aclDetails, token, err := w.GetWebACL(acl, w.aclsInfo[acl].Id)

w.ipsetManager = NewIPSetManager(w.config.IpsetPrefix, w.config.Scope, w.client, w.logger)
if err != nil {
w.logger.Error(err)
return errors.Wrap(err, "Failed to get WebACL")
}

err = w.CleanupAcl(acl, token)
err = w.CleanupAcl(aclDetails, token)

if err != nil {
return errors.Wrap(err, "Failed to cleanup")
if err != nil {
w.logger.Error(err)
return errors.Wrap(err, "Failed to cleanup")
}
}

w.logger.Info("Cleanup done")

w.aclsInfo, w.setsInfos, w.ruleGroupsInfos, err = w.ListRessources()

if err != nil {
Expand All @@ -514,22 +525,20 @@ func (w *WAF) Init() error {
return errors.Wrapf(err, "Failed to create RuleGroup %s", w.config.RuleGroupName)
}

acl, lockTocken, err := w.GetWebACL(w.config.WebACLName, w.aclsInfo[w.config.WebACLName].Id)
for _, acl := range w.acls {

if err != nil {
return errors.Wrapf(err, "Failed to get WebACL %s", w.config.WebACLName)
}
aclDetails, lockTocken, err := w.GetWebACL(acl, w.aclsInfo[acl].Id)

err = w.AddRuleGroupToACL(acl, lockTocken)
if err != nil {
return errors.Wrapf(err, "Failed to get WebACL %s", acl)
}

if err != nil {
return errors.Wrapf(err, "Failed to add RuleGroup %s to WebACL %s", w.config.RuleGroupName, w.config.WebACLName)
}
err = w.AddRuleGroupToACL(aclDetails, lockTocken)

if err != nil {
return fmt.Errorf("failed to list ressources: %s", err)
if err != nil {
return errors.Wrapf(err, "Failed to add RuleGroup %s to WebACL %s", w.config.RuleGroupName, w.config.WebACLName)
}
}

return nil
}

Expand Down Expand Up @@ -717,14 +726,22 @@ func (w *WAF) Dump() {

func NewWaf(config AclConfig) (*WAF, error) {
var s *session.Session
var acls []string
if config.Scope == "CLOUDFRONT" {
config.Region = "us-east-1"
}

if config.WebACLName != "" {
acls = append(acls, config.WebACLName)
}

if config.WebACLNames != nil {
acls = append(acls, config.WebACLNames...)
}

logger := log.WithFields(log.Fields{
"region": config.Region,
"scope": config.Scope,
"acl": config.WebACLName,
})

w := &WAF{
Expand All @@ -733,6 +750,7 @@ func NewWaf(config AclConfig) (*WAF, error) {
ruleGroupsInfos: make(map[string]RuleGroup),
logger: logger,
decisionsChan: make(chan Decisions),
acls: acls,
}

if config.AWSProfile == "" {
Expand Down