Skip to content

Commit

Permalink
fix async Recv always return n==0 at the first time (#8)
Browse files Browse the repository at this point in the history
* fix async Recv always return n==0 at the first time

* fix Close

* fix Close 2

* optimize test

* fix Close 3
  • Loading branch information
lysShub authored Aug 11, 2024
1 parent 01b7432 commit 79d7e0e
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 66 deletions.
22 changes: 7 additions & 15 deletions handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,27 @@ func (d *Handle) Close() error {
if fd != invalid {
r1, _, e := syscall.SyscallN(
procClose.Addr(),
d.handle.Load(),
fd,
)
if r1 == 0 {
return handleError(e)
} else {
return nil
}
}
return nil
return ErrClosed{}
}
func (d *Handle) Priority() int16 { return d.priority }

// Recv recv ip packet, probable return 0.
func (d *Handle) Recv(ip []byte, addr *Address) (int, error) {
var recvLen uint32
var dataPtr, recvLenPtr uintptr
if len(ip) > 0 {
dataPtr = uintptr(unsafe.Pointer(unsafe.SliceData(ip)))
recvLenPtr = uintptr(unsafe.Pointer(&recvLen))
}

r1, _, e := syscall.SyscallN(
procRecv.Addr(),
d.handle.Load(),
dataPtr,
uintptr(unsafe.Pointer(unsafe.SliceData(ip))),
uintptr(len(ip)),
recvLenPtr,
uintptr(unsafe.Pointer(&recvLen)),
uintptr(unsafe.Pointer(addr)),
)
if r1 == 0 {
Expand All @@ -64,15 +60,11 @@ func (d *Handle) Recv(ip []byte, addr *Address) (int, error) {
// notice: recvLen not work, use windows.GetOverlappedResult
func (d *Handle) RecvEx(ip []byte, addr *Address, recvLen *uint32, ol *windows.Overlapped) error {
// todo: support batch recv
var ipPtr uintptr
if len(ip) > 0 {
ipPtr = uintptr(unsafe.Pointer(unsafe.SliceData(ip)))
}

r1, _, e := syscall.SyscallN(
procRecvEx.Addr(),
d.handle.Load(),
ipPtr, // pPacket
uintptr(unsafe.Pointer(unsafe.SliceData(ip))), // pPacket
uintptr(len(ip)), // packetLen
uintptr(unsafe.Pointer(recvLen)), // pRecvLen NOTICE: not work
uintptr(0), // flags 0
Expand Down
230 changes: 179 additions & 51 deletions handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/windows"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
Expand Down Expand Up @@ -142,43 +143,53 @@ func Test_Recv_Error(t *testing.T) {
require.NoError(t, err)
defer d.Close()

{
go func() {
time.Sleep(time.Second)
require.NoError(t, d.Close())
}()
eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
time.Sleep(time.Second)
require.NoError(t, d.Close())
return nil
})
eg.Go(func() error {
_, err = d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrClosed{}), err)
}
return nil
})
eg.Wait()
})

t.Run("recv/close/close", func(t *testing.T) {
d, err := Open("false", Network, 0, 0)
d, err := Open("false", Network, 0, RecvOnly|Sniff)
require.NoError(t, err)
defer d.Close()

{
go func() {
time.Sleep(time.Second)
require.NoError(t, d.Close())
}()
eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
time.Sleep(time.Second)
require.NoError(t, d.Close())
return nil
})
eg.Go(func() error {
_, err = d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrClosed{}), err)

require.True(t, errors.Is(d.Close(), ErrClosed{}), err)
}
return nil
})
eg.Wait()
})

t.Run("recv/close/close/recv", func(t *testing.T) {
d, err := Open("false", Network, 0, 0)
require.NoError(t, err)
defer d.Close()

{
go func() {
time.Sleep(time.Second)
require.NoError(t, d.Close())
}()
eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
time.Sleep(time.Second)
require.NoError(t, d.Close())
return nil
})
eg.Go(func() error {
{
_, err = d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrClosed{}), err)
Expand All @@ -190,7 +201,9 @@ func Test_Recv_Error(t *testing.T) {
_, err = d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrClosed{}), err)
}
}
return nil
})
eg.Wait()
})

t.Run("shutdown/recv", func(t *testing.T) {
Expand All @@ -209,56 +222,71 @@ func Test_Recv_Error(t *testing.T) {
require.NoError(t, err)
defer d.Close()

go func() {
eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
time.Sleep(time.Second)
require.NoError(t, d.Shutdown(Both))
}()

n, err := d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrShutdown{}), err)
require.Zero(t, n)
return nil
})
eg.Go(func() error {
n, err := d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrShutdown{}), err)
require.Zero(t, n)
return nil
})
eg.Wait()
})

