diff --git a/client.go b/client.go index 1d036ae..e3fb145 100644 --- a/client.go +++ b/client.go @@ -3,6 +3,7 @@ package amqp import ( "fmt" "context" + "time" "sync" "github.com/streadway/amqp" @@ -12,6 +13,11 @@ import ( "github.com/luraproject/lura/proxy" ) +const ( + retryInterval = 3 * time.Second + maxRetries = 15 +) + func NewBackendFactory(ctx context.Context, logger logging.Logger, bf proxy.BackendFactory) proxy.BackendFactory { f := backendFactory{ logger: logger, @@ -48,8 +54,28 @@ func (f backendFactory) New(remote *config.Backend) proxy.Proxy { return f.bf(remote) } -func (f backendFactory) newChannel(path string) (*amqp.Channel, closer, error) { - conn, err := amqp.Dial(path) +func (f backendFactory) newChannel(ctx context.Context, path string) (*amqp.Channel, closer, error) { + var ( + retries int + conn *amqp.Connection + err error + ) + for { + conn, err = amqp.Dial(path) + if err == nil { + break + } + retries += 1 + f.logger.Error(fmt.Sprintf("AMQP: connection attempt #%d: %s", retries, err.Error())) + if retries > maxRetries { + break + } + select { + case <-time.After(retryInterval): + case <-ctx.Done(): + break + } + } if err != nil { return nil, nopCloser, err } diff --git a/consumer.go b/consumer.go index ae09090..b15fcd9 100644 --- a/consumer.go +++ b/consumer.go @@ -47,7 +47,7 @@ func (f backendFactory) initConsumer(ctx context.Context, remote *config.Backend return consumerBackend(remote, msgs), nil } - ch, close, err := f.newChannel(dns) + ch, close, err := f.newChannel(ctx, dns) if err != nil { f.logger.Error(fmt.Sprintf("AMQP: getting the channel for %s/%s: %s", dns, cfg.Name, err.Error())) return proxy.NoopProxy, err diff --git a/producer.go b/producer.go index 8251861..6784015 100644 --- a/producer.go +++ b/producer.go @@ -54,7 +54,7 @@ func (f backendFactory) initProducer(ctx context.Context, remote *config.Backend return proxy.NoopProxy, err } - ch, close, err := f.newChannel(dns) + ch, close, err := f.newChannel(ctx, dns) if err != nil { f.logger.Error(fmt.Sprintf("AMQP: getting the channel for %s/%s: %s", dns, cfg.Name, err.Error())) return proxy.NoopProxy, err diff --git a/rpc.go b/rpc.go index 0f3dac1..d9efc30 100644 --- a/rpc.go +++ b/rpc.go @@ -64,7 +64,7 @@ func (c *rpcChannel) Load() *amqp.Channel { } func (f backendFactory) initRpcChannel(ctx context.Context, cfg *rpcCfg, dns string, rch *rpcChannel) error { - ch, close, err := f.newChannel(dns) + ch, close, err := f.newChannel(ctx, dns) if err != nil { f.logger.Error(fmt.Sprintf("AMQP: getting the channel for %s/%s: %s", dns, cfg.Name, err.Error())) return err @@ -107,11 +107,8 @@ func (f backendFactory) initRpcChannel(ctx context.Context, cfg *rpcCfg, dns str if !ok { break } - for { - time.Sleep(time.Duration(3) * time.Second) - if err := f.initRpcChannel(ctx, cfg, dns, rch); err == nil { - break - } + if err := f.initRpcChannel(ctx, cfg, dns, rch); err == nil { + break } case reply, ok := <-replies: if !ok {