Skip to content

Commit

Permalink
Allow specifying the targets as hostnames. (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzemek authored Mar 24, 2024
1 parent 2e7b7ca commit 750fa0e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
8 changes: 4 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func init() {

func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) {
logger := parentLogger.With(slog.Int("listenerNum", listenerNum),
slog.String("protocol", protocolStr), slog.String("listenAdr", opts.ListenAddr.String()))
slog.String("protocol", protocolStr), slog.String("listenAddr", opts.ListenAddr.String()))

listenConfig := net.ListenConfig{}
if listeners > 1 {
Expand Down Expand Up @@ -132,12 +132,12 @@ func main() {
}

var err error
if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr); err != nil {
if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr, 0); err != nil {
logger.Error("listen address is malformed", "error", err)
os.Exit(1)
}

if opts.TargetAddr4, err = netip.ParseAddrPort(targetAddr4Str); err != nil {
if opts.TargetAddr4, err = utils.ParseHostPort(targetAddr4Str, 4); err != nil {
logger.Error("ipv4 target address is malformed", "error", err)
os.Exit(1)
}
Expand All @@ -146,7 +146,7 @@ func main() {
os.Exit(1)
}

if opts.TargetAddr6, err = netip.ParseAddrPort(targetAddr6Str); err != nil {
if opts.TargetAddr6, err = utils.ParseHostPort(targetAddr6Str, 6); err != nil {
logger.Error("ipv6 target address is malformed", "error", err)
os.Exit(1)
}
Expand Down
16 changes: 12 additions & 4 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func CheckOriginAllowed(remoteIP netip.Addr, allowedSubnets []netip.Prefix) bool
return false
}

func ParseHostPort(hostport string) (netip.AddrPort, error) {
func ParseHostPort(hostport string, ipVersion int) (netip.AddrPort, error) {
host, portStr, err := net.SplitHostPort(hostport)
if err != nil {
return netip.AddrPort{}, fmt.Errorf("failed to parse host and port: %w", err)
Expand All @@ -56,7 +56,16 @@ func ParseHostPort(hostport string) (netip.AddrPort, error) {
if err != nil {
return netip.AddrPort{}, fmt.Errorf("failed to lookup IP addresses: %w", err)
}
if len(ips) == 0 {

filteredIPs := make([]netip.Addr, 0, len(ips))
for _, stdip := range ips {
ip := netip.MustParseAddr(stdip.String())
if ipVersion == 0 || (ip.Is4() && ipVersion == 4) || (ip.Is6() && ipVersion == 6) {
filteredIPs = append(filteredIPs, ip)
}
}

if len(filteredIPs) == 0 {
return netip.AddrPort{}, fmt.Errorf("no IP addresses found")
}

Expand All @@ -65,8 +74,7 @@ func ParseHostPort(hostport string) (netip.AddrPort, error) {
return netip.AddrPort{}, fmt.Errorf("failed to parse port: %w", err)
}

ip, _ := netip.AddrFromSlice(ips[0])
return netip.AddrPortFrom(ip, uint16(port)), nil
return netip.AddrPortFrom(filteredIPs[0], uint16(port)), nil
}

func DialUpstreamControl(sport uint16, protocol Protocol, mark int) func(string, string, syscall.RawConn) error {
Expand Down

0 comments on commit 750fa0e

Please sign in to comment.