Skip to content

Commit

Permalink
add dll
Browse files Browse the repository at this point in the history
  • Loading branch information
lysShub committed May 25, 2024
1 parent 0c7d96c commit ae1cd70
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 416 deletions.
163 changes: 36 additions & 127 deletions divert.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,31 @@
package divert

import (
"sync"
"syscall"
"unsafe"

"github.com/lysShub/divert-go/dll"
"github.com/pkg/errors"
"golang.org/x/sys/windows"
)

var global divert
var (
divert = dll.NewLazyDLL("WinDivert.dll")

procOpen = divert.NewProc("WinDivertOpen")
procHelperCompileFilter = divert.NewProc("WinDivertHelperCompileFilter")
procHelperEvalFilter = divert.NewProc("WinDivertHelperEvalFilter")
procHelperFormatFilter = divert.NewProc("WinDivertHelperFormatFilter")

procRecv = divert.NewProc("WinDivertRecv")
procRecvEx = divert.NewProc("WinDivertRecvEx")
procSend = divert.NewProc("WinDivertSend")
procSendEx = divert.NewProc("WinDivertSendEx")
procShutdown = divert.NewProc("WinDivertShutdown")
procClose = divert.NewProc("WinDivertClose")
procSetParam = divert.NewProc("WinDivertSetParam")
procGetParam = divert.NewProc("WinDivertGetParam")
)

