Skip to content

Commit

Permalink
Use plain channels to send data between streams
Browse files Browse the repository at this point in the history
  • Loading branch information
pyropy committed Nov 16, 2024
1 parent e2a5865 commit 079bd3e
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 110 deletions.
45 changes: 19 additions & 26 deletions p2p/transport/memory/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package memory

import (
"context"
"io"
"sync"
"sync/atomic"

Expand All @@ -14,7 +13,7 @@ import (
)

type conn struct {
id int32
id int64

transport *transport
scope network.ConnManagementScope
Expand All @@ -26,21 +25,19 @@ type conn struct {
remotePubKey ic.PubKey
remoteMultiaddr ma.Multiaddr

isClosed atomic.Bool
closeOnce sync.Once

mu sync.Mutex

streamC chan *stream
closed atomic.Bool
closeOnce sync.Once

nextStreamID atomic.Int32
streams map[int32]network.MuxedStream
streamC chan *stream
streams map[int64]network.MuxedStream
}

var _ tpt.CapableConn = &conn{}

func newConnection(
id int32,
t *transport,
s *stream,
localPeer peer.ID,
localMultiaddr ma.Multiaddr,
Expand All @@ -49,48 +46,44 @@ func newConnection(
remoteMultiaddr ma.Multiaddr,
) *conn {
c := &conn{
id: id,
id: connCounter.Add(1),
transport: t,
localPeer: localPeer,
localMultiaddr: localMultiaddr,
remotePubKey: remotePubKey,
remotePeerID: remotePeer,
remoteMultiaddr: remoteMultiaddr,
streamC: make(chan *stream, 1),
streams: make(map[int32]network.MuxedStream),
streams: make(map[int64]network.MuxedStream),
}

streamID := c.nextStreamID.Add(1)
c.addStream(streamID, s)

c.addStream(s.id, s)
return c
}

func (c *conn) Close() error {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
c.transport.removeConn(c)
})
c.closed.Store(true)
for _, s := range c.streams {
s.Close()
}

return nil
}

func (c *conn) IsClosed() bool {
return c.isClosed.Load()
return c.closed.Load()
}

func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
ra, wb := io.Pipe()
rb, wa := io.Pipe()
inConnId, outConnId := c.nextStreamID.Add(1), c.nextStreamID.Add(1)
inStream, outStream := newStream(inConnId, ra, wb), newStream(outConnId, rb, wa)
inStream, outStream := newStreamPair()

c.streamC <- inStream
return outStream, nil
}

func (c *conn) AcceptStream() (network.MuxedStream, error) {
in := <-c.streamC
id := c.nextStreamID.Add(1)
id := streamCounter.Add(1)
c.addStream(id, in)
return in, nil
}
Expand Down Expand Up @@ -122,14 +115,14 @@ func (c *conn) ConnState() network.ConnectionState {
return network.ConnectionState{Transport: "memory"}
}

