From 6ebfd2ba3fa52855bfb9f85ac3f7fb5faf02bf7e Mon Sep 17 00:00:00 2001 From: Joseph Anttila Hall Date: Fri, 12 Apr 2024 09:55:58 -0700 Subject: [PATCH] Replace Backend interface with a struct. Comparisons (like those in DefaultBackendStorage) should be by pointer value. --- pkg/server/backend_manager.go | 54 +++++++++------------ pkg/server/backend_manager_test.go | 24 ++++----- pkg/server/default_route_backend_manager.go | 6 +-- pkg/server/desthost_backend_manager.go | 6 +-- pkg/server/server.go | 20 ++++---- pkg/server/server_test.go | 14 +++--- 6 files changed, 57 insertions(+), 67 deletions(-) diff --git a/pkg/server/backend_manager.go b/pkg/server/backend_manager.go index 3339153fe..cada4961b 100644 --- a/pkg/server/backend_manager.go +++ b/pkg/server/backend_manager.go @@ -77,17 +77,7 @@ func GenProxyStrategiesFromStr(proxyStrategies string) ([]ProxyStrategy, error) // In the only currently supported case (gRPC), it wraps an // agent.AgentService_ConnectServer, provides synchronization and // emits common stream metrics. -type Backend interface { - Send(p *client.Packet) error - Recv() (*client.Packet, error) - Context() context.Context - GetAgentID() string - GetAgentIdentifiers() header.Identifiers -} - -var _ Backend = &backend{} - -type backend struct { +type Backend struct { sendLock sync.Mutex recvLock sync.Mutex conn agent.AgentService_ConnectServer @@ -97,7 +87,7 @@ type backend struct { idents header.Identifiers } -func (b *backend) Send(p *client.Packet) error { +func (b *Backend) Send(p *client.Packet) error { b.sendLock.Lock() defer b.sendLock.Unlock() @@ -110,7 +100,7 @@ func (b *backend) Send(p *client.Packet) error { return err } -func (b *backend) Recv() (*client.Packet, error) { +func (b *Backend) Recv() (*client.Packet, error) { b.recvLock.Lock() defer b.recvLock.Unlock() @@ -126,16 +116,16 @@ func (b *backend) Recv() (*client.Packet, error) { return pkt, nil } -func (b *backend) Context() context.Context { +func (b *Backend) Context() context.Context { // TODO: does Context require lock protection? return b.conn.Context() } -func (b *backend) GetAgentID() string { +func (b *Backend) GetAgentID() string { return b.id } -func (b *backend) GetAgentIdentifiers() header.Identifiers { +func (b *Backend) GetAgentIdentifiers() header.Identifiers { return b.idents } @@ -168,7 +158,7 @@ func getAgentIdentifiers(conn agent.AgentService_ConnectServer) (header.Identifi return header.GenAgentIdentifiers(agentIdent[0]) } -func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) { +func NewBackend(conn agent.AgentService_ConnectServer) (*Backend, error) { agentID, err := getAgentID(conn) if err != nil { return nil, err @@ -177,16 +167,16 @@ func NewBackend(conn agent.AgentService_ConnectServer) (Backend, error) { if err != nil { return nil, err } - return &backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil + return &Backend{conn: conn, id: agentID, idents: agentIdentifiers}, nil } // BackendStorage is an interface to manage the storage of the backend // connections, i.e., get, add and remove type BackendStorage interface { // addBackend adds a backend. - addBackend(identifier string, idType header.IdentifierType, backend Backend) + addBackend(identifier string, idType header.IdentifierType, backend *Backend) // removeBackend removes a backend. - removeBackend(identifier string, idType header.IdentifierType, backend Backend) + removeBackend(identifier string, idType header.IdentifierType, backend *Backend) // NumBackends returns the number of backends. NumBackends() int } @@ -199,11 +189,11 @@ type BackendManager interface { // context instead of a request-scoped context, as the backend manager will // pick a backend for every tunnel session and each tunnel session may // contains multiple requests. - Backend(ctx context.Context) (Backend, error) + Backend(ctx context.Context) (*Backend, error) // AddBackend adds a backend. - AddBackend(backend Backend) + AddBackend(backend *Backend) // RemoveBackend adds a backend. - RemoveBackend(backend Backend) + RemoveBackend(backend *Backend) BackendStorage ReadinessManager } @@ -215,18 +205,18 @@ type DefaultBackendManager struct { *DefaultBackendStorage } -func (dbm *DefaultBackendManager) Backend(_ context.Context) (Backend, error) { +func (dbm *DefaultBackendManager) Backend(_ context.Context) (*Backend, error) { klog.V(5).InfoS("Get a random backend through the DefaultBackendManager") return dbm.DefaultBackendStorage.GetRandomBackend() } -func (dbm *DefaultBackendManager) AddBackend(backend Backend) { +func (dbm *DefaultBackendManager) AddBackend(backend *Backend) { agentID := backend.GetAgentID() klog.V(5).InfoS("Add the agent to DefaultBackendManager", "agentID", agentID) dbm.addBackend(agentID, header.UID, backend) } -func (dbm *DefaultBackendManager) RemoveBackend(backend Backend) { +func (dbm *DefaultBackendManager) RemoveBackend(backend *Backend) { agentID := backend.GetAgentID() klog.V(5).InfoS("Remove the agent from the DefaultBackendManager", "agentID", agentID) dbm.removeBackend(agentID, header.UID, backend) @@ -242,7 +232,7 @@ type DefaultBackendStorage struct { // // TODO: fix documentation. This is not always agentID, e.g. in // the case of DestHostBackendManager. - backends map[string][]Backend + backends map[string][]*Backend // agentID is tracked in this slice to enable randomly picking an // agentID in the Backend() method. There is no reliable way to // randomly pick a key from a map (in this case, the backends) in @@ -272,7 +262,7 @@ func NewDefaultBackendStorage(idTypes []header.IdentifierType) *DefaultBackendSt // no agent ever successfully connects. metrics.Metrics.SetBackendCount(0) return &DefaultBackendStorage{ - backends: make(map[string][]Backend), + backends: make(map[string][]*Backend), random: rand.New(rand.NewSource(time.Now().UnixNano())), idTypes: idTypes, } /* #nosec G404 */ @@ -283,7 +273,7 @@ func containIDType(idTypes []header.IdentifierType, idType header.IdentifierType } // addBackend adds a backend. -func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend Backend) { +func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend *Backend) { if !containIDType(s.idTypes, idType) { klog.V(4).InfoS("fail to add backend", "backend", identifier, "error", &ErrWrongIDType{idType, s.idTypes}) return @@ -302,7 +292,7 @@ func (s *DefaultBackendStorage) addBackend(identifier string, idType header.Iden s.backends[identifier] = append(s.backends[identifier], backend) return } - s.backends[identifier] = []Backend{backend} + s.backends[identifier] = []*Backend{backend} metrics.Metrics.SetBackendCount(len(s.backends)) s.agentIDs = append(s.agentIDs, identifier) if idType == header.DefaultRoute { @@ -311,7 +301,7 @@ func (s *DefaultBackendStorage) addBackend(identifier string, idType header.Iden } // removeBackend removes a backend. -func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.IdentifierType, backend Backend) { +func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.IdentifierType, backend *Backend) { if !containIDType(s.idTypes, idType) { klog.ErrorS(&ErrWrongIDType{idType, s.idTypes}, "fail to remove backend") return @@ -390,7 +380,7 @@ func ignoreNotFound(err error) error { } // GetRandomBackend returns a random backend connection from all connected agents. -func (s *DefaultBackendStorage) GetRandomBackend() (Backend, error) { +func (s *DefaultBackendStorage) GetRandomBackend() (*Backend, error) { s.mu.Lock() defer s.mu.Unlock() if len(s.backends) == 0 { diff --git a/pkg/server/backend_manager_test.go b/pkg/server/backend_manager_test.go index 1c90cae8f..14de3b2ac 100644 --- a/pkg/server/backend_manager_test.go +++ b/pkg/server/backend_manager_test.go @@ -119,7 +119,7 @@ func TestDefaultBackendManager_AddRemoveBackends(t *testing.T) { p.AddBackend(backend1) p.RemoveBackend(backend1) - expectedBackends := make(map[string][]Backend) + expectedBackends := make(map[string][]*Backend) expectedAgentIDs := []string{} expectedDefaultRouteAgentIDs := []string(nil) if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { @@ -143,7 +143,7 @@ func TestDefaultBackendManager_AddRemoveBackends(t *testing.T) { p.RemoveBackend(backend22) p.RemoveBackend(backend2) p.RemoveBackend(backend1) - expectedBackends = map[string][]Backend{ + expectedBackends = map[string][]*Backend{ "agent1": {backend12}, "agent3": {backend3}, } @@ -174,7 +174,7 @@ func TestDefaultRouteBackendManager_AddRemoveBackends(t *testing.T) { p.AddBackend(backend1) p.RemoveBackend(backend1) - expectedBackends := make(map[string][]Backend) + expectedBackends := make(map[string][]*Backend) expectedAgentIDs := []string{} expectedDefaultRouteAgentIDs := []string{} if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { @@ -199,7 +199,7 @@ func TestDefaultRouteBackendManager_AddRemoveBackends(t *testing.T) { p.RemoveBackend(backend2) p.RemoveBackend(backend1) - expectedBackends = map[string][]Backend{ + expectedBackends = map[string][]*Backend{ "agent1": {backend12}, "agent3": {backend3}, } @@ -231,7 +231,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) { p.AddBackend(backend1) p.RemoveBackend(backend1) - expectedBackends := make(map[string][]Backend) + expectedBackends := make(map[string][]*Backend) expectedAgentIDs := []string{} expectedDefaultRouteAgentIDs := []string(nil) if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { @@ -247,7 +247,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) { p = NewDestHostBackendManager() p.AddBackend(backend1) - expectedBackends = map[string][]Backend{ + expectedBackends = map[string][]*Backend{ "localhost": {backend1}, "1.2.3.4": {backend1}, "9878::7675:1292:9183:7562": {backend1}, @@ -273,7 +273,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) { p.AddBackend(backend2) p.AddBackend(backend3) - expectedBackends = map[string][]Backend{ + expectedBackends = map[string][]*Backend{ "localhost": {backend1}, "node1.mydomain.com": {backend1}, "node2.mydomain.com": {backend3}, @@ -306,7 +306,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) { p.RemoveBackend(backend2) p.RemoveBackend(backend1) - expectedBackends = map[string][]Backend{ + expectedBackends = map[string][]*Backend{ "node2.mydomain.com": {backend3}, "5.6.7.8": {backend3}, "::": {backend3}, @@ -328,7 +328,7 @@ func TestDestHostBackendManager_AddRemoveBackends(t *testing.T) { } p.RemoveBackend(backend3) - expectedBackends = map[string][]Backend{} + expectedBackends = map[string][]*Backend{} expectedAgentIDs = []string{} if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { @@ -356,7 +356,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) { p.AddBackend(backend2) p.AddBackend(backend3) - expectedBackends := map[string][]Backend{ + expectedBackends := map[string][]*Backend{ "localhost": {backend1, backend2, backend3}, "1.2.3.4": {backend1, backend2}, "5.6.7.8": {backend3}, @@ -389,7 +389,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) { p.RemoveBackend(backend1) p.RemoveBackend(backend3) - expectedBackends = map[string][]Backend{ + expectedBackends = map[string][]*Backend{ "localhost": {backend2}, "1.2.3.4": {backend2}, "9878::7675:1292:9183:7562": {backend2}, @@ -413,7 +413,7 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) { } p.RemoveBackend(backend2) - expectedBackends = map[string][]Backend{} + expectedBackends = map[string][]*Backend{} expectedAgentIDs = []string{} if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { diff --git a/pkg/server/default_route_backend_manager.go b/pkg/server/default_route_backend_manager.go index 4bd18749d..f4480eab6 100644 --- a/pkg/server/default_route_backend_manager.go +++ b/pkg/server/default_route_backend_manager.go @@ -36,11 +36,11 @@ func NewDefaultRouteBackendManager() *DefaultRouteBackendManager { } // Backend tries to get a backend that advertises default route, with random selection. -func (dibm *DefaultRouteBackendManager) Backend(_ context.Context) (Backend, error) { +func (dibm *DefaultRouteBackendManager) Backend(_ context.Context) (*Backend, error) { return dibm.GetRandomBackend() } -func (dibm *DefaultRouteBackendManager) AddBackend(backend Backend) { +func (dibm *DefaultRouteBackendManager) AddBackend(backend *Backend) { agentID := backend.GetAgentID() agentIdentifiers := backend.GetAgentIdentifiers() if agentIdentifiers.DefaultRoute { @@ -49,7 +49,7 @@ func (dibm *DefaultRouteBackendManager) AddBackend(backend Backend) { } } -func (dibm *DefaultRouteBackendManager) RemoveBackend(backend Backend) { +func (dibm *DefaultRouteBackendManager) RemoveBackend(backend *Backend) { agentID := backend.GetAgentID() agentIdentifiers := backend.GetAgentIdentifiers() if agentIdentifiers.DefaultRoute { diff --git a/pkg/server/desthost_backend_manager.go b/pkg/server/desthost_backend_manager.go index 2857659c5..8c914e4ba 100644 --- a/pkg/server/desthost_backend_manager.go +++ b/pkg/server/desthost_backend_manager.go @@ -35,7 +35,7 @@ func NewDestHostBackendManager() *DestHostBackendManager { []header.IdentifierType{header.IPv4, header.IPv6, header.Host})} } -func (dibm *DestHostBackendManager) AddBackend(backend Backend) { +func (dibm *DestHostBackendManager) AddBackend(backend *Backend) { agentIdentifiers := backend.GetAgentIdentifiers() for _, ipv4 := range agentIdentifiers.IPv4 { klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv4) @@ -51,7 +51,7 @@ func (dibm *DestHostBackendManager) AddBackend(backend Backend) { } } -func (dibm *DestHostBackendManager) RemoveBackend(backend Backend) { +func (dibm *DestHostBackendManager) RemoveBackend(backend *Backend) { agentIdentifiers := backend.GetAgentIdentifiers() for _, ipv4 := range agentIdentifiers.IPv4 { klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv4) @@ -68,7 +68,7 @@ func (dibm *DestHostBackendManager) RemoveBackend(backend Backend) { } // Backend tries to get a backend associating to the request destination host. -func (dibm *DestHostBackendManager) Backend(ctx context.Context) (Backend, error) { +func (dibm *DestHostBackendManager) Backend(ctx context.Context) (*Backend, error) { dibm.mu.RLock() defer dibm.mu.RUnlock() if len(dibm.backends) == 0 { diff --git a/pkg/server/server.go b/pkg/server/server.go index 94a76feec..9be2cf20b 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -102,7 +102,7 @@ type ProxyClientConnection struct { connectID int64 agentID string start time.Time - backend Backend + backend *Backend dialAddress string // cached for logging } @@ -245,7 +245,7 @@ func genContext(proxyStrategies []ProxyStrategy, reqHost string) context.Context return ctx } -func (s *ProxyServer) getBackend(reqHost string) (Backend, error) { +func (s *ProxyServer) getBackend(reqHost string) (*Backend, error) { ctx := genContext(s.proxyStrategies, reqHost) for _, bm := range s.BackendManagers { be, err := bm.Backend(ctx) @@ -261,14 +261,14 @@ func (s *ProxyServer) getBackend(reqHost string) (Backend, error) { return nil, &ErrNotFound{} } -func (s *ProxyServer) addBackend(backend Backend) { +func (s *ProxyServer) addBackend(backend *Backend) { // TODO: refactor BackendStorage to acquire lock once, not up to 3 times. for _, bm := range s.BackendManagers { bm.AddBackend(backend) } } -func (s *ProxyServer) removeBackend(backend Backend) { +func (s *ProxyServer) removeBackend(backend *Backend) { for _, bm := range s.BackendManagers { bm.RemoveBackend(backend) } @@ -318,7 +318,7 @@ func (s *ProxyServer) getFrontend(agentID string, connID int64) (*ProxyClientCon return conn, nil } -func (s *ProxyServer) removeEstablishedForBackendConn(agentID string, backend Backend) ([]*ProxyClientConnection, error) { +func (s *ProxyServer) removeEstablishedForBackendConn(agentID string, backend *Backend) ([]*ProxyClientConnection, error) { var ret []*ProxyClientConnection if backend == nil { return ret, nil @@ -492,7 +492,7 @@ func (s *ProxyServer) serveRecvFrontend(frontend *GrpcFrontend, recvCh <-chan *c // backend from the BackendManger then. // TODO: either add agentID to protocol (DATA, CLOSE_RSP, etc) or replace {agentID, // connectionID} with a simpler key (#462). - var backend Backend + var backend *Backend var err error defer func() { @@ -759,7 +759,7 @@ func (s *ProxyServer) Connect(stream agent.AgentService_ConnectServer) error { return <-stopCh } -func (s *ProxyServer) readBackendToChannel(backend Backend, recvCh chan *client.Packet, stopCh chan error) { +func (s *ProxyServer) readBackendToChannel(backend *Backend, recvCh chan *client.Packet, stopCh chan error) { agentID := backend.GetAgentID() for { in, err := backend.Recv() @@ -792,7 +792,7 @@ func (s *ProxyServer) readBackendToChannel(backend Backend, recvCh chan *client. } // route the packet back to the correct client -func (s *ProxyServer) serveRecvBackend(backend Backend, agentID string, recvCh <-chan *client.Packet) { +func (s *ProxyServer) serveRecvBackend(backend *Backend, agentID string, recvCh <-chan *client.Packet) { defer func() { // Drain recvCh to ensure that readBackendToChannel is not blocked on a channel write. // This should never happen, as termination of this function should only be initiated by closing recvCh. @@ -948,7 +948,7 @@ func (s *ProxyServer) serveRecvBackend(backend Backend, agentID string, recvCh < klog.V(5).InfoS("Close backend of agent", "agentID", agentID) } -func (s *ProxyServer) sendBackendClose(backend Backend, connectID int64, random int64, reason string) { +func (s *ProxyServer) sendBackendClose(backend *Backend, connectID int64, random int64, reason string) { agentID := backend.GetAgentID() pkt := &client.Packet{ Type: client.PacketType_CLOSE_REQ, @@ -963,7 +963,7 @@ func (s *ProxyServer) sendBackendClose(backend Backend, connectID int64, random } } -func (s *ProxyServer) sendBackendDialClose(backend Backend, random int64, reason string) { +func (s *ProxyServer) sendBackendDialClose(backend *Backend, random int64, reason string) { agentID := backend.GetAgentID() pkt := &client.Packet{ Type: client.PacketType_DIAL_CLS, diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index b96fdfe13..6d39bb268 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -540,9 +540,9 @@ func TestEstablishedConnsMetric(t *testing.T) { } func TestRemoveEstablishedForBackendConn(t *testing.T) { - backend1 := &backend{} - backend2 := &backend{} - backend3 := &backend{} + backend1 := &Backend{} + backend2 := &Backend{} + backend3 := &Backend{} agent1ConnID1 := &ProxyClientConnection{backend: backend1} agent1ConnID2 := &ProxyClientConnection{backend: backend1} agent2ConnID1 := &ProxyClientConnection{backend: backend2} @@ -571,9 +571,9 @@ func TestRemoveEstablishedForBackendConn(t *testing.T) { func TestRemoveEstablishedForStream(t *testing.T) { streamUID := "target-uuid" - backend1 := &backend{} - backend2 := &backend{} - backend3 := &backend{} + backend1 := &Backend{} + backend2 := &Backend{} + backend3 := &Backend{} agent1ConnID1 := &ProxyClientConnection{backend: backend1, frontend: &GrpcFrontend{streamUID: streamUID}} agent1ConnID2 := &ProxyClientConnection{backend: backend1} agent2ConnID1 := &ProxyClientConnection{backend: backend2, frontend: &GrpcFrontend{streamUID: streamUID}} @@ -613,7 +613,7 @@ func prepareFrontendConn(ctrl *gomock.Controller) *agentmock.MockAgentService_Co return frontendConn } -func prepareAgentConnMD(t testing.TB, ctrl *gomock.Controller, proxyServer *ProxyServer) (*agentmock.MockAgentService_ConnectServer, Backend) { +func prepareAgentConnMD(t testing.TB, ctrl *gomock.Controller, proxyServer *ProxyServer) (*agentmock.MockAgentService_ConnectServer, *Backend) { t.Helper() // prepare the the connection to agent of proxy-server agentConn := agentmock.NewMockAgentService_ConnectServer(ctrl)