func MustLoad[T string | Mem](p T) struct{} {
err := Load(p)
Expand All @@ -23,128 +39,22 @@ func MustLoad[T string | Mem](p T) struct{} {
}

func Load[T string | Mem](p T) error {
global.Lock()
defer global.Unlock()
if global.dll != nil {
if divert.Loaded() {
return ErrLoaded{}
}

var err error
switch p := any(p).(type) {
case string:
global.dll, err = loadFileDLL(p)
if err != nil {
return errors.WithStack(err)
}
dll.ResetLazyDll(divert, p)
case Mem:
if err = driverInstall(p.Sys); err != nil {
if err := driverInstall(p.Sys); err != nil {
return errors.WithStack(err)
}

global.dll, err = loadMemDLL(p.DLL)
if err != nil {
return err
}
dll.ResetLazyDll(divert, p.DLL)
default:
return windows.ERROR_INVALID_PARAMETER
}

err = global.init()
return errors.WithStack(err)
}

func Release() error {
global.Lock()
defer global.Unlock()
if global.dll == nil {
return nil
}

err := global.dll.Release()
global.dll = nil
return errors.WithStack(err)
}

type divert struct {
sync.RWMutex
dll dll

procOpen uintptr // WinDivertOpen
procHelperCompileFilter uintptr // WinDivertHelperCompileFilter
procHelperEvalFilter uintptr // WinDivertHelperEvalFilter
procHelperFormatFilter uintptr // WinDivertHelperFormatFilter

procRecv uintptr // WinDivertRecv
procRecvEx uintptr // WinDivertRecvEx
procSend uintptr // WinDivertSend
procSendEx uintptr // WinDivertSendEx
procShutdown uintptr // WinDivertShutdown
procClose uintptr // WinDivertClose
procSetParam uintptr // WinDivertSetParam
procGetParam uintptr // WinDivertGetParam
}

func (d *divert) init() (err error) {
if d.procOpen, err = d.dll.FindProc("WinDivertOpen"); err != nil {
goto ret
}
if d.procHelperCompileFilter, err = d.dll.FindProc("WinDivertHelperCompileFilter"); err != nil {
goto ret
}
if d.procHelperEvalFilter, err = d.dll.FindProc("WinDivertHelperEvalFilter"); err != nil {
goto ret
}
if d.procHelperFormatFilter, err = d.dll.FindProc("WinDivertHelperFormatFilter"); err != nil {
goto ret
panic("")
}

if d.procRecv, err = d.dll.FindProc("WinDivertRecv"); err != nil {
goto ret
}
if d.procRecvEx, err = d.dll.FindProc("WinDivertRecvEx"); err != nil {
goto ret
}
if d.procSend, err = d.dll.FindProc("WinDivertSend"); err != nil {
goto ret
}
if d.procSendEx, err = d.dll.FindProc("WinDivertSendEx"); err != nil {
goto ret
}
if d.procShutdown, err = d.dll.FindProc("WinDivertShutdown"); err != nil {
goto ret
}
if d.procClose, err = d.dll.FindProc("WinDivertClose"); err != nil {
goto ret
}
if d.procSetParam, err = d.dll.FindProc("WinDivertSetParam"); err != nil {
goto ret
}
if d.procGetParam, err = d.dll.FindProc("WinDivertGetParam"); err != nil {
goto ret
}

ret:
if err != nil {
d.dll.Release()
d.dll = nil
}
return err
}

func (d *divert) calln(trap uintptr, args ...uintptr) (r1, r2 uintptr, err error) {
d.RLock()
defer d.RUnlock()
if d.dll == nil || trap == 0 {
return 0, 0, errors.WithStack(ErrNotLoad{})
}

var e syscall.Errno
r1, r2, e = syscall.SyscallN(trap, args...)
if e == windows.ERROR_SUCCESS {
return r1, r2, nil
}

return r1, r2, errors.WithStack(e)
return nil
}

func Open(filter string, layer Layer, priority int16, flags Flag) (*Handle, error) {
Expand All @@ -154,22 +64,21 @@ func Open(filter string, layer Layer, priority int16, flags Flag) (*Handle, erro
}

// flags = flags | NoInstall
r1, _, e := global.calln(
global.procOpen,
r1, _, e := syscall.SyscallN(
procOpen.Addr(),
uintptr(unsafe.Pointer(pf)),
uintptr(layer),
uintptr(priority),
uintptr(flags),
)
if r1 == uintptr(windows.InvalidHandle) || e != nil {
if r1 == uintptr(windows.InvalidHandle) || e != 0 {
return nil, errors.WithStack(e)
}

return &Handle{
handle: r1,
layer: layer,
priority: priority,
ctxPeriod: 100,
handle: r1,
layer: layer,
priority: priority,
}, nil
}

Expand All @@ -180,8 +89,8 @@ func HelperCompileFilter(filter string, layer Layer) (string, error) {
return "", err
}

r1, _, e := global.calln(
global.procHelperCompileFilter,
r1, _, e := syscall.SyscallN(
procHelperCompileFilter.Addr(),
uintptr(unsafe.Pointer(pFilter)), // filter
uintptr(layer), // layer
uintptr(unsafe.Pointer(unsafe.SliceData(buf))), // object
Expand All @@ -201,8 +110,8 @@ func HelperEvalFilter(filter string, ip []byte, addr *Address) (bool, error) {
return false, err
}

r1, _, e := global.calln(
global.procHelperEvalFilter,
r1, _, e := syscall.SyscallN(
procHelperEvalFilter.Addr(),
uintptr(unsafe.Pointer(pFilter)), // filter
uintptr(unsafe.Pointer(unsafe.SliceData(ip))), // pPacket
uintptr(len(ip)), // packetLen
Expand All @@ -221,8 +130,8 @@ func HelperFormatFilter(filter string, layer Layer) (string, error) {
return "", err
}

r1, _, e := global.calln(
global.procHelperFormatFilter,
r1, _, e := syscall.SyscallN(
procHelperFormatFilter.Addr(),
uintptr(unsafe.Pointer(pFilter)), // filter
uintptr(layer), // layer
uintptr(unsafe.Pointer(unsafe.SliceData(buf))), // buffer
Expand Down
140 changes: 0 additions & 140 deletions divert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,151 +32,11 @@ func Test_Gofmt(t *testing.T) {
}

func Test_Load_DLL(t *testing.T) {
runLoad(t, "embed", func(t *testing.T) {
e1 := Load(DLL)
require.NoError(t, e1)
require.NoError(t, Release())

e2 := Load(DLL)
require.NoError(t, e2)
require.NoError(t, Release())
})

runLoad(t, "file", func(t *testing.T) {
e1 := Load(path)
require.NoError(t, e1)
require.NoError(t, Release())

e2 := Load(path)
require.NoError(t, e2)
require.NoError(t, Release())
})

runLoad(t, "load-fail", func(t *testing.T) {
err := Load("C:\\Windows\\System32\\ws2_32.dll")
require.NotNil(t, err)
})

runLoad(t, "load-fail/open", func(t *testing.T) {
err := Load("C:\\Windows\\System32\\ws2_32.dll")
require.Error(t, err)

d, err := Open("false", Network, 0, 0)
require.True(t, errors.Is(err, ErrNotLoad{}))
require.Nil(t, d)
})

runLoad(t, "load-fail/release", func(t *testing.T) {
err := Load("C:\\Windows\\System32\\ws2_32.dll")
require.NotNil(t, err)

require.NoError(t, Release())
})

runLoad(t, "load-fail/load", func(t *testing.T) {
e1 := Load("C:\\Windows\\System32\\ws2_32.dll")
require.NotNil(t, e1)
require.NoError(t, Release())

e := Load(DLL)
require.NoError(t, e)
require.NoError(t, Release())
})

runLoad(t, "load/load", func(t *testing.T) {
e1 := Load(path)
require.NoError(t, e1)

e2 := Load(DLL)
require.True(t, errors.Is(e2, ErrLoaded{}))

require.NoError(t, Release())
})

runLoad(t, "release/release", func(t *testing.T) {
require.NoError(t, Release())
require.NoError(t, Release())
})

runLoad(t, "load/release/release", func(t *testing.T) {
err := Load(DLL)
require.NoError(t, err)

require.NoError(t, Release())
require.NoError(t, Release())
})

runLoad(t, "load/open/release", func(t *testing.T) {
err := Load(DLL)
require.NoError(t, err)
defer Release()

d1, err := Open("false", Network, 0, 0)
require.NoError(t, err)
require.NoError(t, d1.Close())

require.NoError(t, Release())

_, err = d1.Recv(nil, nil)
require.True(t, errors.Is(err, ErrNotLoad{}))
})

runLoad(t, "open", func(t *testing.T) {
d, err := Open("false", Network, 0, 0)
require.Nil(t, d)
require.True(t, errors.Is(err, ErrNotLoad{}))
})

runLoad(t, "load/release/open", func(t *testing.T) {
err := Load(DLL)
require.NoError(t, err)
require.NoError(t, Release())

d, err := Open("false", Network, 0, 0)
require.Nil(t, d)
require.True(t, errors.Is(err, ErrNotLoad{}))
})
}

func runLoad(t *testing.T, name string, fn func(t *testing.T)) {
t.Run(name, func(t *testing.T) {
fn(t)
Release()
})
}

func Test_MustLoad_DLL(t *testing.T) {
runLoad(t, "embed", func(t *testing.T) {
MustLoad(DLL)
Release()

MustLoad(DLL)

MustLoad(DLL)
})

runLoad(t, "file", func(t *testing.T) {
MustLoad(path)
Release()

MustLoad(path)

MustLoad(path)
})

runLoad(t, "load-fail", func(t *testing.T) {
defer func() {
e := recover()
require.NotNil(t, e, e)
}()

MustLoad("C:\\Windows\\System32\\ws2_32.dll")
})
}

func Test_Helper(t *testing.T) {
require.NoError(t, Load(DLL))
defer Release()

t.Run("format/null", func(t *testing.T) {
d, err := Open("false", Network, 0, 0)
Expand Down
Loading

0 comments on commit ae1cd70

Please sign in to comment.