Skip to content

Commit

Permalink
server: only enable udp route control on unspecified socket
Browse files Browse the repository at this point in the history
more nil checks

optimize cmcUDPConn interface
  • Loading branch information
urlesistiana committed Oct 30, 2022
1 parent 43be6ee commit 760a660
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 83 deletions.
34 changes: 17 additions & 17 deletions pkg/server/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ import (

// cmcUDPConn can read and write cmsg.
type cmcUDPConn interface {
readFrom(b []byte) (n int, cm any, src net.Addr, err error)
writeTo(b []byte, cm any, dst net.Addr) (n int, err error)
readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error)
writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error)
}

func (s *Server) ServeUDP(c net.PacketConn) error {
Expand All @@ -61,28 +61,28 @@ func (s *Server) ServeUDP(c net.PacketConn) error {
var cmc cmcUDPConn
var err error
uc, ok := c.(*net.UDPConn)
if ok {
cmc, err = newUDPConn(uc)
if ok && uc.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() {
cmc, err = newCmc(uc)
if err != nil {
return fmt.Errorf("failed to control socket cmsg, %w", err)
}
} else {
cmc = newDummyUDPConn(c)
cmc = newDummyCmc(c)
}

for {
n, cm, clientNetAddr, err := cmc.readFrom(rb)
n, localAddr, ifIndex, remoteAddr, err := cmc.readFrom(rb)
if err != nil {
if s.Closed() {
return ErrServerClosed
}
return fmt.Errorf("unexpected read err: %w", err)
}
clientAddr := utils.GetAddrFromAddr(clientNetAddr)
clientAddr := utils.GetAddrFromAddr(remoteAddr)

q := new(dns.Msg)
if err := q.Unpack(rb[:n]); err != nil {
s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", rb[:n]), zap.Stringer("from", clientNetAddr))
s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", rb[:n]), zap.Stringer("from", remoteAddr))
continue
}

Expand All @@ -105,8 +105,8 @@ func (s *Server) ServeUDP(c net.PacketConn) error {
return
}
defer buf.Release()
if _, err := cmc.writeTo(b, cm, clientNetAddr); err != nil {
s.opts.Logger.Warn("failed to write response", zap.Stringer("client", clientNetAddr), zap.Error(err))
if _, err := cmc.writeTo(b, localAddr, ifIndex, remoteAddr); err != nil {
s.opts.Logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
}
}
}()
Expand All @@ -124,22 +124,22 @@ func getUDPSize(m *dns.Msg) int {
return int(s)
}

// newDummyUDPConn returns a dummyWrapper.
func newDummyUDPConn(c net.PacketConn) cmcUDPConn {
return dummyWrapper{c: c}
// newDummyCmc returns a dummyCmcWrapper.
func newDummyCmc(c net.PacketConn) cmcUDPConn {
return dummyCmcWrapper{c: c}
}

// dummyWrapper is just a wrapper that implements cmcUDPConn but does not
// dummyCmcWrapper is just a wrapper that implements cmcUDPConn but does not
// write or read any control msg.
type dummyWrapper struct {
type dummyCmcWrapper struct {
c net.PacketConn
}

func (w dummyWrapper) readFrom(b []byte) (n int, cm any, src net.Addr, err error) {
func (w dummyCmcWrapper) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
n, src, err = w.c.ReadFrom(b)
return
}

func (w dummyWrapper) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) {
func (w dummyCmcWrapper) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
return w.c.WriteTo(b, dst)
}
132 changes: 68 additions & 64 deletions pkg/server/udp_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,102 +30,106 @@ import (
"os"
)

type protocol int

const (
invalid protocol = iota
v4
v6
)

type ipv4PacketConn struct {
type ipv4cmc struct {
c *ipv4.PacketConn
}

func (i ipv4PacketConn) readFrom(b []byte) (n int, cm any, src net.Addr, err error) {
return i.c.ReadFrom(b)
func newIpv4cmc(c *ipv4.PacketConn) *ipv4cmc {
return &ipv4cmc{c: c}
}

func (i ipv4PacketConn) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) {
cm4 := cm.(*ipv4.ControlMessage)
cm4.Src = cm4.Dst
cm4.Dst = nil
return i.c.WriteTo(b, cm4, dst)
func (i *ipv4cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
n, cm, src, err := i.c.ReadFrom(b)
if cm != nil {
dst, IfIndex = cm.Dst, cm.IfIndex
}
return
}

type ipv6PacketConn struct {
func (i *ipv4cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
cm := &ipv4.ControlMessage{
Src: src,
IfIndex: IfIndex,
}
return i.c.WriteTo(b, cm, dst)
}

type ipv6cmc struct {
c4 *ipv4.PacketConn // ipv4 entrypoint for sending ipv4 packages.
c6 *ipv6.PacketConn
}

func (i ipv6PacketConn) readFrom(b []byte) (n int, cm any, src net.Addr, err error) {
return i.c6.ReadFrom(b)
func newIpv6PacketConn(c4 *ipv4.PacketConn, c6 *ipv6.PacketConn) *ipv6cmc {
return &ipv6cmc{c4: c4, c6: c6}
}

func (i ipv6PacketConn) writeTo(b []byte, cm any, dst net.Addr) (n int, err error) {
cm6 := cm.(*ipv6.ControlMessage)
cm6.Src = cm6.Dst
cm6.Dst = nil

// If src is ipv4, use IP_PKTINFO instead of IPV6_PKTINFO.
// Otherwise, sendmsg will raise "invalid argument" error.
// No official doc found.
if src4 := cm6.Src.To4(); src4 != nil {
return i.c4.WriteTo(b, &ipv4.ControlMessage{
Src: src4,
IfIndex: cm6.IfIndex,
}, dst)
} else {
return i.c6.WriteTo(b, cm6, dst)
func (i *ipv6cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
n, cm, src, err := i.c6.ReadFrom(b)
if cm != nil {
dst, IfIndex = cm.Dst, cm.IfIndex
}
return
}

func newUDPConn(c *net.UDPConn) (cmcUDPConn, error) {
p, err := getSocketIPProtocol(c)
if err != nil {
return nil, fmt.Errorf("failed to get socket ip protocol, %w", err)
}
switch p {
case v4:
c := ipv4.NewPacketConn(c)
if err := c.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil {
return nil, fmt.Errorf("failed to set ipv4 cmsg flags, %w", err)
}
return ipv4PacketConn{c: c}, nil
case v6:
c6 := ipv6.NewPacketConn(c)
if err := c6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil {
return nil, fmt.Errorf("failed to set ipv6 cmsg flags, %w", err)
func (i *ipv6cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
if src != nil {
// If src is ipv4, use IP_PKTINFO instead of IPV6_PKTINFO.
// Otherwise, sendmsg will raise "invalid argument" error.
// No official doc found.
if src4 := src.To4(); src4 != nil {
cm4 := &ipv4.ControlMessage{
Src: src4,
IfIndex: IfIndex,
}
return i.c4.WriteTo(b, cm4, dst)
}
return ipv6PacketConn{c6: c6, c4: ipv4.NewPacketConn(c)}, nil
default:
return nil, fmt.Errorf("unknow protocol %d", p)
}
cm6 := &ipv6.ControlMessage{
Src: src,
IfIndex: IfIndex,
}
return i.c6.WriteTo(b, cm6, dst)
}

func getSocketIPProtocol(c *net.UDPConn) (protocol, error) {
func newCmc(c *net.UDPConn) (cmcUDPConn, error) {
sc, err := c.SyscallConn()
if err != nil {
return 0, err
return nil, err
}
proto := invalid
var syscallErr error
if controlErr := sc.Control(func(fd uintptr) {

var controlErr error
var cmc cmcUDPConn

if err := sc.Control(func(fd uintptr) {
v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN)
if err != nil {
syscallErr = os.NewSyscallError("failed to get SO_PROTOCOL", err)
controlErr = os.NewSyscallError("failed to get SO_PROTOCOL", err)
return
}
switch v {
case unix.AF_INET:
proto = v4
c4 := ipv4.NewPacketConn(c)
if err := c4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil {
controlErr = fmt.Errorf("failed to set ipv4 cmsg flags, %w", err)
}
cmc = newIpv4cmc(c4)
return
case unix.AF_INET6:
proto = v6
c6 := ipv6.NewPacketConn(c)
if err := c6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil {
controlErr = fmt.Errorf("failed to set ipv6 cmsg flags, %w", err)
}
cmc = newIpv6PacketConn(ipv4.NewPacketConn(c), c6)
return
default:
syscallErr = fmt.Errorf("socket protocol %d is not supported", v)
controlErr = fmt.Errorf("socket protocol %d is not supported", v)
}
}); err != nil {
return 0, fmt.Errorf("control fd err, %w", controlErr)
return nil, fmt.Errorf("control fd err, %w", controlErr)
}

if controlErr != nil {
return nil, fmt.Errorf("failed to set up socket, %w", controlErr)
}
return proto, syscallErr
return cmc, nil
}
4 changes: 2 additions & 2 deletions pkg/server/udp_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ package server

import "net"

func newUDPConn(c *net.UDPConn) (cmcUDPConn, error) {
return newDummyUDPConn(c), nil
func newCmc(c *net.UDPConn) (cmcUDPConn, error) {
return newDummyCmc(c), nil
}

0 comments on commit 760a660

Please sign in to comment.