diff --git a/stack_gvisor.go b/stack_gvisor.go index 523360b..6c5c27f 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -23,6 +23,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) const WithGVisor = true @@ -79,7 +80,7 @@ func (t *GVisor) Start() error { tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) { source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - pErr := t.handler.PrepareConnection(source, destination) + pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination) if pErr != nil { r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid) return @@ -96,28 +97,21 @@ func (t *GVisor) Start() error { ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) if !t.endpointIndependentNat { udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) { + source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) + destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) + pErr := t.handler.PrepareConnection(N.NetworkUDP, source, destination) + if pErr != nil { + gWriteUnreachable(t.stack, r.Packet(), err) + r.Packet().DecRef() + return + } var wq waiter.Queue endpoint, err := r.CreateEndpoint(&wq) if err != nil { return } - udpConn := gonet.NewUDPConn(&wq, endpoint) - lAddr := udpConn.RemoteAddr() - rAddr := udpConn.LocalAddr() - if lAddr == nil || rAddr == nil { - endpoint.Abort() - return - } - source := M.SocksaddrFromNet(lAddr) - destination := M.SocksaddrFromNet(rAddr) - pErr := t.handler.PrepareConnection(source, destination) - if pErr != nil { - gWriteUnreachable(t.stack, r.Packet(), pErr) - r.Packet().DecRef() - return - } go func() { - ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(t.udpTimeout)*time.Second) + ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gonet.NewUDPConn(&wq, endpoint), destination), t.udpTimeout) t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) }() }) diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go index 26f9244..16abdac 100644 --- a/stack_gvisor_lazy.go +++ b/stack_gvisor_lazy.go @@ -199,15 +199,15 @@ func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err erro return nil } else if errors.Is(err, syscall.ENETUNREACH) { if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetProhibited) + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable) } else { return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) } } else if errors.Is(err, syscall.EHOSTUNREACH) { if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { - return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostProhibited) + return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable) } else { - return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) + return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPAddrUnreachable) } } else if errors.Is(err, syscall.ECONNREFUSED) { if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 0b0fedc..dd0c8a0 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -57,7 +57,7 @@ func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac func rangeIterate(r stack.Range, fn func(*buffer.View)) func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { - pErr := f.handler.PrepareConnection(source, destination) + pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination) if pErr != nil { gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr) return false, nil, nil, nil diff --git a/stack_mixed.go b/stack_mixed.go index d4e0607..8388cb9 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -3,8 +3,6 @@ package tun import ( - "time" - "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" @@ -18,6 +16,7 @@ import ( "github.com/sagernet/sing/common/canceler" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type Mixed struct { @@ -51,23 +50,22 @@ func (m *Mixed) Start() error { return err } if !m.endpointIndependentNat { - udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { - var wq waiter.Queue - endpoint, err := request.CreateEndpoint(&wq) - if err != nil { + udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) { + source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) + destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) + pErr := m.handler.PrepareConnection(N.NetworkUDP, source, destination) + if pErr != nil { + gWriteUnreachable(m.stack, r.Packet(), err) + r.Packet().DecRef() return } - udpConn := gonet.NewUDPConn(&wq, endpoint) - lAddr := udpConn.RemoteAddr() - rAddr := udpConn.LocalAddr() - if lAddr == nil || rAddr == nil { - endpoint.Abort() + var wq waiter.Queue + endpoint, err := r.CreateEndpoint(&wq) + if err != nil { return } go func() { - source := M.SocksaddrFromNet(lAddr) - destination := M.SocksaddrFromNet(rAddr) - ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(m.udpTimeout)*time.Second) + ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gonet.NewUDPConn(&wq, endpoint), destination), m.udpTimeout) m.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil) }() }) @@ -229,7 +227,7 @@ func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) { } switch ipHdr.TransportProtocol() { case header.TCPProtocolNumber: - err = m.processIPv6TCP(ipHdr, ipHdr.Payload()) + writeBack, err = m.processIPv6TCP(ipHdr, ipHdr.Payload()) case header.UDPProtocolNumber: writeBack = false pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/stack_system.go b/stack_system.go index 9c4382d..08f2ba3 100644 --- a/stack_system.go +++ b/stack_system.go @@ -2,6 +2,7 @@ package tun import ( "context" + "errors" "net" "net/netip" "syscall" @@ -258,10 +259,10 @@ func (s *System) processPacket(packet []byte) bool { writeBack bool err error ) - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: + switch ipVersion := header.IPVersion(packet); ipVersion { + case header.IPv4Version: writeBack, err = s.processIPv4(packet) - case 6: + case header.IPv6Version: writeBack, err = s.processIPv6(packet) default: err = E.New("ip: unknown version: ", ipVersion) @@ -306,11 +307,11 @@ func (s *System) acceptLoop(listener net.Listener) { } func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) { - writeBack = true destination := ipHdr.DestinationAddr() if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { return } + writeBack = true switch ipHdr.TransportProtocol() { case header.TCPProtocolNumber: writeBack, err = s.processIPv4TCP(ipHdr, ipHdr.Payload()) @@ -324,13 +325,13 @@ func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) { } func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) { - writeBack = true if !ipHdr.DestinationAddr().IsGlobalUnicast() { return } + writeBack = true switch ipHdr.TransportProtocol() { case header.TCPProtocolNumber: - err = s.processIPv6TCP(ipHdr, ipHdr.Payload()) + writeBack, err = s.processIPv6TCP(ipHdr, ipHdr.Payload()) case header.UDPProtocolNumber: err = s.processIPv6UDP(ipHdr, ipHdr.Payload()) case header.ICMPv6ProtocolNumber: @@ -343,7 +344,7 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort()) destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return true, nil + return false, nil } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -356,8 +357,17 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err } else { natPort, err := s.tcpNat.Lookup(source, destination, s.handler) if err != nil { - // TODO: implement ICMP port unreachable - return false, nil + if errors.Is(err, ErrDrop) { + return false, nil + } else if errors.Is(err, syscall.ENETUNREACH) { + return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4NetUnreachable) + } else if errors.Is(err, syscall.EHOSTUNREACH) { + return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable) + } else if errors.Is(err, syscall.ECONNREFUSED) { + return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable) + } else { + return false, s.resetIPv4TCP(ipHdr, tcpHdr) + } } ipHdr.SetSourceAddr(s.inet4Address) tcpHdr.SetSourcePort(natPort) @@ -377,33 +387,84 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err return true, nil } -func (s *System) resetIPv4TCP(packet header.IPv4, header header.TCP) error { - return nil +func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) error { + frontHeadroom := s.frontHeadroom + PacketOffset + newPacket := buf.NewSize(frontHeadroom + header.IPv4MinimumSize + header.TCPMinimumSize) + defer newPacket.Release() + newPacket.Resize(frontHeadroom, header.IPv4MinimumSize+header.TCPMinimumSize) + ipHdr := header.IPv4(newPacket.Bytes()) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(newPacket.Len()), + Protocol: uint8(header.TCPProtocolNumber), + SrcAddr: origIPHdr.DestinationAddr(), + DstAddr: origIPHdr.SourceAddr(), + }) + tcpHdr := header.TCP(ipHdr.Payload()) + fields := header.TCPFields{ + SrcPort: origTCPHdr.DestinationPort(), + DstPort: origTCPHdr.SourcePort(), + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + } + if origTCPHdr.Flags()&header.TCPFlagAck != 0 { + fields.SeqNum = origTCPHdr.AckNumber() + } else { + fields.Flags |= header.TCPFlagAck + ackNum := origTCPHdr.SequenceNumber() + uint32(len(origTCPHdr.Payload())) + if origTCPHdr.Flags()&header.TCPFlagSyn != 0 { + ackNum++ + } + if origTCPHdr.Flags()&header.TCPFlagFin != 0 { + ackNum++ + } + fields.AckNum = ackNum + } + tcpHdr.Encode(&fields) + if !s.txChecksumOffload { + tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize))) + } + ipHdr.SetChecksum(0) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + if PacketOffset > 0 { + newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET + } else { + newPacket.Advance(-s.frontHeadroom) + } + return common.Error(s.tun.Write(newPacket.Bytes())) } -func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) error { +func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, error) { source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort()) destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return nil + return false, nil } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { - return E.New("ipv6: tcp: session not found: ", destination.Port()) + return false, E.New("ipv6: tcp: session not found: ", destination.Port()) } ipHdr.SetSourceAddr(session.Destination.Addr()) tcpHdr.SetSourcePort(session.Destination.Port()) - ipHdr.SetSourceAddr(session.Source.Addr()) + ipHdr.SetDestinationAddr(session.Source.Addr()) tcpHdr.SetDestinationPort(session.Source.Port()) } else { natPort, err := s.tcpNat.Lookup(source, destination, s.handler) if err != nil { - // TODO: implement ICMP port unreachable - return nil + if errors.Is(err, ErrDrop) { + return false, nil + } else if errors.Is(err, syscall.ENETUNREACH) { + return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6NetworkUnreachable) + } else if errors.Is(err, syscall.EHOSTUNREACH) { + return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable) + } else if errors.Is(err, syscall.ECONNREFUSED) { + return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable) + } else { + return false, s.resetIPv6TCP(ipHdr, tcpHdr) + } } ipHdr.SetSourceAddr(s.inet6Address) tcpHdr.SetSourcePort(natPort) - ipHdr.SetSourceAddr(s.inet6ServerAddress) + ipHdr.SetDestinationAddr(s.inet6ServerAddress) tcpHdr.SetDestinationPort(s.tcpPort6) } if !s.txChecksumOffload { @@ -414,7 +475,51 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) error { } else { tcpHdr.SetChecksum(0) } - return nil + return true, nil +} + +func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) error { + frontHeadroom := s.frontHeadroom + PacketOffset + newPacket := buf.NewSize(frontHeadroom + header.IPv6MinimumSize + header.TCPMinimumSize) + defer newPacket.Release() + newPacket.Resize(frontHeadroom, header.IPv6MinimumSize+header.TCPMinimumSize) + ipHdr := header.IPv6(newPacket.Bytes()) + ipHdr.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.TCPMinimumSize), + TransportProtocol: header.TCPProtocolNumber, + SrcAddr: origIPHdr.DestinationAddr(), + DstAddr: origIPHdr.SourceAddr(), + }) + tcpHdr := header.TCP(ipHdr.Payload()) + fields := header.TCPFields{ + SrcPort: origTCPHdr.DestinationPort(), + DstPort: origTCPHdr.SourcePort(), + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + } + if origTCPHdr.Flags()&header.TCPFlagAck != 0 { + fields.SeqNum = origTCPHdr.AckNumber() + } else { + fields.Flags |= header.TCPFlagAck + ackNum := origTCPHdr.SequenceNumber() + uint32(len(origTCPHdr.Payload())) + if origTCPHdr.Flags()&header.TCPFlagSyn != 0 { + ackNum++ + } + if origTCPHdr.Flags()&header.TCPFlagFin != 0 { + ackNum++ + } + fields.AckNum = ackNum + } + tcpHdr.Encode(&fields) + if !s.txChecksumOffload { + tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize))) + } + if PacketOffset > 0 { + newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + } else { + newPacket.Advance(-s.frontHeadroom) + } + return common.Error(s.tun.Write(newPacket.Bytes())) } func (s *System) processIPv4UDP(ipHdr header.IPv4, udpHdr header.UDP) error { @@ -444,9 +549,28 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error { } func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { - pErr := s.handler.PrepareConnection(source, destination) + pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination) if pErr != nil { - // TODO: implement ICMP port unreachable + if errors.Is(pErr, ErrDrop) { + } else if source.IsIPv4() { + ipHdr := userData.(header.IPv4) + if errors.Is(pErr, syscall.ENETUNREACH) { + s.rejectIPv4WithICMP(ipHdr, header.ICMPv4NetUnreachable) + } else if errors.Is(pErr, syscall.EHOSTUNREACH) { + s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable) + } else { + s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable) + } + } else { + ipHdr := userData.(header.IPv6) + if errors.Is(pErr, syscall.ENETUNREACH) { + s.rejectIPv6WithICMP(ipHdr, header.ICMPv6NetworkUnreachable) + } else if errors.Is(pErr, syscall.EHOSTUNREACH) { + s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable) + } else { + s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable) + } + } return false, nil, nil, nil } var writer N.PacketWriter @@ -492,6 +616,45 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error return nil } +func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) error { + frontHeadroom := s.frontHeadroom + PacketOffset + mtu := s.mtu + const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize + if mtu > maxIPData { + mtu = maxIPData + } + available := mtu - header.ICMPv4MinimumSize + if available < len(ipHdr)+header.ICMPv4MinimumErrorPayloadSize { + return nil + } + payload := ipHdr + if len(payload) > available { + payload = payload[:available] + } + newPacket := buf.NewSize(frontHeadroom + header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(payload)) + defer newPacket.Release() + newPacket.Resize(frontHeadroom, header.IPv4MinimumSize+header.ICMPv4MinimumSize+len(payload)) + newIPHdr := header.IPv4(newPacket.Bytes()) + newIPHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(newPacket.Len()), + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: ipHdr.DestinationAddr(), + DstAddr: ipHdr.SourceAddr(), + }) + newIPHdr.SetChecksum(^newIPHdr.CalculateChecksum()) + icmpHdr := header.ICMPv4(newIPHdr.Payload()) + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(code) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0))) + copy(icmpHdr.Payload(), payload) + if PacketOffset > 0 { + newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET + } else { + newPacket.Advance(-s.frontHeadroom) + } + return common.Error(s.tun.Write(newPacket.Bytes())) +} + func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error { if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 { return nil @@ -508,56 +671,49 @@ func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error return nil } -/*func (s *System) WritePacket4(buffer *buf.Buffer, source netip.AddrPort, destination netip.AddrPort) error { - packet := buf.Get(header.IPv4MinimumSize + header.UDPMinimumSize + buffer.Len()) - ipHdr := header.IPv4(packet) - ipHdr.Encode(&header.IPv4Fields{ - TotalLength: uint16(len(packet)), - Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: source.Addr(), - DstAddr: destination.Addr(), - }) - ipHdr.SetHeaderLength(header.IPv4MinimumSize) - udpHdr := header.UDP(ipHdr.Payload()) - udpHdr.Encode(&header.UDPFields{ - SrcPort: source.Port(), - DstPort: destination.Port(), - Length: uint16(header.UDPMinimumSize + buffer.Len()), - }) - copy(udpHdr.Payload(), buffer.Bytes()) - if !s.txChecksumOffload { - ... - } else { - udpHdr.SetChecksum(0) +func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) error { + frontHeadroom := s.frontHeadroom + PacketOffset + mtu := s.mtu + const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize + if mtu > maxIPv6Data { + mtu = maxIPv6Data + } + available := mtu - header.ICMPv6ErrorHeaderSize + if available < header.IPv6MinimumSize { + return nil } - ipHdr.SetChecksum(0) - ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) - return common.Error(s.tun.Write(packet)) -} - -func (s *System) WritePacket6(buffer *buf.Buffer, source netip.AddrPort, destination netip.AddrPort) error { - packet := buf.Get(header.IPv6MinimumSize + header.UDPMinimumSize + buffer.Len()) - ipHdr := header.IPv6(packet) - ipHdr.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + buffer.Len()), - TransportProtocol: header.UDPProtocolNumber, - SrcAddr: source.Addr(), - DstAddr: destination.Addr(), - }) - udpHdr := header.UDP(ipHdr.Payload()) - udpHdr.Encode(&header.UDPFields{ - SrcPort: source.Port(), - DstPort: destination.Port(), - Length: uint16(header.UDPMinimumSize + buffer.Len()), + payload := ipHdr + if len(payload) > available { + payload = payload[:available] + } + newPacket := buf.NewSize(frontHeadroom + header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + len(payload)) + defer newPacket.Release() + newPacket.Resize(frontHeadroom, header.IPv6MinimumSize+header.ICMPv6DstUnreachableMinimumSize+len(payload)) + newIPHdr := header.IPv6(newPacket.Bytes()) + newIPHdr.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.ICMPv6DstUnreachableMinimumSize + len(payload)), + TransportProtocol: header.ICMPv6ProtocolNumber, + SrcAddr: ipHdr.DestinationAddr(), + DstAddr: ipHdr.SourceAddr(), }) - copy(udpHdr.Payload(), buffer.Bytes()) - if !s.txChecksumOffload { - ... + icmpHdr := header.ICMPv6(newIPHdr.Payload()) + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(code) + icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize], + Src: newIPHdr.SourceAddress(), + Dst: newIPHdr.DestinationAddress(), + PayloadCsum: checksum.Checksum(payload, 0), + PayloadLen: len(payload), + })) + copy(icmpHdr.Payload(), payload) + if PacketOffset > 0 { + newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 } else { - udpHdr.SetChecksum(0) + newPacket.Advance(-s.frontHeadroom) } - return common.Error(s.tun.Write(packet)) -}*/ + return common.Error(s.tun.Write(newPacket.Bytes())) +} type systemUDPPacketWriter4 struct { tun Tun diff --git a/stack_system_nat.go b/stack_system_nat.go index 6e7e7ef..1d0216e 100644 --- a/stack_system_nat.go +++ b/stack_system_nat.go @@ -7,6 +7,7 @@ import ( "time" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type TCPNat struct { @@ -77,7 +78,7 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handl if loaded { return port, nil } - pErr := handler.PrepareConnection(M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination)) + pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination)) if pErr != nil { return 0, pErr } diff --git a/tun.go b/tun.go index 68ba7c1..ad6672d 100644 --- a/tun.go +++ b/tun.go @@ -16,7 +16,7 @@ import ( ) type Handler interface { - PrepareConnection(source M.Socksaddr, destination M.Socksaddr) error + PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error N.TCPConnectionHandlerEx N.UDPConnectionHandlerEx }