t.Run("recv/shutdown/shutdown", func(t *testing.T) {
d, err := Open("false", Network, 0, 0)
require.NoError(t, err)
defer d.Close()

go func() {
eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
time.Sleep(time.Second)
require.NoError(t, d.Shutdown(Both))
}()

n, err := d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrShutdown{}), err)
require.Zero(t, n)
return nil
})
eg.Go(func() error {
n, err := d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrShutdown{}), err)
require.Zero(t, n)

require.NoError(t, d.Shutdown(Both))
require.NoError(t, d.Shutdown(Both))
return nil
})
eg.Wait()
})

t.Run("recv/shutdown/shutdown/recv", func(t *testing.T) {
d, err := Open("false", Network, 0, 0)
require.NoError(t, err)
defer d.Close()

go func() {
eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
time.Sleep(time.Second)
require.NoError(t, d.Shutdown(Both))
}()

{
n, err := d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrShutdown{}), err)
require.Zero(t, n)
}
{
require.NoError(t, d.Shutdown(Both))
}
{
n, err := d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrShutdown{}), err)
require.Zero(t, n)
}
return nil
})
eg.Go(func() error {
{
n, err := d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrShutdown{}), err)
require.Zero(t, n)
}
{
require.NoError(t, d.Shutdown(Both))
}
{
n, err := d.Recv(make([]byte, 1536), nil)
require.True(t, errors.Is(err, ErrShutdown{}), err)
require.Zero(t, n)
}
return nil
})
eg.Wait()
})

t.Run("close/recv", func(t *testing.T) {
Expand All @@ -278,7 +306,7 @@ func Test_Recv_Error(t *testing.T) {
func Test_Recv(t *testing.T) {
MustLoad(DLL)

t.Run("Recv/network/loopback", func(t *testing.T) {
t.Run("network/loopback", func(t *testing.T) {
var (
saddr = netip.AddrPortFrom(locIP, randPort())
caddr = netip.AddrPortFrom(locIP, randPort())
Expand Down Expand Up @@ -325,7 +353,108 @@ func Test_Recv(t *testing.T) {
}
})

t.Run("Recv/socket", func(t *testing.T) {
t.Run("network/async-recv", func(t *testing.T) {
/*
old version will fail:
func (d *Handle) Recv(ip []byte, addr *Address) (int, error) {
var recvLen uint32
var dataPtr, recvLenPtr uintptr
if len(ip) > 0 {
dataPtr = uintptr(unsafe.Pointer(unsafe.SliceData(ip)))
recvLenPtr = uintptr(unsafe.Pointer(&recvLen))
}
r1, _, e := syscall.SyscallN(
procRecv.Addr(),
d.handle.Load(),
dataPtr,
uintptr(len(ip)),
recvLenPtr,
uintptr(unsafe.Pointer(addr)),
)
if r1 == 0 {
return 0, handleError(e)
}
return int(recvLen), nil
}
*/

d, err := Open("inbound", Network, 0, ReadOnly|Sniff)
require.NoError(t, err)
defer d.Close()

eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
var b = make([]byte, 1536)
n, err := d.Recv(b, nil)
require.NoError(t, err)
require.NotZero(t, n)

return nil
})
eg.Go(func() error {
time.Sleep(time.Second)
resp, err := http.Get("http://bing.com")
require.NoError(t, err)
defer resp.Body.Close()

return nil
})
eg.Wait()
})

t.Run("network/empty", func(t *testing.T) {
d, err := Open("inbound", Network, 0, ReadOnly|Sniff)
require.NoError(t, err)
defer d.Close()

eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
var b = make([]byte, 0)
n, err := d.Recv(b, nil)
require.Error(t, windows.ERROR_INSUFFICIENT_BUFFER, err)
require.Zero(t, n)

return nil
})
eg.Go(func() error {
time.Sleep(time.Second)
resp, err := http.Get("http://bing.com")
require.NoError(t, err)
defer resp.Body.Close()

return nil
})
eg.Wait()
})

t.Run("network/nil", func(t *testing.T) {
d, err := Open("inbound", Network, 0, ReadOnly|Sniff)
require.NoError(t, err)
defer d.Close()

eg, _ := errgroup.WithContext(context.Background())
eg.Go(func() error {
n, err := d.Recv(nil, nil)
require.Error(t, windows.ERROR_INSUFFICIENT_BUFFER, err)
require.Zero(t, n)

return nil
})
eg.Go(func() error {
time.Sleep(time.Second)
resp, err := http.Get("http://bing.com")
require.NoError(t, err)
defer resp.Body.Close()

return nil
})
eg.Wait()
})

t.Run("socket/normal", func(t *testing.T) {
d, err := Open("udp and remoteAddr=8.8.8.8", Socket, 0, Sniff|ReadOnly)
require.NoError(t, err)
defer d.Close()
Expand All @@ -348,7 +477,6 @@ func Test_Recv(t *testing.T) {
sa := addr.Socket()
require.Equal(t, netip.AddrFrom4([4]byte{8, 8, 8, 8}), sa.RemoteAddr())
})

}

func Test_Send(t *testing.T) {
Expand Down

0 comments on commit 79d7e0e

Please sign in to comment.