Skip to content

Commit

Permalink
EVM: Rework per-chain transceivers
Browse files Browse the repository at this point in the history
  • Loading branch information
bruce-riley committed Sep 24, 2024
1 parent 507dcbf commit 30cedd1
Show file tree
Hide file tree
Showing 12 changed files with 771 additions and 908 deletions.
101 changes: 50 additions & 51 deletions evm/src/NttManager/ManagerBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,11 @@ abstract contract ManagerBase is
bytes32 private constant MESSAGE_SEQUENCE_SLOT =
bytes32(uint256(keccak256("ntt.messageSequence")) - 1);

bytes32 private constant THRESHOLD_SLOT = bytes32(uint256(keccak256("ntt.threshold")) - 1);
bytes32 internal constant THRESHOLD_SLOT = bytes32(uint256(keccak256("ntt.threshold")) - 1);

// =============== Storage Getters/Setters ==============================================

function _getThresholdStorage() private pure returns (_Threshold storage $) {
uint256 slot = uint256(THRESHOLD_SLOT);
assembly ("memory-safe") {
$.slot := slot
}
}

function _getThresholdStoragePerChain()
private
pure
returns (mapping(uint16 => _Threshold) storage $)
{
// TODO: this is safe (reusing the storage slot, because the mapping
// doesn't write into the slot itself) buy maybe we shouldn't?
function _getThresholdStorage() internal pure returns (_Threshold storage $) {
uint256 slot = uint256(THRESHOLD_SLOT);
assembly ("memory-safe") {
$.slot := slot
Expand Down Expand Up @@ -135,7 +122,7 @@ abstract contract ManagerBase is
uint256 totalPriceQuote = 0;
for (uint256 i = 0; i < numEnabledTransceivers; i++) {
address transceiverAddr = enabledTransceivers[i];
if (!_isTransceiverEnabledForChain(transceiverAddr, recipientChain)) {
if (!_isSendTransceiverEnabledForChain(transceiverAddr, recipientChain)) {
continue;
}
uint8 registeredTransceiverIndex = transceiverInfos[transceiverAddr].index;
Expand Down Expand Up @@ -217,7 +204,7 @@ abstract contract ManagerBase is
// call into transceiver contracts to send the message
for (uint256 i = 0; i < numEnabledTransceivers; i++) {
address transceiverAddr = enabledTransceivers[i];
if (!_isTransceiverEnabledForChain(transceiverAddr, recipientChain)) {
if (!_isSendTransceiverEnabledForChain(transceiverAddr, recipientChain)) {
continue;
}

Expand Down Expand Up @@ -304,22 +291,19 @@ abstract contract ManagerBase is
return _getThresholdStorage().num;
}

function getThreshold(
uint16 forChainId
) public view returns (uint8) {
uint8 threshold = _getThresholdStoragePerChain()[forChainId].num;
if (threshold == 0) {
return _getThresholdStorage().num;
}
return threshold;
/// @inheritdoc IManagerBase
function getPerChainThreshold(
uint16 // forChainId
) public view virtual returns (uint8) {
return _getThresholdStorage().num;
}

/// @inheritdoc IManagerBase
function isMessageApproved(
bytes32 digest
) public view returns (bool) {
uint16 sourceChainId = _getMessageAttestationsStorage()[digest].sourceChainId;
uint8 threshold = getThreshold(sourceChainId);
uint8 threshold = getPerChainThreshold(sourceChainId);
return messageAttestations(digest) >= threshold && threshold > 0;
}

Expand Down Expand Up @@ -379,15 +363,6 @@ abstract contract ManagerBase is
ITransceiver(_registeredTransceivers[i]).transferTransceiverOwnership(newOwner);
}
}

/// @inheritdoc IManagerBase
function enableTransceiverForChain(
address transceiver,
uint16 forChainId
) external onlyOwner {
_enableTransceiverForChain(transceiver, forChainId);
emit TransceiverEnabledForChain(transceiver, forChainId);
}

/// @inheritdoc IManagerBase
function setTransceiver(
Expand Down Expand Up @@ -436,6 +411,22 @@ abstract contract ManagerBase is
_checkThresholdInvariants();
}

/// @inheritdoc IManagerBase
function enableSendTransceiverForChain(
address, // transceiver,
uint16 // forChainId
) external virtual onlyOwner {
revert NotImplemented();
}

/// @inheritdoc IManagerBase
function enableRecvTransceiverForChain(
address, // transceiver,
uint16 // forChainId
) external virtual onlyOwner {
revert NotImplemented();
}

/// @inheritdoc IManagerBase
function setThreshold(
uint8 threshold
Expand All @@ -452,20 +443,13 @@ abstract contract ManagerBase is

emit ThresholdChanged(oldThreshold, threshold);
}

/// @inheritdoc IManagerBase
function setThresholdPerChain(uint16 forChainId, uint8 threshold) external onlyOwner {
if (threshold == 0) {
revert ZeroThreshold();
}

mapping(uint16 => _Threshold) storage _threshold = _getThresholdStoragePerChain();
uint8 oldThreshold = _threshold[forChainId].num;

_threshold[forChainId].num = threshold;
_checkThresholdInvariants(_threshold[forChainId].num);

emit PerChainThresholdChanged(forChainId, oldThreshold, threshold);
/// @inheritdoc IManagerBase
function setPerChainThreshold(
uint16, // forChainId,
uint8 // threshold
) external virtual onlyOwner {
revert NotImplemented();
}

// =============== Internal ==============================================================
Expand All @@ -488,7 +472,7 @@ abstract contract ManagerBase is
) internal view returns (uint64) {
uint64 enabledTransceiverBitmap = _getEnabledTransceiversBitmap();
uint16 sourceChainId = _getMessageAttestationsStorage()[digest].sourceChainId;
uint64 enabledTransceiversForChain = _getEnabledTransceiversBitmapForChain(sourceChainId);
uint64 enabledTransceiversForChain = _getEnabledRecvTransceiversForChain(sourceChainId);
return _getMessageAttestationsStorage()[digest].attestedTransceivers
& enabledTransceiverBitmap & enabledTransceiversForChain;
}
Expand Down Expand Up @@ -521,6 +505,19 @@ abstract contract ManagerBase is
_getMessageSequenceStorage().num++;
}

function _isSendTransceiverEnabledForChain(
address, // transceiver,
uint16 // chainId
) internal view virtual returns (bool) {
return true;
}

function _getEnabledRecvTransceiversForChain(
uint16 // forChainId
) internal view virtual returns (uint64 bitmap) {
return type(uint64).max;
}

/// ============== Invariants =============================================

/// @dev When we add new immutables, this function should be updated
Expand All @@ -541,8 +538,10 @@ abstract contract ManagerBase is
function _checkThresholdInvariants() internal view {
_checkThresholdInvariants(_getThresholdStorage().num);
}

function _checkThresholdInvariants(uint8 threshold) internal pure {

function _checkThresholdInvariants(
uint8 threshold
) internal pure {
_NumTransceivers memory numTransceivers = _getNumTransceiversStorage();

// invariant: threshold <= enabledTransceivers.length
Expand Down
6 changes: 3 additions & 3 deletions evm/src/NttManager/NttManagerNoRateLimiting.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache 2
pragma solidity >=0.8.8 <0.9.0;

import "./NttManager.sol";
import "./NttManagerWithPerChainTransceivers.sol";

/// @title NttManagerNoRateLimiting
/// @author Wormhole Project Contributors.
Expand All @@ -10,12 +10,12 @@ import "./NttManager.sol";
/// free up code space.
///
/// @dev All of the developer notes from `NttManager` apply here.
contract NttManagerNoRateLimiting is NttManager {
contract NttManagerNoRateLimiting is NttManagerWithPerChainTransceivers {
constructor(
address _token,
Mode _mode,
uint16 _chainId
) NttManager(_token, _mode, _chainId, 0, true) {}
) NttManagerWithPerChainTransceivers(_token, _mode, _chainId, 0, true) {}

// ==================== Override RateLimiter functions =========================

Expand Down
150 changes: 150 additions & 0 deletions evm/src/NttManager/NttManagerWithPerChainTransceivers.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// SPDX-License-Identifier: Apache 2
pragma solidity >=0.8.8 <0.9.0;

import "./NttManager.sol";

/// @title NttManagerNoRateLimiting
/// @author Wormhole Project Contributors.
/// @notice The NttManagerNoRateLimiting abstract contract is an implementation of
/// NttManager that allows configuring different transceivers and thresholds
/// for each chain. Note that you can configure a different set of send and
/// receive transceivers for each chain, and if you don't specifically enable
/// any transceivers for a chain, then all transceivers will be used for it.
///
/// @dev All of the developer notes from `NttManager` apply here.
abstract contract NttManagerWithPerChainTransceivers is NttManager {
constructor(
address _token,
Mode _mode,
uint16 _chainId,
uint64 _rateLimitDuration,
bool _skipRateLimiting
) NttManager(_token, _mode, _chainId, _rateLimitDuration, _skipRateLimiting) {}

bytes32 private constant SEND_TRANSCEIVER_BITMAP_SLOT =
bytes32(uint256(keccak256("nttpct.sendTransceiverBitmap")) - 1);

bytes32 private constant RECV_TRANSCEIVER_BITMAP_SLOT =
bytes32(uint256(keccak256("nttpct.recvTransceiverBitmap")) - 1);

// ==================== Override / implementation of transceiver stuff =========================

/// @inheritdoc IManagerBase
function enableSendTransceiverForChain(
address transceiver,
uint16 forChainId
) external override(ManagerBase, IManagerBase) onlyOwner {
_enableTranceiverForChain(transceiver, forChainId, SEND_TRANSCEIVER_BITMAP_SLOT);
}

/// @inheritdoc IManagerBase
function enableRecvTransceiverForChain(
address transceiver,
uint16 forChainId
) external override(ManagerBase, IManagerBase) onlyOwner {
_enableTranceiverForChain(transceiver, forChainId, RECV_TRANSCEIVER_BITMAP_SLOT);
}

function _enableTranceiverForChain(
address transceiver,
uint16 forChainId,
bytes32 tag
) internal onlyOwner {
if (transceiver == address(0)) {
revert InvalidTransceiverZeroAddress();
}

mapping(address => TransceiverInfo) storage transceiverInfos = _getTransceiverInfosStorage();
if (!transceiverInfos[transceiver].registered) {
revert NonRegisteredTransceiver(transceiver);
}

uint8 index = _getTransceiverInfosStorage()[transceiver].index;
mapping(uint16 => _EnabledTransceiverBitmap) storage _bitmaps =
_getPerChainTransceiverBitmapStorage(tag);
_bitmaps[forChainId].bitmap |= uint64(1 << index);

emit TransceiverEnabledForChain(transceiver, forChainId);
}

function _isSendTransceiverEnabledForChain(
address transceiver,
uint16 forChainId
) internal view override returns (bool) {
uint64 bitmap =
_getPerChainTransceiverBitmapStorage(SEND_TRANSCEIVER_BITMAP_SLOT)[forChainId].bitmap;
if (bitmap == 0) {
// NOTE: this makes it backwards compatible -- if the bitmap is not
// set, it's assumed the corridor uses all transceivers.
bitmap = type(uint64).max;
}
uint8 index = _getTransceiverInfosStorage()[transceiver].index;
return (bitmap & uint64(1 << index)) != 0;
}

function _getEnabledRecvTransceiversForChain(
uint16 forChainId
) internal view override returns (uint64 bitmap) {
bitmap =
_getPerChainTransceiverBitmapStorage(RECV_TRANSCEIVER_BITMAP_SLOT)[forChainId].bitmap;
if (bitmap == 0) {
// NOTE: this makes it backwards compatible -- if the bitmap is not
// set, it's assumed the corridor uses all transceivers.
bitmap = type(uint64).max;
}
}

function _getPerChainTransceiverBitmapStorage(
bytes32 tag
) internal pure returns (mapping(uint16 => _EnabledTransceiverBitmap) storage $) {
// TODO: this is safe (reusing the storage slot, because the mapping
// doesn't write into the slot itself) buy maybe we shouldn't?
uint256 slot = uint256(tag);
assembly ("memory-safe") {
$.slot := slot
}
}

// ==================== Override / implementation of threshold stuff =========================

/// @inheritdoc IManagerBase
function setPerChainThreshold(
uint16 forChainId,
uint8 threshold
) external override(ManagerBase, IManagerBase) onlyOwner {
if (threshold == 0) {
revert ZeroThreshold();
}

mapping(uint16 => _Threshold) storage _threshold = _getThresholdStoragePerChain();
uint8 oldThreshold = _threshold[forChainId].num;

_threshold[forChainId].num = threshold;
_checkThresholdInvariants(_threshold[forChainId].num);

emit PerChainThresholdChanged(forChainId, oldThreshold, threshold);
}

function getPerChainThreshold(
uint16 forChainId
) public view override(ManagerBase, IManagerBase) returns (uint8) {
uint8 threshold = _getThresholdStoragePerChain()[forChainId].num;
if (threshold == 0) {
return _getThresholdStorage().num;
}
return threshold;
}

function _getThresholdStoragePerChain()
private
pure
returns (mapping(uint16 => _Threshold) storage $)
{
// TODO: this is safe (reusing the storage slot, because the mapping
// doesn't write into the slot itself) buy maybe we shouldn't?
uint256 slot = uint256(THRESHOLD_SLOT);
assembly ("memory-safe") {
$.slot := slot
}
}
}
Loading

0 comments on commit 30cedd1

Please sign in to comment.