diff --git a/client_test.go b/client_test.go index e09f52da..9e6b830a 100644 --- a/client_test.go +++ b/client_test.go @@ -9,22 +9,31 @@ import ( "bytes" "io" "reflect" + "sync" "testing" "time" ) type server struct { *testing.T - r reader // framer <- client - w writer // framer -> client - S io.ReadWriteCloser // Server IO - C io.ReadWriteCloser // Client IO + destructor sync.Once + r reader // framer <- client + w writer // framer -> client + S io.ReadWriteCloser // Server IO + C io.ReadWriteCloser // Client IO // captured client frames start connectionStartOk tune connectionTuneOk } +func (srv *server) close() { + srv.destructor.Do(func() { + srv.C.Close() + srv.S.Close() + }) +} + func defaultConfig() Config { return Config{SASL: []Authentication{&PlainAuth{"guest", "guest"}}, Vhost: "/"} } @@ -33,8 +42,8 @@ func newSession(t *testing.T) (io.ReadWriteCloser, *server) { rs, wc := io.Pipe() rc, ws := io.Pipe() - rws := &logIO{t, "server", pipe{rs, ws}} - rwc := &logIO{t, "client", pipe{rc, wc}} + rws := &logIO{t: t, prefix: "server", proxy: &pipe{r: rs, w: ws}} + rwc := &logIO{t: t, prefix: "client", proxy: &pipe{r: rc, w: wc}} server := server{ T: t, @@ -175,13 +184,16 @@ func (t *server) channelOpen(id int) { func TestDefaultClientProperties(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() go func() { + defer srv.close() srv.connectionOpen() - rwc.Close() }() - if c, err := Open(rwc, defaultConfig()); err != nil { + c, err := Open(rwc, defaultConfig()) + defer c.Close() + if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } @@ -196,6 +208,7 @@ func TestDefaultClientProperties(t *testing.T) { func TestCustomClientProperties(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() config := defaultConfig() config.Properties = Table{ @@ -204,11 +217,13 @@ func TestCustomClientProperties(t *testing.T) { } go func() { + defer srv.close() srv.connectionOpen() - rwc.Close() }() - if c, err := Open(rwc, config); err != nil { + c, err := Open(rwc, config) + defer c.Close() + if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } @@ -223,27 +238,31 @@ func TestCustomClientProperties(t *testing.T) { func TestOpen(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() go func() { + defer srv.close() srv.connectionOpen() - rwc.Close() }() - if c, err := Open(rwc, defaultConfig()); err != nil { + c, err := Open(rwc, defaultConfig()) + defer c.Close() + if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } } func TestChannelOpen(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() go func() { + defer srv.close() srv.connectionOpen() srv.channelOpen(1) - - rwc.Close() }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } @@ -256,8 +275,10 @@ func TestChannelOpen(t *testing.T) { func TestOpenFailedSASLUnsupportedMechanisms(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() go func() { + defer srv.close() srv.expectAMQP() srv.send(0, &connectionStart{ VersionMajor: 0, @@ -268,6 +289,7 @@ func TestOpenFailedSASLUnsupportedMechanisms(t *testing.T) { }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != ErrSASL { t.Fatalf("expected ErrSASL got: %+v on %+v", err, c) } @@ -275,15 +297,17 @@ func TestOpenFailedSASLUnsupportedMechanisms(t *testing.T) { func TestOpenFailedCredentials(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() go func() { + // kill/timeout the connection indicating bad auth + defer srv.close() srv.expectAMQP() srv.connectionStart() - // Now kill/timeout the connection indicating bad auth - rwc.Close() }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != ErrCredentials { t.Fatalf("expected ErrCredentials got: %+v on %+v", err, c) } @@ -291,18 +315,19 @@ func TestOpenFailedCredentials(t *testing.T) { func TestOpenFailedVhost(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() go func() { + // kill/timeout the connection on bad Vhost + defer srv.close() srv.expectAMQP() srv.connectionStart() srv.connectionTune() srv.recv(0, &connectionOpen{}) - - // Now kill/timeout the connection on bad Vhost - rwc.Close() }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != ErrVhost { t.Fatalf("expected ErrVhost got: %+v on %+v", err, c) } @@ -310,9 +335,10 @@ func TestOpenFailedVhost(t *testing.T) { func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) { rwc, srv := newSession(t) - defer rwc.Close() + defer srv.close() go func() { + defer srv.close() srv.connectionOpen() srv.channelOpen(1) @@ -343,6 +369,7 @@ func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) { }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } @@ -387,8 +414,10 @@ func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) { func TestNotifyClosesReusedPublisherConfirmChan(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() go func() { + defer srv.close() srv.connectionOpen() srv.channelOpen(1) @@ -400,6 +429,7 @@ func TestNotifyClosesReusedPublisherConfirmChan(t *testing.T) { }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } @@ -423,8 +453,10 @@ func TestNotifyClosesReusedPublisherConfirmChan(t *testing.T) { func TestNotifyClosesAllChansAfterConnectionClose(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() go func() { + defer srv.close() srv.connectionOpen() srv.channelOpen(1) @@ -433,6 +465,7 @@ func TestNotifyClosesAllChansAfterConnectionClose(t *testing.T) { }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } @@ -488,7 +521,7 @@ func TestNotifyClosesAllChansAfterConnectionClose(t *testing.T) { // Should not panic when sending bodies split at different boundaries func TestPublishBodySliceIssue74(t *testing.T) { rwc, srv := newSession(t) - defer rwc.Close() + defer srv.close() const frameSize = 100 const publishings = frameSize * 3 @@ -497,6 +530,7 @@ func TestPublishBodySliceIssue74(t *testing.T) { base := make([]byte, publishings) go func() { + defer srv.close() srv.connectionOpen() srv.channelOpen(1) @@ -511,6 +545,7 @@ func TestPublishBodySliceIssue74(t *testing.T) { cfg.FrameSize = frameSize c, err := Open(rwc, cfg) + defer c.Close() if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } @@ -530,13 +565,14 @@ func TestPublishBodySliceIssue74(t *testing.T) { // Should not panic when server and client have frame_size of 0 func TestPublishZeroFrameSizeIssue161(t *testing.T) { rwc, srv := newSession(t) - defer rwc.Close() + defer srv.close() const frameSize = 0 const publishings = 1 done := make(chan bool) go func() { + defer srv.close() srv.connectionOpen() srv.channelOpen(1) @@ -551,6 +587,7 @@ func TestPublishZeroFrameSizeIssue161(t *testing.T) { cfg.FrameSize = frameSize c, err := Open(rwc, cfg) + defer c.Close() // override the tuned framesize with a hard 0, as would happen when rabbit is configured with 0 c.Config.FrameSize = frameSize @@ -573,7 +610,7 @@ func TestPublishZeroFrameSizeIssue161(t *testing.T) { func TestPublishAndShutdownDeadlockIssue84(t *testing.T) { rwc, srv := newSession(t) - defer rwc.Close() + defer srv.close() go func() { srv.connectionOpen() @@ -584,6 +621,7 @@ func TestPublishAndShutdownDeadlockIssue84(t *testing.T) { }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != nil { t.Fatalf("couldn't create connection: %v (%s)", c, err) } @@ -604,18 +642,20 @@ func TestPublishAndShutdownDeadlockIssue84(t *testing.T) { func TestChannelCloseRace(t *testing.T) { rwc, srv := newSession(t) + defer srv.close() done := make(chan bool) go func() { + defer srv.close() srv.connectionOpen() srv.channelOpen(1) - rwc.Close() done <- true }() c, err := Open(rwc, defaultConfig()) + defer c.Close() if err != nil { t.Fatalf("could not create connection: %v (%s)", c, err) } diff --git a/connection.go b/connection.go index e897b087..78c86b40 100644 --- a/connection.go +++ b/connection.go @@ -76,6 +76,7 @@ type Connection struct { conn io.ReadWriteCloser rpc chan message + readDone chan struct{} writer *writer sends chan time.Time // timestamps of each frame sent deadlines chan readDeadliner // heartbeater updates read deadlines @@ -213,6 +214,7 @@ to use your own custom transport. func Open(conn io.ReadWriteCloser, config Config) (*Connection, error) { me := &Connection{ conn: conn, + readDone: make(chan struct{}), writer: &writer{bufio.NewWriter(conn)}, channels: make(map[uint16]*Channel), rpc: make(chan message), @@ -220,7 +222,7 @@ func Open(conn io.ReadWriteCloser, config Config) (*Connection, error) { errors: make(chan *Error, 1), deadlines: make(chan readDeadliner, 1), } - go me.reader(conn) + go me.reader(conn, me.NotifyClose(make(chan *Error, 1))) return me, me.open(config) } @@ -296,6 +298,9 @@ including the underlying io, Channels, Notify listeners and Channel consumers will also be closed. */ func (me *Connection) Close() error { + defer func() { + <-me.readDone + }() defer me.shutdown(nil) return me.call( &connectionClose{ @@ -353,49 +358,55 @@ func (me *Connection) shutdown(err *Error) { atomic.StoreInt32(&me.closed, 1) me.destructor.Do(func() { - me.m.Lock() - defer me.m.Unlock() + me.destruct(err) + }) +} - if err != nil { - for _, c := range me.closes { - c <- err - } - } +func (me *Connection) destruct(err *Error) { + me.m.Lock() + defer me.m.Unlock() - for _, ch := range me.channels { - ch.shutdown(err) - me.releaseChannel(ch.id) + if err != nil { + for _, c := range me.closes { + c <- err } + } - if err != nil { - me.errors <- err - } + for _, ch := range me.channels { + ch.shutdown(err) + me.releaseChannel(ch.id) + } + + if err != nil { + me.errors <- err + } - me.conn.Close() + me.conn.Close() - for _, c := range me.closes { - close(c) - } + for _, c := range me.closes { + close(c) + } + me.closes = nil - for _, c := range me.blocks { - close(c) - } + for _, c := range me.blocks { + close(c) + } + me.blocks = nil - me.noNotify = true - }) + me.noNotify = true } // All methods sent to the connection channel should be synchronous so we // can handle them directly without a framing component -func (me *Connection) demux(f frame) { +func (me *Connection) demux(f frame, done chan *Error) { if f.channel() == 0 { - me.dispatch0(f) + me.dispatch0(f, done) } else { me.dispatchN(f) } } -func (me *Connection) dispatch0(f frame) { +func (me *Connection) dispatch0(f frame, done chan *Error) { switch mf := f.(type) { case *methodFrame: switch m := mf.Method.(type) { @@ -408,15 +419,18 @@ func (me *Connection) dispatch0(f frame) { me.shutdown(newError(m.ReplyCode, m.ReplyText)) case *connectionBlocked: - for _, c := range me.blocks { + for _, c := range me.blocksCopy() { c <- Blocking{Active: true, Reason: m.Reason} } case *connectionUnblocked: - for _, c := range me.blocks { + for _, c := range me.blocksCopy() { c <- Blocking{Active: false} } default: - me.rpc <- m + select { + case me.rpc <- m: + case <-done: + } } case *heartbeatFrame: // kthx - all reads reset our deadline. so we can drop this @@ -467,10 +481,18 @@ func (me *Connection) dispatchClosed(f frame) { } } +func (me *Connection) blocksCopy() []chan Blocking { + me.m.Lock() + blocks := append(([]chan Blocking)(nil), me.blocks...) + me.m.Unlock() + return blocks +} + // Reads each frame off the IO and hand off to the connection object that // will demux the streams and dispatch to one of the opened channels or // handle on channel 0 (the connection channel). -func (me *Connection) reader(r io.Reader) { +func (me *Connection) reader(r io.Reader, done chan *Error) { + defer close(me.readDone) buf := bufio.NewReader(r) frames := &reader{buf} conn, haveDeadliner := r.(readDeadliner) @@ -483,10 +505,13 @@ func (me *Connection) reader(r io.Reader) { return } - me.demux(frame) + me.demux(frame, done) if haveDeadliner { - me.deadlines <- conn + select { + case me.deadlines <- conn: + case <-done: + } } } } diff --git a/integration_test.go b/integration_test.go index 25e6b922..50892c9b 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1708,7 +1708,7 @@ func integrationURLFromEnv() string { func loggedConnection(t *testing.T, conn *Connection, name string) *Connection { if name != "" { - conn.conn = &logIO{t, name, conn.conn} + conn.conn = &logIO{t: t, prefix: name, proxy: conn.conn} } return conn } diff --git a/shared_test.go b/shared_test.go index 0b089516..768a6bce 100644 --- a/shared_test.go +++ b/shared_test.go @@ -16,15 +16,15 @@ type pipe struct { w *io.PipeWriter } -func (p pipe) Read(b []byte) (int, error) { +func (p *pipe) Read(b []byte) (int, error) { return p.r.Read(b) } -func (p pipe) Write(b []byte) (int, error) { +func (p *pipe) Write(b []byte) (int, error) { return p.w.Write(b) } -func (p pipe) Close() error { +func (p *pipe) Close() error { p.r.Close() p.w.Close() return nil