diff --git a/client_test.go b/client_test.go index e09f52da..ab48f1e9 100644 --- a/client_test.go +++ b/client_test.go @@ -33,8 +33,8 @@ func newSession(t *testing.T) (io.ReadWriteCloser, *server) { rs, wc := io.Pipe() rc, ws := io.Pipe() - rws := &logIO{t, "server", pipe{rs, ws}} - rwc := &logIO{t, "client", pipe{rc, wc}} + rws := &logIO{t, "server", &pipe{r: rs, w: ws}} + rwc := &logIO{t, "client", &pipe{r: rc, w: wc}} server := server{ T: t, diff --git a/shared_test.go b/shared_test.go index 0b089516..e3e71ab8 100644 --- a/shared_test.go +++ b/shared_test.go @@ -8,23 +8,38 @@ package amqp import ( "encoding/hex" "io" + "sync" "testing" ) type pipe struct { r *io.PipeReader w *io.PipeWriter + + m sync.Mutex + closed bool } -func (p pipe) Read(b []byte) (int, error) { +func (p *pipe) Read(b []byte) (int, error) { return p.r.Read(b) } -func (p pipe) Write(b []byte) (int, error) { +func (p *pipe) Write(b []byte) (int, error) { + p.m.Lock() + defer p.m.Unlock() + if p.closed { + return 0, io.ErrClosedPipe + } return p.w.Write(b) } -func (p pipe) Close() error { +func (p *pipe) Close() error { + p.m.Lock() + defer p.m.Unlock() + if p.closed { + return nil + } + p.closed = true p.r.Close() p.w.Close() return nil