Skip to content

Commit

Permalink
Refactor chain ID logic in plugin to be chain agnostic (#15213)
Browse files Browse the repository at this point in the history
* Refactor chain ID logic in plugin to be chain agnostic

* fix script error

* update mod

* update dependency

* revert changes on mod

* remove old comment
  • Loading branch information
huangzhen1997 authored Nov 20, 2024
1 parent ae830b0 commit dee0d6a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 34 deletions.
5 changes: 5 additions & 0 deletions .changeset/light-trains-chew.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chainlink": minor
---

Refactor chain ID logic in plugin to be chain agnostic #added
70 changes: 36 additions & 34 deletions core/capabilities/ccip/oraclecreator/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/job"
"github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ocr2key"
"github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon"
"github.com/smartcontractkit/chainlink/v2/core/services/relay"
evmrelaytypes "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types"
"github.com/smartcontractkit/chainlink/v2/core/services/synchronization"
"github.com/smartcontractkit/chainlink/v2/core/services/telemetry"
Expand Down Expand Up @@ -117,15 +116,17 @@ func (i *pluginOracleCreator) Type() cctypes.OracleType {
// Create implements types.OracleCreator.
func (i *pluginOracleCreator) Create(ctx context.Context, donID uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) {
pluginType := cctypes.PluginType(config.Config.PluginType)
chainSelector := uint64(config.Config.ChainSelector)
destChainFamily, err := chainsel.GetSelectorFamily(chainSelector)
if err != nil {
return nil, fmt.Errorf("failed to get chain family from selector %d: %w", config.Config.ChainSelector, err)
}

// Assuming that the chain selector is referring to an evm chain for now.
// TODO: add an api that returns chain family.
destChainID, err := chainsel.ChainIdFromSelector(uint64(config.Config.ChainSelector))
destChainID, err := chainsel.GetChainIDFromSelector(chainSelector)
if err != nil {
return nil, fmt.Errorf("failed to get chain ID from selector %d: %w", config.Config.ChainSelector, err)
return nil, fmt.Errorf("failed to get chain ID from selector %d: %w", chainSelector, err)
}
destChainFamily := relay.NetworkEVM
destRelayID := types.NewRelayID(destChainFamily, fmt.Sprintf("%d", destChainID))
destRelayID := types.NewRelayID(destChainFamily, destChainID)

configTracker := ocrimpls.NewConfigTracker(config)
publicConfig, err := configTracker.PublicConfig()
Expand All @@ -139,6 +140,7 @@ func (i *pluginOracleCreator) Create(ctx context.Context, donID uint32, config c
pluginType,
config,
publicConfig,
destChainFamily,
)
if err != nil {
return nil, fmt.Errorf("failed to create readers and writers: %w", err)
Expand Down Expand Up @@ -293,10 +295,11 @@ func (i *pluginOracleCreator) createFactoryAndTransmitter(

func (i *pluginOracleCreator) createReadersAndWriters(
ctx context.Context,
destChainID uint64,
destChainID string,
pluginType cctypes.PluginType,
config cctypes.OCR3ConfigWithMeta,
publicCfg ocr3confighelper.PublicConfig,
chainFamily string,
) (
map[cciptypes.ChainSelector]types.ContractReader,
map[cciptypes.ChainSelector]types.ChainWriter,
Expand Down Expand Up @@ -324,17 +327,14 @@ func (i *pluginOracleCreator) createReadersAndWriters(
contractReaders := make(map[cciptypes.ChainSelector]types.ContractReader)
chainWriters := make(map[cciptypes.ChainSelector]types.ChainWriter)
for relayID, relayer := range i.relayers {
chainID, ok := new(big.Int).SetString(relayID.ChainID, 10)
if !ok {
return nil, nil, fmt.Errorf("error parsing chain ID, expected big int: %s", relayID.ChainID)
}
chainID := relayID.ChainID

chainSelector, err1 := i.getChainSelector(chainID.Uint64())
chainSelector, err1 := i.getChainSelector(chainID, chainFamily)
if err1 != nil {
return nil, nil, fmt.Errorf("failed to get chain selector from chain ID %s: %w", chainID.String(), err1)
return nil, nil, fmt.Errorf("failed to get chain selector from chain ID %s: %w", chainID, err1)
}

chainReaderConfig, err1 := getChainReaderConfig(i.lggr, chainID.Uint64(), destChainID, homeChainID, ofc, chainSelector)
chainReaderConfig, err1 := getChainReaderConfig(i.lggr, chainID, destChainID, homeChainID, ofc, chainSelector)
if err1 != nil {
return nil, nil, fmt.Errorf("failed to get chain reader config: %w", err1)
}
Expand All @@ -344,7 +344,7 @@ func (i *pluginOracleCreator) createReadersAndWriters(
return nil, nil, err1
}

if chainID.Uint64() == destChainID {
if chainID == destChainID {
offrampAddressHex := common.BytesToAddress(config.Config.OfframpAddress).Hex()
err2 := cr.Bind(ctx, []types.BoundContract{
{
Expand All @@ -353,26 +353,27 @@ func (i *pluginOracleCreator) createReadersAndWriters(
},
})
if err2 != nil {
return nil, nil, fmt.Errorf("failed to bind chain reader for dest chain %s's offramp at %s: %w", chainID.String(), offrampAddressHex, err)
return nil, nil, fmt.Errorf("failed to bind chain reader for dest chain %s's offramp at %s: %w", chainID, offrampAddressHex, err)
}
}

if err2 := cr.Start(ctx); err2 != nil {
return nil, nil, fmt.Errorf("failed to start contract reader for chain %s: %w", chainID.String(), err2)
return nil, nil, fmt.Errorf("failed to start contract reader for chain %s: %w", chainID, err2)
}

cw, err1 := createChainWriter(
ctx,
chainID,
relayer,
i.transmitters,
execBatchGasLimit)
execBatchGasLimit,
chainFamily)
if err1 != nil {
return nil, nil, err1
}

if err4 := cw.Start(ctx); err4 != nil {
return nil, nil, fmt.Errorf("failed to start chain writer for chain %s: %w", chainID.String(), err4)
return nil, nil, fmt.Errorf("failed to start chain writer for chain %s: %w", chainID, err4)
}

contractReaders[chainSelector] = cr
Expand Down Expand Up @@ -411,27 +412,27 @@ func decodeAndValidateOffchainConfig(
return ofc, nil
}

func (i *pluginOracleCreator) getChainSelector(chainID uint64) (cciptypes.ChainSelector, error) {
chainSelector, ok := chainsel.EvmChainIdToChainSelector()[chainID]
if !ok {
return 0, fmt.Errorf("failed to get chain selector from chain ID %d", chainID)
func (i *pluginOracleCreator) getChainSelector(chainID string, chainFamily string) (cciptypes.ChainSelector, error) {
chainDetails, err := chainsel.GetChainDetailsByChainIDAndFamily(chainID, chainFamily)
if err != nil {
return 0, fmt.Errorf("failed to get chain selector from chain ID %s and family %s", chainID, chainFamily)
}
return cciptypes.ChainSelector(chainSelector), nil
return cciptypes.ChainSelector(chainDetails.ChainSelector), nil
}

func (i *pluginOracleCreator) getChainID(chainSelector cciptypes.ChainSelector) (uint64, error) {
chainID, err := chainsel.ChainIdFromSelector(uint64(chainSelector))
func (i *pluginOracleCreator) getChainID(chainSelector cciptypes.ChainSelector) (string, error) {
chainID, err := chainsel.GetChainIDFromSelector(uint64(chainSelector))
if err != nil {
return 0, fmt.Errorf("failed to get chain ID from chain selector %d: %w", chainSelector, err)
return "", fmt.Errorf("failed to get chain ID from chain selector %d: %w", chainSelector, err)
}
return chainID, nil
}

func getChainReaderConfig(
lggr logger.Logger,
chainID uint64,
destChainID uint64,
homeChainID uint64,
chainID string,
destChainID string,
homeChainID string,
ofc offChainConfig,
chainSelector cciptypes.ChainSelector,
) ([]byte, error) {
Expand Down Expand Up @@ -475,13 +476,14 @@ func isUSDCEnabled(ofc offChainConfig) bool {

func createChainWriter(
ctx context.Context,
chainID *big.Int,
chainID string,
relayer loop.Relayer,
transmitters map[types.RelayID][]string,
execBatchGasLimit uint64,
chainFamily string,
) (types.ChainWriter, error) {
var fromAddress common.Address
transmitter, ok := transmitters[types.NewRelayID(relay.NetworkEVM, chainID.String())]
transmitter, ok := transmitters[types.NewRelayID(chainFamily, chainID)]
if ok {
// TODO: remove EVM-specific stuff
fromAddress = common.HexToAddress(transmitter[0])
Expand All @@ -503,7 +505,7 @@ func createChainWriter(

cw, err := relayer.NewChainWriter(ctx, chainWriterConfig)
if err != nil {
return nil, fmt.Errorf("failed to create chain writer for chain %s: %w", chainID.String(), err)
return nil, fmt.Errorf("failed to create chain writer for chain %s: %w", chainID, err)
}

return cw, nil
Expand Down

0 comments on commit dee0d6a

Please sign in to comment.