diff --git a/go.mod b/go.mod index adaba41bda..ac901315ef 100644 --- a/go.mod +++ b/go.mod @@ -25,14 +25,14 @@ require ( github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff github.com/sagernet/quic-go v0.48.1-beta.1 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.6.0-alpha.20 + github.com/sagernet/sing v0.6.0-alpha.24 github.com/sagernet/sing-dns v0.4.0-alpha.3 github.com/sagernet/sing-mux v0.3.0-alpha.1 github.com/sagernet/sing-quic v0.4.0-alpha.4 github.com/sagernet/sing-shadowsocks v0.2.7 github.com/sagernet/sing-shadowsocks2 v0.2.0 github.com/sagernet/sing-shadowtls v0.2.0-alpha.2 - github.com/sagernet/sing-tun v0.6.0-alpha.14 + github.com/sagernet/sing-tun v0.6.0-alpha.16 github.com/sagernet/sing-vmess v0.2.0-beta.1 github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 github.com/sagernet/utls v1.6.7 diff --git a/go.sum b/go.sum index b828c4fda1..53b96a775a 100644 --- a/go.sum +++ b/go.sum @@ -110,8 +110,8 @@ github.com/sagernet/quic-go v0.48.1-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= -github.com/sagernet/sing v0.6.0-alpha.20 h1:coxvnzeEGSLNNPntUW7l8WUEHPIwqKszZNbU019To9c= -github.com/sagernet/sing v0.6.0-alpha.20/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.0-alpha.24 h1:qPc9i0mHADIFNYlWMg7fWWZZ0kBxWHEs8npsAG6KqAo= +github.com/sagernet/sing v0.6.0-alpha.24/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-dns v0.4.0-alpha.3 h1:TcAQdz68Gs28VD9o9zDIW7IS8A9LZDruTPI9g9JbGHA= github.com/sagernet/sing-dns v0.4.0-alpha.3/go.mod h1:9LHcYKg2bGQpbtXrfNbopz8ok/zBK9ljiI2kmFG9JKg= github.com/sagernet/sing-mux v0.3.0-alpha.1 h1:IgNX5bJBpL41gGbp05pdDOvh/b5eUQ6cv9240+Ngipg= @@ -124,10 +124,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.0 h1:wpZNs6wKnR7mh1wV9OHwOyUr21VkS3wK github.com/sagernet/sing-shadowsocks2 v0.2.0/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.2.0-alpha.2 h1:RPrpgAdkP5td0vLfS5ldvYosFjSsZtRPxiyLV6jyKg0= github.com/sagernet/sing-shadowtls v0.2.0-alpha.2/go.mod h1:0j5XlzKxaWRIEjc1uiSKmVoWb0k+L9QgZVb876+thZA= -github.com/sagernet/sing-tun v0.6.0-alpha.14 h1:0nE66HdC6nBSOaUG0CEV5rwB5Te3Gts9buVOPvWrGT4= -github.com/sagernet/sing-tun v0.6.0-alpha.14/go.mod h1:xvZlEl1EGBbQeshv4UXmG7hA3f0ngFjpdCIYk308vfg= -github.com/sagernet/sing-vmess v0.1.13-0.20241123134803-8b806fd4b087 h1:p92kbwAIm5Is8V+fK6IB61AZs/nfWoyxxJeib2Dh2o0= -github.com/sagernet/sing-vmess v0.1.13-0.20241123134803-8b806fd4b087/go.mod h1:fLyE1emIcvQ5DV8reFWnufquZ7MkCSYM5ThodsR9NrQ= +github.com/sagernet/sing-tun v0.6.0-alpha.16 h1:VFB8VoM51ctLeDI3spzUaUcSY1La0T83lFzqEjIpK0M= +github.com/sagernet/sing-tun v0.6.0-alpha.16/go.mod h1:U9seS9Ic25rlhKSIL356h1QxWDnTdW+4nykNV95Eap8= github.com/sagernet/sing-vmess v0.2.0-beta.1 h1:5sXQ23uwNlZuDvygzi0dFtnG0Csm/SNqTjAHXJkpuj4= github.com/sagernet/sing-vmess v0.2.0-beta.1/go.mod h1:fLyE1emIcvQ5DV8reFWnufquZ7MkCSYM5ThodsR9NrQ= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= diff --git a/protocol/direct/inbound.go b/protocol/direct/inbound.go index 18d961b326..29b91f87dc 100644 --- a/protocol/direct/inbound.go +++ b/protocol/direct/inbound.go @@ -72,10 +72,15 @@ func (i *Inbound) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - return i.listener.Start() + err := i.listener.Start() + if err != nil { + return err + } + return i.udpNat.Start() } func (i *Inbound) Close() error { + i.udpNat.Close() return i.listener.Close() } diff --git a/protocol/redirect/tproxy.go b/protocol/redirect/tproxy.go index 23d441b7a7..02b7c0c9bf 100644 --- a/protocol/redirect/tproxy.go +++ b/protocol/redirect/tproxy.go @@ -85,10 +85,15 @@ func (t *TProxy) Start(stage adapter.StartStage) error { return E.Cause(err, "configure tproxy UDP listener") } } + err = t.udpNat.Start() + if err != nil { + return err + } return nil } func (t *TProxy) Close() error { + t.udpNat.Close() return t.listener.Close() } diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index dc40b61330..937f84dd9a 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -56,12 +56,18 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } + var udpTimeout time.Duration + if options.UDPTimeout != 0 { + udpTimeout = time.Duration(options.UDPTimeout) + } else { + udpTimeout = C.UDPTimeout + } wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{ Context: ctx, Logger: logger, System: options.System, Handler: ep, - UDPTimeout: time.Duration(options.UDPTimeout), + UDPTimeout: udpTimeout, Dialer: outboundDialer, CreateDialer: func(interfaceName string) N.Dialer { return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{ diff --git a/route/conn.go b/route/conn.go index 4a2192e0da..93ac33e359 100644 --- a/route/conn.go +++ b/route/conn.go @@ -5,6 +5,7 @@ import ( "io" "net" "net/netip" + "sync" "sync/atomic" "time" @@ -18,31 +19,35 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/x/list" ) var _ adapter.ConnectionManager = (*ConnectionManager)(nil) type ConnectionManager struct { - logger logger.ContextLogger - monitor *ConnectionMonitor + logger logger.ContextLogger + access sync.Mutex + connections list.List[io.Closer] } func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { return &ConnectionManager{ - logger: logger, - monitor: NewConnectionMonitor(), + logger: logger, } } func (m *ConnectionManager) Start(stage adapter.StartStage) error { - if stage != adapter.StartStateInitialize { - return nil - } - return m.monitor.Start() + return nil } func (m *ConnectionManager) Close() error { - return m.monitor.Close() + m.access.Lock() + defer m.access.Unlock() + for element := m.connections.Front(); element != nil; element = element.Next() { + common.Close(element.Value) + } + m.connections.Init() + return nil } func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { @@ -57,95 +62,32 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } if err != nil { + err = E.Cause(err, "open outbound connection") N.CloseOnHandshakeFailure(conn, onClose, err) - m.logger.ErrorContext(ctx, "open outbound connection: ", err) + m.logger.ErrorContext(ctx, err) return } err = N.ReportConnHandshakeSuccess(conn, remoteConn) if err != nil { + err = E.Cause(err, "report handshake success") remoteConn.Close() N.CloseOnHandshakeFailure(conn, onClose, err) - m.logger.ErrorContext(ctx, "report handshake success: ", err) + m.logger.ErrorContext(ctx, err) return } + m.access.Lock() + element := m.connections.PushBack(conn) + m.access.Unlock() + onClose = N.AppendClose(onClose, func(it error) { + m.access.Lock() + defer m.access.Unlock() + m.connections.Remove(element) + }) var done atomic.Bool - if ctx.Done() != nil { - onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) - } go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose) } -func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { - originSource := source - originDestination := destination - var readCounters, writeCounters []N.CountFunc - for { - source, readCounters = N.UnwrapCountReader(source, readCounters) - destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) - if cachedSrc, isCached := source.(N.CachedReader); isCached { - cachedBuffer := cachedSrc.ReadCached() - if cachedBuffer != nil { - dataLen := cachedBuffer.Len() - _, err := destination.Write(cachedBuffer.Bytes()) - cachedBuffer.Release() - if err != nil { - m.logger.ErrorContext(ctx, "connection upload payload: ", err) - if done.Swap(true) { - if onClose != nil { - onClose(err) - } - } - common.Close(originSource, originDestination) - return - } - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - } - continue - } - break - } - _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) - if err != nil { - common.Close(originSource, originDestination) - } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { - err = duplexDst.CloseWrite() - if err != nil { - common.Close(originSource, originDestination) - } - } else { - common.Close(originDestination) - } - if done.Swap(true) { - if onClose != nil { - onClose(err) - } - common.Close(originSource, originDestination) - } - if !direction { - if err == nil { - m.logger.DebugContext(ctx, "connection upload finished") - } else if !E.IsClosedOrCanceled(err) { - m.logger.ErrorContext(ctx, "connection upload closed: ", err) - } else { - m.logger.TraceContext(ctx, "connection upload closed") - } - } else { - if err == nil { - m.logger.DebugContext(ctx, "connection download finished") - } else if !E.IsClosedOrCanceled(err) { - m.logger.ErrorContext(ctx, "connection download closed: ", err) - } else { - m.logger.TraceContext(ctx, "connection download closed") - } - } -} - func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = adapter.WithContext(ctx, &metadata) var ( @@ -227,58 +169,91 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout) } destination := bufio.NewPacketConn(remotePacketConn) + m.access.Lock() + element := m.connections.PushBack(conn) + m.access.Unlock() + onClose = N.AppendClose(onClose, func(it error) { + m.access.Lock() + defer m.access.Unlock() + m.connections.Remove(element) + }) var done atomic.Bool - if ctx.Done() != nil { - onClose = N.AppendClose(onClose, m.monitor.Add(ctx, conn)) - } go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose) go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose) } -func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { - _, err := bufio.CopyPacket(destination, source) - /*var readCounters, writeCounters []N.CountFunc - var cachedPackets []*N.PacketBuffer +func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { originSource := source + originDestination := destination + var readCounters, writeCounters []N.CountFunc for { - source, readCounters = N.UnwrapCountPacketReader(source, readCounters) - destination, writeCounters = N.UnwrapCountPacketWriter(destination, writeCounters) - if cachedReader, isCached := source.(N.CachedPacketReader); isCached { - packet := cachedReader.ReadCachedPacket() - if packet != nil { - cachedPackets = append(cachedPackets, packet) - continue + source, readCounters = N.UnwrapCountReader(source, readCounters) + destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) + if cachedSrc, isCached := source.(N.CachedReader); isCached { + cachedBuffer := cachedSrc.ReadCached() + if cachedBuffer != nil { + dataLen := cachedBuffer.Len() + _, err := destination.Write(cachedBuffer.Bytes()) + cachedBuffer.Release() + if err != nil { + if done.Swap(true) { + onClose(err) + } + common.Close(originSource, originDestination) + if !direction { + m.logger.ErrorContext(ctx, "connection upload payload: ", err) + } else { + m.logger.ErrorContext(ctx, "connection download payload: ", err) + } + return + } + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } } + continue } break } - var handled bool - if natConn, isNatConn := source.(udpnat.Conn); isNatConn { - natConn.SetHandler(&udpHijacker{ - ctx: ctx, - logger: m.logger, - source: natConn, - destination: destination, - direction: direction, - readCounters: readCounters, - writeCounters: writeCounters, - done: done, - onClose: onClose, - }) - handled = true - } - if cachedPackets != nil { - _, err := bufio.WritePacketWithPool(originSource, destination, cachedPackets, readCounters, writeCounters) + _, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters) + if err != nil { + common.Close(originDestination) + } else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex { + err = duplexDst.CloseWrite() if err != nil { - common.Close(source, destination) - m.logger.ErrorContext(ctx, "packet upload payload: ", err) - return + common.Close(originSource, originDestination) } + } else { + common.Close(originDestination) } - if handled { - return + if done.Swap(true) { + onClose(err) + common.Close(originSource, originDestination) + } + if !direction { + if err == nil { + m.logger.DebugContext(ctx, "connection upload finished") + } else if !E.IsClosedOrCanceled(err) { + m.logger.ErrorContext(ctx, "connection upload closed: ", err) + } else { + m.logger.TraceContext(ctx, "connection upload closed") + } + } else { + if err == nil { + m.logger.DebugContext(ctx, "connection download finished") + } else if !E.IsClosedOrCanceled(err) { + m.logger.ErrorContext(ctx, "connection download closed: ", err) + } else { + m.logger.TraceContext(ctx, "connection download closed") + } } - _, err := bufio.CopyPacketWithCounters(destination, source, originSource, readCounters, writeCounters)*/ +} + +func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) { + _, err := bufio.CopyPacket(destination, source) if !direction { if E.IsClosedOrCanceled(err) { m.logger.TraceContext(ctx, "packet upload closed") @@ -293,58 +268,7 @@ func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.P } } if !done.Swap(true) { - if onClose != nil { - onClose(err) - } + onClose(err) } common.Close(source, destination) } - -/*type udpHijacker struct { - ctx context.Context - logger logger.ContextLogger - source io.Closer - destination N.PacketWriter - direction bool - readCounters []N.CountFunc - writeCounters []N.CountFunc - done *atomic.Bool - onClose N.CloseHandlerFunc -} - -func (u *udpHijacker) NewPacketEx(buffer *buf.Buffer, source M.Socksaddr) { - dataLen := buffer.Len() - for _, counter := range u.readCounters { - counter(int64(dataLen)) - } - err := u.destination.WritePacket(buffer, source) - if err != nil { - common.Close(u.source, u.destination) - u.logger.DebugContext(u.ctx, "packet upload closed: ", err) - return - } - for _, counter := range u.writeCounters { - counter(int64(dataLen)) - } -} - -func (u *udpHijacker) Close() error { - var err error - if !u.done.Swap(true) { - err = common.Close(u.source, u.destination) - if u.onClose != nil { - u.onClose(net.ErrClosed) - } - } - if u.direction { - u.logger.TraceContext(u.ctx, "packet download closed") - } else { - u.logger.TraceContext(u.ctx, "packet upload closed") - } - return err -} - -func (u *udpHijacker) Upstream() any { - return u.destination -} -*/ diff --git a/route/conn_monitor.go b/route/conn_monitor.go deleted file mode 100644 index 9e271b82a0..0000000000 --- a/route/conn_monitor.go +++ /dev/null @@ -1,128 +0,0 @@ -package route - -import ( - "context" - "io" - "reflect" - "sync" - "time" - - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/x/list" -) - -type ConnectionMonitor struct { - access sync.RWMutex - reloadChan chan struct{} - connections list.List[*monitorEntry] -} - -type monitorEntry struct { - ctx context.Context - closer io.Closer -} - -func NewConnectionMonitor() *ConnectionMonitor { - return &ConnectionMonitor{ - reloadChan: make(chan struct{}, 1), - } -} - -func (m *ConnectionMonitor) Add(ctx context.Context, closer io.Closer) N.CloseHandlerFunc { - m.access.Lock() - defer m.access.Unlock() - element := m.connections.PushBack(&monitorEntry{ - ctx: ctx, - closer: closer, - }) - select { - case <-m.reloadChan: - return nil - default: - select { - case m.reloadChan <- struct{}{}: - default: - } - } - return func(it error) { - m.access.Lock() - defer m.access.Unlock() - m.connections.Remove(element) - select { - case <-m.reloadChan: - default: - select { - case m.reloadChan <- struct{}{}: - default: - } - } - } -} - -func (m *ConnectionMonitor) Start() error { - go m.monitor() - return nil -} - -func (m *ConnectionMonitor) Close() error { - m.access.Lock() - defer m.access.Unlock() - close(m.reloadChan) - for element := m.connections.Front(); element != nil; element = element.Next() { - element.Value.closer.Close() - } - return nil -} - -func (m *ConnectionMonitor) monitor() { - var ( - selectCases []reflect.SelectCase - elements []*list.Element[*monitorEntry] - ) - rootCase := reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(m.reloadChan), - } - for { - m.access.RLock() - if m.connections.Len() == 0 { - m.access.RUnlock() - if _, loaded := <-m.reloadChan; !loaded { - return - } else { - continue - } - } - if len(elements) < m.connections.Len() { - elements = make([]*list.Element[*monitorEntry], 0, m.connections.Len()) - } - if len(selectCases) < m.connections.Len()+1 { - selectCases = make([]reflect.SelectCase, 0, m.connections.Len()+1) - } - elements = elements[:0] - selectCases = selectCases[:1] - selectCases[0] = rootCase - for element := m.connections.Front(); element != nil; element = element.Next() { - elements = append(elements, element) - selectCases = append(selectCases, reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(element.Value.ctx.Done()), - }) - } - m.access.RUnlock() - selected, _, loaded := reflect.Select(selectCases) - if selected == 0 { - if !loaded { - return - } else { - time.Sleep(time.Second) - continue - } - } - element := elements[selected-1] - m.access.Lock() - m.connections.Remove(element) - m.access.Unlock() - element.Value.closer.Close() // maybe go close - } -} diff --git a/route/conn_monitor_test.go b/route/conn_monitor_test.go deleted file mode 100644 index a712bddcdb..0000000000 --- a/route/conn_monitor_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package route_test - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/sagernet/sing-box/route" - - "github.com/stretchr/testify/require" -) - -func TestMonitor(t *testing.T) { - t.Parallel() - var closer myCloser - closer.Add(1) - monitor := route.NewConnectionMonitor() - require.NoError(t, monitor.Start()) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - monitor.Add(ctx, &closer) - done := make(chan struct{}) - go func() { - closer.Wait() - close(done) - }() - select { - case <-done: - case <-time.After(time.Second + 100*time.Millisecond): - t.Fatal("timeout") - } - cancel() - require.NoError(t, monitor.Close()) -} - -type myCloser struct { - sync.WaitGroup -} - -func (c *myCloser) Close() error { - c.Done() - return nil -}