Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PublishWithContext #500

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 75 additions & 31 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package amqp

import (
"context"
"reflect"
"sync"
"sync/atomic"
Expand All @@ -26,8 +27,8 @@ should be discarded and a new channel established.
*/
type Channel struct {
destructor sync.Once
m sync.Mutex // struct field mutex
confirmM sync.Mutex // publisher confirms state mutex
sema chan struct{} // struct field mutex
confirmM sync.Mutex // publisher confirms state mutex
notifyM sync.RWMutex

connection *Connection
Expand Down Expand Up @@ -84,15 +85,16 @@ func newChannel(c *Connection, id uint16) *Channel {
confirms: newConfirms(),
recv: (*Channel).recvMethod,
errors: make(chan *Error, 1),
sema: make(chan struct{}, 1),
}
}

// shutdown is called by Connection after the channel has been removed from the
// connection registry.
func (ch *Channel) shutdown(e *Error) {
ch.destructor.Do(func() {
ch.m.Lock()
defer ch.m.Unlock()
ch.sema <- struct{}{}
defer func() { <-ch.sema }()

// Grab an exclusive lock for the notify channels
ch.notifyM.Lock()
Expand Down Expand Up @@ -152,13 +154,13 @@ func (ch *Channel) shutdown(e *Error) {
//
// After the channel has been closed, send calls Channel.sendClosed(), ensuring
// only 'channel.close' is sent to the server.
func (ch *Channel) send(msg message) (err error) {
func (ch *Channel) send(ctx context.Context, msg message) (err error) {
// If the channel is closed, use Channel.sendClosed()
if atomic.LoadInt32(&ch.closed) == 1 {
return ch.sendClosed(msg)
return ch.sendClosed(ctx, msg)
}

return ch.sendOpen(msg)
return ch.sendOpen(ctx, msg)
}

func (ch *Channel) open() error {
Expand All @@ -168,7 +170,7 @@ func (ch *Channel) open() error {
// Performs a request/response call for when the message is not NoWait and is
// specified as Synchronous.
func (ch *Channel) call(req message, res ...message) error {
if err := ch.send(req); err != nil {
if err := ch.send(context.Background(), req); err != nil {
return err
}

Expand Down Expand Up @@ -203,11 +205,11 @@ func (ch *Channel) call(req message, res ...message) error {
return nil
}

func (ch *Channel) sendClosed(msg message) (err error) {
func (ch *Channel) sendClosed(ctx context.Context, msg message) (err error) {
// After a 'channel.close' is sent or received the only valid response is
// channel.close-ok
if _, ok := msg.(*channelCloseOk); ok {
return ch.connection.send(&methodFrame{
return ch.connection.send(ctx, &methodFrame{
ChannelId: ch.id,
Method: msg,
})
Expand All @@ -216,7 +218,7 @@ func (ch *Channel) sendClosed(msg message) (err error) {
return ErrClosed
}

func (ch *Channel) sendOpen(msg message) (err error) {
func (ch *Channel) sendOpen(ctx context.Context, msg message) (err error) {
if content, ok := msg.(messageWithContent); ok {
props, body := content.getContent()
class, _ := content.id()
Expand All @@ -230,14 +232,14 @@ func (ch *Channel) sendOpen(msg message) (err error) {
size = len(body)
}

if err = ch.connection.send(&methodFrame{
if err = ch.connection.send(ctx, &methodFrame{
ChannelId: ch.id,
Method: content,
}); err != nil {
return
}

if err = ch.connection.send(&headerFrame{
if err = ch.connection.send(ctx, &headerFrame{
ChannelId: ch.id,
ClassId: class,
Size: uint64(len(body)),
Expand All @@ -252,15 +254,15 @@ func (ch *Channel) sendOpen(msg message) (err error) {
j = len(body)
}

if err = ch.connection.send(&bodyFrame{
if err = ch.connection.send(ctx, &bodyFrame{
ChannelId: ch.id,
Body: body[i:j],
}); err != nil {
return
}
}
} else {
err = ch.connection.send(&methodFrame{
err = ch.connection.send(ctx, &methodFrame{
ChannelId: ch.id,
Method: msg,
})
Expand All @@ -277,9 +279,9 @@ func (ch *Channel) dispatch(msg message) {
// lock before sending connection.close-ok
// to avoid unexpected interleaving with basic.publish frames if
// publishing is happening concurrently
ch.m.Lock()
ch.send(&channelCloseOk{})
ch.m.Unlock()
ch.sema <- struct{}{}
ch.send(context.Background(), &channelCloseOk{})
<-ch.sema
ch.connection.closeChannel(ch, newError(m.ReplyCode, m.ReplyText))

case *channelFlow:
Expand All @@ -288,7 +290,7 @@ func (ch *Channel) dispatch(msg message) {
c <- m.Active
}
ch.notifyM.RUnlock()
ch.send(&channelFlowOk{Active: m.Active})
ch.send(context.Background(), &channelFlowOk{Active: m.Active})

case *basicCancel:
ch.notifyM.RLock()
Expand Down Expand Up @@ -1324,14 +1326,56 @@ internal counter for DeliveryTags with the first confirmation starts at 1.

*/
func (ch *Channel) Publish(exchange, key string, mandatory, immediate bool, msg Publishing) error {
return ch.PublishWithContext(context.Background(), exchange, key, mandatory, immediate, msg)
}

/*
PublishWithContext sends a Publishing from the client to an exchange on the server.
This method uses the context for closing the Publishing.

When you want a single message to be delivered to a single queue, you can
publish to the default exchange with the routingKey of the queue name. This is
because every declared queue gets an implicit route to the default exchange.

Since publishings are asynchronous, any undeliverable message will get returned
by the server. Add a listener with Channel.NotifyReturn to handle any
undeliverable message when calling publish with either the mandatory or
immediate parameters as true.

Publishings can be undeliverable when the mandatory flag is true and no queue is
bound that matches the routing key, or when the immediate flag is true and no
consumer on the matched queue is ready to accept the delivery.

This can return an error when the channel, connection or socket is closed. The
error or lack of an error does not indicate whether the server has received this
publishing.

It is possible for publishing to not reach the broker if the underlying socket
is shut down without pending publishing packets being flushed from the kernel
buffers. The easy way of making it probable that all publishings reach the
server is to always call Connection.Close before terminating your publishing
application. The way to ensure that all publishings reach the server is to add
a listener to Channel.NotifyPublish and put the channel in confirm mode with
Channel.Confirm. Publishing delivery tags and their corresponding
confirmations start at 1. Exit when all publishings are confirmed.

When Publish does not return an error and the channel is in confirm mode, the
internal counter for DeliveryTags with the first confirmation starts at 1.

*/
func (ch *Channel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg Publishing) error {
if err := msg.Headers.Validate(); err != nil {
return err
}

ch.m.Lock()
defer ch.m.Unlock()
select {
case ch.sema <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
defer func() { <-ch.sema }()

if err := ch.send(&basicPublish{
if err := ch.send(ctx, &basicPublish{
Exchange: exchange,
RoutingKey: key,
Mandatory: mandatory,
Expand Down Expand Up @@ -1548,10 +1592,10 @@ is true.
See also Delivery.Ack
*/
func (ch *Channel) Ack(tag uint64, multiple bool) error {
ch.m.Lock()
defer ch.m.Unlock()
ch.sema <- struct{}{}
defer func() { <-ch.sema }()

return ch.send(&basicAck{
return ch.send(context.Background(), &basicAck{
DeliveryTag: tag,
Multiple: multiple,
})
Expand All @@ -1565,10 +1609,10 @@ it must be redelivered or dropped.
See also Delivery.Nack
*/
func (ch *Channel) Nack(tag uint64, multiple bool, requeue bool) error {
ch.m.Lock()
defer ch.m.Unlock()
ch.sema <- struct{}{}
defer func() { <-ch.sema }()

return ch.send(&basicNack{
return ch.send(context.Background(), &basicNack{
DeliveryTag: tag,
Multiple: multiple,
Requeue: requeue,
Expand All @@ -1583,10 +1627,10 @@ multiple messages, reducing the amount of protocol messages to exchange.
See also Delivery.Reject
*/
func (ch *Channel) Reject(tag uint64, requeue bool) error {
ch.m.Lock()
defer ch.m.Unlock()
ch.sema <- struct{}{}
defer func() { <-ch.sema }()

return ch.send(&basicReject{
return ch.send(context.Background(), &basicReject{
DeliveryTag: tag,
Requeue: requeue,
})
Expand Down
45 changes: 45 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package amqp

import (
"bytes"
"context"
"io"
"reflect"
"testing"
Expand Down Expand Up @@ -714,3 +715,47 @@ func TestLeakClosedConsumersIssue264(t *testing.T) {
t.Fatalf("expected deliveries channel to be closed immediately when the connection is closed so not to leak the bufferDeliveries goroutine")
}
}

func TestPublishWithContext(t *testing.T) {
rwc, srv := newSession(t)
defer rwc.Close()

done := make(chan bool)

go func() {
defer close(done)
srv.connectionOpen()
srv.channelOpen(1)
srv.recv(1, &basicPublish{})
}()

cfg := defaultConfig()

c, err := Open(rwc, cfg)
if err != nil {
t.Fatalf("could not create connection: %v (%s)", c, err)
}

ch, err := c.Channel()
if err != nil {
t.Fatalf("could not open channel: %v (%s)", ch, err)
}

canclledCtx, cancel := context.WithCancel(context.Background())
cancel()
err = ch.PublishWithContext(canclledCtx, "", "q", false, false, Publishing{Body: []byte("anything")})
if err != canclledCtx.Err() {
t.Fatalf("unexpected error during publish with closed context: %v", err)
}

err = ch.PublishWithContext(context.Background(), "", "q", false, false, Publishing{Body: []byte("anything")})
if err != nil {
t.Fatalf("unexpected error during publish with valid context: %v", err)
}

select {
case <-time.After(5 * time.Second):
t.Fatal("timeout")
case <-done:
}
}
Loading