Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor into separate packages & add tests. #4

Merged
merged 1 commit into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test-startup.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Go
name: Test systemd

on:
push:
Expand All @@ -18,7 +18,7 @@ jobs:
go-version: "1.21"

- name: Build
run: go build -v ./...
run: go build -v

- name: Install go-mmproxy
run: |
Expand Down
31 changes: 31 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Test

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: "1.21"

- name: Build
run: go build -v

- name: Prepare ip routes
run: |
sudo ip rule add from 127.0.0.1/8 iif lo table 123
sudo ip route add local 0.0.0.0/0 dev lo table 123
sudo ip -6 rule add from ::1/128 iif lo table 123
sudo ip -6 route add local ::/0 dev lo table 123

- name: Test
run: sudo go test -v -timeout 30s ./tests
24 changes: 0 additions & 24 deletions buffers.go

This file was deleted.

28 changes: 28 additions & 0 deletions buffers/buffers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2019 Path Network, Inc. All rights reserved.
// Copyright 2024 Konrad Zemek <konrad.zemek@gmail.com>
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package buffers

import (
"math"
"sync"
)

var buffers sync.Pool

func init() {
buffers.New = func() any {
slice := make([]byte, math.MaxUint16)
return &slice
}
}

func Get() []byte {
return *buffers.Get().(*[]byte)
}

func Put(buf []byte) {
buffers.Put(&buf)
}
141 changes: 72 additions & 69 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,62 +1,58 @@
// Copyright 2019 Path Network, Inc. All rights reserved.
// Copyright 2024 Konrad Zemek <konrad.zemek@gmail.com>
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
"bufio"
"context"
"flag"
"log/slog"
"net"
"net/netip"
"os"
"syscall"
"time"

"github.com/kzemek/go-mmproxy/tcp"
"github.com/kzemek/go-mmproxy/udp"
"github.com/kzemek/go-mmproxy/utils"
)

type options struct {
Protocol string
ListenAddrStr string
TargetAddr4Str string
TargetAddr6Str string
ListenAddr netip.AddrPort
TargetAddr4 netip.AddrPort
TargetAddr6 netip.AddrPort
Mark int
Verbose int
allowedSubnetsPath string
AllowedSubnets []*net.IPNet
Listeners int
Logger *slog.Logger
udpCloseAfter int
UDPCloseAfter time.Duration
}
var protocolStr string
var listenAddrStr string
var targetAddr4Str string
var targetAddr6Str string
var allowedSubnetsPath string
var udpCloseAfterInt int
var listeners int

var Opts options
var opts utils.Options

