diff --git a/cmd/root.go b/cmd/root.go index 4ccfe3a..756513a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -204,7 +204,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error { if err != nil { return configError{Field: "io", Err: err} } - config.IOs = []io.PacketIO{nfio} + config.IO = nfio return nil } @@ -247,12 +247,7 @@ func runMain(cmd *cobra.Command, args []string) { if err != nil { logger.Fatal("failed to parse config", zap.Error(err)) } - defer func() { - // Make sure to close all IOs on exit - for _, i := range engineConfig.IOs { - _ = i.Close() - } - }() + defer engineConfig.IO.Close() // Make sure to close IO on exit // Ruleset rawRs, err := ruleset.ExprRulesFromYAML(args[0]) @@ -260,9 +255,10 @@ func runMain(cmd *cobra.Command, args []string) { logger.Fatal("failed to load rules", zap.Error(err)) } rsConfig := &ruleset.BuiltinConfig{ - Logger: &rulesetLogger{}, - GeoSiteFilename: config.Ruleset.GeoSite, - GeoIpFilename: config.Ruleset.GeoIp, + Logger: &rulesetLogger{}, + GeoSiteFilename: config.Ruleset.GeoSite, + GeoIpFilename: config.Ruleset.GeoIp, + ProtectedDialContext: engineConfig.IO.ProtectedDialContext, } rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig) if err != nil { diff --git a/engine/engine.go b/engine/engine.go index e8c1bfd..7c93e0a 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil) type engine struct { logger Logger - ioList []io.PacketIO + io io.PacketIO workers []*worker } @@ -42,7 +42,7 @@ func NewEngine(config Config) (Engine, error) { } return &engine{ logger: config.Logger, - ioList: config.IOs, + io: config.IO, workers: workers, }, nil } @@ -58,27 +58,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { func (e *engine) Run(ctx context.Context) error { ioCtx, ioCancel := context.WithCancel(ctx) - defer ioCancel() // Stop workers & IOs + defer ioCancel() // Stop workers & IO // Start workers for _, w := range e.workers { go w.Run(ioCtx) } - // Register callbacks - errChan := make(chan error, len(e.ioList)) - for _, i := range e.ioList { - ioEntry := i // Make sure dispatch() uses the correct ioEntry - err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool { - if err != nil { - errChan <- err - return false - } - return e.dispatch(ioEntry, p) - }) + // Register IO callback + errChan := make(chan error, 1) + err := e.io.Register(ioCtx, func(p io.Packet, err error) bool { if err != nil { - return err + errChan <- err + return false } + return e.dispatch(p) + }) + if err != nil { + return err } // Block until IO errors or context is cancelled @@ -91,8 +88,7 @@ func (e *engine) Run(ctx context.Context) error { } // dispatch dispatches a packet to a worker. -// This must be safe for concurrent use, as it may be called from multiple IOs. -func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool { +func (e *engine) dispatch(p io.Packet) bool { data := p.Data() ipVersion := data[0] >> 4 var layerType gopacket.LayerType @@ -102,7 +98,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool { layerType = layers.LayerTypeIPv6 } else { // Unsupported network layer - _ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil) + _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil) return true } // Load balance by stream ID @@ -112,7 +108,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool { StreamID: p.StreamID(), Packet: packet, SetVerdict: func(v io.Verdict, b []byte) error { - return ioEntry.SetVerdict(p, v, b) + return e.io.SetVerdict(p, v, b) }, }) return true diff --git a/engine/interface.go b/engine/interface.go index 1ad26e3..fe25de5 100644 --- a/engine/interface.go +++ b/engine/interface.go @@ -18,7 +18,7 @@ type Engine interface { // Config is the configuration for the engine. type Config struct { Logger Logger - IOs []io.PacketIO + IO io.PacketIO Ruleset ruleset.Ruleset Workers int // Number of workers. Zero or negative means auto (number of CPU cores). diff --git a/go.mod b/go.mod index 75e54ef..70e653a 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21 require ( github.com/bwmarrin/snowflake v0.3.0 github.com/coreos/go-iptables v0.7.0 - github.com/expr-lang/expr v1.15.7 + github.com/expr-lang/expr v1.16.3 github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866 github.com/hashicorp/golang-lru/v2 v2.0.7 diff --git a/go.sum b/go.sum index 9e35205..e75b774 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/expr-lang/expr v1.15.7 h1:BK0JcWUkoW6nrbLBo6xCKhz4BvH5DSOOu1Gx5lucyZo= -github.com/expr-lang/expr v1.15.7/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ= +github.com/expr-lang/expr v1.16.3 h1:NLldf786GffptcXNxxJx5dQ+FzeWDKChBDqOOwyK8to= +github.com/expr-lang/expr v1.16.3/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ= github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow= github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= diff --git a/io/interface.go b/io/interface.go index 6f25df1..35aa886 100644 --- a/io/interface.go +++ b/io/interface.go @@ -2,6 +2,7 @@ package io import ( "context" + "net" ) type Verdict int @@ -29,7 +30,6 @@ type Packet interface { // PacketCallback is called for each packet received. // Return false to "unregister" and stop receiving packets. -// It must be safe for concurrent use. type PacketCallback func(Packet, error) bool type PacketIO interface { @@ -39,6 +39,10 @@ type PacketIO interface { Register(context.Context, PacketCallback) error // SetVerdict sets the verdict for a packet. SetVerdict(Packet, Verdict, []byte) error + // ProtectedDialContext is like net.DialContext, but the connection is "protected" + // in the sense that the packets sent/received through the connection must bypass + // the packet IO and not be processed by the callback. + ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) // Close closes the packet IO. Close() error } diff --git a/io/nfqueue.go b/io/nfqueue.go index 499ff2c..543f247 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -5,9 +5,11 @@ import ( "encoding/binary" "errors" "fmt" + "net" "os/exec" "strconv" "strings" + "syscall" "github.com/coreos/go-iptables/iptables" "github.com/florianl/go-nfqueue" @@ -50,6 +52,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) { } for i := range table.Chains { c := &table.Chains[i] + c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept") if rst { c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset") @@ -72,6 +75,8 @@ func generateIptRules(local, rst bool) ([]iptRule, error) { } rules := make([]iptRule, 0, 4*len(chains)) for _, chain := range chains { + // Bypass protected connections + rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}}) rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}) if rst { rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}) @@ -96,6 +101,8 @@ type nfqueuePacketIO struct { // iptables not nil = use iptables instead of nftables ipt4 *iptables.IPTables ipt6 *iptables.IPTables + + protectedDialer *net.Dialer } type NFQueuePacketIOConfig struct { @@ -153,6 +160,18 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { rst: config.RST, ipt4: ipt4, ipt6: ipt6, + protectedDialer: &net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + var err error + cErr := c.Control(func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept) + }) + if cErr != nil { + return cErr + } + return err + }, + }, }, nil } @@ -239,6 +258,10 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro } } +func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) { + return n.protectedDialer.DialContext(ctx, network, address) +} + func (n *nfqueuePacketIO) Close() error { if n.rSet { if n.ipt4 != nil { diff --git a/ruleset/builtins/geo/geo_matcher.go b/ruleset/builtins/geo/geo_matcher.go index 2032f90..1bb0f30 100644 --- a/ruleset/builtins/geo/geo_matcher.go +++ b/ruleset/builtins/geo/geo_matcher.go @@ -14,14 +14,12 @@ type GeoMatcher struct { ipMatcherLock sync.Mutex } -func NewGeoMatcher(geoSiteFilename, geoIpFilename string) (*GeoMatcher, error) { - geoLoader := NewDefaultGeoLoader(geoSiteFilename, geoIpFilename) - +func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher { return &GeoMatcher{ - geoLoader: geoLoader, + geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename), geoSiteMatcher: make(map[string]hostMatcher), geoIpMatcher: make(map[string]hostMatcher), - }, nil + } } func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool { diff --git a/ruleset/expr.go b/ruleset/expr.go index 7de924c..868a115 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -1,11 +1,15 @@ package ruleset import ( + "context" "fmt" "net" "os" "reflect" "strings" + "time" + + "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr" "github.com/expr-lang/expr/ast" @@ -55,10 +59,9 @@ type compiledExprRule struct { var _ Ruleset = (*exprRuleset)(nil) type exprRuleset struct { - Rules []compiledExprRule - Ans []analyzer.Analyzer - Logger Logger - GeoMatcher *geo.GeoMatcher + Rules []compiledExprRule + Ans []analyzer.Analyzer + Logger Logger } func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { @@ -100,10 +103,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier fullAnMap := analyzersToMap(ans) fullModMap := modifiersToMap(mods) depAnMap := make(map[string]analyzer.Analyzer) - geoMatcher, err := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) - if err != nil { - return nil, err - } + funcMap := buildFunctionMap(config) // Compile all rules and build a map of analyzers that are used by the rules. for _, rule := range rules { if rule.Action == "" && !rule.Log { @@ -118,13 +118,19 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier action = &a } visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)} - patcher := &idPatcher{} + patcher := &idPatcher{FuncMap: funcMap} program, err := expr.Compile(rule.Expr, func(c *conf.Config) { c.Strict = false c.Expect = reflect.Bool c.Visitors = append(c.Visitors, visitor, patcher) - registerBuiltinFunctions(c.Functions, geoMatcher) + for name, f := range funcMap { + c.Functions[name] = &builtin.Function{ + Name: name, + Func: f.Func, + Types: f.Types, + } + } }, ) if err != nil { @@ -138,24 +144,15 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier if isBuiltInAnalyzer(name) || visitor.Variables[name] { continue } - // Check if it's one of the built-in functions, and if so, - // skip it as an analyzer & do initialization if necessary. - switch name { - case "geoip": - if err := geoMatcher.LoadGeoIP(); err != nil { - return nil, fmt.Errorf("rule %q failed to load geoip: %w", rule.Name, err) - } - case "geosite": - if err := geoMatcher.LoadGeoSite(); err != nil { - return nil, fmt.Errorf("rule %q failed to load geosite: %w", rule.Name, err) - } - case "cidr": - // No initialization needed for CIDR. - default: - a, ok := fullAnMap[name] - if !ok { - return nil, fmt.Errorf("rule %q uses unknown analyzer %q", rule.Name, name) + if f, ok := funcMap[name]; ok { + // Built-in function, initialize if necessary + if f.InitFunc != nil { + if err := f.InitFunc(); err != nil { + return nil, fmt.Errorf("rule %q failed to initialize function %q: %w", rule.Name, name, err) + } } + } else if a, ok := fullAnMap[name]; ok { + // Analyzer, add to dependency map depAnMap[name] = a } } @@ -184,37 +181,12 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier depAns = append(depAns, a) } return &exprRuleset{ - Rules: compiledRules, - Ans: depAns, - Logger: config.Logger, - GeoMatcher: geoMatcher, + Rules: compiledRules, + Ans: depAns, + Logger: config.Logger, }, nil } -func registerBuiltinFunctions(funcMap map[string]*ast.Function, geoMatcher *geo.GeoMatcher) { - funcMap["geoip"] = &ast.Function{ - Name: "geoip", - Func: func(params ...any) (any, error) { - return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil - }, - Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)}, - } - funcMap["geosite"] = &ast.Function{ - Name: "geosite", - Func: func(params ...any) (any, error) { - return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil - }, - Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, - } - funcMap["cidr"] = &ast.Function{ - Name: "cidr", - Func: func(params ...any) (any, error) { - return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil - }, - Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)}, - } -} - func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { m := map[string]interface{}{ "id": info.ID, @@ -299,29 +271,109 @@ func (v *idVisitor) Visit(node *ast.Node) { // idPatcher patches the AST during expr compilation, replacing certain values with // their internal representations for better runtime performance. type idPatcher struct { - Err error + FuncMap map[string]*Function + Err error } func (p *idPatcher) Visit(node *ast.Node) { switch (*node).(type) { case *ast.CallNode: callNode := (*node).(*ast.CallNode) - if callNode.Func == nil { + if callNode.Callee == nil { // Ignore invalid call nodes return } - switch callNode.Func.Name { - case "cidr": - cidrStringNode, ok := callNode.Arguments[1].(*ast.StringNode) - if !ok { - return - } - cidr, err := builtins.CompileCIDR(cidrStringNode.Value) - if err != nil { - p.Err = err - return + if f, ok := p.FuncMap[callNode.Callee.String()]; ok { + if f.PatchFunc != nil { + if err := f.PatchFunc(&callNode.Arguments); err != nil { + p.Err = err + return + } } - callNode.Arguments[1] = &ast.ConstantNode{Value: cidr} } } } + +type Function struct { + InitFunc func() error + PatchFunc func(args *[]ast.Node) error + Func func(params ...any) (any, error) + Types []reflect.Type +} + +func buildFunctionMap(config *BuiltinConfig) map[string]*Function { + geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) + return map[string]*Function{ + "geoip": { + InitFunc: geoMatcher.LoadGeoIP, + PatchFunc: nil, + Func: func(params ...any) (any, error) { + return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil + }, + Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)}, + }, + "geosite": { + InitFunc: geoMatcher.LoadGeoSite, + PatchFunc: nil, + Func: func(params ...any) (any, error) { + return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil + }, + Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, + }, + "cidr": { + InitFunc: nil, + PatchFunc: func(args *[]ast.Node) error { + cidrStringNode, ok := (*args)[1].(*ast.StringNode) + if !ok { + return fmt.Errorf("cidr: invalid argument type") + } + cidr, err := builtins.CompileCIDR(cidrStringNode.Value) + if err != nil { + return err + } + (*args)[1] = &ast.ConstantNode{Value: cidr} + return nil + }, + Func: func(params ...any) (any, error) { + return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil + }, + Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)}, + }, + "lookup": { + InitFunc: nil, + PatchFunc: func(args *[]ast.Node) error { + var serverStr *ast.StringNode + if len(*args) > 1 { + // Has the optional server argument + var ok bool + serverStr, ok = (*args)[1].(*ast.StringNode) + if !ok { + return fmt.Errorf("lookup: invalid argument type") + } + } + r := &net.Resolver{ + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + if serverStr != nil { + address = serverStr.Value + } + return config.ProtectedDialContext(ctx, network, address) + }, + } + if len(*args) > 1 { + (*args)[1] = &ast.ConstantNode{Value: r} + } else { + *args = append(*args, &ast.ConstantNode{Value: r}) + } + return nil + }, + Func: func(params ...any) (any, error) { + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string)) + }, + Types: []reflect.Type{ + reflect.TypeOf((func(string, *net.Resolver) []string)(nil)), + }, + }, + } +} diff --git a/ruleset/interface.go b/ruleset/interface.go index 60af75d..535c2a4 100644 --- a/ruleset/interface.go +++ b/ruleset/interface.go @@ -1,6 +1,7 @@ package ruleset import ( + "context" "net" "strconv" @@ -100,7 +101,8 @@ type Logger interface { } type BuiltinConfig struct { - Logger Logger - GeoSiteFilename string - GeoIpFilename string + Logger Logger + GeoSiteFilename string + GeoIpFilename string + ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error) }