func (c *conn) addStream(id int32, stream network.MuxedStream) {
func (c *conn) addStream(id int64, stream network.MuxedStream) {
c.mu.Lock()
defer c.mu.Unlock()

c.streams[id] = stream
}

func (c *conn) removeStream(id int32) {
func (c *conn) removeStream(id int64) {
c.mu.Lock()
defer c.mu.Unlock()

Expand Down
5 changes: 3 additions & 2 deletions p2p/transport/memory/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type listener struct {

mu sync.Mutex
connCh chan *conn
connections map[int32]*conn
connections map[int64]*conn
}

func (l *listener) Multiaddr() ma.Multiaddr {
Expand All @@ -36,7 +36,7 @@ func newListener(t *transport, laddr ma.Multiaddr) *listener {
cancel: cancel,
laddr: laddr,
connCh: make(chan *conn, listenerQueueSize),
connections: make(map[int32]*conn),
connections: make(map[int64]*conn),
}
}

Expand All @@ -53,6 +53,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
l.mu.Lock()
defer l.mu.Unlock()

c.transport = l.t
l.connections[c.id] = c
return c, nil
}
Expand Down
144 changes: 85 additions & 59 deletions p2p/transport/memory/stream.go
Original file line number Diff line number Diff line change
@@ -1,106 +1,132 @@
package memory

import (
"errors"
"io"
"net"
"sync/atomic"
"time"

"github.com/libp2p/go-libp2p/core/network"
)

// stream implements network.Stream
type stream struct {
id int32
id int64

r *io.PipeReader
w *io.PipeWriter
writeC chan []byte
write chan byte
read chan byte

readCloseC chan struct{}
writeCloseC chan struct{}
reset chan struct{}
closeRead chan struct{}
closeWrite chan struct{}
closed atomic.Bool
}

var ErrClosed = errors.New("stream closed")

closed atomic.Bool
func newStreamPair() (*stream, *stream) {
ra, rb := make(chan byte, 4096), make(chan byte, 4096)
wa, wb := rb, ra

in := newStream(rb, wb, network.DirInbound)
out := newStream(ra, wa, network.DirOutbound)
return in, out
}

func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream {
func newStream(r, w chan byte, _ network.Direction) *stream {
s := &stream{
id: id,
r: r,
w: w,
writeC: make(chan []byte, 1),
readCloseC: make(chan struct{}, 1),
writeCloseC: make(chan struct{}, 1),
id: streamCounter.Add(1),
read: r,
write: w,
reset: make(chan struct{}, 1),
closeRead: make(chan struct{}, 1),
closeWrite: make(chan struct{}, 1),
}

go func() {
for {
select {
case b := <-s.writeC:
if _, err := w.Write(b); err != nil {
return
}
case <-s.writeCloseC:
return
}
}
}()

return s
}

func (s *stream) Read(b []byte) (int, error) {
return s.r.Read(b)
}

func (s *stream) Write(b []byte) (int, error) {
// How to handle errors with writes?
func (s *stream) Write(p []byte) (n int, err error) {
if s.closed.Load() {
return 0, network.ErrReset
return 0, ErrClosed
}

select {
case <-s.writeCloseC:
return 0, network.ErrReset
case s.writeC <- b:
return len(b), nil
for i := 0; i < len(p); i++ {
select {
case <-s.reset:
err = network.ErrReset
return
case <-s.closeWrite:
err = ErrClosed
return
case s.write <- p[i]:
n = i
default:
err = io.ErrClosedPipe
}
}

return n + 1, err
}

func (s *stream) Reset() error {
if err := s.CloseWrite(); err != nil {
return err
func (s *stream) Read(p []byte) (n int, err error) {
if s.closed.Load() {
return 0, ErrClosed
}
if err := s.CloseRead(); err != nil {
return err

for n = 0; n < len(p); n++ {
select {
case <-s.reset:
err = network.ErrReset
return
case <-s.closeRead:
err = ErrClosed
return
case b, ok := <-s.read:
if !ok {
err = io.EOF
return
}
p[n] = b
default:
err = io.EOF
return
}
}
return nil

return
}

func (s *stream) Close() error {
s.CloseRead()
s.CloseWrite()
func (s *stream) CloseWrite() error {
s.closeWrite <- struct{}{}
return nil
}

func (s *stream) CloseRead() error {
return s.r.CloseWithError(network.ErrReset)
s.closeRead <- struct{}{}
return nil
}

func (s *stream) CloseWrite() error {
select {
case s.writeCloseC <- struct{}{}:
default:
}

func (s *stream) Close() error {
s.closed.Store(true)
return nil
}

func (s *stream) SetDeadline(_ time.Time) error {
func (s *stream) Reset() error {
s.reset <- struct{}{}
return nil
}

func (s *stream) SetReadDeadline(_ time.Time) error {
return nil
func (s *stream) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (s *stream) SetWriteDeadline(_ time.Time) error {
return nil

func (s *stream) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

func (s *stream) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
12 changes: 2 additions & 10 deletions p2p/transport/memory/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@ import (
)

func TestStreamSimpleReadWriteClose(t *testing.T) {
//client, server := getDetachedDataChannels(t)
ra, wb := io.Pipe()
rb, wa := io.Pipe()

clientStr := newStream(0, ra, wa)
serverStr := newStream(1, rb, wb)
clientStr, serverStr := newStreamPair()

// send a foobar from the client
n, err := clientStr.Write([]byte("foobar"))
require.NoError(t, err)
require.Equal(t, 6, n)
require.NoError(t, clientStr.CloseWrite())

// writing after closing should error
_, err = clientStr.Write([]byte("foobar"))
require.Error(t, err)
//require.False(t, clientDone.Load())

// now read all the data on the server side
b, err := io.ReadAll(serverStr)
Expand All @@ -33,7 +28,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
n, err = serverStr.Read(make([]byte, 10))
require.Zero(t, n)
require.ErrorIs(t, err, io.EOF)
//require.False(t, serverDone.Load())

// send something back
_, err = serverStr.Write([]byte("lorem ipsum"))
Expand All @@ -49,8 +43,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
// stream is only cleaned up on calling Close or Reset
clientStr.Close()
serverStr.Close()
//require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond)
// Need to call Close for cleanup. Otherwise the FIN_ACK is never read
require.NoError(t, serverStr.Close())
//require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond)
}
Loading

0 comments on commit 079bd3e

Please sign in to comment.