diff --git a/common/geodata/decode_test.go b/common/geodata/decode_test.go index 2697e19f7..f210aee7a 100644 --- a/common/geodata/decode_test.go +++ b/common/geodata/decode_test.go @@ -6,18 +6,19 @@ import ( "io/fs" "os" "path/filepath" + "runtime" "testing" "github.com/p4gefau1t/trojan-go/common" "github.com/p4gefau1t/trojan-go/common/geodata" ) -const ( - geoipURL = "https://raw.githubusercontent.com/v2fly/geoip/release/geoip.dat" - geositeURL = "https://raw.githubusercontent.com/v2fly/domain-list-community/release/dlc.dat" -) - func init() { + const ( + geoipURL = "https://raw.githubusercontent.com/v2fly/geoip/release/geoip.dat" + geositeURL = "https://raw.githubusercontent.com/v2fly/domain-list-community/release/dlc.dat" + ) + wd, err := os.Getwd() common.Must(err) @@ -66,3 +67,39 @@ func TestDecodeGeoSite(t *testing.T) { t.Errorf("failed to load geosite:test, expected: %v, got: %v", expected, result) } } + +func BenchmarkLoadGeoIP(b *testing.B) { + m1 := runtime.MemStats{} + m2 := runtime.MemStats{} + + loader := geodata.GetGeodataLoader() + + runtime.ReadMemStats(&m1) + cn, _ := loader.LoadGeoIP("cn") + private, _ := loader.LoadGeoIP("private") + runtime.KeepAlive(cn) + runtime.KeepAlive(private) + runtime.ReadMemStats(&m2) + + b.ReportMetric(float64(m2.Alloc-m1.Alloc)/1024, "KiB(GeoIP-Alloc)") + b.ReportMetric(float64(m2.TotalAlloc-m1.TotalAlloc)/1024/1024, "MiB(GeoIP-TotalAlloc)") +} + +func BenchmarkLoadGeoSite(b *testing.B) { + m3 := runtime.MemStats{} + m4 := runtime.MemStats{} + + loader := geodata.GetGeodataLoader() + + runtime.ReadMemStats(&m3) + cn, _ := loader.LoadGeoSite("cn") + notcn, _ := loader.LoadGeoSite("geolocation-!cn") + private, _ := loader.LoadGeoSite("private") + runtime.KeepAlive(cn) + runtime.KeepAlive(notcn) + runtime.KeepAlive(private) + runtime.ReadMemStats(&m4) + + b.ReportMetric(float64(m4.Alloc-m3.Alloc)/1024/1024, "MiB(GeoSite-Alloc)") + b.ReportMetric(float64(m4.TotalAlloc-m3.TotalAlloc)/1024/1024, "MiB(GeoSite-TotalAlloc)") +} diff --git a/common/geodata/load.go b/common/geodata/load.go deleted file mode 100644 index 16ca509dc..000000000 --- a/common/geodata/load.go +++ /dev/null @@ -1,36 +0,0 @@ -package geodata - -import ( - "runtime" - - v2router "github.com/v2fly/v2ray-core/v4/app/router" -) - -var geoipcache GeoIPCache = make(map[string]*v2router.GeoIP) -var geositecache GeoSiteCache = make(map[string]*v2router.GeoSite) - -func LoadIP(filename, country string) ([]*v2router.CIDR, error) { - geoip, err := geoipcache.Unmarshal(filename, country) - if err != nil { - return nil, err - } - runtime.GC() - return geoip.Cidr, nil -} - -func LoadGeoIP(country string) ([]*v2router.CIDR, error) { - return LoadIP("geoip.dat", country) -} - -func LoadSite(filename, list string) ([]*v2router.Domain, error) { - geosite, err := geositecache.Unmarshal(filename, list) - if err != nil { - return nil, err - } - runtime.GC() - return geosite.Domain, nil -} - -func LoadGeoSite(list string) ([]*v2router.Domain, error) { - return LoadSite("geosite.dat", list) -} diff --git a/common/geodata/loader.go b/common/geodata/loader.go new file mode 100644 index 000000000..62838de53 --- /dev/null +++ b/common/geodata/loader.go @@ -0,0 +1,52 @@ +package geodata + +import ( + "runtime" + + v2router "github.com/v2fly/v2ray-core/v4/app/router" +) + +type geodataLoader interface { + LoadIP(filename, country string) ([]*v2router.CIDR, error) + LoadSite(filename, list string) ([]*v2router.Domain, error) + LoadGeoIP(country string) ([]*v2router.CIDR, error) + LoadGeoSite(list string) ([]*v2router.Domain, error) +} + +func GetGeodataLoader() geodataLoader { + return &geodataCache{ + make(map[string]*v2router.GeoIP), + make(map[string]*v2router.GeoSite), + } +} + +type geodataCache struct { + GeoIPCache + GeoSiteCache +} + +func (g *geodataCache) LoadIP(filename, country string) ([]*v2router.CIDR, error) { + geoip, err := g.GeoIPCache.Unmarshal(filename, country) + if err != nil { + return nil, err + } + runtime.GC() + return geoip.Cidr, nil +} + +func (g *geodataCache) LoadSite(filename, list string) ([]*v2router.Domain, error) { + geosite, err := g.GeoSiteCache.Unmarshal(filename, list) + if err != nil { + return nil, err + } + runtime.GC() + return geosite.Domain, nil +} + +func (g *geodataCache) LoadGeoIP(country string) ([]*v2router.CIDR, error) { + return g.LoadIP("geoip.dat", country) +} + +func (g *geodataCache) LoadGeoSite(list string) ([]*v2router.Domain, error) { + return g.LoadSite("geosite.dat", list) +} diff --git a/tunnel/router/client.go b/tunnel/router/client.go index 01876f267..265ecf497 100644 --- a/tunnel/router/client.go +++ b/tunnel/router/client.go @@ -4,6 +4,7 @@ import ( "context" "net" "regexp" + "runtime" "strconv" "strings" @@ -264,6 +265,11 @@ func loadCode(cfg *Config, prefix string) []codeInfo { } func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) { + m1 := runtime.MemStats{} + m2 := runtime.MemStats{} + m3 := runtime.MemStats{} + m4 := runtime.MemStats{} + cfg := config.FromContext(ctx, Name).(*Config) var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) @@ -304,10 +310,14 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) { return nil, common.NewError("unknown strategy: " + cfg.Router.DomainStrategy) } + runtime.ReadMemStats(&m1) + + geodataLoader := geodata.GetGeodataLoader() + ipCode := loadCode(cfg, "geoip:") for _, c := range ipCode { code := c.code - cidrs, err := geodata.LoadGeoIP(code) + cidrs, err := geodataLoader.LoadGeoIP(code) if err != nil { log.Error(err) } else { @@ -316,6 +326,8 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) { } } + runtime.ReadMemStats(&m2) + siteCode := loadCode(cfg, "geosite:") for _, c := range siteCode { code := c.code @@ -334,7 +346,7 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) { continue } - domainList, err := geodata.LoadGeoSite(code) + domainList, err := geodataLoader.LoadGeoSite(code) if err != nil { log.Error(err) } else { @@ -360,6 +372,8 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) { } } + runtime.ReadMemStats(&m3) + domainInfo := loadCode(cfg, "domain:") for _, info := range domainInfo { client.domains[info.strategy] = append(client.domains[info.strategy], &v2router.Domain{ @@ -433,5 +447,13 @@ func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) { } log.Info("router client created") + + runtime.ReadMemStats(&m4) + + log.Debugf("GeoIP rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m2.Alloc-m1.Alloc), common.HumanFriendlyTraffic(m2.TotalAlloc-m1.TotalAlloc)) + log.Debugf("GeoSite rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m3.Alloc-m2.Alloc), common.HumanFriendlyTraffic(m3.TotalAlloc-m2.TotalAlloc)) + log.Debugf("Plaintext rules -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m4.Alloc-m3.Alloc), common.HumanFriendlyTraffic(m4.TotalAlloc-m3.TotalAlloc)) + log.Debugf("Total(router) -> Alloc: %s; TotalAlloc: %s", common.HumanFriendlyTraffic(m4.Alloc-m1.Alloc), common.HumanFriendlyTraffic(m4.TotalAlloc-m1.TotalAlloc)) + return client, nil }