From 006905dcd50670b1aa6d1dcb4b695a63876e5f49 Mon Sep 17 00:00:00 2001 From: Moses Narrow Date: Mon, 16 Dec 2024 14:12:31 -0600 Subject: [PATCH] add skywire-utilities libraries & fix CI errors --- Makefile | 5 +- go.mod | 8 +- pkg/skyenv/skyenv.go | 181 +++++-- .../pkg/buildinfo/buildinfo.go | 81 ++++ pkg/skywire-utilities/pkg/cipher/cipher.go | 266 +++++++++++ .../pkg/cipher/cipher_test.go | 102 ++++ .../pkg/cipher/utils_pubkey.go | 24 + pkg/skywire-utilities/pkg/cmdutil/catch.go | 38 ++ .../pkg/cmdutil/catch_test.go | 62 +++ pkg/skywire-utilities/pkg/cmdutil/cmd_name.go | 12 + .../pkg/cmdutil/service_flags.go | 242 ++++++++++ .../pkg/cmdutil/service_flags_test.go | 70 +++ .../pkg/cmdutil/signal_context.go | 35 ++ .../pkg/cmdutil/signal_unix.go | 15 + .../pkg/cmdutil/signal_windows.go | 15 + .../pkg/cmdutil/sysloghook_unix.go | 46 ++ .../pkg/cmdutil/sysloghook_windows.go | 36 ++ pkg/skywire-utilities/pkg/geo/geo.go | 93 ++++ pkg/skywire-utilities/pkg/httpauth/auth.go | 120 +++++ pkg/skywire-utilities/pkg/httpauth/handler.go | 221 +++++++++ .../pkg/httpauth/handler_test.go | 287 +++++++++++ .../pkg/httpauth/memory_store.go | 62 +++ .../pkg/httpauth/nonce-storer.go | 43 ++ .../pkg/httpauth/redis-store.go | 89 ++++ .../pkg/httputil/dmsghttp.go | 21 + pkg/skywire-utilities/pkg/httputil/error.go | 61 +++ pkg/skywire-utilities/pkg/httputil/health.go | 45 ++ .../pkg/httputil/httputil.go | 105 ++++ pkg/skywire-utilities/pkg/httputil/log.go | 57 +++ .../pkg/logging/formatter.go | 449 ++++++++++++++++++ pkg/skywire-utilities/pkg/logging/hooks.go | 45 ++ pkg/skywire-utilities/pkg/logging/logger.go | 69 +++ pkg/skywire-utilities/pkg/logging/logging.go | 85 ++++ pkg/skywire-utilities/pkg/metricsutil/http.go | 39 ++ .../request_duration_middleware.go | 26 + .../requests_in_flight_count_middleware.go | 30 ++ .../pkg/metricsutil/status_response_writer.go | 37 ++ .../victoria_metrics_int_gauge_wrapper.go | 46 ++ .../victoria_metrics_uint_gauge_wrapper.go | 46 ++ pkg/skywire-utilities/pkg/netutil/copy.go | 37 ++ pkg/skywire-utilities/pkg/netutil/net.go | 176 +++++++ .../pkg/netutil/net_darwin.go | 28 ++ .../pkg/netutil/net_linux.go | 28 ++ pkg/skywire-utilities/pkg/netutil/net_test.go | 18 + .../pkg/netutil/net_windows.go | 58 +++ pkg/skywire-utilities/pkg/netutil/porter.go | 200 ++++++++ pkg/skywire-utilities/pkg/netutil/retrier.go | 101 ++++ .../pkg/netutil/retrier_test.go | 64 +++ .../pkg/networkmonitor/networkmonitor.go | 21 + pkg/skywire-utilities/pkg/skyenv/values.go | 51 ++ .../pkg/storeconfig/storeconfig.go | 41 ++ pkg/skywire-utilities/pkg/tcpproxy/http.go | 24 + pkg/visor/api.go | 14 +- pkg/visor/logstore/logstore.go | 6 +- pkg/visor/visorconfig/v1.go | 6 +- pkg/visor/visorconfig/values.go | 183 +++++-- 56 files changed, 4253 insertions(+), 117 deletions(-) create mode 100644 pkg/skywire-utilities/pkg/buildinfo/buildinfo.go create mode 100644 pkg/skywire-utilities/pkg/cipher/cipher.go create mode 100644 pkg/skywire-utilities/pkg/cipher/cipher_test.go create mode 100644 pkg/skywire-utilities/pkg/cipher/utils_pubkey.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/catch.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/catch_test.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/cmd_name.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/service_flags.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/service_flags_test.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/signal_context.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/signal_unix.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/signal_windows.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/sysloghook_unix.go create mode 100644 pkg/skywire-utilities/pkg/cmdutil/sysloghook_windows.go create mode 100644 pkg/skywire-utilities/pkg/geo/geo.go create mode 100644 pkg/skywire-utilities/pkg/httpauth/auth.go create mode 100644 pkg/skywire-utilities/pkg/httpauth/handler.go create mode 100644 pkg/skywire-utilities/pkg/httpauth/handler_test.go create mode 100644 pkg/skywire-utilities/pkg/httpauth/memory_store.go create mode 100644 pkg/skywire-utilities/pkg/httpauth/nonce-storer.go create mode 100644 pkg/skywire-utilities/pkg/httpauth/redis-store.go create mode 100644 pkg/skywire-utilities/pkg/httputil/dmsghttp.go create mode 100644 pkg/skywire-utilities/pkg/httputil/error.go create mode 100644 pkg/skywire-utilities/pkg/httputil/health.go create mode 100644 pkg/skywire-utilities/pkg/httputil/httputil.go create mode 100644 pkg/skywire-utilities/pkg/httputil/log.go create mode 100644 pkg/skywire-utilities/pkg/logging/formatter.go create mode 100644 pkg/skywire-utilities/pkg/logging/hooks.go create mode 100644 pkg/skywire-utilities/pkg/logging/logger.go create mode 100644 pkg/skywire-utilities/pkg/logging/logging.go create mode 100644 pkg/skywire-utilities/pkg/metricsutil/http.go create mode 100644 pkg/skywire-utilities/pkg/metricsutil/request_duration_middleware.go create mode 100644 pkg/skywire-utilities/pkg/metricsutil/requests_in_flight_count_middleware.go create mode 100644 pkg/skywire-utilities/pkg/metricsutil/status_response_writer.go create mode 100644 pkg/skywire-utilities/pkg/metricsutil/victoria_metrics_int_gauge_wrapper.go create mode 100644 pkg/skywire-utilities/pkg/metricsutil/victoria_metrics_uint_gauge_wrapper.go create mode 100644 pkg/skywire-utilities/pkg/netutil/copy.go create mode 100644 pkg/skywire-utilities/pkg/netutil/net.go create mode 100644 pkg/skywire-utilities/pkg/netutil/net_darwin.go create mode 100644 pkg/skywire-utilities/pkg/netutil/net_linux.go create mode 100644 pkg/skywire-utilities/pkg/netutil/net_test.go create mode 100644 pkg/skywire-utilities/pkg/netutil/net_windows.go create mode 100644 pkg/skywire-utilities/pkg/netutil/porter.go create mode 100644 pkg/skywire-utilities/pkg/netutil/retrier.go create mode 100644 pkg/skywire-utilities/pkg/netutil/retrier_test.go create mode 100644 pkg/skywire-utilities/pkg/networkmonitor/networkmonitor.go create mode 100644 pkg/skywire-utilities/pkg/skyenv/values.go create mode 100644 pkg/skywire-utilities/pkg/storeconfig/storeconfig.go create mode 100644 pkg/skywire-utilities/pkg/tcpproxy/http.go diff --git a/Makefile b/Makefile index 133f9aead0..fa16cf702e 100644 --- a/Makefile +++ b/Makefile @@ -182,11 +182,15 @@ lint: ## Run linters. Use make install-linters first ${OPTS} golangci-lint run -c .golangci.yml ./cmd/... ${OPTS} golangci-lint run -c .golangci.yml ./pkg/... ${OPTS} golangci-lint run -c .golangci.yml ./... + +gocyclo: ## Run gocyclo gocyclo -over 14 . lint-windows: ## Run linters. Use make install-linters-windows first powershell 'golangci-lint --version' powershell 'golangci-lint run -c .golangci.yml ./...' + +gocyclo-windows: ## Run gocyclo on windows powershell 'gocyclo -over 14 .' test: ## Run tests @@ -214,7 +218,6 @@ tidy: ## Tidies and vendors dependencies. format: tidy ## Formats the code. Must have goimports and goimports-reviser installed (use make install-linters). ${OPTS} goimports -w -local ${PROJECT_BASE} ./pkg ./cmd ./internal find . -type f -name '*.go' -not -path "./.git/*" -not -path "./vendor/*" -exec goimports-reviser -project-name ${PROJECT_BASE} {} \; - gocyclo -over 14 . format-windows: tidy ## Formats the code. Must have goimports and goimports-reviser installed (use make install-linters). powershell 'Get-ChildItem -Directory | where Name -NotMatch vendor | % { Get-ChildItem $$_ -Recurse -Include *.go } | % {goimports -w -local ${PROJECT_BASE} $$_ }' diff --git a/go.mod b/go.mod index 0d79e246b2..74bc433fa9 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/gen2brain/dlgs v0.0.0-20220603100644-40c77870fa8d github.com/gin-gonic/gin v1.10.0 github.com/go-chi/chi/v5 v5.1.0 + github.com/go-redis/redis/v8 v8.11.5 github.com/gocarina/gocsv v0.0.0-20240520201108-78e41c74b4b1 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.2 @@ -25,8 +26,11 @@ require ( github.com/ivanpirog/coloredcobra v1.0.1 github.com/james-barrow/golang-ipc v1.2.4 github.com/jaypipes/ghw v0.13.0 + github.com/json-iterator/go v1.1.12 github.com/lib/pq v1.10.9 + github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d github.com/orandin/lumberjackrus v1.0.1 + github.com/pires/go-proxyproto v0.8.0 github.com/pterm/pterm v0.12.79 github.com/robert-nix/ansihtml v1.0.1 github.com/sirupsen/logrus v1.9.3 @@ -94,7 +98,6 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.22.1 // indirect - github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/goccy/go-json v0.10.3 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect @@ -113,7 +116,6 @@ require ( github.com/jaypipes/pcidb v1.0.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/klauspost/reedsolomon v1.12.4 // indirect github.com/kyokomi/emoji/v2 v2.2.13 // indirect @@ -124,14 +126,12 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect - github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/onsi/ginkgo/v2 v2.20.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/pires/go-proxyproto v0.8.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/quic-go/quic-go v0.48.0 // indirect diff --git a/pkg/skyenv/skyenv.go b/pkg/skyenv/skyenv.go index 518a592a5c..fd2eb067ef 100644 --- a/pkg/skyenv/skyenv.go +++ b/pkg/skyenv/skyenv.go @@ -18,87 +18,170 @@ const ( // Dmsg port constants. // TODO(evanlinjin): Define these properly. These are currently random. - DmsgCtrlPort uint16 = 7 // DmsgCtrlPort Listening port for dmsgctrl protocol (similar to TCP Echo Protocol). //nolint - DmsgSetupPort uint16 = 36 // DmsgSetupPort Listening port of a setup node. - DmsgHypervisorPort uint16 = 46 // DmsgHypervisorPort Listening port of a hypervisor for incoming RPC visor connections over dmsg. - DmsgTransportSetupPort uint16 = 47 // DmsgTransportSetupPort Listening port for transport setup RPC over dmsg. - DmsgAwaitSetupPort uint16 = 136 // DmsgAwaitSetupPort Listening port of a visor for setup operations. + // DmsgCtrlPort Listening port for dmsgctrl protocol (similar to TCP Echo Protocol). //nolint + DmsgCtrlPort uint16 = 7 + + // DmsgSetupPort Listening port of a setup node. + DmsgSetupPort uint16 = 36 + + // DmsgHypervisorPort Listening port of a hypervisor for incoming RPC visor connections over dmsg. + DmsgHypervisorPort uint16 = 46 + + // DmsgTransportSetupPort Listening port for transport setup RPC over dmsg. + DmsgTransportSetupPort uint16 = 47 + + // DmsgAwaitSetupPort Listening port of a visor for setup operations. + DmsgAwaitSetupPort uint16 = 136 // Transport port constants. - TransportPort uint16 = 45 // TransportPort Listening port of a visor for incoming transports. - PublicAutoconnect = true // PublicAutoconnect ... + // TransportPort Listening port of a visor for incoming transports. + TransportPort uint16 = 45 + + // PublicAutoconnect determines if the visor automatically creates stcpr transports to public visors + PublicAutoconnect = true // Dmsgpty constants. - DmsgPtyPort uint16 = 22 // DmsgPtyPort ... - DmsgPtyCLINet = "unix" // DmsgPtyCLINet ... + // DmsgPtyPort is the dmsg port to listen on for dmsgpty connections + DmsgPtyPort uint16 = 22 + + // DmsgPtyCLINet is the type of cli net used by dmsgpty + DmsgPtyCLINet = "unix" // Skywire-TCP constants. - STCPAddr = ":7777" // STCPAddr ... + // STCPAddr is the address to listen for stcpr or stcp transports + STCPAddr = ":7777" // Default skywire app constants. - SkychatName = "skychat" // SkychatName ... - SkychatPort uint16 = 1 // SkychatPort ... - SkychatAddr = ":8001" // SkychatAddr ... - PingTestName = "pingtest" // PingTestName ... - PingTestPort uint16 = 2 // PingTestPort ... - SkysocksName = "skysocks" // SkysocksName ... - SkysocksPort uint16 = 3 // SkysocksPort ... + // SkychatName is the name of the skychat app + SkychatName = "skychat" + + // SkychatPort is the dmsg port used by skychat + SkychatPort uint16 = 1 + + // SkychatAddr is the non-dmsg port used to access the skychat app on localhost + SkychatAddr = ":8001" + + // PingTestName is the namew of the ping test + PingTestName = "pingtest" - SkysocksClientName = "skysocks-client" // SkysocksClientName ... - SkysocksClientPort uint16 = 13 // SkysocksClientPort ... - SkysocksClientAddr = ":1080" // SkysocksClientAddr ... + // PingTestPort is the port to user for ping tests + PingTestPort uint16 = 2 - VPNServerName = "vpn-server" // VPNServerName ... - VPNServerPort uint16 = 44 // VPNServerPort ... + // SkysocksName is the name of the skysocks app + SkysocksName = "skysocks" - VPNClientName = "vpn-client" // VPNClientName ... + // SkysocksPort is the skysocks port on dmsg + SkysocksPort uint16 = 3 + + // SkysocksClientName is the skysocks-client app name + SkysocksClientName = "skysocks-client" + + // SkysocksClientPort is the skysocks-client app dmsg port + SkysocksClientPort uint16 = 13 + + // SkysocksClientAddr is the default port the socks5 proxy client serves on + SkysocksClientAddr = ":1080" + + // VPNServerName is the name of the vpn server app + VPNServerName = "vpn-server" + + // VPNServerPort is the vpn server dmsg port + VPNServerPort uint16 = 44 + + // VPNClientName is the name of the vpn client app + VPNClientName = "vpn-client" // TODO(darkrengarius): this one's not needed for the app to run but lack of it causes errors - VPNClientPort uint16 = 43 // VPNClientPort ... - ExampleServerName = "example-server-app" // ExampleServerName ... - ExampleServerPort uint16 = 45 // ExampleServerPort ... - ExampleClientName = "example-client-app" // ExampleClientName ... - ExampleClientPort uint16 = 46 // ExampleClientPort ... - SkyForwardingServerName = "sky-forwarding" // SkyForwardingServerName ... - SkyForwardingServerPort uint16 = 47 // SkyForwardingServerPort ... - SkyPingName = "sky-ping" // SkyPingName ... - SkyPingPort uint16 = 48 // SkyPingPort ... + // VPNClientPort over dmsg + VPNClientPort uint16 = 43 + + // ExampleServerName is the name of the example server app + ExampleServerName = "example-server-app" + + // ExampleServerPort is dmsg port of example server app + ExampleServerPort uint16 = 45 + + // ExampleClientName is the name of the example client app + ExampleClientName = "example-client-app" + + // ExampleClientPort dmsg port of example client app + ExampleClientPort uint16 = 46 + + // SkyForwardingServerName name of sky forwarding server app + SkyForwardingServerName = "sky-forwarding" + + // SkyForwardingServerPort dmsg port of skyfwd server app + SkyForwardingServerPort uint16 = 47 + + // SkyPingName is the name of the sky ping + SkyPingName = "sky-ping" + + // SkyPingPort dmsg port of sky ping + SkyPingPort uint16 = 48 // RPC constants. - RPCAddr = "localhost:3435" // RPCAddr ... - RPCTimeout = 20 * time.Second // RPCTimeout ... - TransportRPCTimeout = 1 * time.Minute // TransportRPCTimeout ... - UpdateRPCTimeout = 6 * time.Hour // UpdateRPCTimeout update requires huge timeout + // RPCAddr for skywire-cli to access skywire-visor + RPCAddr = "localhost:3435" + + // RPCTimeout timeout of rpc requests + RPCTimeout = 20 * time.Second + + // TransportRPCTimeout timeout of transport rpc + TransportRPCTimeout = 1 * time.Minute + + // UpdateRPCTimeout update requires huge timeout - NOTE: this is likely unused + UpdateRPCTimeout = 6 * time.Hour // Default skywire app server and discovery constants - AppSrvAddr = "localhost:5505" // AppSrvAddr ... - ServiceDiscUpdateInterval = time.Minute // ServiceDiscUpdateInterval ... - AppBinPath = "./" // AppBinPath ... - LogLevel = "info" // LogLevel ... + // AppSrvAddr address of app server + AppSrvAddr = "localhost:5505" + + // ServiceDiscUpdateInterval update interval for apps in service discovery + ServiceDiscUpdateInterval = time.Minute + + // AppBinPath is the default path for the apps + AppBinPath = "./" + + // LogLevel is the default log level of the visor + LogLevel = "info" // Routing constants - TpLogStore = "transport_logs" // TpLogStore ... - Custom = "custom" // Custom ... + // TpLogStore is where tp logs are stored + TpLogStore = "transport_logs" - // LocalPath constants + // Custom path to serve files from dmsghttp log server over dmsg + Custom = "custom" + + // LocalPath where the visor writes files to LocalPath = "./local" // Default hypervisor constants - HypervisorDB = ".skycoin/hypervisor/users.db" //HypervisorDB ... - EnableAuth = false // EnableAuth ... - PackageEnableAuth = true // PackageEnableAuth ... - EnableTLS = false // EnableTLS ... - TLSKey = "./ssl/key.pem" // TLSKey ... - TLSCert = "./ssl/cert.pem" // TLSCert ... + //HypervisorDB stores the password to access the hypervisor + HypervisorDB = ".skycoin/hypervisor/users.db" + + // EnableAuth enables auth on the hypervisor UI + EnableAuth = false + + // PackageEnableAuth is the default auth for package-based installations for hypervisor UI + PackageEnableAuth = true + + // EnableTLS enables tls for accessing hypervisor ui + EnableTLS = false + + // TLSKey for access to hvui + TLSKey = "./ssl/key.pem" + + // TLSCert for access to hvui + TLSCert = "./ssl/cert.pem" // IPCShutdownMessageType sends IPC shutdown message type IPCShutdownMessageType = 68 diff --git a/pkg/skywire-utilities/pkg/buildinfo/buildinfo.go b/pkg/skywire-utilities/pkg/buildinfo/buildinfo.go new file mode 100644 index 0000000000..a807e64226 --- /dev/null +++ b/pkg/skywire-utilities/pkg/buildinfo/buildinfo.go @@ -0,0 +1,81 @@ +// Package buildinfo pkg/buildinfo/buildinfo.go +package buildinfo + +import ( + "encoding/json" + "fmt" + "io" +) + +const unknown = "unknown" + +//$ go build -mod=vendor -ldflags="-X 'github.com/skycoin/skywire-utilities/pkg/buildinfo.version=$(git describe)' -X 'github.com/skycoin/skywire-utilities/pkg/buildinfo.date=$(date -u "+%Y-%m-%dT%H:%M:%SZ")' -X 'github.com/skycoin/skywire-utilities/pkg/buildinfo.commit=$(git rev-list -1 HEAD)'" . + +var ( + version = unknown + commit = unknown + date = unknown +) + +// $ go build -ldflags="-X 'github.com/skycoin/skywire-utilities/pkg/buildinfo.golist=$(go list -m -json -mod=mod github.com/skycoin/@)' -X 'github.com/skycoin/skywire-utilities/pkg/buildinfo.date=$(date -u "+%Y-%m-%dT%H:%M:%SZ")'" . +var golist string + +// ModuleInfo represents the JSON structure returned by `go list -m -json`. +type ModuleInfo struct { + Version string `json:"Version"` + Origin struct { + Hash string `json:"Hash"` + } `json:"Origin"` +} + +func init() { + if golist != "" { + var mInfo ModuleInfo + if err := json.Unmarshal([]byte(golist), &mInfo); err == nil { + if mInfo.Version != "" && version == unknown { + version = mInfo.Version + } + if mInfo.Origin.Hash != "" && commit == unknown { + commit = mInfo.Origin.Hash + } + } + } +} + +// Version returns version from the parsed module info. +func Version() string { + return version +} + +// Commit returns commit hash from the parsed module info. +func Commit() string { + return commit +} + +// Date returns date of build in RFC3339 format. +func Date() string { + return date +} + +// Get returns build info summary. +func Get() *Info { + return &Info{ + Version: Version(), + Commit: Commit(), + Date: Date(), + } +} + +// Info is build info summary. +type Info struct { + Version string `json:"version"` + Commit string `json:"commit"` + Date string `json:"date"` +} + +// WriteTo writes build info summary to io.Writer. +func (info *Info) WriteTo(w io.Writer) (int64, error) { + msg := fmt.Sprintf("Version %q built on %q against commit %q\n", info.Version, info.Date, info.Commit) + n, err := w.Write([]byte(msg)) + return int64(n), err +} diff --git a/pkg/skywire-utilities/pkg/cipher/cipher.go b/pkg/skywire-utilities/pkg/cipher/cipher.go new file mode 100644 index 0000000000..9b140badfb --- /dev/null +++ b/pkg/skywire-utilities/pkg/cipher/cipher.go @@ -0,0 +1,266 @@ +// Package cipher implements common golang encoding interfaces for +// github.com/skycoin/skycoin/src/cipher +package cipher + +import ( + "bytes" + "fmt" + "math/big" + "strings" + + "github.com/skycoin/skycoin/src/cipher" +) + +func init() { + cipher.DebugLevel2 = false // DebugLevel2 causes ECDH to be really slow +} + +// GenerateKeyPair creates key pair +func GenerateKeyPair() (PubKey, SecKey) { + pk, sk := cipher.GenerateKeyPair() + return PubKey(pk), SecKey(sk) +} + +// GenerateDeterministicKeyPair generates deterministic key pair +func GenerateDeterministicKeyPair(seed []byte) (PubKey, SecKey, error) { + pk, sk, err := cipher.GenerateDeterministicKeyPair(seed) + return PubKey(pk), SecKey(sk), err +} + +// NewPubKey converts []byte to a PubKey +func NewPubKey(b []byte) (PubKey, error) { + pk, err := cipher.NewPubKey(b) + return PubKey(pk), err +} + +// PubKey is a wrapper type for cipher.PubKey that implements common +// golang interfaces. +type PubKey cipher.PubKey + +// Hex returns a hex encoded PubKey string +func (pk PubKey) Hex() string { + return cipher.PubKey(pk).Hex() +} + +// Null returns true if PubKey is the null PubKey +func (pk PubKey) Null() bool { + return cipher.PubKey(pk).Null() +} + +// String implements fmt.Stringer for PubKey. Returns Hex representation. +func (pk PubKey) String() string { + return pk.Hex() +} + +// Big returns the big.Int representation of the public key. +func (pk PubKey) Big() *big.Int { + return new(big.Int).SetBytes(pk[:]) +} + +// Set implements pflag.Value for PubKey. +func (pk *PubKey) Set(s string) error { + cPK, err := cipher.PubKeyFromHex(s) + if err != nil { + return err + } + *pk = PubKey(cPK) + return nil +} + +// Type implements pflag.Value for PubKey. +func (pk PubKey) Type() string { + return "cipher.PubKey" +} + +// MarshalText implements encoding.TextMarshaler. +func (pk PubKey) MarshalText() ([]byte, error) { + return []byte(pk.Hex()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (pk *PubKey) UnmarshalText(data []byte) error { + if bytes.Count(data, []byte("0")) == len(data) { + return nil + } + + dPK, err := cipher.PubKeyFromHex(string(data)) + if err == nil { + *pk = PubKey(dPK) + } + return err +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (pk PubKey) MarshalBinary() ([]byte, error) { + return pk[:], nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (pk *PubKey) UnmarshalBinary(data []byte) error { + dPK, err := cipher.NewPubKey(data) + if err == nil { + *pk = PubKey(dPK) + } + return err +} + +// PubKeys represents a slice of PubKeys. +type PubKeys []PubKey + +// String implements stringer for PubKeys. +func (p PubKeys) String() string { + res := "public keys:\n" + for _, pk := range p { + res += fmt.Sprintf("\t%s\n", pk) + } + return res +} + +// Set implements pflag.Value for PubKeys. +func (p *PubKeys) Set(list string) error { + *p = PubKeys{} + for _, s := range strings.Split(list, ",") { + var pk PubKey + if err := pk.Set(strings.TrimSpace(s)); err != nil { + return err + } + *p = append(*p, pk) + } + return nil +} + +// Type implements pflag.Value for PubKeys. +func (p PubKeys) Type() string { + return "cipher.PubKeys" +} + +// SecKey is a wrapper type for cipher.SecKey that implements common +// golang interfaces. +type SecKey cipher.SecKey + +// Hex returns a hex encoded SecKey string +func (sk SecKey) Hex() string { + return cipher.SecKey(sk).Hex() +} + +// Null returns true if SecKey is the null SecKey. +func (sk SecKey) Null() bool { + return cipher.SecKey(sk).Null() +} + +// String implements fmt.Stringer for SecKey. Returns Hex representation. +func (sk SecKey) String() string { + return sk.Hex() +} + +// Set implements pflag.Value for SecKey. +func (sk *SecKey) Set(s string) error { + cSK, err := cipher.SecKeyFromHex(s) + if err != nil { + return err + } + *sk = SecKey(cSK) + return nil +} + +// Type implements pflag.Value for SecKey. +func (sk *SecKey) Type() string { + return "cipher.SecKey" +} + +// MarshalText implements encoding.TextMarshaler. +func (sk SecKey) MarshalText() ([]byte, error) { + return []byte(sk.Hex()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (sk *SecKey) UnmarshalText(data []byte) error { + dSK, err := cipher.SecKeyFromHex(string(data)) + if err == nil { + *sk = SecKey(dSK) + } + return err +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (sk SecKey) MarshalBinary() ([]byte, error) { + return sk[:], nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (sk *SecKey) UnmarshalBinary(data []byte) error { + dSK, err := cipher.NewSecKey(data) + if err == nil { + *sk = SecKey(dSK) + } + return err +} + +// PubKey recovers the public key for a secret key +func (sk SecKey) PubKey() (PubKey, error) { + pk, err := cipher.PubKeyFromSecKey(cipher.SecKey(sk)) + return PubKey(pk), err +} + +// Sig is a wrapper type for cipher.Sig that implements common golang interfaces. +type Sig cipher.Sig + +// Hex returns a hex encoded Sig string +func (sig Sig) Hex() string { + return cipher.Sig(sig).Hex() +} + +// String implements fmt.Stringer for Sig. Returns Hex representation. +func (sig Sig) String() string { + return sig.Hex() +} + +// Null returns true if Sig is a null Sig +func (sig Sig) Null() bool { + return sig == Sig{} +} + +// MarshalText implements encoding.TextMarshaler. +func (sig Sig) MarshalText() ([]byte, error) { + return []byte(sig.Hex()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (sig *Sig) UnmarshalText(data []byte) error { + dSig, err := cipher.SigFromHex(string(data)) + if err == nil { + *sig = Sig(dSig) + } + return err +} + +// SignPayload creates Sig for payload using SHA256 +func SignPayload(payload []byte, sec SecKey) (Sig, error) { + sig, err := cipher.SignHash(cipher.SumSHA256(payload), cipher.SecKey(sec)) + return Sig(sig), err +} + +// VerifyPubKeySignedPayload verifies that SHA256 hash of the payload was signed by PubKey +func VerifyPubKeySignedPayload(pubkey PubKey, sig Sig, payload []byte) error { + return cipher.VerifyPubKeySignedHash(cipher.PubKey(pubkey), cipher.Sig(sig), cipher.SumSHA256(payload)) +} + +// RandByte returns rand N bytes +func RandByte(n int) []byte { + return cipher.RandByte(n) +} + +// SHA256 is a wrapper type for cipher.SHA256 that implements common +// golang interfaces. +type SHA256 cipher.SHA256 + +// SHA256FromBytes converts []byte to SHA256 +func SHA256FromBytes(b []byte) (SHA256, error) { + h, err := cipher.SHA256FromBytes(b) + return SHA256(h), err +} + +// SumSHA256 sum sha256 +func SumSHA256(b []byte) SHA256 { + return SHA256(cipher.SumSHA256(b)) +} diff --git a/pkg/skywire-utilities/pkg/cipher/cipher_test.go b/pkg/skywire-utilities/pkg/cipher/cipher_test.go new file mode 100644 index 0000000000..9cd85ee83b --- /dev/null +++ b/pkg/skywire-utilities/pkg/cipher/cipher_test.go @@ -0,0 +1,102 @@ +// Package buildinfo pkg/cipher/cipher_test.go +package cipher + +import ( + "log" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire-utilities/pkg/logging" +) + +func TestMain(m *testing.M) { + loggingLevel, ok := os.LookupEnv("TEST_LOGGING_LEVEL") + if ok { + lvl, err := logging.LevelFromString(loggingLevel) + if err != nil { + log.Fatal(err) + } + logging.SetLevel(lvl) + } else { + logging.Disable() + } + + os.Exit(m.Run()) +} + +func TestPubKeyString(t *testing.T) { + p, _ := GenerateKeyPair() + require.Equal(t, p.Hex(), p.String()) +} + +func TestPubKeyTextMarshaller(t *testing.T) { + p, _ := GenerateKeyPair() + h, err := p.MarshalText() + require.NoError(t, err) + + var p2 PubKey + err = p2.UnmarshalText(h) + require.NoError(t, err) + require.Equal(t, p, p2) +} + +func TestPubKeyBinaryMarshaller(t *testing.T) { + p, _ := GenerateKeyPair() + b, err := p.MarshalBinary() + require.NoError(t, err) + + var p2 PubKey + err = p2.UnmarshalBinary(b) + require.NoError(t, err) + require.Equal(t, p, p2) +} + +func TestSecKeyString(t *testing.T) { + _, s := GenerateKeyPair() + require.Equal(t, s.Hex(), s.String()) +} + +func TestSecKeyTextMarshaller(t *testing.T) { + _, s := GenerateKeyPair() + h, err := s.MarshalText() + require.NoError(t, err) + + var s2 SecKey + err = s2.UnmarshalText(h) + require.NoError(t, err) + require.Equal(t, s, s2) +} + +func TestSecKeyBinaryMarshaller(t *testing.T) { + _, s := GenerateKeyPair() + b, err := s.MarshalBinary() + require.NoError(t, err) + + var s2 SecKey + err = s2.UnmarshalBinary(b) + require.NoError(t, err) + require.Equal(t, s, s2) +} + +func TestSigString(t *testing.T) { + _, sk := GenerateKeyPair() + sig, err := SignPayload([]byte("foo"), sk) + require.NoError(t, err) + assert.Equal(t, sig.Hex(), sig.String()) +} + +func TestSigTextMarshaller(t *testing.T) { + _, sk := GenerateKeyPair() + sig, err := SignPayload([]byte("foo"), sk) + require.NoError(t, err) + h, err := sig.MarshalText() + require.NoError(t, err) + + var sig2 Sig + err = sig2.UnmarshalText(h) + require.NoError(t, err) + assert.Equal(t, sig, sig2) +} diff --git a/pkg/skywire-utilities/pkg/cipher/utils_pubkey.go b/pkg/skywire-utilities/pkg/cipher/utils_pubkey.go new file mode 100644 index 0000000000..aa694719b3 --- /dev/null +++ b/pkg/skywire-utilities/pkg/cipher/utils_pubkey.go @@ -0,0 +1,24 @@ +// Package cipher pkg/cipher/ustils_pubkey.go +package cipher + +// SamePubKeys returns true when the provided public key slices have the same keys. +// The slices do not need to be in the same order. +// It is assumed that there are no duplicate elements within the slices. +func SamePubKeys(pks1, pks2 []PubKey) bool { + if len(pks1) != len(pks2) { + return false + } + + m := make(map[PubKey]struct{}, len(pks1)) + for _, pk := range pks1 { + m[pk] = struct{}{} + } + + for _, pk := range pks2 { + if _, ok := m[pk]; !ok { + return false + } + } + + return true +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/catch.go b/pkg/skywire-utilities/pkg/cmdutil/catch.go new file mode 100644 index 0000000000..baa140de31 --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/catch.go @@ -0,0 +1,38 @@ +// Package cmdutil pkg/cmdutil/catch.go +package cmdutil + +import ( + "fmt" + "os" + "strings" + + "github.com/sirupsen/logrus" +) + +// Catch panics on any non-nil error. +func Catch(v ...interface{}) { + CatchWithMsg("", v...) +} + +// CatchWithMsg panics on any non-nil error with the provided message (if any). +func CatchWithMsg(msg string, v ...interface{}) { + for _, val := range v { + if err, ok := val.(error); ok && err != nil { + if msg == "" { + panic(err) + } + msg = strings.TrimSuffix(strings.TrimSpace(msg), ":") + panic(fmt.Errorf("%s: %v", msg, err)) + } + } +} + +// CatchWithLog calls Fatal() on any non-nil error. +func CatchWithLog(log logrus.FieldLogger, msg string, v ...interface{}) { + for _, val := range v { + if err, ok := val.(error); ok && err != nil { + log.WithError(err).Fatal(msg) + os.Exit(1) + } + } +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/catch_test.go b/pkg/skywire-utilities/pkg/cmdutil/catch_test.go new file mode 100644 index 0000000000..4f096a5fac --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/catch_test.go @@ -0,0 +1,62 @@ +// Package cmdutil pkg/cmdutil/catch_test.go +package cmdutil + +import ( + "errors" + "math/rand" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire-utilities/pkg/cipher" +) + +func TestCatch(t *testing.T) { + fn := func(ok bool) (int, error) { + if ok { + return rand.Int(), nil //nolint:gosec + } + return 0, errors.New("not okay") + } + + t.Run("should_not_panic", func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("The code paniced: %v", r) + } + }() + Catch(fn(true)) + }) + + t.Run("should_panic", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + Catch(fn(false)) + }) + + t.Run("in_order", func(t *testing.T) { + const rounds = 5 + + expected := cipher.RandByte(rounds) + actual := make([]byte, 0, rounds) + + addFn := func(i int) error { //nolint:unparam + actual = append(actual, expected[i]) + return nil + } + + Catch( + addFn(0), + addFn(1), + addFn(2), + addFn(3), + addFn(4)) + + for i, exp := range expected { + require.Equal(t, exp, actual[i]) + } + }) +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/cmd_name.go b/pkg/skywire-utilities/pkg/cmdutil/cmd_name.go new file mode 100644 index 0000000000..68cf026a47 --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/cmd_name.go @@ -0,0 +1,12 @@ +// Package cmdutil pkg/cmdutil/cmd_name.go +package cmdutil + +import ( + "os" + "path" +) + +// RootCmdName returns the root command name. +func RootCmdName() string { + return path.Base(os.Args[0]) +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/service_flags.go b/pkg/skywire-utilities/pkg/cmdutil/service_flags.go new file mode 100644 index 0000000000..c8a8d7c98d --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/service_flags.go @@ -0,0 +1,242 @@ +// Package cmdutil pkg/cmdutil/service_flags.go +package cmdutil + +import ( + "errors" + "fmt" + "io" + "os" + "strings" + "unicode" + + jsoniter "github.com/json-iterator/go" + "github.com/spf13/cobra" + + "github.com/skycoin/skywire-utilities/pkg/logging" +) + +// Associated errors. +var ( + ErrTagCannotBeEmpty = errors.New("tag cannot be empty") + ErrTagHasInvalidChars = errors.New("tag can only contain alphanumeric values and underscore") + ErrTagHasMisplacedUnderscores = errors.New("tag cannot start or end with an underscore or have two underscores back-to-back") + ErrInvalidLogString = errors.New("failed to convert string to log level") + ErrInvalidSyslogNet = errors.New("network type is unsupported for syslog") +) + +var json = jsoniter.ConfigFastest + +const ( + stdinConfig = "stdin" +) + +// ServiceFlags represents common flags which are shared across services. +type ServiceFlags struct { + MetricsAddr string + Syslog string + SyslogNet string + LogLevel string + Tag string + Config string + Stdin bool + + // state + checkDone bool + loggerDone bool + + logger *logging.Logger +} + +// Init initiates the service flags. +// The following are performed: +// - Ensure 'defaultTag' is provided and valid. +// - Set "library" defaults. +// - Set "exec" defaults - provided by 'defaultTag' and 'defaultConf'. +// - Add flags to 'rootCmd'. +func (sf *ServiceFlags) Init(rootCmd *cobra.Command, defaultTag, defaultConf string) { + if err := ValidTag(defaultTag); err != nil { + panic(err) + } + + // "library" defaults + if sf.SyslogNet == "" { + // TODO (evanlinjin): Consider using tcp as syslog udp is legacy. + sf.SyslogNet = "udp" + } + if sf.LogLevel == "" { + sf.LogLevel = "debug" + } + + // "exec" defaults + if defaultTag != "" { + sf.Tag = defaultTag + } + if defaultConf != "" { + sf.Config = defaultConf + } + + // flags + rootCmd.Flags().StringVarP(&sf.MetricsAddr, "metrics", "m", sf.MetricsAddr, "address to serve metrics API from") + rootCmd.Flags().StringVar(&sf.Syslog, "syslog", sf.Syslog, "address in which to dial to syslog server") + rootCmd.Flags().StringVar(&sf.SyslogNet, "syslog-net", sf.SyslogNet, "network in which to dial to syslog server") + rootCmd.Flags().StringVar(&sf.LogLevel, "syslog-lvl", sf.LogLevel, "minimum log level to report") + rootCmd.Flags().StringVar(&sf.Tag, "tag", sf.Tag, "tag used for logging and metrics") + + // only enable config flags if 'defaultConf' is set + if defaultConf != "" { + rootCmd.Flags().StringVarP(&sf.Config, "config", "c", sf.Config, "location of config file (STDIN to read from standard input)") + rootCmd.Flags().BoolVar(&sf.Stdin, "stdin", sf.Stdin, "whether to read config via stdin") + } +} + +// Check checks service flags. +func (sf *ServiceFlags) Check() error { + if alreadyDone(&sf.checkDone) { + return nil + } + + if sf.Syslog != "" { + switch sf.SyslogNet { + case "tcp", "udp", "unix": + default: + return fmt.Errorf("%w: %s", ErrInvalidSyslogNet, sf.SyslogNet) + } + } + + if _, _, err := LevelFromString(sf.LogLevel); err != nil { + return fmt.Errorf("%w: %s", ErrInvalidLogString, sf.LogLevel) + } + + if err := ValidTag(sf.Tag); err != nil { + return fmt.Errorf("%w: %s", err, sf.Tag) + } + + return nil +} + +// Logger returns the logger as specified by the service flags. +func (sf *ServiceFlags) Logger() *logging.Logger { + if alreadyDone(&sf.loggerDone) { + return sf.logger + } + + log := logging.MustGetLogger(sf.Tag) + sf.logger = log + + logLvl, sysLvl, err := LevelFromString(sf.LogLevel) + if err != nil { + panic(err) // should not happen as we have already checked earlier on + } + logging.SetLevel(logLvl) + + if sf.Syslog != "" { + sf.sysLogHook(log, sysLvl) + } + + return log +} + +// ParseConfig parses config from service tags. +// If checkArgs is set, we additionally parse os.Args to find a config path. +func (sf *ServiceFlags) ParseConfig(args []string, checkArgs bool, v interface{}, genDefaultFunc func() (io.ReadCloser, error)) error { + r, err := sf.obtainConfigReader(args, checkArgs, genDefaultFunc) + if err != nil { + return err + } + defer func() { + if err = r.Close(); err != nil { + sf.logger.WithError(err).Warn("Failed to close config source.") + } + }() + + b, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("failed to read from config source: %w", err) + } + + if err = json.Unmarshal(b, v); err != nil { + return fmt.Errorf("failed to decode config file: %w", err) + } + + j, err := json.MarshalIndent(v, "", " ") + if err != nil { + panic(err) // should not happen + } + sf.logger.Infof("Read config: %s", string(j)) + + return nil +} + +func (sf *ServiceFlags) obtainConfigReader(args []string, checkArgs bool, genDefaultFunc func() (io.ReadCloser, error)) (io.ReadCloser, error) { + switch { + case sf.Stdin || strings.ToLower(sf.Config) == stdinConfig: + stdin := io.NopCloser(os.Stdin) // ensure stdin is not closed + return stdin, nil + + case checkArgs: + if len(args) == 1 { + return genDefaultFunc() + } + + for i, arg := range args { + if strings.HasSuffix(arg, ".json") && i > 0 && !strings.HasPrefix(args[i-1], "-") { + var f io.ReadCloser + var err error + f, err = os.Open(arg) //nolint:gosec + if err != nil { + return nil, fmt.Errorf("failed to open config file: %w", err) + } + return f, nil + } + } + + case sf.Config != "": + f, err := os.Open(sf.Config) + if err != nil { + return nil, fmt.Errorf("failed to open config file: %w", err) + } + return f, nil + + } + + return nil, errors.New("no config location specified") +} + +// ValidTag returns an error if the tag is invalid. +func ValidTag(tag string) error { + if tag == "" { + return ErrTagCannotBeEmpty + } + + // check: valid characters + for _, c := range tag { + ranges := []*unicode.RangeTable{unicode.Letter, unicode.Number} + if unicode.IsOneOf(ranges, c) || c == '_' { + continue + } + return ErrTagHasInvalidChars + } + + // check: correct positioning of characters + for i, c := range tag { + if i == 0 || i == len(tag)-1 { + if c == '_' { + return ErrTagHasMisplacedUnderscores + } + continue + } + if c == '_' && (tag[i-1] == '_' || tag[i+1] == '_') { + return ErrTagHasMisplacedUnderscores + } + } + + return nil +} + +func alreadyDone(done *bool) bool { + if *done { + return true + } + *done = true + return false +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/service_flags_test.go b/pkg/skywire-utilities/pkg/cmdutil/service_flags_test.go new file mode 100644 index 0000000000..46786ef6c8 --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/service_flags_test.go @@ -0,0 +1,70 @@ +// Package cmdutil pkg/cmdutil/service_flags_test.go +package cmdutil + +import ( + "fmt" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +func TestServiceFlags_Init(t *testing.T) { + t.Run("panic_on_empty_tag", func(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + + err, ok := r.(error) + require.True(t, ok) + require.EqualError(t, err, ErrTagCannotBeEmpty.Error()) + }() + + var sf ServiceFlags + sf.Init(&cobra.Command{}, "", "config.json") + }) + + t.Run("panic_on_invalid_tag", func(t *testing.T) { + type testCase struct { + tag string + err error + } + + testCases := []testCase{ + {tag: "abcdefghijklmnopqrstuvwxyz", err: nil}, + {tag: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", err: nil}, + {tag: "0123456789", err: nil}, + {tag: "aA12bB34cC56", err: nil}, + {tag: "a_a", err: nil}, + {tag: "ab_12_34_D34_d4_1f34", err: nil}, + {tag: "aGJKN-#fn3", err: ErrTagHasInvalidChars}, + {tag: "_abc123", err: ErrTagHasMisplacedUnderscores}, + {tag: "AF3g_", err: ErrTagHasMisplacedUnderscores}, + {tag: "A__231v", err: ErrTagHasMisplacedUnderscores}, + {tag: "a32f__b", err: ErrTagHasMisplacedUnderscores}, + {tag: "B3s__21fg", err: ErrTagHasMisplacedUnderscores}, + } + + for i, tc := range testCases { + i, tc := i, tc + t.Run(fmt.Sprintf("%d:%s", i, tc.tag), func(t *testing.T) { + defer func() { + if tc.err == nil { + require.Nil(t, recover()) + return + } + + r := recover() + require.NotNil(t, r) + + err, ok := r.(error) + require.True(t, ok) + require.EqualError(t, err, tc.err.Error()) + }() + + var sf ServiceFlags + sf.Init(&cobra.Command{}, tc.tag, "config.json") + }) + } + }) +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/signal_context.go b/pkg/skywire-utilities/pkg/cmdutil/signal_context.go new file mode 100644 index 0000000000..343c522d80 --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/signal_context.go @@ -0,0 +1,35 @@ +// Package cmdutil pkg/cmdutil/signal_context.go +package cmdutil + +import ( + "context" + "os" + "os/signal" + + "github.com/sirupsen/logrus" +) + +// SignalContext returns a context that cancels on given syscall signals. +func SignalContext(ctx context.Context, log logrus.FieldLogger) (context.Context, context.CancelFunc) { + if log == nil { + log = logrus.New() + } + + ctx, cancel := context.WithCancel(ctx) + + ch := make(chan os.Signal, 1) + listenSigs := listenSignals() + signal.Notify(ch, listenSigs...) + + go func() { + select { + case sig := <-ch: + log.WithField("signal", sig). + Info("Closing with received signal.") + case <-ctx.Done(): + } + cancel() + }() + + return ctx, cancel +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/signal_unix.go b/pkg/skywire-utilities/pkg/cmdutil/signal_unix.go new file mode 100644 index 0000000000..473d3bf210 --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/signal_unix.go @@ -0,0 +1,15 @@ +//go:build !windows +// +build !windows + +// Package cmdutil pkg/cmdutil/signal_unix.go +package cmdutil + +import ( + "os" + + "golang.org/x/sys/unix" +) + +func listenSignals() []os.Signal { + return []os.Signal{unix.SIGINT, unix.SIGTERM, unix.SIGQUIT} +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/signal_windows.go b/pkg/skywire-utilities/pkg/cmdutil/signal_windows.go new file mode 100644 index 0000000000..df05d8c4ac --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/signal_windows.go @@ -0,0 +1,15 @@ +//go:build windows +// +build windows + +// Package cmdutil pkg/cmdutil/signal_windows.go +package cmdutil + +import ( + "os" + + "golang.org/x/sys/windows" +) + +func listenSignals() []os.Signal { + return []os.Signal{os.Interrupt, windows.SIGINT, windows.SIGTERM, windows.SIGQUIT} +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/sysloghook_unix.go b/pkg/skywire-utilities/pkg/cmdutil/sysloghook_unix.go new file mode 100644 index 0000000000..0a90281097 --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/sysloghook_unix.go @@ -0,0 +1,46 @@ +//go:build !windows +// +build !windows + +// Package cmdutil pkg/cmdutil/sysloghook_unix.go +package cmdutil + +import ( + "log/syslog" + "strings" + + "github.com/sirupsen/logrus" + logrussyslog "github.com/sirupsen/logrus/hooks/syslog" + + "github.com/skycoin/skywire-utilities/pkg/logging" +) + +func (sf *ServiceFlags) sysLogHook(log *logging.Logger, sysLvl int) { + hook, err := logrussyslog.NewSyslogHook(sf.SyslogNet, sf.Syslog, syslog.Priority(sysLvl), sf.Tag) + if err != nil { + log.WithError(err). + WithField("net", sf.SyslogNet). + WithField("addr", sf.Syslog). + Fatal("Failed to connect to syslog daemon.") + } + logging.AddHook(hook) +} + +// LevelFromString returns a logrus.Level and syslog.Priority from a string identifier. +func LevelFromString(s string) (logrus.Level, int, error) { + switch strings.ToLower(s) { + case "debug": + return logrus.DebugLevel, int(syslog.LOG_DEBUG), nil + case "info", "notice": + return logrus.InfoLevel, int(syslog.LOG_INFO), nil + case "warn", "warning": + return logrus.WarnLevel, int(syslog.LOG_WARNING), nil + case "error": + return logrus.ErrorLevel, int(syslog.LOG_ERR), nil + case "fatal", "critical": + return logrus.FatalLevel, int(syslog.LOG_CRIT), nil + case "panic": + return logrus.PanicLevel, int(syslog.LOG_EMERG), nil + default: + return logrus.DebugLevel, int(syslog.LOG_DEBUG), ErrInvalidLogString + } +} diff --git a/pkg/skywire-utilities/pkg/cmdutil/sysloghook_windows.go b/pkg/skywire-utilities/pkg/cmdutil/sysloghook_windows.go new file mode 100644 index 0000000000..83ddf4f8d5 --- /dev/null +++ b/pkg/skywire-utilities/pkg/cmdutil/sysloghook_windows.go @@ -0,0 +1,36 @@ +//go:build windows +// +build windows + +// Package cmdutil pkg/cmdutil/sysloghook_windows.go +package cmdutil + +import ( + "strings" + + "github.com/sirupsen/logrus" + + "github.com/skycoin/skywire-utilities/pkg/logging" +) + +func (sf *ServiceFlags) sysLogHook(_ *logging.Logger, _ int) { +} + +// LevelFromString returns a logrus.Level and syslog.Priority from a string identifier. +func LevelFromString(s string) (logrus.Level, int, error) { + switch strings.ToLower(s) { + case "debug": + return logrus.DebugLevel, 0, nil + case "info", "notice": + return logrus.InfoLevel, 0, nil + case "warn", "warning": + return logrus.WarnLevel, 0, nil + case "error": + return logrus.ErrorLevel, 0, nil + case "fatal", "critical": + return logrus.FatalLevel, 0, nil + case "panic": + return logrus.PanicLevel, 0, nil + default: + return logrus.DebugLevel, 0, ErrInvalidLogString + } +} diff --git a/pkg/skywire-utilities/pkg/geo/geo.go b/pkg/skywire-utilities/pkg/geo/geo.go new file mode 100644 index 0000000000..031f06f64b --- /dev/null +++ b/pkg/skywire-utilities/pkg/geo/geo.go @@ -0,0 +1,93 @@ +// Package geo pkg/geo/geo.go +package geo + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "net" + "net/http" + + "github.com/sirupsen/logrus" + + "github.com/skycoin/skywire-utilities/pkg/logging" + "github.com/skycoin/skywire-utilities/pkg/netutil" +) + +// Errors associated with geo calls. +var ( + ErrIPIsNotPublic = errors.New("ip address is not public") + ErrCannotObtainLocFromIP = errors.New("cannot obtain location from IP") +) + +const ( + reqURL = "http://ip.skycoin.com/?ip=%s" +) + +// LocationData represents a geolocation point. +type LocationData struct { + Lat float64 `json:"lat,omitempty"` + Lon float64 `json:"lon,omitempty"` + Country string `json:"country,omitempty"` + Region string `json:"region,omitempty"` +} + +// LocationDetails represents a function that obtains geolocation from a given IP. +type LocationDetails func(ip net.IP) (*LocationData, error) + +// MakeIPDetails returns a GeoFunc. +func MakeIPDetails(log logrus.FieldLogger, _ string) LocationDetails { + // Just in case. + if log == nil { + log = logging.MustGetLogger("geo") + } + + return func(ip net.IP) (*LocationData, error) { + // Check if IP is public IP. + if !netutil.IsPublicIP(ip) { + return nil, ErrIPIsNotPublic + } + + // Get Geo from IP. + var ( + resp *http.Response + err error + ) + + resp, err = http.Get(fmt.Sprintf(reqURL, ip.String())) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() //nolint:errcheck + + // Get body. + j := struct { + CountryCode string `json:"country_code"` + Region string `json:"region_code"` + Lat float64 `json:"latitude"` + Lon float64 `json:"longitude"` + }{} + if err := json.NewDecoder(resp.Body).Decode(&j); err != nil { + return nil, err + } + if j.CountryCode == "" && j.Region == "" && j.Lat == 0 && j.Lon == 0 { + return nil, fmt.Errorf("call to ip.skycoin.com returned empty: %s", ErrCannotObtainLocFromIP) + } + + // Prepare output. + out := LocationData{ + Lat: roundTwoDigits(j.Lat), + Lon: roundTwoDigits(j.Lon), + Country: j.CountryCode, + Region: j.Region, + } + log.WithField("geo", out).Info() + + return &out, nil + } +} + +func roundTwoDigits(value float64) float64 { + return math.Round(value*100) / 100 +} diff --git a/pkg/skywire-utilities/pkg/httpauth/auth.go b/pkg/skywire-utilities/pkg/httpauth/auth.go new file mode 100644 index 0000000000..2f2d99fd4f --- /dev/null +++ b/pkg/skywire-utilities/pkg/httpauth/auth.go @@ -0,0 +1,120 @@ +// Package httpauth pkg/httpauth/auth.go +package httpauth + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "strconv" + + "github.com/skycoin/skywire-utilities/pkg/cipher" +) + +// Auth holds authentication mandatory values +type Auth struct { + Key cipher.PubKey + Nonce Nonce + Sig cipher.Sig +} + +// AuthFromHeaders attempts to extract auth from request header +func AuthFromHeaders(hdr http.Header, shouldVerifyAuth bool) (*Auth, error) { + a := &Auth{} + v := hdr.Get("SW-Public") + + if v == "" { + return nil, errors.New("SW-Public missing") + } + + key := cipher.PubKey{} + if err := key.UnmarshalText([]byte(v)); err != nil { + return nil, fmt.Errorf("error parsing SW-Public: %w", err) + } + + a.Key = key + + if shouldVerifyAuth { + if v = hdr.Get("SW-Sig"); v == "" { + return nil, errors.New("SW-Sig missing") + } + + sig := cipher.Sig{} + if err := sig.UnmarshalText([]byte(v)); err != nil { + return nil, fmt.Errorf("error parsing SW-Sig:'%s': %w", v, err) + } + + a.Sig = sig + } + + nonceStr := hdr.Get("SW-Nonce") + if nonceStr == "" { + return nil, errors.New("SW-Nonce missing") + } + + nonceUint, err := strconv.ParseUint(nonceStr, 10, 64) + if err != nil { + if numErr, ok := err.(*strconv.NumError); ok { + return nil, fmt.Errorf("error parsing SW-Nonce: %w", numErr.Err) + } + + return nil, fmt.Errorf("error parsing SW-Nonce: %w", err) + } + + a.Nonce = Nonce(nonceUint) + + return a, nil +} + +// Verify verifies signature of a payload. +func (a *Auth) Verify(in []byte) error { + return Verify(in, a.Nonce, a.Key, a.Sig) +} + +// verifyAuth verifies Request's signature. +func verifyAuth(store NonceStore, r *http.Request, auth *Auth) error { + cur, err := store.Nonce(r.Context(), auth.Key) + if err != nil { + return err + } + + if auth.Nonce != cur { + fmt.Printf("SW-Nonce mismatch, want %q, got %q, key=%q, sig=%q\n", + cur.String(), auth.Nonce.String(), auth.Key.String(), auth.Sig.String()) + + return errors.New("SW-Nonce does not match") + } + + var buf bytes.Buffer + body := io.TeeReader(r.Body, &buf) + + payload, err := io.ReadAll(body) + if err != nil { + return err + } + + // close the original body cause it will be replaced + if err := r.Body.Close(); err != nil { + return err + } + + r.Body = io.NopCloser(&buf) + + return auth.Verify(payload) +} + +// PayloadWithNonce returns the concatenation of payload and nonce. +func PayloadWithNonce(payload []byte, nonce Nonce) []byte { + return []byte(fmt.Sprintf("%s%d", string(payload), nonce)) +} + +// Sign signs the Hash of payload and nonce +func Sign(payload []byte, nonce Nonce, sec cipher.SecKey) (cipher.Sig, error) { + return cipher.SignPayload(PayloadWithNonce(payload, nonce), sec) +} + +// Verify verifies the signature of the hash of payload and nonce +func Verify(payload []byte, nonce Nonce, pub cipher.PubKey, sig cipher.Sig) error { + return cipher.VerifyPubKeySignedPayload(pub, sig, PayloadWithNonce(payload, nonce)) +} diff --git a/pkg/skywire-utilities/pkg/httpauth/handler.go b/pkg/skywire-utilities/pkg/httpauth/handler.go new file mode 100644 index 0000000000..2d6a992a27 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httpauth/handler.go @@ -0,0 +1,221 @@ +// Package httpauth pkg/httpauth/handler.go +package httpauth + +import ( + "bufio" + "context" + "errors" + "net" + "net/http" + "net/url" + "strings" + + "github.com/skycoin/skywire-utilities/pkg/cipher" + "github.com/skycoin/skywire-utilities/pkg/httputil" + "github.com/skycoin/skywire-utilities/pkg/logging" +) + +var ( + logger = logging.MustGetLogger("Auth") + + // ContextAuthKey stores authenticated PubKey in Context . + ContextAuthKey = struct{}{} + // LogAuthKey stores authentication PubKey in log entry + LogAuthKey = "PK" +) + +// HTTPResponse represents the http response struct +type HTTPResponse struct { + Error *HTTPError `json:"error,omitempty"` + Data interface{} `json:"data,omitempty"` +} + +// HTTPError is included in an HTTPResponse +type HTTPError struct { + Message string `json:"message"` + Code int `json:"code"` +} + +// implements http.ResponseWriter +type statusWriter struct { + http.ResponseWriter + http.Hijacker + status int +} + +func (w *statusWriter) WriteHeader(status int) { + w.status = status + w.ResponseWriter.WriteHeader(status) +} + +func (w *statusWriter) Write(b []byte) (int, error) { + if w.status == 0 { + w.status = 200 + } + + n, err := w.ResponseWriter.Write(b) + + return n, err +} + +func (w *statusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("http.ResponseWriter does not implement http.Hijacker") + } + + return hijacker.Hijack() +} + +// WithAuth wraps a http.Handler and adds authentication logic. +// The original http.Handler is responsible for setting the status code. +// The middleware logic should only increment the security nonce if the status code +// from the original http.Handler is of 2xx value (representing success). +// Any http.Handler that is wrapped with this function will have available the authenticated +// public key from it's context, stored in the value ContextAuthKey. +func WithAuth(store NonceStore, original http.Handler, shouldVerifyAuth bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth, err := AuthFromHeaders(r.Header, shouldVerifyAuth) + if err != nil { + httputil.WriteJSON(w, r, http.StatusUnauthorized, + NewHTTPErrorResponse(http.StatusUnauthorized, + err.Error())) + return + } + + if shouldVerifyAuth { + err = verifyAuth(store, r, auth) + if err != nil { + httputil.WriteJSON(w, r, http.StatusUnauthorized, + NewHTTPErrorResponse(http.StatusUnauthorized, + err.Error())) + return + } + } + + sw := statusWriter{ResponseWriter: w} + httputil.LogEntrySetField(r, LogAuthKey, auth.Key) + original.ServeHTTP(&sw, r.WithContext(context.WithValue( + r.Context(), ContextAuthKey, auth.Key))) //nolint + + if sw.status == http.StatusOK { + _, err := store.IncrementNonce(r.Context(), auth.Key) + if err != nil { + logger.Error(err) + } + } + }) +} + +func makeMiddleware(store NonceStore, shouldVerifyAuth bool) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return WithAuth(store, next, shouldVerifyAuth) + } +} + +// MakeMiddleware is a convenience function that calls WithAuth. +func MakeMiddleware(store NonceStore) func(next http.Handler) http.Handler { + return makeMiddleware(store, true) +} + +// PKFromCtx is a convenience function to obtain PK from ctx. +func PKFromCtx(ctx context.Context) cipher.PubKey { + pk, _ := ctx.Value(ContextAuthKey).(cipher.PubKey) + return pk +} + +// MakeLoadTestingMiddleware is the same as `MakeMiddleware` but omits auth checks to simplify load testing. +func MakeLoadTestingMiddleware(store NonceStore) func(next http.Handler) http.Handler { + return makeMiddleware(store, false) +} + +// NextNonceResponse represents a ServeHTTP response for json encoding +type NextNonceResponse struct { + Edge cipher.PubKey `json:"edge"` + NextNonce Nonce `json:"next_nonce"` +} + +// NonceHandler provides server-side logic for Skywire-related RESTFUL authorization and authentication. +type NonceHandler struct { + Store NonceStore +} + +// ServeHTTP implements http Handler +// Use this in endpoint: +// mux.Handle("/security/nonces/", &NonceHandler{Store}) +func (as *NonceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + remotePK, err := retrievePkFromURL(r.URL) + if err != nil { + httputil.WriteJSON(w, r, http.StatusBadRequest, + NewHTTPErrorResponse(http.StatusBadRequest, + err.Error())) + + return + } + + var nilPK cipher.PubKey + + if remotePK == nilPK { + httputil.WriteJSON(w, r, http.StatusBadRequest, + NewHTTPErrorResponse(http.StatusBadRequest, + "Invalid public key")) + + return + } + + nonce, err := as.Store.Nonce(r.Context(), remotePK) + if err != nil { + httputil.WriteJSON(w, r, http.StatusInternalServerError, + NewHTTPErrorResponse(http.StatusInternalServerError, + err.Error())) + + return + } + + httputil.WriteJSON(w, r, http.StatusOK, NextNonceResponse{Edge: remotePK, NextNonce: nonce}) +} + +// NewHTTPErrorResponse returns an HTTPResponse with the Error field populated +func NewHTTPErrorResponse(code int, msg string) HTTPResponse { + if msg == "" { + msg = http.StatusText(code) + } + + return HTTPResponse{ + Error: &HTTPError{ + Code: code, + Message: msg, + }, + } +} + +// retrievePkFromURL returns the id used on endpoints of the form path/:pk +// it doesn't checks if the endpoint has this form and can fail with other +// endpoint forms +func retrievePkFromURL(url *url.URL) (cipher.PubKey, error) { + splitPath := strings.Split(url.EscapedPath(), "/") + v := splitPath[len(splitPath)-1] + pk := cipher.PubKey{} + err := pk.UnmarshalText([]byte(v)) + return pk, err +} + +// GetRemoteAddr gets the remote address from the request +// in case of dmsghttp the RemoteAddress is a pk so it gets the RemoteAddr +// from the header instead +func GetRemoteAddr(r *http.Request) string { + var pk cipher.PubKey + + // remove the port incase of an IP or a PK + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + + err = pk.Set(host) + if err == nil { + return r.Header.Get("SW-PublicIP") + } + + return host +} diff --git a/pkg/skywire-utilities/pkg/httpauth/handler_test.go b/pkg/skywire-utilities/pkg/httpauth/handler_test.go new file mode 100644 index 0000000000..80aa050599 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httpauth/handler_test.go @@ -0,0 +1,287 @@ +// Package httpauth pkg/httpauth/handler_test.go +package httpauth + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire-utilities/pkg/cipher" + "github.com/skycoin/skywire-utilities/pkg/httputil" + "github.com/skycoin/skywire-utilities/pkg/storeconfig" +) + +var testPubKey, testSec = cipher.GenerateKeyPair() + +// validHeaders returns a valid set of headers +func validHeaders(t *testing.T, payload []byte) http.Header { + nonce := Nonce(0) + sig, err := Sign(payload, nonce, testSec) + require.NoError(t, err) + + hdr := http.Header{} + hdr.Set("SW-Public", testPubKey.Hex()) + hdr.Set("SW-Sig", sig.Hex()) + hdr.Set("SW-Nonce", nonce.String()) + + return hdr +} + +func validHeadersWithNonce(t *testing.T, nonce Nonce, payload []byte) http.Header { + sig, err := Sign(payload, nonce, testSec) + require.NoError(t, err) + + hdr := http.Header{} + hdr.Set("SW-Public", testPubKey.Hex()) + hdr.Set("SW-Sig", sig.Hex()) + hdr.Set("SW-Nonce", nonce.String()) + + return hdr +} + +func invalidHeaders(t *testing.T, payload []byte) http.Header { + _, invalidSec := cipher.GenerateKeyPair() + nonce := Nonce(0) + sig, err := Sign(payload, nonce, invalidSec) + require.NoError(t, err) + + hdr := http.Header{} + hdr.Set("SW-Public", testPubKey.Hex()) + hdr.Set("SW-Sig", sig.Hex()) + hdr.Set("SW-Nonce", nonce.String()) + + return hdr +} + +func TestServer_Wrap(t *testing.T) { + storeConfig := storeconfig.Config{Type: storeconfig.Memory} + ctx := context.TODO() + mock, err := NewNonceStore(ctx, storeConfig, "") + require.NoError(t, err) + + t.Run("Without headers", func(t *testing.T) { + defer func() { + storeConfig := storeconfig.Config{Type: storeconfig.Memory} + nmock, err := NewNonceStore(ctx, storeConfig, "") + require.NoError(t, err) + + mock = nmock + }() + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/foo", nil) + + m := chi.NewRouter() + + m.Use(MakeMiddleware(mock)) + m.Post("/foo", func(writer http.ResponseWriter, request *http.Request) { + httputil.WriteJSON(writer, request, http.StatusOK, "") + }) + + m.ServeHTTP(w, r) + + assert.Equal(t, http.StatusUnauthorized, w.Code, w.Body.String()) + nonce, err := mock.Nonce(context.TODO(), testPubKey) + require.NoError(t, err) + + assert.Equal(t, Nonce(0), nonce) + }) + + t.Run("Context has verified pubkey", func(t *testing.T) { + defer func() { + storeConfig := storeconfig.Config{Type: storeconfig.Memory} + nmock, err := NewNonceStore(ctx, storeConfig, "") + require.NoError(t, err) + + mock = nmock + }() + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/foo", bytes.NewReader([]byte("hi"))) + r.Header = validHeaders(t, []byte("hi")) + + handler := func(writer http.ResponseWriter, request *http.Request) { + pk, ok := request.Context().Value(ContextAuthKey).(cipher.PubKey) + if !ok { + httputil.WriteJSON(writer, request, http.StatusBadRequest, "") + } + if pk != testPubKey { + httputil.WriteJSON(writer, request, http.StatusBadRequest, "") + } + httputil.WriteJSON(writer, request, http.StatusOK, "") + } + + m := chi.NewRouter() + m.Use(MakeMiddleware(mock)) + m.Post("/foo", handler) + m.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + _, err := mock.Nonce(context.TODO(), testPubKey) + require.NoError(t, err) + }) + + t.Run("Valid", func(t *testing.T) { + defer func() { + storeConfig := storeconfig.Config{Type: storeconfig.Memory} + nmock, err := NewNonceStore(ctx, storeConfig, "") + require.NoError(t, err) + + mock = nmock + }() + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/foo", bytes.NewReader([]byte("hi"))) + r.Header = validHeaders(t, []byte("hi")) + + handler := func(writer http.ResponseWriter, request *http.Request) { + httputil.WriteJSON(writer, request, http.StatusOK, "") + } + + m := chi.NewRouter() + m.Use(MakeMiddleware(mock)) + m.Post("/foo", handler) + m.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + nonce, err := mock.Nonce(context.TODO(), testPubKey) + require.NoError(t, err) + + assert.Equal(t, Nonce(1), nonce) + }) + + t.Run("Valid with nonzero nonce", func(t *testing.T) { + _, err := mock.IncrementNonce(context.TODO(), testPubKey) + require.NoError(t, err) + defer func() { + storeConfig := storeconfig.Config{Type: storeconfig.Memory} + nmock, err := NewNonceStore(ctx, storeConfig, "") + require.NoError(t, err) + + mock = nmock + }() + + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/foo", bytes.NewReader([]byte("foo"))) + r.Header = validHeadersWithNonce(t, Nonce(1), []byte("foo")) + + handler := func(writer http.ResponseWriter, request *http.Request) { + httputil.WriteJSON(writer, request, http.StatusOK, "") + } + + m := chi.NewRouter() + m.Use(MakeMiddleware(mock)) + m.Post("/foo", handler) + m.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code, w.Body.String()) + }) + + t.Run("Invalid with nonzero nonce", func(t *testing.T) { + _, err := mock.IncrementNonce(context.TODO(), testPubKey) + require.NoError(t, err) + defer func() { + storeConfig := storeconfig.Config{Type: storeconfig.Memory} + nmock, err := NewNonceStore(ctx, storeConfig, "") + require.NoError(t, err) + + mock = nmock + }() + + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/foo", nil) + r.Header = validHeadersWithNonce(t, Nonce(3), nil) + + handler := func(writer http.ResponseWriter, request *http.Request) { + httputil.WriteJSON(writer, request, http.StatusOK, "") + } + + m := chi.NewRouter() + m.Use(MakeMiddleware(mock)) + m.Post("/foo", handler) + m.ServeHTTP(w, r) + + assert.Equal(t, http.StatusUnauthorized, w.Code, w.Body.String()) + }) + + t.Run("Invalid signature", func(t *testing.T) { + defer func() { + storeConfig := storeconfig.Config{Type: storeconfig.Memory} + nmock, err := NewNonceStore(ctx, storeConfig, "") + require.NoError(t, err) + + mock = nmock + }() + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/foo", nil) + r.Header = invalidHeaders(t, nil) + + handler := func(writer http.ResponseWriter, request *http.Request) { + httputil.WriteJSON(writer, request, http.StatusOK, "") + } + + m := chi.NewRouter() + m.Use(MakeMiddleware(mock)) + m.Post("/foo", handler) + m.ServeHTTP(w, r) + + assert.Equal(t, http.StatusUnauthorized, w.Code, w.Body.String()) + }) +} + +func TestAuthFormat(t *testing.T) { + headers := []string{"SW-Public", "SW-Sig", "SW-Nonce"} + for _, header := range headers { + header := header + t.Run(header+"-IsMissing", func(t *testing.T) { + hdr := validHeaders(t, nil) + hdr.Del(header) + + _, err := AuthFromHeaders(hdr, true) + assert.Error(t, err) + assert.Contains(t, err.Error(), header) + }) + } + + t.Run("NonceFormat", func(t *testing.T) { + nonces := []string{"not_a_number", "-1", "0x0"} + hdr := validHeaders(t, nil) + for _, n := range nonces { + hdr.Set("SW-Nonce", n) + _, err := AuthFromHeaders(hdr, true) + assert.Error(t, err) + assert.Contains(t, err.Error(), "SW-Nonce: invalid syntax") + } + }) +} + +func TestAuthSignatureVerification(t *testing.T) { + nonce := Nonce(0xdeadbeef) + payload := []byte("dead beed") + + sig, err := Sign(payload, nonce, testSec) + require.NoError(t, err) + + auth := &Auth{ + Key: testPubKey, + Nonce: nonce, + Sig: sig, + } + + assert.NoError(t, auth.Verify(payload)) + assert.Error(t, auth.Verify([]byte("other payload")), "Validate should return an error for this payload") +} + +func TestSignatureVerification(t *testing.T) { + pub, sec := cipher.GenerateKeyPair() + payload := []byte("payload to sign") + nonce := Nonce(0xff) + + sig, err := Sign(payload, nonce, sec) + require.NoError(t, err) + require.NoError(t, Verify(payload, nonce, pub, sig)) + require.Error(t, Verify(payload, nonce+1, pub, sig)) +} diff --git a/pkg/skywire-utilities/pkg/httpauth/memory_store.go b/pkg/skywire-utilities/pkg/httpauth/memory_store.go new file mode 100644 index 0000000000..b0b4151dc0 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httpauth/memory_store.go @@ -0,0 +1,62 @@ +// Package httpauth pkg/httpauth/memory_store.go +package httpauth + +import ( + "context" + "sync" + + "github.com/skycoin/skywire-utilities/pkg/cipher" +) + +type memStore struct { + nonces map[cipher.PubKey]Nonce + + err error + mu sync.Mutex +} + +func newMemoryStore() *memStore { + return &memStore{ + nonces: make(map[cipher.PubKey]Nonce), + } +} + +func (s *memStore) SetError(err error) { + s.mu.Lock() + s.err = err + s.mu.Unlock() +} + +func (s *memStore) Nonce(_ context.Context, pk cipher.PubKey) (Nonce, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return 0, s.err + } + + return s.nonces[pk], nil +} + +func (s *memStore) IncrementNonce(_ context.Context, pk cipher.PubKey) (Nonce, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return 0, s.err + } + + s.nonces[pk]++ + return s.nonces[pk], nil +} + +func (s *memStore) Count(_ context.Context) (n int, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.err != nil { + return 0, s.err + } + + return len(s.nonces), nil +} diff --git a/pkg/skywire-utilities/pkg/httpauth/nonce-storer.go b/pkg/skywire-utilities/pkg/httpauth/nonce-storer.go new file mode 100644 index 0000000000..e2f8fe7bb7 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httpauth/nonce-storer.go @@ -0,0 +1,43 @@ +// Package httpauth pkg/httpauth/nonce-storer.go +package httpauth + +import ( + "context" + "fmt" + + "github.com/skycoin/skywire-utilities/pkg/cipher" + "github.com/skycoin/skywire-utilities/pkg/storeconfig" +) + +// Nonce is used to sign requests in order to avoid replay attack +type Nonce uint64 + +func (n Nonce) String() string { return fmt.Sprintf("%d", n) } + +// NonceStore stores Incrementing Security Nonces. +type NonceStore interface { + + // IncrementNonce increments the nonce associated with the specified remote entity. + // It returns the next expected nonce after it has been incremented and returns error on failure. + IncrementNonce(ctx context.Context, remotePK cipher.PubKey) (nonce Nonce, err error) + + // Nonce obtains the next expected nonce for a given remote entity (represented by public key). + // It returns error on failure. + Nonce(ctx context.Context, remotePK cipher.PubKey) (nonce Nonce, err error) + + // Count obtains the number of entries stored in the underlying database. + Count(ctx context.Context) (n int, err error) +} + +// NewNonceStore returns a new nonce storer of the given kind that connects to given Store's url. +// Nonce count should not be shared between services, so it should be stored in a unique key for every service. +func NewNonceStore(ctx context.Context, config storeconfig.Config, prefix string) (NonceStore, error) { + switch config.Type { + case storeconfig.Redis: + return newRedisStore(ctx, config.URL, config.Password, prefix) + case storeconfig.Memory: + return newMemoryStore(), nil + } + + return nil, fmt.Errorf("kind has to be either redis or memory") +} diff --git a/pkg/skywire-utilities/pkg/httpauth/redis-store.go b/pkg/skywire-utilities/pkg/httpauth/redis-store.go new file mode 100644 index 0000000000..f5525b5d67 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httpauth/redis-store.go @@ -0,0 +1,89 @@ +// Package httpauth pkg/httpauth/redis-store.go +package httpauth + +import ( + "context" + "fmt" + "log" + "strconv" + "time" + + "github.com/go-redis/redis/v8" + + "github.com/skycoin/skywire-utilities/pkg/cipher" +) + +type redisStore struct { + client *redis.Client + prefix string +} + +func newRedisStore(ctx context.Context, addr, password, prefix string) (*redisStore, error) { + opt, err := redis.ParseURL(addr) + if err != nil { + return nil, fmt.Errorf("addr: %w", err) + } + + opt.Password = password + opt.ReadTimeout = time.Minute + opt.WriteTimeout = 5 * time.Second + opt.PoolTimeout = 10 * time.Second + opt.IdleCheckFrequency = 5 * time.Second + opt.PoolSize = 200 + + if prefix != "" { + prefix += ":" + } + + redisCl := redis.NewClient(opt) + if err := redisCl.Ping(ctx).Err(); err != nil { + log.Fatalf("Failed to connect to Redis cluster: %v", err) + } + + store := &redisStore{ + client: redisCl, + prefix: prefix, + } + + return store, nil +} + +func (s *redisStore) key(v string) string { + return s.prefix + v +} + +func (s *redisStore) Nonce(ctx context.Context, remotePK cipher.PubKey) (Nonce, error) { + nonce, err := s.client.Get(ctx, s.key(fmt.Sprintf("nonces:%s", remotePK))).Result() + if err != nil { + return 0, nil + } + + n, err := strconv.Atoi(nonce) + if err != nil { + return 0, fmt.Errorf("malformed nonce: %s", nonce) + } + return Nonce(n), nil //nolint +} + +func (s *redisStore) IncrementNonce(ctx context.Context, remotePK cipher.PubKey) (Nonce, error) { + nonce, err := s.client.Incr(ctx, s.key(fmt.Sprintf("nonces:%s", remotePK))).Result() + if err != nil { + return 0, fmt.Errorf("redis: %w", err) + } + + _, err = s.client.SAdd(ctx, s.key("nonces"), remotePK).Result() + if err != nil { + return 0, fmt.Errorf("redis: %w", err) + } + + return Nonce(nonce), nil //nolint +} + +func (s *redisStore) Count(ctx context.Context) (n int, err error) { + size, err := s.client.SCard(ctx, s.key("nonces")).Result() + if err != nil { + return 0, fmt.Errorf("redis: %w", err) + } + + return int(size), nil +} diff --git a/pkg/skywire-utilities/pkg/httputil/dmsghttp.go b/pkg/skywire-utilities/pkg/httputil/dmsghttp.go new file mode 100644 index 0000000000..c2a1b59446 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httputil/dmsghttp.go @@ -0,0 +1,21 @@ +// Package httputil pkg/httputil/dmsghttp.go +package httputil + +// DMSGHTTPConf is struct of /dmsghttp endpoint of config bootstrap +type DMSGHTTPConf struct { + DMSGServers []DMSGServersConf `json:"dmsg_servers"` + DMSGDiscovery string `json:"dmsg_discovery"` + TranspordDiscovery string `json:"transport_discovery"` + AddressResolver string `json:"address_resolver"` + RouteFinder string `json:"route_finder"` + UptimeTracker string `json:"uptime_tracker"` + ServiceDiscovery string `json:"service_discovery"` +} + +// DMSGServersConf is struct of dmsg servers list on /dmsghttp endpoint +type DMSGServersConf struct { + Static string `json:"static"` + Server struct { + Address string `json:"address"` + } `json:"server"` +} diff --git a/pkg/skywire-utilities/pkg/httputil/error.go b/pkg/skywire-utilities/pkg/httputil/error.go new file mode 100644 index 0000000000..38e5c94165 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httputil/error.go @@ -0,0 +1,61 @@ +// Package httputil pkg/httputil/error.go +package httputil + +import ( + "bytes" + "fmt" + "io" + "net/http" +) + +// HTTPError represents an http error associated with a server response. +type HTTPError struct { + Status int + Body string +} + +// Error is the object returned to the client when there's an error. +type Error struct { + Error string `json:"error"` +} + +// ErrorFromResp creates an HTTPError from a given server response. +func ErrorFromResp(resp *http.Response) error { + status := resp.StatusCode + if status >= 200 && status < 300 { + return nil + } + msg, err := io.ReadAll(resp.Body) + if err != nil && len(msg) == 0 { + msg = []byte(fmt.Sprintf("failed to read HTTP response body: %v", err)) + } + return &HTTPError{Status: status, Body: string(bytes.TrimSpace(msg))} +} + +// Error returns the error message. +func (e *HTTPError) Error() string { + return fmt.Sprintf("%d %s: %v", e.Status, http.StatusText(e.Status), e.Body) +} + +// Timeout implements net.Error +func (e *HTTPError) Timeout() bool { + switch e.Status { + case http.StatusGatewayTimeout, http.StatusRequestTimeout: + return true + default: + return false + } +} + +// Temporary implements net.Error +func (e *HTTPError) Temporary() bool { + if e.Timeout() { + return true + } + switch e.Status { + case http.StatusServiceUnavailable, http.StatusTooManyRequests: + return true + default: + return false + } +} diff --git a/pkg/skywire-utilities/pkg/httputil/health.go b/pkg/skywire-utilities/pkg/httputil/health.go new file mode 100644 index 0000000000..015071d7e3 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httputil/health.go @@ -0,0 +1,45 @@ +// Package httputil pkg/httputil/health.go +package httputil + +import ( + "context" + "net/http" + "time" + + "github.com/skycoin/skywire-utilities/pkg/buildinfo" +) + +var path = "/health" + +// HealthCheckResponse is struct of /health endpoint +type HealthCheckResponse struct { + BuildInfo *buildinfo.Info `json:"build_info,omitempty"` + StartedAt time.Time `json:"started_at"` + DmsgAddr string `json:"dmsg_address,omitempty"` + DmsgServers []string `json:"dmsg_servers,omitempty"` +} + +// GetServiceHealth gets the response from the given service url +func GetServiceHealth(_ context.Context, url string) (health *HealthCheckResponse, err error) { + resp, err := http.Get(url + path) + if err != nil { + return nil, err + } + if resp != nil { + defer func() { + if cErr := resp.Body.Close(); cErr != nil && err == nil { + err = cErr + } + }() + } + if resp.StatusCode != http.StatusOK { + var hErr HTTPError + if err = json.NewDecoder(resp.Body).Decode(&hErr); err != nil { + return nil, err + } + return nil, &hErr + } + err = json.NewDecoder(resp.Body).Decode(&health) + + return health, nil +} diff --git a/pkg/skywire-utilities/pkg/httputil/httputil.go b/pkg/skywire-utilities/pkg/httputil/httputil.go new file mode 100644 index 0000000000..f63b08baaf --- /dev/null +++ b/pkg/skywire-utilities/pkg/httputil/httputil.go @@ -0,0 +1,105 @@ +// Package httputil pkg/httputil/httputil.go +package httputil + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/go-chi/chi/v5/middleware" + jsoniter "github.com/json-iterator/go" + "github.com/sirupsen/logrus" + + "github.com/skycoin/skywire-utilities/pkg/logging" +) + +var json = jsoniter.ConfigFastest + +var log = logging.MustGetLogger("httputil") + +// WriteJSON writes a json object on a http.ResponseWriter with the given code, +// panics on marshaling error +func WriteJSON(w http.ResponseWriter, r *http.Request, code int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + enc := json.NewEncoder(w) + pretty, err := BoolFromQuery(r, "pretty", false) + if err != nil { + log.WithError(err).Warn("Failed to get bool from query") + } + if pretty { + enc.SetIndent("", " ") + } + if err, ok := v.(error); ok { + v = map[string]interface{}{"error": err.Error()} + } + if err := json.NewEncoder(w).Encode(v); err != nil { + panic(err) + } +} + +// ReadJSON reads the request body to a json object. +func ReadJSON(r *http.Request, v interface{}) error { + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + return dec.Decode(v) +} + +// BoolFromQuery obtains a boolean from a query entry. +func BoolFromQuery(r *http.Request, key string, defaultVal bool) (bool, error) { + switch q := r.URL.Query().Get(key); q { + case "true", "on", "1": + return true, nil + case "false", "off", "0": + return false, nil + case "": + return defaultVal, nil + default: + return false, fmt.Errorf("invalid '%s' query value of '%s'", key, q) + } +} + +// SplitRPCAddr returns host and port and whatever error results from parsing the rpc address interface +func SplitRPCAddr(rpcAddr string) (host string, port uint16, err error) { + addrToken := strings.Split(rpcAddr, ":") + uint64port, err := strconv.ParseUint(addrToken[1], 10, 16) + if err != nil { + return + } + + return addrToken[0], uint16(uint64port), nil +} + +type ctxKeyLogger int + +// LoggerKey defines logger HTTP context key. +const LoggerKey ctxKeyLogger = -1 + +// GetLogger returns logger from HTTP context. +func GetLogger(r *http.Request) logrus.FieldLogger { + if log, ok := r.Context().Value(LoggerKey).(logrus.FieldLogger); ok && log != nil { + return log + } + + return logging.NewMasterLogger() +} + +// todo: investigate if it's used throughout the services (didn't work properly for UT) +// remove and use structured logging + +// SetLoggerMiddleware sets logger to context of HTTP requests. +func SetLoggerMiddleware(log logrus.FieldLogger) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if reqID := middleware.GetReqID(ctx); reqID != "" && log != nil { + ctx = context.WithValue(ctx, LoggerKey, log.WithField("RequestID", reqID)) + } + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} diff --git a/pkg/skywire-utilities/pkg/httputil/log.go b/pkg/skywire-utilities/pkg/httputil/log.go new file mode 100644 index 0000000000..c6d1900432 --- /dev/null +++ b/pkg/skywire-utilities/pkg/httputil/log.go @@ -0,0 +1,57 @@ +// Package httputil pkg/httputil/log.go +package httputil + +import ( + "context" + "net/http" + "time" + + "github.com/go-chi/chi/v5/middleware" + "github.com/sirupsen/logrus" +) + +type structuredLogger struct { + logger logrus.FieldLogger +} + +// NewLogMiddleware creates a new instance of logging middleware. This will allow +// adding log fields in the handler and any further middleware. At the end of request, this +// log entry will be printed at Info level via passed logger +func NewLogMiddleware(logger logrus.FieldLogger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + sl := &structuredLogger{logger} + start := time.Now() + var requestID string + if reqID := r.Context().Value(middleware.RequestIDKey); reqID != nil { + requestID = reqID.(string) + } + ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) + newContext := context.WithValue(r.Context(), middleware.LogEntryCtxKey, sl) + next.ServeHTTP(ww, r.WithContext(newContext)) + latency := time.Since(start) + fields := logrus.Fields{ + "status": ww.Status(), + "took": latency, + "remote": r.RemoteAddr, + "request": r.RequestURI, + "method": r.Method, + } + if requestID != "" { + fields["request_id"] = requestID + } + sl.logger.WithFields(fields).Info() + + } + return http.HandlerFunc(fn) + } +} + +// LogEntrySetField adds new key-value pair to current (request scoped) log entry. This pair will be +// printed along with all other pairs when the request is served. +// This requires log middleware from this package to be installed in the chain +func LogEntrySetField(r *http.Request, key string, value interface{}) { + if sl, ok := r.Context().Value(middleware.LogEntryCtxKey).(*structuredLogger); ok { + sl.logger = sl.logger.WithField(key, value) + } +} diff --git a/pkg/skywire-utilities/pkg/logging/formatter.go b/pkg/skywire-utilities/pkg/logging/formatter.go new file mode 100644 index 0000000000..60b998045e --- /dev/null +++ b/pkg/skywire-utilities/pkg/logging/formatter.go @@ -0,0 +1,449 @@ +// Package logging pkg/logging/formatter.go +package logging + +import ( + "bytes" + "fmt" + "io" + "os" + "sort" + "strings" + "sync" + "time" + + "github.com/mgutz/ansi" + "github.com/sirupsen/logrus" + "golang.org/x/term" +) + +const defaultTimestampFormat = time.RFC3339 + +var ( + baseTimestamp = time.Now() + defaultColorScheme = &ColorScheme{ + InfoLevelStyle: "green", + WarnLevelStyle: "yellow", + ErrorLevelStyle: "red", + FatalLevelStyle: "red", + PanicLevelStyle: "red", + DebugLevelStyle: "blue", + TraceLevelStyle: "black", + PrefixStyle: "cyan", + TimestampStyle: "black+h", + CallContextStyle: "black+h", + CriticalStyle: "magenta+h", + } + noColorsColorScheme = &compiledColorScheme{ + InfoLevelColor: ansi.ColorFunc(""), + WarnLevelColor: ansi.ColorFunc(""), + ErrorLevelColor: ansi.ColorFunc(""), + FatalLevelColor: ansi.ColorFunc(""), + PanicLevelColor: ansi.ColorFunc(""), + DebugLevelColor: ansi.ColorFunc(""), + TraceLevelColor: ansi.ColorFunc(""), + PrefixColor: ansi.ColorFunc(""), + TimestampColor: ansi.ColorFunc(""), + CallContextColor: ansi.ColorFunc(""), + CriticalColor: ansi.ColorFunc(""), + } + defaultCompiledColorScheme = compileColorScheme(defaultColorScheme) +) + +func miniTS() int { + return int(time.Since(baseTimestamp) / time.Second) +} + +// ColorScheme configures the logging output colors +type ColorScheme struct { + InfoLevelStyle string + WarnLevelStyle string + ErrorLevelStyle string + FatalLevelStyle string + PanicLevelStyle string + DebugLevelStyle string + TraceLevelStyle string + PrefixStyle string + TimestampStyle string + CallContextStyle string + CriticalStyle string +} + +type compiledColorScheme struct { + InfoLevelColor func(string) string + WarnLevelColor func(string) string + ErrorLevelColor func(string) string + FatalLevelColor func(string) string + PanicLevelColor func(string) string + DebugLevelColor func(string) string + TraceLevelColor func(string) string + PrefixColor func(string) string + TimestampColor func(string) string + CallContextColor func(string) string + CriticalColor func(string) string +} + +// TextFormatter formats log output +type TextFormatter struct { + // Set to true to bypass checking for a TTY before outputting colors. + ForceColors bool + + // Force disabling colors. For a TTY colors are enabled by default. + DisableColors bool + + // Force formatted layout, even for non-TTY output. + ForceFormatting bool + + // Disable timestamp logging. useful when output is redirected to logging + // system that already adds timestamps. + DisableTimestamp bool + + // Disable the conversion of the log levels to uppercase + DisableUppercase bool + + // Enable logging the full timestamp when a TTY is attached instead of just + // the time passed since beginning of execution. + FullTimestamp bool + + // Timestamp format to use for display when a full timestamp is printed. + TimestampFormat string + + // The fields are sorted by default for a consistent output. For applications + // that log extremely frequently and don't use the JSON formatter this may not + // be desired. + DisableSorting bool + + // Wrap empty fields in quotes if true. + QuoteEmptyFields bool + + // Can be set to the override the default quoting character " + // with something else. For example: ', or `. + QuoteCharacter string + + // Pad msg field with spaces on the right for display. + // The value for this parameter will be the size of padding. + // Its default value is zero, which means no padding will be applied for msg. + SpacePadding int + + // Always use quotes for string values (except for empty fields) + AlwaysQuoteStrings bool + + // Color scheme to use. + colorScheme *compiledColorScheme + + // Whether the logger's out is to a terminal. + isTerminal bool + + sync.Once +} + +func getCompiledColor(main string, fallback string) func(string) string { + var style string + if main != "" { + style = main + } else { + style = fallback + } + return ansi.ColorFunc(style) +} + +func compileColorScheme(s *ColorScheme) *compiledColorScheme { + return &compiledColorScheme{ + InfoLevelColor: getCompiledColor(s.InfoLevelStyle, defaultColorScheme.InfoLevelStyle), + WarnLevelColor: getCompiledColor(s.WarnLevelStyle, defaultColorScheme.WarnLevelStyle), + ErrorLevelColor: getCompiledColor(s.ErrorLevelStyle, defaultColorScheme.ErrorLevelStyle), + FatalLevelColor: getCompiledColor(s.FatalLevelStyle, defaultColorScheme.FatalLevelStyle), + PanicLevelColor: getCompiledColor(s.PanicLevelStyle, defaultColorScheme.PanicLevelStyle), + DebugLevelColor: getCompiledColor(s.DebugLevelStyle, defaultColorScheme.DebugLevelStyle), + TraceLevelColor: getCompiledColor(s.TraceLevelStyle, defaultColorScheme.TraceLevelStyle), + PrefixColor: getCompiledColor(s.PrefixStyle, defaultColorScheme.PrefixStyle), + TimestampColor: getCompiledColor(s.TimestampStyle, defaultColorScheme.TimestampStyle), + CallContextColor: getCompiledColor(s.CallContextStyle, defaultColorScheme.CallContextStyle), + CriticalColor: getCompiledColor(s.CriticalStyle, defaultColorScheme.CriticalStyle), + } +} + +func (f *TextFormatter) init(entry *logrus.Entry) { + if len(f.QuoteCharacter) == 0 { + f.QuoteCharacter = "\"" + } + if entry.Logger != nil { + f.isTerminal = f.checkIfTerminal(entry.Logger.Out) + } +} + +func (f *TextFormatter) checkIfTerminal(w io.Writer) bool { + switch v := w.(type) { + case *os.File: + return term.IsTerminal(int(v.Fd())) + default: + return false + } +} + +// SetColorScheme sets the TextFormatter's color scheme configuration +func (f *TextFormatter) SetColorScheme(colorScheme *ColorScheme) { + f.colorScheme = compileColorScheme(colorScheme) +} + +// Format formats a logrus.Entry +func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) { + var b *bytes.Buffer + keys := make([]string, 0, len(entry.Data)) + for k := range entry.Data { + keys = append(keys, k) + } + lastKeyIdx := len(keys) - 1 + + if !f.DisableSorting { + sort.Strings(keys) + } + if entry.Buffer != nil { + b = entry.Buffer + } else { + b = &bytes.Buffer{} + } + + f.Do(func() { f.init(entry) }) + + isFormatted := f.ForceFormatting || f.isTerminal + + timestampFormat := f.TimestampFormat + if timestampFormat == "" { + timestampFormat = defaultTimestampFormat + } + if isFormatted { + isColored := (f.ForceColors || f.isTerminal) && !f.DisableColors + var colorScheme *compiledColorScheme + if isColored { + if f.colorScheme == nil { + colorScheme = defaultCompiledColorScheme + } else { + colorScheme = f.colorScheme + } + } else { + colorScheme = noColorsColorScheme + } + f.printColored(b, entry, keys, timestampFormat, colorScheme) + } else { + if !f.DisableTimestamp { + f.appendKeyValue(b, "time", entry.Time.Format(timestampFormat), true) + } + f.appendKeyValue(b, "level", entry.Level.String(), true) + if entry.Message != "" { + f.appendKeyValue(b, "msg", entry.Message, lastKeyIdx >= 0) + } + for i, key := range keys { + f.appendKeyValue(b, key, entry.Data[key], lastKeyIdx != i) + } + } + + b.WriteByte('\n') //nolint:gosec + return b.Bytes(), nil +} + +func (f *TextFormatter) printColored(b *bytes.Buffer, entry *logrus.Entry, keys []string, timestampFormat string, colorScheme *compiledColorScheme) { + var levelColor func(string) string + var levelText string + switch entry.Level { + case logrus.InfoLevel: + levelColor = colorScheme.InfoLevelColor + case logrus.WarnLevel: + levelColor = colorScheme.WarnLevelColor + case logrus.ErrorLevel: + levelColor = colorScheme.ErrorLevelColor + case logrus.FatalLevel: + levelColor = colorScheme.FatalLevelColor + case logrus.PanicLevel: + levelColor = colorScheme.PanicLevelColor + case logrus.TraceLevel: + levelColor = colorScheme.TraceLevelColor + default: + levelColor = colorScheme.DebugLevelColor + } + + priority, ok := entry.Data[logPriorityKey] + hasPriority := ok && priority == logPriorityCritical + + if entry.Level != logrus.WarnLevel { + levelText = entry.Level.String() + } else { + levelText = "warn" + } + + if !f.DisableUppercase { + levelText = strings.ToUpper(levelText) + } + + level := levelColor(levelText) + message := entry.Message + prefix := "" + + prefixText := extractPrefix(entry) + if prefixText != "" { + prefixText = " " + prefixText + ":" + prefix = colorScheme.PrefixColor(prefixText) + } + + messageFormat := "%s" + if f.SpacePadding != 0 { + messageFormat = fmt.Sprintf("%%-%ds", f.SpacePadding) + } + if message != "" { + messageFormat = " " + messageFormat + } + + callContextParts := []string{} + if ifile, ok := entry.Data["file"]; ok { + if sfile, ok := ifile.(string); ok && sfile != "" { + callContextParts = append(callContextParts, sfile) + } + } + if ifunc, ok := entry.Data["func"]; ok { + if sfunc, ok := ifunc.(string); ok && sfunc != "" { + callContextParts = append(callContextParts, sfunc) + } + } + if iline, ok := entry.Data["line"]; ok { + sline := "" + switch iline := iline.(type) { + case string: + sline = iline + case int, uint, int32, int64, uint32, uint64: + sline = fmt.Sprint(iline) + } + if sline != "" { + callContextParts = append(callContextParts, fmt.Sprint(sline)) + } + } + callContextText := strings.Join(callContextParts, ":") + callContext := colorScheme.CallContextColor(callContextText) + if callContext != "" { + callContext = " " + callContext + } + + if f.DisableTimestamp { + if hasPriority { + str := fmt.Sprintf("%s%s%s"+messageFormat, levelText, callContextText, prefixText, message) + fmt.Fprint(b, colorScheme.CriticalColor(str)) + } else { + fmt.Fprintf(b, "%s%s%s"+messageFormat, level, callContext, prefix, message) + } + } else { + var timestamp string + if !f.FullTimestamp { + timestamp = fmt.Sprintf("[%04d]", miniTS()) + } else { + timestamp = fmt.Sprintf("[%s]", entry.Time.Format(timestampFormat)) + } + + coloredTimestamp := colorScheme.TimestampColor(timestamp) + + if hasPriority { + str := fmt.Sprintf("%s %s%s%s"+messageFormat, timestamp, levelText, callContextText, prefixText, message) + fmt.Fprint(b, colorScheme.CriticalColor(str)) + } else { + fmt.Fprintf(b, "%s %s%s%s"+messageFormat, coloredTimestamp, level, callContext, prefix, message) + } + } + + for _, k := range keys { + if k != "prefix" && k != "file" && k != "func" && k != "line" && k != logPriorityKey && k != logModuleKey { + v := entry.Data[k] + fmt.Fprintf(b, " %s", f.formatKeyValue(levelColor(k), v)) + } + } +} + +func (f *TextFormatter) needsQuoting(text string) bool { + if len(text) == 0 { + return f.QuoteEmptyFields + } + + if f.AlwaysQuoteStrings { + return true + } + + for _, ch := range text { + if !((ch >= 'a' && ch <= 'z') || + (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || + ch == '-' || ch == '.') { + return true + } + } + + return false +} + +func extractPrefix(e *logrus.Entry) string { + var module string + if iModule, ok := e.Data[logModuleKey]; ok { + module, _ = iModule.(string) + } + + var priority string + if iPriority, ok := e.Data[logPriorityKey]; ok { + priority, _ = iPriority.(string) + } + + switch { + case priority == "": + return fmt.Sprintf("[%s]", module) + case module == "": + return fmt.Sprintf("[%s]", priority) + default: + return fmt.Sprintf("[%s:%s]", module, priority) + } +} + +func (f *TextFormatter) formatKeyValue(key string, value interface{}) string { + return fmt.Sprintf("%s=%s", key, f.formatValue(value)) +} + +func (f *TextFormatter) formatValue(value interface{}) string { + switch value := value.(type) { + case string: + if f.needsQuoting(value) { + return fmt.Sprintf("%s%+v%s", f.QuoteCharacter, value, f.QuoteCharacter) + } + return value + case error: + errmsg := value.Error() + if f.needsQuoting(errmsg) { + return fmt.Sprintf("%s%+v%s", f.QuoteCharacter, errmsg, f.QuoteCharacter) + } + return errmsg + default: + return fmt.Sprintf("%+v", value) + } +} + +func (f *TextFormatter) appendKeyValue(b *bytes.Buffer, key string, value interface{}, appendSpace bool) { + b.WriteString(key) //nolint:gosec + b.WriteByte('=') //nolint:gosec + f.appendValue(b, value) //nolint:gosec + + if appendSpace { + b.WriteByte(' ') //nolint:gosec + } +} + +func (f *TextFormatter) appendValue(b *bytes.Buffer, value interface{}) { + switch value := value.(type) { + case string: + if f.needsQuoting(value) { + fmt.Fprintf(b, "%s%+v%s", f.QuoteCharacter, value, f.QuoteCharacter) + } else { + b.WriteString(value) //nolint:gosec + } + case error: + errmsg := value.Error() + if f.needsQuoting(errmsg) { + fmt.Fprintf(b, "%s%+v%s", f.QuoteCharacter, errmsg, f.QuoteCharacter) + } else { + b.WriteString(errmsg) //nolint:gosec + } + default: + fmt.Fprint(b, value) + } +} diff --git a/pkg/skywire-utilities/pkg/logging/hooks.go b/pkg/skywire-utilities/pkg/logging/hooks.go new file mode 100644 index 0000000000..8a74fbf2b6 --- /dev/null +++ b/pkg/skywire-utilities/pkg/logging/hooks.go @@ -0,0 +1,45 @@ +// Package logging pkg/logging/hooks.go +package logging + +import ( + "io" + + "github.com/sirupsen/logrus" +) + +// WriteHook is a logrus.Hook that logs to an io.Writer +type WriteHook struct { + w io.Writer + formatter logrus.Formatter +} + +// NewWriteHook returns a new WriteHook +func NewWriteHook(w io.Writer) *WriteHook { + return &WriteHook{ + w: w, + formatter: &TextFormatter{ + DisableColors: true, + FullTimestamp: true, + AlwaysQuoteStrings: true, + QuoteEmptyFields: true, + ForceFormatting: true, + }, + } +} + +// Levels returns Levels accepted by the WriteHook. +// All logrus.Levels are returned. +func (f *WriteHook) Levels() []logrus.Level { + return logrus.AllLevels +} + +// Fire writes a logrus.Entry to the file +func (f *WriteHook) Fire(e *logrus.Entry) error { + b, err := f.formatter.Format(e) + if err != nil { + return err + } + + _, err = f.w.Write(b) + return err +} diff --git a/pkg/skywire-utilities/pkg/logging/logger.go b/pkg/skywire-utilities/pkg/logging/logger.go new file mode 100644 index 0000000000..5f2f3f13a4 --- /dev/null +++ b/pkg/skywire-utilities/pkg/logging/logger.go @@ -0,0 +1,69 @@ +// Package logging pkg/logging/logger.go +package logging + +import ( + "os" + "time" + + "github.com/sirupsen/logrus" +) + +// Logger wraps logrus.FieldLogger +type Logger struct { + logrus.FieldLogger +} + +// Critical adds special critical-level fields for specially highlighted logging, +// since logrus lacks a distinct critical field and does not have configurable log levels +func (logger *Logger) Critical() logrus.FieldLogger { + return logger.WithField(logPriorityKey, logPriorityCritical) +} + +// WithTime overrides time, used by logger. +func (logger *Logger) WithTime(t time.Time) *logrus.Entry { + return logger.WithFields(logrus.Fields{}).WithTime(t) +} + +// MasterLogger wraps logrus.Logger and is able to create new package-aware loggers +type MasterLogger struct { + *logrus.Logger +} + +// NewMasterLogger creates a new package-aware logger with formatting string +func NewMasterLogger() *MasterLogger { + hooks := make(logrus.LevelHooks) + + return &MasterLogger{ + Logger: &logrus.Logger{ + Out: os.Stdout, + Formatter: &TextFormatter{ + FullTimestamp: true, + AlwaysQuoteStrings: true, + QuoteEmptyFields: true, + ForceFormatting: true, + DisableColors: false, + ForceColors: false, + TimestampFormat: "2006-01-02T15:04:05.999999999Z07:00", + }, + Hooks: hooks, + Level: logrus.DebugLevel, + }, + } +} + +// PackageLogger instantiates a package-aware logger +func (logger *MasterLogger) PackageLogger(moduleName string) *Logger { + return &Logger{ + FieldLogger: logger.WithField(logModuleKey, moduleName), + } +} + +// EnableColors enables colored logging +func (logger *MasterLogger) EnableColors() { + logger.Formatter.(*TextFormatter).DisableColors = false +} + +// DisableColors disables colored logging +func (logger *MasterLogger) DisableColors() { + logger.Formatter.(*TextFormatter).DisableColors = true +} diff --git a/pkg/skywire-utilities/pkg/logging/logging.go b/pkg/skywire-utilities/pkg/logging/logging.go new file mode 100644 index 0000000000..cd0d2a352e --- /dev/null +++ b/pkg/skywire-utilities/pkg/logging/logging.go @@ -0,0 +1,85 @@ +/* +Package logging provides application logging utilities +*/ +package logging + +import ( + "errors" + "io" + "strings" + + "github.com/sirupsen/logrus" +) + +var log = NewMasterLogger() + +const ( + // logModuleKey is the key used for the module name data entry + logModuleKey = "_module" + // logPriorityKey is the log entry key for priority log statements + logPriorityKey = "_priority" + // logPriorityCritical is the log entry value for priority log statements + logPriorityCritical = "CRITICAL" +) + +// LevelFromString returns a logrus.Level from a string identifier +func LevelFromString(s string) (logrus.Level, error) { + switch strings.ToLower(s) { + case "debug": + return logrus.DebugLevel, nil + case "info", "notice": + return logrus.InfoLevel, nil + case "warn", "warning": + return logrus.WarnLevel, nil + case "error": + return logrus.ErrorLevel, nil + case "fatal", "critical": + return logrus.FatalLevel, nil + case "panic": + return logrus.PanicLevel, nil + case "trace": + return logrus.TraceLevel, nil + default: + return logrus.DebugLevel, errors.New("could not convert string to log level") + } +} + +// MustGetLogger returns a package-aware logger from the master logger +func MustGetLogger(module string) *Logger { + return log.PackageLogger(module) +} + +// AddHook adds a hook to the global logger +func AddHook(hook logrus.Hook) { + log.AddHook(hook) +} + +// EnableColors enables colored logging +func EnableColors() { + log.EnableColors() +} + +// DisableColors disables colored logging +func DisableColors() { + log.DisableColors() +} + +// SetLevel sets the logger's minimum log level +func SetLevel(level logrus.Level) { + log.SetLevel(level) +} + +// GetLevel returns the logger level +func GetLevel() logrus.Level { + return log.GetLevel() +} + +// SetOutputTo sets the logger's output to an io.Writer +func SetOutputTo(w io.Writer) { + log.Out = w +} + +// Disable disables the logger completely +func Disable() { + log.Out = io.Discard +} diff --git a/pkg/skywire-utilities/pkg/metricsutil/http.go b/pkg/skywire-utilities/pkg/metricsutil/http.go new file mode 100644 index 0000000000..6957dc2909 --- /dev/null +++ b/pkg/skywire-utilities/pkg/metricsutil/http.go @@ -0,0 +1,39 @@ +// Package metricsutil pkg/metricsutil/http.go +package metricsutil + +import ( + "net/http" + + "github.com/VictoriaMetrics/metrics" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/sirupsen/logrus" +) + +// AddMetricsHandler adds a prometheus-format Handle at '/metrics' to the provided serve mux. +func AddMetricsHandler(mux *chi.Mux) { + mux.HandleFunc("/metrics", func(w http.ResponseWriter, _ *http.Request) { + metrics.WritePrometheus(w, true) + }) +} + +// ServeHTTPMetrics starts serving metrics on a given `addr`. +func ServeHTTPMetrics(log logrus.FieldLogger, addr string) { + if addr == "" { + return + } + + r := chi.NewRouter() + + r.Use(middleware.RequestID) + r.Use(middleware.RealIP) + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) + + AddMetricsHandler(r) + + log.WithField("addr", addr).Info("Serving metrics.") + go func() { + log.Fatal(http.ListenAndServe(addr, r)) //nolint + }() +} diff --git a/pkg/skywire-utilities/pkg/metricsutil/request_duration_middleware.go b/pkg/skywire-utilities/pkg/metricsutil/request_duration_middleware.go new file mode 100644 index 0000000000..050c7878cc --- /dev/null +++ b/pkg/skywire-utilities/pkg/metricsutil/request_duration_middleware.go @@ -0,0 +1,26 @@ +// Package metricsutil pkg/metricsutil/request_duration_middleware.go +package metricsutil + +import ( + "fmt" + "net/http" + "time" + + "github.com/VictoriaMetrics/metrics" +) + +// RequestDurationMiddleware is a request duration tracking middleware. +func RequestDurationMiddleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + srw := NewStatusResponseWriter(w) + + reqStart := time.Now() + next.ServeHTTP(srw, r) + reqDuration := time.Since(reqStart) + + hName := fmt.Sprintf("vm_request_duration{code=\"%d\", method=\"%s\"}", srw.StatusCode(), r.Method) + metrics.GetOrCreateHistogram(hName).Update(reqDuration.Seconds()) + } + + return http.HandlerFunc(fn) +} diff --git a/pkg/skywire-utilities/pkg/metricsutil/requests_in_flight_count_middleware.go b/pkg/skywire-utilities/pkg/metricsutil/requests_in_flight_count_middleware.go new file mode 100644 index 0000000000..204edae914 --- /dev/null +++ b/pkg/skywire-utilities/pkg/metricsutil/requests_in_flight_count_middleware.go @@ -0,0 +1,30 @@ +// Package metricsutil pkg/metricsutil/requests_in_flight_count_middleware.go +package metricsutil + +import ( + "net/http" +) + +// RequestsInFlightCountMiddleware is a middleware to track current requests-in-flight count. +type RequestsInFlightCountMiddleware struct { + reqsInFlightGauge *VictoriaMetricsIntGaugeWrapper +} + +// NewRequestsInFlightCountMiddleware constructs `RequestsInFlightCountMiddleware`. +func NewRequestsInFlightCountMiddleware() *RequestsInFlightCountMiddleware { + return &RequestsInFlightCountMiddleware{ + reqsInFlightGauge: NewVictoriaMetricsIntGauge("vm_request_ongoing_count"), + } +} + +// Handle adds to the requests count during request serving. +func (m *RequestsInFlightCountMiddleware) Handle(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + m.reqsInFlightGauge.Inc() + defer m.reqsInFlightGauge.Dec() + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} diff --git a/pkg/skywire-utilities/pkg/metricsutil/status_response_writer.go b/pkg/skywire-utilities/pkg/metricsutil/status_response_writer.go new file mode 100644 index 0000000000..a2fd762ef3 --- /dev/null +++ b/pkg/skywire-utilities/pkg/metricsutil/status_response_writer.go @@ -0,0 +1,37 @@ +// Package metricsutil pkg/metricsutil/status_response_writer.go +package metricsutil + +import ( + "net/http" +) + +// StatusResponseWriter wraps `http.ResponseWriter` but stores status code +// on call to `WriteHeader`. +type StatusResponseWriter struct { + http.ResponseWriter + statusCode int +} + +// NewStatusResponseWriter wraps `http.ResponseWriter` constructing `StatusResponseWriter`. +func NewStatusResponseWriter(w http.ResponseWriter) *StatusResponseWriter { + return &StatusResponseWriter{ + ResponseWriter: w, + } +} + +// WriteHeader implements `http.ResponseWriter` storing the written status code. +func (w *StatusResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +// StatusCode gets status code from the writer. +func (w *StatusResponseWriter) StatusCode() int { + if w.statusCode == 0 { + // this is case when `WriteHeader` wasn't called explicitly, + // so we consider it 200 + return http.StatusOK + } + + return w.statusCode +} diff --git a/pkg/skywire-utilities/pkg/metricsutil/victoria_metrics_int_gauge_wrapper.go b/pkg/skywire-utilities/pkg/metricsutil/victoria_metrics_int_gauge_wrapper.go new file mode 100644 index 0000000000..27f46b0bb9 --- /dev/null +++ b/pkg/skywire-utilities/pkg/metricsutil/victoria_metrics_int_gauge_wrapper.go @@ -0,0 +1,46 @@ +// Package metricsutil pkg/metricsutil/victoria_metrics_int_gauge_wrapper.go +package metricsutil + +import ( + "sync/atomic" + + "github.com/VictoriaMetrics/metrics" +) + +// VictoriaMetricsIntGaugeWrapper wraps Victoria Metrics int gauge encapsulating all the +// needed logic to control the value. +type VictoriaMetricsIntGaugeWrapper struct { + val int64 + gauge *metrics.Gauge +} + +// NewVictoriaMetricsIntGauge constructs new wrapper for Victoria Metric int gauge with +// the name `name`. +func NewVictoriaMetricsIntGauge(name string) *VictoriaMetricsIntGaugeWrapper { + var w VictoriaMetricsIntGaugeWrapper + w.gauge = metrics.GetOrCreateGauge(name, func() float64 { + return float64(w.Val()) + }) + + return &w +} + +// Inc increments gauge value. +func (w *VictoriaMetricsIntGaugeWrapper) Inc() { + atomic.AddInt64(&w.val, 1) +} + +// Dec decrements gauge value. +func (w *VictoriaMetricsIntGaugeWrapper) Dec() { + atomic.AddInt64(&w.val, -1) +} + +// Set sets gauge value. +func (w *VictoriaMetricsIntGaugeWrapper) Set(val int64) { + atomic.StoreInt64(&w.val, val) +} + +// Val gets gauge value. +func (w *VictoriaMetricsIntGaugeWrapper) Val() int64 { + return atomic.LoadInt64(&w.val) +} diff --git a/pkg/skywire-utilities/pkg/metricsutil/victoria_metrics_uint_gauge_wrapper.go b/pkg/skywire-utilities/pkg/metricsutil/victoria_metrics_uint_gauge_wrapper.go new file mode 100644 index 0000000000..42254f1ce9 --- /dev/null +++ b/pkg/skywire-utilities/pkg/metricsutil/victoria_metrics_uint_gauge_wrapper.go @@ -0,0 +1,46 @@ +// Package metricsutil pkg/metricsutil/victoria_metrics_uint_gauge_wrapper.go +package metricsutil + +import ( + "sync/atomic" + + "github.com/VictoriaMetrics/metrics" +) + +// VictoriaMetricsUintGaugeWrapper wraps Victoria Metrics int gauge encapsulating all the +// needed logic to control the value. +type VictoriaMetricsUintGaugeWrapper struct { + val uint64 + gauge *metrics.Gauge +} + +// NewVictoriaMetricsUintGauge constructs new wrapper for Victoria Metric int gauge with +// the name `name`. +func NewVictoriaMetricsUintGauge(name string) *VictoriaMetricsUintGaugeWrapper { + var w VictoriaMetricsUintGaugeWrapper + w.gauge = metrics.GetOrCreateGauge(name, func() float64 { + return float64(w.Val()) + }) + + return &w +} + +// Inc increments gauge value. +func (w *VictoriaMetricsUintGaugeWrapper) Inc() { + atomic.AddUint64(&w.val, 1) +} + +// Dec decrements gauge value. +func (w *VictoriaMetricsUintGaugeWrapper) Dec() { + atomic.AddUint64(&w.val, ^uint64(0)) +} + +// Set sets gauge value. +func (w *VictoriaMetricsUintGaugeWrapper) Set(val uint64) { + atomic.StoreUint64(&w.val, val) +} + +// Val gets gauge value. +func (w *VictoriaMetricsUintGaugeWrapper) Val() uint64 { + return atomic.LoadUint64(&w.val) +} diff --git a/pkg/skywire-utilities/pkg/netutil/copy.go b/pkg/skywire-utilities/pkg/netutil/copy.go new file mode 100644 index 0000000000..0fd7dbb43e --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/copy.go @@ -0,0 +1,37 @@ +// Package netutil pkg/netutil/copy.go +package netutil + +import ( + "io" +) + +// CopyReadWriteCloser copies reads and writes between two connections. +// It returns when a connection returns an error. +func CopyReadWriteCloser(conn1, conn2 io.ReadWriteCloser) error { + errCh1 := make(chan error, 1) + go func() { + _, err := io.Copy(conn2, conn1) + errCh1 <- err + close(errCh1) + }() + + errCh2 := make(chan error, 1) + go func() { + _, err := io.Copy(conn1, conn2) + errCh2 <- err + close(errCh2) + }() + + select { + case err := <-errCh1: + _ = conn1.Close() //nolint:errcheck + _ = conn2.Close() //nolint:errcheck + <-errCh2 + return err + case err := <-errCh2: + _ = conn2.Close() //nolint:errcheck + _ = conn1.Close() //nolint:errcheck + <-errCh1 + return err + } +} diff --git a/pkg/skywire-utilities/pkg/netutil/net.go b/pkg/skywire-utilities/pkg/netutil/net.go new file mode 100644 index 0000000000..fe5e56893b --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/net.go @@ -0,0 +1,176 @@ +// Package netutil pkg/netutil/net.go +package netutil + +import ( + "fmt" + "io" + "net" + "net/http" +) + +// LocalNetworkInterfaceIPs gets IPs of all local interfaces. +func LocalNetworkInterfaceIPs() ([]net.IP, error) { + ips, _, err := localNetworkInterfaceIPs("") + return ips, err +} + +// NetworkInterfaceIPs gets IPs of network interface with name `name`. +func NetworkInterfaceIPs(name string) ([]net.IP, error) { + _, ifcIPs, err := localNetworkInterfaceIPs(name) + return ifcIPs, err +} + +// localNetworkInterfaceIPs gets IPs of all local interfaces. Separately returns list of IPs +// of interface `ifcName`. +func localNetworkInterfaceIPs(ifcName string) ([]net.IP, []net.IP, error) { + var ifcIPs []net.IP + + ifaces, err := net.Interfaces() + if err != nil { + return nil, nil, fmt.Errorf("error getting network interfaces: %w", err) + } + + var ips []net.IP + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if iface.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + + addrs, err := iface.Addrs() + if err != nil { + return nil, nil, fmt.Errorf("error getting addresses for interface %s: %w", iface.Name, err) + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + + if ip == nil || ip.IsLoopback() { + continue + } + + ip = ip.To4() + if ip == nil { + continue // not an ipv4 address + } + + ips = append(ips, ip) + + if ifcName != "" && iface.Name == ifcName { + ifcIPs = append(ifcIPs, ip) + } + } + } + + return ips, ifcIPs, nil +} + +// IsPublicIP returns true if the provided IP is public. +// Obtained from: https://stackoverflow.com/questions/41670155/get-public-ip-in-golang +func IsPublicIP(IP net.IP) bool { + if IP.IsLoopback() || IP.IsLinkLocalMulticast() || IP.IsLinkLocalUnicast() { + return false + } + if ip4 := IP.To4(); ip4 != nil { + switch { + case ip4[0] == 10: + return false + case ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31: + return false + case ip4[0] == 192 && ip4[1] == 168: + return false + default: + return true + } + } + return false +} + +// DefaultNetworkInterfaceIPs returns IP addresses for the default network interface +func DefaultNetworkInterfaceIPs() ([]net.IP, error) { + networkIfc, err := DefaultNetworkInterface() + if err != nil { + return nil, fmt.Errorf("failed to get default network interface: %w", err) + } + localIPs, err := NetworkInterfaceIPs(networkIfc) + if err != nil { + return nil, fmt.Errorf("failed to get IPs of %s: %w", networkIfc, err) + } + return localIPs, nil +} + +// HasPublicIP returns true if this machine has at least one +// publically available IP address +func HasPublicIP() (bool, error) { + localIPs, err := LocalNetworkInterfaceIPs() + if err != nil { + return false, err + } + for _, IP := range localIPs { + if IsPublicIP(IP) { + return true, nil + } + } + return false, nil +} + +// ExtractPort returns port of the given UDP or TCP address +func ExtractPort(addr net.Addr) (uint16, error) { + switch address := addr.(type) { + case *net.TCPAddr: + return uint16(address.Port), nil //nolint + case *net.UDPAddr: + return uint16(address.Port), nil //nolint + default: + return 0, fmt.Errorf("extract port: invalid address: %s", addr.String()) + } +} + +// LocalAddresses returns a list of all local addresses +func LocalAddresses() ([]string, error) { + result := make([]string, 0) + + addresses, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + + for _, addr := range addresses { + switch v := addr.(type) { + case *net.IPNet: + if v.IP.IsGlobalUnicast() || v.IP.IsLoopback() { + result = append(result, v.IP.String()) + } + case *net.IPAddr: + if v.IP.IsGlobalUnicast() || v.IP.IsLoopback() { + result = append(result, v.IP.String()) + } + } + } + + return result, nil +} + +// LocalProtocol check a condition to use dmsghttp or direct url +func LocalProtocol() bool { + resp, err := http.Get("https://ipinfo.io/country") + if err != nil { + return false + } + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return false + } + if string(respBody)[:2] == "CN" { + return true + } + return false +} diff --git a/pkg/skywire-utilities/pkg/netutil/net_darwin.go b/pkg/skywire-utilities/pkg/netutil/net_darwin.go new file mode 100644 index 0000000000..aaf2bd554f --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/net_darwin.go @@ -0,0 +1,28 @@ +//go:build darwin +// +build darwin + +// Package netutil pkg/netutil/net_darwin.go +package netutil + +import ( + "bytes" + "fmt" + "os/exec" +) + +const ( + defaultNetworkInterfaceCMD = "route -n get default | awk 'FNR == 5 {print $2}'" +) + +// DefaultNetworkInterface fetches default network interface name. +func DefaultNetworkInterface() (string, error) { + outputBytes, err := exec.Command("sh", "-c", defaultNetworkInterfaceCMD).Output() + if err != nil { + return "", fmt.Errorf("error running command %s: %w", defaultNetworkInterfaceCMD, err) + } + + // just in case + outputBytes = bytes.TrimRight(outputBytes, "\n") + + return string(outputBytes), nil +} diff --git a/pkg/skywire-utilities/pkg/netutil/net_linux.go b/pkg/skywire-utilities/pkg/netutil/net_linux.go new file mode 100644 index 0000000000..a913df824a --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/net_linux.go @@ -0,0 +1,28 @@ +//go:build linux +// +build linux + +// Package netutil pkg/netutil/net_linux.go +package netutil + +import ( + "bytes" + "fmt" + "os/exec" +) + +const ( + defaultNetworkInterfaceCMD = "ip r | awk '$1 == \"default\" {print $5}'" +) + +// DefaultNetworkInterface fetches default network interface name. +func DefaultNetworkInterface() (string, error) { + outputBytes, err := exec.Command("sh", "-c", defaultNetworkInterfaceCMD).Output() + if err != nil { + return "", fmt.Errorf("error running command %s: %w", defaultNetworkInterfaceCMD, err) + } + + // just in case + outputBytes = bytes.TrimRight(outputBytes, "\n") + + return string(outputBytes), nil +} diff --git a/pkg/skywire-utilities/pkg/netutil/net_test.go b/pkg/skywire-utilities/pkg/netutil/net_test.go new file mode 100644 index 0000000000..0874ac0992 --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/net_test.go @@ -0,0 +1,18 @@ +// Package netutil pkg/netutil/net_test.go +package netutil_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire-utilities/pkg/netutil" +) + +func TestDefaultNetworkInterfaceIPs(t *testing.T) { + req := require.New(t) + + ifaceIPs, err := netutil.DefaultNetworkInterfaceIPs() + req.NoError(err) + t.Logf("interface IP: %v", ifaceIPs) +} diff --git a/pkg/skywire-utilities/pkg/netutil/net_windows.go b/pkg/skywire-utilities/pkg/netutil/net_windows.go new file mode 100644 index 0000000000..0217fbb78f --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/net_windows.go @@ -0,0 +1,58 @@ +//go:build windows +// +build windows + +// Package netutil pkg/netutil/net_windows.go +package netutil + +import ( + "errors" + "fmt" + "net" + "os/exec" + "regexp" + "strings" +) + +const ( + defaultNetworkInterfaceCMD = `netsh int ip show config | findstr /r "IP Address.*([0-9]{1,3}\.|){4}"` +) + +// DefaultNetworkInterface fetches default network interface name. +func DefaultNetworkInterface() (string, error) { + cmd := exec.Command("powershell", defaultNetworkInterfaceCMD) + output, err := cmd.Output() + if err != nil { + return "", err + } + // parse output + splitLines := strings.Split(string(output), "\n") + var ips []string + + if len(splitLines) > 0 { + re := regexp.MustCompile(`\s+`) + for _, line := range splitLines { + ipAddr := re.Split(strings.TrimSpace(line), -1) + + if len(ipAddr) > 2 { + ip := net.ParseIP(ipAddr[2]) + if ip != nil && !ip.IsLoopback() { + ips = append(ips, ipAddr[2]) + } + } + } + } + + if len(ips) == 0 { + return "", errors.New("no active ip found") + } + + // get default network interface based on its ip + findInterfaceCmd := fmt.Sprintf("Get-NetIpAddress -IPAddress '%s' | %%{$_.InterfaceAlias}", ips[0]) + cmd = exec.Command("powershell", findInterfaceCmd) // nolint:gosec + output, err = cmd.Output() + if err != nil { + return "", fmt.Errorf("unable to get default interface: %v", err) + } + + return strings.TrimSpace(string(output)), nil +} diff --git a/pkg/skywire-utilities/pkg/netutil/porter.go b/pkg/skywire-utilities/pkg/netutil/porter.go new file mode 100644 index 0000000000..031d7853a5 --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/porter.go @@ -0,0 +1,200 @@ +// Package netutil pkg/netutil/porter.go +package netutil + +import ( + "context" + "io" + "sync" + + "github.com/sirupsen/logrus" +) + +const ( + // PorterMinEphemeral is the default minimum ephemeral port. + PorterMinEphemeral = uint16(49152) +) + +// PorterValue associates a port value alongside it's children. +type PorterValue struct { + Value interface{} + Children map[uint16]interface{} +} + +// Porter reserves ports. +type Porter struct { + sync.RWMutex + eph uint16 // current ephemeral value + minEph uint16 // minimal ephemeral port value + ports map[uint16]PorterValue +} + +// NewPorter creates a new Porter with a given minimum ephemeral port value. +func NewPorter(minEph uint16) *Porter { + ports := make(map[uint16]PorterValue) + ports[0] = PorterValue{} // port 0 is invalid + + return &Porter{ + eph: minEph, + minEph: minEph, + ports: ports, + } +} + +// Reserve a given port. +// It returns a boolean informing whether the port is reserved, and a function to clear the reservation. +func (p *Porter) Reserve(port uint16, v interface{}) (bool, func()) { + p.Lock() + defer p.Unlock() + + if _, ok := p.ports[port]; ok { + return false, nil + } + p.ports[port] = PorterValue{ + Value: v, + } + return true, p.makePortFreer(port) +} + +// ReserveChild reserves a child. +func (p *Porter) ReserveChild(port, subPort uint16, v interface{}) (bool, func()) { + p.Lock() + defer p.Unlock() + + pv, ok := p.ports[port] + if !ok { + return false, nil + } + if pv.Children == nil { + pv.Children = make(map[uint16]interface{}, 1) + } else if _, ok := pv.Children[subPort]; ok { + return false, nil + } + + pv.Children[subPort] = v + p.ports[port] = pv + return true, p.makeChildFreer(port, subPort) +} + +// ReserveEphemeral reserves a new ephemeral port. +// It returns the reserved ephemeral port, a function to clear the reservation and an error (if any). +func (p *Porter) ReserveEphemeral(ctx context.Context, v interface{}) (uint16, func(), error) { + p.Lock() + defer p.Unlock() + + for { + p.eph++ + if p.eph < p.minEph { + p.eph = p.minEph + } + if _, ok := p.ports[p.eph]; ok { + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + default: + continue + } + } + p.ports[p.eph] = PorterValue{Value: v} + return p.eph, p.makePortFreer(p.eph), nil + } +} + +// PortValue returns the value stored under a given port. +func (p *Porter) PortValue(port uint16) (interface{}, bool) { + p.RLock() + defer p.RUnlock() + + v, ok := p.ports[port] + return v.Value, ok +} + +// RangePortValues ranges all ports that are currently reserved. +func (p *Porter) RangePortValues(fn func(port uint16, v interface{}) (next bool)) { + p.RLock() + defer p.RUnlock() + + for port, v := range p.ports { + if next := fn(port, v.Value); !next { + return + } + } +} + +// RangePortValuesAndChildren ranges port values and it's contained children. +func (p *Porter) RangePortValuesAndChildren(fn func(port uint16, v PorterValue) (next bool)) { + p.RLock() + defer p.RUnlock() + + for port, v := range p.ports { + if next := fn(port, v); !next { + return + } + } +} + +// This returns a function that frees a given port (if there are no children). +// It is ensured that the function's action is only performed once. +func (p *Porter) makePortFreer(port uint16) func() { + once := new(sync.Once) + + action := func() { + p.Lock() + defer p.Unlock() + + // If port still has children, only clear the port value. + if v, ok := p.ports[port]; ok && len(v.Children) > 0 { + v.Value = nil + p.ports[port] = v + return + } + + delete(p.ports, port) + } + + return func() { once.Do(action) } +} + +func (p *Porter) makeChildFreer(port, subPort uint16) func() { + once := new(sync.Once) + + action := func() { + p.Lock() + defer p.Unlock() + + if v, ok := p.ports[port]; ok && v.Children != nil { + delete(v.Children, subPort) + + // Also delete the ensure port entry if port value is nil and there is no more children. + if v.Value == nil && len(v.Children) == 0 { + delete(p.ports, port) + } + } + } + + return func() { once.Do(action) } +} + +// CloseAll closes all contained variables that implement io.Closer +func (p *Porter) CloseAll(log logrus.FieldLogger) { + if log == nil { + log = logrus.New() + } + + wg := new(sync.WaitGroup) + p.Lock() + for _, v := range p.ports { + if c, ok := v.Value.(io.Closer); ok { + + wg.Add(1) + go func(c io.Closer) { + if err := c.Close(); err != nil { + log.WithError(err). + Debug("On (*netutil.Porter).CloseAll(), closing contained value resulted in error.") + } + wg.Done() + }(c) + } + } + p.Unlock() + wg.Wait() +} diff --git a/pkg/skywire-utilities/pkg/netutil/retrier.go b/pkg/skywire-utilities/pkg/netutil/retrier.go new file mode 100644 index 0000000000..a3648fac8f --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/retrier.go @@ -0,0 +1,101 @@ +// Package netutil pkg/netutil/retrier.go +package netutil + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/sirupsen/logrus" +) + +// Package errors +var ( + ErrMaximumRetriesReached = errors.New("maximum retries attempted without success") +) + +// Default values for retrier. +const ( + DefaultInitBackoff = time.Second + DefaultMaxBackoff = time.Second * 20 + DefaultTries = int64(0) + DefaultFactor = float64(1.3) +) + +// RetryFunc is a function used as argument of (*Retrier).Do(), which will retry on error unless it is whitelisted +type RetryFunc func() error + +// Retrier holds a configuration for how retries should be performed +type Retrier struct { + initBO time.Duration // initial backoff duration + maxBO time.Duration // maximum backoff duration + tries int64 // number of times the given function is to be retried until success, if 0 it will be retried forever until success + factor float64 // multiplier for the backoff duration that is applied on every retry + errWl map[error]struct{} // list of errors which will always trigger retirer to return + log logrus.FieldLogger +} + +// NewRetrier returns a retrier that is ready to call Do() method +func NewRetrier(log logrus.FieldLogger, initBO, maxBO time.Duration, tries int64, factor float64) *Retrier { + if log != nil { + log = log.WithField("func", "retrier") + } + return &Retrier{ + initBO: initBO, + maxBO: maxBO, + tries: tries, + factor: factor, + errWl: make(map[error]struct{}), + log: log, + } +} + +// NewDefaultRetrier creates a retrier with default values. +func NewDefaultRetrier(log logrus.FieldLogger) *Retrier { + return NewRetrier(log, DefaultInitBackoff, DefaultMaxBackoff, DefaultTries, DefaultFactor) +} + +// WithErrWhitelist sets a list of errors into the retrier, if the RetryFunc provided to Do() fails with one of them it will return inmediatelly with such error. Calling +// this function is not thread-safe, and is advised to only use it when initializing the Retrier +func (r *Retrier) WithErrWhitelist(errors ...error) *Retrier { + for _, err := range errors { + r.errWl[err] = struct{}{} + } + return r +} + +// Do takes a RetryFunc and attempts to execute it. +// If it fails with an error it will be retried a maximum of given times with an initBO +// until it returns nil or an error that is whitelisted +func (r *Retrier) Do(ctx context.Context, f RetryFunc) error { + bo := r.initBO + + t := time.NewTimer(bo) + defer t.Stop() + + for i := int64(0); r.tries == 0 || i < r.tries; i++ { + if err := f(); err != nil { + if _, ok := r.errWl[err]; ok { + return err + } + if newBO := time.Duration(float64(bo) * r.factor); r.maxBO == 0 || newBO <= r.maxBO { + bo = newBO + } + select { + case <-t.C: + if r.log != nil { + r.log.WithError(err).WithField("current_backoff", bo).Warn("Retrying...") + } else { + fmt.Printf("func = retrier, current_backoff = %v Retrying...\n", bo) + } + t.Reset(bo) + continue + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + } + return ErrMaximumRetriesReached +} diff --git a/pkg/skywire-utilities/pkg/netutil/retrier_test.go b/pkg/skywire-utilities/pkg/netutil/retrier_test.go new file mode 100644 index 0000000000..14f830c972 --- /dev/null +++ b/pkg/skywire-utilities/pkg/netutil/retrier_test.go @@ -0,0 +1,64 @@ +// Package netutil pkg/netutil/retrier_test.go +package netutil + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func TestRetrier_Do(t *testing.T) { + r := NewRetrier(logrus.New(), time.Millisecond*100, 0, 3, 2) + c := 0 + threshold := 2 + f := func() error { + c++ + if c >= threshold { + return nil + } + + return errors.New("foo") + } + + t.Run("should retry", func(t *testing.T) { + c = 0 + + err := r.Do(context.TODO(), f) + require.NoError(t, err) + }) + + t.Run("if retry reaches max number of times should error", func(t *testing.T) { + c = 0 + threshold = 4 + defer func() { + threshold = 2 + }() + + err := r.Do(context.TODO(), f) + require.Error(t, err) + }) + + t.Run("should return whitelisted errors if any instead of retry", func(t *testing.T) { + bar := errors.New("bar") + wR := NewRetrier(logrus.New(), 50*time.Millisecond, 0, 1, 2).WithErrWhitelist(bar) + barF := func() error { + return bar + } + + err := wR.Do(context.TODO(), barF) + require.EqualError(t, err, bar.Error()) + }) + + t.Run("if times is 0, should retry until success", func(t *testing.T) { + c = 0 + loopR := NewRetrier(logrus.New(), 50*time.Millisecond, 0, 0, 1) + err := loopR.Do(context.TODO(), f) + require.NoError(t, err) + + require.Equal(t, threshold, c) + }) +} diff --git a/pkg/skywire-utilities/pkg/networkmonitor/networkmonitor.go b/pkg/skywire-utilities/pkg/networkmonitor/networkmonitor.go new file mode 100644 index 0000000000..fc246166fa --- /dev/null +++ b/pkg/skywire-utilities/pkg/networkmonitor/networkmonitor.go @@ -0,0 +1,21 @@ +// Package networkmonitor pkg/networkmonitor/networkmonitor.go +package networkmonitor + +// WhitelistPKs store whitelisted keys of network monitor +type WhitelistPKs map[string]struct{} + +// GetWhitelistPKs returns the stuct WhitelistPKs +func GetWhitelistPKs() WhitelistPKs { + return make(WhitelistPKs) +} + +// Set sets the whitelist with the given pk in the struct +func (wl WhitelistPKs) Set(nmPkString string) { + wl[nmPkString] = struct{}{} +} + +// Get gets the pk from the whitelist +func (wl WhitelistPKs) Get(nmPkString string) bool { + _, ok := wl[nmPkString] + return ok +} diff --git a/pkg/skywire-utilities/pkg/skyenv/values.go b/pkg/skywire-utilities/pkg/skyenv/values.go new file mode 100644 index 0000000000..e46632cee1 --- /dev/null +++ b/pkg/skywire-utilities/pkg/skyenv/values.go @@ -0,0 +1,51 @@ +// Package skyenv pkg/skyenv/values.go +package skyenv + +// Constants for new default services. +const ( + ServiceConfAddr = "http://conf.skywire.skycoin.com" + TpDiscAddr = "http://tpd.skywire.skycoin.com" + DmsgDiscAddr = "http://dmsgd.skywire.skycoin.com" + ServiceDiscAddr = "http://sd.skycoin.com" + RouteFinderAddr = "http://rf.skywire.skycoin.com" + UptimeTrackerAddr = "http://ut.skywire.skycoin.com" + AddressResolverAddr = "http://ar.skywire.skycoin.com" + RouteSetupPKs = "0324579f003e6b4048bae2def4365e634d8e0e3054a20fc7af49daf2a179658557,024fbd3997d4260f731b01abcfce60b8967a6d4c6a11d1008812810ea1437ce438,03b87c282f6e9f70d97aeea90b07cf09864a235ef718725632d067873431dd1015" + TPSetupPKs = "03530b786c670fc7f5ab9021478c7ec9cd06a03f3ea1416c50c4a8889ef5bba80e,03271c0de223b80400d9bd4b7722b536a245eb6c9c3176781ee41e7bac8f9bad21,03a792e6d960c88c6fb2184ee4f16714c58b55f0746840617a19f7dd6e021699d9,0313efedc579f57f05d4f5bc3fbf0261f31e51cdcfde7e568169acf92c78868926,025c7bbf23e3441a36d7e8a1e9d717921e2a49a2ce035680fec4808a048d244c8a,030eb6967f6e23e81db0d214f925fc5ce3371e1b059fb8379ae3eb1edfc95e0b46,02e582c0a5e5563aad47f561b272e4c3a9f7ac716258b58e58eb50afd83c286a7f,02ddc6c749d6ed067bb68df19c9bcb1a58b7587464043b1707398ffa26a9746b26,03aa0b1c4e23616872058c11c6efba777c130a85eaf909945d697399a1eb08426d,03adb2c924987d8deef04d02bd95236c5ae172fe5dfe7273e0461d96bf4bc220be" + NetworkMonitorPKs = "0380ea88f0ad0aa4d93c330ba5f97aabca1d892190b94db69eee140b549d2817dd,0283bddb4357e2c4de0d470032cd809966aec65ce57e1188143ab32c7b589b38b6,02f4e33b75307267229b0c3d679d08dd23374333f558288cfcb114311a52199358,02090f03cb26c71779b8327067e2e37314d2db3e31dfe4f8f3cdd8e088a98eb7ec,03ff8dc39ed8d84be17a15b6a243edbcef1a5fd425209243fd7a9a28f0d23ddbea,02b9aa8276907db6f6ea8626d5d26aa6e119dd89d88bb222ce868376c5367d7b4c" + SurveyWhitelistPKs = "0327e2cf1d2e516ecbfdbd616a87489cc92a73af97335d5c8c29eafb5d8882264a,03abbb3eff140cf3dce468b3fa5a28c80fa02c6703d7b952be6faaf2050990ebf4,02b5ee5333aa6b7f5fc623b7d5f35f505cb7f974e98a70751cf41962f84c8c4637,03714c8bdaee0fb48f47babbc47c33e1880752b6620317c9d56b30f3b0ff58a9c3,020d35bbaf0a5abc8ec0ba33cde219fde734c63e7202098e1f9a6cf9daaeee55a9,027f7dec979482f418f01dfabddbd750ad036c579a16422125dd9a313eaa59c8e1,031d4cf1b7ab4c789b56c769f2888e4a61c778dfa5fe7e5cd0217fc41660b2eb65" + RewardSystemPKs = "036a70e6956061778e1883e928c1236189db14dfd446df23d83e45c321b330c91f" +) + +// Constants for testing deployment. +const ( + TestServiceConfAddr = "http://conf.skywire.dev" + TestTpDiscAddr = "http://tpd.skywire.dev" + TestDmsgDiscAddr = "http://dmsgd.skywire.dev" + TestServiceDiscAddr = "http://sd.skywire.dev" + TestRouteFinderAddr = "http://rf.skywire.dev" + TestUptimeTrackerAddr = "http://ut.skywire.dev" + TestAddressResolverAddr = "http://ar.skywire.dev" + TestRouteSetupPKs = "0324579f003e6b4048bae2def4365e634d8e0e3054a20fc7af49daf2a179658557,024fbd3997d4260f731b01abcfce60b8967a6d4c6a11d1008812810ea1437ce438,03b87c282f6e9f70d97aeea90b07cf09864a235ef718725632d067873431dd1015" + TestTPSetupPKs = "03530b786c670fc7f5ab9021478c7ec9cd06a03f3ea1416c50c4a8889ef5bba80e,03271c0de223b80400d9bd4b7722b536a245eb6c9c3176781ee41e7bac8f9bad21,03a792e6d960c88c6fb2184ee4f16714c58b55f0746840617a19f7dd6e021699d9,0313efedc579f57f05d4f5bc3fbf0261f31e51cdcfde7e568169acf92c78868926,025c7bbf23e3441a36d7e8a1e9d717921e2a49a2ce035680fec4808a048d244c8a,030eb6967f6e23e81db0d214f925fc5ce3371e1b059fb8379ae3eb1edfc95e0b46,02e582c0a5e5563aad47f561b272e4c3a9f7ac716258b58e58eb50afd83c286a7f,02ddc6c749d6ed067bb68df19c9bcb1a58b7587464043b1707398ffa26a9746b26,03aa0b1c4e23616872058c11c6efba777c130a85eaf909945d697399a1eb08426d,03adb2c924987d8deef04d02bd95236c5ae172fe5dfe7273e0461d96bf4bc220be" + TestNetworkMonitorPKs = "0380ea88f0ad0aa4d93c330ba5f97aabca1d892190b94db69eee140b549d2817dd,0283bddb4357e2c4de0d470032cd809966aec65ce57e1188143ab32c7b589b38b6,02f4e33b75307267229b0c3d679d08dd23374333f558288cfcb114311a52199358,02090f03cb26c71779b8327067e2e37314d2db3e31dfe4f8f3cdd8e088a98eb7ec,03ff8dc39ed8d84be17a15b6a243edbcef1a5fd425209243fd7a9a28f0d23ddbea,02b9aa8276907db6f6ea8626d5d26aa6e119dd89d88bb222ce868376c5367d7b4c" + TestSurveyWhitelistPKs = "0327e2cf1d2e516ecbfdbd616a87489cc92a73af97335d5c8c29eafb5d8882264a,03abbb3eff140cf3dce468b3fa5a28c80fa02c6703d7b952be6faaf2050990ebf4,02b5ee5333aa6b7f5fc623b7d5f35f505cb7f974e98a70751cf41962f84c8c4637,03714c8bdaee0fb48f47babbc47c33e1880752b6620317c9d56b30f3b0ff58a9c3,020d35bbaf0a5abc8ec0ba33cde219fde734c63e7202098e1f9a6cf9daaeee55a9,027f7dec979482f418f01dfabddbd750ad036c579a16422125dd9a313eaa59c8e1,031d4cf1b7ab4c789b56c769f2888e4a61c778dfa5fe7e5cd0217fc41660b2eb65" + TestRewardSystemPKs = "036a70e6956061778e1883e928c1236189db14dfd446df23d83e45c321b330c91f" +) + +// GetStunServers gives back default Stun Servers +func GetStunServers() []string { + return []string{ + "139.162.30.112:3478", + "192.53.118.31:3478", + "192.53.118.61:3478", + "170.187.228.44:3478", + "170.187.228.178:3478", + "139.162.30.129:3478", + "192.53.118.134:3478", + "192.53.118.209:3478", + } +} + +// DNSServer is value for DNS Server Address +const DNSServer = "1.1.1.1" diff --git a/pkg/skywire-utilities/pkg/storeconfig/storeconfig.go b/pkg/skywire-utilities/pkg/storeconfig/storeconfig.go new file mode 100644 index 0000000000..8d178191c6 --- /dev/null +++ b/pkg/skywire-utilities/pkg/storeconfig/storeconfig.go @@ -0,0 +1,41 @@ +// Package storeconfig pkg/storeconfig/storeconfig.go +package storeconfig + +import ( + "os" +) + +// Type is a config type. +type Type int + +// Type may be either in-memory or Redis. +const ( + Memory Type = iota + Redis +) + +// Config defines a store configuration. +type Config struct { + Type Type + URL string `json:"url"` + Password string `json:"password"` + PoolSize int `json:"pool_size"` +} + +const redisPasswordEnvName = "REDIS_PASSWORD" + +const ( + pgUser = "PG_USER" + pgPassword = "PG_PASSWORD" + pgDatabase = "PG_DATABASE" +) + +// RedisPassword returns Redis password which is read from an environment variable. +func RedisPassword() string { + return os.Getenv(redisPasswordEnvName) +} + +// PostgresCredential return prostgres credential needed on services +func PostgresCredential() (string, string, string) { + return os.Getenv(pgUser), os.Getenv(pgPassword), os.Getenv(pgDatabase) +} diff --git a/pkg/skywire-utilities/pkg/tcpproxy/http.go b/pkg/skywire-utilities/pkg/tcpproxy/http.go new file mode 100644 index 0000000000..51ce30c923 --- /dev/null +++ b/pkg/skywire-utilities/pkg/tcpproxy/http.go @@ -0,0 +1,24 @@ +// Package tcpproxy pkg/tcpproxy/tcpproxy.go +package tcpproxy + +import ( + "net" + "net/http" + + proxyproto "github.com/pires/go-proxyproto" +) + +// ListenAndServe starts http server with tcp proxy support +func ListenAndServe(addr string, handler http.Handler) error { + srv := &http.Server{Addr: addr, Handler: handler} //nolint + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + proxyListener := &proxyproto.Listener{Listener: ln} + defer proxyListener.Close() // nolint:errcheck + return srv.Serve(proxyListener) +} diff --git a/pkg/visor/api.go b/pkg/visor/api.go index 8e63206da8..2004b21a7f 100644 --- a/pkg/visor/api.go +++ b/pkg/visor/api.go @@ -1393,19 +1393,19 @@ func (v *Visor) TestVisor(conf PingConfig) ([]TestResult, error) { result = append(result, TestResult{PK: conf.PK.String(), Max: fmt.Sprint(0), Min: fmt.Sprint(0), Mean: fmt.Sprint(0), Status: "Failed"}) continue } - var max, min, mean, sumLatency time.Duration - min = time.Duration(10000000000) + var maxx, minn, mean, sumLatency time.Duration + minn = time.Duration(10000000000) for _, latency := range latencies { - if latency > max { - max = latency + if latency > maxx { + maxx = latency } - if latency < min { - min = latency + if latency < minn { + minn = latency } sumLatency += latency } mean = sumLatency / time.Duration(len(latencies)) - result = append(result, TestResult{PK: conf.PK.String(), Max: fmt.Sprint(max), Min: fmt.Sprint(min), Mean: fmt.Sprint(mean), Status: "Success"}) + result = append(result, TestResult{PK: conf.PK.String(), Max: fmt.Sprint(maxx), Min: fmt.Sprint(minn), Mean: fmt.Sprint(mean), Status: "Success"}) v.StopPing(conf.PK) //nolint } return result, nil diff --git a/pkg/visor/logstore/logstore.go b/pkg/visor/logstore/logstore.go index 87fa134a59..2b3e408aa6 100644 --- a/pkg/visor/logstore/logstore.go +++ b/pkg/visor/logstore/logstore.go @@ -24,10 +24,10 @@ type Store interface { // overwriting the oldest entry when over the capacity // returned hook should be registered in logrus master logger to // store log entries -func MakeStore(max int) (Store, logrus.Hook) { - entries := make([]string, max) +func MakeStore(maxx int) (Store, logrus.Hook) { + entries := make([]string, maxx) formatter := &logrus.JSONFormatter{} - store := &store{cap: int64(max), entries: entries, formatter: formatter} + store := &store{cap: int64(maxx), entries: entries, formatter: formatter} return store, store } diff --git a/pkg/visor/visorconfig/v1.go b/pkg/visor/visorconfig/v1.go index 17ee7c2a52..93f0701b3f 100644 --- a/pkg/visor/visorconfig/v1.go +++ b/pkg/visor/visorconfig/v1.go @@ -288,9 +288,9 @@ func (v1 *V1) AddAppConfig(launch *launcher.AppLauncher, appName, binaryName str } var randomNumber int for { - min := 10 - max := 99 - randomNumber = rand.Intn(max-min+1) + min //nolint: gosec + minn := 10 + maxx := 99 + randomNumber = rand.Intn(maxx-minn+1) + minn //nolint: gosec if _, ok := busyPorts[routing.Port(randomNumber)]; !ok { //nolint: gosec break } diff --git a/pkg/visor/visorconfig/values.go b/pkg/visor/visorconfig/values.go index ec467e0738..f583bb4336 100644 --- a/pkg/visor/visorconfig/values.go +++ b/pkg/visor/visorconfig/values.go @@ -36,92 +36,175 @@ var ( // Dmsg port constants. // TODO(evanlinjin): Define these properly. These are currently random. - DmsgCtrlPort = skyenv.DmsgCtrlPort // DmsgCtrlPort Listening port for dmsgctrl protocol (similar to TCP Echo Protocol). - DmsgSetupPort = skyenv.DmsgSetupPort // DmsgSetupPort Listening port of a setup node. - DmsgHypervisorPort = skyenv.DmsgHypervisorPort // DmsgHypervisorPort Listening port of a hypervisor for incoming RPC visor connections over dmsg. - DmsgTransportSetupPort = skyenv.DmsgTransportSetupPort // DmsgTransportSetupPort Listening port for transport setup RPC over dmsg. - DmsgHTTPPort = dmsg.DefaultDmsgHTTPPort // DmsgHTTPPort Listening port for dmsghttp logserver. - DmsgAwaitSetupPort = skyenv.DmsgAwaitSetupPort // DmsgAwaitSetupPort Listening port of a visor for setup operations. + // DmsgCtrlPort Listening port for dmsgctrl protocol (similar to TCP Echo Protocol). + DmsgCtrlPort = skyenv.DmsgCtrlPort + + // DmsgSetupPort Listening port of a setup node. + DmsgSetupPort = skyenv.DmsgSetupPort + + // DmsgHypervisorPort Listening port of a hypervisor for incoming RPC visor connections over dmsg. + DmsgHypervisorPort = skyenv.DmsgHypervisorPort + + // DmsgTransportSetupPort Listening port for transport setup RPC over dmsg. + DmsgTransportSetupPort = skyenv.DmsgTransportSetupPort + + // DmsgHTTPPort Listening port for dmsghttp logserver. + DmsgHTTPPort = dmsg.DefaultDmsgHTTPPort + + // DmsgAwaitSetupPort Listening port of a visor for setup operations. + DmsgAwaitSetupPort = skyenv.DmsgAwaitSetupPort // Transport port constants. - TransportPort = skyenv.TransportPort // TransportPort Listening port of a visor for incoming transports. - PublicAutoconnect = skyenv.PublicAutoconnect // PublicAutoconnect ... + // TransportPort Listening port of a visor for incoming transports. + TransportPort = skyenv.TransportPort + + // PublicAutoconnect ... + PublicAutoconnect = skyenv.PublicAutoconnect // Dmsgpty constants. - DmsgPtyPort = skyenv.DmsgPtyPort // DmsgPtyPort ... - DmsgPtyCLINet = skyenv.DmsgPtyCLINet // DmsgPtyCLINet ... + // DmsgPtyPort ... + DmsgPtyPort = skyenv.DmsgPtyPort + + // DmsgPtyCLINet ... + DmsgPtyCLINet = skyenv.DmsgPtyCLINet // Skywire-TCP constants. - STCPAddr = skyenv.STCPAddr // STCPAddr ... + // STCPAddr ... + STCPAddr = skyenv.STCPAddr // Default skywire app constants. - SkychatName = skyenv.SkychatName // SkychatName ... - SkychatPort = skyenv.SkychatPort // SkychatPort ... - SkychatAddr = skyenv.SkychatAddr // SkychatAddr ... + // SkychatName ... + SkychatName = skyenv.SkychatName + + // SkychatPort ... + SkychatPort = skyenv.SkychatPort + + // SkychatAddr ... + SkychatAddr = skyenv.SkychatAddr - PingTestName = skyenv.PingTestName // PingTestName ... - PingTestPort = skyenv.PingTestPort // PingTestPort ... + // PingTestName ... + PingTestName = skyenv.PingTestName - SkysocksName = skyenv.SkysocksName // SkysocksName ... - SkysocksPort = skyenv.SkysocksPort // SkysocksPort ... + // PingTestPort ... + PingTestPort = skyenv.PingTestPort - SkysocksClientName = skyenv.SkysocksClientName // SkysocksClientName ... - SkysocksClientPort = skyenv.SkysocksClientPort // SkysocksClientPort ... - SkysocksClientAddr = skyenv.SkysocksClientAddr // SkysocksClientAddr ... + // SkysocksName ... + SkysocksName = skyenv.SkysocksName - VPNServerName = skyenv.VPNServerName // VPNServerName ... - VPNServerPort = skyenv.VPNServerPort // VPNServerPort ... + // SkysocksPort ... + SkysocksPort = skyenv.SkysocksPort - VPNClientName = skyenv.VPNClientName // VPNClientName ... + // SkysocksClientName ... + SkysocksClientName = skyenv.SkysocksClientName + + // SkysocksClientPort ... + SkysocksClientPort = skyenv.SkysocksClientPort + + // SkysocksClientAddr ... + SkysocksClientAddr = skyenv.SkysocksClientAddr + + // VPNServerName ... + VPNServerName = skyenv.VPNServerName + + // VPNServerPort ... + VPNServerPort = skyenv.VPNServerPort + + // VPNClientName ... + VPNClientName = skyenv.VPNClientName // TODO(darkrengarius): this one's not needed for the app to run but lack of it causes errors - VPNClientPort = skyenv.VPNClientPort // VPNClientPort ... + // VPNClientPort ... + VPNClientPort = skyenv.VPNClientPort + + // ExampleServerName ... + ExampleServerName = skyenv.ExampleServerName + + // ExampleServerPort ... + ExampleServerPort = skyenv.ExampleServerPort + + // ExampleClientName ... + ExampleClientName = skyenv.ExampleClientName - ExampleServerName = skyenv.ExampleServerName // ExampleServerName ... - ExampleServerPort = skyenv.ExampleServerPort // ExampleServerPort ... - ExampleClientName = skyenv.ExampleClientName // ExampleClientName ... - ExampleClientPort = skyenv.ExampleClientPort // ExampleClientPort ... - SkyForwardingServerName = skyenv.SkyForwardingServerName // SkyForwardingServerName ... - SkyForwardingServerPort = skyenv.SkyForwardingServerPort // SkyForwardingServerPort ... - SkyPingName = skyenv.SkyPingName // SkyPingName ... - SkyPingPort = skyenv.SkyPingPort // SkyPingPort ... + // ExampleClientPort ... + ExampleClientPort = skyenv.ExampleClientPort + + // SkyForwardingServerName ... + SkyForwardingServerName = skyenv.SkyForwardingServerName + + // SkyForwardingServerPort ... + SkyForwardingServerPort = skyenv.SkyForwardingServerPort + + // SkyPingName ... + SkyPingName = skyenv.SkyPingName + + // SkyPingPort ... + SkyPingPort = skyenv.SkyPingPort // RPC constants. - RPCAddr = skyenv.RPCAddr // RPCAddr ... - RPCTimeout = skyenv.RPCTimeout // RPCTimeout ... - TransportRPCTimeout = skyenv.TransportRPCTimeout // TransportRPCTimeout ... - UpdateRPCTimeout = skyenv.UpdateRPCTimeout // UpdateRPCTimeout ... + // RPCAddr ... + RPCAddr = skyenv.RPCAddr + + // RPCTimeout ... + RPCTimeout = skyenv.RPCTimeout + + // TransportRPCTimeout ... + TransportRPCTimeout = skyenv.TransportRPCTimeout + + // UpdateRPCTimeout ... + UpdateRPCTimeout = skyenv.UpdateRPCTimeout // Default skywire app server and discovery constants - AppSrvAddr = skyenv.AppSrvAddr // AppSrvAddr ... - ServiceDiscUpdateInterval = skyenv.ServiceDiscUpdateInterval // ServiceDiscUpdateInterval ... - AppBinPath = skyenv.AppBinPath // AppBinPath ... - LogLevel = skyenv.LogLevel // LogLevel ... + // AppSrvAddr ... + AppSrvAddr = skyenv.AppSrvAddr + + // ServiceDiscUpdateInterval ... + ServiceDiscUpdateInterval = skyenv.ServiceDiscUpdateInterval + + // AppBinPath ... + AppBinPath = skyenv.AppBinPath + + // LogLevel ... + LogLevel = skyenv.LogLevel // Routing constants - TpLogStore = skyenv.TpLogStore // TpLogStore ... - Custom = skyenv.Custom // Custom ... + // TpLogStore ... + TpLogStore = skyenv.TpLogStore + + // Custom ... + Custom = skyenv.Custom // Local constants - LocalPath = skyenv.LocalPath // LocalPath ... + // LocalPath ... + LocalPath = skyenv.LocalPath // Default hypervisor constants - HypervisorDB = skyenv.HypervisorDB // HypervisorDB ... - EnableAuth = skyenv.EnableAuth // EnableAuth ... - PackageEnableAuth = skyenv.PackageEnableAuth // PackageEnableAuth ... - EnableTLS = skyenv.EnableTLS // EnableTLS ... - TLSKey = skyenv.TLSKey // TLSKey ... - TLSCert = skyenv.TLSCert // TLSCert ... + // HypervisorDB ... + HypervisorDB = skyenv.HypervisorDB + + // EnableAuth ... + EnableAuth = skyenv.EnableAuth + + // PackageEnableAuth ... + PackageEnableAuth = skyenv.PackageEnableAuth + + // EnableTLS ... + EnableTLS = skyenv.EnableTLS + + // TLSKey ... + TLSKey = skyenv.TLSKey + + // TLSCert ... + TLSCert = skyenv.TLSCert // IPCShutdownMessageType sends IPC shutdown message type IPCShutdownMessageType = skyenv.IPCShutdownMessageType