From bc44d432f9f6b14ef0807fd3bfd1c8bbb0a53a5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rio=20Freitas?= Date: Mon, 28 Mar 2016 14:38:34 +0900 Subject: [PATCH] fix data race in connection shutdown --- .travis.yml | 2 +- channel.go | 8 +++++--- client_test.go | 27 +++++++++++++++++++++++++++ connection.go | 13 +++++++------ 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/.travis.yml b/.travis.yml index f1c275a2..aff0400a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,4 +11,4 @@ services: env: - AMQP_URL=amqp://guest:guest@127.0.0.1:5672/ GOMAXPROCS=2 -script: go test -v -tags integration ./... +script: go test -v -race -tags integration ./... diff --git a/channel.go b/channel.go index 7ac6ec98..fe6433e1 100644 --- a/channel.go +++ b/channel.go @@ -141,9 +141,12 @@ func (me *Channel) open() error { // Performs a request/response call for when the message is not NoWait and is // specified as Synchronous. func (me *Channel) call(req message, res ...message) error { + me.m.Lock() if err := me.send(me, req); err != nil { + me.m.Unlock() return err } + me.m.Unlock() if req.wait() { select { @@ -1476,9 +1479,6 @@ exception could occur if the server does not support this method. */ func (me *Channel) Confirm(noWait bool) error { - me.m.Lock() - defer me.m.Unlock() - if err := me.call( &confirmSelect{Nowait: noWait}, &confirmSelectOk{}, @@ -1486,7 +1486,9 @@ func (me *Channel) Confirm(noWait bool) error { return err } + me.m.Lock() me.confirming = true + me.m.Unlock() return nil } diff --git a/client_test.go b/client_test.go index 23acc974..f70a2c4d 100644 --- a/client_test.go +++ b/client_test.go @@ -601,3 +601,30 @@ func TestPublishAndShutdownDeadlockIssue84(t *testing.T) { } } } + +func TestChannelCloseRace(t *testing.T) { + rwc, srv := newSession(t) + + done := make(chan bool) + + go func() { + srv.connectionOpen() + srv.channelOpen(1) + + rwc.Close() + done <- true + }() + + c, err := Open(rwc, defaultConfig()) + if err != nil { + t.Fatalf("could not create connection: %v (%s)", c, err) + } + + ch, err := c.Channel() + if err != nil { + t.Fatalf("could not open channel: %v (%s)", ch, err) + } + <-done + ch.Close() + c.Close() +} diff --git a/connection.go b/connection.go index ad400797..e66229b3 100644 --- a/connection.go +++ b/connection.go @@ -340,6 +340,9 @@ func (me *Connection) send(f frame) error { func (me *Connection) shutdown(err *Error) { me.destructor.Do(func() { + me.m.Lock() + defer me.m.Unlock() + if err != nil { for _, c := range me.closes { c <- err @@ -347,7 +350,8 @@ func (me *Connection) shutdown(err *Error) { } for _, ch := range me.channels { - me.closeChannel(ch, err) + ch.shutdown(err) + me.releaseChannel(ch.id) } if err != nil { @@ -364,9 +368,7 @@ func (me *Connection) shutdown(err *Error) { close(c) } - me.m.Lock() me.noNotify = true - me.m.Unlock() }) } @@ -553,9 +555,6 @@ func (me *Connection) allocateChannel() (*Channel, error) { // releaseChannel removes a channel from the registry as the final part of the // channel lifecycle func (me *Connection) releaseChannel(id uint16) { - me.m.Lock() - defer me.m.Unlock() - delete(me.channels, id) me.allocator.release(int(id)) } @@ -578,6 +577,8 @@ func (me *Connection) openChannel() (*Channel, error) { // this connection. func (me *Connection) closeChannel(ch *Channel, e *Error) { ch.shutdown(e) + me.m.Lock() + defer me.m.Unlock() me.releaseChannel(ch.id) }