diff --git a/p2p/transport/tcp/metrics.go b/p2p/transport/tcp/metrics.go index 213ee2200a..50820d870c 100644 --- a/p2p/transport/tcp/metrics.go +++ b/p2p/transport/tcp/metrics.go @@ -24,7 +24,7 @@ var ( const collectFrequency = 10 * time.Second -var collector *aggregatingCollector +var defaultCollector *aggregatingCollector var initMetricsOnce sync.Once @@ -34,8 +34,8 @@ func initMetrics() { bytesSentDesc = prometheus.NewDesc("tcp_sent_bytes", "TCP bytes sent", nil, nil) bytesRcvdDesc = prometheus.NewDesc("tcp_rcvd_bytes", "TCP bytes received", nil, nil) - collector = newAggregatingCollector() - prometheus.MustRegister(collector) + defaultCollector = newAggregatingCollector() + prometheus.MustRegister(defaultCollector) const direction = "direction" @@ -196,7 +196,7 @@ func (c *aggregatingCollector) Collect(metrics chan<- prometheus.Metric) { func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { c.mutex.Lock() - collector.removeConn(conn.id) + c.removeConn(conn.id) c.mutex.Unlock() closedConns.WithLabelValues(direction).Inc() } @@ -204,6 +204,8 @@ func (c *aggregatingCollector) ClosedConn(conn *tracingConn, direction string) { type tracingConn struct { id uint64 + collector *aggregatingCollector + startTime time.Time isClient bool @@ -213,7 +215,8 @@ type tracingConn struct { closeErr error } -func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { +// newTracingConn wraps a manet.Conn with a tracingConn. A nil collector will use the default collector. +func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (*tracingConn, error) { initMetricsOnce.Do(func() { initMetrics() }) conn, err := tcp.NewConn(c) if err != nil { @@ -224,8 +227,12 @@ func newTracingConn(c manet.Conn, isClient bool) (*tracingConn, error) { isClient: isClient, Conn: c, tcpConn: conn, + collector: collector, + } + if tc.collector == nil { + tc.collector = defaultCollector } - tc.id = collector.AddConn(tc) + tc.id = tc.collector.AddConn(tc) newConns.WithLabelValues(tc.getDirection()).Inc() return tc, nil } @@ -239,7 +246,7 @@ func (c *tracingConn) getDirection() string { func (c *tracingConn) Close() error { c.closeOnce.Do(func() { - collector.ClosedConn(c, c.getDirection()) + c.collector.ClosedConn(c, c.getDirection()) c.closeErr = c.Conn.Close() }) return c.closeErr @@ -258,10 +265,12 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { type tracingListener struct { manet.Listener + collector *aggregatingCollector } -func newTracingListener(l manet.Listener) *tracingListener { - return &tracingListener{Listener: l} +// newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector. +func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener { + return &tracingListener{Listener: l, collector: collector} } func (l *tracingListener) Accept() (manet.Conn, error) { @@ -269,5 +278,5 @@ func (l *tracingListener) Accept() (manet.Conn, error) { if err != nil { return nil, err } - return newTracingConn(conn, false) + return newTracingConn(conn, l.collector, false) } diff --git a/p2p/transport/tcp/metrics_none.go b/p2p/transport/tcp/metrics_none.go index 8538b30c89..cbee982070 100644 --- a/p2p/transport/tcp/metrics_none.go +++ b/p2p/transport/tcp/metrics_none.go @@ -6,5 +6,9 @@ package tcp import manet "github.com/multiformats/go-multiaddr/net" -func newTracingConn(c manet.Conn, _ bool) (manet.Conn, error) { return c, nil } -func newTracingListener(l manet.Listener) manet.Listener { return l } +type aggregatingCollector struct{} + +func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) { + return c, nil +} +func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l } diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index d52bb96019..a6f56be9ff 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -128,6 +128,8 @@ type TcpTransport struct { rcmgr network.ResourceManager reuse reuseport.Transport + + metricsCollector *aggregatingCollector } var _ transport.Transport = &TcpTransport{} @@ -212,7 +214,7 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p c := conn if t.enableMetrics { var err error - c, err = newTracingConn(conn, true) + c, err = newTracingConn(conn, t.metricsCollector, true) if err != nil { return nil, err } @@ -250,7 +252,7 @@ func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { return nil, err } if t.enableMetrics { - list = newTracingListener(&tcpListener{list, 0}) + list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector) } return t.upgrader.UpgradeListener(t, list), nil }