diff --git a/core/capabilities/compute/compute_test.go b/core/capabilities/compute/compute_test.go index c4146b7408e..3e5f501fa61 100644 --- a/core/capabilities/compute/compute_test.go +++ b/core/capabilities/compute/compute_test.go @@ -14,6 +14,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/capabilities" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/wasmtest" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils/matches" cappkg "github.com/smartcontractkit/chainlink-common/pkg/capabilities" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" @@ -188,6 +189,7 @@ func TestComputeFetch(t *testing.T) { th := setup(t, defaultConfig) th.connector.EXPECT().DonID().Return("don-id") + th.connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil) th.connector.EXPECT().GatewayIDs().Return([]string{"gateway1", "gateway2"}) msgID := strings.Join([]string{ diff --git a/core/capabilities/webapi/outgoing_connector_handler.go b/core/capabilities/webapi/outgoing_connector_handler.go index 5ea497cd87d..a9ff9ee3aae 100644 --- a/core/capabilities/webapi/outgoing_connector_handler.go +++ b/core/capabilities/webapi/outgoing_connector_handler.go @@ -96,8 +96,15 @@ func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context, } sort.Strings(gatewayIDs) - err = c.gc.SignAndSendToGateway(ctx, gatewayIDs[0], body) - if err != nil { + selectedGateway := gatewayIDs[0] + + l.Infow("selected gateway, awaiting connection", "gatewayID", selectedGateway) + + if err := c.gc.AwaitConnection(ctx, selectedGateway); err != nil { + return nil, errors.Wrap(err, "await connection canceled") + } + + if err := c.gc.SignAndSendToGateway(ctx, selectedGateway, body); err != nil { return nil, errors.Wrap(err, "failed to send request to gateway") } diff --git a/core/capabilities/webapi/outgoing_connector_handler_test.go b/core/capabilities/webapi/outgoing_connector_handler_test.go index 2090edc6aea..4a8c425d4f1 100644 --- a/core/capabilities/webapi/outgoing_connector_handler_test.go +++ b/core/capabilities/webapi/outgoing_connector_handler_test.go @@ -10,6 +10,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils/matches" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" @@ -36,6 +37,7 @@ func TestHandleSingleNodeRequest(t *testing.T) { msgID := "msgID" testURL := "http://localhost:8080" connector.EXPECT().DonID().Return("donID") + connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil) connector.EXPECT().GatewayIDs().Return([]string{"gateway1"}) // build the expected body with the default timeout @@ -82,6 +84,7 @@ func TestHandleSingleNodeRequest(t *testing.T) { msgID := "msgID" testURL := "http://localhost:8080" connector.EXPECT().DonID().Return("donID") + connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil) connector.EXPECT().GatewayIDs().Return([]string{"gateway1"}) // build the expected body with the defined timeout diff --git a/core/capabilities/webapi/target/target_test.go b/core/capabilities/webapi/target/target_test.go index f51cdcd0d70..1af9a107054 100644 --- a/core/capabilities/webapi/target/target_test.go +++ b/core/capabilities/webapi/target/target_test.go @@ -194,7 +194,7 @@ func TestCapability_Execute(t *testing.T) { require.NoError(t, err) gatewayResp := gatewayResponse(t, msgID) - + th.connector.EXPECT().AwaitConnection(mock.Anything, "gateway1").Return(nil) th.connector.On("SignAndSendToGateway", mock.Anything, "gateway1", mock.Anything).Return(nil).Run(func(args mock.Arguments) { th.connectorHandler.HandleGatewayMessage(ctx, "gateway1", gatewayResp) }).Once() diff --git a/core/services/gateway/connector/connector.go b/core/services/gateway/connector/connector.go index a8d356478e9..cab123d4ce5 100644 --- a/core/services/gateway/connector/connector.go +++ b/core/services/gateway/connector/connector.go @@ -28,13 +28,14 @@ type GatewayConnector interface { AddHandler(methods []string, handler GatewayConnectorHandler) error // SendToGateway takes a signed message as argument and sends it to the specified gateway - SendToGateway(ctx context.Context, gatewayId string, msg *api.Message) error + SendToGateway(ctx context.Context, gatewayID string, msg *api.Message) error // SignAndSendToGateway signs the message and sends the message to the specified gateway SignAndSendToGateway(ctx context.Context, gatewayID string, msg *api.MessageBody) error // GatewayIDs returns the list of Gateway IDs GatewayIDs() []string // DonID returns the DON ID DonID() string + AwaitConnection(ctx context.Context, gatewayID string) error } // Signer implementation needs to be provided by a GatewayConnector user (node) @@ -78,12 +79,30 @@ func (c *gatewayConnector) HealthReport() map[string]error { func (c *gatewayConnector) Name() string { return c.lggr.Name() } type gatewayState struct { + // signal channel is closed once the gateway is connected + signalCh chan struct{} + conn network.WSConnectionWrapper config ConnectorGatewayConfig url *url.URL wsClient network.WebSocketClient } +// A gatewayState is connected when the signal channel is closed +func (gs *gatewayState) signal() { + close(gs.signalCh) +} + +// awaitConn blocks until the gateway is connected or the context is done +func (gs *gatewayState) awaitConn(ctx context.Context) error { + select { + case <-ctx.Done(): + return fmt.Errorf("await connection failed: %w", ctx.Err()) + case <-gs.signalCh: + return nil + } +} + func NewGatewayConnector(config *ConnectorConfig, signer Signer, clock clockwork.Clock, lggr logger.Logger) (GatewayConnector, error) { if config == nil || signer == nil || clock == nil || lggr == nil { return nil, errors.New("nil dependency") @@ -125,6 +144,7 @@ func NewGatewayConnector(config *ConnectorConfig, signer Signer, clock clockwork config: gw, url: parsedURL, wsClient: network.NewWebSocketClient(config.WsClientConfig, connector, lggr), + signalCh: make(chan struct{}), } gateways[gw.Id] = gateway urlToId[gw.URL] = gw.Id @@ -150,17 +170,25 @@ func (c *gatewayConnector) AddHandler(methods []string, handler GatewayConnector return nil } -func (c *gatewayConnector) SendToGateway(ctx context.Context, gatewayId string, msg *api.Message) error { +func (c *gatewayConnector) AwaitConnection(ctx context.Context, gatewayID string) error { + gateway, ok := c.gateways[gatewayID] + if !ok { + return fmt.Errorf("invalid Gateway ID %s", gatewayID) + } + return gateway.awaitConn(ctx) +} + +func (c *gatewayConnector) SendToGateway(ctx context.Context, gatewayID string, msg *api.Message) error { data, err := c.codec.EncodeResponse(msg) if err != nil { - return fmt.Errorf("error encoding response for gateway %s: %v", gatewayId, err) + return fmt.Errorf("error encoding response for gateway %s: %w", gatewayID, err) } - gateway, ok := c.gateways[gatewayId] + gateway, ok := c.gateways[gatewayID] if !ok { - return fmt.Errorf("invalid Gateway ID %s", gatewayId) + return fmt.Errorf("invalid Gateway ID %s", gatewayID) } if gateway.conn == nil { - return fmt.Errorf("connector not started") + return errors.New("connector not started") } return gateway.conn.Write(ctx, websocket.BinaryMessage, data) } @@ -242,10 +270,15 @@ func (c *gatewayConnector) reconnectLoop(gatewayState *gatewayState) { } else { c.lggr.Infow("connected successfully", "url", gatewayState.url) closeCh := gatewayState.conn.Reset(conn) + gatewayState.signal() <-closeCh c.lggr.Infow("connection closed", "url", gatewayState.url) + // reset backoff redialBackoff = utils.NewRedialBackoff() + + // reset signal channel + gatewayState.signalCh = make(chan struct{}) } select { case <-c.shutdownCh: diff --git a/core/services/gateway/connector/mocks/gateway_connector.go b/core/services/gateway/connector/mocks/gateway_connector.go index 183fc949cd5..ba5c2213b5f 100644 --- a/core/services/gateway/connector/mocks/gateway_connector.go +++ b/core/services/gateway/connector/mocks/gateway_connector.go @@ -73,6 +73,53 @@ func (_c *GatewayConnector_AddHandler_Call) RunAndReturn(run func([]string, conn return _c } +// AwaitConnection provides a mock function with given fields: ctx, gatewayID +func (_m *GatewayConnector) AwaitConnection(ctx context.Context, gatewayID string) error { + ret := _m.Called(ctx, gatewayID) + + if len(ret) == 0 { + panic("no return value specified for AwaitConnection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, gatewayID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GatewayConnector_AwaitConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AwaitConnection' +type GatewayConnector_AwaitConnection_Call struct { + *mock.Call +} + +// AwaitConnection is a helper method to define mock.On call +// - ctx context.Context +// - gatewayID string +func (_e *GatewayConnector_Expecter) AwaitConnection(ctx interface{}, gatewayID interface{}) *GatewayConnector_AwaitConnection_Call { + return &GatewayConnector_AwaitConnection_Call{Call: _e.mock.On("AwaitConnection", ctx, gatewayID)} +} + +func (_c *GatewayConnector_AwaitConnection_Call) Run(run func(ctx context.Context, gatewayID string)) *GatewayConnector_AwaitConnection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *GatewayConnector_AwaitConnection_Call) Return(_a0 error) *GatewayConnector_AwaitConnection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *GatewayConnector_AwaitConnection_Call) RunAndReturn(run func(context.Context, string) error) *GatewayConnector_AwaitConnection_Call { + _c.Call.Return(run) + return _c +} + // ChallengeResponse provides a mock function with given fields: _a0, challenge func (_m *GatewayConnector) ChallengeResponse(_a0 *url.URL, challenge []byte) ([]byte, error) { ret := _m.Called(_a0, challenge) @@ -464,9 +511,9 @@ func (_c *GatewayConnector_Ready_Call) RunAndReturn(run func() error) *GatewayCo return _c } -// SendToGateway provides a mock function with given fields: ctx, gatewayId, msg -func (_m *GatewayConnector) SendToGateway(ctx context.Context, gatewayId string, msg *api.Message) error { - ret := _m.Called(ctx, gatewayId, msg) +// SendToGateway provides a mock function with given fields: ctx, gatewayID, msg +func (_m *GatewayConnector) SendToGateway(ctx context.Context, gatewayID string, msg *api.Message) error { + ret := _m.Called(ctx, gatewayID, msg) if len(ret) == 0 { panic("no return value specified for SendToGateway") @@ -474,7 +521,7 @@ func (_m *GatewayConnector) SendToGateway(ctx context.Context, gatewayId string, var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, *api.Message) error); ok { - r0 = rf(ctx, gatewayId, msg) + r0 = rf(ctx, gatewayID, msg) } else { r0 = ret.Error(0) } @@ -489,13 +536,13 @@ type GatewayConnector_SendToGateway_Call struct { // SendToGateway is a helper method to define mock.On call // - ctx context.Context -// - gatewayId string +// - gatewayID string // - msg *api.Message -func (_e *GatewayConnector_Expecter) SendToGateway(ctx interface{}, gatewayId interface{}, msg interface{}) *GatewayConnector_SendToGateway_Call { - return &GatewayConnector_SendToGateway_Call{Call: _e.mock.On("SendToGateway", ctx, gatewayId, msg)} +func (_e *GatewayConnector_Expecter) SendToGateway(ctx interface{}, gatewayID interface{}, msg interface{}) *GatewayConnector_SendToGateway_Call { + return &GatewayConnector_SendToGateway_Call{Call: _e.mock.On("SendToGateway", ctx, gatewayID, msg)} } -func (_c *GatewayConnector_SendToGateway_Call) Run(run func(ctx context.Context, gatewayId string, msg *api.Message)) *GatewayConnector_SendToGateway_Call { +func (_c *GatewayConnector_SendToGateway_Call) Run(run func(ctx context.Context, gatewayID string, msg *api.Message)) *GatewayConnector_SendToGateway_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string), args[2].(*api.Message)) }) diff --git a/core/services/workflows/syncer/fetcher_test.go b/core/services/workflows/syncer/fetcher_test.go index 8e3e58fba0d..ee59d22608a 100644 --- a/core/services/workflows/syncer/fetcher_test.go +++ b/core/services/workflows/syncer/fetcher_test.go @@ -15,6 +15,7 @@ import ( gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" ghcapabilities "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" + "github.com/smartcontractkit/chainlink/v2/core/utils/matches" ) type wrapper struct { @@ -48,6 +49,7 @@ func TestNewFetcherService(t *testing.T) { fetcher.och.HandleGatewayMessage(ctx, "gateway1", gatewayResp) }).Return(nil).Times(1) connector.EXPECT().DonID().Return("don-id") + connector.EXPECT().AwaitConnection(matches.AnyContext, "gateway1").Return(nil) connector.EXPECT().GatewayIDs().Return([]string{"gateway1", "gateway2"}) payload, err := fetcher.Fetch(ctx, url)