diff --git a/main.go b/main.go index f8e3582..3e7bf21 100644 --- a/main.go +++ b/main.go @@ -50,7 +50,7 @@ func init() { func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) { logger := parentLogger.With(slog.Int("listenerNum", listenerNum), - slog.String("protocol", protocolStr), slog.String("listenAdr", opts.ListenAddr.String())) + slog.String("protocol", protocolStr), slog.String("listenAddr", opts.ListenAddr.String())) listenConfig := net.ListenConfig{} if listeners > 1 { @@ -132,12 +132,12 @@ func main() { } var err error - if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr); err != nil { + if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr, 0); err != nil { logger.Error("listen address is malformed", "error", err) os.Exit(1) } - if opts.TargetAddr4, err = netip.ParseAddrPort(targetAddr4Str); err != nil { + if opts.TargetAddr4, err = utils.ParseHostPort(targetAddr4Str, 4); err != nil { logger.Error("ipv4 target address is malformed", "error", err) os.Exit(1) } @@ -146,7 +146,7 @@ func main() { os.Exit(1) } - if opts.TargetAddr6, err = netip.ParseAddrPort(targetAddr6Str); err != nil { + if opts.TargetAddr6, err = utils.ParseHostPort(targetAddr6Str, 6); err != nil { logger.Error("ipv6 target address is malformed", "error", err) os.Exit(1) } diff --git a/utils/utils.go b/utils/utils.go index ee8982a..9294325 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -46,7 +46,7 @@ func CheckOriginAllowed(remoteIP netip.Addr, allowedSubnets []netip.Prefix) bool return false } -func ParseHostPort(hostport string) (netip.AddrPort, error) { +func ParseHostPort(hostport string, ipVersion int) (netip.AddrPort, error) { host, portStr, err := net.SplitHostPort(hostport) if err != nil { return netip.AddrPort{}, fmt.Errorf("failed to parse host and port: %w", err) @@ -56,7 +56,16 @@ func ParseHostPort(hostport string) (netip.AddrPort, error) { if err != nil { return netip.AddrPort{}, fmt.Errorf("failed to lookup IP addresses: %w", err) } - if len(ips) == 0 { + + filteredIPs := make([]netip.Addr, 0, len(ips)) + for _, stdip := range ips { + ip := netip.MustParseAddr(stdip.String()) + if ipVersion == 0 || (ip.Is4() && ipVersion == 4) || (ip.Is6() && ipVersion == 6) { + filteredIPs = append(filteredIPs, ip) + } + } + + if len(filteredIPs) == 0 { return netip.AddrPort{}, fmt.Errorf("no IP addresses found") } @@ -65,8 +74,7 @@ func ParseHostPort(hostport string) (netip.AddrPort, error) { return netip.AddrPort{}, fmt.Errorf("failed to parse port: %w", err) } - ip, _ := netip.AddrFromSlice(ips[0]) - return netip.AddrPortFrom(ip, uint16(port)), nil + return netip.AddrPortFrom(filteredIPs[0], uint16(port)), nil } func DialUpstreamControl(sport uint16, protocol Protocol, mark int) func(string, string, syscall.RawConn) error {