Skip to content

Commit

Permalink
Bug: Fixing a deadlock that would occur if the connection was closed …
Browse files Browse the repository at this point in the history
…before its streams (#178)

Signed-off-by: Shivansh Vij <shivanshvij@loopholelabs.io>
  • Loading branch information
ShivanshVij authored Sep 6, 2024
1 parent 3ba46ee commit a83a18c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
2 changes: 1 addition & 1 deletion async.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ func (c *Async) close() error {
c.stale = c.incoming.Drain()
c.staleMu.Unlock()
for _, stream := range c.streams {
_ = stream.Close()
_ = stream.closeSend(false)
}
c.streamsMu.Unlock()
c.Lock()
Expand Down
12 changes: 9 additions & 3 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func (s *Stream) Conn() *Async {

// Close will close the stream and prevent any further reads or writes.
func (s *Stream) Close() error {
return s.closeSend(true)
}

func (s *Stream) closeSend(lock bool) error {
s.staleMu.Lock()
if s.closed.CompareAndSwap(false, true) {
s.queue.Close()
Expand All @@ -106,9 +110,11 @@ func (s *Stream) Close() error {
err := s.conn.writePacket(p, true)
packet.Put(p)

s.conn.streamsMu.Lock()
delete(s.conn.streams, s.id)
s.conn.streamsMu.Unlock()
if lock {
s.conn.streamsMu.Lock()
delete(s.conn.streams, s.id)
s.conn.streamsMu.Unlock()
}

return err
}
Expand Down
53 changes: 53 additions & 0 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,56 @@ func TestNewStreamDualCreate(t *testing.T) {
err = readerConn.Close()
assert.NoError(t, err)
}

func TestStreamConnClose(t *testing.T) {
t.Parallel()

const packetSize = 512

emptyLogger := logging.Test(t, logging.Noop, t.Name())

reader, writer := net.Pipe()

readerConn := NewAsync(reader, emptyLogger, func(_ *Stream) {})
writerConn := NewAsync(writer, emptyLogger, func(_ *Stream) {})

writerStream := writerConn.NewStream(0)
readerStream := readerConn.NewStream(0)

data := make([]byte, packetSize)
_, err := rand.Read(data)
require.NoError(t, err)

p := packet.Get()
p.Metadata.Id = 64
p.Metadata.Operation = 32
p.Metadata.ContentLength = uint32(packetSize)
p.Content.Write(data)

err = writerStream.WritePacket(p)
require.NoError(t, err)

packet.Put(p)

err = writerConn.Close()
require.NoError(t, err)

time.Sleep(DefaultDeadline)

err = writerStream.Close()
require.ErrorIs(t, err, StreamClosed)

p, err = readerStream.ReadPacket()
require.NoError(t, err)
require.NotNil(t, p.Metadata)
assert.Equal(t, readerStream.ID(), p.Metadata.Id)
assert.Equal(t, STREAM, p.Metadata.Operation)
assert.Equal(t, uint32(packetSize), p.Metadata.ContentLength)
assert.Equal(t, data, p.Content.Bytes())

_, err = readerStream.ReadPacket()
require.ErrorIs(t, err, StreamClosed)

err = readerConn.Close()
assert.NoError(t, err)
}

0 comments on commit a83a18c

Please sign in to comment.