From 61a13555983aac7cf93568f8a01e88dea8eecf7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 8 Nov 2024 10:29:20 +0800 Subject: [PATCH] Re-implement lazy conns --- go.mod | 4 +- go.sum | 4 +- stack_gvisor.go | 33 ++------ stack_gvisor_lazy.go | 190 +++++++++++++++++++++++++++++++++++++++++++ stack_gvisor_udp.go | 28 ++++++- 5 files changed, 227 insertions(+), 32 deletions(-) create mode 100644 stack_gvisor_lazy.go diff --git a/go.mod b/go.mod index eadd50b..6797583 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,12 @@ go 1.20 require ( github.com/go-ole/go-ole v1.3.0 + github.com/google/btree v1.1.3 github.com/sagernet/fswatch v0.1.1 github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 - github.com/sagernet/sing v0.5.1-0.20241105104305-c80c8f907c56 + github.com/sagernet/sing v0.5.1-0.20241108022204-8fe04d136965 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/net v0.26.0 @@ -17,7 +18,6 @@ require ( require ( github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mdlayher/netlink v1.7.2 // indirect diff --git a/go.sum b/go.sum index 0e0e2ca..97cd937 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= -github.com/sagernet/sing v0.5.1-0.20241105104305-c80c8f907c56 h1:g+JCKxY8a+0L7GXjtS+t6uvJB3ibqKwyM/LJfFQM9/A= -github.com/sagernet/sing v0.5.1-0.20241105104305-c80c8f907c56/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.5.1-0.20241108022204-8fe04d136965 h1:J7931AWKG7qOWyBO2reN6rFnszKWfYCtsJWVdDVksjc= +github.com/sagernet/sing v0.5.1-0.20241108022204-8fe04d136965/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= diff --git a/stack_gvisor.go b/stack_gvisor.go index 89983f1..69ff900 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -17,7 +17,6 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" - "github.com/sagernet/gvisor/pkg/waiter" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -81,32 +80,14 @@ func (t *GVisor) Start() error { r.Complete(pErr != ErrDrop) return } - var ( - wq waiter.Queue - endpoint tcpip.Endpoint - tErr tcpip.Error - ) - handshakeCtx, cancel := context.WithCancel(context.Background()) - go func() { - select { - case <-t.ctx.Done(): - wq.Notify(wq.Events()) - case <-handshakeCtx.Done(): - } - }() - endpoint, tErr = r.CreateEndpoint(&wq) - cancel() - if tErr != nil { - r.Complete(true) - return + conn := &gLazyConn{ + parentCtx: t.ctx, + stack: t.stack, + request: r, + localAddr: source.TCPAddr(), + remoteAddr: destination.TCPAddr(), } - r.Complete(false) - endpoint.SocketOptions().SetKeepAlive(true) - keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second) - endpoint.SetSockOpt(&keepAliveIdle) - keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second) - endpoint.SetSockOpt(&keepAliveInterval) - go t.handler.NewConnectionEx(t.ctx, gonet.NewTCPConn(&wq, endpoint), source, destination, nil) + go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil) }) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go new file mode 100644 index 0000000..8195844 --- /dev/null +++ b/stack_gvisor_lazy.go @@ -0,0 +1,190 @@ +//go:build with_gvisor + +package tun + +import ( + "context" + "net" + "time" + + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" + "github.com/sagernet/gvisor/pkg/waiter" + "github.com/sagernet/sing/common" +) + +type gLazyConn struct { + tcpConn *gonet.TCPConn + parentCtx context.Context + stack *stack.Stack + request *tcp.ForwarderRequest + localAddr net.Addr + remoteAddr net.Addr + handshakeDone bool + handshakeErr error +} + +func (c *gLazyConn) HandshakeContext(ctx context.Context) error { + if c.handshakeDone { + return nil + } + defer func() { + c.handshakeDone = true + }() + var ( + wq waiter.Queue + endpoint tcpip.Endpoint + ) + handshakeCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-c.parentCtx.Done(): + wq.Notify(wq.Events()) + case <-handshakeCtx.Done(): + } + }() + endpoint, err := c.request.CreateEndpoint(&wq) + cancel() + if err != nil { + gErr := gonet.TranslateNetstackError(err) + c.handshakeErr = gErr + c.request.Complete(true) + return gErr + } + c.request.Complete(false) + endpoint.SocketOptions().SetKeepAlive(true) + endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIdleOption(15 * time.Second))) + endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIntervalOption(15 * time.Second))) + tcpConn := gonet.NewTCPConn(&wq, endpoint) + c.tcpConn = tcpConn + return nil +} + +func (c *gLazyConn) HandshakeFailure(err error) error { + if c.handshakeDone { + return nil + } + c.request.Complete(err != ErrDrop) + c.handshakeDone = true + c.handshakeErr = err + return nil +} + +func (c *gLazyConn) HandshakeSuccess() error { + return c.HandshakeContext(context.Background()) +} + +func (c *gLazyConn) Read(b []byte) (n int, err error) { + if !c.handshakeDone { + err = c.HandshakeContext(context.Background()) + if err != nil { + return + } + } else if c.handshakeErr != nil { + return 0, c.handshakeErr + } + return c.tcpConn.Read(b) +} + +func (c *gLazyConn) Write(b []byte) (n int, err error) { + if !c.handshakeDone { + err = c.HandshakeContext(context.Background()) + if err != nil { + return + } + } else if c.handshakeErr != nil { + return 0, c.handshakeErr + } + return c.tcpConn.Write(b) +} + +func (c *gLazyConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *gLazyConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *gLazyConn) SetDeadline(t time.Time) error { + if !c.handshakeDone { + err := c.HandshakeContext(context.Background()) + if err != nil { + return err + } + } else if c.handshakeErr != nil { + return c.handshakeErr + } + return c.tcpConn.SetDeadline(t) +} + +func (c *gLazyConn) SetReadDeadline(t time.Time) error { + if !c.handshakeDone { + err := c.HandshakeContext(context.Background()) + if err != nil { + return err + } + } else if c.handshakeErr != nil { + return c.handshakeErr + } + return c.tcpConn.SetReadDeadline(t) +} + +func (c *gLazyConn) SetWriteDeadline(t time.Time) error { + if !c.handshakeDone { + err := c.HandshakeContext(context.Background()) + if err != nil { + return err + } + } else if c.handshakeErr != nil { + return c.handshakeErr + } + return c.tcpConn.SetWriteDeadline(t) +} + +func (c *gLazyConn) Close() error { + if !c.handshakeDone { + c.request.Complete(true) + c.handshakeErr = net.ErrClosed + return nil + } else if c.handshakeErr != nil { + return nil + } + return c.tcpConn.Close() +} + +func (c *gLazyConn) CloseRead() error { + if !c.handshakeDone { + c.request.Complete(true) + c.handshakeErr = net.ErrClosed + return nil + } else if c.handshakeErr != nil { + return nil + } + return c.tcpConn.CloseRead() +} + +func (c *gLazyConn) CloseWrite() error { + if !c.handshakeDone { + c.request.Complete(true) + c.handshakeErr = net.ErrClosed + return nil + } else if c.handshakeErr != nil { + return nil + } + return c.tcpConn.CloseRead() +} + +func (c *gLazyConn) ReaderReplaceable() bool { + return c.handshakeDone && c.handshakeErr == nil +} + +func (c *gLazyConn) WriterReplaceable() bool { + return c.handshakeDone && c.handshakeErr == nil +} + +func (c *gLazyConn) Upstream() any { + return c.tcpConn +} diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 22e7e09..243e5b2 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -60,7 +60,7 @@ func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination) if pErr != nil { if pErr != ErrDrop { - gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr) + gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer)) } return false, nil, nil, nil } @@ -72,6 +72,7 @@ func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M } writer := &UDPBackWriter{ stack: f.stack, + packet: userData.(*stack.PacketBuffer).IncRef(), source: AddressFromAddr(source.Addr), sourcePort: source.Port, sourceNetwork: sourceNetwork, @@ -82,11 +83,34 @@ func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M type UDPBackWriter struct { access sync.Mutex stack *stack.Stack + packet *stack.PacketBuffer source tcpip.Address sourcePort uint16 sourceNetwork tcpip.NetworkProtocolNumber } +func (w *UDPBackWriter) HandshakeSuccess() error { + w.access.Lock() + defer w.access.Unlock() + if w.packet != nil { + w.packet.DecRef() + w.packet = nil + } + return nil +} + +func (w *UDPBackWriter) HandshakeFailure(err error) error { + w.access.Lock() + defer w.access.Unlock() + if w.packet != nil { + wErr := gWriteUnreachable(w.stack, w.packet) + w.packet.DecRef() + w.packet = nil + return wErr + } + return nil +} + func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error { if !destination.IsIP() { return E.Cause(os.ErrInvalid, "invalid destination") @@ -150,7 +174,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock return nil } -func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error { +func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer) error { if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, stack.RejectIPv4WithICMPPortUnreachable, true)) } else {