Skip to content

Commit

Permalink
refactor: More interface reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 11, 2024
1 parent 1015c09 commit 355e4e8
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 107 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.6.0-alpha.3
github.com/sagernet/sing v0.6.0-alpha.4
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.5.1-0.20241109034027-099899991126 h1:pLMpV9pEAinrS9R1n1JLcbNesCl369RfvyxnYCPrkbw=
github.com/sagernet/sing v0.5.1-0.20241109034027-099899991126/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.6.0-alpha.3 h1:GLp9d6Gbt+Ioeplauuzojz1nY2J6moceVGYIOv/h5gA=
github.com/sagernet/sing v0.6.0-alpha.3/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.6.0-alpha.4 h1:h9oshzhaY0ESPC9HERcXtT9MhK7Oyo/IWXVu1uIiw3Y=
github.com/sagernet/sing v0.6.0-alpha.4/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
Expand Down
6 changes: 1 addition & 5 deletions monitor.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package tun

import (
"net/netip"

"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/x/list"
Expand Down Expand Up @@ -31,9 +29,7 @@ type NetworkUpdateMonitor interface {
type DefaultInterfaceMonitor interface {
Start() error
Close() error
DefaultInterfaceName(destination netip.Addr) string
DefaultInterfaceIndex(destination netip.Addr) int
DefaultInterface(destination netip.Addr) (string, int)
DefaultInterface() *control.Interface
OverrideAndroidVPN() bool
AndroidVPNEnabled() bool
RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback]
Expand Down
14 changes: 7 additions & 7 deletions monitor_android.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return err
}

oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex

m.defaultInterfaceName = link.Attrs().Name
m.defaultInterfaceIndex = link.Attrs().Index

oldInterface := m.defaultInterface.Load()
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
if err != nil {
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
}
m.defaultInterface.Store(newInterface)
var event int
if oldInterface != m.defaultInterfaceName || oldIndex != m.defaultInterfaceIndex {
if oldInterface == nil || oldInterface.Name != newInterface.Name || oldInterface.Index != newInterface.Index {
event |= EventInterfaceUpdate
}
if oldVPNEnabled != m.androidVPNEnabled {
Expand Down
49 changes: 17 additions & 32 deletions monitor_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
Expand Down Expand Up @@ -106,12 +107,13 @@ func (m *networkUpdateMonitor) Close() error {
}

func (m *defaultInterfaceMonitor) checkUpdate() error {
var (
defaultInterface *net.Interface
err error
)
err := m.interfaceFinder.Update()
if err != nil {
return E.Cause(err, "update interfaces")
}
var defaultInterface *control.Interface
if m.underNetworkExtension {
defaultInterface, err = getDefaultInterfaceBySocket()
defaultInterface, err = m.getDefaultInterfaceBySocket()
if err != nil {
return err
}
Expand Down Expand Up @@ -144,7 +146,7 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
if ones != 0 {
continue
}
routeInterface, err := net.InterfaceByIndex(routeMessage.Index)
routeInterface, err := m.interfaceFinder.ByIndex(routeMessage.Index)
if err != nil {
return err
}
Expand All @@ -164,18 +166,20 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
if defaultInterface == nil {
return ErrNoRoute
}
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceIndex = defaultInterface.Index
m.defaultInterfaceName = defaultInterface.Name
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
oldInterface := m.defaultInterface.Load()
newInterface, err := m.interfaceFinder.ByIndex(defaultInterface.Index)
if err != nil {
return E.Cause(err, "find updated interface: ", defaultInterface.Name)
}
m.defaultInterface.Store(newInterface)
if oldInterface != nil && oldInterface.Name == newInterface.Name && oldInterface.Index == newInterface.Index {
return nil
}
m.emit(EventInterfaceUpdate)
return nil
}

func getDefaultInterfaceBySocket() (*net.Interface, error) {
func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interface, error) {
socketFd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0)
if err != nil {
return nil, E.Cause(err, "create file descriptor")
Expand Down Expand Up @@ -218,24 +222,5 @@ func getDefaultInterfaceBySocket() (*net.Interface, error) {
case <-time.After(time.Second):
return nil, nil
}
interfaces, err := net.Interfaces()
if err != nil {
return nil, E.Cause(err, "net.Interfaces")
}
for _, netInterface := range interfaces {
interfaceAddrs, err := netInterface.Addrs()
if err != nil {
return nil, E.Cause(err, "net.Interfaces.Addrs")
}
for _, interfaceAddr := range interfaceAddrs {
ipNet, isIPNet := interfaceAddr.(*net.IPNet)
if !isIPNet {
continue
}
if ipNet.Contains(selectedAddr.AsSlice()) {
return &netInterface, nil
}
}
}
return nil, E.New("no interface found for address ", selectedAddr)
return m.interfaceFinder.ByAddr(selectedAddr)
}
15 changes: 8 additions & 7 deletions monitor_linux_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package tun

import (
"github.com/sagernet/netlink"
E "github.com/sagernet/sing/common/exceptions"

"golang.org/x/sys/unix"
)
Expand All @@ -24,13 +25,13 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return err
}

oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex

m.defaultInterfaceName = link.Attrs().Name
m.defaultInterfaceIndex = link.Attrs().Index

if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
oldInterface := m.defaultInterface.Load()
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
if err != nil {
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
}
m.defaultInterface.Store(newInterface)
if oldInterface != nil && oldInterface.Name == newInterface.Name && oldInterface.Index == newInterface.Index {
return nil
}
m.emit(EventInterfaceUpdate)
Expand Down
50 changes: 7 additions & 43 deletions monitor_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ package tun

import (
"errors"
"net/netip"
"sync"
"time"

"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
Expand Down Expand Up @@ -38,8 +38,7 @@ type defaultInterfaceMonitor struct {
interfaceFinder control.InterfaceFinder
overrideAndroidVPN bool
underNetworkExtension bool
defaultInterfaceName string
defaultInterfaceIndex int
defaultInterface atomic.Pointer[control.Interface]
androidVPNEnabled bool
noRoute bool
networkMonitor NetworkUpdateMonitor
Expand All @@ -56,13 +55,12 @@ func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logg
overrideAndroidVPN: options.OverrideAndroidVPN,
underNetworkExtension: options.UnderNetworkExtension,
networkMonitor: networkMonitor,
defaultInterfaceIndex: -1,
logger: logger,
}, nil
}

func (m *defaultInterfaceMonitor) Start() error {
_ = m.checkUpdate()
m.postCheckUpdate()
m.element = m.networkMonitor.RegisterCallback(m.delayCheckUpdate)
return nil
}
Expand All @@ -76,16 +74,11 @@ func (m *defaultInterfaceMonitor) delayCheckUpdate() {
}

func (m *defaultInterfaceMonitor) postCheckUpdate() {
err := m.interfaceFinder.Update()
if err != nil {
m.logger.Error("update interfaces: ", err)
}
err = m.checkUpdate()
err := m.checkUpdate()
if errors.Is(err, ErrNoRoute) {
if !m.noRoute {
m.noRoute = true
m.defaultInterfaceName = ""
m.defaultInterfaceIndex = -1
m.defaultInterface.Store(nil)
m.emit(EventNoRoute)
}
} else if err != nil {
Expand All @@ -102,37 +95,8 @@ func (m *defaultInterfaceMonitor) Close() error {
return nil
}

func (m *defaultInterfaceMonitor) DefaultInterfaceName(destination netip.Addr) string {
for _, address := range m.interfaceFinder.Interfaces() {
for _, prefix := range address.Addresses {
if prefix.Contains(destination) {
return address.Name
}
}
}
return m.defaultInterfaceName
}

func (m *defaultInterfaceMonitor) DefaultInterfaceIndex(destination netip.Addr) int {
for _, address := range m.interfaceFinder.Interfaces() {
for _, prefix := range address.Addresses {
if prefix.Contains(destination) {
return address.Index
}
}
}
return m.defaultInterfaceIndex
}

func (m *defaultInterfaceMonitor) DefaultInterface(destination netip.Addr) (string, int) {
for _, address := range m.interfaceFinder.Interfaces() {
for _, prefix := range address.Addresses {
if prefix.Contains(destination) {
return address.Name, address.Index
}
}
}
return m.defaultInterfaceName, m.defaultInterfaceIndex
func (m *defaultInterfaceMonitor) DefaultInterface() *control.Interface {
return m.defaultInterface.Load()
}

func (m *defaultInterfaceMonitor) OverrideAndroidVPN() bool {
Expand Down
16 changes: 8 additions & 8 deletions monitor_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"sync"

"github.com/sagernet/sing-tun/internal/winipcfg"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"

Expand Down Expand Up @@ -101,16 +102,15 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return ErrNoRoute
}

oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex

m.defaultInterfaceName = alias
m.defaultInterfaceIndex = index

if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
oldInterface := m.defaultInterface.Load()
newInterface, err := m.interfaceFinder.ByIndex(index)
if err != nil {
return E.Cause(err, "find updated interface: ", alias)
}
m.defaultInterface.Store(newInterface)
if oldInterface != nil && oldInterface.Name == newInterface.Name && oldInterface.Index == newInterface.Index {
return nil
}

m.emit(EventInterfaceUpdate)
return nil
}

0 comments on commit 355e4e8

Please sign in to comment.