diff --git a/README.md b/README.md index 2cc9619..6cbe948 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,8 @@ Usage of ./go-mmproxy: Path to a file that contains allowed subnets of the proxy servers -close-after int Number of seconds after which UDP socket will be cleaned up (default 60) + -dynamic-destination + Traffic will be forwarded to the destination specified in the PROXY protocol header -l string Address the proxy listens on (default "0.0.0.0:8443") -listeners int diff --git a/main.go b/main.go index f34d35f..f8e3582 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,7 @@ func init() { flag.StringVar(&listenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on") flag.StringVar(&targetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to") flag.StringVar(&targetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to") + flag.BoolVar(&opts.DynamicDestination, "dynamic-destination", false, "Traffic will be forwarded to the destination specified in the PROXY protocol header") flag.IntVar(&opts.Mark, "mark", 0, "The mark that will be set on outbound packets") flag.IntVar(&opts.Verbose, "v", 0, `0 - no logging of individual connections 1 - log errors occurring in individual connections @@ -44,7 +45,7 @@ func init() { "Path to a file that contains allowed subnets of the proxy servers") flag.IntVar(&listeners, "listeners", 1, "Number of listener sockets that will be opened for the listen address (Linux 3.9+)") - flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up") + flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up on inactivity") } func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) { diff --git a/tcp/tcp.go b/tcp/tcp.go index 316d253..906a9b2 100644 --- a/tcp/tcp.go +++ b/tcp/tcp.go @@ -52,7 +52,7 @@ func handleConnection(conn net.Conn, opts *utils.Options, logger *slog.Logger) { return } - saddr, _, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.TCP) + saddr, daddr, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.TCP) if err != nil { logger.Debug("failed to parse PROXY header", "error", err, slog.Bool("dropConnection", true)) return @@ -60,7 +60,9 @@ func handleConnection(conn net.Conn, opts *utils.Options, logger *slog.Logger) { targetAddr := opts.TargetAddr6 if saddr.IsValid() { - if saddr.Addr().Is4() { + if opts.DynamicDestination && daddr.IsValid() { + targetAddr = daddr + } else if saddr.Addr().Is4() { targetAddr = opts.TargetAddr4 } } else { diff --git a/tests/tcp_test.go b/tests/tcp_test.go index 1a5d3ce..fea0ac6 100644 --- a/tests/tcp_test.go +++ b/tests/tcp_test.go @@ -75,7 +75,7 @@ func TestListen(t *testing.T) { receivedData4 := make(chan listenResult, 1) go runServer(t, "127.0.0.1:54321", receivedData4) - time.Sleep(1 * time.Second) + time.Sleep(100 * time.Millisecond) conn, err := net.Dial("tcp", "127.0.0.1:12345") if err != nil { @@ -123,7 +123,7 @@ func TestListen_unknown(t *testing.T) { receivedData4 := make(chan listenResult, 1) go runServer(t, "127.0.0.1:54322", receivedData4) - time.Sleep(1 * time.Second) + time.Sleep(100 * time.Millisecond) conn, err := net.Dial("tcp", "127.0.0.1:12346") if err != nil { @@ -171,7 +171,7 @@ func TestListen_proxyV2(t *testing.T) { receivedData4 := make(chan listenResult, 1) go runServer(t, "127.0.0.1:54323", receivedData4) - time.Sleep(1 * time.Second) + time.Sleep(100 * time.Millisecond) conn, err := net.Dial("tcp", "127.0.0.1:12347") if err != nil { @@ -200,3 +200,52 @@ func TestListen_proxyV2(t *testing.T) { t.Errorf("Unexpected source address: %v", result.saddr) } } + +func TestTCPListen_DynamicDestination(t *testing.T) { + opts := utils.Options{ + Protocol: utils.TCP, + ListenAddr: netip.MustParseAddrPort("0.0.0.0:12350"), + TargetAddr4: netip.MustParseAddrPort("127.0.0.1:443"), + TargetAddr6: netip.MustParseAddrPort("[::1]:443"), + DynamicDestination: true, + Mark: 0, + AllowedSubnets: nil, + Verbose: 2, + } + + lvl := slog.LevelInfo + if opts.Verbose > 0 { + lvl = slog.LevelDebug + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) + + listenConfig := net.ListenConfig{} + errors := make(chan error, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go tcp.Listen(ctx, &listenConfig, &opts, logger, errors) + + receivedData4 := make(chan listenResult, 1) + go runServer(t, "127.0.0.1:56324", receivedData4) + + time.Sleep(100 * time.Millisecond) + + conn, err := net.Dial("tcp", "127.0.0.1:12350") + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + conn.Write([]byte("PROXY TCP4 192.168.0.1 127.0.0.1 56324 56324\r\nmoredata")) + result := <-receivedData4 + + if !reflect.DeepEqual(result.data, []byte("moredata")) { + t.Errorf("Unexpected data: %v", result.data) + } + + if result.saddr.String() != "192.168.0.1:56324" { + t.Errorf("Unexpected source address: %v", result.saddr) + } +} diff --git a/tests/udp_test.go b/tests/udp_test.go index 0ece86b..cbca211 100644 --- a/tests/udp_test.go +++ b/tests/udp_test.go @@ -65,7 +65,7 @@ func TestListenUDP(t *testing.T) { receivedData4 := make(chan listenResult, 1) go runUDPServer(t, "127.0.0.1:54323", receivedData4) - time.Sleep(1 * time.Second) + time.Sleep(100 * time.Millisecond) conn, err := net.Dial("udp", "127.0.0.1:12347") if err != nil { @@ -94,3 +94,62 @@ func TestListenUDP(t *testing.T) { t.Errorf("Unexpected source address: %v", result.saddr) } } + +func TestListenUDP_DynamicDestination(t *testing.T) { + opts := utils.Options{ + Protocol: utils.UDP, + ListenAddr: netip.MustParseAddrPort("0.0.0.0:12348"), + TargetAddr4: netip.MustParseAddrPort("127.0.0.1:443"), + TargetAddr6: netip.MustParseAddrPort("[::1]:443"), + DynamicDestination: true, + Mark: 0, + AllowedSubnets: nil, + Verbose: 2, + } + + lvl := slog.LevelInfo + if opts.Verbose > 0 { + lvl = slog.LevelDebug + } + + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: lvl})) + + listenConfig := net.ListenConfig{} + errors := make(chan error, 1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go udp.Listen(ctx, &listenConfig, &opts, logger, errors) + + receivedData4 := make(chan listenResult, 1) + go runUDPServer(t, "127.0.0.1:56324", receivedData4) + + time.Sleep(100 * time.Millisecond) + + conn, err := net.Dial("udp", "127.0.0.1:12348") + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + buf := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + buf = append(buf, 0x21) // PROXY + buf = append(buf, 0x12) // UDP4 + buf = append(buf, 0x00, 0x0C) // 12 bytes + buf = append(buf, 192, 168, 0, 1) // saddr + buf = append(buf, 127, 0, 0, 1) // daddr + buf = append(buf, 0xDC, 0x04) // sport 56324 + buf = append(buf, 0xDC, 0x04) // sport 56324 + buf = append(buf, []byte("moredata")...) + + conn.Write(buf) + result := <-receivedData4 + + if !reflect.DeepEqual(result.data, []byte("moredata")) { + t.Errorf("Unexpected data: %v", result.data) + } + + if result.saddr.String() != "192.168.0.1:56324" { + t.Errorf("Unexpected source address: %v", result.saddr) + } +} diff --git a/udp/udp.go b/udp/udp.go index c5816a4..bc9b3e3 100644 --- a/udp/udp.go +++ b/udp/udp.go @@ -83,8 +83,8 @@ func copyFromUpstream(downstream net.PacketConn, conn *connection) { } } -func getSocketFromMap(downstream net.PacketConn, opts *utils.Options, downstreamAddr, saddr netip.AddrPort, logger *slog.Logger, - connMap map[netip.AddrPort]*connection, socketClosures chan<- netip.AddrPort) (*connection, error) { +func getSocketFromMap(downstream net.PacketConn, opts *utils.Options, downstreamAddr, saddr, daddr netip.AddrPort, + logger *slog.Logger, connMap map[netip.AddrPort]*connection, socketClosures chan<- netip.AddrPort) (*connection, error) { if conn := connMap[saddr]; conn != nil { atomic.AddInt64(conn.lastActivity, 1) return conn, nil @@ -92,7 +92,9 @@ func getSocketFromMap(downstream net.PacketConn, opts *utils.Options, downstream targetAddr := opts.TargetAddr6 if saddr.IsValid() { - if saddr.Addr().Is4() { + if opts.DynamicDestination && daddr.IsValid() { + targetAddr = daddr + } else if saddr.Addr().Is4() { targetAddr = opts.TargetAddr4 } } else { @@ -162,7 +164,7 @@ func Listen(ctx context.Context, listenConfig *net.ListenConfig, opts *utils.Opt continue } - saddr, _, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.UDP) + saddr, daddr, restBytes, err := proxyprotocol.ReadRemoteAddr(buffer[:n], utils.UDP) if err != nil { logger.Debug("failed to parse PROXY header", "error", err, slog.String("remoteAddr", remoteAddr.String())) continue @@ -181,7 +183,7 @@ func Listen(ctx context.Context, listenConfig *net.ListenConfig, opts *utils.Opt } } - conn, err := getSocketFromMap(ln, opts, remoteAddr, saddr, logger, connectionMap, socketClosures) + conn, err := getSocketFromMap(ln, opts, remoteAddr, saddr, daddr, logger, connectionMap, socketClosures) if err != nil { continue } diff --git a/utils/utils.go b/utils/utils.go index 56b88c6..ee8982a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -22,14 +22,15 @@ const ( ) type Options struct { - Protocol Protocol - ListenAddr netip.AddrPort - TargetAddr4 netip.AddrPort - TargetAddr6 netip.AddrPort - Mark int - Verbose int - AllowedSubnets []netip.Prefix - UDPCloseAfter time.Duration + Protocol Protocol + ListenAddr netip.AddrPort + TargetAddr4 netip.AddrPort + TargetAddr6 netip.AddrPort + DynamicDestination bool + Mark int + Verbose int + AllowedSubnets []netip.Prefix + UDPCloseAfter time.Duration } func CheckOriginAllowed(remoteIP netip.Addr, allowedSubnets []netip.Prefix) bool {