From f6f29e35a2f3aecb78afed7343f857f7384dd825 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 21 Sep 2023 15:02:45 +0530 Subject: [PATCH] webrtcprivate: fix deadline, limit inflight connection requests --- p2p/transport/webrtcprivate/listener.go | 55 ++- p2p/transport/webrtcprivate/transport.go | 99 ++-- p2p/transport/webrtcprivate/transport_test.go | 423 ++++++++++++++++-- 3 files changed, 483 insertions(+), 94 deletions(-) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 77f47eec61..26d5fbe6c1 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -18,10 +18,10 @@ import ( ) type listener struct { - t *transport - webrtcConfig webrtc.Configuration - conns chan tpt.CapableConn - closeC chan struct{} + transport *transport + connQueue chan tpt.CapableConn + closeC chan struct{} + inflightQueue chan struct{} } var _ tpt.Listener = &listener{} @@ -41,7 +41,7 @@ func (n NetAddr) String() string { // Accept implements transport.Listener. func (l *listener) Accept() (tpt.CapableConn, error) { select { - case c := <-l.conns: + case c := <-l.connQueue: return c, nil case <-l.closeC: return nil, tpt.ErrListenerClosed @@ -55,7 +55,7 @@ func (l *listener) Addr() net.Addr { // Close implements transport.Listener. func (l *listener) Close() error { - l.t.RemoveListener(l) + l.transport.RemoveListener(l) close(l.closeC) return nil } @@ -66,22 +66,28 @@ func (*listener) Multiaddr() ma.Multiaddr { } func (l *listener) handleIncoming(s network.Stream) { - ctx, cancel := context.WithTimeout(context.Background(), streamTimeout) + select { + case l.inflightQueue <- struct{}{}: + defer func() { <-l.inflightQueue }() + case <-l.closeC: + s.Reset() + return + } + + ctx, cancel := context.WithTimeout(context.Background(), connectTimeout) defer cancel() defer s.Close() - s.SetDeadline(time.Now().Add(streamTimeout)) - scope, err := l.t.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) + s.SetDeadline(time.Now().Add(connectTimeout)) + + scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, false, ma.StringCast("/webrtc")) if err != nil { s.Reset() log.Debug("failed to create connection scope:", err) return } - settings := webrtc.SettingEngine{} - settings.DetachDataChannels() - api := webrtc.NewAPI(webrtc.WithSettingEngine(settings)) - pc, err := api.NewPeerConnection(l.webrtcConfig) + pc, err := l.transport.NewPeerConnection() if err != nil { s.Reset() log.Debug("error creating a webrtc.PeerConnection:", err) @@ -209,7 +215,7 @@ func (l *listener) handleIncoming(s network.Stream) { readErr <- fmt.Errorf("invalid message: msg.Type expected %s got %s", pb.Message_ICE_CANDIDATE, msg.Type) return } - // Ignore without erroring on empty message. + // Ignore without Debuging on empty message. // Pion has a case where OnCandidate callback may be called with a nil // candidate if msg.Data == nil || *msg.Data == "" { @@ -233,42 +239,45 @@ func (l *listener) handleIncoming(s network.Stream) { case <-ctx.Done(): pc.Close() s.Reset() - log.Error(ctx.Err()) + log.Debug(ctx.Err()) return case err := <-writeErr: pc.Close() s.Reset() - log.Error(err) + log.Debug(err) return case err := <-readErr: pc.Close() s.Reset() - log.Error(err) + log.Debug(err) return case state := <-connectionState: switch state { default: pc.Close() s.Reset() + log.Debugf("connection setup failed, got state: %s", state) return case webrtc.PeerConnectionStateConnected: conn, _ := libp2pwebrtc.NewWebRTCConnection( network.DirInbound, pc, - l.t, + l.transport, scope, - l.t.host.ID(), + l.transport.host.ID(), ma.StringCast("/webrtc"), s.Conn().RemotePeer(), - l.t.host.Peerstore().PubKey(s.Conn().RemotePeer()), + l.transport.host.Peerstore().PubKey(s.Conn().RemotePeer()), ma.StringCast("/webrtc"), ) + // Close the stream before we wait for the connection to be accepted + s.Close() select { - case l.conns <- conn: - default: + case l.connQueue <- conn: + case <-l.closeC: s.Reset() - log.Debug("incoming conn queue full: dropping conn from %s", s.Conn().RemotePeer()) conn.Close() + log.Debug("listener closed: dropping conn from %s", s.Conn().RemotePeer()) } return } diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index 897f8bb78f..68670eb78a 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -15,33 +15,43 @@ import ( logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" + pionlogger "github.com/pion/logging" + "github.com/libp2p/go-libp2p/core/peer" tpt "github.com/libp2p/go-libp2p/core/transport" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/libp2p/go-libp2p/p2p/transport/webrtcprivate/pb" "github.com/libp2p/go-msgio/pbio" "github.com/pion/webrtc/v3" + "go.uber.org/zap/zapcore" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" ) const ( - name = "webrtcprivate" - maxMsgSize = 4096 - streamTimeout = time.Minute - SignalingProtocol = "/webrtc-signaling" + name = "webrtcprivate" + maxMsgSize = 4096 + connectTimeout = time.Minute + SignalingProtocol = "/webrtc-signaling" + disconnectedTimeout = 20 * time.Second + failedTimeout = 30 * time.Second + keepaliveTimeout = 15 * time.Second ) -var log = logging.Logger("webrtcprivate") +var ( + log = logging.Logger("webrtcprivate") + WebRTCAddr = ma.StringCast("/webrtc") +) type transport struct { - host host.Host - rcmgr network.ResourceManager - webrtcConfig webrtc.Configuration + host host.Host + rcmgr network.ResourceManager + webrtcConfig webrtc.Configuration + maxInFlightConnections int - mu sync.Mutex - l *listener + mu sync.Mutex + listener *listener } var _ tpt.Transport = &transport{} @@ -93,9 +103,10 @@ func newTransport(h host.Host) (*transport, error) { } return &transport{ - host: h, - rcmgr: h.Network().ResourceManager(), - webrtcConfig: config, + host: h, + rcmgr: h.Network().ResourceManager(), + webrtcConfig: config, + maxInFlightConnections: 16, }, nil } @@ -113,7 +124,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, fmt.Errorf("failed to open %s stream: %w", SignalingProtocol, err) } - scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, true, raddr) if err != nil { log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) return nil, err @@ -147,12 +158,16 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. defer s.Scope().ReleaseMemory(maxMsgSize) defer s.Close() - s.SetDeadline(time.Now().Add(streamTimeout)) + deadline := time.Now().Add(connectTimeout) + if d, ok := ctx.Deadline(); ok && d.Before(deadline) { + deadline = d + } + s.SetDeadline(deadline) - pc, err := t.connect(ctx, s) + pc, err := t.establishPeerConnection(ctx, s) if err != nil { s.Reset() - return nil, fmt.Errorf("error creating webrtc.PeerConnection: %w", err) + return nil, fmt.Errorf("error establishing webrtc.PeerConnection: %w", err) } return libp2pwebrtc.NewWebRTCConnection( network.DirOutbound, @@ -167,15 +182,11 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. ) } -func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { - settings := webrtc.SettingEngine{} - settings.DetachDataChannels() - api := webrtc.NewAPI(webrtc.WithSettingEngine(settings)) - pc, err := api.NewPeerConnection(t.webrtcConfig) +func (t *transport) establishPeerConnection(ctx context.Context, s network.Stream) (*webrtc.PeerConnection, error) { + pc, err := t.NewPeerConnection() if err != nil { - return nil, fmt.Errorf("error creating peer connection: %w", err) + return nil, fmt.Errorf("failed to create webrtc.PeerConnection: %w", err) } - // Exchange offer and answer with peer r := pbio.NewDelimitedReader(s, maxMsgSize) w := pbio.NewDelimitedWriter(s) @@ -275,7 +286,7 @@ func (t *transport) connect(ctx context.Context, s network.Stream) (*webrtc.Peer } readErr := make(chan error, 1) - ctx, cancel := context.WithTimeout(ctx, streamTimeout) + ctx, cancel := context.WithTimeout(ctx, connectTimeout) defer cancel() // start a goroutine to read candidates go func() { @@ -342,17 +353,17 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { } t.mu.Lock() defer t.mu.Unlock() - if t.l != nil { + if t.listener != nil { return nil, errors.New("already listening on /webrtc") } l := &listener{ - t: t, - webrtcConfig: t.webrtcConfig, - conns: make(chan tpt.CapableConn, 8), - closeC: make(chan struct{}), + transport: t, + connQueue: make(chan tpt.CapableConn), + inflightQueue: make(chan struct{}, t.maxInFlightConnections), + closeC: make(chan struct{}), } - t.l = l + t.listener = l t.host.SetStreamHandler(SignalingProtocol, l.handleIncoming) return l, nil } @@ -360,8 +371,8 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { func (t *transport) RemoveListener(l *listener) { t.mu.Lock() defer t.mu.Unlock() - if t.l == l { - t.l = nil + if t.listener == l { + t.listener = nil t.host.RemoveStreamHandler(SignalingProtocol) } } @@ -376,6 +387,28 @@ func (*transport) Proxy() bool { return false } +func (t *transport) NewPeerConnection() (*webrtc.PeerConnection, error) { + loggerFactory := pionlogger.NewDefaultLoggerFactory() + logLevel := pionlogger.LogLevelDisabled + switch log.Level() { + case zapcore.DebugLevel: + logLevel = pionlogger.LogLevelDebug + case zapcore.InfoLevel: + logLevel = pionlogger.LogLevelInfo + case zapcore.WarnLevel: + logLevel = pionlogger.LogLevelWarn + case zapcore.ErrorLevel: + logLevel = pionlogger.LogLevelError + } + loggerFactory.DefaultLogLevel = logLevel + s := webrtc.SettingEngine{LoggerFactory: loggerFactory} + s.SetICETimeouts(disconnectedTimeout, failedTimeout, keepaliveTimeout) + s.DetachDataChannels() + s.SetIncludeLoopbackCandidate(true) + api := webrtc.NewAPI(webrtc.WithSettingEngine(s)) + return api.NewPeerConnection(t.webrtcConfig) +} + // getRelayAddr removes /webrtc from addr and returns a circuit v2 only address func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { first, rest := ma.SplitFunc(addr, func(c ma.Component) bool { diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index ccd9fcde1c..efdf491f0e 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -3,7 +3,9 @@ package libp2pwebrtcprivate import ( "context" "fmt" + "os" "sync" + "sync/atomic" "testing" "time" @@ -54,7 +56,9 @@ func newWebRTCHost(t *testing.T) *webrtcHost { func newRelayedHost(t *testing.T) *relayedHost { rh := blankhost.NewBlankHost(swarmt.GenSwarm(t)) - _, err := relay.New(rh) + rr := relay.DefaultResources() + rr.MaxCircuits = 100 + _, err := relay.New(rh, relay.WithResources(rr)) require.NoError(t, err) ps := swarmt.GenSwarm(t) @@ -100,6 +104,9 @@ func TestSingleDial(t *testing.T) { n, err := sb.Read(recv) require.NoError(t, err) require.Equal(t, "hello world", string(recv[:n])) + + ca.Close() + cb.Close() } func TestMultipleDials(t *testing.T) { @@ -113,25 +120,39 @@ func TestMultipleDials(t *testing.T) { defer b.Close() l, err := b.T.Listen(ma.StringCast("/webrtc")) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() ca, err := a.T.Dial(ctx, b.Addr, b.ID()) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } cb, err := l.Accept() - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } sa, err := ca.OpenStream(ctx) - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } sb, err := cb.AcceptStream() - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } sa.Write([]byte("hello world")) recv := make([]byte, 24) n, err := sb.Read(recv) - assert.NoError(t, err) - assert.Equal(t, "hello world", string(recv[:n])) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } wg.Done() }() } @@ -139,42 +160,368 @@ func TestMultipleDials(t *testing.T) { } func TestMultipleDialsAndListeners(t *testing.T) { - var hosts []*webrtcHost - for i := 0; i < 5; i++ { - hosts = append(hosts, newWebRTCHost(t)) + var dialHosts []*webrtcHost + const N = 5 + for i := 0; i < N; i++ { + dialHosts = append(dialHosts, newWebRTCHost(t)) + defer dialHosts[i].Close() + } + + var listenHosts []*relayedHost + for i := 0; i < N; i++ { + listenHosts = append(listenHosts, newRelayedHost(t)) + l, err := listenHosts[i].T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + defer listenHosts[i].Close() + defer l.Close() } var wg sync.WaitGroup - for i := 0; i < 5; i++ { - for j := 0; j < 5; j++ { + dialAndPing := func(h *webrtcHost, raddr ma.Multiaddr, p peer.ID) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + ca, err := h.T.Dial(ctx, raddr, p) + if !assert.NoError(t, err) { + return + } + defer ca.Close() + sa, err := ca.OpenStream(ctx) + if !assert.NoError(t, err) { + return + } + defer sa.Close() + sa.Write([]byte("hello world")) + recv := make([]byte, 24) + n, err := sa.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + } + + acceptAndPong := func(r *relayedHost) { + cb, err := r.T.listener.Accept() + if !assert.NoError(t, err) { + return + } + + sb, err := cb.AcceptStream() + if !assert.NoError(t, err) { + return + } + defer sb.Close() + + recv := make([]byte, 24) + n, err := sb.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + sb.Write(recv[:n]) + } + + for i := 0; i < N; i++ { + for j := 0; j < N; j++ { wg.Add(1) - go func(j int) { - b := newRelayedHost(t) - defer b.Close() - - l, err := b.T.Listen(ma.StringCast("/webrtc")) - assert.NoError(t, err) - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - ca, err := hosts[j].T.Dial(ctx, b.Addr, b.ID()) - assert.NoError(t, err) - - cb, err := l.Accept() - assert.NoError(t, err) - - sa, err := ca.OpenStream(ctx) - assert.NoError(t, err) - sb, err := cb.AcceptStream() - assert.NoError(t, err) - sa.Write([]byte("hello world")) - recv := make([]byte, 24) - n, err := sb.Read(recv) - assert.NoError(t, err) - assert.Equal(t, "hello world", string(recv[:n])) + go func(i, j int) { + go dialAndPing(dialHosts[i], listenHosts[j].Addr, listenHosts[j].ID()) + acceptAndPong(listenHosts[j]) wg.Done() - }(j) + }(i, j) } } wg.Wait() } + +func TestDialerCanCreateStreams(t *testing.T) { + a := newWebRTCHost(t) + b := newRelayedHost(t) + listener, err := b.T.Listen(ma.StringCast("/webrtc")) + require.NoError(t, err) + + aC := make(chan bool) + go func() { + defer close(aC) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := a.T.Dial(ctx, b.Addr, b.ID()) + if !assert.NoError(t, err) { + return + } + s, err := conn.AcceptStream() + if !assert.NoError(t, err) { + return + } + recv := make([]byte, 24) + n, err := s.Read(recv) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(recv[:n]) + if !assert.NoError(t, err) { + return + } + s.Close() + }() + + bC := make(chan bool) + go func() { + defer close(bC) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + s, err := conn.OpenStream(ctx) + if !assert.NoError(t, err) { + return + } + defer s.Close() + + _, err = s.Write([]byte("hello world")) + if !assert.NoError(t, err) { + return + } + + recv := make([]byte, 24) + n, err := s.Read(recv) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello world", string(recv[:n])) { + return + } + }() + + select { + case <-aC: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + select { + case <-bC: + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } +} + +func TestDialerCanCreateStreamsMultiple(t *testing.T) { + count := 5 + a := newWebRTCHost(t) + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, a.ID(), lconn.RemotePeer()) { + return + } + var wg sync.WaitGroup + + for i := 0; i < count; i++ { + stream, err := lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + n, err := stream.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "test", string(buf[:n])) { + return + } + _, err = stream.Write([]byte("test")) + if !assert.NoError(t, err) { + return + } + }() + } + + wg.Wait() + done <- struct{}{} + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + + for i := 0; i < count; i++ { + idx := i + go func() { + stream, err := conn.OpenStream(context.Background()) + if !assert.NoError(t, err) { + return + } + t.Logf("dialer opened stream: %d", idx) + buf := make([]byte, 100) + _, err = stream.Write([]byte("test")) + if !assert.NoError(t, err) { + return + } + n, err := stream.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "test", string(buf[:n])) { + return + } + }() + } + select { + case <-done: + case <-time.After(20 * time.Second): + t.Fatal("timed out") + } +} + +func TestMaxInflightQueue(t *testing.T) { + b := newRelayedHost(t) + defer b.Close() + count := 3 + b.T.maxInFlightConnections = count + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + defer listener.Close() + + var success, failure atomic.Int32 + var wg sync.WaitGroup + for i := 0; i < count+1; i++ { + wg.Add(1) + go func() { + a := newWebRTCHost(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := a.T.Dial(ctx, b.Addr, b.ID()) + if err == nil { + success.Add(1) + } else { + failure.Add(1) + } + wg.Done() + }() + } + wg.Wait() + require.Equal(t, 1, int(failure.Load())) + require.Equal(t, count, int(success.Load())) +} + +func TestRemoteReadsAfterClose(t *testing.T) { + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + + a := newWebRTCHost(t) + + done := make(chan error) + go func() { + lconn, err := listener.Accept() + if err != nil { + done <- err + return + } + stream, err := lconn.AcceptStream() + if err != nil { + done <- err + return + } + _, err = stream.Write([]byte{1, 2, 3, 4}) + if err != nil { + done <- err + return + } + err = stream.Close() + if err != nil { + done <- err + return + } + close(done) + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + // create a stream + stream, err := conn.OpenStream(context.Background()) + + require.NoError(t, err) + // require write and close to complete + require.NoError(t, <-done) + + stream.SetReadDeadline(time.Now().Add(5 * time.Second)) + + buf := make([]byte, 10) + n, err := stream.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 4) +} + +func TestStreamDeadline(t *testing.T) { + b := newRelayedHost(t) + listener, err := b.T.Listen(WebRTCAddr) + require.NoError(t, err) + a := newWebRTCHost(t) + + t.Run("SetReadDeadline", func(t *testing.T) { + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + _, err = lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + stream, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + + // deadline set to the past + stream.SetReadDeadline(time.Now().Add(-200 * time.Millisecond)) + _, err = stream.Read([]byte{0, 0}) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + + // future deadline exceeded + stream.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + _, err = stream.Read([]byte{0, 0}) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + }) + + t.Run("SetWriteDeadline", func(t *testing.T) { + go func() { + lconn, err := listener.Accept() + if !assert.NoError(t, err) { + return + } + _, err = lconn.AcceptStream() + if !assert.NoError(t, err) { + return + } + }() + + conn, err := a.T.Dial(context.Background(), b.Addr, b.ID()) + require.NoError(t, err) + stream, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + + stream.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) + time.Sleep(201 * time.Millisecond) + largeBuffer := make([]byte, 2*1024*1024) + _, err = stream.Write(largeBuffer) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + }) +}