Skip to content

Commit

Permalink
Optimize code (#6)
Browse files Browse the repository at this point in the history
* optimize code

* optimize code
  • Loading branch information
lysShub authored Jul 5, 2024
1 parent e4fdb79 commit 3552ece
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 30 deletions.
34 changes: 17 additions & 17 deletions address.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ type Address struct {
reserved3 [64]byte
}

func (a *Address) Network() *DateNetwork {
return (*DateNetwork)(unsafe.Pointer(&a.reserved3[0]))
func (a *Address) Network() *AddrNetwork {
return (*AddrNetwork)(unsafe.Pointer(&a.reserved3[0]))
}
func (a *Address) Flow() *DataFlow {
return (*DataFlow)(unsafe.Pointer(&a.reserved3[0]))
func (a *Address) Flow() *AddrFlow {
return (*AddrFlow)(unsafe.Pointer(&a.reserved3[0]))
}
func (a *Address) Socket() *DataSocket {
return (*DataSocket)(unsafe.Pointer(&a.reserved3[0]))
func (a *Address) Socket() *AddrSocket {
return (*AddrSocket)(unsafe.Pointer(&a.reserved3[0]))
}
func (a *Address) Reflect() *DataReflect {
return (*DataReflect)(unsafe.Pointer(&a.reserved3[0]))
func (a *Address) Reflect() *AddrReflect {
return (*AddrReflect)(unsafe.Pointer(&a.reserved3[0]))
}

type Flags uint8
Expand Down Expand Up @@ -135,12 +135,12 @@ func (f *Flags) SetUDPChecksum(sum bool) {
}
}

type DateNetwork struct {
type AddrNetwork struct {
IfIdx uint32 // Packet's interface index.
SubIfIdx uint32 // Packet's sub-interface index.
}

type DataFlow struct {
type AddrFlow struct {
EndpointId uint64 // Endpoint ID.
ParentEndpointId uint64 // Parent endpoint ID.
ProcessId uint32 // Process ID.
Expand All @@ -151,7 +151,7 @@ type DataFlow struct {
Protocol Proto // Protocol.
}

func (d *DataFlow) LocalAddr() netip.Addr {
func (d *AddrFlow) LocalAddr() netip.Addr {
var ip = make([]byte, 0, 16)
for i := 3; i >= 0; i-- {
ip = binary.BigEndian.AppendUint32(ip, d.localAddr[i])
Expand All @@ -164,11 +164,11 @@ func (d *DataFlow) LocalAddr() netip.Addr {
return addr
}

func (d *DataFlow) LocalAddrPort() netip.AddrPort {
func (d *AddrFlow) LocalAddrPort() netip.AddrPort {
return netip.AddrPortFrom(d.LocalAddr(), d.LocalPort)
}

func (d *DataFlow) RemoteAddr() netip.Addr {
func (d *AddrFlow) RemoteAddr() netip.Addr {
var ip = make([]byte, 0, 16)
for i := 3; i >= 0; i-- {
ip = binary.BigEndian.AppendUint32(ip, d.remoteAddr[i])
Expand All @@ -181,16 +181,16 @@ func (d *DataFlow) RemoteAddr() netip.Addr {
return addr
}

func (d *DataFlow) RemoteAddrPort() netip.AddrPort {
func (d *AddrFlow) RemoteAddrPort() netip.AddrPort {
return netip.AddrPortFrom(d.RemoteAddr(), d.RemotePort)
}

type DataSocket = DataFlow
type AddrSocket = AddrFlow

type DataReflect struct {
type AddrReflect struct {
Timestamp int64 // Handle open time.
ProcessId uint32 // Handle process ID.
Layer Layer // Handle layer.
Flags uint64 // Handle flags.
Flags Flag // Handle flags.
Priority int16 // Handle priority.
}
4 changes: 2 additions & 2 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ const (
type Flag uint64

const (
Sniff Flag = 0x0001 // copy data, like pcap
Sniff Flag = 0x0001 // copy ip packet, like pcap
Drop Flag = 0x0002
RecvOnly Flag = 0x0004
ReadOnly Flag = RecvOnly
SendOnly Flag = 0x0008
WriteOnly Flag = SendOnly
NoInstall Flag = 0x0010
Fragments Flag = 0x0020
Fragments Flag = 0x0020 // can recv ip MF segment
)

type Event uint8
Expand Down
4 changes: 4 additions & 0 deletions divert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ func Test_Gofmt(t *testing.T) {
}

func Test_Load_DLL(t *testing.T) {
// go test -list ".*" ./...
// and go test with -run flag
t.Skip("require independent test.")

t.Run("reset-mem", func(t *testing.T) {
MustLoad("test.dll")
MustLoad(DLL)
Expand Down
13 changes: 9 additions & 4 deletions dll/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
type MemLazyDll struct {
Data []byte

mu sync.Mutex
mu sync.RWMutex
dll *memmod.Module
}

Expand Down Expand Up @@ -51,7 +51,9 @@ func (d *MemLazyDll) NewProc(name string) LazyProc {
return &MemLazyProc{Name: name, l: d}
}
func (d *MemLazyDll) Loaded() bool {
return atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.dll))) != nil
d.mu.RLock()
defer d.mu.RUnlock()
return d.dll != nil
}

type MemLazyProc struct {
Expand Down Expand Up @@ -87,8 +89,11 @@ func (p *MemLazyProc) Find() error {
}
atomic.StoreUintptr(&p.proc, proc)
}

}
return nil
}
func (p *MemLazyProc) mustFind() {}
func (p *MemLazyProc) mustFind() {
if err := p.Find(); err != nil {
panic(err)
}
}
19 changes: 18 additions & 1 deletion dll/sys.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dll
import (
"sync/atomic"

"github.com/pkg/errors"
"golang.org/x/sys/windows"
)

Expand All @@ -14,7 +15,10 @@ type SysLazyDll struct {
var _ LazyDll = (*SysLazyDll)(nil)

func (l *SysLazyDll) NewProc(name string) LazyProc {
return l.LazyDLL.NewProc(name)
return &sysLazyProcWraper{
LazyProc: l.LazyDLL.NewProc(name),
dll: l.LazyDLL.Name,
}
}
func (l *SysLazyDll) Load() error {
if !l.loaded.Load() {
Expand All @@ -27,3 +31,16 @@ func (l *SysLazyDll) Load() error {
return nil
}
func (l *SysLazyDll) Loaded() bool { return l.loaded.Load() }

type sysLazyProcWraper struct {
*windows.LazyProc
dll string
}

func (p *sysLazyProcWraper) Find() error {
err := p.LazyProc.Find()
if err != nil {
return errors.WithMessage(err, p.dll)
}
return nil
}
18 changes: 12 additions & 6 deletions handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package divert

import (
"sync/atomic"
"syscall"
"unsafe"

Expand All @@ -19,12 +20,17 @@ type Handle struct {
}

func (d *Handle) Close() error {
r1, _, e := syscall.SyscallN(
procClose.Addr(),
d.handle,
)
if r1 == 0 {
return handleError(e)
const invalid = uintptr(windows.InvalidHandle)

fd := atomic.SwapUintptr(&d.handle, invalid)
if fd != invalid {
r1, _, e := syscall.SyscallN(
procClose.Addr(),
d.handle,
)
if r1 == 0 {
return handleError(e)
}
}
return nil
}
Expand Down

0 comments on commit 3552ece

Please sign in to comment.