diff --git a/dialer/bandit.go b/dialer/bandit.go index bd26679fe..8ab817609 100644 --- a/dialer/bandit.go +++ b/dialer/bandit.go @@ -19,10 +19,12 @@ import ( // banditDialer is responsible for continually choosing the optimized dialer. type banditDialer struct { - dialers []ProxyDialer - bandit bandit.Bandit - opts *Options - banditRewardsMutex *sync.Mutex + dialers []ProxyDialer + bandit bandit.Bandit + opts *Options + banditRewardsMutex *sync.Mutex + secondsUntilRewardSample time.Duration + secondsUntilSaveBanditRewards time.Duration } type banditMetrics struct { @@ -47,9 +49,11 @@ func NewBandit(opts *Options) (Dialer, error) { var b bandit.Bandit var err error dialer := &banditDialer{ - dialers: dialers, - opts: opts, - banditRewardsMutex: &sync.Mutex{}, + dialers: dialers, + opts: opts, + banditRewardsMutex: &sync.Mutex{}, + secondsUntilRewardSample: secondsForSample * time.Second, + secondsUntilSaveBanditRewards: saveBanditRewardsAfter, } dialerWeights, err := dialer.loadLastBanditRewards() @@ -134,16 +138,17 @@ func (bd *banditDialer) DialContext(ctx context.Context, network, addr string) ( // Tell the dialer to update the bandit with it's throughput after 5 seconds. var dataRecv atomic.Uint64 - dt := newDataTrackingConn(conn, &dataRecv) - time.AfterFunc(secondsForSample*time.Second, func() { - speed := normalizeReceiveSpeed(dataRecv.Load()) - //log.Debugf("Dialer %v received %v bytes in %v seconds, normalized speed: %v", d.Name(), dt.dataRecv, secondsForSample, speed) + var elapsedTimeReading atomic.Int64 + dt := newDataTrackingConn(conn, &dataRecv, &elapsedTimeReading) + time.AfterFunc(bd.secondsUntilRewardSample, func() { + speed := normalizeReceiveSpeed(dataRecv.Load(), elapsedTimeReading.Load()) + // log.Debugf("Dialer %v received %v bytes in %v seconds, normalized speed: %v", d.Name(), dt.dataRecv, secondsForSample, speed) if errUpdatingBanditReward := bd.bandit.Update(chosenArm, speed); errUpdatingBanditReward != nil { - log.Errorf("unable to update bandit: %v", errUpdatingBanditReward) + log.Errorf("unable to update bandit: %v", err) } }) - time.AfterFunc(30*time.Second, func() { + time.AfterFunc(bd.secondsUntilSaveBanditRewards, func() { log.Debugf("saving bandit rewards") metrics := make(map[string]banditMetrics) rewards := bd.bandit.GetRewards() @@ -339,13 +344,15 @@ func differentArm(existingArm, numDialers int) int { const secondsForSample = 6 +const saveBanditRewardsAfter = 30 * time.Second + // A reasonable upper bound for the top expected bytes to receive per second. // Anything over this will be normalized to over 1. const topExpectedBps = 125000 -func normalizeReceiveSpeed(dataRecv uint64) float64 { +func normalizeReceiveSpeed(dataRecv uint64, elapsedTimeReading int64) float64 { // Record the bytes in relation to the top expected speed. - return (float64(dataRecv) / secondsForSample) / topExpectedBps + return (float64(dataRecv) / (float64(elapsedTimeReading) / 1000)) / topExpectedBps } func (bd *banditDialer) Close() { @@ -355,20 +362,24 @@ func (bd *banditDialer) Close() { } } -func newDataTrackingConn(conn net.Conn, dataRecv *atomic.Uint64) *dataTrackingConn { +func newDataTrackingConn(conn net.Conn, dataRecv *atomic.Uint64, elapsedTimeReading *atomic.Int64) *dataTrackingConn { return &dataTrackingConn{ - Conn: conn, - dataRecv: dataRecv, + Conn: conn, + dataRecv: dataRecv, + elapsedTimeReading: elapsedTimeReading, } } type dataTrackingConn struct { net.Conn - dataRecv *atomic.Uint64 + dataRecv *atomic.Uint64 + elapsedTimeReading *atomic.Int64 // elapsedTimeReading store in milliseconds the time the connection took to read data } func (c *dataTrackingConn) Read(b []byte) (int, error) { + startedReading := time.Now() n, err := c.Conn.Read(b) c.dataRecv.Add(uint64(n)) + c.elapsedTimeReading.Add(time.Since(startedReading).Milliseconds()) return n, err } diff --git a/dialer/bandit_test.go b/dialer/bandit_test.go index d47ab2f38..9f5177161 100644 --- a/dialer/bandit_test.go +++ b/dialer/bandit_test.go @@ -17,6 +17,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) func TestBanditDialer_chooseDialerForDomain(t *testing.T) { @@ -230,7 +231,8 @@ func TestBanditDialer_DialContext(t *testing.T) { func Test_normalizeReceiveSpeed(t *testing.T) { type args struct { - dataRecv uint64 + dataRecv uint64 + elapsedTimeReading int64 } tests := []struct { name string @@ -240,7 +242,8 @@ func Test_normalizeReceiveSpeed(t *testing.T) { { name: "should return 0 if no data received", args: args{ - dataRecv: 0, + dataRecv: 0, + elapsedTimeReading: secondsForSample * 1000, }, want: func(got float64) bool { return got == 0 @@ -249,7 +252,8 @@ func Test_normalizeReceiveSpeed(t *testing.T) { { name: "should return 1 if pretty fast", args: args{ - dataRecv: topExpectedBps * secondsForSample, + dataRecv: topExpectedBps * secondsForSample, + elapsedTimeReading: secondsForSample * 1000, }, want: func(got float64) bool { return got == 1 @@ -258,17 +262,18 @@ func Test_normalizeReceiveSpeed(t *testing.T) { { name: "should return 1 if super fast", args: args{ - dataRecv: topExpectedBps * 50, + dataRecv: topExpectedBps * 50, + elapsedTimeReading: secondsForSample * 1000, }, want: func(got float64) bool { return got > 1 }, }, - { name: "should return <1 if sorta fast", args: args{ - dataRecv: 2000, + dataRecv: 2000, + elapsedTimeReading: secondsForSample * 1000, }, want: func(got float64) bool { return got > 0 && got < 1 @@ -277,7 +282,7 @@ func Test_normalizeReceiveSpeed(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := normalizeReceiveSpeed(tt.args.dataRecv); !tt.want(got) { + if got := normalizeReceiveSpeed(tt.args.dataRecv, tt.args.elapsedTimeReading); !assert.True(t, tt.want(got)) { t.Errorf("unexpected normalizeReceiveSpeed() = %v", got) } }) @@ -453,6 +458,7 @@ type tcpConnDialer struct { client net.Conn server net.Conn name string + dial func() (net.Conn, bool, error) } func (*tcpConnDialer) Ready() <-chan error { @@ -508,6 +514,10 @@ func (t *tcpConnDialer) DialContext(ctx context.Context, network string, addr st if t.shouldFail { return nil, true, io.EOF } + + if t.dial != nil { + return t.dial() + } return &net.TCPConn{}, false, nil } @@ -600,3 +610,61 @@ func (*tcpConnDialer) Trusted() bool { // WriteStats implements Dialer. func (*tcpConnDialer) WriteStats(w io.Writer) { } + +//go:generate mockgen -package=dialer -destination=mocks_test.go net Conn + +func TestBanditDialerIntegration(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + baseDialer := newTcpConnDialer() + message := "hello" + connSleepTime := 200 * time.Millisecond + + baseDialer.(*tcpConnDialer).dial = func() (net.Conn, bool, error) { + conn := NewMockConn(ctrl) + conn.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { + time.Sleep(connSleepTime) + return copy(b, []byte(message)), io.EOF + }).AnyTimes() + return conn, false, nil + } + + banditDir, err := os.MkdirTemp("", "bandit_dial_test") + require.NoError(t, err) + defer os.RemoveAll(banditDir) + + opts := &Options{ + Dialers: []ProxyDialer{baseDialer}, + BanditDir: banditDir, + } + bandit, err := NewBandit(opts) + require.NoError(t, err) + banditDialer := bandit.(*banditDialer) + banditDialer.secondsUntilRewardSample = 1 * time.Second + banditDialer.secondsUntilSaveBanditRewards = 1200 * time.Millisecond + + ctx := context.Background() + banditConn, err := banditDialer.DialContext(ctx, "tcp", "localhost:8080") + require.NoError(t, err) + + got, err := io.ReadAll(banditConn) + assert.NoError(t, err) + assert.Equal(t, message, string(got[:len(message)])) + + // waiting so reward is sampled and bandit rewards are stored + time.Sleep(1400 * time.Millisecond) + + rewards := banditDialer.bandit.GetRewards() + counts := banditDialer.bandit.GetCounts() + + // there's only one dialer + assert.Len(t, counts, 1) + assert.Len(t, rewards, 1) + // since there's only one dialer and one Dial call, we're expecting one count + assert.Equal(t, 1, counts[0]) + assert.InEpsilon(t, normalizeReceiveSpeed(uint64(len(got)), connSleepTime.Milliseconds()), rewards[0], 0.2) + + // check if rewards.csv was written + assert.FileExists(t, filepath.Join(banditDir, "rewards.csv")) +} diff --git a/dialer/mocks_test.go b/dialer/mocks_test.go new file mode 100644 index 000000000..22872dd7c --- /dev/null +++ b/dialer/mocks_test.go @@ -0,0 +1,156 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: net (interfaces: Conn) +// +// Generated by this command: +// +// mockgen -package=dialer -destination=mocks_test.go net Conn +// + +// Package dialer is a generated GoMock package. +package dialer + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "go.uber.org/mock/gomock" +) + +// MockConn is a mock of Conn interface. +type MockConn struct { + ctrl *gomock.Controller + recorder *MockConnMockRecorder + isgomock struct{} +} + +// MockConnMockRecorder is the mock recorder for MockConn. +type MockConnMockRecorder struct { + mock *MockConn +} + +// NewMockConn creates a new mock instance. +func NewMockConn(ctrl *gomock.Controller) *MockConn { + mock := &MockConn{ctrl: ctrl} + mock.recorder = &MockConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConn) EXPECT() *MockConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConn)(nil).Close)) +} + +// LocalAddr mocks base method. +func (m *MockConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockConn)(nil).LocalAddr)) +} + +// Read mocks base method. +func (m *MockConn) Read(b []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", b) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockConnMockRecorder) Read(b any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConn)(nil).Read), b) +} + +// RemoteAddr mocks base method. +func (m *MockConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockConn)(nil).RemoteAddr)) +} + +// SetDeadline mocks base method. +func (m *MockConn) SetDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockConnMockRecorder) SetDeadline(t any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockConn)(nil).SetDeadline), t) +} + +// SetReadDeadline mocks base method. +func (m *MockConn) SetReadDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockConnMockRecorder) SetReadDeadline(t any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockConn)(nil).SetReadDeadline), t) +} + +// SetWriteDeadline mocks base method. +func (m *MockConn) SetWriteDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockConnMockRecorder) SetWriteDeadline(t any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockConn)(nil).SetWriteDeadline), t) +} + +// Write mocks base method. +func (m *MockConn) Write(b []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", b) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockConnMockRecorder) Write(b any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConn)(nil).Write), b) +}