diff --git a/go.mod b/go.mod index dc28c18..4f2bdd2 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,8 @@ require ( github.com/go-ole/go-ole v1.3.0 github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f github.com/sagernet/netlink v0.0.0-20240523065131-45e60152f9ba - github.com/sagernet/sing v0.4.0 + github.com/sagernet/nftables v0.3.0-beta.2 + github.com/sagernet/sing v0.5.0-alpha.8 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/net v0.25.0 golang.org/x/sys v0.20.0 @@ -15,6 +16,11 @@ require ( require ( github.com/google/btree v1.1.2 // indirect - github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect + github.com/google/go-cmp v0.5.9 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.4.1 // indirect + github.com/vishvananda/netns v0.0.4 // indirect + golang.org/x/sync v0.1.0 // indirect golang.org/x/time v0.5.0 // indirect ) diff --git a/go.sum b/go.sum index 6211ef2..0f08ed7 100644 --- a/go.sum +++ b/go.sum @@ -5,21 +5,32 @@ github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= +github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f h1:NkhuupzH5ch7b/Y/6ZHJWrnNLoiNnSJaow6DPb8VW2I= github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f/go.mod h1:KXmw+ouSJNOsuRpg4wgwwCQuunrGz4yoAqQjsLjc6N0= github.com/sagernet/netlink v0.0.0-20240523065131-45e60152f9ba h1:EY5AS7CCtfmARNv2zXUOrsEMPFDGYxaw65JzA2p51Vk= github.com/sagernet/netlink v0.0.0-20240523065131-45e60152f9ba/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= -github.com/sagernet/sing v0.4.0 h1:sCLSqLHOptgFvzQO9FfaYMl4PONePZkclMznpeKhdHc= -github.com/sagernet/sing v0.4.0/go.mod h1:Xh4KO9nGdvm4K/LVg9Xn9jSxJdqe9KcXbAzNC1S2qfw= +github.com/sagernet/nftables v0.3.0-beta.2 h1:yKqMl4Dpb6nKxAmlE6fXjJRlLO2c1f2wyNFBg4hBr8w= +github.com/sagernet/nftables v0.3.0-beta.2/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= +github.com/sagernet/sing v0.5.0-alpha.8 h1:2KtzBvKP6hwknsi/G6H4vRgR4it31HQ6quLb0Woze7c= +github.com/sagernet/sing v0.5.0-alpha.8/go.mod h1:Xh4KO9nGdvm4K/LVg9Xn9jSxJdqe9KcXbAzNC1S2qfw= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= -github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/redirect.go b/redirect.go new file mode 100644 index 0000000..d61db97 --- /dev/null +++ b/redirect.go @@ -0,0 +1,22 @@ +package tun + +import ( + "context" + + "github.com/sagernet/sing/common/logger" +) + +type AutoRedirect interface { + Start() error + Close() error +} + +type AutoRedirectOptions struct { + TunOptions *Options + Context context.Context + Handler Handler + Logger logger.Logger + TableName string + DisableNFTables bool + CustomRedirectPort func() int +} diff --git a/redirect_iptables.go b/redirect_iptables.go new file mode 100644 index 0000000..fd1bfa0 --- /dev/null +++ b/redirect_iptables.go @@ -0,0 +1,227 @@ +//go:build linux + +package tun + +import ( + "net/netip" + "os/exec" + "strings" + + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + + "golang.org/x/sys/unix" +) + +func (r *autoRedirect) iptablesPathForFamily(family int) string { + if family == unix.AF_INET { + return r.iptablesPath + } else { + return r.ip6tablesPath + } +} + +func (r *autoRedirect) setupIPTables(family int) error { + tableNameOutput := r.tableName + "-output" + tableNameForward := r.tableName + "-forward" + tableNamePreRouteing := r.tableName + "-prerouting" + iptablesPath := r.iptablesPathForFamily(family) + redirectPort := r.redirectPort() + // OUTPUT + err := r.runShell(iptablesPath, "-t nat -N", tableNameOutput) + if err != nil { + return err + } + err = r.runShell(iptablesPath, "-t nat -A", tableNameOutput, + "-p tcp -o", r.tunOptions.Name, + "-j REDIRECT --to-ports", redirectPort) + if err != nil { + return err + } + err = r.runShell(iptablesPath, "-t nat -I OUTPUT -j", tableNameOutput) + if err != nil { + return err + } + if r.androidSu { + return nil + } + // FORWARD + err = r.runShell(iptablesPath, "-N", tableNameForward) + if err != nil { + return err + } + err = r.runShell(iptablesPath, "-A", tableNameForward, + "-i", r.tunOptions.Name, "-j", "ACCEPT") + if err != nil { + return err + } + err = r.runShell(iptablesPath, "-A", tableNameForward, + "-o", r.tunOptions.Name, "-j", "ACCEPT") + if err != nil { + return err + } + err = r.runShell(iptablesPath, "-I FORWARD -j", tableNameForward) + if err != nil { + return err + } + // PREROUTING + err = r.runShell(iptablesPath, "-t nat -N", tableNamePreRouteing) + if err != nil { + return err + } + var ( + routeAddress []netip.Prefix + routeExcludeAddress []netip.Prefix + ) + if family == unix.AF_INET { + routeAddress = r.tunOptions.Inet4RouteAddress + routeExcludeAddress = r.tunOptions.Inet4RouteExcludeAddress + } else { + routeAddress = r.tunOptions.Inet6RouteAddress + routeExcludeAddress = r.tunOptions.Inet6RouteExcludeAddress + } + if len(routeAddress) > 0 && (len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0) { + return E.New("`*_route_address` is conflict with `include_interface` or `include_uid`") + } + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-i", r.tunOptions.Name, "-j RETURN") + if err != nil { + return err + } + for _, address := range routeExcludeAddress { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-d", address.String(), "-j RETURN") + if err != nil { + return err + } + } + for _, name := range r.tunOptions.ExcludeInterface { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-i", name, "-j RETURN") + if err != nil { + return err + } + } + for _, uid := range r.tunOptions.ExcludeUID { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-m owner --uid-owner", uid, "-j RETURN") + if err != nil { + return err + } + } + var dnsServerAddress netip.Addr + if family == unix.AF_INET { + dnsServerAddress = r.tunOptions.Inet4Address[0].Addr().Next() + } else { + dnsServerAddress = r.tunOptions.Inet6Address[0].Addr().Next() + } + if len(routeAddress) > 0 { + for _, address := range routeAddress { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-d", address.String(), "-p udp --dport 53 -j DNAT --to", dnsServerAddress) + if err != nil { + return err + } + } + } else if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { + for _, name := range r.tunOptions.IncludeInterface { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-i", name, "-p udp --dport 53 -j DNAT --to", dnsServerAddress) + if err != nil { + return err + } + } + for _, uidRange := range r.tunOptions.IncludeUID { + for uid := uidRange.Start; uid <= uidRange.End; uid++ { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-m owner --uid-owner", uid, "-p udp --dport 53 -j DNAT --to", dnsServerAddress) + if err != nil { + return err + } + } + } + } else { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-p udp --dport 53 -j DNAT --to", dnsServerAddress) + if err != nil { + return err + } + } + + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, "-m addrtype --dst-type LOCAL -j RETURN") + if err != nil { + return err + } + + if len(routeAddress) > 0 { + for _, address := range routeAddress { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-d", address.String(), "-p tcp -j REDIRECT --to-ports", redirectPort) + if err != nil { + return err + } + } + } else if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { + for _, name := range r.tunOptions.IncludeInterface { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-i", name, "-p tcp -j REDIRECT --to-ports", redirectPort) + if err != nil { + return err + } + } + for _, uidRange := range r.tunOptions.IncludeUID { + for uid := uidRange.Start; uid <= uidRange.End; uid++ { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-m owner --uid-owner", uid, "-p tcp -j REDIRECT --to-ports", redirectPort) + if err != nil { + return err + } + } + } + } else { + err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, + "-p tcp -j REDIRECT --to-ports", redirectPort) + if err != nil { + return err + } + } + err = r.runShell(iptablesPath, "-t nat -I PREROUTING -j", tableNamePreRouteing) + if err != nil { + return err + } + return nil +} + +func (r *autoRedirect) cleanupIPTables(family int) { + tableNameOutput := r.tableName + "-output" + tableNameForward := r.tableName + "-forward" + tableNamePreRouteing := r.tableName + "-prerouting" + iptablesPath := r.iptablesPathForFamily(family) + _ = r.runShell(iptablesPath, "-t nat -D OUTPUT -j", tableNameOutput) + _ = r.runShell(iptablesPath, "-t nat -F", tableNameOutput) + _ = r.runShell(iptablesPath, "-t nat -X", tableNameOutput) + if !r.androidSu { + _ = r.runShell(iptablesPath, "-D FORWARD -j", tableNameForward) + _ = r.runShell(iptablesPath, "-F", tableNameForward) + _ = r.runShell(iptablesPath, "-X", tableNameForward) + _ = r.runShell(iptablesPath, "-t nat -D PREROUTING -j", tableNamePreRouteing) + _ = r.runShell(iptablesPath, "-t nat -F", tableNamePreRouteing) + _ = r.runShell(iptablesPath, "-t nat -X", tableNamePreRouteing) + } +} + +func (r *autoRedirect) runShell(commands ...any) error { + commandStr := strings.Join(F.MapToString(commands), " ") + var command *exec.Cmd + if r.androidSu { + command = exec.Command(r.suPath, "-c", commandStr) + } else { + commandArray := strings.Split(commandStr, " ") + command = exec.Command(commandArray[0], commandArray[1:]...) + } + combinedOutput, err := command.CombinedOutput() + if err != nil { + return E.Extend(err, F.ToString(commandStr, ": ", string(combinedOutput))) + } + return nil +} diff --git a/redirect_linux.go b/redirect_linux.go new file mode 100644 index 0000000..fe42b1e --- /dev/null +++ b/redirect_linux.go @@ -0,0 +1,185 @@ +package tun + +import ( + "context" + "github.com/sagernet/nftables" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + "net/netip" + "os" + "os/exec" + "runtime" + + "golang.org/x/sys/unix" +) + +type autoRedirect struct { + tunOptions *Options + ctx context.Context + handler Handler + logger logger.Logger + tableName string + customRedirectPortFunc func() int + customRedirectPort int + redirectServer *redirectServer + enableIPv4 bool + enableIPv6 bool + iptablesPath string + ip6tablesPath string + useNFTables bool + androidSu bool + suPath string +} + +func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { + r := &autoRedirect{ + tunOptions: options.TunOptions, + ctx: options.Context, + handler: options.Handler, + logger: options.Logger, + useNFTables: runtime.GOOS != "android" && !options.DisableNFTables, + customRedirectPortFunc: options.CustomRedirectPort, + } + var err error + if runtime.GOOS == "android" { + r.enableIPv4 = true + r.iptablesPath = "/system/bin/iptables" + userId := os.Getuid() + if userId != 0 { + r.androidSu = true + for _, suPath := range []string{ + "su", + "/system/bin/su", + } { + r.suPath, err = exec.LookPath(suPath) + if err == nil { + break + } + } + if err != nil { + return nil, E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH")) + } + } + } else { + if r.useNFTables { + err = r.initializeNFTables() + if err != nil && err != os.ErrInvalid { + r.logger.Debug("device has no nftables support: ", err) + } + } + if len(r.tunOptions.Inet4Address) > 0 { + r.enableIPv4 = true + if !r.useNFTables { + r.iptablesPath, err = exec.LookPath("iptables") + if err != nil { + return nil, E.Cause(err, "iptables is required") + } + } + } + if len(r.tunOptions.Inet6Address) > 0 { + r.enableIPv6 = true + if !r.useNFTables { + r.ip6tablesPath, err = exec.LookPath("ip6tables") + if err != nil { + if !r.enableIPv4 { + return nil, E.Cause(err, "ip6tables is required") + } else { + r.enableIPv6 = false + r.logger.Error("device has no ip6tables nat support: ", err) + } + } + } + } + } + return r, nil +} + +func (r *autoRedirect) Start() error { + if r.customRedirectPortFunc != nil { + r.customRedirectPort = r.customRedirectPortFunc() + } + if r.customRedirectPort == 0 { + var listenAddr netip.Addr + if runtime.GOOS == "android" { + listenAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) + } else if r.enableIPv6 { + listenAddr = netip.IPv6Unspecified() + } else { + listenAddr = netip.IPv4Unspecified() + } + server := newRedirectServer(r.ctx, r.handler, r.logger, listenAddr) + err := server.Start() + if err != nil { + return E.Cause(err, "start redirect server") + } + r.redirectServer = server + } + return r.setupTables() +} + +func (r *autoRedirect) Close() error { + r.cleanupTables() + return common.Close( + common.PtrOrNil(r.redirectServer), + ) +} + +func (r *autoRedirect) initializeNFTables() error { + nft, err := nftables.New() + if err != nil { + return err + } + defer nft.CloseLasting() + _, err = nft.ListTablesOfFamily(unix.AF_INET) + if err != nil { + return err + } + r.useNFTables = true + return nil +} + +func (r *autoRedirect) redirectPort() uint16 { + if r.customRedirectPort > 0 { + return uint16(r.customRedirectPort) + } + return M.AddrPortFromNet(r.redirectServer.listener.Addr()).Port() +} + +func (r *autoRedirect) setupTables() error { + var setupTables func(int) error + if r.useNFTables { + setupTables = r.setupNFTables + } else { + setupTables = r.setupIPTables + } + if r.enableIPv4 { + err := setupTables(unix.AF_INET) + if err != nil { + return err + } + } + if r.enableIPv6 { + err := setupTables(unix.AF_INET6) + if err != nil { + return err + } + } + return nil +} + +func (r *autoRedirect) cleanupTables() { + var cleanupTables func(int) + if r.useNFTables { + cleanupTables = r.cleanupNFTables + } else { + cleanupTables = r.cleanupIPTables + } + if r.enableIPv4 { + cleanupTables(unix.AF_INET) + } + if r.enableIPv6 { + cleanupTables(unix.AF_INET6) + } +} diff --git a/redirect_nftables.go b/redirect_nftables.go new file mode 100644 index 0000000..6b3d290 --- /dev/null +++ b/redirect_nftables.go @@ -0,0 +1,231 @@ +//go:build linux + +package tun + +import ( + "net/netip" + + "github.com/sagernet/nftables" + "github.com/sagernet/nftables/binaryutil" + "github.com/sagernet/nftables/expr" + F "github.com/sagernet/sing/common/format" + + "golang.org/x/sys/unix" +) + +const ( + nftablesChainOutput = "output" + nftablesChainForward = "forward" + nftablesChainPreRouting = "prerouting" +) + +func nftablesFamily(family int) nftables.TableFamily { + switch family { + case unix.AF_INET: + return nftables.TableFamilyIPv4 + case unix.AF_INET6: + return nftables.TableFamilyIPv6 + default: + panic(F.ToString("unknown family ", family)) + } +} + +func (r *autoRedirect) setupNFTables(family int) error { + nft, err := nftables.New() + if err != nil { + return err + } + defer nft.CloseLasting() + + redirectPort := r.redirectPort() + + table := nft.AddTable(&nftables.Table{ + Name: r.tableName, + Family: nftablesFamily(family), + }) + + chainOutput := nft.AddChain(&nftables.Chain{ + Name: nftablesChainOutput, + Table: table, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeNAT, + }) + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainOutput, + Exprs: nftablesRuleIfName(expr.MetaKeyOIFNAME, r.tunOptions.Name, nftablesRuleRedirectToPorts(redirectPort)...), + }) + + chainForward := nft.AddChain(&nftables.Chain{ + Name: nftablesChainForward, + Table: table, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityMangle, + }) + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainForward, + Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, r.tunOptions.Name, &expr.Verdict{ + Kind: expr.VerdictAccept, + }), + }) + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainForward, + Exprs: nftablesRuleIfName(expr.MetaKeyOIFNAME, r.tunOptions.Name, &expr.Verdict{ + Kind: expr.VerdictAccept, + }), + }) + + chainPreRouting := nft.AddChain(&nftables.Chain{ + Name: nftablesChainPreRouting, + Table: table, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeNAT, + }) + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, r.tunOptions.Name, &expr.Verdict{ + Kind: expr.VerdictReturn, + }), + }) + var ( + routeAddress []netip.Prefix + routeExcludeAddress []netip.Prefix + ) + if table.Family == nftables.TableFamilyIPv4 { + routeAddress = r.tunOptions.Inet4RouteAddress + routeExcludeAddress = r.tunOptions.Inet4RouteExcludeAddress + } else { + routeAddress = r.tunOptions.Inet6RouteAddress + routeExcludeAddress = r.tunOptions.Inet6RouteExcludeAddress + } + for _, address := range routeExcludeAddress { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleDestinationAddress(address, &expr.Verdict{ + Kind: expr.VerdictReturn, + }), + }) + } + for _, name := range r.tunOptions.ExcludeInterface { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, &expr.Verdict{ + Kind: expr.VerdictReturn, + }), + }) + } + for _, uidRange := range r.tunOptions.ExcludeUID { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, &expr.Verdict{ + Kind: expr.VerdictReturn, + }), + }) + } + + var routeExprs []expr.Any + if len(routeAddress) > 0 { + for _, address := range routeAddress { + routeExprs = append(routeExprs, nftablesRuleDestinationAddress(address)...) + } + } + + var dnsServerAddress netip.Addr + if table.Family == nftables.TableFamilyIPv4 { + dnsServerAddress = r.tunOptions.Inet4Address[0].Addr().Next() + } else { + dnsServerAddress = r.tunOptions.Inet6Address[0].Addr().Next() + } + + if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { + for _, name := range r.tunOptions.IncludeInterface { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...)...), + }) + } + for _, uidRange := range r.tunOptions.IncludeUID { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...)...), + }) + } + } else { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...), + }) + } + + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: []expr.Any{ + &expr.Fib{ + Register: 1, + FlagDADDR: true, + ResultADDRTYPE: true, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), + }, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + + if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { + for _, name := range r.tunOptions.IncludeInterface { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleRedirectToPorts(redirectPort)...)...), + }) + } + for _, uidRange := range r.tunOptions.IncludeUID { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleRedirectToPorts(redirectPort)...)...), + }) + } + } else { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: append(routeExprs, nftablesRuleRedirectToPorts(redirectPort)...), + }) + } + return nft.Flush() +} + +func (r *autoRedirect) cleanupNFTables(family int) { + conn, err := nftables.New() + if err != nil { + return + } + conn.FlushTable(&nftables.Table{ + Name: r.tableName, + Family: nftablesFamily(family), + }) + conn.DelTable(&nftables.Table{ + Name: r.tableName, + Family: nftablesFamily(family), + }) + _ = conn.Flush() + _ = conn.CloseLasting() +} diff --git a/redirect_nftables_expr.go b/redirect_nftables_expr.go new file mode 100644 index 0000000..bf2d46f --- /dev/null +++ b/redirect_nftables_expr.go @@ -0,0 +1,153 @@ +//go:build linux + +package tun + +import ( + "net" + "net/netip" + + "github.com/sagernet/nftables" + "github.com/sagernet/nftables/binaryutil" + "github.com/sagernet/nftables/expr" + "github.com/sagernet/sing/common/ranges" + + "golang.org/x/sys/unix" +) + +func nftablesIfname(n string) []byte { + b := make([]byte, 16) + copy(b, n+"\x00") + return b +} + +func nftablesRuleIfName(key expr.MetaKey, value string, exprs ...expr.Any) []expr.Any { + newExprs := []expr.Any{ + &expr.Meta{Key: key, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: nftablesIfname(value), + }, + } + newExprs = append(newExprs, exprs...) + return newExprs +} + +func nftablesRuleMetaUInt32Range(key expr.MetaKey, uidRange ranges.Range[uint32], exprs ...expr.Any) []expr.Any { + newExprs := []expr.Any{ + &expr.Meta{Key: key, Register: 1}, + &expr.Range{ + Op: expr.CmpOpEq, + Register: 1, + FromData: binaryutil.BigEndian.PutUint32(uidRange.Start), + ToData: binaryutil.BigEndian.PutUint32(uidRange.End), + }, + } + newExprs = append(newExprs, exprs...) + return newExprs +} + +func nftablesRuleDestinationAddress(address netip.Prefix, exprs ...expr.Any) []expr.Any { + var newExprs []expr.Any + if address.Addr().Is4() { + newExprs = append(newExprs, &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + SourceRegister: 0, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Xor: make([]byte, 4), + Mask: net.CIDRMask(address.Bits(), 32), + }) + } else { + newExprs = append(newExprs, &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + SourceRegister: 0, + Base: expr.PayloadBaseNetworkHeader, + Offset: 24, + Len: 16, + }, &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 16, + Xor: make([]byte, 16), + Mask: net.CIDRMask(address.Bits(), 128), + }) + } + newExprs = append(newExprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: address.Masked().Addr().AsSlice(), + }) + newExprs = append(newExprs, exprs...) + return newExprs +} + +func nftablesRuleHijackDNS(family nftables.TableFamily, dnsServerAddress netip.Addr) []expr.Any { + return []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_UDP}, + }, + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + SourceRegister: 0, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(53), + }, &expr.Immediate{ + Register: 1, + Data: dnsServerAddress.AsSlice(), + }, &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(family), + RegAddrMin: 1, + }, + } +} + +const ( + NF_NAT_RANGE_MAP_IPS = 1 << iota + NF_NAT_RANGE_PROTO_SPECIFIED + NF_NAT_RANGE_PROTO_RANDOM + NF_NAT_RANGE_PERSISTENT + NF_NAT_RANGE_PROTO_RANDOM_FULLY + NF_NAT_RANGE_PROTO_OFFSET +) + +func nftablesRuleRedirectToPorts(redirectPort uint16) []expr.Any { + return []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.BigEndian.PutUint16(redirectPort), + }, &expr.Redir{ + RegisterProtoMin: 1, + Flags: NF_NAT_RANGE_PROTO_SPECIFIED, + }, + } +} diff --git a/redirect_server.go b/redirect_server.go new file mode 100644 index 0000000..1915bcb --- /dev/null +++ b/redirect_server.go @@ -0,0 +1,88 @@ +//go:build linux + +package tun + +import ( + "context" + "errors" + "net" + "net/netip" + "time" + + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" +) + +const ProtocolRedirect = "redirect" + +type redirectServer struct { + ctx context.Context + handler Handler + logger logger.Logger + listenAddr netip.Addr + listener *net.TCPListener + inShutdown atomic.Bool +} + +func newRedirectServer(ctx context.Context, handler Handler, logger logger.Logger, listenAddr netip.Addr) *redirectServer { + return &redirectServer{ + ctx: ctx, + handler: handler, + logger: logger, + listenAddr: listenAddr, + } +} + +func (s *redirectServer) Start() error { + var listenConfig net.ListenConfig + // listenConfig.KeepAlive = C.TCPKeepAliveInitial + listenConfig.KeepAlive = 10 * time.Minute + listener, err := listenConfig.Listen(s.ctx, M.NetworkFromNetAddr("tcp", s.listenAddr), M.SocksaddrFrom(s.listenAddr, 0).String()) + if err != nil { + return err + } + s.listener = listener.(*net.TCPListener) + go s.loopIn() + return nil +} + +func (s *redirectServer) Close() error { + s.inShutdown.Store(true) + return s.listener.Close() +} + +func (s *redirectServer) loopIn() { + for { + conn, err := s.listener.AcceptTCP() + if err != nil { + var netError net.Error + //goland:noinspection GoDeprecation + //nolint:staticcheck + if errors.As(err, &netError) && netError.Temporary() { + s.logger.Error(err) + continue + } + if s.inShutdown.Load() && E.IsClosed(err) { + return + } + s.listener.Close() + s.logger.Error("serve error: ", err) + continue + } + var metadata M.Metadata + metadata.Protocol = ProtocolRedirect + metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()) + destination, err := control.GetOriginalDestination(conn) + if err != nil { + _ = conn.SetLinger(0) + _ = conn.Close() + s.logger.Error("process connection from ", metadata.Source, ": invalid connection: ", err) + continue + } + metadata.Destination = M.SocksaddrFromNetIP(destination) + go s.handler.NewConnection(s.ctx, conn, metadata) + } +} diff --git a/redirect_stub.go b/redirect_stub.go new file mode 100644 index 0000000..040ef12 --- /dev/null +++ b/redirect_stub.go @@ -0,0 +1,11 @@ +//go:build !linux + +package tun + +import ( + "os" +) + +func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { + return nil, os.ErrInvalid +}