diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 822e0ecee..57c2e5062 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -83,7 +83,7 @@ jobs: - name: Run linter uses: golangci/golangci-lint-action@v2 with: - version: v1.42.1 + version: v1.43.0 docker: name: Docker diff --git a/.golangci.toml b/.golangci.toml index 216d37157..97f84a529 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -9,4 +9,4 @@ format = "colored-line-number" [linters] enable-all = true -disable = ["gochecknoglobals", "gas", "goerr113", "exhaustivestruct"] +disable = ["ireturn", "varnamelen", "gochecknoglobals", "gas", "goerr113", "exhaustivestruct"] diff --git a/Makefile b/Makefile index 776678909..37cc4e9ac 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ ROOT_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) IMAGE_NAME := mtg APP_NAME := $(IMAGE_NAME) -GOLANGCI_LINT_VERSION := v1.42.1 +GOLANGCI_LINT_VERSION := v1.43.0 VERSION_GO := $(shell go version) VERSION_DATE := $(shell date -Ru) diff --git a/README.md b/README.md index f271d61b6..1b8a433a1 100644 --- a/README.md +++ b/README.md @@ -196,7 +196,7 @@ go get github.com/9seconds/mtg/v2 #### Build from sources ```console -git clone https://github.com:9seconds/mtg.git +git clone https://github.com/9seconds/mtg.git cd mtg make static ``` diff --git a/essentials/conns.go b/essentials/conns.go new file mode 100644 index 000000000..432591aa4 --- /dev/null +++ b/essentials/conns.go @@ -0,0 +1,28 @@ +package essentials + +import ( + "io" + "net" +) + +// CloseableReader is a reader interface that can close its reading end. +type CloseableReader interface { + io.Reader + + CloseRead() error +} + +// CloseableWriter is a writer that can close its writing end. +type CloseableWriter interface { + io.Writer + + CloseWrite() error +} + +// Conn is an extension of net.Conn that can close its ends. This mostly +// implies TCP connections. +type Conn interface { + net.Conn + CloseableReader + CloseableWriter +} diff --git a/essentials/doc.go b/essentials/doc.go new file mode 100644 index 000000000..5021ffa54 --- /dev/null +++ b/essentials/doc.go @@ -0,0 +1,6 @@ +// This is a minimal package that contains _essentials_ of mtglib and its +// complimentary packages. This is mostly required to comply some interfaces +// between mtglib and its internals to avoid circular dependencies. +// +// This package should contain only bare minimum and mostly technical. +package essentials diff --git a/example.config.toml b/example.config.toml index 9a9c96bd8..b4862f679 100644 --- a/example.config.toml +++ b/example.config.toml @@ -30,7 +30,9 @@ concurrency = 8192 # A size of user-space buffer for TCP to use. Since we do 2 connections, # then we have tcp-buffer * (4 + 2) per each connection: read/write for # each connection + 2 copy buffers to pump the data between sockets. -tcp-buffer = "4kb" +# +# Deprecated: this setting is no longer makes any effect. +# tcp-buffer = "4kb" # Sometimes you want to enforce mtg to use some types of # IP connectivity to Telegram. We have 4 modes: @@ -174,6 +176,27 @@ urls = [ # How often do we need to update a blocklist set. update-each = "24h" +# Allowlist is an opposite to a blocklist. Only those IPs that are coming from +# subnets defined in these lists are allowed. All others will be rejected. +# +# If this feature is disabled, then there won't be any check performed by this +# validator. It is possible to combine both blocklist and whitelist. +[defense.allowlist] +# You can enable/disable this feature. +enabled = false +# This is a limiter for concurrency. In order to protect website +# from overloading, we download files in this number of threads. +download-concurrency = 2 +# A list of URLs in FireHOL format (https://iplists.firehol.org/) +# You can provider links here (starts with https:// or http://) or +# path to a local file, but in this case it should be absolute. +urls = [ + # "https://iplists.firehol.org/files/firehol_level1.netset", + # "/local.file" + +] +update-each = "24h" + # statsd statistics integration. [stats.statsd] # enabled/disabled diff --git a/go.mod b/go.mod index 90dbe910f..2ef0dbca3 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/OneOfOne/xxhash v1.2.8 - github.com/alecthomas/kong v0.2.17 + github.com/alecthomas/kong v0.2.19 github.com/alecthomas/units v0.0.0-20210927113745-59d0afb8317a github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/babolivier/go-doh-client v0.0.0-20201028162107-a76cff4cb8b6 @@ -17,19 +17,21 @@ require ( github.com/panjf2000/ants/v2 v2.4.6 github.com/pelletier/go-toml v1.9.4 github.com/prometheus/client_golang v1.11.0 - github.com/prometheus/common v0.31.1 // indirect + github.com/prometheus/common v0.32.1 // indirect github.com/prometheus/procfs v0.7.3 // indirect - github.com/rs/zerolog v1.25.0 + github.com/rs/zerolog v1.26.0 github.com/smira/go-statsd v1.3.2 github.com/stretchr/objx v0.3.0 // indirect github.com/stretchr/testify v1.7.0 github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43 - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20211005001312-d4b1ae081e3b - golang.org/x/sys v0.0.0-20211004093028-2c5d950f24ef + golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 + golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 // indirect + golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 google.golang.org/protobuf v1.27.1 // indirect ) +require github.com/txthinking/socks5 v0.0.0-20211121111206-e03c1217a50b + require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.1.0 // indirect @@ -38,9 +40,12 @@ require ( github.com/gotd/ige v0.1.5 // indirect github.com/gotd/xor v0.1.1 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect + github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf // indirect + github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.16.0 // indirect diff --git a/go.sum b/go.sum index eaf855f97..0fdfd030d 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,8 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8= github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q= github.com/PuerkitoBio/goquery v1.6.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= -github.com/alecthomas/kong v0.2.17 h1:URDISCI96MIgcIlQyoCAlhOmrSw6pZScBNkctg8r0W0= -github.com/alecthomas/kong v0.2.17/go.mod h1:ka3VZ8GZNPXv9Ov+j4YNLkI8mTuhXyr/0ktSlqIydQQ= +github.com/alecthomas/kong v0.2.19 h1:qBDfByO5XgWUXyNB4D6OOhGh5Z1eNOwWayDPQJFNWdc= +github.com/alecthomas/kong v0.2.19/go.mod h1:ka3VZ8GZNPXv9Ov+j4YNLkI8mTuhXyr/0ktSlqIydQQ= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -195,6 +195,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/panjf2000/ants/v2 v2.4.6 h1:drmj9mcygn2gawZ155dRbo+NfXEfAssjZNU1qoIb4gQ= github.com/panjf2000/ants/v2 v2.4.6/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM= github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= @@ -217,8 +219,8 @@ github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6T github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.31.1 h1:d18hG4PkHnNAKNMOmFuXFaiY8Us0nird/2m60uS1AMs= -github.com/prometheus/common v0.31.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= +github.com/prometheus/common v0.32.1 h1:hWIdL3N2HoUx3B8j3YN9mWor0qhY/NlEKZEaXxuIRh4= +github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= @@ -229,8 +231,8 @@ github.com/quasilyte/go-ruleguard/dsl v0.3.2/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQP github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.25.0 h1:Rj7XygbUHKUlDPcVdoLyR91fJBsduXj5fRxyqIQj/II= -github.com/rs/zerolog v1.25.0/go.mod h1:7KHcEGe0QZPOm2IE4Kpb5rTh6n1h2hIgS5OOnu1rUaI= +github.com/rs/zerolog v1.26.0 h1:ORM4ibhEZeTeQlCojCK2kPz1ogAY4bGs4tD+SaAdGaE= +github.com/rs/zerolog v1.26.0/go.mod h1:yBiM87lvSqX8h0Ww4sdzNSkVYZ8dL2xjZJG1lAuGZEo= github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= @@ -250,12 +252,18 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf h1:7PflaKRtU4np/epFxRXlFhlzLXZzKFrH5/I4so5Ove0= +github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf/go.mod h1:CLUSJbazqETbaR+i0YAhXBICV9TrKH93pziccMhmhpM= +github.com/txthinking/socks5 v0.0.0-20211121111206-e03c1217a50b h1:6J/38A0Xmdnjacfie0Udams7OP/GdoExyTipKwuQWjY= +github.com/txthinking/socks5 v0.0.0-20211121111206-e03c1217a50b/go.mod h1:7NloQcrxaZYKURWph5HLxVDlIwMHJXCPkeWPtpftsIg= +github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe h1:gMWxZxBFRAXqoGkwkYlPX2zvyyKNWJpxOxCrjqJkm5A= +github.com/txthinking/x v0.0.0-20210326105829-476fab902fbe/go.mod h1:WgqbSEmUYSjEV3B1qmee/PpP2NYEz4bL9/+mF1ma+s4= github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43 h1:QEePdg0ty2r0t1+qwfZmQ4OOl/MB2UXIeJSpIZv56lg= github.com/tylertreat/BoomFilters v0.0.0-20210315201527-1a82519a3e43/go.mod h1:OYRfF6eb5wY9VRFkXJH8FFBi3plw2v+giaIu7P054pM= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -277,8 +285,8 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 h1:/pEO3GD/ABYAjuakUS6xSEmmlyVS4kxBNkeA9tLJiTI= +golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -342,11 +350,10 @@ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211005001312-d4b1ae081e3b h1:SXy8Ld8oKlcogOvUAh0J5Pm5RKzgYBMMxLxt6n5XW50= -golang.org/x/net v0.0.0-20211005001312-d4b1ae081e3b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 h1:CIJ76btIcR3eFI5EgSo6k1qKw9KJexJuRLI9G7Hp5wE= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -400,13 +407,12 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211004093028-2c5d950f24ef h1:fPxZ3Umkct3LZ8gK9nbk+DWDJ9fstZa2grBn+lWVKPs= -golang.org/x/sys v0.0.0-20211004093028-2c5d950f24ef/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 h1:TyHqChC80pFkXWraUUf6RuB5IqFdQieMLwwCJokV2pc= +golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -460,8 +466,8 @@ golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200820010801-b793a1359eac/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= -golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ= +golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/cli/access.go b/internal/cli/access.go index 23a47ce3e..2c9a0318f 100644 --- a/internal/cli/access.go +++ b/internal/cli/access.go @@ -13,6 +13,7 @@ import ( "strings" "sync" + "github.com/9seconds/mtg/v2/essentials" "github.com/9seconds/mtg/v2/internal/config" "github.com/9seconds/mtg/v2/internal/utils" "github.com/9seconds/mtg/v2/mtglib" @@ -106,7 +107,7 @@ func (a *Access) Run(cli *CLI, version string) error { } func (a *Access) getIP(ntw mtglib.Network, protocol string) net.IP { - client := ntw.MakeHTTPClient(func(ctx context.Context, network, address string) (net.Conn, error) { + client := ntw.MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error) { return ntw.DialContext(ctx, protocol, address) // nolint: wrapcheck }) diff --git a/internal/cli/run_proxy.go b/internal/cli/run_proxy.go index a4fd81e5a..a53e1d72b 100644 --- a/internal/cli/run_proxy.go +++ b/internal/cli/run_proxy.go @@ -38,10 +38,9 @@ func makeNetwork(conf *config.Config, version string) (mtglib.Network, error) { tcpTimeout := conf.Network.Timeout.TCP.Get(network.DefaultTimeout) httpTimeout := conf.Network.Timeout.HTTP.Get(network.DefaultHTTPTimeout) dohIP := conf.Network.DOHIP.Get(net.ParseIP(network.DefaultDOHHostname)).String() - bufferSize := conf.TCPBuffer.Get(network.DefaultBufferSize) userAgent := "mtg/" + version - baseDialer, err := network.NewDefaultDialer(tcpTimeout, int(bufferSize)) + baseDialer, err := network.NewDefaultDialer(tcpTimeout, 0) if err != nil { return nil, fmt.Errorf("cannot build a default dialer: %w", err) } @@ -86,15 +85,15 @@ func makeAntiReplayCache(conf *config.Config) mtglib.AntiReplayCache { ) } -func makeIPBlocklist(conf *config.Config, logger mtglib.Logger, ntw mtglib.Network) (mtglib.IPBlocklist, error) { - if !conf.Defense.Blocklist.Enabled.Get(false) { +func makeIPBlocklist(conf config.ListConfig, logger mtglib.Logger, ntw mtglib.Network) (mtglib.IPBlocklist, error) { + if !conf.Enabled.Get(false) { return ipblocklist.NewNoop(), nil } remoteURLs := []string{} localFiles := []string{} - for _, v := range conf.Defense.Blocklist.URLs { + for _, v := range conf.URLs { if v.IsRemote() { remoteURLs = append(remoteURLs, v.String()) } else { @@ -104,7 +103,7 @@ func makeIPBlocklist(conf *config.Config, logger mtglib.Logger, ntw mtglib.Netwo firehol, err := ipblocklist.NewFirehol(logger.Named("ipblockist"), ntw, - conf.Defense.Blocklist.DownloadConcurrency.Get(1), + conf.DownloadConcurrency.Get(1), remoteURLs, localFiles) if err != nil { @@ -153,7 +152,7 @@ func makeEventStream(conf *config.Config, logger mtglib.Logger) (mtglib.EventStr return events.NewNoopStream(), nil } -func runProxy(conf *config.Config, version string) error { +func runProxy(conf *config.Config, version string) error { // nolint: funlen logger := makeLogger(conf) logger.BindJSON("configuration", conf.String()).Debug("configuration") @@ -163,11 +162,22 @@ func runProxy(conf *config.Config, version string) error { return fmt.Errorf("cannot build network: %w", err) } - blocklist, err := makeIPBlocklist(conf, logger, ntw) + blocklist, err := makeIPBlocklist(conf.Defense.Blocklist, logger, ntw) if err != nil { return fmt.Errorf("cannot build ip blocklist: %w", err) } + var whitelist mtglib.IPBlocklist + + if conf.Defense.Allowlist.Enabled.Get(false) { + whlist, err := makeIPBlocklist(conf.Defense.Allowlist, logger, ntw) + if err != nil { + return fmt.Errorf("cannot build ip blocklist: %w", err) + } + + whitelist = whlist + } + eventStream, err := makeEventStream(conf, logger) if err != nil { return fmt.Errorf("cannot build event stream: %w", err) @@ -178,14 +188,15 @@ func runProxy(conf *config.Config, version string) error { Network: ntw, AntiReplayCache: makeAntiReplayCache(conf), IPBlocklist: blocklist, + IPWhitelist: whitelist, EventStream: eventStream, Secret: conf.Secret, - BufferSize: conf.TCPBuffer.Get(mtglib.DefaultBufferSize), DomainFrontingPort: conf.DomainFrontingPort.Get(mtglib.DefaultDomainFrontingPort), PreferIP: conf.PreferIP.Get(mtglib.DefaultPreferIP), AllowFallbackOnUnknownDC: conf.AllowFallbackOnUnknownDC.Get(false), + TolerateTimeSkewness: conf.TolerateTimeSkewness.Value, } proxy, err := mtglib.NewProxy(opts) @@ -193,7 +204,7 @@ func runProxy(conf *config.Config, version string) error { return fmt.Errorf("cannot create a proxy: %w", err) } - listener, err := utils.NewListener(conf.BindTo.Get(""), int(opts.BufferSize)) + listener, err := utils.NewListener(conf.BindTo.Get(""), 0) if err != nil { return fmt.Errorf("cannot start proxy: %w", err) } diff --git a/internal/cli/simple_run.go b/internal/cli/simple_run.go index a0bb25d83..80edec36f 100644 --- a/internal/cli/simple_run.go +++ b/internal/cli/simple_run.go @@ -15,7 +15,7 @@ type SimpleRun struct { Debug bool `kong:"name='debug',short='d',help='Run in debug mode.'"` // nolint: lll Concurrency uint64 `kong:"name='concurrency',short='c',default='8192',help='Max number of concurrent connection to proxy.'"` // nolint: lll - TCPBuffer string `kong:"name='tcp-buffer',short='b',default='4KB',help='Size of TCP buffer to use.'"` // nolint: lll + TCPBuffer string `kong:"name='tcp-buffer',short='b',default='4KB',help='Deprecated and ignored'"` // nolint: lll PreferIP string `kong:"name='prefer-ip',short='i',default='prefer-ipv6',help='IP preference. By default we prefer IPv6 with fallback to IPv4.'"` // nolint: lll DomainFrontingPort uint64 `kong:"name='domain-fronting-port',short='p',default='443',help='A port to access for domain fronting.'"` // nolint: lll DOHIP net.IP `kong:"name='doh-ip',short='n',default='9.9.9.9',help='IP address of DNS-over-HTTP to use.'"` // nolint: lll @@ -38,10 +38,6 @@ func (s *SimpleRun) Run(cli *CLI, version string) error { // nolint: cyclop return fmt.Errorf("incorrect concurrency: %w", err) } - if err := conf.TCPBuffer.Set(s.TCPBuffer); err != nil { - return fmt.Errorf("incorrect tcp-buffer: %w", err) - } - if err := conf.PreferIP.Set(s.PreferIP); err != nil { return fmt.Errorf("incorrect prefer-ip: %w", err) } diff --git a/internal/config/config.go b/internal/config/config.go index c072dadba..9b69a9402 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,28 +8,36 @@ import ( "github.com/9seconds/mtg/v2/mtglib" ) +type Optional struct { + Enabled TypeBool `json:"enabled"` +} + +type ListConfig struct { + Optional + + DownloadConcurrency TypeConcurrency `json:"downloadConcurrency"` + URLs []TypeBlocklistURI `json:"urls"` + UpdateEach TypeDuration `json:"updateEach"` +} + type Config struct { Debug TypeBool `json:"debug"` AllowFallbackOnUnknownDC TypeBool `json:"allowFallbackOnUnknownDc"` Secret mtglib.Secret `json:"secret"` BindTo TypeHostPort `json:"bindTo"` - TCPBuffer TypeBytes `json:"tcpBuffer"` PreferIP TypePreferIP `json:"preferIp"` DomainFrontingPort TypePort `json:"domainFrontingPort"` TolerateTimeSkewness TypeDuration `json:"tolerateTimeSkewness"` Concurrency TypeConcurrency `json:"concurrency"` Defense struct { AntiReplay struct { - Enabled TypeBool `json:"enabled"` + Optional + MaxSize TypeBytes `json:"maxSize"` ErrorRate TypeErrorRate `json:"errorRate"` } `json:"antiReplay"` - Blocklist struct { - Enabled TypeBool `json:"enabled"` - DownloadConcurrency TypeConcurrency `json:"downloadConcurrency"` - URLs []TypeBlocklistURI `json:"urls"` - UpdateEach TypeDuration `json:"updateEach"` - } `json:"blocklist"` + Blocklist ListConfig `json:"blocklist"` + Allowlist ListConfig `json:"allowlist"` } `json:"defense"` Network struct { Timeout struct { @@ -42,13 +50,15 @@ type Config struct { } `json:"network"` Stats struct { StatsD struct { - Enabled TypeBool `json:"enabled"` + Optional + Address TypeHostPort `json:"address"` MetricPrefix TypeMetricPrefix `json:"metricPrefix"` TagFormat TypeStatsdTagFormat `json:"tagFormat"` } `json:"statsd"` Prometheus struct { - Enabled TypeBool `json:"enabled"` + Optional + BindTo TypeHostPort `json:"bindTo"` HTTPPath TypeHTTPPath `json:"httpPath"` MetricPrefix TypeMetricPrefix `json:"metricPrefix"` diff --git a/internal/config/parse.go b/internal/config/parse.go index a36471256..591e6bf1e 100644 --- a/internal/config/parse.go +++ b/internal/config/parse.go @@ -13,7 +13,6 @@ type tomlConfig struct { AllowFallbackOnUnknownDC bool `toml:"allow-fallback-on-unknown-dc" json:"allowFallbackOnUnknownDc,omitempty"` Secret string `toml:"secret" json:"secret"` BindTo string `toml:"bind-to" json:"bindTo"` - TCPBuffer string `toml:"tcp-buffer" json:"tcpBuffer,omitempty"` PreferIP string `toml:"prefer-ip" json:"preferIp,omitempty"` DomainFrontingPort uint `toml:"domain-fronting-port" json:"domainFrontingPort,omitempty"` TolerateTimeSkewness string `toml:"tolerate-time-skewness" json:"tolerateTimeSkewness,omitempty"` @@ -30,6 +29,12 @@ type tomlConfig struct { URLs []string `toml:"urls" json:"urls,omitempty"` UpdateEach string `toml:"update-each" json:"updateEach,omitempty"` } `toml:"blocklist" json:"blocklist,omitempty"` + Allowlist struct { + Enabled bool `toml:"enabled" json:"enabled,omitempty"` + DownloadConcurrency uint `toml:"download-concurrency" json:"downloadConcurrency,omitempty"` + URLs []string `toml:"urls" json:"urls,omitempty"` + UpdateEach string `toml:"update-each" json:"updateEach,omitempty"` + } `toml:"allowlist" json:"allowlist,omitempty"` } `toml:"defense" json:"defense,omitempty"` Network struct { Timeout struct { diff --git a/internal/testlib/mtglib_network_mock.go b/internal/testlib/mtglib_network_mock.go index 0f58a2522..97bd69a54 100644 --- a/internal/testlib/mtglib_network_mock.go +++ b/internal/testlib/mtglib_network_mock.go @@ -2,9 +2,9 @@ package testlib import ( "context" - "net" "net/http" + "github.com/9seconds/mtg/v2/essentials" "github.com/stretchr/testify/mock" ) @@ -12,19 +12,19 @@ type MtglibNetworkMock struct { mock.Mock } -func (m *MtglibNetworkMock) Dial(network, address string) (net.Conn, error) { +func (m *MtglibNetworkMock) Dial(network, address string) (essentials.Conn, error) { args := m.Called(network, address) - return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck + return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck } -func (m *MtglibNetworkMock) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (m *MtglibNetworkMock) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) { args := m.Called(ctx, network, address) - return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck + return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck } func (m *MtglibNetworkMock) MakeHTTPClient(dialFunc func(ctx context.Context, - network, address string) (net.Conn, error)) *http.Client { + network, address string) (essentials.Conn, error)) *http.Client { return m.Called(dialFunc).Get(0).(*http.Client) } diff --git a/internal/testlib/net_conn_mock.go b/internal/testlib/net_conn_mock.go index e167665f2..476fe0923 100644 --- a/internal/testlib/net_conn_mock.go +++ b/internal/testlib/net_conn_mock.go @@ -7,42 +7,50 @@ import ( "github.com/stretchr/testify/mock" ) -type NetConnMock struct { +type EssentialsConnMock struct { mock.Mock } -func (n *NetConnMock) Read(b []byte) (int, error) { +func (n *EssentialsConnMock) Read(b []byte) (int, error) { args := n.Called(b) return args.Int(0), args.Error(1) } -func (n *NetConnMock) Write(b []byte) (int, error) { +func (n *EssentialsConnMock) Write(b []byte) (int, error) { args := n.Called(b) return args.Int(0), args.Error(1) } -func (n *NetConnMock) Close() error { +func (n *EssentialsConnMock) Close() error { return n.Called().Error(0) // nolint: wrapcheck } -func (n *NetConnMock) LocalAddr() net.Addr { +func (n *EssentialsConnMock) CloseRead() error { + return n.Called().Error(0) // nolint: wrapcheck +} + +func (n *EssentialsConnMock) CloseWrite() error { + return n.Called().Error(0) // nolint: wrapcheck +} + +func (n *EssentialsConnMock) LocalAddr() net.Addr { return n.Called().Get(0).(net.Addr) } -func (n *NetConnMock) RemoteAddr() net.Addr { +func (n *EssentialsConnMock) RemoteAddr() net.Addr { return n.Called().Get(0).(net.Addr) } -func (n *NetConnMock) SetDeadline(t time.Time) error { +func (n *EssentialsConnMock) SetDeadline(t time.Time) error { return n.Called(t).Error(0) // nolint: wrapcheck } -func (n *NetConnMock) SetReadDeadline(t time.Time) error { +func (n *EssentialsConnMock) SetReadDeadline(t time.Time) error { return n.Called(t).Error(0) // nolint: wrapcheck } -func (n *NetConnMock) SetWriteDeadline(t time.Time) error { +func (n *EssentialsConnMock) SetWriteDeadline(t time.Time) error { return n.Called(t).Error(0) // nolint: wrapcheck } diff --git a/internal/utils/net_listener.go b/internal/utils/net_listener.go index c168e889f..496f51b38 100644 --- a/internal/utils/net_listener.go +++ b/internal/utils/net_listener.go @@ -9,8 +9,6 @@ import ( type Listener struct { net.Listener - - bufferSize int } func (l Listener) Accept() (net.Conn, error) { @@ -19,7 +17,7 @@ func (l Listener) Accept() (net.Conn, error) { return nil, err // nolint: wrapcheck } - if err := network.SetClientSocketOptions(conn, l.bufferSize); err != nil { + if err := network.SetClientSocketOptions(conn, 0); err != nil { conn.Close() return nil, fmt.Errorf("cannot set TCP options: %w", err) @@ -35,7 +33,6 @@ func NewListener(bindTo string, bufferSize int) (net.Listener, error) { } return Listener{ - Listener: base, - bufferSize: bufferSize, + Listener: base, }, nil } diff --git a/ipblocklist/files/http.go b/ipblocklist/files/http.go new file mode 100644 index 000000000..6a60edeb9 --- /dev/null +++ b/ipblocklist/files/http.go @@ -0,0 +1,63 @@ +package files + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" +) + +type httpFile struct { + http *http.Client + url string +} + +func (h httpFile) Open(ctx context.Context) (io.ReadCloser, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, h.url, nil) + if err != nil { + panic(err) + } + + response, err := h.http.Do(request) + if err != nil { + if response != nil { + io.Copy(io.Discard, response.Body) // nolint: errcheck + response.Body.Close() + } + + return nil, fmt.Errorf("cannot get url %s: %w", h.url, err) + } + + if response.StatusCode >= http.StatusBadRequest { + return nil, fmt.Errorf("unexpected status code %d", response.StatusCode) + } + + return response.Body, nil +} + +func (h httpFile) String() string { + return h.url +} + +func NewHTTP(client *http.Client, endpoint string) (File, error) { + if client == nil { + return nil, ErrBadHTTPClient + } + + parsed, err := url.Parse(endpoint) + if err != nil { + return nil, fmt.Errorf("incorrect url %s: %w", endpoint, err) + } + + switch parsed.Scheme { + case "http", "https": + default: + return nil, fmt.Errorf("unsupported url %s", endpoint) + } + + return httpFile{ + http: client, + url: endpoint, + }, nil +} diff --git a/ipblocklist/files/http_test.go b/ipblocklist/files/http_test.go new file mode 100644 index 000000000..6559969b2 --- /dev/null +++ b/ipblocklist/files/http_test.go @@ -0,0 +1,90 @@ +package files_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/9seconds/mtg/v2/ipblocklist/files" + "github.com/stretchr/testify/suite" +) + +type HTTPTestSuite struct { + suite.Suite + + httpClient *http.Client + httpServer *httptest.Server + ctx context.Context + ctxCancel context.CancelFunc +} + +func (suite *HTTPTestSuite) makeFile(path string) (files.File, error) { + return files.NewHTTP(suite.httpClient, suite.httpServer.URL+"/"+path) // nolint: wrapcheck +} + +func (suite *HTTPTestSuite) SetupSuite() { + mux := http.NewServeMux() + + mux.Handle("/", http.FileServer(http.Dir("testdata"))) + + suite.httpServer = httptest.NewServer(mux) + suite.httpClient = suite.httpServer.Client() +} + +func (suite *HTTPTestSuite) SetupTest() { + suite.ctx, suite.ctxCancel = context.WithCancel(context.Background()) +} + +func (suite *HTTPTestSuite) TearDownTest() { + suite.ctxCancel() + suite.httpServer.CloseClientConnections() +} + +func (suite *HTTPTestSuite) TearDownSuite() { + suite.httpServer.Close() +} + +func (suite *HTTPTestSuite) TestBadURL() { + _, err := files.NewHTTP(suite.httpClient, "sdfsdf") + suite.Error(err) +} + +func (suite *HTTPTestSuite) TestBadSchema() { + _, err := files.NewHTTP(suite.httpClient, "gopher://lala") + suite.Error(err) +} + +func (suite *HTTPTestSuite) TestNilHTTPClient() { + _, err := files.NewHTTP(nil, "") + suite.Error(err) +} + +func (suite *HTTPTestSuite) TestAbsentFile() { + file, err := suite.makeFile("absent") + suite.NoError(err) + + _, err = file.Open(suite.ctx) + suite.Error(err) +} + +func (suite *HTTPTestSuite) TestOk() { + file, err := suite.makeFile("readable") + suite.NoError(err) + + readCloser, err := file.Open(suite.ctx) + suite.NoError(err) + + defer readCloser.Close() + + data, err := io.ReadAll(readCloser) + suite.NoError(err) + suite.Equal("Hooray!", strings.TrimSpace(string(data))) +} + +func TestHTTP(t *testing.T) { + t.Parallel() + suite.Run(t, &HTTPTestSuite{}) +} diff --git a/ipblocklist/files/init.go b/ipblocklist/files/init.go new file mode 100644 index 000000000..97570afd9 --- /dev/null +++ b/ipblocklist/files/init.go @@ -0,0 +1,14 @@ +package files + +import ( + "context" + "errors" + "io" +) + +var ErrBadHTTPClient = errors.New("incorrect http client") + +type File interface { + Open(context.Context) (io.ReadCloser, error) + String() string +} diff --git a/ipblocklist/files/local.go b/ipblocklist/files/local.go new file mode 100644 index 000000000..3cd08c709 --- /dev/null +++ b/ipblocklist/files/local.go @@ -0,0 +1,30 @@ +package files + +import ( + "context" + "fmt" + "io" + "os" +) + +type localFile struct { + path string +} + +func (l localFile) Open(ctx context.Context) (io.ReadCloser, error) { + return os.Open(l.path) // nolint: wrapcheck +} + +func (l localFile) String() string { + return l.path +} + +func NewLocal(path string) (File, error) { + if stat, err := os.Stat(path); os.IsNotExist(err) || stat.IsDir() || stat.Mode().Perm()&0o400 == 0 { + return nil, fmt.Errorf("%s is not a readable file", path) + } + + return localFile{ + path: path, + }, nil +} diff --git a/ipblocklist/files/local_test.go b/ipblocklist/files/local_test.go new file mode 100644 index 000000000..f3dba384e --- /dev/null +++ b/ipblocklist/files/local_test.go @@ -0,0 +1,55 @@ +package files_test + +import ( + "context" + "io" + "path/filepath" + "strings" + "testing" + + "github.com/9seconds/mtg/v2/ipblocklist/files" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type LocalTestSuite struct { + suite.Suite +} + +func (suite *LocalTestSuite) getLocalFile(name string) string { + return filepath.Join("testdata", name) +} + +func (suite *LocalTestSuite) TestIncorrect() { + names := []string{ + "absent", + "directory", + } + + for _, v := range names { + value := v + + suite.T().Run(v, func(t *testing.T) { + _, err := files.NewLocal(suite.getLocalFile(value)) + assert.Error(t, err) + }) + } +} + +func (suite *LocalTestSuite) TestOk() { + file, err := files.NewLocal(suite.getLocalFile("readable")) + suite.NoError(err) + + reader, err := file.Open(context.Background()) + suite.NoError(err) + + data, err := io.ReadAll(reader) + suite.NoError(err) + + suite.Equal("Hooray!", strings.TrimSpace(string(data))) +} + +func TestLocal(t *testing.T) { + t.Parallel() + suite.Run(t, &LocalTestSuite{}) +} diff --git a/ipblocklist/files/testdata/directory/.gitkeep b/ipblocklist/files/testdata/directory/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/ipblocklist/files/testdata/readable b/ipblocklist/files/testdata/readable new file mode 100644 index 000000000..715fcb72e --- /dev/null +++ b/ipblocklist/files/testdata/readable @@ -0,0 +1 @@ +Hooray! diff --git a/ipblocklist/firehol.go b/ipblocklist/firehol.go index 828afe072..41726a71b 100644 --- a/ipblocklist/firehol.go +++ b/ipblocklist/firehol.go @@ -4,16 +4,13 @@ import ( "bufio" "context" "fmt" - "io" "net" - "net/http" - "net/url" - "os" "regexp" "strings" "sync" "time" + "github.com/9seconds/mtg/v2/ipblocklist/files" "github.com/9seconds/mtg/v2/mtglib" "github.com/kentik/patricia" "github.com/kentik/patricia/bool_tree" @@ -41,20 +38,16 @@ var fireholRegexpComment = regexp.MustCompile(`\s*#.*?$`) // 127.0.0.1 # you can specify an IP // 10.0.0.0/8 # or cidr type Firehol struct { - ctx context.Context - ctxCancel context.CancelFunc - logger mtglib.Logger + ctx context.Context + ctxCancel context.CancelFunc + logger mtglib.Logger + updateMutex sync.RWMutex - rwMutex sync.RWMutex + blocklists []files.File - remoteURLs []string - localFiles []string - - httpClient *http.Client workerPool *ants.Pool - - treeV4 *bool_tree.TreeV4 - treeV6 *bool_tree.TreeV6 + treeV4 *bool_tree.TreeV4 + treeV6 *bool_tree.TreeV6 } // Shutdown stop a background update process. @@ -68,8 +61,8 @@ func (f *Firehol) Contains(ip net.IP) bool { return true } - f.rwMutex.RLock() - defer f.rwMutex.RUnlock() + f.updateMutex.RLock() + defer f.updateMutex.RUnlock() if ip4 := ip.To4(); ip4 != nil { return f.containsIPv4(ip4) @@ -98,22 +91,14 @@ func (f *Firehol) Run(updateEach time.Duration) { } }() - if err := f.update(); err != nil { - f.logger.WarningError("cannot update blocklist", err) - } else { - f.logger.Info("blocklist was updated") - } + f.update() for { select { case <-f.ctx.Done(): return case <-ticker.C: - if err := f.update(); err != nil { - f.logger.WarningError("cannot update blocklist", err) - } else { - f.logger.Info("blocklist was updated") - } + f.update() } } } @@ -138,121 +123,53 @@ func (f *Firehol) containsIPv6(addr net.IP) bool { return false } -func (f *Firehol) update() error { // nolint: funlen, cyclop +func (f *Firehol) update() { ctx, cancel := context.WithCancel(f.ctx) defer cancel() wg := &sync.WaitGroup{} - wg.Add(len(f.remoteURLs) + len(f.localFiles)) + wg.Add(len(f.blocklists)) treeMutex := &sync.Mutex{} v4tree := bool_tree.NewTreeV4() v6tree := bool_tree.NewTreeV6() - errorChan := make(chan error, 1) - defer close(errorChan) - - for _, v := range f.localFiles { - go func(filename string) { + for _, v := range f.blocklists { + go func(file files.File) { defer wg.Done() - if err := f.updateLocalFile(ctx, filename, treeMutex, v4tree, v6tree); err != nil { - cancel() - f.logger.BindStr("filename", filename).WarningError("cannot update", err) + logger := f.logger.BindStr("filename", file.String()) - select { - case errorChan <- err: - default: - } - } - }(v) - } + fileContent, err := file.Open(ctx) + if err != nil { + logger.WarningError("update has failed", err) - for _, v := range f.remoteURLs { - value := v - - f.workerPool.Submit(func() { // nolint: errcheck - defer wg.Done() + return + } - if err := f.updateRemoteURL(ctx, value, treeMutex, v4tree, v6tree); err != nil { - cancel() - f.logger.BindStr("url", value).WarningError("cannot update", err) + defer fileContent.Close() - select { - case errorChan <- err: - default: - } + if err := f.updateFromFile(treeMutex, v4tree, v6tree, bufio.NewScanner(fileContent)); err != nil { + logger.WarningError("update has failed", err) } - }) + }(v) } wg.Wait() - select { - case err := <-errorChan: - return fmt.Errorf("cannot update trees: %w", err) - default: - } - - f.rwMutex.Lock() - defer f.rwMutex.Unlock() + f.updateMutex.Lock() + defer f.updateMutex.Unlock() f.treeV4 = v4tree f.treeV6 = v6tree - return nil -} - -func (f *Firehol) updateLocalFile(ctx context.Context, filename string, - mutex sync.Locker, - v4tree *bool_tree.TreeV4, v6tree *bool_tree.TreeV6) error { - filefp, err := os.Open(filename) - if err != nil { - return fmt.Errorf("cannot open file: %w", err) - } - - go func(ctx context.Context, closer io.Closer) { - <-ctx.Done() - closer.Close() - }(ctx, filefp) - - defer filefp.Close() - - return f.updateTrees(mutex, filefp, v4tree, v6tree) + f.logger.Info("blocklist was updated") } -func (f *Firehol) updateRemoteURL(ctx context.Context, url string, - mutex sync.Locker, - v4tree *bool_tree.TreeV4, v6tree *bool_tree.TreeV6) error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return fmt.Errorf("cannot build a request: %w", err) - } - - resp, err := f.httpClient.Do(req) // nolint: bodyclose - if err != nil { - return fmt.Errorf("cannot request a remote URL %s: %w", url, err) - } - - go func(ctx context.Context, closer io.Closer) { - <-ctx.Done() - closer.Close() - }(ctx, resp.Body) - - defer func(rc io.ReadCloser) { - io.Copy(io.Discard, rc) // nolint: errcheck - rc.Close() - }(resp.Body) - - return f.updateTrees(mutex, resp.Body, v4tree, v6tree) -} - -func (f *Firehol) updateTrees(mutex sync.Locker, - reader io.Reader, +func (f *Firehol) updateFromFile(mutex sync.Locker, v4tree *bool_tree.TreeV4, - v6tree *bool_tree.TreeV6) error { - scanner := bufio.NewScanner(reader) - + v6tree *bool_tree.TreeV6, + scanner *bufio.Scanner) error { for scanner.Scan() { text := scanner.Text() text = fireholRegexpComment.ReplaceAllLiteralString(text, "") @@ -271,7 +188,7 @@ func (f *Firehol) updateTrees(mutex sync.Locker, } if scanner.Err() != nil { - return fmt.Errorf("cannot parse a response: %w", scanner.Err()) + return fmt.Errorf("cannot parse a file: %w", scanner.Err()) } return nil @@ -317,27 +234,36 @@ func (f *Firehol) updateAddToTrees(ip net.IP, cidr uint, // when it is necessary. func NewFirehol(logger mtglib.Logger, network mtglib.Network, downloadConcurrency uint, - remoteURLs []string, + urls []string, localFiles []string) (*Firehol, error) { - for _, v := range remoteURLs { - parsed, err := url.Parse(v) + blocklists := []files.File{} + + for _, v := range localFiles { + file, err := files.NewLocal(v) if err != nil { - return nil, fmt.Errorf("incorrect url %s: %w", v, err) + return nil, fmt.Errorf("cannot create a local file %s: %w", v, err) } - switch parsed.Scheme { - case "http", "https": - default: - return nil, fmt.Errorf("unsupported url %s", v) - } + blocklists = append(blocklists, file) } - for _, v := range localFiles { - if stat, err := os.Stat(v); os.IsNotExist(err) || stat.IsDir() || stat.Mode().Perm()&0o400 == 0 { - return nil, fmt.Errorf("%s is not a readable file", v) + httpClient := network.MakeHTTPClient(nil) + + for _, v := range urls { + file, err := files.NewHTTP(httpClient, v) + if err != nil { + return nil, fmt.Errorf("cannot create a HTTP file %s: %w", v, err) } + + blocklists = append(blocklists, file) } + return NewFireholFromFiles(logger, downloadConcurrency, blocklists) +} + +func NewFireholFromFiles(logger mtglib.Logger, + downloadConcurrency uint, + blocklists []files.File) (*Firehol, error) { if downloadConcurrency == 0 { downloadConcurrency = DefaultFireholDownloadConcurrency } @@ -349,11 +275,9 @@ func NewFirehol(logger mtglib.Logger, network mtglib.Network, ctx: ctx, ctxCancel: cancel, logger: logger.Named("firehol"), - httpClient: network.MakeHTTPClient(nil), treeV4: bool_tree.NewTreeV4(), treeV6: bool_tree.NewTreeV6(), workerPool: workerPool, - remoteURLs: remoteURLs, - localFiles: localFiles, + blocklists: blocklists, }, nil } diff --git a/mtglib/conns.go b/mtglib/conns.go index 8dfb3a45f..129ef52a2 100644 --- a/mtglib/conns.go +++ b/mtglib/conns.go @@ -4,12 +4,13 @@ import ( "bytes" "context" "io" - "net" "sync" + + "github.com/9seconds/mtg/v2/essentials" ) type connTraffic struct { - net.Conn + essentials.Conn streamID string stream EventStream @@ -37,7 +38,7 @@ func (c connTraffic) Write(b []byte) (int, error) { } type connRewind struct { - net.Conn + essentials.Conn active io.Reader buf bytes.Buffer @@ -58,7 +59,7 @@ func (c *connRewind) Rewind() { c.active = io.MultiReader(&c.buf, c.Conn) } -func newConnRewind(conn net.Conn) *connRewind { +func newConnRewind(conn essentials.Conn) *connRewind { rv := &connRewind{ Conn: conn, } diff --git a/mtglib/conns_internal_test.go b/mtglib/conns_internal_test.go index 8149b0d7e..ea46d7352 100644 --- a/mtglib/conns_internal_test.go +++ b/mtglib/conns_internal_test.go @@ -14,7 +14,7 @@ import ( ) type ConnRewindBaseConn struct { - testlib.NetConnMock + testlib.EssentialsConnMock readBuffer bytes.Buffer } @@ -29,13 +29,13 @@ type ConnTrafficTestSuite struct { suite.Suite eventStreamMock *EventStreamMock - connMock *testlib.NetConnMock + connMock *testlib.EssentialsConnMock conn io.ReadWriter } func (suite *ConnTrafficTestSuite) SetupTest() { suite.eventStreamMock = &EventStreamMock{} - suite.connMock = &testlib.NetConnMock{} + suite.connMock = &testlib.EssentialsConnMock{} suite.conn = connTraffic{ Conn: suite.connMock, streamID: "CONNID", diff --git a/mtglib/init.go b/mtglib/init.go index 08a4e2f6f..2f34ff719 100644 --- a/mtglib/init.go +++ b/mtglib/init.go @@ -23,6 +23,8 @@ import ( "net" "net/http" "time" + + "github.com/9seconds/mtg/v2/essentials" ) var ( @@ -61,6 +63,8 @@ const ( DefaultConcurrency = 4096 // DefaultBufferSize is a default size of a copy buffer. + // + // Deprecated: this setting no longer makes any effect. DefaultBufferSize = 16 * 1024 // 16 kib // DefaultDomainFrontingPort is a default port (HTTPS) to connect to in @@ -114,16 +118,16 @@ const ( // 3. Doing HTTP requests (for example, for FireHOL ipblocklist). type Network interface { // Dial establishes context-free TCP connections. - Dial(network, address string) (net.Conn, error) + Dial(network, address string) (essentials.Conn, error) // DialContext dials using a context. This is a preferrable // way of establishing TCP connections. - DialContext(ctx context.Context, network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (essentials.Conn, error) // MakeHTTPClient build an HTTP client with given dial function. If // nothing is provided, then DialContext of this interface is going // to be used. - MakeHTTPClient(func(ctx context.Context, network, address string) (net.Conn, error)) *http.Client + MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client } // AntiReplayCache is an interface that is used to detect replay attacks diff --git a/mtglib/internal/faketls/conn.go b/mtglib/internal/faketls/conn.go index d7e783767..74c802083 100644 --- a/mtglib/internal/faketls/conn.go +++ b/mtglib/internal/faketls/conn.go @@ -4,13 +4,13 @@ import ( "bytes" "fmt" "math/rand" - "net" + "github.com/9seconds/mtg/v2/essentials" "github.com/9seconds/mtg/v2/mtglib/internal/faketls/record" ) type Conn struct { - net.Conn + essentials.Conn readBuffer bytes.Buffer } diff --git a/mtglib/internal/faketls/conn_test.go b/mtglib/internal/faketls/conn_test.go index affcec0be..e7f311aa8 100644 --- a/mtglib/internal/faketls/conn_test.go +++ b/mtglib/internal/faketls/conn_test.go @@ -15,7 +15,7 @@ import ( ) type ConnMock struct { - testlib.NetConnMock + testlib.EssentialsConnMock readBuffer bytes.Buffer writeBuffer bytes.Buffer diff --git a/mtglib/internal/obfuscated2/client_handshake_test.go b/mtglib/internal/obfuscated2/client_handshake_test.go index 6d4d8c25e..c310e96dd 100644 --- a/mtglib/internal/obfuscated2/client_handshake_test.go +++ b/mtglib/internal/obfuscated2/client_handshake_test.go @@ -42,7 +42,7 @@ func (suite *ClientHandshakeTestSuite) TestOk() { writeData := make([]byte, len(snapshot.Encrypted.Text.data)) readData := make([]byte, len(snapshot.Decrypted.Text.data)) - connMock := &testlib.NetConnMock{} + connMock := &testlib.EssentialsConnMock{} connMock.On("Read", mock.Anything). Once(). Return(len(snapshot.Decrypted.Text.data), nil). diff --git a/mtglib/internal/obfuscated2/conn.go b/mtglib/internal/obfuscated2/conn.go index 511ba62f7..b6ecbf489 100644 --- a/mtglib/internal/obfuscated2/conn.go +++ b/mtglib/internal/obfuscated2/conn.go @@ -2,11 +2,12 @@ package obfuscated2 import ( "crypto/cipher" - "net" + + "github.com/9seconds/mtg/v2/essentials" ) type Conn struct { - net.Conn + essentials.Conn Encryptor cipher.Stream Decryptor cipher.Stream diff --git a/mtglib/internal/obfuscated2/server_handshake_test.go b/mtglib/internal/obfuscated2/server_handshake_test.go index 418af0c6a..89719bb48 100644 --- a/mtglib/internal/obfuscated2/server_handshake_test.go +++ b/mtglib/internal/obfuscated2/server_handshake_test.go @@ -16,7 +16,7 @@ import ( type ServerHandshakeTestSuite struct { suite.Suite - connMock *testlib.NetConnMock + connMock *testlib.EssentialsConnMock proxyConn obfuscated2.Conn encryptor cipher.Stream decryptor cipher.Stream @@ -24,7 +24,7 @@ type ServerHandshakeTestSuite struct { func (suite *ServerHandshakeTestSuite) SetupTest() { buf := &bytes.Buffer{} - suite.connMock = &testlib.NetConnMock{} + suite.connMock = &testlib.EssentialsConnMock{} encryptor, decryptor, err := obfuscated2.ServerHandshake(buf) suite.NoError(err) diff --git a/mtglib/internal/relay/conn.go b/mtglib/internal/relay/conn.go deleted file mode 100644 index cdeaff14b..000000000 --- a/mtglib/internal/relay/conn.go +++ /dev/null @@ -1,19 +0,0 @@ -package relay - -import ( - "fmt" - "net" - "time" -) - -type conn struct { - net.Conn -} - -func (c conn) Read(p []byte) (int, error) { - if err := c.SetReadDeadline(time.Now().Add(getTimeout())); err != nil { - return 0, fmt.Errorf("cannot set read deadline: %w", err) - } - - return c.Conn.Read(p) // nolint: wrapcheck -} diff --git a/mtglib/internal/relay/init.go b/mtglib/internal/relay/init.go index df1f7208b..6278c48f4 100644 --- a/mtglib/internal/relay/init.go +++ b/mtglib/internal/relay/init.go @@ -3,10 +3,9 @@ package relay import "time" const ( - ConnectionTimeToLiveMin = 2 * time.Minute - ConnectionTimeToLiveMax = 10 * time.Minute - TimeoutMin = 20 * time.Second - TimeoutMax = time.Minute + copyBufferSize = 64 * 1024 + writerBufferSize = 128 * 1024 + readTimeout = 10 * time.Millisecond ) type Logger interface { diff --git a/mtglib/internal/relay/pools.go b/mtglib/internal/relay/pools.go index 0f0a34a69..2adff99e3 100644 --- a/mtglib/internal/relay/pools.go +++ b/mtglib/internal/relay/pools.go @@ -1,32 +1,31 @@ package relay -import "sync" - -type eastWest struct { - east []byte - west []byte -} - -var eastWestPool = sync.Pool{ +import ( + "bufio" + "io" + "net" + "sync" +) + +var syncPairPool = sync.Pool{ New: func() interface{} { - return &eastWest{} + return &syncPair{ + writer: bufio.NewWriterSize(nil, writerBufferSize), + copyBuf: make([]byte, copyBufferSize), + } }, } -func acquireEastWest(bufferSize int) *eastWest { - wanted := eastWestPool.Get().(*eastWest) // nolint: forcetypeassert - - if len(wanted.east) != bufferSize { - wanted.east = make([]byte, bufferSize) - } - - if len(wanted.west) != bufferSize { - wanted.west = make([]byte, bufferSize) - } +func acquireSyncPair(reader net.Conn, writer io.Writer) *syncPair { + sp := syncPairPool.Get().(*syncPair) // nolint: forcetypeassert + sp.writer.Reset(writer) + sp.reader = reader - return wanted + return sp } -func releaseEastWest(ew *eastWest) { - eastWestPool.Put(ew) +func releaseSyncPair(sp *syncPair) { + sp.writer.Reset(nil) + sp.reader = nil + syncPairPool.Put(sp) } diff --git a/mtglib/internal/relay/relay.go b/mtglib/internal/relay/relay.go index 223f9bc05..6f9db4f3d 100644 --- a/mtglib/internal/relay/relay.go +++ b/mtglib/internal/relay/relay.go @@ -2,17 +2,18 @@ package relay import ( "context" + "errors" "io" - "net" "sync" + + "github.com/9seconds/mtg/v2/essentials" ) -func Relay(ctx context.Context, log Logger, bufferSize int, - telegramConn net.Conn, clientConn io.ReadWriteCloser) { +func Relay(ctx context.Context, log Logger, telegramConn, clientConn essentials.Conn) { defer telegramConn.Close() defer clientConn.Close() - ctx, cancel := context.WithTimeout(ctx, getConnectionTimeToLive()) + ctx, cancel := context.WithCancel(ctx) defer cancel() go func() { @@ -21,30 +22,35 @@ func Relay(ctx context.Context, log Logger, bufferSize int, clientConn.Close() }() - buffers := acquireEastWest(bufferSize) - defer releaseEastWest(buffers) - - telegramConn = conn{ - Conn: telegramConn, - } - wg := &sync.WaitGroup{} wg.Add(2) // nolint: gomnd - go pump(log, telegramConn, clientConn, wg, buffers.east, "east -> west") + go pump(log, telegramConn, clientConn, wg, "client -> telegram") - pump(log, clientConn, telegramConn, wg, buffers.west, "west -> east") + pump(log, clientConn, telegramConn, wg, "telegram -> client") wg.Wait() } -func pump(log Logger, src io.ReadCloser, dst io.WriteCloser, wg *sync.WaitGroup, - buf []byte, direction string) { - defer wg.Done() - defer src.Close() - defer dst.Close() +func pump(log Logger, src, dst essentials.Conn, wg *sync.WaitGroup, direction string) { + syncer := acquireSyncPair(src, dst) + + defer func() { + syncer.Flush() + releaseSyncPair(syncer) + src.CloseRead() // nolint: errcheck + dst.CloseWrite() // nolint: errcheck + wg.Done() + }() + + n, err := syncer.Sync() - if n, err := io.CopyBuffer(dst, src, buf); err != nil { - log.Printf("cannot pump %s (written %d bytes): %w", direction, n, err) + switch { + case err == nil: + log.Printf("%s has been finished", direction) + case errors.Is(err, io.EOF): + log.Printf("%s has been finished because of EOF. Written %d bytes", direction, n) + default: + log.Printf("%s has been finished (written %d bytes): %v", direction, n, err) } } diff --git a/mtglib/internal/relay/relay_test.go b/mtglib/internal/relay/relay_test.go index d1b53623f..368469f7d 100644 --- a/mtglib/internal/relay/relay_test.go +++ b/mtglib/internal/relay/relay_test.go @@ -17,8 +17,8 @@ type RelayTestSuite struct { loggerMock relay.Logger ctx context.Context ctxCancel context.CancelFunc - telegramConnMock *testlib.NetConnMock - clientConnMock *testlib.NetConnMock + telegramConnMock *testlib.EssentialsConnMock + clientConnMock *testlib.EssentialsConnMock } func (suite *RelayTestSuite) SetupTest() { @@ -26,8 +26,8 @@ func (suite *RelayTestSuite) SetupTest() { suite.ctx = ctx suite.ctxCancel = cancel suite.loggerMock = &loggerMock{} - suite.telegramConnMock = &testlib.NetConnMock{} - suite.clientConnMock = &testlib.NetConnMock{} + suite.telegramConnMock = &testlib.EssentialsConnMock{} + suite.clientConnMock = &testlib.EssentialsConnMock{} } func (suite *RelayTestSuite) TearDownTest() { @@ -37,17 +37,21 @@ func (suite *RelayTestSuite) TearDownTest() { } func (suite *RelayTestSuite) TestExit() { - suite.telegramConnMock.On("SetReadDeadline", mock.Anything).Return(nil) suite.telegramConnMock.On("Close").Return(nil) + suite.telegramConnMock.On("CloseRead").Return(nil).Once() + suite.telegramConnMock.On("CloseWrite").Return(nil).Once() suite.telegramConnMock.On("Read", mock.Anything).Return(10, io.EOF).Once() suite.telegramConnMock.On("Write", mock.Anything).Return(10, io.EOF).Maybe() + suite.telegramConnMock.On("SetReadDeadline", mock.Anything).Return(nil).Maybe() suite.clientConnMock.On("Read", mock.Anything).Return(0, io.EOF).Once() suite.clientConnMock.On("Write", mock.Anything).Return(10, io.EOF).Maybe() suite.clientConnMock.On("Close").Return(nil) + suite.clientConnMock.On("CloseRead").Return(nil).Once() + suite.clientConnMock.On("CloseWrite").Return(nil).Once() + suite.clientConnMock.On("SetReadDeadline", mock.Anything).Return(nil).Maybe() - relay.Relay(suite.ctx, suite.loggerMock, 1024, - suite.telegramConnMock, suite.clientConnMock) + relay.Relay(suite.ctx, suite.loggerMock, suite.telegramConnMock, suite.clientConnMock) } func TestRelay(t *testing.T) { diff --git a/mtglib/internal/relay/sync_pair.go b/mtglib/internal/relay/sync_pair.go new file mode 100644 index 000000000..035b05e39 --- /dev/null +++ b/mtglib/internal/relay/sync_pair.go @@ -0,0 +1,77 @@ +package relay + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "os" + "sync" + "time" +) + +type syncPair struct { + writer *bufio.Writer + copyBuf []byte + + mutex sync.Mutex + reader net.Conn +} + +func (s *syncPair) Sync() (int64, error) { + return io.CopyBuffer(s, s, s.copyBuf) // nolint: wrapcheck +} + +func (s *syncPair) Read(p []byte) (int, error) { + n, err := s.readBlocking(p, false) + + // nothing has been delivered for readTimeout time. Let's flush. + if errors.Is(err, os.ErrDeadlineExceeded) { + if err := s.Flush(); err != nil { + return 0, fmt.Errorf("cannot flush writer hand-side: %w", err) + } + + return s.readBlocking(p, true) + } + + return n, err +} + +func (s *syncPair) Write(p []byte) (int, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + n, err := s.writer.Write(p) + + // optimization for a case when we have a small package and want to avoid a + // delay in readTimeout. In that case, we assume that peer has finished to + // sent a data it wants to send so we can flush without waiting for anything + // else. + if err == nil && n < copyBufferSize { + err = s.writer.Flush() + } + + return n, err // nolint: wrapcheck +} + +func (s *syncPair) Flush() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.writer.Flush() // nolint: wrapcheck +} + +func (s *syncPair) readBlocking(p []byte, blocking bool) (int, error) { + var deadline time.Time + + if !blocking { + deadline = time.Now().Add(readTimeout) + } + + if err := s.reader.SetReadDeadline(deadline); err != nil { + return 0, fmt.Errorf("cannot set read deadline: %w", err) + } + + return s.reader.Read(p) // nolint: wrapcheck +} diff --git a/mtglib/internal/relay/timeouts.go b/mtglib/internal/relay/timeouts.go deleted file mode 100644 index 54ec7537c..000000000 --- a/mtglib/internal/relay/timeouts.go +++ /dev/null @@ -1,22 +0,0 @@ -package relay - -import ( - "math/rand" - "time" -) - -func getConnectionTimeToLive() time.Duration { - return getTime(ConnectionTimeToLiveMin, ConnectionTimeToLiveMax) -} - -func getTimeout() time.Duration { - return getTime(TimeoutMin, TimeoutMax) -} - -func getTime(minDuration, maxDuration time.Duration) time.Duration { - minDurationInSeconds := int(minDuration.Seconds()) - maxDurationInSeconds := int(maxDuration.Seconds()) - number := minDurationInSeconds + rand.Intn(maxDurationInSeconds-minDurationInSeconds) - - return time.Duration(number) * time.Second -} diff --git a/mtglib/internal/relay/timeouts_internal_test.go b/mtglib/internal/relay/timeouts_internal_test.go deleted file mode 100644 index 47e49e4d1..000000000 --- a/mtglib/internal/relay/timeouts_internal_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package relay - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/suite" -) - -type TimeoutsTestSuite struct { - suite.Suite -} - -func (suite *TimeoutsTestSuite) TestGetConnectionTimeToLive() { - for i := 0; i < 100; i++ { - value := getConnectionTimeToLive() - message := fmt.Sprintf("generated value is %v", value) - - suite.GreaterOrEqual(value, ConnectionTimeToLiveMin, message) - suite.LessOrEqual(value, ConnectionTimeToLiveMax, message) - } -} - -func (suite *TimeoutsTestSuite) TestGetTimeout() { - for i := 0; i < 100; i++ { - value := getTimeout() - message := fmt.Sprintf("generated value is %v", value) - - suite.GreaterOrEqual(value, TimeoutMin, message) - suite.LessOrEqual(value, TimeoutMax, message) - } -} - -func TestTimeouts(t *testing.T) { - t.Parallel() - suite.Run(t, &TimeoutsTestSuite{}) -} diff --git a/mtglib/internal/telegram/init.go b/mtglib/internal/telegram/init.go index 469478f0c..448120bcd 100644 --- a/mtglib/internal/telegram/init.go +++ b/mtglib/internal/telegram/init.go @@ -2,7 +2,8 @@ package telegram import ( "context" - "net" + + "github.com/9seconds/mtg/v2/essentials" ) type preferIP uint8 @@ -82,5 +83,5 @@ var ( ) type Dialer interface { - DialContext(ctx context.Context, network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (essentials.Conn, error) } diff --git a/mtglib/internal/telegram/telegram.go b/mtglib/internal/telegram/telegram.go index 43d486785..395a9b7a6 100644 --- a/mtglib/internal/telegram/telegram.go +++ b/mtglib/internal/telegram/telegram.go @@ -3,8 +3,9 @@ package telegram import ( "context" "fmt" - "net" "strings" + + "github.com/9seconds/mtg/v2/essentials" ) type Telegram struct { @@ -13,7 +14,7 @@ type Telegram struct { pool addressPool } -func (t Telegram) Dial(ctx context.Context, dc int) (net.Conn, error) { +func (t Telegram) Dial(ctx context.Context, dc int) (essentials.Conn, error) { var addresses []tgAddr switch t.preferIP { @@ -28,7 +29,7 @@ func (t Telegram) Dial(ctx context.Context, dc int) (net.Conn, error) { } var ( - conn net.Conn + conn essentials.Conn err error ) diff --git a/mtglib/proxy.go b/mtglib/proxy.go index 330d2f3f3..294d55d2a 100644 --- a/mtglib/proxy.go +++ b/mtglib/proxy.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/9seconds/mtg/v2/essentials" "github.com/9seconds/mtg/v2/mtglib/internal/faketls" "github.com/9seconds/mtg/v2/mtglib/internal/faketls/record" "github.com/9seconds/mtg/v2/mtglib/internal/obfuscated2" @@ -25,7 +26,6 @@ type Proxy struct { allowFallbackOnUnknownDC bool tolerateTimeSkewness time.Duration - bufferSize int domainFrontingPort int workerPool *ants.PoolWithFunc telegram *telegram.Telegram @@ -33,7 +33,8 @@ type Proxy struct { secret Secret network Network antiReplayCache AntiReplayCache - ipBlocklist IPBlocklist + blocklist IPBlocklist + whitelist IPBlocklist eventStream EventStream logger Logger } @@ -45,7 +46,7 @@ func (p *Proxy) DomainFrontingAddress() string { // ServeConn serves a connection. We do not check IP blocklist and // concurrency limit here. -func (p *Proxy) ServeConn(conn net.Conn) { +func (p *Proxy) ServeConn(conn essentials.Conn) { p.streamWaitGroup.Add(1) defer p.streamWaitGroup.Done() @@ -84,14 +85,13 @@ func (p *Proxy) ServeConn(conn net.Conn) { relay.Relay( ctx, ctx.logger.Named("relay"), - p.bufferSize, ctx.telegramConn, ctx.clientConn, ) } // Serve starts a proxy on a given listener. -func (p *Proxy) Serve(listener net.Listener) error { +func (p *Proxy) Serve(listener net.Listener) error { // nolint: cyclop p.streamWaitGroup.Add(1) defer p.streamWaitGroup.Done() @@ -109,7 +109,15 @@ func (p *Proxy) Serve(listener net.Listener) error { ipAddr := conn.RemoteAddr().(*net.TCPAddr).IP logger := p.logger.BindStr("ip", ipAddr.String()) - if p.ipBlocklist.Contains(ipAddr) { + if p.whitelist != nil && !p.whitelist.Contains(ipAddr) { + conn.Close() + logger.Info("ip was rejected by whitelist") + p.eventStream.Send(p.ctx, NewEventIPBlocklisted(ipAddr)) + + continue + } + + if p.blocklist.Contains(ipAddr) { conn.Close() logger.Info("ip was blacklisted") p.eventStream.Send(p.ctx, NewEventIPBlocklisted(ipAddr)) @@ -267,7 +275,6 @@ func (p *Proxy) doDomainFronting(ctx *streamContext, conn *connRewind) { relay.Relay( ctx, ctx.logger.Named("domain-fronting"), - p.bufferSize, frontConn, conn, ) @@ -291,19 +298,19 @@ func NewProxy(opts ProxyOpts) (*Proxy, error) { secret: opts.Secret, network: opts.Network, antiReplayCache: opts.AntiReplayCache, - ipBlocklist: opts.IPBlocklist, + blocklist: opts.IPBlocklist, + whitelist: opts.IPWhitelist, eventStream: opts.EventStream, logger: opts.getLogger("proxy"), domainFrontingPort: opts.getDomainFrontingPort(), tolerateTimeSkewness: opts.getTolerateTimeSkewness(), - bufferSize: opts.getBufferSize(), allowFallbackOnUnknownDC: opts.AllowFallbackOnUnknownDC, telegram: tg, } pool, err := ants.NewPoolWithFunc(opts.getConcurrency(), func(arg interface{}) { - proxy.ServeConn(arg.(net.Conn)) + proxy.ServeConn(arg.(essentials.Conn)) }, ants.WithLogger(opts.getLogger("ants")), ants.WithNonblocking(true)) diff --git a/mtglib/proxy_opts.go b/mtglib/proxy_opts.go index 8993de728..faad2100d 100644 --- a/mtglib/proxy_opts.go +++ b/mtglib/proxy_opts.go @@ -28,6 +28,11 @@ type ProxyOpts struct { // This is a mandatory setting. IPBlocklist IPBlocklist + // IPWhitelist defines a whitelist of IPs to allow to use proxy. + // + // This is an optional setting, ignored by default (no restrictions). + IPWhitelist IPBlocklist + // EventStream defines an instance of event stream. // // This ia a mandatory setting. @@ -45,6 +50,8 @@ type ProxyOpts struct { // buffers: to and from. // // This is an optional setting. + // + // Deprecated: this setting is no longer makes any effect. BufferSize uint // Concurrency is a size of the worker pool for connection management. @@ -129,14 +136,6 @@ func (p ProxyOpts) valid() error { return nil } -func (p ProxyOpts) getBufferSize() int { - if p.BufferSize < 1 { - return DefaultBufferSize - } - - return int(p.BufferSize) -} - func (p ProxyOpts) getConcurrency() int { if p.Concurrency == 0 { return DefaultConcurrency diff --git a/mtglib/stream_context.go b/mtglib/stream_context.go index 7ee8071f9..e1f231971 100644 --- a/mtglib/stream_context.go +++ b/mtglib/stream_context.go @@ -6,13 +6,15 @@ import ( "encoding/base64" "net" "time" + + "github.com/9seconds/mtg/v2/essentials" ) type streamContext struct { ctx context.Context ctxCancel context.CancelFunc - clientConn net.Conn - telegramConn net.Conn + clientConn essentials.Conn + telegramConn essentials.Conn streamID string dc int logger Logger @@ -50,7 +52,7 @@ func (s *streamContext) ClientIP() net.IP { return s.clientConn.RemoteAddr().(*net.TCPAddr).IP } -func newStreamContext(ctx context.Context, logger Logger, clientConn net.Conn) *streamContext { +func newStreamContext(ctx context.Context, logger Logger, clientConn essentials.Conn) *streamContext { connIDBytes := make([]byte, ConnectionIDBytesLength) if _, err := rand.Read(connIDBytes); err != nil { diff --git a/mtglib/stream_context_internal_test.go b/mtglib/stream_context_internal_test.go index 17f90ea98..52b5d4a08 100644 --- a/mtglib/stream_context_internal_test.go +++ b/mtglib/stream_context_internal_test.go @@ -12,7 +12,7 @@ import ( type StreamContextTestSuite struct { suite.Suite - connMock *testlib.NetConnMock + connMock *testlib.EssentialsConnMock logger NoopLogger ctx *streamContext ctxCancel context.CancelFunc @@ -27,7 +27,7 @@ func (suite *StreamContextTestSuite) SetupTest() { ctx = context.WithValue(ctx, "key", "value") // nolint: golint, revive, staticcheck suite.ctxCancel = cancel - suite.connMock = &testlib.NetConnMock{} + suite.connMock = &testlib.EssentialsConnMock{} addr := &net.TCPAddr{ IP: net.ParseIP("10.0.0.10"), @@ -73,7 +73,7 @@ func (suite *StreamContextTestSuite) TestClientIP() { func (suite *StreamContextTestSuite) TestClose() { suite.connMock.On("Close").Once().Return(nil) - tgConnMock := &testlib.NetConnMock{} + tgConnMock := &testlib.EssentialsConnMock{} tgConnMock.On("Close").Once().Return(nil) suite.ctx.telegramConn = tgConnMock diff --git a/network/circuit_breaker.go b/network/circuit_breaker.go index 4c74a9b93..91e131774 100644 --- a/network/circuit_breaker.go +++ b/network/circuit_breaker.go @@ -2,9 +2,10 @@ package network import ( "context" - "net" "sync/atomic" "time" + + "github.com/9seconds/mtg/v2/essentials" ) const ( @@ -30,12 +31,12 @@ type circuitBreakerDialer struct { resetFailuresTimeout time.Duration } -func (c *circuitBreakerDialer) Dial(network, address string) (net.Conn, error) { +func (c *circuitBreakerDialer) Dial(network, address string) (essentials.Conn, error) { return c.DialContext(context.Background(), network, address) } func (c *circuitBreakerDialer) DialContext(ctx context.Context, - network, address string) (net.Conn, error) { + network, address string) (essentials.Conn, error) { switch atomic.LoadUint32(&c.state) { case circuitBreakerStateClosed: return c.doClosed(ctx, network, address) @@ -47,7 +48,7 @@ func (c *circuitBreakerDialer) DialContext(ctx context.Context, } func (c *circuitBreakerDialer) doClosed(ctx context.Context, - network, address string) (net.Conn, error) { + network, address string) (essentials.Conn, error) { conn, err := c.Dialer.DialContext(ctx, network, address) select { @@ -78,7 +79,8 @@ func (c *circuitBreakerDialer) doClosed(ctx context.Context, return conn, err // nolint: wrapcheck } -func (c *circuitBreakerDialer) doHalfOpened(ctx context.Context, network, address string) (net.Conn, error) { +func (c *circuitBreakerDialer) doHalfOpened(ctx context.Context, + network, address string) (essentials.Conn, error) { if !atomic.CompareAndSwapUint32(&c.halfOpenAttempts, 0, 1) { return nil, ErrCircuitBreakerOpened } diff --git a/network/circuit_breaker_internal_test.go b/network/circuit_breaker_internal_test.go index a26f5946f..d300d684a 100644 --- a/network/circuit_breaker_internal_test.go +++ b/network/circuit_breaker_internal_test.go @@ -21,7 +21,7 @@ type CircuitBreakerTestSuite struct { mutex sync.Mutex ctx context.Context ctxCancel context.CancelFunc - connMock *testlib.NetConnMock + connMock *testlib.EssentialsConnMock baseDialerMock *DialerMock } @@ -29,7 +29,7 @@ func (suite *CircuitBreakerTestSuite) SetupTest() { suite.mutex = sync.Mutex{} suite.ctx, suite.ctxCancel = context.WithCancel(context.Background()) suite.baseDialerMock = &DialerMock{} - suite.connMock = &testlib.NetConnMock{} + suite.connMock = &testlib.EssentialsConnMock{} suite.d = newCircuitBreakerDialer(suite.baseDialerMock, 3, 100*time.Millisecond, 50*time.Millisecond) } diff --git a/network/default.go b/network/default.go index e976daa48..50855d989 100644 --- a/network/default.go +++ b/network/default.go @@ -5,19 +5,19 @@ import ( "fmt" "net" "time" + + "github.com/9seconds/mtg/v2/essentials" ) type defaultDialer struct { net.Dialer - - bufferSize int } -func (d *defaultDialer) Dial(network, address string) (net.Conn, error) { +func (d *defaultDialer) Dial(network, address string) (essentials.Conn, error) { return d.DialContext(context.Background(), network, address) } -func (d *defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *defaultDialer) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) { switch network { case "tcp", "tcp4", "tcp6": // nolint: goconst default: @@ -30,13 +30,13 @@ func (d *defaultDialer) DialContext(ctx context.Context, network, address string } // we do not need to call to end user. End users call us. - if err := SetServerSocketOptions(conn, d.bufferSize); err != nil { + if err := SetServerSocketOptions(conn, 0); err != nil { conn.Close() return nil, fmt.Errorf("cannot set socket options: %w", err) } - return conn, nil + return conn.(essentials.Conn), nil } // NewDefaultDialer build a new dialer which dials bypassing proxies @@ -44,26 +44,20 @@ func (d *defaultDialer) DialContext(ctx context.Context, network, address string // // The most default one you can imagine. But it has tunes TCP // connections and setups SO_REUSEPORT. +// +// bufferSize is deprecated and ignored. It is kept here for backward +// compatibility. func NewDefaultDialer(timeout time.Duration, bufferSize int) (Dialer, error) { switch { case timeout < 0: return nil, fmt.Errorf("timeout %v should be positive number", timeout) - case bufferSize < 0: - return nil, fmt.Errorf("buffer size %d should be positive number", bufferSize) - } - - if timeout == 0 { + case timeout == 0: timeout = DefaultTimeout } - if bufferSize == 0 { - bufferSize = DefaultBufferSize - } - return &defaultDialer{ Dialer: net.Dialer{ Timeout: timeout, }, - bufferSize: bufferSize, }, nil } diff --git a/network/default_test.go b/network/default_test.go index a6b24c7ee..5a3802647 100644 --- a/network/default_test.go +++ b/network/default_test.go @@ -30,11 +30,6 @@ func (suite *DefaultDialerTestSuite) TestNegativeTimeout() { suite.Error(err) } -func (suite *DefaultDialerTestSuite) TestNegativeBufferSize() { - _, err := network.NewDefaultDialer(0, -1) - suite.Error(err) -} - func (suite *DefaultDialerTestSuite) TestUnsupportedProtocol() { _, err := suite.d.DialContext(context.Background(), "udp", diff --git a/network/dns_resolver.go b/network/dns_resolver.go index 23a4c3732..424c7ae70 100644 --- a/network/dns_resolver.go +++ b/network/dns_resolver.go @@ -1,6 +1,8 @@ package network import ( + "fmt" + "net" "net/http" "sync" "time" @@ -83,8 +85,13 @@ func (d *dnsResolver) LookupAAAA(hostname string) []string { return ips } -func newDNSResolver(hostname string, httpClient *http.Client) *dnsResolver { - return &dnsResolver{ +func newDNSResolver(hostname string, httpClient *http.Client) (ret *dnsResolver) { + if net.ParseIP(hostname).To4() == nil { + // the hostname is an IPv6 address + hostname = fmt.Sprintf("[%s]", hostname) + } + + ret = &dnsResolver{ resolver: doh.Resolver{ Host: hostname, Class: doh.IN, @@ -92,4 +99,6 @@ func newDNSResolver(hostname string, httpClient *http.Client) *dnsResolver { }, cache: map[string]dnsResolverCacheEntry{}, } + + return } diff --git a/network/init.go b/network/init.go index c12c7a0de..a1426b3ad 100644 --- a/network/init.go +++ b/network/init.go @@ -20,8 +20,9 @@ package network import ( "context" "errors" - "net" "time" + + "github.com/9seconds/mtg/v2/essentials" ) const ( @@ -33,10 +34,16 @@ const ( // request. DefaultHTTPTimeout = 10 * time.Second + // Deprecated: + // // DefaultBufferSize defines a TCP buffer size. Both read and write, so // for real size, please multiply this number by 2. DefaultBufferSize = 16 * 1024 // 16 kib + // DefaultTCPKeepAlivePeriod defines a time period between 2 + // consequitive probes. + DefaultTCPKeepAlivePeriod = 10 * time.Second + // ProxyDialerOpenThreshold is used for load balancing SOCKS5 dialer // only. // @@ -89,6 +96,6 @@ var ( // Dialer defines an interface which is required to bootstrap a network // instance from. type Dialer interface { - Dial(network, address string) (net.Conn, error) - DialContext(ctx context.Context, network, address string) (net.Conn, error) + Dial(network, address string) (essentials.Conn, error) + DialContext(ctx context.Context, network, address string) (essentials.Conn, error) } diff --git a/network/init_internal_test.go b/network/init_internal_test.go index 365a4e563..6e63532fd 100644 --- a/network/init_internal_test.go +++ b/network/init_internal_test.go @@ -2,8 +2,8 @@ package network import ( "context" - "net" + "github.com/9seconds/mtg/v2/essentials" "github.com/stretchr/testify/mock" ) @@ -11,14 +11,14 @@ type DialerMock struct { mock.Mock } -func (d *DialerMock) Dial(network, address string) (net.Conn, error) { +func (d *DialerMock) Dial(network, address string) (essentials.Conn, error) { args := d.Called(network, address) - return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck + return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck } -func (d *DialerMock) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *DialerMock) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) { args := d.Called(ctx, network, address) - return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck + return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck } diff --git a/network/init_test.go b/network/init_test.go index f1692e035..6e79a48db 100644 --- a/network/init_test.go +++ b/network/init_test.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" + "github.com/9seconds/mtg/v2/essentials" "github.com/9seconds/mtg/v2/network" socks5 "github.com/armon/go-socks5" "github.com/mccutchen/go-httpbin/httpbin" @@ -18,16 +19,16 @@ type DialerMock struct { mock.Mock } -func (d *DialerMock) Dial(network, address string) (net.Conn, error) { +func (d *DialerMock) Dial(network, address string) (essentials.Conn, error) { args := d.Called(network, address) - return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck + return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck } -func (d *DialerMock) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *DialerMock) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) { args := d.Called(ctx, network, address) - return args.Get(0).(net.Conn), args.Error(1) // nolint: wrapcheck + return args.Get(0).(essentials.Conn), args.Error(1) // nolint: wrapcheck } type HTTPServerTestSuite struct { @@ -53,7 +54,9 @@ func (suite *HTTPServerTestSuite) MakeURL(path string) string { func (suite *HTTPServerTestSuite) MakeHTTPClient(dialer network.Dialer) *http.Client { return &http.Client{ Transport: &http.Transport{ - DialContext: dialer.DialContext, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return dialer.DialContext(ctx, network, address) // nolint: wrapcheck + }, }, } } diff --git a/network/load_balanced_socks5.go b/network/load_balanced_socks5.go index a52004db5..37cbf28e8 100644 --- a/network/load_balanced_socks5.go +++ b/network/load_balanced_socks5.go @@ -4,19 +4,20 @@ import ( "context" "fmt" "math/rand" - "net" "net/url" + + "github.com/9seconds/mtg/v2/essentials" ) type loadBalancedSocks5Dialer struct { dialers []Dialer } -func (l loadBalancedSocks5Dialer) Dial(network, address string) (net.Conn, error) { +func (l loadBalancedSocks5Dialer) Dial(network, address string) (essentials.Conn, error) { return l.DialContext(context.Background(), network, address) } -func (l loadBalancedSocks5Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (l loadBalancedSocks5Dialer) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) { length := len(l.dialers) start := rand.Intn(length) moved := false diff --git a/network/network.go b/network/network.go index cac18fbef..e42508fc2 100644 --- a/network/network.go +++ b/network/network.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/9seconds/mtg/v2/essentials" "github.com/9seconds/mtg/v2/mtglib" ) @@ -30,11 +31,11 @@ type network struct { dns *dnsResolver } -func (n *network) Dial(protocol, address string) (net.Conn, error) { +func (n *network) Dial(protocol, address string) (essentials.Conn, error) { return n.DialContext(context.Background(), protocol, address) } -func (n *network) DialContext(ctx context.Context, protocol, address string) (net.Conn, error) { +func (n *network) DialContext(ctx context.Context, protocol, address string) (essentials.Conn, error) { host, port, _ := net.SplitHostPort(address) ips, err := n.dnsResolve(protocol, host) @@ -46,7 +47,8 @@ func (n *network) DialContext(ctx context.Context, protocol, address string) (ne ips[i], ips[j] = ips[j], ips[i] }) - var conn net.Conn + var conn essentials.Conn + for _, v := range ips { conn, err = n.dialer.DialContext(ctx, protocol, net.JoinHostPort(v, port)) @@ -59,7 +61,7 @@ func (n *network) DialContext(ctx context.Context, protocol, address string) (ne } func (n *network) MakeHTTPClient(dialFunc func(ctx context.Context, - network, address string) (net.Conn, error)) *http.Client { + network, address string) (essentials.Conn, error)) *http.Client { if dialFunc == nil { dialFunc = n.DialContext } @@ -144,13 +146,15 @@ func NewNetwork(dialer Dialer, func makeHTTPClient(userAgent string, timeout time.Duration, - dialFunc func(ctx context.Context, network, address string) (net.Conn, error)) *http.Client { + dialFunc func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client { return &http.Client{ Timeout: timeout, Transport: networkHTTPTransport{ userAgent: userAgent, next: &http.Transport{ - DialContext: dialFunc, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return dialFunc(ctx, network, address) + }, }, }, } diff --git a/network/sockopts.go b/network/sockopts.go index 241b0340a..da99d0974 100644 --- a/network/sockopts.go +++ b/network/sockopts.go @@ -7,39 +7,21 @@ import ( // SetClientSocketOptions tunes a TCP socket that represents a connection to // end user (not Telegram service or fronting domain). +// +// bufferSize setting is deprecated and ignored. func SetClientSocketOptions(conn net.Conn, bufferSize int) error { - tcpConn := conn.(*net.TCPConn) // nolint: forcetypeassert - - if err := tcpConn.SetNoDelay(false); err != nil { - return fmt.Errorf("cannot disable TCP_NO_DELAY: %w", err) - } - - return setCommonSocketOptions(tcpConn, bufferSize) + return setCommonSocketOptions(conn.(*net.TCPConn)) } // SetServerSocketOptions tunes a TCP socket that represents a connection to // remote server like Telegram or fronting domain (but not end user). func SetServerSocketOptions(conn net.Conn, bufferSize int) error { - tcpConn := conn.(*net.TCPConn) // nolint: forcetypeassert - - if err := tcpConn.SetNoDelay(true); err != nil { - return fmt.Errorf("cannot enable TCP_NO_DELAY: %w", err) - } - - return setCommonSocketOptions(tcpConn, bufferSize) + return setCommonSocketOptions(conn.(*net.TCPConn)) } -func setCommonSocketOptions(conn *net.TCPConn, bufferSize int) error { - if err := conn.SetReadBuffer(bufferSize); err != nil { - return fmt.Errorf("cannot set read buffer size: %w", err) - } - - if err := conn.SetWriteBuffer(bufferSize); err != nil { - return fmt.Errorf("cannot set write buffer size: %w", err) - } - - if err := conn.SetKeepAlive(false); err != nil { - return fmt.Errorf("cannot disable TCP keepalive probes: %w", err) +func setCommonSocketOptions(conn *net.TCPConn) error { + if err := conn.SetKeepAlivePeriod(DefaultTCPKeepAlivePeriod); err != nil { + return fmt.Errorf("cannot set time period of TCP keepalive probes: %w", err) } if err := conn.SetLinger(tcpLingerTimeout); err != nil { @@ -51,7 +33,7 @@ func setCommonSocketOptions(conn *net.TCPConn, bufferSize int) error { return fmt.Errorf("cannot get underlying raw connection: %w", err) } - if err := setSocketReuseAddrPort(rawConn, bufferSize); err != nil { + if err := setSocketReuseAddrPort(rawConn); err != nil { return fmt.Errorf("cannot setup SO_REUSEADDR/PORT: %w", err) } diff --git a/network/sockopts_unix.go b/network/sockopts_unix.go index b45b95def..b7c5f10f0 100644 --- a/network/sockopts_unix.go +++ b/network/sockopts_unix.go @@ -10,7 +10,7 @@ import ( "golang.org/x/sys/unix" ) -func setSocketReuseAddrPort(conn syscall.RawConn, bufferSize int) error { +func setSocketReuseAddrPort(conn syscall.RawConn) error { var err error conn.Control(func(fd uintptr) { // nolint: errcheck diff --git a/network/sockopts_windows.go b/network/sockopts_windows.go index fbeae8349..32a702a3b 100644 --- a/network/sockopts_windows.go +++ b/network/sockopts_windows.go @@ -5,6 +5,6 @@ package network import "syscall" -func setSocketReuseAddrPort(conn syscall.RawConn, bufferSize int) error { +func setSocketReuseAddrPort(conn syscall.RawConn) error { return nil } diff --git a/network/socks5.go b/network/socks5.go index 43aac4a9b..459f64bc0 100644 --- a/network/socks5.go +++ b/network/socks5.go @@ -1,21 +1,162 @@ package network import ( + "context" "fmt" + "io" + "net" "net/url" - "golang.org/x/net/proxy" + "github.com/9seconds/mtg/v2/essentials" + "github.com/txthinking/socks5" ) +type socks5Dialer struct { + Dialer + + username []byte + password []byte + proxyAddress string +} + +func (s socks5Dialer) Dial(network, address string) (essentials.Conn, error) { + return s.DialContext(context.Background(), network, address) +} + +func (s socks5Dialer) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) { + switch network { + case "tcp", "tcp4", "tcp6": + default: + return nil, fmt.Errorf("%s network type is not supported", network) + } + + conn, err := s.Dialer.DialContext(ctx, network, s.proxyAddress) + if err != nil { + return nil, fmt.Errorf("cannot dial to the proxy: %w", err) + } + + if err := s.handshake(conn); err != nil { + conn.Close() + + return nil, fmt.Errorf("cannot perform a handshake: %w", err) + } + + if err := s.connect(conn, address); err != nil { + conn.Close() + + return nil, fmt.Errorf("cannot connect to a destination host %s: %w", address, err) + } + + return conn, nil +} + +func (s socks5Dialer) handshake(conn io.ReadWriter) error { + authMethod := socks5.MethodUsernamePassword + if len(s.username)+len(s.password) == 0 { + authMethod = socks5.MethodNone + } + + if err := s.handshakeNegotiation(conn, authMethod); err != nil { + return fmt.Errorf("cannot perform negotiation: %w", err) + } + + if authMethod == socks5.MethodNone { + return nil + } + + if err := s.handshakeAuth(conn); err != nil { + return fmt.Errorf("cannot authenticate: %w", err) + } + + return nil +} + +func (s socks5Dialer) handshakeNegotiation(conn io.ReadWriter, authMethod byte) error { + request := socks5.NewNegotiationRequest([]byte{authMethod}) + if _, err := request.WriteTo(conn); err != nil { + return fmt.Errorf("cannot send request: %w", err) + } + + response, err := socks5.NewNegotiationReplyFrom(conn) + if err != nil { + return fmt.Errorf("cannot read response: %w", err) + } + + if response.Method != authMethod { + return fmt.Errorf("%v is unsupported auth method", authMethod) + } + + return nil +} + +func (s socks5Dialer) handshakeAuth(conn io.ReadWriter) error { + request := socks5.NewUserPassNegotiationRequest(s.username, s.password) + + if _, err := request.WriteTo(conn); err != nil { + return fmt.Errorf("cannot send a request: %w", err) + } + + response, err := socks5.NewUserPassNegotiationReplyFrom(conn) + if err != nil { + return fmt.Errorf("cannot read a response: %w", err) + } + + if response.Status != socks5.UserPassStatusSuccess { + return fmt.Errorf("authenticate has failed: %v", response.Status) + } + + return nil +} + +func (s socks5Dialer) connect(conn io.ReadWriter, address string) error { + addrType, host, port, err := socks5.ParseAddress(address) + if err != nil { + return fmt.Errorf("cannot parse address: %w", err) + } + + if addrType == socks5.ATYPDomain { + host = host[1:] + } + + request := socks5.NewRequest(socks5.CmdConnect, addrType, host, port) + + if _, err := request.WriteTo(conn); err != nil { + return fmt.Errorf("cannot send a request: %w", err) + } + + response, err := socks5.NewReplyFrom(conn) + if err != nil { + return fmt.Errorf("cannot read a response: %w", err) + } + + if response.Rep != socks5.RepSuccess { + return fmt.Errorf("unsuccessful request: %v", response.Rep) + } + + return nil +} + // NewSocks5Dialer build a new dialer from a given one (so, in theory // you can chain here). Proxy parameters are passed with URI in a form of: // // socks5://[user:[password]]@host:port func NewSocks5Dialer(baseDialer Dialer, proxyURL *url.URL) (Dialer, error) { - rv, err := proxy.FromURL(proxyURL, baseDialer) - if err != nil { - return nil, fmt.Errorf("cannot initialize socks5 proxy dialer: %w", err) + if _, _, err := net.SplitHostPort(proxyURL.Host); err != nil { + return nil, fmt.Errorf("incorrect url %s", proxyURL.Redacted()) + } + + dialer := socks5Dialer{ + Dialer: baseDialer, + proxyAddress: proxyURL.Host, + } + + if proxyURL.User != nil { + password, isSet := proxyURL.User.Password() + if isSet { + dialer.username = []byte(proxyURL.User.Username()) + dialer.password = []byte(password) + } } - return rv.(Dialer), nil + return dialer, nil } diff --git a/network/socks5_test.go b/network/socks5_test.go index 3460f4d7e..5cb56c384 100644 --- a/network/socks5_test.go +++ b/network/socks5_test.go @@ -55,7 +55,7 @@ func (suite *Socks5TestSuite) TestRequestOk() { suite.Equal(http.StatusOK, resp.StatusCode) } -func TestSocks5TestSuite(t *testing.T) { +func TestSocks5(t *testing.T) { t.Parallel() suite.Run(t, &Socks5TestSuite{}) }