diff --git a/pkg/agent/clientset.go b/pkg/agent/clientset.go index 635bae64d..0bec82809 100644 --- a/pkg/agent/clientset.go +++ b/pkg/agent/clientset.go @@ -28,6 +28,7 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" "sigs.k8s.io/apiserver-network-proxy/pkg/agent/metrics" + "sigs.k8s.io/apiserver-network-proxy/pkg/servercounter" ) // ClientSet consists of clients connected to each instance of an HA proxy server. @@ -36,9 +37,9 @@ type ClientSet struct { clients map[string]*Client // map between serverID and the client // connects to this server. - agentID string // ID of this agent - address string // proxy server address. Assuming HA proxy server - serverCount int // number of proxy server instances, should be 1 + agentID string // ID of this agent + address string // proxy server address. Assuming HA proxy server + serverCounter servercounter.ServerCounter // counts number of proxy servers // unless it is an HA server. Initialized when the ClientSet creates // the first client. When syncForever is set, it will be the most recently seen. syncInterval time.Duration // The interval by which the agent @@ -63,6 +64,8 @@ type ClientSet struct { warnOnChannelLimit bool syncForever bool // Continue syncing (support dynamic server count). + + respectReceivedServerCount bool // Respect server count received from proxy server rather than relying on the agent's own server counter } func (cs *ClientSet) ClientsCount() int { @@ -145,19 +148,21 @@ type ClientSetConfig struct { func (cc *ClientSetConfig) NewAgentClientSet(drainCh, stopCh <-chan struct{}) *ClientSet { return &ClientSet{ - clients: make(map[string]*Client), - agentID: cc.AgentID, - agentIdentifiers: cc.AgentIdentifiers, - address: cc.Address, - syncInterval: cc.SyncInterval, - probeInterval: cc.ProbeInterval, - syncIntervalCap: cc.SyncIntervalCap, - dialOptions: cc.DialOptions, - serviceAccountTokenPath: cc.ServiceAccountTokenPath, - warnOnChannelLimit: cc.WarnOnChannelLimit, - syncForever: cc.SyncForever, - drainCh: drainCh, - stopCh: stopCh, + clients: make(map[string]*Client), + agentID: cc.AgentID, + agentIdentifiers: cc.AgentIdentifiers, + address: cc.Address, + syncInterval: cc.SyncInterval, + probeInterval: cc.ProbeInterval, + syncIntervalCap: cc.SyncIntervalCap, + dialOptions: cc.DialOptions, + serviceAccountTokenPath: cc.ServiceAccountTokenPath, + warnOnChannelLimit: cc.WarnOnChannelLimit, + syncForever: cc.SyncForever, + drainCh: drainCh, + stopCh: stopCh, + respectReceivedServerCount: true, + serverCounter: servercounter.StaticServerCounter(0), } } @@ -184,8 +189,9 @@ func (cs *ClientSet) sync() { if err := cs.connectOnce(); err != nil { if dse, ok := err.(*DuplicateServerError); ok { clientsCount := cs.ClientsCount() - klog.V(4).InfoS("duplicate server", "serverID", dse.ServerID, "serverCount", cs.serverCount, "clientsCount", clientsCount) - if cs.serverCount != 0 && clientsCount >= cs.serverCount { + serverCount := cs.serverCounter.CountServers() + klog.V(4).InfoS("duplicate server", "serverID", dse.ServerID, "serverCount", serverCount, "clientsCount", clientsCount) + if serverCount != 0 && clientsCount >= serverCount { duration = backoff.Step() } else { backoff = cs.resetBackoff() @@ -209,19 +215,24 @@ func (cs *ClientSet) sync() { } func (cs *ClientSet) connectOnce() error { - if !cs.syncForever && cs.serverCount != 0 && cs.ClientsCount() >= cs.serverCount { + agentServerCount := cs.serverCounter.CountServers() + if !cs.syncForever && agentServerCount != 0 && cs.ClientsCount() >= agentServerCount { return nil } - c, serverCount, err := cs.newAgentClient() + c, newServerCount, err := cs.newAgentClient() if err != nil { return err } - if cs.serverCount != 0 && cs.serverCount != serverCount { + if agentServerCount != 0 && agentServerCount != newServerCount { klog.V(2).InfoS("Server count change suggestion by server", - "current", cs.serverCount, "serverID", c.serverID, "actual", serverCount) - + "current", agentServerCount, "serverID", c.serverID, "actual", newServerCount) + if cs.respectReceivedServerCount { + cs.serverCounter = servercounter.StaticServerCounter(newServerCount) + klog.V(2).Infof("respecting server count change suggestion, new count: %v", newServerCount) + } else { + klog.V(2).Infof("ignoring server count change suggestion") + } } - cs.serverCount = serverCount if err := cs.AddClient(c.serverID, c); err != nil { c.Close() return err diff --git a/pkg/server/server.go b/pkg/server/server.go index 793cd606c..32a2a0353 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -37,9 +37,9 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" "k8s.io/klog/v2" - commonmetrics "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/pkg/common/metrics" "sigs.k8s.io/apiserver-network-proxy/konnectivity-client/proto/client" + "sigs.k8s.io/apiserver-network-proxy/pkg/servercounter" "sigs.k8s.io/apiserver-network-proxy/pkg/server/metrics" "sigs.k8s.io/apiserver-network-proxy/pkg/util" @@ -210,8 +210,8 @@ type ProxyServer struct { PendingDial *PendingDialManager - serverID string // unique ID of this server - serverCount int // Number of proxy server instances, should be 1 unless it is a HA server. + serverID string // unique ID of this server + serverCounter servercounter.ServerCounter // provides number of proxy servers // agent authentication AgentAuthenticationOptions *AgentTokenAuthenticationOptions @@ -395,7 +395,7 @@ func NewProxyServer(serverID string, proxyStrategies []ProxyStrategy, serverCoun established: make(map[string](map[int64]*ProxyClientConnection)), PendingDial: NewPendingDialManager(), serverID: serverID, - serverCount: serverCount, + serverCounter: servercounter.StaticServerCounter(serverCount), BackendManagers: bms, AgentAuthenticationOptions: agentAuthenticationOptions, // use the first backend-manager as the Readiness Manager @@ -441,7 +441,7 @@ func (s *ProxyServer) Proxy(stream client.ProxyService_ProxyServer) error { }() labels := runpprof.Labels( - "serverCount", strconv.Itoa(s.serverCount), + "serverCount", strconv.Itoa(s.serverCounter.CountServers()), "userAgent", strings.Join(userAgent, ", "), ) // Start goroutine to receive packets from frontend and push to recvCh @@ -722,7 +722,7 @@ func (s *ProxyServer) Connect(stream agent.AgentService_ConnectServer) error { klog.V(5).InfoS("Connect request from agent", "agentID", agentID, "serverID", s.serverID) labels := runpprof.Labels( - "serverCount", strconv.Itoa(s.serverCount), + "serverCount", strconv.Itoa(s.serverCounter.CountServers()), "agentID", agentID, ) ctx := runpprof.WithLabels(context.Background(), labels) @@ -735,7 +735,7 @@ func (s *ProxyServer) Connect(stream agent.AgentService_ConnectServer) error { } } - h := metadata.Pairs(header.ServerID, s.serverID, header.ServerCount, strconv.Itoa(s.serverCount)) + h := metadata.Pairs(header.ServerID, s.serverID, header.ServerCount, strconv.Itoa(s.serverCounter.CountServers())) if err := stream.SendHeader(h); err != nil { klog.ErrorS(err, "Failed to send server count back to agent", "agentID", agentID) return err diff --git a/pkg/servercounter/cache.go b/pkg/servercounter/cache.go new file mode 100644 index 000000000..9ab0f9f18 --- /dev/null +++ b/pkg/servercounter/cache.go @@ -0,0 +1,43 @@ +package servercounter + +import ( + "time" + + "k8s.io/klog/v2" +) + +// A CachedServerCounter wraps a ServerCounter to cache its server count value. +// Cache refreshes occur when CountServers() is called after a user-configurable +// cache expiration duration. +type CachedServerCounter struct { + inner ServerCounter + cachedCount int + expiration time.Duration + lastRefresh time.Time +} + +// CountServers returns the last cached server count and updates the cached count +// if it has expired since the last call. +func (csc *CachedServerCounter) CountServers() int { + // Refresh the cache if expiry time has passed since last call. + if time.Now().Sub(csc.lastRefresh) >= csc.expiration { + newCount := csc.inner.CountServers() + if newCount != csc.cachedCount { + klog.Infof("updated cached server count from %v to %v", csc.cachedCount, newCount) + } + csc.lastRefresh = time.Now() + } + + return csc.cachedCount +} + +// NewCachedServerCounter creates a new CachedServerCounter with a given expiration +// time wrapping the provided ServerCounter. +func NewCachedServerCounter(serverCounter ServerCounter, expiration time.Duration) *CachedServerCounter { + return &CachedServerCounter{ + inner: serverCounter, + cachedCount: serverCounter.CountServers(), + expiration: expiration, + lastRefresh: time.Now(), + } +} diff --git a/pkg/servercounter/cache_test.go b/pkg/servercounter/cache_test.go new file mode 100644 index 000000000..58b584675 --- /dev/null +++ b/pkg/servercounter/cache_test.go @@ -0,0 +1,52 @@ +package servercounter + +import ( + "testing" + "time" +) + +type MockServerCounter struct { + NumCalls int + Count int +} + +func (msc *MockServerCounter) CountServers() int { + msc.NumCalls += 1 + return msc.Count +} + +func TestCachedServerCounter(t *testing.T) { + initialCount := 3 + mockCounter := &MockServerCounter{NumCalls: 0, Count: initialCount} + + cacheExpiry := time.Second // This can be tuned down to make this test run faster, just don't make it too small! + cachedCounter := NewCachedServerCounter(mockCounter, cacheExpiry) + + if mockCounter.NumCalls != 1 { + t.Errorf("inner server counter have been called once during cached counter creation, got %v calls isntead", mockCounter.NumCalls) + } + + // Updates in the underlying ServerCounter should not matter until the cache expires. + callAttemptsWithoutRefresh := 5 + for i := 0; i < callAttemptsWithoutRefresh; i++ { + mockCounter.Count = 100 + i + time.Sleep(cacheExpiry / time.Duration(callAttemptsWithoutRefresh) / time.Duration(2)) + got := cachedCounter.CountServers() + if got != initialCount { + t.Errorf("server count should not have been updated yet; wanted: %v, got %v", initialCount, got) + } else if mockCounter.NumCalls != 1 { + t.Errorf("inner server counter should not have been called yet; expected 1 call, got %v instead", mockCounter.NumCalls) + } + } + + // Once the cache expires, the cached count should update by calling the underlying + // ServerCounter. + mockCounter.Count = 5 + time.Sleep(cacheExpiry) + got := cachedCounter.CountServers() + if got != initialCount { + t.Errorf("server count should have been updated yet; wanted: %v, got %v", mockCounter.Count, got) + } else if mockCounter.NumCalls != 2 { + t.Errorf("inner server counter should have been called one more time; expected 2 calls, got %v instead", mockCounter.NumCalls) + } +} diff --git a/pkg/servercounter/counter.go b/pkg/servercounter/counter.go new file mode 100644 index 000000000..b1a98593f --- /dev/null +++ b/pkg/servercounter/counter.go @@ -0,0 +1,14 @@ +package servercounter + +// A ServerCounter counts the number of available proxy servers. +type ServerCounter interface { + CountServers() int +} + +// A StaticServerCounter stores a static server count. +type StaticServerCounter int + +// CountServers returns the current (static) proxy server count. +func (sc StaticServerCounter) CountServers() int { + return int(sc) +}