Skip to content

Commit

Permalink
Re-implement lazy conns
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 8, 2024
1 parent 5a91eb9 commit 61a1355
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 32 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ go 1.20

require (
github.com/go-ole/go-ole v1.3.0
github.com/google/btree v1.1.3
github.com/sagernet/fswatch v0.1.1
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.5.1-0.20241105104305-c80c8f907c56
github.com/sagernet/sing v0.5.1-0.20241108022204-8fe04d136965
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0
Expand All @@ -17,7 +18,6 @@ require (

require (
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.5.1-0.20241105104305-c80c8f907c56 h1:g+JCKxY8a+0L7GXjtS+t6uvJB3ibqKwyM/LJfFQM9/A=
github.com/sagernet/sing v0.5.1-0.20241105104305-c80c8f907c56/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.5.1-0.20241108022204-8fe04d136965 h1:J7931AWKG7qOWyBO2reN6rFnszKWfYCtsJWVdDVksjc=
github.com/sagernet/sing v0.5.1-0.20241108022204-8fe04d136965/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
Expand Down
33 changes: 7 additions & 26 deletions stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
Expand Down Expand Up @@ -81,32 +80,14 @@ func (t *GVisor) Start() error {
r.Complete(pErr != ErrDrop)
return
}
var (
wq waiter.Queue
endpoint tcpip.Endpoint
tErr tcpip.Error
)
handshakeCtx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-t.ctx.Done():
wq.Notify(wq.Events())
case <-handshakeCtx.Done():
}
}()
endpoint, tErr = r.CreateEndpoint(&wq)
cancel()
if tErr != nil {
r.Complete(true)
return
conn := &gLazyConn{
parentCtx: t.ctx,
stack: t.stack,
request: r,
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
r.Complete(false)
endpoint.SocketOptions().SetKeepAlive(true)
keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
endpoint.SetSockOpt(&keepAliveIdle)
keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
endpoint.SetSockOpt(&keepAliveInterval)
go t.handler.NewConnectionEx(t.ctx, gonet.NewTCPConn(&wq, endpoint), source, destination, nil)
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
Expand Down
190 changes: 190 additions & 0 deletions stack_gvisor_lazy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
//go:build with_gvisor

package tun

import (
"context"
"net"
"time"

"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing/common"
)

type gLazyConn struct {
tcpConn *gonet.TCPConn
parentCtx context.Context
stack *stack.Stack
request *tcp.ForwarderRequest
localAddr net.Addr
remoteAddr net.Addr
handshakeDone bool
handshakeErr error
}

func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
if c.handshakeDone {
return nil
}
defer func() {
c.handshakeDone = true
}()
var (
wq waiter.Queue
endpoint tcpip.Endpoint
)
handshakeCtx, cancel := context.WithCancel(ctx)
go func() {
select {
case <-c.parentCtx.Done():
wq.Notify(wq.Events())
case <-handshakeCtx.Done():
}
}()
endpoint, err := c.request.CreateEndpoint(&wq)
cancel()
if err != nil {
gErr := gonet.TranslateNetstackError(err)
c.handshakeErr = gErr
c.request.Complete(true)
return gErr
}
c.request.Complete(false)
endpoint.SocketOptions().SetKeepAlive(true)
endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIdleOption(15 * time.Second)))
endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIntervalOption(15 * time.Second)))
tcpConn := gonet.NewTCPConn(&wq, endpoint)
c.tcpConn = tcpConn
return nil
}

func (c *gLazyConn) HandshakeFailure(err error) error {
if c.handshakeDone {
return nil
}
c.request.Complete(err != ErrDrop)
c.handshakeDone = true
c.handshakeErr = err
return nil
}

func (c *gLazyConn) HandshakeSuccess() error {
return c.HandshakeContext(context.Background())
}

func (c *gLazyConn) Read(b []byte) (n int, err error) {
if !c.handshakeDone {
err = c.HandshakeContext(context.Background())
if err != nil {
return
}
} else if c.handshakeErr != nil {
return 0, c.handshakeErr
}
return c.tcpConn.Read(b)
}

func (c *gLazyConn) Write(b []byte) (n int, err error) {
if !c.handshakeDone {
err = c.HandshakeContext(context.Background())
if err != nil {
return
}
} else if c.handshakeErr != nil {
return 0, c.handshakeErr
}
return c.tcpConn.Write(b)
}

func (c *gLazyConn) LocalAddr() net.Addr {
return c.localAddr
}

func (c *gLazyConn) RemoteAddr() net.Addr {
return c.remoteAddr
}

func (c *gLazyConn) SetDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
}
return c.tcpConn.SetDeadline(t)
}

func (c *gLazyConn) SetReadDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
}
return c.tcpConn.SetReadDeadline(t)
}

func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
}
return c.tcpConn.SetWriteDeadline(t)
}

func (c *gLazyConn) Close() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
} else if c.handshakeErr != nil {
return nil
}
return c.tcpConn.Close()
}

func (c *gLazyConn) CloseRead() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
} else if c.handshakeErr != nil {
return nil
}
return c.tcpConn.CloseRead()
}

func (c *gLazyConn) CloseWrite() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
} else if c.handshakeErr != nil {
return nil
}
return c.tcpConn.CloseRead()
}

func (c *gLazyConn) ReaderReplaceable() bool {
return c.handshakeDone && c.handshakeErr == nil
}

func (c *gLazyConn) WriterReplaceable() bool {
return c.handshakeDone && c.handshakeErr == nil
}

func (c *gLazyConn) Upstream() any {
return c.tcpConn
}
28 changes: 26 additions & 2 deletions stack_gvisor_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M
pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
if pErr != nil {
if pErr != ErrDrop {
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr)
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer))
}
return false, nil, nil, nil
}
Expand All @@ -72,6 +72,7 @@ func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M
}
writer := &UDPBackWriter{
stack: f.stack,
packet: userData.(*stack.PacketBuffer).IncRef(),
source: AddressFromAddr(source.Addr),
sourcePort: source.Port,
sourceNetwork: sourceNetwork,
Expand All @@ -82,11 +83,34 @@ func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M
type UDPBackWriter struct {
access sync.Mutex
stack *stack.Stack
packet *stack.PacketBuffer
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
}

func (w *UDPBackWriter) HandshakeSuccess() error {
w.access.Lock()
defer w.access.Unlock()
if w.packet != nil {
w.packet.DecRef()
w.packet = nil
}
return nil
}

func (w *UDPBackWriter) HandshakeFailure(err error) error {
w.access.Lock()
defer w.access.Unlock()
if w.packet != nil {
wErr := gWriteUnreachable(w.stack, w.packet)
w.packet.DecRef()
w.packet = nil
return wErr
}
return nil
}

func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
if !destination.IsIP() {
return E.Cause(os.ErrInvalid, "invalid destination")
Expand Down Expand Up @@ -150,7 +174,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
return nil
}

func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error {
func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer) error {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, stack.RejectIPv4WithICMPPortUnreachable, true))
} else {
Expand Down

0 comments on commit 61a1355

Please sign in to comment.