diff --git a/core/network/conn.go b/core/network/conn.go index 3be8cb0d69..06441558b7 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -2,6 +2,7 @@ package network import ( "context" + "fmt" "io" ic "github.com/libp2p/go-libp2p/core/crypto" @@ -11,6 +12,17 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +type ConnErrorCode uint32 + +type ConnError struct { + Remote bool + ErrorCode ConnErrorCode +} + +func (c *ConnError) Error() string { + return fmt.Sprintf("connection closed: code: %d", c.ErrorCode) +} + // Conn is a connection to a remote peer. It multiplexes streams. // Usually there is no need to use a Conn directly, but it may // be useful to get information about the peer on the other side: @@ -24,6 +36,11 @@ type Conn interface { ConnStat ConnScoper + // CloseWithError closes the connection with errCode. The errCode is sent to the + // peer on a best effort basis. For transports that do not support sending error + // codes on connection close, the behavior is identical to calling Close. + CloseWithError(errCode ConnErrorCode) error + // ID returns an identifier that uniquely identifies this Conn within this // host, during this run. Connection IDs may repeat across restarts. ID() string diff --git a/core/network/mux.go b/core/network/mux.go index d12e2ea34b..4f584bd591 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -3,6 +3,7 @@ package network import ( "context" "errors" + "fmt" "io" "net" "time" @@ -11,6 +12,21 @@ import ( // ErrReset is returned when reading or writing on a reset stream. var ErrReset = errors.New("stream reset") +type StreamErrorCode uint32 + +type StreamError struct { + ErrorCode StreamErrorCode + Remote bool +} + +func (s *StreamError) Error() string { + return fmt.Sprintf("stream reset: code: %d", s.ErrorCode) +} + +func (s *StreamError) Is(target error) bool { + return target == ErrReset +} + // MuxedStream is a bidirectional io pipe within a connection. type MuxedStream interface { io.Reader @@ -61,6 +77,13 @@ type MuxedStream interface { SetWriteDeadline(time.Time) error } +type ResetWithErrorer interface { + // ResetWithError closes both ends of the stream with errCode. The errCode is sent + // to the peer on a best effort basis. For transports that do not support sending + // error codes to remote peer, the behavior is identical to calling Reset + ResetWithError(errCode StreamErrorCode) error +} + // MuxedConn represents a connection to a remote peer that has been // extended to support stream multiplexing. // @@ -86,6 +109,12 @@ type MuxedConn interface { AcceptStream() (MuxedStream, error) } +type CloseWithErrorer interface { + // CloseWithError closes the connection with errCode. The errCode is sent + // to the peer. + CloseWithError(errCode ConnErrorCode) error +} + // Multiplexer wraps a net.Conn with a stream multiplexing // implementation and returns a MuxedConn that supports opening // multiple streams over the underlying net.Conn diff --git a/core/network/stream.go b/core/network/stream.go index 62e230034c..f2b6cbcb88 100644 --- a/core/network/stream.go +++ b/core/network/stream.go @@ -27,4 +27,8 @@ type Stream interface { // Scope returns the user's view of this stream's resource scope Scope() StreamScope + + // ResetWithError closes both ends of the stream with errCode. The errCode is sent + // to the peer. + ResetWithError(errCode StreamErrorCode) error } diff --git a/go.mod b/go.mod index c6f3f2b324..596d8ec7eb 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/libp2p/go-nat v0.2.0 github.com/libp2p/go-netroute v0.2.1 github.com/libp2p/go-reuseport v0.4.0 - github.com/libp2p/go-yamux/v4 v4.0.1 + github.com/libp2p/go-yamux/v4 v4.0.2-0.20240828193053-e17eaa82d8a7 github.com/libp2p/zeroconf/v2 v2.2.0 github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b diff --git a/go.sum b/go.sum index 4650eef1f0..d1cd4a5c47 100644 --- a/go.sum +++ b/go.sum @@ -194,8 +194,8 @@ github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9t github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= -github.com/libp2p/go-yamux/v4 v4.0.1 h1:FfDR4S1wj6Bw2Pqbc8Uz7pCxeRBPbwsBbEdfwiCypkQ= -github.com/libp2p/go-yamux/v4 v4.0.1/go.mod h1:NWjl8ZTLOGlozrXSOZ/HlfG++39iKNnM5wwmtQP1YB4= +github.com/libp2p/go-yamux/v4 v4.0.2-0.20240828193053-e17eaa82d8a7 h1:9DQhrYNrteSCiE8EZC1Na9AZNothvTF+NQtbnOjbxzo= +github.com/libp2p/go-yamux/v4 v4.0.2-0.20240828193053-e17eaa82d8a7/go.mod h1:PGP+3py2ZWDKABvqstBZtMnixEHNC7U/odnGylzur5o= github.com/libp2p/zeroconf/v2 v2.2.0 h1:Cup06Jv6u81HLhIj1KasuNM/RHHrJ8T7wOTS4+Tv53Q= github.com/libp2p/zeroconf/v2 v2.2.0/go.mod h1:fuJqLnUwZTshS3U/bMRJ3+ow/v9oid1n0DmyYyNO1Xs= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= diff --git a/p2p/muxer/yamux/conn.go b/p2p/muxer/yamux/conn.go index 40c4af4052..4531771842 100644 --- a/p2p/muxer/yamux/conn.go +++ b/p2p/muxer/yamux/conn.go @@ -23,6 +23,10 @@ func (c *conn) Close() error { return c.yamux().Close() } +func (c *conn) CloseWithError(errCode network.ConnErrorCode) error { + return c.yamux().CloseWithError(uint32(errCode)) +} + // IsClosed checks if yamux.Session is in closed state. func (c *conn) IsClosed() bool { return c.yamux().IsClosed() @@ -32,7 +36,7 @@ func (c *conn) IsClosed() bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { s, err := c.yamux().OpenStream(ctx) if err != nil { - return nil, err + return nil, parseResetError(err) } return (*stream)(s), nil @@ -41,7 +45,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { // AcceptStream accepts a stream opened by the other side. func (c *conn) AcceptStream() (network.MuxedStream, error) { s, err := c.yamux().AcceptStream() - return (*stream)(s), err + return (*stream)(s), parseResetError(err) } func (c *conn) yamux() *yamux.Session { diff --git a/p2p/muxer/yamux/stream.go b/p2p/muxer/yamux/stream.go index b50bc0bb87..b588c7c2b8 100644 --- a/p2p/muxer/yamux/stream.go +++ b/p2p/muxer/yamux/stream.go @@ -1,6 +1,7 @@ package yamux import ( + "errors" "time" "github.com/libp2p/go-libp2p/core/network" @@ -13,22 +14,29 @@ type stream yamux.Stream var _ network.MuxedStream = &stream{} -func (s *stream) Read(b []byte) (n int, err error) { - n, err = s.yamux().Read(b) - if err == yamux.ErrStreamReset { - err = network.ErrReset +func parseResetError(err error) error { + if err == nil { + return err + } + se := &yamux.StreamError{} + if errors.As(err, &se) { + return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode)} } + ce := &yamux.GoAwayError{} + if errors.As(err, &ce) { + return &network.ConnError{Remote: ce.Remote, ErrorCode: network.ConnErrorCode(ce.ErrorCode)} + } + return err +} - return n, err +func (s *stream) Read(b []byte) (n int, err error) { + n, err = s.yamux().Read(b) + return n, parseResetError(err) } func (s *stream) Write(b []byte) (n int, err error) { n, err = s.yamux().Write(b) - if err == yamux.ErrStreamReset { - err = network.ErrReset - } - - return n, err + return n, parseResetError(err) } func (s *stream) Close() error { @@ -39,6 +47,10 @@ func (s *stream) Reset() error { return s.yamux().Reset() } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + return s.yamux().ResetWithError(uint32(errCode)) +} + func (s *stream) CloseRead() error { return s.yamux().CloseRead() } diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index 2c657255f0..5955265f9b 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -794,6 +794,7 @@ type mockConn struct { } func (m mockConn) Close() error { panic("implement me") } +func (m mockConn) CloseWithError(errCode network.ConnErrorCode) error { panic("implement me") } func (m mockConn) LocalPeer() peer.ID { panic("implement me") } func (m mockConn) RemotePeer() peer.ID { panic("implement me") } func (m mockConn) RemotePublicKey() crypto.PubKey { panic("implement me") } diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index c85cca544d..3ba29ddd80 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -144,6 +144,10 @@ func (s *stream) Reset() error { return nil } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + panic("not implemented") +} + func (s *stream) teardown() { // at this point, no streams are writing. s.conn.removeStream(s) diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index ef1fc2a2b3..12e83d38ef 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -840,6 +840,14 @@ func (c connWithMetrics) Close() error { return c.CapableConn.Close() } +func (c connWithMetrics) CloseWithError(errCode network.ConnErrorCode) error { + c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr()) + if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return c.CapableConn.Close() +} + func (c connWithMetrics) Stat() network.ConnStats { if cs, ok := c.CapableConn.(network.ConnStat); ok { return cs.Stat() diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 5fd41c8d9f..b7cc46fb71 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -58,11 +58,20 @@ func (c *Conn) ID() string { // open notifications must finish before we can fire off the close // notifications). func (c *Conn) Close() error { - c.closeOnce.Do(c.doClose) + c.closeOnce.Do(func() { + c.doClose(0) + }) return c.err } -func (c *Conn) doClose() { +func (c *Conn) CloseWithError(errCode network.ConnErrorCode) error { + c.closeOnce.Do(func() { + c.doClose(errCode) + }) + return c.err +} + +func (c *Conn) doClose(errCode network.ConnErrorCode) { c.swarm.removeConn(c) // Prevent new streams from opening. @@ -71,7 +80,15 @@ func (c *Conn) doClose() { c.streams.m = nil c.streams.Unlock() - c.err = c.conn.Close() + if errCode != 0 { + if ce, ok := c.conn.(network.CloseWithErrorer); ok { + c.err = ce.CloseWithError(errCode) + } else { + c.err = c.conn.Close() + } + } else { + c.err = c.conn.Close() + } // Send the connectedness event after closing the connection. // This ensures that both remote connection close and local connection diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index b7846adec2..437921aaff 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -91,6 +91,17 @@ func (s *Stream) Reset() error { return err } +func (s *Stream) ResetWithError(errCode network.StreamErrorCode) error { + var err error + if se, ok := s.stream.(network.ResetWithErrorer); ok { + err = se.ResetWithError(errCode) + } else { + err = s.stream.Reset() + } + s.closeAndRemoveStream() + return err +} + func (s *Stream) closeAndRemoveStream() { s.closeMx.Lock() defer s.closeMx.Unlock() diff --git a/p2p/net/upgrader/conn.go b/p2p/net/upgrader/conn.go index 1c23a01aed..18e1e6a931 100644 --- a/p2p/net/upgrader/conn.go +++ b/p2p/net/upgrader/conn.go @@ -63,3 +63,10 @@ func (t *transportConn) ConnState() network.ConnectionState { UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation, } } + +func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error { + if ce, ok := t.MuxedConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return t.Close() +} diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 60f8ca0c06..a7b7edfe19 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -35,6 +35,7 @@ import ( "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -830,3 +831,232 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) { }) } } + +// TestStreamErrorCode tests correctness for resetting stream with errors +func TestStreamErrorCode(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name == "WebTransport" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.StreamErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + se := &network.StreamError{} + if errors.As(err, &se) { + require.Equal(t, se.ErrorCode, code) + require.Equal(t, se.Remote, remote) + return + } + t.Fatal("expected network.StreamError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + _, err = s.Read(b) + errCh <- err + + _, err = s.Write(b) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.ResetWithError(42) + require.NoError(t, err) + + _, err = s.Read(buf) + checkError(err, 42, false) + + _, err = s.Write(buf) + checkError(err, 42, false) + + err = <-errCh // read error + checkError(err, 42, true) + + err = <-errCh // write error + checkError(err, 42, true) + }) + } +} + +// TestStreamErrorCodeConnClosed tests correctness for resetting stream with errors +func TestStreamErrorCodeConnClosed(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name == "WebTransport" || tc.Name == "WebRTC" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.ConnErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + ce := &network.ConnError{} + if errors.As(err, &ce) { + require.Equal(t, code, ce.ErrorCode) + require.Equal(t, remote, ce.Remote) + return + } + t.Fatal("expected network.ConnError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + _, err = s.Read(b) + errCh <- err + + _, err = s.Write(b) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.Conn().CloseWithError(42) + require.NoError(t, err) + + _, err = s.Read(buf) + checkError(err, 42, false) + + _, err = s.Write(buf) + checkError(err, 42, false) + + err = <-errCh + checkError(err, 42, true) + + err = <-errCh + checkError(err, 42, true) + }) + } +} + +// TestConnectionErrorCode tests correctness for resetting stream with errors +func TestConnectionErrorCode(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name == "WebTransport" || tc.Name == "WebRTC" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.ConnErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + ce := &network.ConnError{} + if errors.As(err, &ce) { + require.Equal(t, code, ce.ErrorCode) + require.Equal(t, remote, ce.Remote) + return + } + t.Fatal("expected network.ConnError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + + _, err = s.Read(b) + if !assert.Error(t, err) { + return + } + _, err = s.Conn().NewStream(context.Background()) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.Conn().CloseWithError(42) + require.NoError(t, err) + + str, err := s.Conn().NewStream(context.Background()) + require.Nil(t, str) + checkError(err, 42, false) + + err = <-errCh + checkError(err, 42, true) + + }) + } +} diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index a2da81eb34..8b381d8eda 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -34,6 +34,13 @@ func (c *conn) Close() error { return c.closeWithError(0, "") } +// CloseWithError closes the connection +// It must be called even if the peer closed the connection in order for +// garbage collection to properly work in this package. +func (c *conn) CloseWithError(errCode network.ConnErrorCode) error { + return c.closeWithError(quic.ApplicationErrorCode(errCode), "") +} + func (c *conn) closeWithError(errCode quic.ApplicationErrorCode, errString string) error { c.transport.removeConn(c.quicConn) err := c.quicConn.CloseWithError(errCode, errString) @@ -53,13 +60,19 @@ func (c *conn) allowWindowIncrease(size uint64) bool { // OpenStream creates a new stream. func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { qstr, err := c.quicConn.OpenStreamSync(ctx) - return &stream{Stream: qstr}, err + if err != nil { + return nil, parseStreamError(err) + } + return &stream{Stream: qstr}, nil } // AcceptStream accepts a stream opened by the other side. func (c *conn) AcceptStream() (network.MuxedStream, error) { qstr, err := c.quicConn.AcceptStream(context.Background()) - return &stream{Stream: qstr}, err + if err != nil { + return nil, parseStreamError(err) + } + return &stream{Stream: qstr}, nil } // LocalPeer returns our peer ID diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index d3e27a7e16..bf3f7b0751 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -270,6 +270,9 @@ func TestStreams(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { testStreams(t, tc) }) + t.Run(tc.Name, func(t *testing.T) { + testStreamsErrorCode(t, tc) + }) } } @@ -305,6 +308,45 @@ func testStreams(t *testing.T, tc *connTestCase) { require.Equal(t, data, []byte("foobar")) } +func testStreamsErrorCode(t *testing.T, tc *connTestCase) { + serverID, serverKey := createPeer(t) + _, clientKey := createPeer(t) + + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") + defer ln.Close() + + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) + require.NoError(t, err) + defer clientTransport.(io.Closer).Close() + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + serverConn, err := ln.Accept() + require.NoError(t, err) + defer serverConn.Close() + + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + err = str.ResetWithError(42) + require.NoError(t, err) + + sstr, err := serverConn.AcceptStream() + require.NoError(t, err) + _, err = io.ReadAll(sstr) + require.Error(t, err) + se := &network.StreamError{} + if errors.As(err, &se) { + require.Equal(t, se.ErrorCode, network.StreamErrorCode(42)) + require.True(t, se.Remote) + } else { + t.Fatalf("expected error to be of network.StreamError type, got %T, %v", err, err) + } + +} + func TestHandshakeFailPeerIDMismatch(t *testing.T) { for _, tc := range connTestCases { t.Run(tc.Name, func(t *testing.T) { diff --git a/p2p/transport/quic/stream.go b/p2p/transport/quic/stream.go index 56f12dade2..57d6577f3f 100644 --- a/p2p/transport/quic/stream.go +++ b/p2p/transport/quic/stream.go @@ -2,6 +2,7 @@ package libp2pquic import ( "errors" + "math" "github.com/libp2p/go-libp2p/core/network" @@ -18,20 +19,41 @@ type stream struct { var _ network.MuxedStream = &stream{} +func parseStreamError(err error) error { + if err == nil { + return err + } + se := &quic.StreamError{} + if errors.As(err, &se) { + code := se.ErrorCode + if code > math.MaxUint32 { + // TODO(sukunrt): do we need this? + code = reset + } + err = &network.StreamError{ + ErrorCode: network.StreamErrorCode(code), + Remote: se.Remote, + } + } + ae := &quic.ApplicationError{} + if errors.As(err, &ae) { + code := ae.ErrorCode + err = &network.ConnError{ + ErrorCode: network.ConnErrorCode(code), + Remote: ae.Remote, + } + } + return err +} + func (s *stream) Read(b []byte) (n int, err error) { n, err = s.Stream.Read(b) - if err != nil && errors.Is(err, &quic.StreamError{}) { - err = network.ErrReset - } - return n, err + return n, parseStreamError(err) } func (s *stream) Write(b []byte) (n int, err error) { n, err = s.Stream.Write(b) - if err != nil && errors.Is(err, &quic.StreamError{}) { - err = network.ErrReset - } - return n, err + return n, parseStreamError(err) } func (s *stream) Reset() error { @@ -40,6 +62,12 @@ func (s *stream) Reset() error { return nil } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + s.Stream.CancelRead(quic.StreamErrorCode(errCode)) + s.Stream.CancelWrite(quic.StreamErrorCode(errCode)) + return nil +} + func (s *stream) Close() error { s.Stream.CancelRead(reset) return s.Stream.Close() diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index 2fba37a970..57c0df7e94 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -132,6 +132,10 @@ func (c *connection) Close() error { return nil } +func (c *connection) CloseWithError(errCode network.ConnErrorCode) error { + return c.Close() +} + // closeWithError is used to Close the connection when the underlying DTLS connection fails func (c *connection) closeWithError(err error) { c.closeOnce.Do(func() { diff --git a/p2p/transport/webrtc/pb/message.pb.go b/p2p/transport/webrtc/pb/message.pb.go index d7d4d583af..6e7b54f2b1 100644 --- a/p2p/transport/webrtc/pb/message.pb.go +++ b/p2p/transport/webrtc/pb/message.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.35.1 -// protoc v5.28.2 +// protoc-gen-go v1.35.2 +// protoc v5.28.3 // source: p2p/transport/webrtc/pb/message.proto package pb @@ -95,8 +95,9 @@ type Message struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Flag *Message_Flag `protobuf:"varint,1,opt,name=flag,enum=Message_Flag" json:"flag,omitempty"` - Message []byte `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` + Flag *Message_Flag `protobuf:"varint,1,opt,name=flag,enum=Message_Flag" json:"flag,omitempty"` + Message []byte `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` + ErrorCode *uint32 `protobuf:"varint,3,opt,name=errorCode" json:"errorCode,omitempty"` } func (x *Message) Reset() { @@ -143,24 +144,32 @@ func (x *Message) GetMessage() []byte { return nil } +func (x *Message) GetErrorCode() uint32 { + if x != nil && x.ErrorCode != nil { + return *x.ErrorCode + } + return 0 +} + var File_p2p_transport_webrtc_pb_message_proto protoreflect.FileDescriptor var file_p2p_transport_webrtc_pb_message_proto_rawDesc = []byte{ 0x0a, 0x25, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x81, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x9f, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x39, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, - 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, - 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, - 0x0a, 0x07, 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, - 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, - 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, - 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, - 0x70, 0x62, + 0x12, 0x1c, 0x0a, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x39, + 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, + 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, + 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, + 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, + 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, } var ( diff --git a/p2p/transport/webrtc/pb/message.proto b/p2p/transport/webrtc/pb/message.proto index aab885b0da..2401f7c4d2 100644 --- a/p2p/transport/webrtc/pb/message.proto +++ b/p2p/transport/webrtc/pb/message.proto @@ -21,4 +21,6 @@ message Message { optional Flag flag=1; optional bytes message = 2; + + optional uint32 errorCode = 3; } diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 56f869f5e1..98da1cde02 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -69,8 +69,9 @@ type stream struct { // readerMx ensures that only a single goroutine reads from the reader. Read is not threadsafe // But we may need to read from reader for control messages from a different goroutine. - readerMx sync.Mutex - reader pbio.Reader + readerMx sync.Mutex + reader pbio.Reader + readError error // this buffer is limited up to a single message. Reason we need it // is because a reader might read a message midway, and so we need a @@ -82,6 +83,7 @@ type stream struct { writeStateChanged chan struct{} sendState sendState writeDeadline time.Time + writeError error controlMessageReaderOnce sync.Once // controlMessageReaderEndTime is the end time for reading FIN_ACK from the control @@ -146,6 +148,10 @@ func (s *stream) Close() error { } func (s *stream) Reset() error { + return s.ResetWithError(0) +} + +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { s.mx.Lock() isClosed := s.closeForShutdownErr != nil s.mx.Unlock() @@ -154,8 +160,8 @@ func (s *stream) Reset() error { } defer s.cleanup() - cancelWriteErr := s.cancelWrite() - closeReadErr := s.CloseRead() + cancelWriteErr := s.cancelWrite(errCode) + closeReadErr := s.closeRead(errCode, false) s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) return errors.Join(closeReadErr, cancelWriteErr) } @@ -175,19 +181,20 @@ func (s *stream) SetDeadline(t time.Time) error { return s.SetWriteDeadline(t) } -// processIncomingFlag process the flag on an incoming message +// processIncomingFlag processes the flag(FIN/RST/etc) on msg. // It needs to be called while the mutex is locked. -func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { - if flag == nil { +func (s *stream) processIncomingFlag(msg *pb.Message) { + if msg.Flag == nil { return } - switch *flag { + switch msg.GetFlag() { case pb.Message_STOP_SENDING: // We must process STOP_SENDING after sending a FIN(sendStateDataSent). Remote peer // may not send a FIN_ACK once it has sent a STOP_SENDING if s.sendState == sendStateSending || s.sendState == sendStateDataSent { s.sendState = sendStateReset + s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())} } s.notifyWriteStateChanged() case pb.Message_FIN_ACK: @@ -206,6 +213,7 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { case pb.Message_RESET: if s.receiveState == receiveStateReceiving { s.receiveState = receiveStateReset + s.readError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())} } s.spawnControlMessageReader() } @@ -235,7 +243,7 @@ func (s *stream) spawnControlMessageReader() { s.readerMx.Unlock() if s.nextMessage != nil { - s.processIncomingFlag(s.nextMessage.Flag) + s.processIncomingFlag(s.nextMessage) s.nextMessage = nil } var msg pb.Message @@ -266,7 +274,7 @@ func (s *stream) spawnControlMessageReader() { } return } - s.processIncomingFlag(msg.Flag) + s.processIncomingFlag(&msg) } }() }) diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 80d99ea91c..826eec3049 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -22,7 +22,7 @@ func (s *stream) Read(b []byte) (int, error) { case receiveStateDataRead: return 0, io.EOF case receiveStateReset: - return 0, network.ErrReset + return 0, s.readError } if len(b) == 0 { @@ -52,10 +52,11 @@ func (s *stream) Read(b []byte) (int, error) { // datachannel. For these implementations a stream reset will be observed as an // abrupt closing of the datachannel. s.receiveState = receiveStateReset - return 0, network.ErrReset + s.readError = &network.StreamError{Remote: true} + return 0, s.readError } if s.receiveState == receiveStateReset { - return 0, network.ErrReset + return 0, s.readError } if s.receiveState == receiveStateDataRead { return 0, io.EOF @@ -73,7 +74,7 @@ func (s *stream) Read(b []byte) (int, error) { } // process flags on the message after reading all the data - s.processIncomingFlag(s.nextMessage.Flag) + s.processIncomingFlag(s.nextMessage) s.nextMessage = nil if s.closeForShutdownErr != nil { return read, s.closeForShutdownErr @@ -82,7 +83,7 @@ func (s *stream) Read(b []byte) (int, error) { case receiveStateDataRead: return read, io.EOF case receiveStateReset: - return read, network.ErrReset + return read, s.readError } } } @@ -101,12 +102,17 @@ func (s *stream) setDataChannelReadDeadline(t time.Time) error { } func (s *stream) CloseRead() error { + return s.closeRead(0, false) +} + +func (s *stream) closeRead(errCode network.StreamErrorCode, remote bool) error { s.mx.Lock() defer s.mx.Unlock() var err error if s.receiveState == receiveStateReceiving && s.closeForShutdownErr == nil { err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) s.receiveState = receiveStateReset + s.readError = &network.StreamError{Remote: remote, ErrorCode: errCode} } s.spawnControlMessageReader() return err diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 534a8d8e60..01fddac331 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -24,7 +24,7 @@ func (s *stream) Write(b []byte) (int, error) { } switch s.sendState { case sendStateReset: - return 0, network.ErrReset + return 0, s.writeError case sendStateDataSent, sendStateDataReceived: return 0, errWriteAfterClose } @@ -48,7 +48,7 @@ func (s *stream) Write(b []byte) (int, error) { } switch s.sendState { case sendStateReset: - return n, network.ErrReset + return n, s.writeError case sendStateDataSent, sendStateDataReceived: return n, errWriteAfterClose } @@ -119,7 +119,7 @@ func (s *stream) availableSendSpace() int { return availableSpace } -func (s *stream) cancelWrite() error { +func (s *stream) cancelWrite(errCode network.StreamErrorCode) error { s.mx.Lock() defer s.mx.Unlock() @@ -129,10 +129,12 @@ func (s *stream) cancelWrite() error { return nil } s.sendState = sendStateReset + s.writeError = &network.StreamError{Remote: false, ErrorCode: errCode} // Remove reference to this stream from data channel s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() - return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}) + code := uint32(errCode) + return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum(), ErrorCode: &code}) } func (s *stream) CloseWrite() error { diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index ce51611703..eca2262c58 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -203,3 +203,11 @@ func (c *capableConn) ConnState() network.ConnectionState { cs.Transport = "websocket" return cs } + +// CloseWithError implements network.CloseWithErrorer +func (c *capableConn) CloseWithError(errCode network.ConnErrorCode) error { + if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return c.Close() +} diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index d914398e0e..3618548d14 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -78,6 +78,10 @@ func (c *conn) Close() error { return err } +func (c *conn) CloseWithError(errCode network.ConnErrorCode) error { + return c.Close() +} + func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } func (c *conn) Scope() network.ConnScope { return c.scope } func (c *conn) Transport() tpt.Transport { return c.transport } diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go index 0849fc9f38..583708edc2 100644 --- a/p2p/transport/webtransport/stream.go +++ b/p2p/transport/webtransport/stream.go @@ -56,6 +56,10 @@ func (s *stream) Reset() error { return nil } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + panic("not implemented") +} + func (s *stream) Close() error { s.Stream.CancelRead(reset) return s.Stream.Close()