diff --git a/.github/workflows/test-startup.yml b/.github/workflows/test-startup.yml index a386b2b..62bf71b 100644 --- a/.github/workflows/test-startup.yml +++ b/.github/workflows/test-startup.yml @@ -1,4 +1,4 @@ -name: Go +name: Test systemd on: push: @@ -18,7 +18,7 @@ jobs: go-version: "1.21" - name: Build - run: go build -v ./... + run: go build -v - name: Install go-mmproxy run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..2049f6b --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,31 @@ +name: Test + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.21" + + - name: Build + run: go build -v + + - name: Prepare ip routes + run: | + sudo ip rule add from 127.0.0.1/8 iif lo table 123 + sudo ip route add local 0.0.0.0/0 dev lo table 123 + sudo ip -6 rule add from ::1/128 iif lo table 123 + sudo ip -6 route add local ::/0 dev lo table 123 + + - name: Test + run: sudo go test -v -timeout 30s ./tests diff --git a/buffers.go b/buffers.go deleted file mode 100644 index d51bf71..0000000 --- a/buffers.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2019 Path Network, Inc. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "math" - "sync" -) - -var buffers sync.Pool - -func init() { - buffers.New = func() any { return make([]byte, math.MaxUint16) } -} - -func GetBuffer() []byte { - return buffers.Get().([]byte) -} - -func PutBuffer(buf []byte) { - buffers.Put(buf) // nolint:staticcheck -} diff --git a/buffers/buffers.go b/buffers/buffers.go new file mode 100644 index 0000000..86a782c --- /dev/null +++ b/buffers/buffers.go @@ -0,0 +1,28 @@ +// Copyright 2019 Path Network, Inc. All rights reserved. +// Copyright 2024 Konrad Zemek +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package buffers + +import ( + "math" + "sync" +) + +var buffers sync.Pool + +func init() { + buffers.New = func() any { + slice := make([]byte, math.MaxUint16) + return &slice + } +} + +func Get() []byte { + return *buffers.Get().(*[]byte) +} + +func Put(buf []byte) { + buffers.Put(&buf) +} diff --git a/main.go b/main.go index 6fc7ffd..f34d35f 100644 --- a/main.go +++ b/main.go @@ -1,4 +1,5 @@ // Copyright 2019 Path Network, Inc. All rights reserved. +// Copyright 2024 Konrad Zemek // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -6,6 +7,7 @@ package main import ( "bufio" + "context" "flag" "log/slog" "net" @@ -13,50 +15,44 @@ import ( "os" "syscall" "time" + + "github.com/kzemek/go-mmproxy/tcp" + "github.com/kzemek/go-mmproxy/udp" + "github.com/kzemek/go-mmproxy/utils" ) -type options struct { - Protocol string - ListenAddrStr string - TargetAddr4Str string - TargetAddr6Str string - ListenAddr netip.AddrPort - TargetAddr4 netip.AddrPort - TargetAddr6 netip.AddrPort - Mark int - Verbose int - allowedSubnetsPath string - AllowedSubnets []*net.IPNet - Listeners int - Logger *slog.Logger - udpCloseAfter int - UDPCloseAfter time.Duration -} +var protocolStr string +var listenAddrStr string +var targetAddr4Str string +var targetAddr6Str string +var allowedSubnetsPath string +var udpCloseAfterInt int +var listeners int -var Opts options +var opts utils.Options func init() { - flag.StringVar(&Opts.Protocol, "p", "tcp", "Protocol that will be proxied: tcp, udp") - flag.StringVar(&Opts.ListenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on") - flag.StringVar(&Opts.TargetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to") - flag.StringVar(&Opts.TargetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to") - flag.IntVar(&Opts.Mark, "mark", 0, "The mark that will be set on outbound packets") - flag.IntVar(&Opts.Verbose, "v", 0, `0 - no logging of individual connections + flag.StringVar(&protocolStr, "p", "tcp", "Protocol that will be proxied: tcp, udp") + flag.StringVar(&listenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on") + flag.StringVar(&targetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to") + flag.StringVar(&targetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to") + flag.IntVar(&opts.Mark, "mark", 0, "The mark that will be set on outbound packets") + flag.IntVar(&opts.Verbose, "v", 0, `0 - no logging of individual connections 1 - log errors occurring in individual connections 2 - log all state changes of individual connections`) - flag.StringVar(&Opts.allowedSubnetsPath, "allowed-subnets", "", + flag.StringVar(&allowedSubnetsPath, "allowed-subnets", "", "Path to a file that contains allowed subnets of the proxy servers") - flag.IntVar(&Opts.Listeners, "listeners", 1, + flag.IntVar(&listeners, "listeners", 1, "Number of listener sockets that will be opened for the listen address (Linux 3.9+)") - flag.IntVar(&Opts.udpCloseAfter, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up") + flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up") } -func listen(listenerNum int, errors chan<- error) { - logger := Opts.Logger.With(slog.Int("listenerNum", listenerNum), - slog.String("protocol", Opts.Protocol), slog.String("listenAdr", Opts.ListenAddr.String())) +func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) { + logger := parentLogger.With(slog.Int("listenerNum", listenerNum), + slog.String("protocol", protocolStr), slog.String("listenAdr", opts.ListenAddr.String())) listenConfig := net.ListenConfig{} - if Opts.Listeners > 1 { + if listeners > 1 { listenConfig.Control = func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { soReusePort := 15 @@ -67,15 +63,15 @@ func listen(listenerNum int, errors chan<- error) { } } - if Opts.Protocol == "tcp" { - tcpListen(&listenConfig, logger, errors) + if opts.Protocol == utils.TCP { + tcp.Listen(ctx, &listenConfig, &opts, logger, listenErrors) } else { - udpListen(&listenConfig, logger, errors) + udp.Listen(ctx, &listenConfig, &opts, logger, listenErrors) } } -func loadAllowedSubnets() error { - file, err := os.Open(Opts.allowedSubnetsPath) +func loadAllowedSubnets(logger *slog.Logger) error { + file, err := os.Open(allowedSubnetsPath) if err != nil { return err } @@ -84,12 +80,12 @@ func loadAllowedSubnets() error { scanner := bufio.NewScanner(file) for scanner.Scan() { - _, ipNet, err := net.ParseCIDR(scanner.Text()) + ipNet, err := netip.ParsePrefix(scanner.Text()) if err != nil { return err } - Opts.AllowedSubnets = append(Opts.AllowedSubnets, ipNet) - Opts.Logger.Info("allowed subnet", slog.String("subnet", ipNet.String())) + opts.AllowedSubnets = append(opts.AllowedSubnets, ipNet) + logger.Info("allowed subnet", slog.String("subnet", ipNet.String())) } return nil @@ -98,72 +94,79 @@ func loadAllowedSubnets() error { func main() { flag.Parse() lvl := slog.LevelInfo - if Opts.Verbose > 0 { + if opts.Verbose > 0 { lvl = slog.LevelDebug } - Opts.Logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) - if Opts.allowedSubnetsPath != "" { - if err := loadAllowedSubnets(); err != nil { - Opts.Logger.Error("failed to load allowed subnets file", "path", Opts.allowedSubnetsPath, "error", err) + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) + + if allowedSubnetsPath != "" { + if err := loadAllowedSubnets(logger); err != nil { + logger.Error("failed to load allowed subnets file", "path", allowedSubnetsPath, "error", err) } } - if Opts.Protocol != "tcp" && Opts.Protocol != "udp" { - Opts.Logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", Opts.Protocol)) + if protocolStr == "tcp" { + opts.Protocol = utils.TCP + } else if protocolStr == "udp" { + opts.Protocol = utils.UDP + } else { + logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", protocolStr)) os.Exit(1) } - if Opts.Mark < 0 { - Opts.Logger.Error("--mark has to be >= 0", slog.Int("mark", Opts.Mark)) + if opts.Mark < 0 { + logger.Error("--mark has to be >= 0", slog.Int("mark", opts.Mark)) os.Exit(1) } - if Opts.Verbose < 0 { - Opts.Logger.Error("-v has to be >= 0", slog.Int("verbose", Opts.Verbose)) + if opts.Verbose < 0 { + logger.Error("-v has to be >= 0", slog.Int("verbose", opts.Verbose)) os.Exit(1) } - if Opts.Listeners < 1 { - Opts.Logger.Error("--listeners has to be >= 1") + if listeners < 1 { + logger.Error("--listeners has to be >= 1") os.Exit(1) } var err error - if Opts.ListenAddr, err = parseHostPort(Opts.ListenAddrStr); err != nil { - Opts.Logger.Error("listen address is malformed", "error", err) + if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr); err != nil { + logger.Error("listen address is malformed", "error", err) os.Exit(1) } - if Opts.TargetAddr4, err = netip.ParseAddrPort(Opts.TargetAddr4Str); err != nil { - Opts.Logger.Error("ipv4 target address is malformed", "error", err) + if opts.TargetAddr4, err = netip.ParseAddrPort(targetAddr4Str); err != nil { + logger.Error("ipv4 target address is malformed", "error", err) os.Exit(1) } - if !Opts.TargetAddr4.Addr().Is4() { - Opts.Logger.Error("ipv4 target address is not IPv4") + if !opts.TargetAddr4.Addr().Is4() { + logger.Error("ipv4 target address is not IPv4") os.Exit(1) } - if Opts.TargetAddr6, err = netip.ParseAddrPort(Opts.TargetAddr6Str); err != nil { - Opts.Logger.Error("ipv6 target address is malformed", "error", err) + if opts.TargetAddr6, err = netip.ParseAddrPort(targetAddr6Str); err != nil { + logger.Error("ipv6 target address is malformed", "error", err) os.Exit(1) } - if !Opts.TargetAddr6.Addr().Is6() { - Opts.Logger.Error("ipv6 target address is not IPv6") + if !opts.TargetAddr6.Addr().Is6() { + logger.Error("ipv6 target address is not IPv6") os.Exit(1) } - if Opts.udpCloseAfter < 0 { - Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter)) + if udpCloseAfterInt < 0 { + logger.Error("--close-after has to be >= 0", slog.Int("close-after", udpCloseAfterInt)) os.Exit(1) } - Opts.UDPCloseAfter = time.Duration(Opts.udpCloseAfter) * time.Second + opts.UDPCloseAfter = time.Duration(udpCloseAfterInt) * time.Second - listenErrors := make(chan error, Opts.Listeners) - for i := 0; i < Opts.Listeners; i++ { - go listen(i, listenErrors) + listenErrors := make(chan error, listeners) + ctxs := make([]context.Context, listeners) + for i := range ctxs { + ctxs[i] = context.Background() + go listen(ctxs[i], i, logger, listenErrors) } - for i := 0; i < Opts.Listeners; i++ { + for range ctxs { <-listenErrors } } diff --git a/proxyprotocol.go b/proxyprotocol.go deleted file mode 100644 index 1547ec1..0000000 --- a/proxyprotocol.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2019 Path Network, Inc. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package main - -import ( - "bytes" - "encoding/binary" - "fmt" - "net" - "strings" -) - -func readRemoteAddrPROXYv2(ctrlBuf []byte, protocol Protocol) (net.Addr, net.Addr, []byte, error) { - if (ctrlBuf[12] >> 4) != 2 { - return nil, nil, nil, fmt.Errorf("unknown protocol version %d", ctrlBuf[12]>>4) - } - - if ctrlBuf[12]&0xF > 1 { - return nil, nil, nil, fmt.Errorf("unknown command %d", ctrlBuf[12]&0xF) - } - - if ctrlBuf[12]&0xF == 1 && ((protocol == TCP && ctrlBuf[13] != 0x11 && ctrlBuf[13] != 0x21) || - (protocol == UDP && ctrlBuf[13] != 0x12 && ctrlBuf[13] != 0x22)) { - return nil, nil, nil, fmt.Errorf("invalid family/protocol %d/%d", ctrlBuf[13]>>4, ctrlBuf[13]&0xF) - } - - var dataLen uint16 - reader := bytes.NewReader(ctrlBuf[14:16]) - if err := binary.Read(reader, binary.BigEndian, &dataLen); err != nil { - return nil, nil, nil, fmt.Errorf("failed to decode address data length: %w", err) - } - - if len(ctrlBuf) < 16+int(dataLen) { - return nil, nil, nil, fmt.Errorf("incomplete PROXY header") - } - - if ctrlBuf[12]&0xF == 0 { // LOCAL - return nil, nil, ctrlBuf[16+dataLen:], nil - } - - var sport, dport uint16 - if ctrlBuf[13]>>4 == 0x1 { // IPv4 - reader = bytes.NewReader(ctrlBuf[24:]) - } else { - reader = bytes.NewReader(ctrlBuf[48:]) - } - if err := binary.Read(reader, binary.BigEndian, &sport); err != nil { - return nil, nil, nil, fmt.Errorf("failed to decode source port: %w", err) - } - if err := binary.Read(reader, binary.BigEndian, &dport); err != nil { - return nil, nil, nil, fmt.Errorf("failed to decode destination port: %w", err) - } - - var srcIP, dstIP net.IP - if ctrlBuf[13]>>4 == 0x1 { // IPv4 - srcIP = net.IPv4(ctrlBuf[16], ctrlBuf[17], ctrlBuf[18], ctrlBuf[19]) - dstIP = net.IPv4(ctrlBuf[20], ctrlBuf[21], ctrlBuf[22], ctrlBuf[23]) - } else { - srcIP = ctrlBuf[16:32] - dstIP = ctrlBuf[32:48] - } - - if ctrlBuf[13]&0xF == 0x1 { // TCP - return &net.TCPAddr{IP: srcIP, Port: int(sport)}, - &net.TCPAddr{IP: dstIP, Port: int(dport)}, - ctrlBuf[16+dataLen:], nil - } - - return &net.UDPAddr{IP: srcIP, Port: int(sport)}, - &net.UDPAddr{IP: dstIP, Port: int(dport)}, - ctrlBuf[16+dataLen:], nil -} - -func readRemoteAddrPROXYv1(ctrlBuf []byte) (net.Addr, net.Addr, []byte, error) { - str := string(ctrlBuf) - if idx := strings.Index(str, "\r\n"); idx >= 0 { - var headerProtocol, src, dst string - var sport, dport int - n, err := fmt.Sscanf(str, "PROXY %s", &headerProtocol) - if err != nil { - return nil, nil, nil, err - } - if n != 1 { - return nil, nil, nil, fmt.Errorf("failed to decode elements") - } - if headerProtocol == "UNKNOWN" { - return nil, nil, ctrlBuf[idx+2:], nil - } - if headerProtocol != "TCP4" && headerProtocol != "TCP6" { - return nil, nil, nil, fmt.Errorf("unknown protocol %s", headerProtocol) - } - - n, err = fmt.Sscanf(str, "PROXY %s %s %s %d %d", &headerProtocol, &src, &dst, &sport, &dport) - if err != nil { - return nil, nil, nil, err - } - if n != 5 { - return nil, nil, nil, fmt.Errorf("failed to decode elements") - } - srcIP := net.ParseIP(src) - if srcIP == nil { - return nil, nil, nil, fmt.Errorf("failed to parse source IP address %s", src) - } - dstIP := net.ParseIP(dst) - if dstIP == nil { - return nil, nil, nil, fmt.Errorf("failed to parse destination IP address %s", dst) - } - return &net.TCPAddr{IP: srcIP, Port: sport}, - &net.TCPAddr{IP: dstIP, Port: dport}, - ctrlBuf[idx+2:], nil - } - - return nil, nil, nil, fmt.Errorf("did not find \\r\\n in first data segment") -} - -var proxyv2header = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} - -func proxyReadRemoteAddr(buf []byte, protocol Protocol) (net.Addr, net.Addr, []byte, error) { - if len(buf) >= 16 && bytes.Equal(buf[:12], proxyv2header) { - saddr, daddr, rest, err := readRemoteAddrPROXYv2(buf, protocol) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to parse PROXY v2 header: %w", err) - } - return saddr, daddr, rest, err - } - - // PROXYv1 only works with TCP - if protocol == TCP && len(buf) >= 8 && bytes.Equal(buf[:5], []byte("PROXY")) { - saddr, daddr, rest, err := readRemoteAddrPROXYv1(buf) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to parse PROXY v1 header: %w", err) - } - return saddr, daddr, rest, err - } - - return nil, nil, nil, fmt.Errorf("PROXY header missing") -} diff --git a/proxyprotocol/proxyprotocol.go b/proxyprotocol/proxyprotocol.go new file mode 100644 index 0000000..3a1ad5c --- /dev/null +++ b/proxyprotocol/proxyprotocol.go @@ -0,0 +1,175 @@ +// Copyright 2019 Path Network, Inc. All rights reserved. +// Copyright 2024 Konrad Zemek +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxyprotocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "net/netip" + "strings" + + "github.com/kzemek/go-mmproxy/utils" +) + +func readRemoteAddrPROXYv2(ctrlBuf []byte, protocol utils.Protocol) (saddr, daddr netip.AddrPort, data []byte, resultErr error) { + if (ctrlBuf[12] >> 4) != 2 { + resultErr = fmt.Errorf("unknown protocol version %d", ctrlBuf[12]>>4) + return + } + + if ctrlBuf[12]&0xF > 1 { + resultErr = fmt.Errorf("unknown command %d", ctrlBuf[12]&0xF) + return + } + + if ctrlBuf[12]&0xF == 1 && ((protocol == utils.TCP && ctrlBuf[13] != 0x11 && ctrlBuf[13] != 0x21) || + (protocol == utils.UDP && ctrlBuf[13] != 0x12 && ctrlBuf[13] != 0x22)) { + resultErr = fmt.Errorf("invalid family/protocol %d/%d", ctrlBuf[13]>>4, ctrlBuf[13]&0xF) + return + } + + var dataLen uint16 + reader := bytes.NewReader(ctrlBuf[14:16]) + if err := binary.Read(reader, binary.BigEndian, &dataLen); err != nil { + resultErr = fmt.Errorf("failed to decode address data length: %w", err) + return + } + + if len(ctrlBuf) < 16+int(dataLen) { + resultErr = fmt.Errorf("incomplete PROXY header") + return + } + + if ctrlBuf[12]&0xF == 0 { // LOCAL + data = ctrlBuf[16+dataLen:] + return + } + + var sport, dport uint16 + if ctrlBuf[13]>>4 == 0x1 { // IPv4 + reader = bytes.NewReader(ctrlBuf[24:]) + } else { + reader = bytes.NewReader(ctrlBuf[48:]) + } + if err := binary.Read(reader, binary.BigEndian, &sport); err != nil { + resultErr = fmt.Errorf("failed to decode source port: %w", err) + return + } + if sport == 0 { + resultErr = fmt.Errorf("invalid source port %d", sport) + return + } + if err := binary.Read(reader, binary.BigEndian, &dport); err != nil { + resultErr = fmt.Errorf("failed to decode destination port: %w", err) + return + } + if dport == 0 { + resultErr = fmt.Errorf("invalid destination port %d", sport) + return + } + + var srcIP, dstIP netip.Addr + if ctrlBuf[13]>>4 == 0x1 { // IPv4 + srcIP, _ = netip.AddrFromSlice(ctrlBuf[16:20]) + dstIP, _ = netip.AddrFromSlice(ctrlBuf[20:24]) + } else { + srcIP, _ = netip.AddrFromSlice(ctrlBuf[16:32]) + dstIP, _ = netip.AddrFromSlice(ctrlBuf[32:48]) + } + + saddr = netip.AddrPortFrom(srcIP, sport) + daddr = netip.AddrPortFrom(dstIP, dport) + data = ctrlBuf[16+dataLen:] + return +} + +func readRemoteAddrPROXYv1(ctrlBuf []byte) (saddr, daddr netip.AddrPort, data []byte, resultErr error) { + str := string(ctrlBuf) + idx := strings.Index(str, "\r\n") + if idx < 0 { + resultErr = fmt.Errorf("did not find \\r\\n in first data segment") + return + } + + var headerProtocol string + n, err := fmt.Sscanf(str, "PROXY %s", &headerProtocol) + if err != nil { + resultErr = err + return + } + if n != 1 { + resultErr = fmt.Errorf("failed to decode elements") + return + } + if headerProtocol == "UNKNOWN" { + data = ctrlBuf[idx+2:] + return + } + if headerProtocol != "TCP4" && headerProtocol != "TCP6" { + resultErr = fmt.Errorf("unknown protocol %s", headerProtocol) + return + } + + var src, dst string + var sport, dport int + n, err = fmt.Sscanf(str, "PROXY %s %s %s %d %d", &headerProtocol, &src, &dst, &sport, &dport) + if err != nil { + resultErr = err + return + } + if n != 5 { + resultErr = fmt.Errorf("failed to decode elements") + return + } + if sport <= 0 || sport > 65535 { + resultErr = fmt.Errorf("invalid source port %d", sport) + return + } + if dport <= 0 || dport > 65535 { + resultErr = fmt.Errorf("invalid destination port %d", sport) + return + } + srcIP, err := netip.ParseAddr(src) + if err != nil { + resultErr = fmt.Errorf("failed to parse source IP address %s: %w", src, err) + return + } + dstIP, err := netip.ParseAddr(dst) + if err != nil { + resultErr = fmt.Errorf("failed to parse destination IP address %s: %w", dst, err) + return + } + + saddr = netip.AddrPortFrom(srcIP, uint16(sport)) + daddr = netip.AddrPortFrom(dstIP, uint16(dport)) + data = ctrlBuf[idx+2:] + return +} + +var proxyv2header = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + +func ReadRemoteAddr(buf []byte, protocol utils.Protocol) (saddr, daddr netip.AddrPort, rest []byte, err error) { + if len(buf) >= 16 && bytes.Equal(buf[:12], proxyv2header) { + saddr, daddr, rest, err = readRemoteAddrPROXYv2(buf, protocol) + if err != nil { + err = fmt.Errorf("failed to parse PROXY v2 header: %w", err) + } + return + } + + // PROXYv1 only works with TCP + if protocol == utils.TCP && len(buf) >= 8 && bytes.Equal(buf[:5], []byte("PROXY")) { + saddr, daddr, rest, err = readRemoteAddrPROXYv1(buf) + if err != nil { + err = fmt.Errorf("failed to parse PROXY v1 header: %w", err) + } + return + } + + err = fmt.Errorf("PROXY header missing") + return +} diff --git a/tcp.go b/tcp/tcp.go similarity index 62% rename from tcp.go rename to tcp/tcp.go index abf76e4..316d253 100644 --- a/tcp.go +++ b/tcp/tcp.go @@ -1,8 +1,9 @@ // Copyright 2019 Path Network, Inc. All rights reserved. +// Copyright 2024 Konrad Zemek // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package main +package tcp import ( "context" @@ -10,31 +11,38 @@ import ( "log/slog" "net" "net/netip" + + "github.com/kzemek/go-mmproxy/buffers" + "github.com/kzemek/go-mmproxy/proxyprotocol" + "github.com/kzemek/go-mmproxy/utils" ) -func tcpCopyData(dst net.Conn, src net.Conn, ch chan<- error) { +func copyData(dst net.Conn, src net.Conn, ch chan<- error) { _, err := io.Copy(dst, src) ch <- err } -func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { +func handleConnection(conn net.Conn, opts *utils.Options, logger *slog.Logger) { defer conn.Close() - logger = logger.With(slog.String("remoteAddr", conn.RemoteAddr().String()), + + remoteAddr := netip.MustParseAddrPort(conn.RemoteAddr().String()) + + logger = logger.With(slog.String("remoteAddr", remoteAddr.String()), slog.String("localAddr", conn.LocalAddr().String())) - if !checkOriginAllowed(conn.RemoteAddr().(*net.TCPAddr).IP) { + if !utils.CheckOriginAllowed(remoteAddr.Addr(), opts.AllowedSubnets) { logger.Debug("connection origin not in allowed subnets", slog.Bool("dropConnection", true)) return } - if Opts.Verbose > 1 { + if opts.Verbose > 1 { logger.Debug("new connection") } - buffer := GetBuffer() + buffer := buffers.Get() defer func() { if buffer != nil { - PutBuffer(buffer) + buffers.Put(buffer) } }() @@ -44,33 +52,36 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { return } - saddr, _, restBytes, err := proxyReadRemoteAddr(buffer[:n], TCP) + saddr, _, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.TCP) if err != nil { logger.Debug("failed to parse PROXY header", "error", err, slog.Bool("dropConnection", true)) return } - targetAddr := Opts.TargetAddr6 - if saddr == nil { - if netip.MustParseAddrPort(conn.RemoteAddr().String()).Addr().Is4() { - targetAddr = Opts.TargetAddr4 + targetAddr := opts.TargetAddr6 + if saddr.IsValid() { + if saddr.Addr().Is4() { + targetAddr = opts.TargetAddr4 + } + } else { + if remoteAddr.Addr().Is4() { + targetAddr = opts.TargetAddr4 } - } else if netip.MustParseAddrPort(saddr.String()).Addr().Is4() { - targetAddr = Opts.TargetAddr4 } clientAddr := "UNKNOWN" - if saddr != nil { + if saddr.IsValid() { clientAddr = saddr.String() } logger = logger.With(slog.String("clientAddr", clientAddr), slog.String("targetAddr", targetAddr.String())) - if Opts.Verbose > 1 { + if opts.Verbose > 1 { logger.Debug("successfully parsed PROXY header") } - dialer := net.Dialer{LocalAddr: saddr} - if saddr != nil { - dialer.Control = dialUpstreamControl(saddr.(*net.TCPAddr).Port) + dialer := net.Dialer{} + if saddr.IsValid() { + dialer.LocalAddr = net.TCPAddrFromAddrPort(saddr) + dialer.Control = utils.DialUpstreamControl(saddr.Port(), opts.Protocol, opts.Mark) } upstreamConn, err := dialer.Dial("tcp", targetAddr.String()) if err != nil { @@ -79,19 +90,19 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { } defer upstreamConn.Close() - if Opts.Verbose > 1 { + if opts.Verbose > 1 { logger.Debug("successfully established upstream connection") } if err := conn.(*net.TCPConn).SetNoDelay(true); err != nil { logger.Debug("failed to set nodelay on downstream connection", "error", err, slog.Bool("dropConnection", true)) - } else if Opts.Verbose > 1 { + } else if opts.Verbose > 1 { logger.Debug("successfully set NoDelay on downstream connection") } if err := upstreamConn.(*net.TCPConn).SetNoDelay(true); err != nil { logger.Debug("failed to set nodelay on upstream connection", "error", err, slog.Bool("dropConnection", true)) - } else if Opts.Verbose > 1 { + } else if opts.Verbose > 1 { logger.Debug("successfully set NoDelay on upstream connection") } @@ -105,24 +116,23 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { restBytes = restBytes[n:] } - PutBuffer(buffer) + buffers.Put(buffer) buffer = nil outErr := make(chan error, 2) - go tcpCopyData(upstreamConn, conn, outErr) - go tcpCopyData(conn, upstreamConn, outErr) + go copyData(upstreamConn, conn, outErr) + go copyData(conn, upstreamConn, outErr) err = <-outErr if err != nil { logger.Debug("connection broken", "error", err, slog.Bool("dropConnection", true)) - } else if Opts.Verbose > 1 { + } else if opts.Verbose > 1 { logger.Debug("connection closing") } } -func tcpListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) { - ctx := context.Background() - ln, err := listenConfig.Listen(ctx, "tcp", Opts.ListenAddr.String()) +func Listen(ctx context.Context, listenConfig *net.ListenConfig, opts *utils.Options, logger *slog.Logger, errors chan<- error) { + ln, err := listenConfig.Listen(ctx, "tcp", opts.ListenAddr.String()) if err != nil { logger.Error("failed to bind listener", "error", err) errors <- err @@ -139,6 +149,6 @@ func tcpListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan< return } - go tcpHandleConnection(conn, logger) + go handleConnection(conn, opts, logger) } } diff --git a/tests/buffers_test.go b/tests/buffers_test.go new file mode 100644 index 0000000..b332624 --- /dev/null +++ b/tests/buffers_test.go @@ -0,0 +1,33 @@ +// Copyright 2024 Konrad Zemek +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tests + +import ( + "testing" + + "github.com/kzemek/go-mmproxy/buffers" +) + +func TestGetGetsAPutBuffer(t *testing.T) { + buf1 := buffers.Get() + buf2 := buffers.Get() + + for i := range buf1 { + buf1[i] = 127 + } + + buffers.Put(buf1) + + buf3 := buffers.Get() + + for i := range buf3 { + if buf3[i] != 127 { + t.Errorf("Expected to retrieve previously stored buffer") + } + } + + buffers.Put(buf3) + buffers.Put(buf2) +} diff --git a/tests/proxyprotocol_test.go b/tests/proxyprotocol_test.go new file mode 100644 index 0000000..70544d6 --- /dev/null +++ b/tests/proxyprotocol_test.go @@ -0,0 +1,147 @@ +// Copyright 2024 Konrad Zemek +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tests + +import ( + "net/netip" + "reflect" + "testing" + + "github.com/kzemek/go-mmproxy/proxyprotocol" + "github.com/kzemek/go-mmproxy/utils" +) + +func TestProxyProtocolV1(t *testing.T) { + buf := []byte("PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\nmoredata") + + saddr, daddr, rest, err := proxyprotocol.ReadRemoteAddr(buf, utils.TCP) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if saddr.String() != "192.168.0.1:56324" { + t.Errorf("Unexpected source address: %v", saddr) + } + + if daddr.String() != "192.168.0.11:443" { + t.Errorf("Unexpected destination address: %v", daddr) + } + + if !reflect.DeepEqual(rest, []byte("moredata")) { + t.Errorf("Unexpected rest: %v", rest) + } +} + +func TestProxyProtocolV1_nontcp(t *testing.T) { + buf := []byte("PROXY UDP4 192.168.0.1 192.168.0.11 56324 443\r\nmoredata") + + saddr, daddr, rest, err := proxyprotocol.ReadRemoteAddr(buf, utils.TCP) + if err == nil { + t.Errorf("Error was expected, yet returned %v %v %v", saddr, daddr, rest) + } +} + +func TestProxyProtocolV1_Unknown(t *testing.T) { + buf := []byte("PROXY UNKNOWN\r\nmoredata") + + saddr, daddr, rest, err := proxyprotocol.ReadRemoteAddr(buf, utils.TCP) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if saddr.IsValid() { + t.Errorf("Unexpected source address: %v", saddr) + } + + if daddr.IsValid() { + t.Errorf("Unexpected destination address: %v", daddr) + } + + if !reflect.DeepEqual(rest, []byte("moredata")) { + t.Errorf("Unexpected rest: %v", rest) + } +} + +func TestProxyProtocolV1_UnknownWithAddrs(t *testing.T) { + buf := []byte("PROXY UNKNOWN ffff::1 ffff::1 1234 1234\r\nmoredata") + + saddr, daddr, rest, err := proxyprotocol.ReadRemoteAddr(buf, utils.TCP) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if saddr.IsValid() { + t.Errorf("Unexpected source address: %v", saddr) + } + + if daddr.IsValid() { + t.Errorf("Unexpected destination address: %v", daddr) + } + + if !reflect.DeepEqual(rest, []byte("moredata")) { + t.Errorf("Unexpected rest: %v", rest) + } +} + +func TestProxyProtocolV2(t *testing.T) { + buf := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + buf = append(buf, 0x21) // PROXY + buf = append(buf, 0x11) // TCP4 + buf = append(buf, 0x00, 0x0C) // 12 bytes + buf = append(buf, 192, 168, 0, 1) // saddr + buf = append(buf, 192, 168, 0, 11) // daddr + buf = append(buf, 0xDC, 0x04) // sport 56324 + buf = append(buf, 0x01, 0xBB) // dport 443 + buf = append(buf, []byte("moredata")...) + + saddr, daddr, rest, err := proxyprotocol.ReadRemoteAddr(buf, utils.TCP) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if saddr.String() != "192.168.0.1:56324" { + t.Errorf("Unexpected source address: %v", saddr) + } + + if daddr.String() != "192.168.0.11:443" { + t.Errorf("Unexpected destination address: %v", daddr) + } + + if !reflect.DeepEqual(rest, []byte("moredata")) { + t.Errorf("Unexpected rest: %v", rest) + } +} + +func TestProxyProtocolV2_udp6(t *testing.T) { + expectedSaddr := netip.MustParseAddr("2001:db8::1") + expectedDaddr := netip.MustParseAddr("2001:db8::2") + + buf := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + buf = append(buf, 0x21) // PROXY + buf = append(buf, 0x22) // UDP6 + buf = append(buf, 0x00, 0x24) // 36 bytes + buf = append(buf, expectedSaddr.AsSlice()...) + buf = append(buf, expectedDaddr.AsSlice()...) + buf = append(buf, 0xDC, 0x04) // sport 56324 + buf = append(buf, 0x01, 0xBB) // dport 443 + buf = append(buf, []byte("moredata")...) + + saddr, daddr, rest, err := proxyprotocol.ReadRemoteAddr(buf, utils.UDP) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if saddr != netip.AddrPortFrom(expectedSaddr, 56324) { + t.Errorf("Unexpected source address: %v", saddr) + } + + if daddr != netip.AddrPortFrom(expectedDaddr, 443) { + t.Errorf("Unexpected destination address: %v", daddr) + } + + if !reflect.DeepEqual(rest, []byte("moredata")) { + t.Errorf("Unexpected rest: %v", rest) + } +} diff --git a/tests/tcp_test.go b/tests/tcp_test.go new file mode 100644 index 0000000..1a5d3ce --- /dev/null +++ b/tests/tcp_test.go @@ -0,0 +1,202 @@ +// Copyright 2024 Konrad Zemek +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tests + +import ( + "context" + "log/slog" + "net" + "net/netip" + "os" + "reflect" + "testing" + "time" + + "github.com/kzemek/go-mmproxy/tcp" + "github.com/kzemek/go-mmproxy/utils" +) + +type listenResult struct { + data []byte + saddr netip.AddrPort +} + +func runServer(t *testing.T, addr string, receivedData chan<- listenResult) { + server, err := net.Listen("tcp", addr) + if err != nil { + t.Fatalf("Failed to listen on server: %v", err) + } + defer server.Close() + + conn, err := server.Accept() + if err != nil { + t.Fatalf("Failed to accept connection: %v", err) + } + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + t.Fatalf("Failed to read data: %v", err) + } + + receivedData <- listenResult{ + data: buf[:n], + saddr: netip.MustParseAddrPort(conn.RemoteAddr().String()), + } +} + +func TestListen(t *testing.T) { + opts := utils.Options{ + Protocol: utils.TCP, + ListenAddr: netip.MustParseAddrPort("0.0.0.0:12345"), + TargetAddr4: netip.MustParseAddrPort("127.0.0.1:54321"), + TargetAddr6: netip.MustParseAddrPort("[::1]:54321"), + Mark: 0, + AllowedSubnets: nil, + Verbose: 2, + } + + lvl := slog.LevelInfo + if opts.Verbose > 0 { + lvl = slog.LevelDebug + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) + + listenConfig := net.ListenConfig{} + errors := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go tcp.Listen(ctx, &listenConfig, &opts, logger, errors) + + receivedData4 := make(chan listenResult, 1) + go runServer(t, "127.0.0.1:54321", receivedData4) + + time.Sleep(1 * time.Second) + + conn, err := net.Dial("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + conn.Write([]byte("PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\nmoredata")) + result := <-receivedData4 + + if !reflect.DeepEqual(result.data, []byte("moredata")) { + t.Errorf("Unexpected data: %v", result.data) + } + + if result.saddr.String() != "192.168.0.1:56324" { + t.Errorf("Unexpected source address: %v", result.saddr) + } +} + +func TestListen_unknown(t *testing.T) { + opts := utils.Options{ + Protocol: utils.TCP, + ListenAddr: netip.MustParseAddrPort("0.0.0.0:12346"), + TargetAddr4: netip.MustParseAddrPort("127.0.0.1:54322"), + TargetAddr6: netip.MustParseAddrPort("[::1]:54322"), + Mark: 0, + AllowedSubnets: nil, + Verbose: 2, + } + + lvl := slog.LevelInfo + if opts.Verbose > 0 { + lvl = slog.LevelDebug + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) + + listenConfig := net.ListenConfig{} + errors := make(chan error, 1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go tcp.Listen(ctx, &listenConfig, &opts, logger, errors) + + receivedData4 := make(chan listenResult, 1) + go runServer(t, "127.0.0.1:54322", receivedData4) + + time.Sleep(1 * time.Second) + + conn, err := net.Dial("tcp", "127.0.0.1:12346") + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + conn.Write([]byte("PROXY UNKNOWN\r\nmoredata")) + result := <-receivedData4 + + if !reflect.DeepEqual(result.data, []byte("moredata")) { + t.Errorf("Unexpected data: %v", result.data) + } + + if result.saddr.Addr().String() != "127.0.0.1" { + t.Errorf("Unexpected source address: %v", result.saddr) + } +} + +func TestListen_proxyV2(t *testing.T) { + opts := utils.Options{ + Protocol: utils.TCP, + ListenAddr: netip.MustParseAddrPort("0.0.0.0:12347"), + TargetAddr4: netip.MustParseAddrPort("127.0.0.1:54323"), + TargetAddr6: netip.MustParseAddrPort("[::1]:54323"), + Mark: 0, + AllowedSubnets: nil, + Verbose: 2, + } + + lvl := slog.LevelInfo + if opts.Verbose > 0 { + lvl = slog.LevelDebug + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) + + listenConfig := net.ListenConfig{} + errors := make(chan error, 1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go tcp.Listen(ctx, &listenConfig, &opts, logger, errors) + + receivedData4 := make(chan listenResult, 1) + go runServer(t, "127.0.0.1:54323", receivedData4) + + time.Sleep(1 * time.Second) + + conn, err := net.Dial("tcp", "127.0.0.1:12347") + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + buf := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + buf = append(buf, 0x21) // PROXY + buf = append(buf, 0x11) // TCP4 + buf = append(buf, 0x00, 0x0C) // 12 bytes + buf = append(buf, 192, 168, 0, 1) // saddr + buf = append(buf, 192, 168, 0, 11) // daddr + buf = append(buf, 0xDC, 0x04) // sport 56324 + buf = append(buf, 0x01, 0xBB) // dport 443 + buf = append(buf, []byte("moredata")...) + + conn.Write(buf) + result := <-receivedData4 + + if !reflect.DeepEqual(result.data, []byte("moredata")) { + t.Errorf("Unexpected data: %v", result.data) + } + + if result.saddr.String() != "192.168.0.1:56324" { + t.Errorf("Unexpected source address: %v", result.saddr) + } +} diff --git a/tests/udp_test.go b/tests/udp_test.go new file mode 100644 index 0000000..0ece86b --- /dev/null +++ b/tests/udp_test.go @@ -0,0 +1,96 @@ +// Copyright 2024 Konrad Zemek +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tests + +import ( + "context" + "log/slog" + "net" + "net/netip" + "os" + "reflect" + "testing" + "time" + + "github.com/kzemek/go-mmproxy/udp" + "github.com/kzemek/go-mmproxy/utils" +) + +func runUDPServer(t *testing.T, addr string, receivedData chan<- listenResult) { + conn, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addr))) + if err != nil { + t.Fatalf("Failed to listen on server: %v", err) + } + defer conn.Close() + + buf := make([]byte, 1024) + n, from, err := conn.ReadFrom(buf) + if err != nil { + t.Fatalf("Failed to read data: %v", err) + } + + receivedData <- listenResult{ + data: buf[:n], + saddr: netip.MustParseAddrPort(from.String()), + } +} + +func TestListenUDP(t *testing.T) { + opts := utils.Options{ + Protocol: utils.UDP, + ListenAddr: netip.MustParseAddrPort("0.0.0.0:12347"), + TargetAddr4: netip.MustParseAddrPort("127.0.0.1:54323"), + TargetAddr6: netip.MustParseAddrPort("[::1]:54323"), + Mark: 0, + AllowedSubnets: nil, + Verbose: 2, + } + + lvl := slog.LevelInfo + if opts.Verbose > 0 { + lvl = slog.LevelDebug + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) + + listenConfig := net.ListenConfig{} + errors := make(chan error, 1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go udp.Listen(ctx, &listenConfig, &opts, logger, errors) + + receivedData4 := make(chan listenResult, 1) + go runUDPServer(t, "127.0.0.1:54323", receivedData4) + + time.Sleep(1 * time.Second) + + conn, err := net.Dial("udp", "127.0.0.1:12347") + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + buf := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + buf = append(buf, 0x21) // PROXY + buf = append(buf, 0x12) // UDP4 + buf = append(buf, 0x00, 0x0C) // 12 bytes + buf = append(buf, 192, 168, 0, 1) // saddr + buf = append(buf, 192, 168, 0, 11) // daddr + buf = append(buf, 0xDC, 0x04) // sport 56324 + buf = append(buf, 0x01, 0xBB) // dport 443 + buf = append(buf, []byte("moredata")...) + + conn.Write(buf) + result := <-receivedData4 + + if !reflect.DeepEqual(result.data, []byte("moredata")) { + t.Errorf("Unexpected data: %v", result.data) + } + + if result.saddr.String() != "192.168.0.1:56324" { + t.Errorf("Unexpected source address: %v", result.saddr) + } +} diff --git a/udp.go b/udp/udp.go similarity index 53% rename from udp.go rename to udp/udp.go index 9475e1b..c5816a4 100644 --- a/udp.go +++ b/udp/udp.go @@ -1,8 +1,9 @@ // Copyright 2019 Path Network, Inc. All rights reserved. +// Copyright 2024 Konrad Zemek // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package main +package udp import ( "context" @@ -13,33 +14,33 @@ import ( "sync/atomic" "syscall" "time" + + "github.com/kzemek/go-mmproxy/buffers" + "github.com/kzemek/go-mmproxy/proxyprotocol" + "github.com/kzemek/go-mmproxy/utils" ) -type udpConnection struct { +type connection struct { lastActivity *int64 - clientAddr *net.UDPAddr - downstreamAddr *net.UDPAddr + clientAddr netip.AddrPort + downstreamAddr netip.AddrPort upstream *net.UDPConn logger *slog.Logger } -func udpCloseAfterInactivity(conn *udpConnection, socketClosures chan<- string) { +func closeAfterInactivity(conn *connection, closeAfter time.Duration, socketClosures chan<- netip.AddrPort) { for { lastActivity := atomic.LoadInt64(conn.lastActivity) - <-time.After(Opts.UDPCloseAfter) + <-time.After(closeAfter) if atomic.LoadInt64(conn.lastActivity) == lastActivity { break } } conn.upstream.Close() - if conn.clientAddr != nil { - socketClosures <- conn.clientAddr.String() - } else { - socketClosures <- "" - } + socketClosures <- conn.clientAddr } -func udpCopyFromUpstream(downstream net.PacketConn, conn *udpConnection) { +func copyFromUpstream(downstream net.PacketConn, conn *connection) { rawConn, err := conn.upstream.SyscallConn() if err != nil { conn.logger.Error("failed to retrieve raw connection from upstream socket", "error", err) @@ -49,8 +50,8 @@ func udpCopyFromUpstream(downstream net.PacketConn, conn *udpConnection) { var syscallErr error err = rawConn.Read(func(fd uintptr) bool { - buf := GetBuffer() - defer PutBuffer(buf) + buf := buffers.Get() + defer buffers.Put(buf) for { n, _, serr := syscall.Recvfrom(int(fd), buf, syscall.MSG_DONTWAIT) @@ -67,7 +68,7 @@ func udpCopyFromUpstream(downstream net.PacketConn, conn *udpConnection) { atomic.AddInt64(conn.lastActivity, 1) - if _, serr := downstream.WriteTo(buf[:n], conn.downstreamAddr); serr != nil { + if _, serr := downstream.WriteTo(buf[:n], net.UDPAddrFromAddrPort(conn.downstreamAddr)); serr != nil { syscallErr = serr return true } @@ -82,30 +83,33 @@ func udpCopyFromUpstream(downstream net.PacketConn, conn *udpConnection) { } } -func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Addr, logger *slog.Logger, - connMap map[string]*udpConnection, socketClosures chan<- string) (*udpConnection, error) { - connKey := "" - if saddr != nil { - connKey = saddr.String() - } - if conn := connMap[connKey]; conn != nil { +func getSocketFromMap(downstream net.PacketConn, opts *utils.Options, downstreamAddr, saddr netip.AddrPort, logger *slog.Logger, + connMap map[netip.AddrPort]*connection, socketClosures chan<- netip.AddrPort) (*connection, error) { + if conn := connMap[saddr]; conn != nil { atomic.AddInt64(conn.lastActivity, 1) return conn, nil } - targetAddr := Opts.TargetAddr6 - if netip.MustParseAddr(downstreamAddr.String()).Is4() { - targetAddr = Opts.TargetAddr4 + targetAddr := opts.TargetAddr6 + if saddr.IsValid() { + if saddr.Addr().Is4() { + targetAddr = opts.TargetAddr4 + } + } else { + if downstreamAddr.Addr().Is4() { + targetAddr = opts.TargetAddr4 + } } logger = logger.With(slog.String("downstreamAddr", downstreamAddr.String()), slog.String("targetAddr", targetAddr.String())) - dialer := net.Dialer{LocalAddr: saddr} - if saddr != nil { + dialer := net.Dialer{} + if saddr.IsValid() { logger = logger.With(slog.String("clientAddr", saddr.String())) - dialer.Control = dialUpstreamControl(saddr.(*net.UDPAddr).Port) + dialer.LocalAddr = net.UDPAddrFromAddrPort(saddr) + dialer.Control = utils.DialUpstreamControl(saddr.Port(), opts.Protocol, opts.Mark) } - if Opts.Verbose > 1 { + if opts.Verbose > 1 { logger.Debug("new connection") } @@ -115,24 +119,21 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad return nil, err } - udpConn := &udpConnection{upstream: conn.(*net.UDPConn), + udpConn := &connection{upstream: conn.(*net.UDPConn), logger: logger, lastActivity: new(int64), - downstreamAddr: downstreamAddr.(*net.UDPAddr)} - if saddr != nil { - udpConn.clientAddr = saddr.(*net.UDPAddr) - } + clientAddr: saddr, + downstreamAddr: downstreamAddr} - go udpCopyFromUpstream(downstream, udpConn) - go udpCloseAfterInactivity(udpConn, socketClosures) + go copyFromUpstream(downstream, udpConn) + go closeAfterInactivity(udpConn, opts.UDPCloseAfter, socketClosures) - connMap[connKey] = udpConn + connMap[saddr] = udpConn return udpConn, nil } -func udpListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) { - ctx := context.Background() - ln, err := listenConfig.ListenPacket(ctx, "udp", Opts.ListenAddr.String()) +func Listen(ctx context.Context, listenConfig *net.ListenConfig, opts *utils.Options, logger *slog.Logger, errors chan<- error) { + ln, err := listenConfig.ListenPacket(ctx, "udp", opts.ListenAddr.String()) if err != nil { logger.Error("failed to bind listener", "error", err) errors <- err @@ -141,25 +142,27 @@ func udpListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan< logger.Info("listening") - socketClosures := make(chan string, 1024) - connectionMap := make(map[string]*udpConnection) + socketClosures := make(chan netip.AddrPort, 1024) + connectionMap := make(map[netip.AddrPort]*connection) - buffer := GetBuffer() - defer PutBuffer(buffer) + buffer := buffers.Get() + defer buffers.Put(buffer) for { - n, remoteAddr, err := ln.ReadFrom(buffer) + n, remoteAddrNet, err := ln.ReadFrom(buffer) if err != nil { logger.Error("failed to read from socket", "error", err) continue } - if !checkOriginAllowed(remoteAddr.(*net.UDPAddr).IP) { + remoteAddr := netip.MustParseAddrPort(remoteAddrNet.String()) + + if !utils.CheckOriginAllowed(remoteAddr.Addr(), opts.AllowedSubnets) { logger.Debug("packet origin not in allowed subnets", slog.String("remoteAddr", remoteAddr.String())) continue } - saddr, _, restBytes, err := proxyReadRemoteAddr(buffer[:n], UDP) + saddr, _, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.UDP) if err != nil { logger.Debug("failed to parse PROXY header", "error", err, slog.String("remoteAddr", remoteAddr.String())) continue @@ -178,7 +181,7 @@ func udpListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan< } } - conn, err := udpGetSocketFromMap(ln, remoteAddr, saddr, logger, connectionMap, socketClosures) + conn, err := getSocketFromMap(ln, opts, remoteAddr, saddr, logger, connectionMap, socketClosures) if err != nil { continue } diff --git a/utils.go b/utils/utils.go similarity index 74% rename from utils.go rename to utils/utils.go index 36cbb95..56b88c6 100644 --- a/utils.go +++ b/utils/utils.go @@ -1,8 +1,9 @@ // Copyright 2019 Path Network, Inc. All rights reserved. +// Copyright 2024 Konrad Zemek // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package main +package utils import ( "fmt" @@ -10,6 +11,7 @@ import ( "net/netip" "strconv" "syscall" + "time" ) type Protocol int @@ -19,12 +21,23 @@ const ( UDP ) -func checkOriginAllowed(remoteIP net.IP) bool { - if len(Opts.AllowedSubnets) == 0 { +type Options struct { + Protocol Protocol + ListenAddr netip.AddrPort + TargetAddr4 netip.AddrPort + TargetAddr6 netip.AddrPort + Mark int + Verbose int + AllowedSubnets []netip.Prefix + UDPCloseAfter time.Duration +} + +func CheckOriginAllowed(remoteIP netip.Addr, allowedSubnets []netip.Prefix) bool { + if len(allowedSubnets) == 0 { return true } - for _, ipNet := range Opts.AllowedSubnets { + for _, ipNet := range allowedSubnets { if ipNet.Contains(remoteIP) { return true } @@ -32,7 +45,7 @@ func checkOriginAllowed(remoteIP net.IP) bool { return false } -func parseHostPort(hostport string) (netip.AddrPort, error) { +func ParseHostPort(hostport string) (netip.AddrPort, error) { host, portStr, err := net.SplitHostPort(hostport) if err != nil { return netip.AddrPort{}, fmt.Errorf("failed to parse host and port: %w", err) @@ -55,11 +68,11 @@ func parseHostPort(hostport string) (netip.AddrPort, error) { return netip.AddrPortFrom(ip, uint16(port)), nil } -func dialUpstreamControl(sport int) func(string, string, syscall.RawConn) error { +func DialUpstreamControl(sport uint16, protocol Protocol, mark int) func(string, string, syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { var syscallErr error err := c.Control(func(fd uintptr) { - if Opts.Protocol == "tcp" { + if protocol == TCP { syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, syscall.TCP_SYNCNT, 2) if syscallErr != nil { syscallErr = fmt.Errorf("setsockopt(IPPROTO_TCP, TCP_SYNCTNT, 2): %w", syscallErr) @@ -83,15 +96,15 @@ func dialUpstreamControl(sport int) func(string, string, syscall.RawConn) error ipBindAddressNoPort := 24 syscallErr = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, ipBindAddressNoPort, 1) if syscallErr != nil { - syscallErr = fmt.Errorf("setsockopt(SOL_SOCKET, IPPROTO_IP, %d): %w", Opts.Mark, syscallErr) + syscallErr = fmt.Errorf("setsockopt(IPPROTO_IP, IP_BIND_ADDRESS_NO_PORT, 1): %w", syscallErr) return } } - if Opts.Mark != 0 { - syscallErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, Opts.Mark) + if mark != 0 { + syscallErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, mark) if syscallErr != nil { - syscallErr = fmt.Errorf("setsockopt(SOL_SOCK, SO_MARK, %d): %w", Opts.Mark, syscallErr) + syscallErr = fmt.Errorf("setsockopt(SOL_SOCK, SO_MARK, %d): %w", mark, syscallErr) return } }