From 079bd3e5657177fa6e5d6fc5c40e476eaa950953 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Sat, 16 Nov 2024 15:56:56 +0100 Subject: [PATCH] Use plain channels to send data between streams --- p2p/transport/memory/conn.go | 45 ++++---- p2p/transport/memory/listener.go | 5 +- p2p/transport/memory/stream.go | 144 +++++++++++++++---------- p2p/transport/memory/stream_test.go | 12 +-- p2p/transport/memory/transport.go | 35 +++--- p2p/transport/memory/transport_test.go | 70 ++++++++++++ 6 files changed, 201 insertions(+), 110 deletions(-) create mode 100644 p2p/transport/memory/transport_test.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index d864e93316..515fb43625 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -2,7 +2,6 @@ package memory import ( "context" - "io" "sync" "sync/atomic" @@ -14,7 +13,7 @@ import ( ) type conn struct { - id int32 + id int64 transport *transport scope network.ConnManagementScope @@ -26,21 +25,19 @@ type conn struct { remotePubKey ic.PubKey remoteMultiaddr ma.Multiaddr - isClosed atomic.Bool - closeOnce sync.Once - mu sync.Mutex - streamC chan *stream + closed atomic.Bool + closeOnce sync.Once - nextStreamID atomic.Int32 - streams map[int32]network.MuxedStream + streamC chan *stream + streams map[int64]network.MuxedStream } var _ tpt.CapableConn = &conn{} func newConnection( - id int32, + t *transport, s *stream, localPeer peer.ID, localMultiaddr ma.Multiaddr, @@ -49,40 +46,36 @@ func newConnection( remoteMultiaddr ma.Multiaddr, ) *conn { c := &conn{ - id: id, + id: connCounter.Add(1), + transport: t, localPeer: localPeer, localMultiaddr: localMultiaddr, remotePubKey: remotePubKey, remotePeerID: remotePeer, remoteMultiaddr: remoteMultiaddr, streamC: make(chan *stream, 1), - streams: make(map[int32]network.MuxedStream), + streams: make(map[int64]network.MuxedStream), } - streamID := c.nextStreamID.Add(1) - c.addStream(streamID, s) - + c.addStream(s.id, s) return c } func (c *conn) Close() error { - c.closeOnce.Do(func() { - c.isClosed.Store(true) - c.transport.removeConn(c) - }) + c.closed.Store(true) + for _, s := range c.streams { + s.Close() + } return nil } func (c *conn) IsClosed() bool { - return c.isClosed.Load() + return c.closed.Load() } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - ra, wb := io.Pipe() - rb, wa := io.Pipe() - inConnId, outConnId := c.nextStreamID.Add(1), c.nextStreamID.Add(1) - inStream, outStream := newStream(inConnId, ra, wb), newStream(outConnId, rb, wa) + inStream, outStream := newStreamPair() c.streamC <- inStream return outStream, nil @@ -90,7 +83,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { func (c *conn) AcceptStream() (network.MuxedStream, error) { in := <-c.streamC - id := c.nextStreamID.Add(1) + id := streamCounter.Add(1) c.addStream(id, in) return in, nil } @@ -122,14 +115,14 @@ func (c *conn) ConnState() network.ConnectionState { return network.ConnectionState{Transport: "memory"} } -func (c *conn) addStream(id int32, stream network.MuxedStream) { +func (c *conn) addStream(id int64, stream network.MuxedStream) { c.mu.Lock() defer c.mu.Unlock() c.streams[id] = stream } -func (c *conn) removeStream(id int32) { +func (c *conn) removeStream(id int64) { c.mu.Lock() defer c.mu.Unlock() diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index 39e8acfb29..54417e2a8b 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -21,7 +21,7 @@ type listener struct { mu sync.Mutex connCh chan *conn - connections map[int32]*conn + connections map[int64]*conn } func (l *listener) Multiaddr() ma.Multiaddr { @@ -36,7 +36,7 @@ func newListener(t *transport, laddr ma.Multiaddr) *listener { cancel: cancel, laddr: laddr, connCh: make(chan *conn, listenerQueueSize), - connections: make(map[int32]*conn), + connections: make(map[int64]*conn), } } @@ -53,6 +53,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { l.mu.Lock() defer l.mu.Unlock() + c.transport = l.t l.connections[c.id] = c return c, nil } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 101ae516da..66d8879f88 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -1,106 +1,132 @@ package memory import ( + "errors" "io" + "net" "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" ) +// stream implements network.Stream type stream struct { - id int32 + id int64 - r *io.PipeReader - w *io.PipeWriter - writeC chan []byte + write chan byte + read chan byte - readCloseC chan struct{} - writeCloseC chan struct{} + reset chan struct{} + closeRead chan struct{} + closeWrite chan struct{} + closed atomic.Bool +} + +var ErrClosed = errors.New("stream closed") - closed atomic.Bool +func newStreamPair() (*stream, *stream) { + ra, rb := make(chan byte, 4096), make(chan byte, 4096) + wa, wb := rb, ra + + in := newStream(rb, wb, network.DirInbound) + out := newStream(ra, wa, network.DirOutbound) + return in, out } -func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream { +func newStream(r, w chan byte, _ network.Direction) *stream { s := &stream{ - id: id, - r: r, - w: w, - writeC: make(chan []byte, 1), - readCloseC: make(chan struct{}, 1), - writeCloseC: make(chan struct{}, 1), + id: streamCounter.Add(1), + read: r, + write: w, + reset: make(chan struct{}, 1), + closeRead: make(chan struct{}, 1), + closeWrite: make(chan struct{}, 1), } - go func() { - for { - select { - case b := <-s.writeC: - if _, err := w.Write(b); err != nil { - return - } - case <-s.writeCloseC: - return - } - } - }() - return s } -func (s *stream) Read(b []byte) (int, error) { - return s.r.Read(b) -} - -func (s *stream) Write(b []byte) (int, error) { +// How to handle errors with writes? +func (s *stream) Write(p []byte) (n int, err error) { if s.closed.Load() { - return 0, network.ErrReset + return 0, ErrClosed } - select { - case <-s.writeCloseC: - return 0, network.ErrReset - case s.writeC <- b: - return len(b), nil + for i := 0; i < len(p); i++ { + select { + case <-s.reset: + err = network.ErrReset + return + case <-s.closeWrite: + err = ErrClosed + return + case s.write <- p[i]: + n = i + default: + err = io.ErrClosedPipe + } } + + return n + 1, err } -func (s *stream) Reset() error { - if err := s.CloseWrite(); err != nil { - return err +func (s *stream) Read(p []byte) (n int, err error) { + if s.closed.Load() { + return 0, ErrClosed } - if err := s.CloseRead(); err != nil { - return err + + for n = 0; n < len(p); n++ { + select { + case <-s.reset: + err = network.ErrReset + return + case <-s.closeRead: + err = ErrClosed + return + case b, ok := <-s.read: + if !ok { + err = io.EOF + return + } + p[n] = b + default: + err = io.EOF + return + } } - return nil + + return } -func (s *stream) Close() error { - s.CloseRead() - s.CloseWrite() +func (s *stream) CloseWrite() error { + s.closeWrite <- struct{}{} return nil } func (s *stream) CloseRead() error { - return s.r.CloseWithError(network.ErrReset) + s.closeRead <- struct{}{} + return nil } -func (s *stream) CloseWrite() error { - select { - case s.writeCloseC <- struct{}{}: - default: - } - +func (s *stream) Close() error { s.closed.Store(true) return nil } -func (s *stream) SetDeadline(_ time.Time) error { +func (s *stream) Reset() error { + s.reset <- struct{}{} return nil } -func (s *stream) SetReadDeadline(_ time.Time) error { - return nil +func (s *stream) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (s *stream) SetWriteDeadline(_ time.Time) error { - return nil + +func (s *stream) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (s *stream) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 33c3cbdc64..cd5149c685 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -8,22 +8,17 @@ import ( ) func TestStreamSimpleReadWriteClose(t *testing.T) { - //client, server := getDetachedDataChannels(t) - ra, wb := io.Pipe() - rb, wa := io.Pipe() - - clientStr := newStream(0, ra, wa) - serverStr := newStream(1, rb, wb) + clientStr, serverStr := newStreamPair() // send a foobar from the client n, err := clientStr.Write([]byte("foobar")) require.NoError(t, err) require.Equal(t, 6, n) require.NoError(t, clientStr.CloseWrite()) + // writing after closing should error _, err = clientStr.Write([]byte("foobar")) require.Error(t, err) - //require.False(t, clientDone.Load()) // now read all the data on the server side b, err := io.ReadAll(serverStr) @@ -33,7 +28,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { n, err = serverStr.Read(make([]byte, 10)) require.Zero(t, n) require.ErrorIs(t, err, io.EOF) - //require.False(t, serverDone.Load()) // send something back _, err = serverStr.Write([]byte("lorem ipsum")) @@ -49,8 +43,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { // stream is only cleaned up on calling Close or Reset clientStr.Close() serverStr.Close() - //require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond) // Need to call Close for cleanup. Otherwise the FIN_ACK is never read require.NoError(t, serverStr.Close()) - //require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond) } diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index 5016e3a7dd..a13c737437 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -3,7 +3,6 @@ package memory import ( "context" "errors" - "io" "sync" "sync/atomic" @@ -13,6 +12,14 @@ import ( "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" +) + +var ( + connCounter atomic.Int64 + streamCounter atomic.Int64 + listenerCounter atomic.Int64 + dialMatcher = mafmt.Base(ma.P_MEMORY) ) type hub struct { @@ -84,8 +91,7 @@ type transport struct { mu sync.RWMutex - connID atomic.Int32 - connections map[int32]*conn + connections map[int64]*conn } func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { @@ -105,7 +111,7 @@ func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManage localPeerID: id, localPrivKey: privKey, localPubKey: privKey.GetPublic(), - connections: make(map[int32]*conn), + connections: make(map[int64]*conn), }, nil } @@ -141,19 +147,16 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid return nil, errors.New("failed to get remote public key") } - ra, wb := io.Pipe() - rb, wa := io.Pipe() - inConnId, outConnId := t.connID.Add(1), t.connID.Add(1) - inStream, outStream := newStream(0, ra, wb), newStream(0, rb, wa) - - l.connCh <- newConnection(inConnId, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) + inStream, outStream := newStreamPair() + inConn := newConnection(t, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr) + outConn := newConnection(nil, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) + l.connCh <- outConn - return newConnection(outConnId, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr), nil + return inConn, nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { - _, exists := memhub.getListener(addr.String()) - return exists + return dialMatcher.Matches(addr) } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { @@ -184,6 +187,12 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them memhub.close() + + for _, c := range t.connections { + c.Close() + delete(t.connections, c.id) + } + return nil } diff --git a/p2p/transport/memory/transport_test.go b/p2p/transport/memory/transport_test.go new file mode 100644 index 0000000000..f83f0d1280 --- /dev/null +++ b/p2p/transport/memory/transport_test.go @@ -0,0 +1,70 @@ +package memory + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "io" + "testing" + + ic "github.com/libp2p/go-libp2p/core/crypto" + tpt "github.com/libp2p/go-libp2p/core/transport" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func getTransport(t *testing.T) tpt.Transport { + t.Helper() + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + require.NoError(t, err) + tr, err := NewTransport(key, nil, nil) + require.NoError(t, err) + return tr +} + +func TestMemoryProtocol(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() + + protocols := tr.Protocols() + if len(protocols) > 1 { + t.Fatalf("expected at most one protocol, got %v", protocols) + } + + if protocols[0] != ma.P_MEMORY { + t.Fatalf("expected the supported protocol to be memory, got %d", protocols[0]) + } +} + +func TestCanDial(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() + + invalid := []string{ + "/ip4/127.0.0.1/udp/1234", + "/ip4/5.5.5.5/tcp/1234", + "/dns/google.com/udp/443/quic-v1", + "/ip4/127.0.0.1/udp/1234/quic", + } + valid := []string{ + "/memory/1234", + "/memory/1337123", + } + for _, s := range invalid { + invalidAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if tr.CanDial(invalidAddr) { + t.Errorf("didn't expect to be able to dial a non-memory address (%s)", invalidAddr) + } + } + for _, s := range valid { + validAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if !tr.CanDial(validAddr) { + t.Errorf("expected to be able to dial memory address (%s)", validAddr) + } + } +}