diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 7e873f5634..f777f5b8fd 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -123,7 +123,23 @@ func newStream( } } - s.maybeDeclareStreamDone() + if s.isDone() { + // onDone removes the stream from the connection and requires the connection lock. + // This callback(onBufferedAmountLow) is executing in the sctp readLoop goroutine. + // If the connection is closed concurrently, the closing goroutine will acquire + // the connection lock and wait for sctp readLoop to exit, the sctp readLoop will + // wait for the connection lock before exiting, causing a deadlock. + // Run this in a different goroutine to avoid the deadlock. + go func() { + s.mx.Lock() + defer s.mx.Unlock() + // TODO: we should be closing the underlying datachannel, but this resets the stream + // See https://github.com/libp2p/specs/issues/575 for details. + // _ = s.dataChannel.Close() + // TODO: write for the spawned reader to return + s.onDone() + }() + } select { case s.writeAvailable <- struct{}{}: @@ -188,9 +204,7 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { // this is used to force reset a stream func (s *stream) maybeDeclareStreamDone() { - if (s.sendState == sendStateReset || s.sendState == sendStateDataSent) && - (s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) && - len(s.controlMsgQueue) == 0 { + if s.isDone() { _ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times // TODO: we should be closing the underlying datachannel, but this resets the stream // See https://github.com/libp2p/specs/issues/575 for details. @@ -200,6 +214,12 @@ func (s *stream) maybeDeclareStreamDone() { } } +func (s *stream) isDone() bool { + return (s.sendState == sendStateReset || s.sendState == sendStateDataSent) && + (s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) && + len(s.controlMsgQueue) == 0 +} + func (s *stream) setCloseError(e error) { s.mx.Lock() defer s.mx.Unlock() diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index c7eb3bf7a4..698af9c4d6 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -35,6 +35,10 @@ func (s *stream) Write(b []byte) (int, error) { s.readLoopOnce.Do(s.spawnControlMessageReader) } + if !s.writeDeadline.IsZero() && time.Now().After(s.writeDeadline) { + return 0, os.ErrDeadlineExceeded + } + var writeDeadlineTimer *time.Timer defer func() { if writeDeadlineTimer != nil { diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index 19d39ba282..7f4df94fc1 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -402,10 +402,16 @@ func TestTransportWebRTC_Deadline(t *testing.T) { stream, err := conn.OpenStream(context.Background()) require.NoError(t, err) - stream.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) + stream.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) largeBuffer := make([]byte, 2*1024*1024) _, err = stream.Write(largeBuffer) require.ErrorIs(t, err, os.ErrDeadlineExceeded) + + stream.SetWriteDeadline(time.Now().Add(-200 * time.Millisecond)) + smallBuffer := make([]byte, 1024) + _, err = stream.Write(smallBuffer) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + }) }