func init() {
flag.StringVar(&Opts.Protocol, "p", "tcp", "Protocol that will be proxied: tcp, udp")
flag.StringVar(&Opts.ListenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on")
flag.StringVar(&Opts.TargetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to")
flag.StringVar(&Opts.TargetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to")
flag.IntVar(&Opts.Mark, "mark", 0, "The mark that will be set on outbound packets")
flag.IntVar(&Opts.Verbose, "v", 0, `0 - no logging of individual connections
flag.StringVar(&protocolStr, "p", "tcp", "Protocol that will be proxied: tcp, udp")
flag.StringVar(&listenAddrStr, "l", "0.0.0.0:8443", "Address the proxy listens on")
flag.StringVar(&targetAddr4Str, "4", "127.0.0.1:443", "Address to which IPv4 traffic will be forwarded to")
flag.StringVar(&targetAddr6Str, "6", "[::1]:443", "Address to which IPv6 traffic will be forwarded to")
flag.IntVar(&opts.Mark, "mark", 0, "The mark that will be set on outbound packets")
flag.IntVar(&opts.Verbose, "v", 0, `0 - no logging of individual connections
1 - log errors occurring in individual connections
2 - log all state changes of individual connections`)
flag.StringVar(&Opts.allowedSubnetsPath, "allowed-subnets", "",
flag.StringVar(&allowedSubnetsPath, "allowed-subnets", "",
"Path to a file that contains allowed subnets of the proxy servers")
flag.IntVar(&Opts.Listeners, "listeners", 1,
flag.IntVar(&listeners, "listeners", 1,
"Number of listener sockets that will be opened for the listen address (Linux 3.9+)")
flag.IntVar(&Opts.udpCloseAfter, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up")
flag.IntVar(&udpCloseAfterInt, "close-after", 60, "Number of seconds after which UDP socket will be cleaned up")
}

func listen(listenerNum int, errors chan<- error) {
logger := Opts.Logger.With(slog.Int("listenerNum", listenerNum),
slog.String("protocol", Opts.Protocol), slog.String("listenAdr", Opts.ListenAddr.String()))
func listen(ctx context.Context, listenerNum int, parentLogger *slog.Logger, listenErrors chan<- error) {
logger := parentLogger.With(slog.Int("listenerNum", listenerNum),
slog.String("protocol", protocolStr), slog.String("listenAdr", opts.ListenAddr.String()))

listenConfig := net.ListenConfig{}
if Opts.Listeners > 1 {
if listeners > 1 {
listenConfig.Control = func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
soReusePort := 15
Expand All @@ -67,15 +63,15 @@ func listen(listenerNum int, errors chan<- error) {
}
}

if Opts.Protocol == "tcp" {
tcpListen(&listenConfig, logger, errors)
if opts.Protocol == utils.TCP {
tcp.Listen(ctx, &listenConfig, &opts, logger, listenErrors)
} else {
udpListen(&listenConfig, logger, errors)
udp.Listen(ctx, &listenConfig, &opts, logger, listenErrors)
}
}

func loadAllowedSubnets() error {
file, err := os.Open(Opts.allowedSubnetsPath)
func loadAllowedSubnets(logger *slog.Logger) error {
file, err := os.Open(allowedSubnetsPath)
if err != nil {
return err
}
Expand All @@ -84,12 +80,12 @@ func loadAllowedSubnets() error {

scanner := bufio.NewScanner(file)
for scanner.Scan() {
_, ipNet, err := net.ParseCIDR(scanner.Text())
ipNet, err := netip.ParsePrefix(scanner.Text())
if err != nil {
return err
}
Opts.AllowedSubnets = append(Opts.AllowedSubnets, ipNet)
Opts.Logger.Info("allowed subnet", slog.String("subnet", ipNet.String()))
opts.AllowedSubnets = append(opts.AllowedSubnets, ipNet)
logger.Info("allowed subnet", slog.String("subnet", ipNet.String()))
}

return nil
Expand All @@ -98,72 +94,79 @@ func loadAllowedSubnets() error {
func main() {
flag.Parse()
lvl := slog.LevelInfo
if Opts.Verbose > 0 {
if opts.Verbose > 0 {
lvl = slog.LevelDebug
}
Opts.Logger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))

if Opts.allowedSubnetsPath != "" {
if err := loadAllowedSubnets(); err != nil {
Opts.Logger.Error("failed to load allowed subnets file", "path", Opts.allowedSubnetsPath, "error", err)
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: lvl}))

if allowedSubnetsPath != "" {
if err := loadAllowedSubnets(logger); err != nil {
logger.Error("failed to load allowed subnets file", "path", allowedSubnetsPath, "error", err)
}
}

if Opts.Protocol != "tcp" && Opts.Protocol != "udp" {
Opts.Logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", Opts.Protocol))
if protocolStr == "tcp" {
opts.Protocol = utils.TCP
} else if protocolStr == "udp" {
opts.Protocol = utils.UDP
} else {
logger.Error("--protocol has to be one of udp, tcp", slog.String("protocol", protocolStr))
os.Exit(1)
}

if Opts.Mark < 0 {
Opts.Logger.Error("--mark has to be >= 0", slog.Int("mark", Opts.Mark))
if opts.Mark < 0 {
logger.Error("--mark has to be >= 0", slog.Int("mark", opts.Mark))
os.Exit(1)
}

if Opts.Verbose < 0 {
Opts.Logger.Error("-v has to be >= 0", slog.Int("verbose", Opts.Verbose))
if opts.Verbose < 0 {
logger.Error("-v has to be >= 0", slog.Int("verbose", opts.Verbose))
os.Exit(1)
}

if Opts.Listeners < 1 {
Opts.Logger.Error("--listeners has to be >= 1")
if listeners < 1 {
logger.Error("--listeners has to be >= 1")
os.Exit(1)
}

var err error
if Opts.ListenAddr, err = parseHostPort(Opts.ListenAddrStr); err != nil {
Opts.Logger.Error("listen address is malformed", "error", err)
if opts.ListenAddr, err = utils.ParseHostPort(listenAddrStr); err != nil {
logger.Error("listen address is malformed", "error", err)
os.Exit(1)
}

if Opts.TargetAddr4, err = netip.ParseAddrPort(Opts.TargetAddr4Str); err != nil {
Opts.Logger.Error("ipv4 target address is malformed", "error", err)
if opts.TargetAddr4, err = netip.ParseAddrPort(targetAddr4Str); err != nil {
logger.Error("ipv4 target address is malformed", "error", err)
os.Exit(1)
}
if !Opts.TargetAddr4.Addr().Is4() {
Opts.Logger.Error("ipv4 target address is not IPv4")
if !opts.TargetAddr4.Addr().Is4() {
logger.Error("ipv4 target address is not IPv4")
os.Exit(1)
}

if Opts.TargetAddr6, err = netip.ParseAddrPort(Opts.TargetAddr6Str); err != nil {
Opts.Logger.Error("ipv6 target address is malformed", "error", err)
if opts.TargetAddr6, err = netip.ParseAddrPort(targetAddr6Str); err != nil {
logger.Error("ipv6 target address is malformed", "error", err)
os.Exit(1)
}
if !Opts.TargetAddr6.Addr().Is6() {
Opts.Logger.Error("ipv6 target address is not IPv6")
if !opts.TargetAddr6.Addr().Is6() {
logger.Error("ipv6 target address is not IPv6")
os.Exit(1)
}

if Opts.udpCloseAfter < 0 {
Opts.Logger.Error("--close-after has to be >= 0", slog.Int("close-after", Opts.udpCloseAfter))
if udpCloseAfterInt < 0 {
logger.Error("--close-after has to be >= 0", slog.Int("close-after", udpCloseAfterInt))
os.Exit(1)
}
Opts.UDPCloseAfter = time.Duration(Opts.udpCloseAfter) * time.Second
opts.UDPCloseAfter = time.Duration(udpCloseAfterInt) * time.Second

listenErrors := make(chan error, Opts.Listeners)
for i := 0; i < Opts.Listeners; i++ {
go listen(i, listenErrors)
listenErrors := make(chan error, listeners)
ctxs := make([]context.Context, listeners)
for i := range ctxs {
ctxs[i] = context.Background()
go listen(ctxs[i], i, logger, listenErrors)
}
for i := 0; i < Opts.Listeners; i++ {
for range ctxs {
<-listenErrors
}
}
Loading
Loading