From 79d7e0ef450e8cf4b9c50a7ce8277c716af53374 Mon Sep 17 00:00:00 2001 From: lysShub Date: Mon, 12 Aug 2024 03:27:23 +0800 Subject: [PATCH] fix async Recv always return n==0 at the first time (#8) * fix async Recv always return n==0 at the first time * fix Close * fix Close 2 * optimize test * fix Close 3 --- handle.go | 22 ++--- handle_test.go | 230 ++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 186 insertions(+), 66 deletions(-) diff --git a/handle.go b/handle.go index a060e79..7de943b 100644 --- a/handle.go +++ b/handle.go @@ -26,31 +26,27 @@ func (d *Handle) Close() error { if fd != invalid { r1, _, e := syscall.SyscallN( procClose.Addr(), - d.handle.Load(), + fd, ) if r1 == 0 { return handleError(e) + } else { + return nil } } - return nil + return ErrClosed{} } func (d *Handle) Priority() int16 { return d.priority } // Recv recv ip packet, probable return 0. func (d *Handle) Recv(ip []byte, addr *Address) (int, error) { var recvLen uint32 - var dataPtr, recvLenPtr uintptr - if len(ip) > 0 { - dataPtr = uintptr(unsafe.Pointer(unsafe.SliceData(ip))) - recvLenPtr = uintptr(unsafe.Pointer(&recvLen)) - } - r1, _, e := syscall.SyscallN( procRecv.Addr(), d.handle.Load(), - dataPtr, + uintptr(unsafe.Pointer(unsafe.SliceData(ip))), uintptr(len(ip)), - recvLenPtr, + uintptr(unsafe.Pointer(&recvLen)), uintptr(unsafe.Pointer(addr)), ) if r1 == 0 { @@ -64,15 +60,11 @@ func (d *Handle) Recv(ip []byte, addr *Address) (int, error) { // notice: recvLen not work, use windows.GetOverlappedResult func (d *Handle) RecvEx(ip []byte, addr *Address, recvLen *uint32, ol *windows.Overlapped) error { // todo: support batch recv - var ipPtr uintptr - if len(ip) > 0 { - ipPtr = uintptr(unsafe.Pointer(unsafe.SliceData(ip))) - } r1, _, e := syscall.SyscallN( procRecvEx.Addr(), d.handle.Load(), - ipPtr, // pPacket + uintptr(unsafe.Pointer(unsafe.SliceData(ip))), // pPacket uintptr(len(ip)), // packetLen uintptr(unsafe.Pointer(recvLen)), // pRecvLen NOTICE: not work uintptr(0), // flags 0 diff --git a/handle_test.go b/handle_test.go index 4d86e46..4d3649e 100644 --- a/handle_test.go +++ b/handle_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "golang.org/x/sys/windows" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -142,31 +143,39 @@ func Test_Recv_Error(t *testing.T) { require.NoError(t, err) defer d.Close() - { - go func() { - time.Sleep(time.Second) - require.NoError(t, d.Close()) - }() + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { + time.Sleep(time.Second) + require.NoError(t, d.Close()) + return nil + }) + eg.Go(func() error { _, err = d.Recv(make([]byte, 1536), nil) require.True(t, errors.Is(err, ErrClosed{}), err) - } + return nil + }) + eg.Wait() }) t.Run("recv/close/close", func(t *testing.T) { - d, err := Open("false", Network, 0, 0) + d, err := Open("false", Network, 0, RecvOnly|Sniff) require.NoError(t, err) defer d.Close() - { - go func() { - time.Sleep(time.Second) - require.NoError(t, d.Close()) - }() + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { + time.Sleep(time.Second) + require.NoError(t, d.Close()) + return nil + }) + eg.Go(func() error { _, err = d.Recv(make([]byte, 1536), nil) require.True(t, errors.Is(err, ErrClosed{}), err) require.True(t, errors.Is(d.Close(), ErrClosed{}), err) - } + return nil + }) + eg.Wait() }) t.Run("recv/close/close/recv", func(t *testing.T) { @@ -174,11 +183,13 @@ func Test_Recv_Error(t *testing.T) { require.NoError(t, err) defer d.Close() - { - go func() { - time.Sleep(time.Second) - require.NoError(t, d.Close()) - }() + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { + time.Sleep(time.Second) + require.NoError(t, d.Close()) + return nil + }) + eg.Go(func() error { { _, err = d.Recv(make([]byte, 1536), nil) require.True(t, errors.Is(err, ErrClosed{}), err) @@ -190,7 +201,9 @@ func Test_Recv_Error(t *testing.T) { _, err = d.Recv(make([]byte, 1536), nil) require.True(t, errors.Is(err, ErrClosed{}), err) } - } + return nil + }) + eg.Wait() }) t.Run("shutdown/recv", func(t *testing.T) { @@ -209,14 +222,19 @@ func Test_Recv_Error(t *testing.T) { require.NoError(t, err) defer d.Close() - go func() { + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { time.Sleep(time.Second) require.NoError(t, d.Shutdown(Both)) - }() - - n, err := d.Recv(make([]byte, 1536), nil) - require.True(t, errors.Is(err, ErrShutdown{}), err) - require.Zero(t, n) + return nil + }) + eg.Go(func() error { + n, err := d.Recv(make([]byte, 1536), nil) + require.True(t, errors.Is(err, ErrShutdown{}), err) + require.Zero(t, n) + return nil + }) + eg.Wait() }) t.Run("recv/shutdown/shutdown", func(t *testing.T) { @@ -224,16 +242,21 @@ func Test_Recv_Error(t *testing.T) { require.NoError(t, err) defer d.Close() - go func() { + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { time.Sleep(time.Second) require.NoError(t, d.Shutdown(Both)) - }() - - n, err := d.Recv(make([]byte, 1536), nil) - require.True(t, errors.Is(err, ErrShutdown{}), err) - require.Zero(t, n) + return nil + }) + eg.Go(func() error { + n, err := d.Recv(make([]byte, 1536), nil) + require.True(t, errors.Is(err, ErrShutdown{}), err) + require.Zero(t, n) - require.NoError(t, d.Shutdown(Both)) + require.NoError(t, d.Shutdown(Both)) + return nil + }) + eg.Wait() }) t.Run("recv/shutdown/shutdown/recv", func(t *testing.T) { @@ -241,24 +264,29 @@ func Test_Recv_Error(t *testing.T) { require.NoError(t, err) defer d.Close() - go func() { + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { time.Sleep(time.Second) require.NoError(t, d.Shutdown(Both)) - }() - - { - n, err := d.Recv(make([]byte, 1536), nil) - require.True(t, errors.Is(err, ErrShutdown{}), err) - require.Zero(t, n) - } - { - require.NoError(t, d.Shutdown(Both)) - } - { - n, err := d.Recv(make([]byte, 1536), nil) - require.True(t, errors.Is(err, ErrShutdown{}), err) - require.Zero(t, n) - } + return nil + }) + eg.Go(func() error { + { + n, err := d.Recv(make([]byte, 1536), nil) + require.True(t, errors.Is(err, ErrShutdown{}), err) + require.Zero(t, n) + } + { + require.NoError(t, d.Shutdown(Both)) + } + { + n, err := d.Recv(make([]byte, 1536), nil) + require.True(t, errors.Is(err, ErrShutdown{}), err) + require.Zero(t, n) + } + return nil + }) + eg.Wait() }) t.Run("close/recv", func(t *testing.T) { @@ -278,7 +306,7 @@ func Test_Recv_Error(t *testing.T) { func Test_Recv(t *testing.T) { MustLoad(DLL) - t.Run("Recv/network/loopback", func(t *testing.T) { + t.Run("network/loopback", func(t *testing.T) { var ( saddr = netip.AddrPortFrom(locIP, randPort()) caddr = netip.AddrPortFrom(locIP, randPort()) @@ -325,7 +353,108 @@ func Test_Recv(t *testing.T) { } }) - t.Run("Recv/socket", func(t *testing.T) { + t.Run("network/async-recv", func(t *testing.T) { + /* + old version will fail: + + func (d *Handle) Recv(ip []byte, addr *Address) (int, error) { + var recvLen uint32 + var dataPtr, recvLenPtr uintptr + if len(ip) > 0 { + dataPtr = uintptr(unsafe.Pointer(unsafe.SliceData(ip))) + recvLenPtr = uintptr(unsafe.Pointer(&recvLen)) + } + + r1, _, e := syscall.SyscallN( + procRecv.Addr(), + d.handle.Load(), + dataPtr, + uintptr(len(ip)), + recvLenPtr, + uintptr(unsafe.Pointer(addr)), + ) + if r1 == 0 { + return 0, handleError(e) + } + + return int(recvLen), nil + } + */ + + d, err := Open("inbound", Network, 0, ReadOnly|Sniff) + require.NoError(t, err) + defer d.Close() + + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { + var b = make([]byte, 1536) + n, err := d.Recv(b, nil) + require.NoError(t, err) + require.NotZero(t, n) + + return nil + }) + eg.Go(func() error { + time.Sleep(time.Second) + resp, err := http.Get("http://bing.com") + require.NoError(t, err) + defer resp.Body.Close() + + return nil + }) + eg.Wait() + }) + + t.Run("network/empty", func(t *testing.T) { + d, err := Open("inbound", Network, 0, ReadOnly|Sniff) + require.NoError(t, err) + defer d.Close() + + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { + var b = make([]byte, 0) + n, err := d.Recv(b, nil) + require.Error(t, windows.ERROR_INSUFFICIENT_BUFFER, err) + require.Zero(t, n) + + return nil + }) + eg.Go(func() error { + time.Sleep(time.Second) + resp, err := http.Get("http://bing.com") + require.NoError(t, err) + defer resp.Body.Close() + + return nil + }) + eg.Wait() + }) + + t.Run("network/nil", func(t *testing.T) { + d, err := Open("inbound", Network, 0, ReadOnly|Sniff) + require.NoError(t, err) + defer d.Close() + + eg, _ := errgroup.WithContext(context.Background()) + eg.Go(func() error { + n, err := d.Recv(nil, nil) + require.Error(t, windows.ERROR_INSUFFICIENT_BUFFER, err) + require.Zero(t, n) + + return nil + }) + eg.Go(func() error { + time.Sleep(time.Second) + resp, err := http.Get("http://bing.com") + require.NoError(t, err) + defer resp.Body.Close() + + return nil + }) + eg.Wait() + }) + + t.Run("socket/normal", func(t *testing.T) { d, err := Open("udp and remoteAddr=8.8.8.8", Socket, 0, Sniff|ReadOnly) require.NoError(t, err) defer d.Close() @@ -348,7 +477,6 @@ func Test_Recv(t *testing.T) { sa := addr.Socket() require.Equal(t, netip.AddrFrom4([4]byte{8, 8, 8, 8}), sa.RemoteAddr()) }) - } func Test_Send(t *testing.T) {