From d12143cbdca0c00d1eb5ea7452d077bbeffdc632 Mon Sep 17 00:00:00 2001 From: Stefan Majer Date: Fri, 20 Oct 2023 09:31:37 +0200 Subject: [PATCH] unexport most func --- main.go | 6 +++--- proxyprotocol.go | 17 +++++++++-------- tcp.go | 8 ++++---- udp.go | 8 ++++---- utils.go | 6 +++--- 5 files changed, 23 insertions(+), 22 deletions(-) diff --git a/main.go b/main.go index c978aac..6fc7ffd 100644 --- a/main.go +++ b/main.go @@ -68,9 +68,9 @@ func listen(listenerNum int, errors chan<- error) { } if Opts.Protocol == "tcp" { - TCPListen(&listenConfig, logger, errors) + tcpListen(&listenConfig, logger, errors) } else { - UDPListen(&listenConfig, logger, errors) + udpListen(&listenConfig, logger, errors) } } @@ -130,7 +130,7 @@ func main() { } var err error - if Opts.ListenAddr, err = ParseHostPort(Opts.ListenAddrStr); err != nil { + if Opts.ListenAddr, err = parseHostPort(Opts.ListenAddrStr); err != nil { Opts.Logger.Error("listen address is malformed", "error", err) os.Exit(1) } diff --git a/proxyprotocol.go b/proxyprotocol.go index 7e88642..1547ec1 100644 --- a/proxyprotocol.go +++ b/proxyprotocol.go @@ -29,7 +29,7 @@ func readRemoteAddrPROXYv2(ctrlBuf []byte, protocol Protocol) (net.Addr, net.Add var dataLen uint16 reader := bytes.NewReader(ctrlBuf[14:16]) if err := binary.Read(reader, binary.BigEndian, &dataLen); err != nil { - return nil, nil, nil, fmt.Errorf("failed to decode address data length: %s", err.Error()) + return nil, nil, nil, fmt.Errorf("failed to decode address data length: %w", err) } if len(ctrlBuf) < 16+int(dataLen) { @@ -47,10 +47,10 @@ func readRemoteAddrPROXYv2(ctrlBuf []byte, protocol Protocol) (net.Addr, net.Add reader = bytes.NewReader(ctrlBuf[48:]) } if err := binary.Read(reader, binary.BigEndian, &sport); err != nil { - return nil, nil, nil, fmt.Errorf("failed to decode source port: %s", err.Error()) + return nil, nil, nil, fmt.Errorf("failed to decode source port: %w", err) } if err := binary.Read(reader, binary.BigEndian, &dport); err != nil { - return nil, nil, nil, fmt.Errorf("failed to decode destination port: %s", err.Error()) + return nil, nil, nil, fmt.Errorf("failed to decode destination port: %w", err) } var srcIP, dstIP net.IP @@ -115,12 +115,13 @@ func readRemoteAddrPROXYv1(ctrlBuf []byte) (net.Addr, net.Addr, []byte, error) { return nil, nil, nil, fmt.Errorf("did not find \\r\\n in first data segment") } -func PROXYReadRemoteAddr(buf []byte, protocol Protocol) (net.Addr, net.Addr, []byte, error) { - if len(buf) >= 16 && bytes.Equal(buf[:12], - []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}) { +var proxyv2header = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + +func proxyReadRemoteAddr(buf []byte, protocol Protocol) (net.Addr, net.Addr, []byte, error) { + if len(buf) >= 16 && bytes.Equal(buf[:12], proxyv2header) { saddr, daddr, rest, err := readRemoteAddrPROXYv2(buf, protocol) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to parse PROXY v2 header: %s", err.Error()) + return nil, nil, nil, fmt.Errorf("failed to parse PROXY v2 header: %w", err) } return saddr, daddr, rest, err } @@ -129,7 +130,7 @@ func PROXYReadRemoteAddr(buf []byte, protocol Protocol) (net.Addr, net.Addr, []b if protocol == TCP && len(buf) >= 8 && bytes.Equal(buf[:5], []byte("PROXY")) { saddr, daddr, rest, err := readRemoteAddrPROXYv1(buf) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to parse PROXY v1 header: %s", err.Error()) + return nil, nil, nil, fmt.Errorf("failed to parse PROXY v1 header: %w", err) } return saddr, daddr, rest, err } diff --git a/tcp.go b/tcp.go index 2a297ca..abf76e4 100644 --- a/tcp.go +++ b/tcp.go @@ -22,7 +22,7 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { logger = logger.With(slog.String("remoteAddr", conn.RemoteAddr().String()), slog.String("localAddr", conn.LocalAddr().String())) - if !CheckOriginAllowed(conn.RemoteAddr().(*net.TCPAddr).IP) { + if !checkOriginAllowed(conn.RemoteAddr().(*net.TCPAddr).IP) { logger.Debug("connection origin not in allowed subnets", slog.Bool("dropConnection", true)) return } @@ -44,7 +44,7 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { return } - saddr, _, restBytes, err := PROXYReadRemoteAddr(buffer[:n], TCP) + saddr, _, restBytes, err := proxyReadRemoteAddr(buffer[:n], TCP) if err != nil { logger.Debug("failed to parse PROXY header", "error", err, slog.Bool("dropConnection", true)) return @@ -70,7 +70,7 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { dialer := net.Dialer{LocalAddr: saddr} if saddr != nil { - dialer.Control = DialUpstreamControl(saddr.(*net.TCPAddr).Port) + dialer.Control = dialUpstreamControl(saddr.(*net.TCPAddr).Port) } upstreamConn, err := dialer.Dial("tcp", targetAddr.String()) if err != nil { @@ -120,7 +120,7 @@ func tcpHandleConnection(conn net.Conn, logger *slog.Logger) { } } -func TCPListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) { +func tcpListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) { ctx := context.Background() ln, err := listenConfig.Listen(ctx, "tcp", Opts.ListenAddr.String()) if err != nil { diff --git a/udp.go b/udp.go index ca13441..9475e1b 100644 --- a/udp.go +++ b/udp.go @@ -102,7 +102,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad dialer := net.Dialer{LocalAddr: saddr} if saddr != nil { logger = logger.With(slog.String("clientAddr", saddr.String())) - dialer.Control = DialUpstreamControl(saddr.(*net.UDPAddr).Port) + dialer.Control = dialUpstreamControl(saddr.(*net.UDPAddr).Port) } if Opts.Verbose > 1 { @@ -130,7 +130,7 @@ func udpGetSocketFromMap(downstream net.PacketConn, downstreamAddr, saddr net.Ad return udpConn, nil } -func UDPListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) { +func udpListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan<- error) { ctx := context.Background() ln, err := listenConfig.ListenPacket(ctx, "udp", Opts.ListenAddr.String()) if err != nil { @@ -154,12 +154,12 @@ func UDPListen(listenConfig *net.ListenConfig, logger *slog.Logger, errors chan< continue } - if !CheckOriginAllowed(remoteAddr.(*net.UDPAddr).IP) { + if !checkOriginAllowed(remoteAddr.(*net.UDPAddr).IP) { logger.Debug("packet origin not in allowed subnets", slog.String("remoteAddr", remoteAddr.String())) continue } - saddr, _, restBytes, err := PROXYReadRemoteAddr(buffer[:n], UDP) + saddr, _, restBytes, err := proxyReadRemoteAddr(buffer[:n], UDP) if err != nil { logger.Debug("failed to parse PROXY header", "error", err, slog.String("remoteAddr", remoteAddr.String())) continue diff --git a/utils.go b/utils.go index 00b806d..36cbb95 100644 --- a/utils.go +++ b/utils.go @@ -19,7 +19,7 @@ const ( UDP ) -func CheckOriginAllowed(remoteIP net.IP) bool { +func checkOriginAllowed(remoteIP net.IP) bool { if len(Opts.AllowedSubnets) == 0 { return true } @@ -32,7 +32,7 @@ func CheckOriginAllowed(remoteIP net.IP) bool { return false } -func ParseHostPort(hostport string) (netip.AddrPort, error) { +func parseHostPort(hostport string) (netip.AddrPort, error) { host, portStr, err := net.SplitHostPort(hostport) if err != nil { return netip.AddrPort{}, fmt.Errorf("failed to parse host and port: %w", err) @@ -55,7 +55,7 @@ func ParseHostPort(hostport string) (netip.AddrPort, error) { return netip.AddrPortFrom(ip, uint16(port)), nil } -func DialUpstreamControl(sport int) func(string, string, syscall.RawConn) error { +func dialUpstreamControl(sport int) func(string, string, syscall.RawConn) error { return func(network, address string, c syscall.RawConn) error { var syscallErr error err := c.Control(func(fd uintptr) {