From ae1cd7050377877bd4006fdca127386d599e2521 Mon Sep 17 00:00:00 2001 From: lysShub Date: Sun, 26 May 2024 06:16:56 +0800 Subject: [PATCH] add dll --- divert.go | 163 +++++++++++------------------------------------- divert_test.go | 140 ----------------------------------------- dll.go | 68 -------------------- dll/dll.go | 112 +++++++++++++++++++++++++++++++++ dll/dll_test.go | 21 +++++++ dll/mem.go | 91 +++++++++++++++++++++++++++ dll/sys.go | 15 +++++ handle.go | 45 ++++++------- handle_test.go | 91 +++++++++++---------------- 9 files changed, 330 insertions(+), 416 deletions(-) delete mode 100644 dll.go create mode 100644 dll/dll.go create mode 100644 dll/dll_test.go create mode 100644 dll/mem.go create mode 100644 dll/sys.go diff --git a/divert.go b/divert.go index 64a01cf..af9cb1d 100644 --- a/divert.go +++ b/divert.go @@ -4,15 +4,31 @@ package divert import ( - "sync" "syscall" "unsafe" + "github.com/lysShub/divert-go/dll" "github.com/pkg/errors" "golang.org/x/sys/windows" ) -var global divert +var ( + divert = dll.NewLazyDLL("WinDivert.dll") + + procOpen = divert.NewProc("WinDivertOpen") + procHelperCompileFilter = divert.NewProc("WinDivertHelperCompileFilter") + procHelperEvalFilter = divert.NewProc("WinDivertHelperEvalFilter") + procHelperFormatFilter = divert.NewProc("WinDivertHelperFormatFilter") + + procRecv = divert.NewProc("WinDivertRecv") + procRecvEx = divert.NewProc("WinDivertRecvEx") + procSend = divert.NewProc("WinDivertSend") + procSendEx = divert.NewProc("WinDivertSendEx") + procShutdown = divert.NewProc("WinDivertShutdown") + procClose = divert.NewProc("WinDivertClose") + procSetParam = divert.NewProc("WinDivertSetParam") + procGetParam = divert.NewProc("WinDivertGetParam") +) func MustLoad[T string | Mem](p T) struct{} { err := Load(p) @@ -23,128 +39,22 @@ func MustLoad[T string | Mem](p T) struct{} { } func Load[T string | Mem](p T) error { - global.Lock() - defer global.Unlock() - if global.dll != nil { + if divert.Loaded() { return ErrLoaded{} } - var err error switch p := any(p).(type) { case string: - global.dll, err = loadFileDLL(p) - if err != nil { - return errors.WithStack(err) - } + dll.ResetLazyDll(divert, p) case Mem: - if err = driverInstall(p.Sys); err != nil { + if err := driverInstall(p.Sys); err != nil { return errors.WithStack(err) } - - global.dll, err = loadMemDLL(p.DLL) - if err != nil { - return err - } + dll.ResetLazyDll(divert, p.DLL) default: - return windows.ERROR_INVALID_PARAMETER - } - - err = global.init() - return errors.WithStack(err) -} - -func Release() error { - global.Lock() - defer global.Unlock() - if global.dll == nil { - return nil - } - - err := global.dll.Release() - global.dll = nil - return errors.WithStack(err) -} - -type divert struct { - sync.RWMutex - dll dll - - procOpen uintptr // WinDivertOpen - procHelperCompileFilter uintptr // WinDivertHelperCompileFilter - procHelperEvalFilter uintptr // WinDivertHelperEvalFilter - procHelperFormatFilter uintptr // WinDivertHelperFormatFilter - - procRecv uintptr // WinDivertRecv - procRecvEx uintptr // WinDivertRecvEx - procSend uintptr // WinDivertSend - procSendEx uintptr // WinDivertSendEx - procShutdown uintptr // WinDivertShutdown - procClose uintptr // WinDivertClose - procSetParam uintptr // WinDivertSetParam - procGetParam uintptr // WinDivertGetParam -} - -func (d *divert) init() (err error) { - if d.procOpen, err = d.dll.FindProc("WinDivertOpen"); err != nil { - goto ret - } - if d.procHelperCompileFilter, err = d.dll.FindProc("WinDivertHelperCompileFilter"); err != nil { - goto ret - } - if d.procHelperEvalFilter, err = d.dll.FindProc("WinDivertHelperEvalFilter"); err != nil { - goto ret - } - if d.procHelperFormatFilter, err = d.dll.FindProc("WinDivertHelperFormatFilter"); err != nil { - goto ret + panic("") } - - if d.procRecv, err = d.dll.FindProc("WinDivertRecv"); err != nil { - goto ret - } - if d.procRecvEx, err = d.dll.FindProc("WinDivertRecvEx"); err != nil { - goto ret - } - if d.procSend, err = d.dll.FindProc("WinDivertSend"); err != nil { - goto ret - } - if d.procSendEx, err = d.dll.FindProc("WinDivertSendEx"); err != nil { - goto ret - } - if d.procShutdown, err = d.dll.FindProc("WinDivertShutdown"); err != nil { - goto ret - } - if d.procClose, err = d.dll.FindProc("WinDivertClose"); err != nil { - goto ret - } - if d.procSetParam, err = d.dll.FindProc("WinDivertSetParam"); err != nil { - goto ret - } - if d.procGetParam, err = d.dll.FindProc("WinDivertGetParam"); err != nil { - goto ret - } - -ret: - if err != nil { - d.dll.Release() - d.dll = nil - } - return err -} - -func (d *divert) calln(trap uintptr, args ...uintptr) (r1, r2 uintptr, err error) { - d.RLock() - defer d.RUnlock() - if d.dll == nil || trap == 0 { - return 0, 0, errors.WithStack(ErrNotLoad{}) - } - - var e syscall.Errno - r1, r2, e = syscall.SyscallN(trap, args...) - if e == windows.ERROR_SUCCESS { - return r1, r2, nil - } - - return r1, r2, errors.WithStack(e) + return nil } func Open(filter string, layer Layer, priority int16, flags Flag) (*Handle, error) { @@ -154,22 +64,21 @@ func Open(filter string, layer Layer, priority int16, flags Flag) (*Handle, erro } // flags = flags | NoInstall - r1, _, e := global.calln( - global.procOpen, + r1, _, e := syscall.SyscallN( + procOpen.Addr(), uintptr(unsafe.Pointer(pf)), uintptr(layer), uintptr(priority), uintptr(flags), ) - if r1 == uintptr(windows.InvalidHandle) || e != nil { + if r1 == uintptr(windows.InvalidHandle) || e != 0 { return nil, errors.WithStack(e) } return &Handle{ - handle: r1, - layer: layer, - priority: priority, - ctxPeriod: 100, + handle: r1, + layer: layer, + priority: priority, }, nil } @@ -180,8 +89,8 @@ func HelperCompileFilter(filter string, layer Layer) (string, error) { return "", err } - r1, _, e := global.calln( - global.procHelperCompileFilter, + r1, _, e := syscall.SyscallN( + procHelperCompileFilter.Addr(), uintptr(unsafe.Pointer(pFilter)), // filter uintptr(layer), // layer uintptr(unsafe.Pointer(unsafe.SliceData(buf))), // object @@ -201,8 +110,8 @@ func HelperEvalFilter(filter string, ip []byte, addr *Address) (bool, error) { return false, err } - r1, _, e := global.calln( - global.procHelperEvalFilter, + r1, _, e := syscall.SyscallN( + procHelperEvalFilter.Addr(), uintptr(unsafe.Pointer(pFilter)), // filter uintptr(unsafe.Pointer(unsafe.SliceData(ip))), // pPacket uintptr(len(ip)), // packetLen @@ -221,8 +130,8 @@ func HelperFormatFilter(filter string, layer Layer) (string, error) { return "", err } - r1, _, e := global.calln( - global.procHelperFormatFilter, + r1, _, e := syscall.SyscallN( + procHelperFormatFilter.Addr(), uintptr(unsafe.Pointer(pFilter)), // filter uintptr(layer), // layer uintptr(unsafe.Pointer(unsafe.SliceData(buf))), // buffer diff --git a/divert_test.go b/divert_test.go index b8d3a43..9e25d96 100644 --- a/divert_test.go +++ b/divert_test.go @@ -32,151 +32,11 @@ func Test_Gofmt(t *testing.T) { } func Test_Load_DLL(t *testing.T) { - runLoad(t, "embed", func(t *testing.T) { - e1 := Load(DLL) - require.NoError(t, e1) - require.NoError(t, Release()) - e2 := Load(DLL) - require.NoError(t, e2) - require.NoError(t, Release()) - }) - - runLoad(t, "file", func(t *testing.T) { - e1 := Load(path) - require.NoError(t, e1) - require.NoError(t, Release()) - - e2 := Load(path) - require.NoError(t, e2) - require.NoError(t, Release()) - }) - - runLoad(t, "load-fail", func(t *testing.T) { - err := Load("C:\\Windows\\System32\\ws2_32.dll") - require.NotNil(t, err) - }) - - runLoad(t, "load-fail/open", func(t *testing.T) { - err := Load("C:\\Windows\\System32\\ws2_32.dll") - require.Error(t, err) - - d, err := Open("false", Network, 0, 0) - require.True(t, errors.Is(err, ErrNotLoad{})) - require.Nil(t, d) - }) - - runLoad(t, "load-fail/release", func(t *testing.T) { - err := Load("C:\\Windows\\System32\\ws2_32.dll") - require.NotNil(t, err) - - require.NoError(t, Release()) - }) - - runLoad(t, "load-fail/load", func(t *testing.T) { - e1 := Load("C:\\Windows\\System32\\ws2_32.dll") - require.NotNil(t, e1) - require.NoError(t, Release()) - - e := Load(DLL) - require.NoError(t, e) - require.NoError(t, Release()) - }) - - runLoad(t, "load/load", func(t *testing.T) { - e1 := Load(path) - require.NoError(t, e1) - - e2 := Load(DLL) - require.True(t, errors.Is(e2, ErrLoaded{})) - - require.NoError(t, Release()) - }) - - runLoad(t, "release/release", func(t *testing.T) { - require.NoError(t, Release()) - require.NoError(t, Release()) - }) - - runLoad(t, "load/release/release", func(t *testing.T) { - err := Load(DLL) - require.NoError(t, err) - - require.NoError(t, Release()) - require.NoError(t, Release()) - }) - - runLoad(t, "load/open/release", func(t *testing.T) { - err := Load(DLL) - require.NoError(t, err) - defer Release() - - d1, err := Open("false", Network, 0, 0) - require.NoError(t, err) - require.NoError(t, d1.Close()) - - require.NoError(t, Release()) - - _, err = d1.Recv(nil, nil) - require.True(t, errors.Is(err, ErrNotLoad{})) - }) - - runLoad(t, "open", func(t *testing.T) { - d, err := Open("false", Network, 0, 0) - require.Nil(t, d) - require.True(t, errors.Is(err, ErrNotLoad{})) - }) - - runLoad(t, "load/release/open", func(t *testing.T) { - err := Load(DLL) - require.NoError(t, err) - require.NoError(t, Release()) - - d, err := Open("false", Network, 0, 0) - require.Nil(t, d) - require.True(t, errors.Is(err, ErrNotLoad{})) - }) -} - -func runLoad(t *testing.T, name string, fn func(t *testing.T)) { - t.Run(name, func(t *testing.T) { - fn(t) - Release() - }) -} - -func Test_MustLoad_DLL(t *testing.T) { - runLoad(t, "embed", func(t *testing.T) { - MustLoad(DLL) - Release() - - MustLoad(DLL) - - MustLoad(DLL) - }) - - runLoad(t, "file", func(t *testing.T) { - MustLoad(path) - Release() - - MustLoad(path) - - MustLoad(path) - }) - - runLoad(t, "load-fail", func(t *testing.T) { - defer func() { - e := recover() - require.NotNil(t, e, e) - }() - - MustLoad("C:\\Windows\\System32\\ws2_32.dll") - }) } func Test_Helper(t *testing.T) { require.NoError(t, Load(DLL)) - defer Release() t.Run("format/null", func(t *testing.T) { d, err := Open("false", Network, 0, 0) diff --git a/dll.go b/dll.go deleted file mode 100644 index dc98aa4..0000000 --- a/dll.go +++ /dev/null @@ -1,68 +0,0 @@ -//go:build windows -// +build windows - -package divert - -import ( - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/driver/memmod" -) - -type dll interface { - Release() error - FindProc(string) (uintptr, error) - MustFindProc(string) uintptr -} - -type file windows.DLL - -func loadFileDLL(path string) (dll, error) { - dll, err := windows.LoadDLL(path) - if err != nil { - return nil, err - } - return (*file)(dll), nil -} - -func (d *file) FindProc(name string) (uintptr, error) { - p, err := ((*windows.DLL)(d)).FindProc(name) - if err != nil { - return 0, err - } - return p.Addr(), nil -} -func (d *file) MustFindProc(name string) uintptr { - hdl, err := d.FindProc(name) - if err != nil { - panic(err) - } - return hdl -} -func (d *file) Release() error { - return ((*windows.DLL)(d)).Release() -} - -type mem memmod.Module - -func loadMemDLL(data []byte) (dll, error) { - d, err := memmod.LoadLibrary(data) - if err != nil { - return nil, err - } - return (*mem)(d), nil -} - -func (d *mem) FindProc(name string) (uintptr, error) { - return ((*memmod.Module)(d)).ProcAddressByName(name) -} -func (d *mem) MustFindProc(name string) uintptr { - hdl, err := d.FindProc(name) - if err != nil { - panic(err) - } - return hdl -} -func (d *mem) Release() error { - ((*memmod.Module)(d)).Free() - return nil -} diff --git a/dll/dll.go b/dll/dll.go new file mode 100644 index 0000000..2ec5055 --- /dev/null +++ b/dll/dll.go @@ -0,0 +1,112 @@ +//go:build windows +// +build windows + +package dll + +import ( + "sync" + "sync/atomic" + + "golang.org/x/sys/windows" +) + +type LazyDll interface { + Handle() uintptr + Load() error + NewProc(name string) LazyProc +} + +type LazyProc interface { + Addr() uintptr + Call(a ...uintptr) (r1 uintptr, r2 uintptr, lastErr error) + Find() error +} + +type CommLazyDll struct { + LazyDll + loaded atomic.Bool +} + +func (d *CommLazyDll) Loaded() bool { return d.loaded.Load() } +func (d *CommLazyDll) Load() error { + if !d.loaded.Load() { + err := d.LazyDll.Load() + if err != nil { + return err + } + d.loaded.Store(true) + } + return nil +} + +func NewLazyDLL[T ~string | ~[]byte](dll T) *CommLazyDll { + switch dll := any(dll).(type) { + case string: + return &CommLazyDll{ + LazyDll: &SysLazyDll{LazyDLL: windows.LazyDLL{Name: dll}}, + } + case []byte: + return &CommLazyDll{ + LazyDll: &MemLazyDll{Data: dll}, + } + default: + panic("") + } +} + +// ResetLazyDll reset dll before load +func ResetLazyDll[T ~string | ~[]byte](dll *CommLazyDll, src T) { + if dll.Loaded() { + panic("cant't reset loaded dll") + } + + switch src := any(src).(type) { + case string: + dll.LazyDll = &SysLazyDll{LazyDLL: windows.LazyDLL{Name: src}} + case []byte: + dll.LazyDll = &MemLazyDll{Data: src} + default: + panic("") + } +} + +func (d *CommLazyDll) NewProc(name string) LazyProc { return &CommLazyProc{Name: name, dll: d} } + +type CommLazyProc struct { + Name string + dll *CommLazyDll + + found atomic.Bool + mu sync.RWMutex + proc LazyProc +} + +var _ LazyProc = (*CommLazyProc)(nil) + +func (p *CommLazyProc) Addr() uintptr { + p.mustFind() + return p.proc.Addr() +} +func (p *CommLazyProc) Call(a ...uintptr) (r1 uintptr, r2 uintptr, lastErr error) { + p.mustFind() + return p.proc.Call(a...) +} +func (p *CommLazyProc) Find() error { + if !p.found.Load() { + p.mu.Lock() + defer p.mu.Unlock() + if p.proc != nil { + return nil + } + + p.proc = p.dll.LazyDll.NewProc(p.Name) + p.found.Store(true) + } + return p.proc.Find() +} +func (p *CommLazyProc) mustFind() { + err := p.Find() + if err != nil { + panic(err) + } +} diff --git a/dll/dll_test.go b/dll/dll_test.go new file mode 100644 index 0000000..0688fdb --- /dev/null +++ b/dll/dll_test.go @@ -0,0 +1,21 @@ +//go:build windows +// +build windows + +package dll_test + +import ( + "testing" + + "github.com/lysShub/divert-go/dll" +) + +var ( + test = dll.NewLazyDLL(make([]byte, 3)) + openProc = test.NewProc("WinDivertOpen") +) + +func TestXxx(t *testing.T) { + dll.ResetLazyDll(test, `D:\OneDrive\code\go\divert-go\embed\WinDivert64.dll`) + + openProc.Find() +} diff --git a/dll/mem.go b/dll/mem.go new file mode 100644 index 0000000..9940e72 --- /dev/null +++ b/dll/mem.go @@ -0,0 +1,91 @@ +package dll + +import ( + "sync" + "sync/atomic" + "syscall" + "unsafe" + + "github.com/pkg/errors" + "golang.zx2c4.com/wireguard/windows/driver/memmod" +) + +type MemLazyDll struct { + Data []byte + + mu sync.Mutex + dll *memmod.Module +} + +var _ LazyDll = (*MemLazyDll)(nil) + +func (d *MemLazyDll) Handle() uintptr { + d.mustLoad() + return d.dll.BaseAddr() +} +func (d *MemLazyDll) Load() (err error) { + if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll))) != nil { + return nil + } + d.mu.Lock() + defer d.mu.Unlock() + if d.dll != nil { + return nil + } + + dll, err := memmod.LoadLibrary(d.Data) + if err != nil { + return errors.WithStack(err) + } + atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll)), unsafe.Pointer(dll)) + return nil +} +func (d *MemLazyDll) mustLoad() { + err := d.Load() + if err != nil { + panic(err) + } +} + +func (d *MemLazyDll) NewProc(name string) LazyProc { + return &MemLazyProc{Name: name, l: d} +} + +type MemLazyProc struct { + Name string + + mu sync.Mutex + l *MemLazyDll + proc uintptr +} + +func (p *MemLazyProc) Addr() uintptr { + p.mustFind() + return p.proc +} +func (p *MemLazyProc) Call(a ...uintptr) (r1 uintptr, r2 uintptr, lastErr error) { + p.mustFind() + return syscall.SyscallN(p.Addr(), a...) +} +func (p *MemLazyProc) Find() error { + if atomic.LoadUintptr(&p.proc) == 0 { + p.mu.Lock() + defer p.mu.Unlock() + + if p.proc == 0 { + err := p.l.Load() + if err != nil { + return err + } + + proc, err := p.l.dll.ProcAddressByName(p.Name) + if err != nil { + return err + } + atomic.StoreUintptr(&p.proc, proc) + } + + } + return nil +} +func (p *MemLazyProc) mustFind() {} diff --git a/dll/sys.go b/dll/sys.go new file mode 100644 index 0000000..a4a3fc9 --- /dev/null +++ b/dll/sys.go @@ -0,0 +1,15 @@ +package dll + +import ( + "golang.org/x/sys/windows" +) + +type SysLazyDll struct { + windows.LazyDLL +} + +var _ LazyDll = (*SysLazyDll)(nil) + +func (l *SysLazyDll) NewProc(name string) LazyProc { + return l.LazyDLL.NewProc(name) +} diff --git a/handle.go b/handle.go index 8a570bb..1355914 100644 --- a/handle.go +++ b/handle.go @@ -4,6 +4,7 @@ package divert import ( + "syscall" "unsafe" "github.com/pkg/errors" @@ -13,14 +14,13 @@ import ( type Handle struct { handle uintptr - layer Layer - priority int16 - ctxPeriod uint32 // milliseconds + layer Layer + priority int16 } func (d *Handle) Close() error { - r1, _, e := global.calln( - global.procClose, + r1, _, e := syscall.SyscallN( + procClose.Addr(), d.handle, ) if r1 == 0 { @@ -30,13 +30,6 @@ func (d *Handle) Close() error { } func (d *Handle) Priority() int16 { return d.priority } -func (d *Handle) SetCtxPeriod(milliseconds uint32) { - d.ctxPeriod = milliseconds - if d.ctxPeriod < 5 { - d.ctxPeriod = 5 - } -} - func (d *Handle) Recv(ip []byte, addr *Address) (int, error) { var recvLen uint32 var dataPtr, recvLenPtr uintptr @@ -45,8 +38,8 @@ func (d *Handle) Recv(ip []byte, addr *Address) (int, error) { recvLenPtr = uintptr(unsafe.Pointer(&recvLen)) } - r1, _, e := global.calln( - global.procRecv, + r1, _, e := syscall.SyscallN( + procRecv.Addr(), d.handle, dataPtr, uintptr(len(ip)), @@ -77,8 +70,8 @@ func (d *Handle) recvEx(ip []byte, addr *Address, recvLen *uint32, ol *windows.O ipPtr = uintptr(unsafe.Pointer(unsafe.SliceData(ip))) } - r1, _, e := global.calln( - global.procRecvEx, + r1, _, e := syscall.SyscallN( + procRecvEx.Addr(), d.handle, ipPtr, // pPacket uintptr(len(ip)), // packetLen @@ -100,8 +93,8 @@ func (d *Handle) Send(ip []byte, addr *Address) (int, error) { } var n uint32 - r1, _, e := global.calln( - global.procSend, + r1, _, e := syscall.SyscallN( + procSend.Addr(), d.handle, // handle uintptr(unsafe.Pointer(unsafe.SliceData(ip))), // pPacket uintptr(len(ip)), // packetLen @@ -122,8 +115,8 @@ func (d *Handle) SendEx(ip []byte, flag uint64, addr *Address, ol *windows.Overl var n uint32 // todo: support batch - r1, _, e := global.calln( - global.procSendEx, + r1, _, e := syscall.SyscallN( + procSendEx.Addr(), d.handle, uintptr(unsafe.Pointer(unsafe.SliceData(ip))), // pPacket uintptr(len(ip)), // packetLen @@ -140,8 +133,8 @@ func (d *Handle) SendEx(ip []byte, flag uint64, addr *Address, ol *windows.Overl } func (d *Handle) Shutdown(how Shutdown) error { - r1, _, e := global.calln( - global.procShutdown, + r1, _, e := syscall.SyscallN( + procShutdown.Addr(), d.handle, uintptr(how), ) @@ -152,8 +145,8 @@ func (d *Handle) Shutdown(how Shutdown) error { } func (d *Handle) SetParam(param PARAM, value uint64) error { - r1, _, e := global.calln( - global.procSetParam, + r1, _, e := syscall.SyscallN( + procSetParam.Addr(), d.handle, uintptr(param), uintptr(value), @@ -165,8 +158,8 @@ func (d *Handle) SetParam(param PARAM, value uint64) error { } func (d *Handle) GetParam(param PARAM) (value uint64, err error) { - r1, _, e := global.calln( - global.procGetParam, + r1, _, e := syscall.SyscallN( + procGetParam.Addr(), d.handle, uintptr(param), uintptr(unsafe.Pointer(&value)), diff --git a/handle_test.go b/handle_test.go index 9a38322..9b2669b 100644 --- a/handle_test.go +++ b/handle_test.go @@ -22,14 +22,12 @@ import ( ) func TestXxxxx(t *testing.T) { - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) } func Test_Address(t *testing.T) { - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) t.Run("flow", func(t *testing.T) { go func() { @@ -127,8 +125,7 @@ func Test_Address(t *testing.T) { } func Test_Recv_Error(t *testing.T) { - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) t.Run("close/recv", func(t *testing.T) { d, err := Open("false", Network, 0, 0) @@ -279,8 +276,7 @@ func Test_Recv_Error(t *testing.T) { } func Test_Recv(t *testing.T) { - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) t.Run("Recv/network/loopback", func(t *testing.T) { var ( @@ -358,8 +354,7 @@ func Test_Recv(t *testing.T) { func Test_Send(t *testing.T) { t.Skip("todo: can't pass github/action, local can pass") - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) t.Run("inbound", func(t *testing.T) { var ( @@ -503,8 +498,7 @@ func Test_Send(t *testing.T) { } func Test_Auto_Handle_DF(t *testing.T) { - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) t.Run("recv", func(t *testing.T) { var ( @@ -547,8 +541,7 @@ func Test_Auto_Handle_DF(t *testing.T) { func Test_Recving_Close(t *testing.T) { t.Skip("todo: not support concurrent call") - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) wg, _ := errgroup.WithContext(context.Background()) @@ -583,8 +576,7 @@ func Test_Recving_Close(t *testing.T) { // CONCLUSION: packet alway be handle by higher priority. func Test_Recv_Priority(t *testing.T) { - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) t.Run("outbound", func(t *testing.T) { var ( @@ -785,8 +777,7 @@ func Test_Recv_Priority(t *testing.T) { // CONCLUSION: send packet will be handle by equal(random) or lower(always) priority func Test_Send_Priority(t *testing.T) { - require.NoError(t, Load(DLL)) - defer Release() + MustLoad(DLL) t.Run("outbound", func(t *testing.T) { var ( @@ -794,48 +785,44 @@ func Test_Send_Priority(t *testing.T) { dst = netip.AddrPortFrom(netip.MustParseAddr("8.8.8.8"), uint16(randPort())) msg = "hello" ) - var ( filter = fmt.Sprintf( "outbound and udp and localAddr=%s and localPort=%d and remoteAddr=%s and remotePort=%d", src.Addr(), src.Port(), dst.Addr(), dst.Port(), ) - hiPriority, midPriority, loPriority = 4, 2, 1 + loPriority, midPriority, hiPriority int16 = 1, 2, 4 rs atomic.Int32 ) + eg, _ := errgroup.WithContext(context.Background()) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - for _, pri := range []int{hiPriority, midPriority, loPriority} { - go func(p int16) { - d, err := Open(filter, Network, p, Sniff) + for _, pri := range []int16{hiPriority, midPriority, loPriority} { + p := pri + eg.Go(func() error { + d, err := Open(filter, Network, p, 0) require.NoError(t, err) - defer d.Close() + time.AfterFunc(time.Second*3, func() { + println("call") + d.Close() + }) _, err = d.Recv(make([]byte, 1536), nil) if err == nil { rs.Add(int32(p)) } - }(int16(pri)) + return nil + }) } - d, err := Open("false", Network, int16(midPriority), WriteOnly) + d, err := Open("false", Network, midPriority, WriteOnly) require.NoError(t, err) defer d.Close() for rs.Load() == 0 { - select { - case <-ctx.Done(): - return - default: - } _, err := d.Send(buildUDP(t, src, dst, []byte(msg)), outboundAddr) require.NoError(t, err) time.Sleep(time.Second) } - require.Contains(t, []int{loPriority, loPriority + midPriority}, int(rs.Load())) - require.NoError(t, ctx.Err()) + require.Contains(t, []int16{loPriority, midPriority, loPriority + midPriority}, int16(rs.Load())) }) t.Run("inbound", func(t *testing.T) { @@ -844,47 +831,41 @@ func Test_Send_Priority(t *testing.T) { dst = netip.AddrPortFrom(locIP, uint16(randPort())) msg = "hello" ) - var ( filter = fmt.Sprintf( "inbound and udp and localAddr=%s and localPort=%d and remoteAddr=%s and remotePort=%d", dst.Addr(), dst.Port(), src.Addr(), src.Port(), ) - hiPriority, midPriority, loPriority = 4, 2, 1 + loPriority, midPriority, hiPriority int16 = 1, 2, 4 rs atomic.Int32 ) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() + eg, _ := errgroup.WithContext(context.Background()) - for _, pri := range []int{hiPriority, midPriority, loPriority} { - go func(p int16) { - d, err := Open(filter, Network, p, Sniff) + for _, pri := range []int16{hiPriority, midPriority, loPriority} { + p := pri + eg.Go(func() error { + d, err := Open(filter, Network, p, 0) require.NoError(t, err) - defer d.Close() + time.AfterFunc(time.Second*3, func() { d.Close() }) var b = make([]byte, 1536) var addr Address _, err = d.Recv(b, &addr) - require.NoError(t, err) - require.True(t, !addr.Flags.Outbound()) - rs.Add(int32(p)) - }(int16(pri)) + if err == nil { + rs.Add(int32(p)) + } + return nil + }) } d, err := Open("false", Network, int16(midPriority), WriteOnly) require.NoError(t, err) for rs.Load() == 0 { - select { - case <-ctx.Done(): - return - default: - } _, err := d.Send(buildUDP(t, src, dst, []byte(msg)), inboundAddr) require.NoError(t, err) time.Sleep(time.Second) } - require.Contains(t, []int{loPriority, loPriority + midPriority}, int(rs.Load())) - require.NoError(t, ctx.Err()) + require.Contains(t, []int16{loPriority, loPriority + midPriority}, int16(rs.Load())) }) }