From 38722b98fb24cc04862af460f69314e24b6e9bb8 Mon Sep 17 00:00:00 2001 From: Reptile <43194093+gator-boi@users.noreply.github.com> Date: Mon, 26 Feb 2024 11:41:22 -0600 Subject: [PATCH] evm: reorganize WormholeTransceiver and NttManager (#196) * evm: separate Transceiver and NttManger into separate directories * evm: move natspec to ITransceiver interface * evm: clean up WormholeTransceiver * evm: remove internal admin functions * evm: clean up NttManager * evm: add _checkImmutables to WormholeTransceiverState * fix build issues on rebase * evm: address pr feedback --------- Co-authored-by: gator-boi Co-authored-by: Rahul Maganti --- evm/src/{ => NttManager}/NttManager.sol | 689 +++++------------- evm/src/NttManager/NttManagerState.sol | 363 +++++++++ .../{ => NttManager}/TransceiverRegistry.sol | 6 +- evm/src/{ => Transceiver}/Transceiver.sol | 57 +- .../WormholeTransceiver.sol | 233 ++++++ .../WormholeTransceiverState.sol | 278 +++++++ evm/src/WormholeTransceiver.sol | 459 ------------ evm/src/interfaces/INttManager.sol | 159 ++-- evm/src/interfaces/INttManagerState.sol | 115 +++ evm/src/interfaces/ITransceiver.sol | 13 + evm/src/interfaces/IWormholeTransceiver.sol | 45 +- .../interfaces/IWormholeTransceiverState.sol | 66 ++ evm/test/IntegrationRelayer.t.sol | 19 +- evm/test/IntegrationStandalone.t.sol | 14 +- evm/test/NttManager.t.sol | 17 +- evm/test/Ownership.t.sol | 2 +- evm/test/RateLimit.t.sol | 4 +- evm/test/TransceiverStructs.t.sol | 2 +- evm/test/Upgrades.t.sol | 33 +- evm/test/libraries/NttManagerHelpers.sol | 2 +- evm/test/libraries/TransceiverHelpers.sol | 2 +- evm/test/mocks/DummyTransceiver.sol | 2 +- evm/test/mocks/MockNttManager.sol | 2 +- evm/test/mocks/MockTransceivers.sol | 2 +- 24 files changed, 1429 insertions(+), 1155 deletions(-) rename evm/src/{ => NttManager}/NttManager.sol (54%) create mode 100644 evm/src/NttManager/NttManagerState.sol rename evm/src/{ => NttManager}/TransceiverRegistry.sol (97%) rename evm/src/{ => Transceiver}/Transceiver.sol (75%) create mode 100644 evm/src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol create mode 100644 evm/src/Transceiver/WormholeTransceiver/WormholeTransceiverState.sol delete mode 100644 evm/src/WormholeTransceiver.sol create mode 100644 evm/src/interfaces/INttManagerState.sol create mode 100644 evm/src/interfaces/IWormholeTransceiverState.sol diff --git a/evm/src/NttManager.sol b/evm/src/NttManager/NttManager.sol similarity index 54% rename from evm/src/NttManager.sol rename to evm/src/NttManager/NttManager.sol index 4e5c69746..153c16713 100644 --- a/evm/src/NttManager.sol +++ b/evm/src/NttManager/NttManager.sol @@ -8,224 +8,51 @@ import "openzeppelin-contracts/contracts/token/ERC20/extensions/ERC20Burnable.so import "wormhole-solidity-sdk/Utils.sol"; import "wormhole-solidity-sdk/libraries/BytesParsing.sol"; -import "./libraries/external/OwnableUpgradeable.sol"; -import "./libraries/external/ReentrancyGuardUpgradeable.sol"; -import "./libraries/TransceiverStructs.sol"; -import "./libraries/TransceiverHelpers.sol"; -import "./libraries/RateLimiter.sol"; -import "./interfaces/INttManager.sol"; -import "./interfaces/INttManagerEvents.sol"; -import "./interfaces/INTTToken.sol"; -import "./interfaces/ITransceiver.sol"; -import "./TransceiverRegistry.sol"; -import "./NttTrimmer.sol"; -import "./libraries/PausableOwnable.sol"; -import "./libraries/Implementation.sol"; - -contract NttManager is - INttManager, - INttManagerEvents, - TransceiverRegistry, - RateLimiter, - NttTrimmer, - ReentrancyGuardUpgradeable, - PausableOwnable, - Implementation -{ - using BytesParsing for bytes; - using SafeERC20 for IERC20; - - error RefundFailed(uint256 refundAmount); - error CannotRenounceNttManagerOwnership(address owner); - error UnexpectedOwner(address expectedOwner, address owner); - error TransceiverAlreadyAttestedToMessage(bytes32 nttManagerMessageHash); - - address public immutable token; - address immutable deployer; - Mode public immutable mode; - uint16 public immutable chainId; - uint256 immutable evmChainId; - - enum Mode { - LOCKING, - BURNING - } - - // @dev Information about attestations for a given message. - struct AttestationInfo { - // whether this message has been executed - bool executed; - // bitmap of transceivers that have attested to this message (NOTE: might contain disabled transceivers) - uint64 attestedTransceivers; - } - - struct _Sequence { - uint64 num; - } - - struct _Threshold { - uint8 num; - } - - /// =============== STORAGE =============================================== - - bytes32 private constant MESSAGE_ATTESTATIONS_SLOT = - bytes32(uint256(keccak256("ntt.messageAttestations")) - 1); - - bytes32 private constant MESSAGE_SEQUENCE_SLOT = - bytes32(uint256(keccak256("ntt.messageSequence")) - 1); - - bytes32 private constant PEERS_SLOT = bytes32(uint256(keccak256("ntt.peers")) - 1); - - bytes32 private constant THRESHOLD_SLOT = bytes32(uint256(keccak256("ntt.threshold")) - 1); - - /// =============== GETTERS/SETTERS ======================================== - - function _getThresholdStorage() private pure returns (_Threshold storage $) { - uint256 slot = uint256(THRESHOLD_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - - function _getMessageAttestationsStorage() - internal - pure - returns (mapping(bytes32 => AttestationInfo) storage $) - { - uint256 slot = uint256(MESSAGE_ATTESTATIONS_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - - function _getMessageSequenceStorage() internal pure returns (_Sequence storage $) { - uint256 slot = uint256(MESSAGE_SEQUENCE_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - - function _getPeersStorage() internal pure returns (mapping(uint16 => bytes32) storage $) { - uint256 slot = uint256(PEERS_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - - function setThreshold(uint8 threshold) external onlyOwner { - if (threshold == 0) { - revert ZeroThreshold(); - } - - _Threshold storage _threshold = _getThresholdStorage(); - uint8 oldThreshold = _threshold.num; - - _threshold.num = threshold; - _checkThresholdInvariants(); - - emit ThresholdChanged(oldThreshold, threshold); - } - - function getMode() public view returns (uint8) { - return uint8(mode); - } +import "../libraries/RateLimiter.sol"; - /// @notice Returns the number of Transceivers that must attest to a msgId for - /// it to be considered valid and acted upon. - function getThreshold() public view returns (uint8) { - return _getThresholdStorage().num; - } - - function setTransceiver(address transceiver) external onlyOwner { - _setTransceiver(transceiver); - - _Threshold storage _threshold = _getThresholdStorage(); - // We do not automatically increase the threshold here. - // Automatically increasing the threshold can result in a scenario - // where in-flight messages can't be redeemed. - // For example: Assume there is 1 Transceiver and the threshold is 1. - // If we were to add a new Transceiver, the threshold would increase to 2. - // However, all messages that are either in-flight or that are sent on - // a source chain that does not yet have 2 Transceivers will only have been - // sent from a single transceiver, so they would never be able to get - // redeemed. - // Instead, we leave it up to the owner to manually update the threshold - // after some period of time, ideally once all chains have the new Transceiver - // and transfers that were sent via the old configuration are all complete. - // However if the threshold is 0 (the initial case) we do increment to 1. - if (_threshold.num == 0) { - _threshold.num = 1; - } +import "../interfaces/INttManager.sol"; +import "../interfaces/INttManagerEvents.sol"; +import "../interfaces/INTTToken.sol"; +import "../interfaces/ITransceiver.sol"; - emit TransceiverAdded(transceiver, _getNumTransceiversStorage().enabled, _threshold.num); - } +import {NttManagerState} from "./NttManagerState.sol"; - function removeTransceiver(address transceiver) external onlyOwner { - _removeTransceiver(transceiver); - - _Threshold storage _threshold = _getThresholdStorage(); - uint8 numEnabledTransceivers = _getNumTransceiversStorage().enabled; - - if (numEnabledTransceivers < _threshold.num) { - _threshold.num = numEnabledTransceivers; - } - - emit TransceiverRemoved(transceiver, _threshold.num); - } +contract NttManager is INttManager, NttManagerState { + using BytesParsing for bytes; + using SafeERC20 for IERC20; constructor( address _token, Mode _mode, uint16 _chainId, uint64 _rateLimitDuration - ) RateLimiter(_rateLimitDuration) NttTrimmer(_token) { - token = _token; - mode = _mode; - chainId = _chainId; - evmChainId = block.chainid; - // save the deployer (check this on initialization) - deployer = msg.sender; - } + ) NttManagerState(_token, _mode, _chainId, _rateLimitDuration) {} - function __NttManager_init() internal onlyInitializing { - // check if the owner is the deployer of this contract - if (msg.sender != deployer) { - revert UnexpectedOwner(deployer, msg.sender); - } - __PausedOwnable_init(msg.sender, msg.sender); - __ReentrancyGuard_init(); - } - - function _initialize() internal virtual override { - __NttManager_init(); - _checkThresholdInvariants(); - _checkTransceiversInvariants(); - } + // ==================== External Interface =============================================== - function _migrate() internal virtual override { - _checkThresholdInvariants(); - _checkTransceiversInvariants(); - } - - /// =============== ADMIN =============================================== - function upgrade(address newImplementation) external onlyOwner { - _upgrade(newImplementation); + /// @inheritdoc INttManager + function transfer( + uint256 amount, + uint16 recipientChain, + bytes32 recipient + ) external payable nonReentrant whenNotPaused returns (uint64) { + return _transferEntryPoint(amount, recipientChain, recipient, false, new bytes(1)); } - /// @dev Transfer ownership of the NttManager contract and all Transceiver contracts to a new owner. - function transferOwnership(address newOwner) public override onlyOwner { - super.transferOwnership(newOwner); - // loop through all the registered transceivers and set the new owner of each transceiver to the newOwner - address[] storage _registeredTransceivers = _getRegisteredTransceiversStorage(); - _checkRegisteredTransceiversInvariants(); - - for (uint256 i = 0; i < _registeredTransceivers.length; i++) { - ITransceiver(_registeredTransceivers[i]).transferTransceiverOwnership(newOwner); - } + /// @inheritdoc INttManager + function transfer( + uint256 amount, + uint16 recipientChain, + bytes32 recipient, + bool shouldQueue, + bytes memory transceiverInstructions + ) external payable nonReentrant whenNotPaused returns (uint64) { + return _transferEntryPoint( + amount, recipientChain, recipient, shouldQueue, transceiverInstructions + ); } - /// @dev This method should return an array of delivery prices corresponding to each transceiver. + /// @inheritdoc INttManager function quoteDeliveryPrice( uint16 recipientChain, TransceiverStructs.TransceiverInstruction[] memory transceiverInstructions, @@ -248,74 +75,113 @@ contract NttManager is return (priceQuotes, totalPriceQuote); } - function _sendMessageToTransceivers( - uint16 recipientChain, - uint256[] memory priceQuotes, - TransceiverStructs.TransceiverInstruction[] memory transceiverInstructions, - address[] memory enabledTransceivers, - bytes memory nttManagerMessage - ) internal { - uint256 numEnabledTransceivers = enabledTransceivers.length; - mapping(address => TransceiverInfo) storage transceiverInfos = _getTransceiverInfosStorage(); - // call into transceiver contracts to send the message - for (uint256 i = 0; i < numEnabledTransceivers; i++) { - address transceiverAddr = enabledTransceivers[i]; - // send it to the recipient nttManager based on the chain - ITransceiver(transceiverAddr).sendMessage{value: priceQuotes[i]}( - recipientChain, - transceiverInstructions[transceiverInfos[transceiverAddr].index], - nttManagerMessage, - getPeer(recipientChain) - ); + /// @inheritdoc INttManager + function attestationReceived( + uint16 sourceChainId, + bytes32 sourceNttManagerAddress, + TransceiverStructs.NttManagerMessage memory payload + ) external onlyTransceiver { + _verifyPeer(sourceChainId, sourceNttManagerAddress); + + bytes32 nttManagerMessageHash = + TransceiverStructs.nttManagerMessageDigest(sourceChainId, payload); + + // set the attested flag for this transceiver. + // NOTE: Attestation is idempotent (bitwise or 1), but we revert + // anyway to ensure that the client does not continue to initiate calls + // to receive the same message through the same transceiver. + if ( + transceiverAttestedToMessage( + nttManagerMessageHash, _getTransceiverInfosStorage()[msg.sender].index + ) + ) { + revert TransceiverAlreadyAttestedToMessage(nttManagerMessageHash); } - } + _setTransceiverAttestedToMessage(nttManagerMessageHash, msg.sender); - function isMessageApproved(bytes32 digest) public view returns (bool) { - uint8 threshold = getThreshold(); - return messageAttestations(digest) >= threshold && threshold > 0; + if (isMessageApproved(nttManagerMessageHash)) { + executeMsg(sourceChainId, sourceNttManagerAddress, payload); + } } - function _setTransceiverAttestedToMessage(bytes32 digest, uint8 index) internal { - _getMessageAttestationsStorage()[digest].attestedTransceivers |= uint64(1 << index); - } + /// @inheritdoc INttManager + function executeMsg( + uint16 sourceChainId, + bytes32 sourceNttManagerAddress, + TransceiverStructs.NttManagerMessage memory message + ) public { + // verify chain has not forked + checkFork(evmChainId); - function _setTransceiverAttestedToMessage(bytes32 digest, address transceiver) internal { - _setTransceiverAttestedToMessage(digest, _getTransceiverInfosStorage()[transceiver].index); + bytes32 digest = TransceiverStructs.nttManagerMessageDigest(sourceChainId, message); - emit MessageAttestedTo( - digest, transceiver, _getTransceiverInfosStorage()[transceiver].index - ); - } + if (!isMessageApproved(digest)) { + revert MessageNotApproved(digest); + } - /* - * @dev pause the Transceiver. - */ - function pause() public onlyOwnerOrPauser { - _pause(); - } + bool msgAlreadyExecuted = _replayProtect(digest); + if (msgAlreadyExecuted) { + // end execution early to mitigate the possibility of race conditions from transceivers + // attempting to deliver the same message when (threshold < number of transceiver messages) + // notify client (off-chain process) so they don't attempt redundant msg delivery + emit MessageAlreadyExecuted(sourceNttManagerAddress, digest); + return; + } - /// @dev Returns the bitmap of attestations from enabled transceivers for a given message. - function _getMessageAttestations(bytes32 digest) internal view returns (uint64) { - uint64 enabledTransceiverBitmap = _getEnabledTransceiversBitmap(); - return - _getMessageAttestationsStorage()[digest].attestedTransceivers & enabledTransceiverBitmap; - } + TransceiverStructs.NativeTokenTransfer memory nativeTokenTransfer = + TransceiverStructs.parseNativeTokenTransfer(message.payload); - function _getEnabledTransceiverAttestedToMessage( - bytes32 digest, - uint8 index - ) internal view returns (bool) { - return _getMessageAttestations(digest) & uint64(1 << index) != 0; - } + // verify that the destination chain is valid + if (nativeTokenTransfer.toChain != chainId) { + revert InvalidTargetChain(nativeTokenTransfer.toChain, chainId); + } + + TrimmedAmount memory nativeTransferAmount = _nttFixDecimals(nativeTokenTransfer.amount); + + address transferRecipient = fromWormholeFormat(nativeTokenTransfer.to); - function setOutboundLimit(uint256 limit) external onlyOwner { - _setOutboundLimit(_nttTrimmer(limit)); + { + // Check inbound rate limits + bool isRateLimited = _isInboundAmountRateLimited(nativeTransferAmount, sourceChainId); + if (isRateLimited) { + // queue up the transfer + _enqueueInboundTransfer(digest, nativeTransferAmount, transferRecipient); + + // end execution early + return; + } + } + + // consume the amount for the inbound rate limit + _consumeInboundAmount(nativeTransferAmount, sourceChainId); + // When receiving a transfer, we refill the outbound rate limit + // by the same amount (we call this "backflow") + _backfillOutboundAmount(nativeTransferAmount); + + _mintOrUnlockToRecipient(digest, transferRecipient, nativeTransferAmount); } - function setInboundLimit(uint256 limit, uint16 chainId_) external onlyOwner { - _setInboundLimit(_nttTrimmer(limit), chainId_); + /// @inheritdoc INttManager + function completeInboundQueuedTransfer(bytes32 digest) external nonReentrant whenNotPaused { + // find the message in the queue + InboundQueuedTransfer memory queuedTransfer = getInboundQueuedTransfer(digest); + if (queuedTransfer.txTimestamp == 0) { + revert InboundQueuedTransferNotFound(digest); + } + + // check that > RATE_LIMIT_DURATION has elapsed + if (block.timestamp - queuedTransfer.txTimestamp < rateLimitDuration) { + revert InboundQueuedTransferStillQueued(digest, queuedTransfer.txTimestamp); + } + + // remove transfer from the queue + delete _getInboundQueueStorage()[digest]; + + // run it through the mint/unlock logic + _mintOrUnlockToRecipient(digest, queuedTransfer.recipient, queuedTransfer.amount); } + /// @inheritdoc INttManager function completeOutboundQueuedTransfer(uint64 messageSequence) external payable @@ -348,54 +214,33 @@ contract NttManager is ); } - /// @dev Refunds the remaining amount back to the sender. - function refundToSender(uint256 refundAmount) internal { - // refund the price quote back to sender - (bool refundSuccessful,) = payable(msg.sender).call{value: refundAmount}(""); - - // check success - if (!refundSuccessful) { - revert RefundFailed(refundAmount); - } - } - - /// @dev Returns trimmed amount and checks for dust - function trimTransferAmount(uint256 amount) internal view returns (TrimmedAmount memory) { - TrimmedAmount memory trimmedAmount; - { - trimmedAmount = _nttTrimmer(amount); - // don't deposit dust that can not be bridged due to the decimal shift - uint256 newAmount = _nttUntrim(trimmedAmount); - if (amount != newAmount) { - revert TransferAmountHasDust(amount, amount - newAmount); - } - } - - return trimmedAmount; + /// @inheritdoc INttManager + function tokenDecimals() public view override(INttManager, RateLimiter) returns (uint8) { + return tokenDecimals_; } - /// @dev Simple quality of life transfer method that doesn't deal with queuing or passing transceiver instructions. - function transfer( - uint256 amount, - uint16 recipientChain, - bytes32 recipient - ) external payable nonReentrant whenNotPaused returns (uint64) { - return _transferEntryPoint(amount, recipientChain, recipient, false, new bytes(1)); - } + // ==================== Internal Business Logic ========================================= - /// @notice Called by the user to send the token cross-chain. - /// This function will either lock or burn the sender's tokens. - /// Finally, this function will call into the Transceiver contracts to send a message with the incrementing sequence number and the token transfer payload. - function transfer( - uint256 amount, + function _sendMessageToTransceivers( uint16 recipientChain, - bytes32 recipient, - bool shouldQueue, - bytes memory transceiverInstructions - ) external payable nonReentrant whenNotPaused returns (uint64) { - return _transferEntryPoint( - amount, recipientChain, recipient, shouldQueue, transceiverInstructions - ); + uint256[] memory priceQuotes, + TransceiverStructs.TransceiverInstruction[] memory transceiverInstructions, + address[] memory enabledTransceivers, + bytes memory nttManagerMessage + ) internal { + uint256 numEnabledTransceivers = enabledTransceivers.length; + mapping(address => TransceiverInfo) storage transceiverInfos = _getTransceiverInfosStorage(); + // call into transceiver contracts to send the message + for (uint256 i = 0; i < numEnabledTransceivers; i++) { + address transceiverAddr = enabledTransceivers[i]; + // send it to the recipient nttManager based on the chain + ITransceiver(transceiverAddr).sendMessage{value: priceQuotes[i]}( + recipientChain, + transceiverInstructions[transceiverInfos[transceiverAddr].index], + nttManagerMessage, + getPeer(recipientChain) + ); + } } function _transferEntryPoint( @@ -419,13 +264,13 @@ contract NttManager is { // use transferFrom to pull tokens from the user and lock them // query own token balance before transfer - uint256 balanceBefore = getTokenBalanceOf(token, address(this)); + uint256 balanceBefore = _getTokenBalanceOf(token, address(this)); // transfer tokens IERC20(token).safeTransferFrom(msg.sender, address(this), amount); // query own token balance after transfer - uint256 balanceAfter = getTokenBalanceOf(token, address(this)); + uint256 balanceAfter = _getTokenBalanceOf(token, address(this)); // correct amount for potential transfer fees amount = balanceAfter - balanceBefore; @@ -433,7 +278,7 @@ contract NttManager is } else if (mode == Mode.BURNING) { { // query sender's token balance before burn - uint256 balanceBefore = getTokenBalanceOf(token, msg.sender); + uint256 balanceBefore = _getTokenBalanceOf(token, msg.sender); // call the token's burn function to burn the sender's token // NOTE: We don't account for burn fees in this code path. @@ -450,7 +295,7 @@ contract NttManager is ERC20Burnable(token).burnFrom(msg.sender, amount); // query sender's token balance after transfer - uint256 balanceAfter = getTokenBalanceOf(token, msg.sender); + uint256 balanceAfter = _getTokenBalanceOf(token, msg.sender); uint256 balanceDiff = balanceBefore - balanceAfter; if (balanceDiff != amount) { @@ -463,7 +308,7 @@ contract NttManager is } // trim amount after burning to ensure transfer amount matches (amount - fee) - TrimmedAmount memory trimmedAmount = trimTransferAmount(amount); + TrimmedAmount memory trimmedAmount = _trimTransferAmount(amount); // get the sequence for this transfer uint64 sequence = _useMessageSequence(); @@ -491,7 +336,7 @@ contract NttManager is ); // refund price quote back to sender - refundToSender(msg.value); + _refundToSender(msg.value); // return the sequence in the queue return sequence; @@ -534,7 +379,7 @@ contract NttManager is // refund user extra excess value from msg.value uint256 excessValue = msg.value - totalPriceQuote; if (excessValue > 0) { - refundToSender(excessValue); + _refundToSender(excessValue); } } @@ -563,104 +408,6 @@ contract NttManager is return sequence; } - /// @dev Verify that the peer address saved for `sourceChainId` matches the `peerAddress`. - function _verifyPeer(uint16 sourceChainId, bytes32 peerAddress) internal view { - if (getPeer(sourceChainId) != peerAddress) { - revert InvalidPeer(sourceChainId, peerAddress); - } - } - - // @dev Mark a message as executed. - // This function will retuns `true` if the message has already been executed. - function _replayProtect(bytes32 digest) internal returns (bool) { - // check if this message has already been executed - if (isMessageExecuted(digest)) { - return true; - } - - // mark this message as executed - _getMessageAttestationsStorage()[digest].executed = true; - - return false; - } - - /// @dev Called after a message has been sufficiently verified to execute the command in the message. - /// This function will decode the payload as an NttManagerMessage to extract the sequence, msgType, and other parameters. - function executeMsg( - uint16 sourceChainId, - bytes32 sourceNttManagerAddress, - TransceiverStructs.NttManagerMessage memory message - ) public { - // verify chain has not forked - checkFork(evmChainId); - - bytes32 digest = TransceiverStructs.nttManagerMessageDigest(sourceChainId, message); - - if (!isMessageApproved(digest)) { - revert MessageNotApproved(digest); - } - - bool msgAlreadyExecuted = _replayProtect(digest); - if (msgAlreadyExecuted) { - // end execution early to mitigate the possibility of race conditions from transceivers - // attempting to deliver the same message when (threshold < number of transceiver messages) - // notify client (off-chain process) so they don't attempt redundant msg delivery - emit MessageAlreadyExecuted(sourceNttManagerAddress, digest); - return; - } - - TransceiverStructs.NativeTokenTransfer memory nativeTokenTransfer = - TransceiverStructs.parseNativeTokenTransfer(message.payload); - - // verify that the destination chain is valid - if (nativeTokenTransfer.toChain != chainId) { - revert InvalidTargetChain(nativeTokenTransfer.toChain, chainId); - } - - TrimmedAmount memory nativeTransferAmount = _nttFixDecimals(nativeTokenTransfer.amount); - - address transferRecipient = fromWormholeFormat(nativeTokenTransfer.to); - - { - // Check inbound rate limits - bool isRateLimited = _isInboundAmountRateLimited(nativeTransferAmount, sourceChainId); - if (isRateLimited) { - // queue up the transfer - _enqueueInboundTransfer(digest, nativeTransferAmount, transferRecipient); - - // end execution early - return; - } - } - - // consume the amount for the inbound rate limit - _consumeInboundAmount(nativeTransferAmount, sourceChainId); - // When receiving a transfer, we refill the outbound rate limit - // by the same amount (we call this "backflow") - _backfillOutboundAmount(nativeTransferAmount); - - _mintOrUnlockToRecipient(digest, transferRecipient, nativeTransferAmount); - } - - function completeInboundQueuedTransfer(bytes32 digest) external nonReentrant whenNotPaused { - // find the message in the queue - InboundQueuedTransfer memory queuedTransfer = getInboundQueuedTransfer(digest); - if (queuedTransfer.txTimestamp == 0) { - revert InboundQueuedTransferNotFound(digest); - } - - // check that > RATE_LIMIT_DURATION has elapsed - if (block.timestamp - queuedTransfer.txTimestamp < rateLimitDuration) { - revert InboundQueuedTransferStillQueued(digest, queuedTransfer.txTimestamp); - } - - // remove transfer from the queue - delete _getInboundQueueStorage()[digest]; - - // run it through the mint/unlock logic - _mintOrUnlockToRecipient(digest, queuedTransfer.recipient, queuedTransfer.amount); - } - function _mintOrUnlockToRecipient( bytes32 digest, address recipient, @@ -683,16 +430,33 @@ contract NttManager is } } - function nextMessageSequence() external view returns (uint64) { - return _getMessageSequenceStorage().num; + // ==================== Internal Helpers =============================================== + + function _refundToSender(uint256 refundAmount) internal { + // refund the price quote back to sender + (bool refundSuccessful,) = payable(msg.sender).call{value: refundAmount}(""); + + // check success + if (!refundSuccessful) { + revert RefundFailed(refundAmount); + } } - function _useMessageSequence() internal returns (uint64 currentSequence) { - currentSequence = _getMessageSequenceStorage().num; - _getMessageSequenceStorage().num++; + function _trimTransferAmount(uint256 amount) internal view returns (TrimmedAmount memory) { + TrimmedAmount memory trimmedAmount; + { + trimmedAmount = _nttTrimmer(amount); + // don't deposit dust that can not be bridged due to the decimal shift + uint256 newAmount = _nttUntrim(trimmedAmount); + if (amount != newAmount) { + revert TransferAmountHasDust(amount, amount - newAmount); + } + } + + return trimmedAmount; } - function getTokenBalanceOf( + function _getTokenBalanceOf( address tokenAddr, address accountAddr ) internal view returns (uint256) { @@ -700,105 +464,4 @@ contract NttManager is tokenAddr.staticcall(abi.encodeWithSelector(IERC20.balanceOf.selector, accountAddr)); return abi.decode(queriedBalance, (uint256)); } - - function isMessageExecuted(bytes32 digest) public view returns (bool) { - return _getMessageAttestationsStorage()[digest].executed; - } - - function getPeer(uint16 chainId_) public view returns (bytes32) { - return _getPeersStorage()[chainId_]; - } - - /// @notice this sets the corresponding peer. - /// @dev The nttManager that executes the message sets the source nttManager as the peer. - function setPeer(uint16 peerChainId, bytes32 peerContract) public onlyOwner { - if (peerChainId == 0) { - revert InvalidPeerChainIdZero(); - } - if (peerContract == bytes32(0)) { - revert InvalidPeerZeroAddress(); - } - - bytes32 oldPeerContract = _getPeersStorage()[peerChainId]; - - _getPeersStorage()[peerChainId] = peerContract; - - emit PeerUpdated(peerChainId, oldPeerContract, peerContract); - } - - function transceiverAttestedToMessage(bytes32 digest, uint8 index) public view returns (bool) { - return - _getMessageAttestationsStorage()[digest].attestedTransceivers & uint64(1 << index) == 1; - } - - function attestationReceived( - uint16 sourceChainId, - bytes32 sourceNttManagerAddress, - TransceiverStructs.NttManagerMessage memory payload - ) external onlyTransceiver { - _verifyPeer(sourceChainId, sourceNttManagerAddress); - - bytes32 nttManagerMessageHash = - TransceiverStructs.nttManagerMessageDigest(sourceChainId, payload); - - // set the attested flag for this transceiver. - // NOTE: Attestation is idempotent (bitwise or 1), but we revert - // anyway to ensure that the client does not continue to initiate calls - // to receive the same message through the same transceiver. - if ( - transceiverAttestedToMessage( - nttManagerMessageHash, _getTransceiverInfosStorage()[msg.sender].index - ) - ) { - revert TransceiverAlreadyAttestedToMessage(nttManagerMessageHash); - } - _setTransceiverAttestedToMessage(nttManagerMessageHash, msg.sender); - - if (isMessageApproved(nttManagerMessageHash)) { - executeMsg(sourceChainId, sourceNttManagerAddress, payload); - } - } - - // @dev Count the number of attestations from enabled transceivers for a given message. - function messageAttestations(bytes32 digest) public view returns (uint8 count) { - return countSetBits(_getMessageAttestations(digest)); - } - - function tokenDecimals() public view override(INttManager, RateLimiter) returns (uint8) { - return tokenDecimals_; - } - - /// ============== INVARIANTS ============================================= - - /// @dev When we add new immutables, this function should be updated - function _checkImmutables() internal view override { - assert(this.token() == token); - assert(this.mode() == mode); - assert(this.chainId() == chainId); - assert(this.rateLimitDuration() == rateLimitDuration); - } - - function _checkRegisteredTransceiversInvariants() internal view { - if (_getRegisteredTransceiversStorage().length != _getNumTransceiversStorage().registered) { - revert RetrievedIncorrectRegisteredTransceivers( - _getRegisteredTransceiversStorage().length, _getNumTransceiversStorage().registered - ); - } - } - - function _checkThresholdInvariants() internal view { - uint8 threshold = _getThresholdStorage().num; - _NumTransceivers memory numTransceivers = _getNumTransceiversStorage(); - - // invariant: threshold <= enabledTransceivers.length - if (threshold > numTransceivers.enabled) { - revert ThresholdTooHigh(threshold, numTransceivers.enabled); - } - - if (numTransceivers.registered > 0) { - if (threshold == 0) { - revert ZeroThreshold(); - } - } - } } diff --git a/evm/src/NttManager/NttManagerState.sol b/evm/src/NttManager/NttManagerState.sol new file mode 100644 index 000000000..9cc042da4 --- /dev/null +++ b/evm/src/NttManager/NttManagerState.sol @@ -0,0 +1,363 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity >=0.8.8 <0.9.0; + +import "openzeppelin-contracts/contracts/token/ERC20/IERC20.sol"; +import "openzeppelin-contracts/contracts/token/ERC20/utils/SafeERC20.sol"; +import "openzeppelin-contracts/contracts/token/ERC20/extensions/ERC20Burnable.sol"; + +import "wormhole-solidity-sdk/Utils.sol"; +import "wormhole-solidity-sdk/libraries/BytesParsing.sol"; + +import "../libraries/external/OwnableUpgradeable.sol"; +import "../libraries/external/ReentrancyGuardUpgradeable.sol"; +import "../libraries/TransceiverStructs.sol"; +import "../libraries/TransceiverHelpers.sol"; +import "../libraries/RateLimiter.sol"; +import "../libraries/PausableOwnable.sol"; +import "../libraries/Implementation.sol"; + +import "../interfaces/INttManager.sol"; +import "../interfaces/INttManagerState.sol"; +import "../interfaces/INttManagerEvents.sol"; +import "../interfaces/INTTToken.sol"; +import "../interfaces/ITransceiver.sol"; + +import "./TransceiverRegistry.sol"; +import "./../NttTrimmer.sol"; + +abstract contract NttManagerState is + INttManagerState, + INttManagerEvents, + RateLimiter, + NttTrimmer, + TransceiverRegistry, + PausableOwnable, + ReentrancyGuardUpgradeable, + Implementation +{ + // =============== Immutables ============================================================ + + address public immutable token; + address immutable deployer; + INttManager.Mode public immutable mode; + uint16 public immutable chainId; + uint256 immutable evmChainId; + + // =============== Setup ================================================================= + + constructor( + address _token, + INttManager.Mode _mode, + uint16 _chainId, + uint64 _rateLimitDuration + ) RateLimiter(_rateLimitDuration) NttTrimmer(_token) { + token = _token; + mode = _mode; + chainId = _chainId; + evmChainId = block.chainid; + // save the deployer (check this on initialization) + deployer = msg.sender; + } + + function __NttManager_init() internal onlyInitializing { + // check if the owner is the deployer of this contract + if (msg.sender != deployer) { + revert UnexpectedDeployer(deployer, msg.sender); + } + __PausedOwnable_init(msg.sender, msg.sender); + __ReentrancyGuard_init(); + } + + function _initialize() internal virtual override { + __NttManager_init(); + _checkThresholdInvariants(); + _checkTransceiversInvariants(); + } + + function _migrate() internal virtual override { + _checkThresholdInvariants(); + _checkTransceiversInvariants(); + } + + // =============== Storage ============================================================== + + bytes32 private constant MESSAGE_ATTESTATIONS_SLOT = + bytes32(uint256(keccak256("ntt.messageAttestations")) - 1); + + bytes32 private constant MESSAGE_SEQUENCE_SLOT = + bytes32(uint256(keccak256("ntt.messageSequence")) - 1); + + bytes32 private constant PEERS_SLOT = bytes32(uint256(keccak256("ntt.peers")) - 1); + + bytes32 private constant THRESHOLD_SLOT = bytes32(uint256(keccak256("ntt.threshold")) - 1); + + // =============== Storage Getters/Setters ============================================== + + function _getThresholdStorage() private pure returns (INttManager._Threshold storage $) { + uint256 slot = uint256(THRESHOLD_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + function _getMessageAttestationsStorage() + internal + pure + returns (mapping(bytes32 => INttManager.AttestationInfo) storage $) + { + uint256 slot = uint256(MESSAGE_ATTESTATIONS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + function _getMessageSequenceStorage() internal pure returns (INttManager._Sequence storage $) { + uint256 slot = uint256(MESSAGE_SEQUENCE_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + function _getPeersStorage() internal pure returns (mapping(uint16 => bytes32) storage $) { + uint256 slot = uint256(PEERS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + // =============== Public Getters ======================================================== + + /// @inheritdoc INttManagerState + function getMode() public view returns (uint8) { + return uint8(mode); + } + + /// @inheritdoc INttManagerState + function getThreshold() public view returns (uint8) { + return _getThresholdStorage().num; + } + + /// @inheritdoc INttManagerState + function isMessageApproved(bytes32 digest) public view returns (bool) { + uint8 threshold = getThreshold(); + return messageAttestations(digest) >= threshold && threshold > 0; + } + + /// @inheritdoc INttManagerState + function nextMessageSequence() external view returns (uint64) { + return _getMessageSequenceStorage().num; + } + + /// @inheritdoc INttManagerState + function isMessageExecuted(bytes32 digest) public view returns (bool) { + return _getMessageAttestationsStorage()[digest].executed; + } + + /// @inheritdoc INttManagerState + function getPeer(uint16 chainId_) public view returns (bytes32) { + return _getPeersStorage()[chainId_]; + } + + /// @inheritdoc INttManagerState + function transceiverAttestedToMessage(bytes32 digest, uint8 index) public view returns (bool) { + return + _getMessageAttestationsStorage()[digest].attestedTransceivers & uint64(1 << index) == 1; + } + + /// @inheritdoc INttManagerState + function messageAttestations(bytes32 digest) public view returns (uint8 count) { + return countSetBits(_getMessageAttestations(digest)); + } + + // =============== ADMIN ============================================================== + + /// @inheritdoc INttManagerState + function upgrade(address newImplementation) external onlyOwner { + _upgrade(newImplementation); + } + + /// @inheritdoc INttManagerState + function pause() public onlyOwnerOrPauser { + _pause(); + } + + /// @notice Transfer ownership of the Manager contract and all Endpoint contracts to a new owner. + function transferOwnership(address newOwner) public override onlyOwner { + super.transferOwnership(newOwner); + // loop through all the registered transceivers and set the new owner of each transceiver to the newOwner + address[] storage _registeredTransceivers = _getRegisteredTransceiversStorage(); + _checkRegisteredTransceiversInvariants(); + + for (uint256 i = 0; i < _registeredTransceivers.length; i++) { + ITransceiver(_registeredTransceivers[i]).transferTransceiverOwnership(newOwner); + } + } + + /// @inheritdoc INttManagerState + function setTransceiver(address transceiver) external onlyOwner { + _setTransceiver(transceiver); + + INttManager._Threshold storage _threshold = _getThresholdStorage(); + // We do not automatically increase the threshold here. + // Automatically increasing the threshold can result in a scenario + // where in-flight messages can't be redeemed. + // For example: Assume there is 1 Transceiver and the threshold is 1. + // If we were to add a new Transceiver, the threshold would increase to 2. + // However, all messages that are either in-flight or that are sent on + // a source chain that does not yet have 2 Transceivers will only have been + // sent from a single transceiver, so they would never be able to get + // redeemed. + // Instead, we leave it up to the owner to manually update the threshold + // after some period of time, ideally once all chains have the new Transceiver + // and transfers that were sent via the old configuration are all complete. + // However if the threshold is 0 (the initial case) we do increment to 1. + if (_threshold.num == 0) { + _threshold.num = 1; + } + + emit TransceiverAdded(transceiver, _getNumTransceiversStorage().enabled, _threshold.num); + } + + /// @inheritdoc INttManagerState + function removeTransceiver(address transceiver) external onlyOwner { + _removeTransceiver(transceiver); + + INttManager._Threshold storage _threshold = _getThresholdStorage(); + uint8 numEnabledTransceivers = _getNumTransceiversStorage().enabled; + + if (numEnabledTransceivers < _threshold.num) { + _threshold.num = numEnabledTransceivers; + } + + emit TransceiverRemoved(transceiver, _threshold.num); + } + + /// @inheritdoc INttManagerState + function setThreshold(uint8 threshold) external onlyOwner { + if (threshold == 0) { + revert ZeroThreshold(); + } + + INttManager._Threshold storage _threshold = _getThresholdStorage(); + uint8 oldThreshold = _threshold.num; + + _threshold.num = threshold; + _checkThresholdInvariants(); + + emit ThresholdChanged(oldThreshold, threshold); + } + + /// @inheritdoc INttManagerState + function setPeer(uint16 peerChainId, bytes32 peerContract) public onlyOwner { + if (peerChainId == 0) { + revert InvalidPeerChainIdZero(); + } + if (peerContract == bytes32(0)) { + revert InvalidPeerZeroAddress(); + } + + bytes32 oldPeerContract = _getPeersStorage()[peerChainId]; + + _getPeersStorage()[peerChainId] = peerContract; + + emit PeerUpdated(peerChainId, oldPeerContract, peerContract); + } + + /// @inheritdoc INttManagerState + function setOutboundLimit(uint256 limit) external onlyOwner { + _setOutboundLimit(_nttTrimmer(limit)); + } + + /// @inheritdoc INttManagerState + function setInboundLimit(uint256 limit, uint16 chainId_) external onlyOwner { + _setInboundLimit(_nttTrimmer(limit), chainId_); + } + + // =============== Internal ============================================================== + + function _setTransceiverAttestedToMessage(bytes32 digest, uint8 index) internal { + _getMessageAttestationsStorage()[digest].attestedTransceivers |= uint64(1 << index); + } + + function _setTransceiverAttestedToMessage(bytes32 digest, address transceiver) internal { + _setTransceiverAttestedToMessage(digest, _getTransceiverInfosStorage()[transceiver].index); + + emit MessageAttestedTo( + digest, transceiver, _getTransceiverInfosStorage()[transceiver].index + ); + } + + /// @dev Returns the bitmap of attestations from enabled transceivers for a given message. + function _getMessageAttestations(bytes32 digest) internal view returns (uint64) { + uint64 enabledTransceiverBitmap = _getEnabledTransceiversBitmap(); + return + _getMessageAttestationsStorage()[digest].attestedTransceivers & enabledTransceiverBitmap; + } + + function _getEnabledTransceiverAttestedToMessage( + bytes32 digest, + uint8 index + ) internal view returns (bool) { + return _getMessageAttestations(digest) & uint64(1 << index) != 0; + } + + /// @dev Verify that the peer address saved for `sourceChainId` matches the `peerAddress`. + function _verifyPeer(uint16 sourceChainId, bytes32 peerAddress) internal view { + if (getPeer(sourceChainId) != peerAddress) { + revert InvalidPeer(sourceChainId, peerAddress); + } + } + + // @dev Mark a message as executed. + // This function will retuns `true` if the message has already been executed. + function _replayProtect(bytes32 digest) internal returns (bool) { + // check if this message has already been executed + if (isMessageExecuted(digest)) { + return true; + } + + // mark this message as executed + _getMessageAttestationsStorage()[digest].executed = true; + + return false; + } + + function _useMessageSequence() internal returns (uint64 currentSequence) { + currentSequence = _getMessageSequenceStorage().num; + _getMessageSequenceStorage().num++; + } + + /// ============== Invariants ============================================= + + /// @dev When we add new immutables, this function should be updated + function _checkImmutables() internal view override { + assert(this.token() == token); + assert(this.mode() == mode); + assert(this.chainId() == chainId); + assert(this.rateLimitDuration() == rateLimitDuration); + } + + function _checkRegisteredTransceiversInvariants() internal view { + if (_getRegisteredTransceiversStorage().length != _getNumTransceiversStorage().registered) { + revert RetrievedIncorrectRegisteredTransceivers( + _getRegisteredTransceiversStorage().length, _getNumTransceiversStorage().registered + ); + } + } + + function _checkThresholdInvariants() internal view { + uint8 threshold = _getThresholdStorage().num; + _NumTransceivers memory numTransceivers = _getNumTransceiversStorage(); + + // invariant: threshold <= enabledTransceivers.length + if (threshold > numTransceivers.enabled) { + revert ThresholdTooHigh(threshold, numTransceivers.enabled); + } + + if (numTransceivers.registered > 0) { + if (threshold == 0) { + revert ZeroThreshold(); + } + } + } +} diff --git a/evm/src/TransceiverRegistry.sol b/evm/src/NttManager/TransceiverRegistry.sol similarity index 97% rename from evm/src/TransceiverRegistry.sol rename to evm/src/NttManager/TransceiverRegistry.sol index e7e35e56b..f7fce914c 100644 --- a/evm/src/TransceiverRegistry.sol +++ b/evm/src/NttManager/TransceiverRegistry.sol @@ -50,7 +50,7 @@ abstract contract TransceiverRegistry { _; } - /// =============== STORAGE =============================================== + // =============== Storage =============================================== bytes32 private constant TRANSCEIVER_INFOS_SLOT = bytes32(uint256(keccak256("ntt.transceiverInfos")) - 1); @@ -110,7 +110,7 @@ abstract contract TransceiverRegistry { } } - /// =============== GETTERS/SETTERS ======================================== + // =============== Storage Getters/Setters ======================================== function _setTransceiver(address transceiver) internal returns (uint8 index) { mapping(address => TransceiverInfo) storage transceiverInfos = _getTransceiverInfosStorage(); @@ -213,7 +213,7 @@ abstract contract TransceiverRegistry { result = _getEnabledTransceiversStorage(); } - /// ============== INVARIANTS ============================================= + // ============== Invariants ============================================= /// @dev Check that the transceiver nttManager is in a valid state. /// Checking these invariants is somewhat costly, but we only need to do it diff --git a/evm/src/Transceiver.sol b/evm/src/Transceiver/Transceiver.sol similarity index 75% rename from evm/src/Transceiver.sol rename to evm/src/Transceiver/Transceiver.sol index 79706bd1b..c355ccd65 100644 --- a/evm/src/Transceiver.sol +++ b/evm/src/Transceiver/Transceiver.sol @@ -1,14 +1,16 @@ // SPDX-License-Identifier: Apache 2 pragma solidity >=0.8.8 <0.9.0; -import "./libraries/TransceiverStructs.sol"; -import "./libraries/PausableOwnable.sol"; -import "./interfaces/INttManager.sol"; -import "./interfaces/ITransceiver.sol"; -import "./libraries/external/ReentrancyGuardUpgradeable.sol"; -import "./libraries/Implementation.sol"; import "wormhole-solidity-sdk/Utils.sol"; +import "../libraries/TransceiverStructs.sol"; +import "../libraries/PausableOwnable.sol"; +import "../libraries/external/ReentrancyGuardUpgradeable.sol"; +import "../libraries/Implementation.sol"; + +import "../interfaces/INttManager.sol"; +import "../interfaces/ITransceiver.sol"; + abstract contract Transceiver is ITransceiver, PausableOwnable, @@ -16,7 +18,8 @@ abstract contract Transceiver is Implementation { /// @dev updating bridgeNttManager requires a new Transceiver deployment. - /// Projects should implement their own governance to remove the old Transceiver contract address and then add the new one. + /// Projects should implement their own governance to remove the old Transceiver + /// contract address and then add the new one. address public immutable nttManager; address public immutable nttManagerToken; @@ -59,9 +62,9 @@ abstract contract Transceiver is function _migrate() internal virtual override {} - /// @dev This method checks that the the referecnes to the nttManager and its corresponding function are correct - /// When new immutable variables are added, this function should be updated. - function _checkImmutables() internal view override { + // @define This method checks that the the referecnes to the nttManager and its corresponding function + // are correct When new immutable variables are added, this function should be updated. + function _checkImmutables() internal view virtual override { assert(this.nttManager() == nttManager); assert(this.nttManagerToken() == nttManagerToken); } @@ -77,13 +80,16 @@ abstract contract Transceiver is } /// =============== TRANSCEIVING LOGIC =============================================== - /** - * @dev send a message to another chain. - * @param recipientChain The chain id of the recipient. - * @param instruction An additional Instruction provided by the Transceiver to be - * executed on the recipient chain. - * @param nttManagerMessage A message to be sent to the nttManager on the recipient chain. - */ + + /// @inheritdoc ITransceiver + function quoteDeliveryPrice( + uint16 targetChain, + TransceiverStructs.TransceiverInstruction memory instruction + ) external view returns (uint256) { + return _quoteDeliveryPrice(targetChain, instruction); + } + + /// @inheritdoc ITransceiver function sendMessage( uint16 recipientChain, TransceiverStructs.TransceiverInstruction memory instruction, @@ -100,6 +106,8 @@ abstract contract Transceiver is ); } + /// ============================= INTERNAL ========================================= + function _sendMessage( uint16 recipientChain, uint256 deliveryPayment, @@ -109,11 +117,9 @@ abstract contract Transceiver is bytes memory nttManagerMessage ) internal virtual; - // @dev This method is called by the BridgeNttManager contract to send a cross-chain message. - // Forwards the VAA payload to the transceiver nttManager contract. - // @param sourceChainId The chain id of the sender. - // @param sourceNttManagerAddress The address of the sender's nttManager contract. - // @param payload The VAA payload. + // @define This method is called by the BridgeNttManager contract to send a cross-chain message. + // @reverts if: + // - `recipientNttManagerAddress` does not match the address of this manager contract function _deliverToNttManager( uint16 sourceChainId, bytes32 sourceNttManagerAddress, @@ -128,13 +134,6 @@ abstract contract Transceiver is INttManager(nttManager).attestationReceived(sourceChainId, sourceNttManagerAddress, payload); } - function quoteDeliveryPrice( - uint16 targetChain, - TransceiverStructs.TransceiverInstruction memory instruction - ) external view returns (uint256) { - return _quoteDeliveryPrice(targetChain, instruction); - } - function _quoteDeliveryPrice( uint16 targetChain, TransceiverStructs.TransceiverInstruction memory transceiverInstruction diff --git a/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol b/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol new file mode 100644 index 000000000..1a83fcd51 --- /dev/null +++ b/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity >=0.8.8 <0.9.0; + +import "wormhole-solidity-sdk/WormholeRelayerSDK.sol"; +import "wormhole-solidity-sdk/libraries/BytesParsing.sol"; +import "wormhole-solidity-sdk/interfaces/IWormhole.sol"; + +import "../../libraries/TransceiverHelpers.sol"; +import "../../libraries/TransceiverStructs.sol"; + +import "../../interfaces/IWormholeTransceiver.sol"; +import "../../interfaces/ISpecialRelayer.sol"; +import "../../interfaces/INttManager.sol"; + +import "./WormholeTransceiverState.sol"; + +contract WormholeTransceiver is + IWormholeTransceiver, + IWormholeReceiver, + WormholeTransceiverState +{ + using BytesParsing for bytes; + + constructor( + address nttManager, + address wormholeCoreBridge, + address wormholeRelayerAddr, + address specialRelayerAddr, + uint8 _consistencyLevel + ) + WormholeTransceiverState( + nttManager, + wormholeCoreBridge, + wormholeRelayerAddr, + specialRelayerAddr, + _consistencyLevel + ) + {} + + // ==================== External Interface =============================================== + + /// @inheritdoc IWormholeTransceiver + function receiveMessage(bytes memory encodedMessage) external { + uint16 sourceChainId; + bytes memory payload; + (sourceChainId, payload) = _verifyMessage(encodedMessage); + + // parse the encoded Transceiver payload + TransceiverStructs.TransceiverMessage memory parsedTransceiverMessage; + TransceiverStructs.NttManagerMessage memory parsedNttManagerMessage; + (parsedTransceiverMessage, parsedNttManagerMessage) = TransceiverStructs + .parseTransceiverAndNttManagerMessage(WH_TRANSCEIVER_PAYLOAD_PREFIX, payload); + + _deliverToNttManager( + sourceChainId, + parsedTransceiverMessage.sourceNttManagerAddress, + parsedTransceiverMessage.recipientNttManagerAddress, + parsedNttManagerMessage + ); + } + + /// @inheritdoc IWormholeReceiver + function receiveWormholeMessages( + bytes memory payload, + bytes[] memory additionalMessages, + bytes32 sourceAddress, + uint16 sourceChain, + bytes32 deliveryHash + ) external payable onlyRelayer { + if (getWormholePeer(sourceChain) != sourceAddress) { + revert InvalidWormholePeer(sourceChain, sourceAddress); + } + + // VAA replay protection: + // - Note that this VAA is for the AR delivery, not for the raw message emitted by the source + // - chain Transceiver contract. The VAAs received by this entrypoint are different than the + // - VAA received by the receiveMessage entrypoint. + if (isVAAConsumed(deliveryHash)) { + revert TransferAlreadyCompleted(deliveryHash); + } + _setVAAConsumed(deliveryHash); + + // We don't honor additional messages in this handler. + if (additionalMessages.length > 0) { + revert UnexpectedAdditionalMessages(); + } + + // emit `ReceivedRelayedMessage` event + emit ReceivedRelayedMessage(deliveryHash, sourceChain, sourceAddress); + + // parse the encoded Transceiver payload + TransceiverStructs.TransceiverMessage memory parsedTransceiverMessage; + TransceiverStructs.NttManagerMessage memory parsedNttManagerMessage; + (parsedTransceiverMessage, parsedNttManagerMessage) = TransceiverStructs + .parseTransceiverAndNttManagerMessage(WH_TRANSCEIVER_PAYLOAD_PREFIX, payload); + + _deliverToNttManager( + sourceChain, + parsedTransceiverMessage.sourceNttManagerAddress, + parsedTransceiverMessage.recipientNttManagerAddress, + parsedNttManagerMessage + ); + } + + /// @inheritdoc IWormholeTransceiver + function parseWormholeTransceiverInstruction(bytes memory encoded) + public + pure + returns (WormholeTransceiverInstruction memory instruction) + { + // If the user doesn't pass in any transceiver instructions then the default is false + if (encoded.length == 0) { + instruction.shouldSkipRelayerSend = false; + return instruction; + } + + uint256 offset = 0; + (instruction.shouldSkipRelayerSend, offset) = encoded.asBoolUnchecked(offset); + encoded.checkLength(offset); + } + + /// @inheritdoc IWormholeTransceiver + function encodeWormholeTransceiverInstruction(WormholeTransceiverInstruction memory instruction) + public + pure + returns (bytes memory) + { + return abi.encodePacked(instruction.shouldSkipRelayerSend); + } + + // ==================== Internal ======================================================== + + function _quoteDeliveryPrice( + uint16 targetChain, + TransceiverStructs.TransceiverInstruction memory instruction + ) internal view override returns (uint256 nativePriceQuote) { + // Check the special instruction up front to see if we should skip sending via a relayer + WormholeTransceiverInstruction memory weIns = + parseWormholeTransceiverInstruction(instruction.payload); + if (weIns.shouldSkipRelayerSend) { + return 0; + } + + if (_checkInvalidRelayingConfig(targetChain)) { + revert InvalidRelayingConfig(targetChain); + } + + if (_shouldRelayViaStandardRelaying(targetChain)) { + (uint256 cost,) = wormholeRelayer.quoteEVMDeliveryPrice(targetChain, 0, GAS_LIMIT); + return cost; + } else if (isSpecialRelayingEnabled(targetChain)) { + uint256 cost = specialRelayer.quoteDeliveryPrice(getNttManagerToken(), targetChain, 0); + return cost; + } else { + return 0; + } + } + + function _sendMessage( + uint16 recipientChain, + uint256 deliveryPayment, + address caller, + bytes32 recipientNttManagerAddress, + TransceiverStructs.TransceiverInstruction memory instruction, + bytes memory nttManagerMessage + ) internal override { + ( + TransceiverStructs.TransceiverMessage memory transceiverMessage, + bytes memory encodedTransceiverPayload + ) = TransceiverStructs.buildAndEncodeTransceiverMessage( + WH_TRANSCEIVER_PAYLOAD_PREFIX, + toWormholeFormat(caller), + recipientNttManagerAddress, + nttManagerMessage, + new bytes(0) + ); + + WormholeTransceiverInstruction memory weIns = + parseWormholeTransceiverInstruction(instruction.payload); + + if (!weIns.shouldSkipRelayerSend && _shouldRelayViaStandardRelaying(recipientChain)) { + wormholeRelayer.sendPayloadToEvm{value: deliveryPayment}( + recipientChain, + fromWormholeFormat(getWormholePeer(recipientChain)), + encodedTransceiverPayload, + 0, + GAS_LIMIT + ); + } else if (!weIns.shouldSkipRelayerSend && isSpecialRelayingEnabled(recipientChain)) { + uint64 sequence = + wormhole.publishMessage(0, encodedTransceiverPayload, consistencyLevel); + specialRelayer.requestDelivery{value: deliveryPayment}( + getNttManagerToken(), recipientChain, 0, sequence + ); + } else { + wormhole.publishMessage(0, encodedTransceiverPayload, consistencyLevel); + } + + emit SendTransceiverMessage(recipientChain, transceiverMessage); + } + + function _verifyMessage(bytes memory encodedMessage) internal returns (uint16, bytes memory) { + // verify VAA against Wormhole Core Bridge contract + (IWormhole.VM memory vm, bool valid, string memory reason) = + wormhole.parseAndVerifyVM(encodedMessage); + + // ensure that the VAA is valid + if (!valid) { + revert InvalidVaa(reason); + } + + // ensure that the message came from a registered peer contract + if (!_verifyBridgeVM(vm)) { + revert InvalidWormholePeer(vm.emitterChainId, vm.emitterAddress); + } + + // save the VAA hash in storage to protect against replay attacks. + if (isVAAConsumed(vm.hash)) { + revert TransferAlreadyCompleted(vm.hash); + } + _setVAAConsumed(vm.hash); + + // emit `ReceivedMessage` event + emit ReceivedMessage(vm.hash, vm.emitterChainId, vm.emitterAddress, vm.sequence); + + return (vm.emitterChainId, vm.payload); + } + + function _verifyBridgeVM(IWormhole.VM memory vm) internal view returns (bool) { + checkFork(wormholeTransceiver_evmChainId); + return getWormholePeer(vm.emitterChainId) == vm.emitterAddress; + } +} diff --git a/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiverState.sol b/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiverState.sol new file mode 100644 index 000000000..3c72cb2f7 --- /dev/null +++ b/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiverState.sol @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity >=0.8.8 <0.9.0; + +import "wormhole-solidity-sdk/WormholeRelayerSDK.sol"; +import "wormhole-solidity-sdk/libraries/BytesParsing.sol"; +import "wormhole-solidity-sdk/interfaces/IWormhole.sol"; + +import "../../libraries/TransceiverHelpers.sol"; +import "../../libraries/TransceiverStructs.sol"; + +import "../../interfaces/IWormholeTransceiver.sol"; +import "../../interfaces/IWormholeTransceiverState.sol"; +import "../../interfaces/ISpecialRelayer.sol"; +import "../../interfaces/INttManager.sol"; + +import "../Transceiver.sol"; + +abstract contract WormholeTransceiverState is IWormholeTransceiverState, Transceiver { + using BytesParsing for bytes; + + // ==================== Immutables =============================================== + uint8 public immutable consistencyLevel; + IWormhole public immutable wormhole; + IWormholeRelayer public immutable wormholeRelayer; + ISpecialRelayer public immutable specialRelayer; + uint256 immutable wormholeTransceiver_evmChainId; + + // ==================== Constants ================================================ + uint256 public constant GAS_LIMIT = 500000; + + /// @dev Prefix for all TransceiverMessage payloads + /// This is 0x99'E''W''H' + /// @notice Magic string (constant value set by messaging provider) that idenfies the payload as an transceiver-emitted payload. + /// Note that this is not a security critical field. It's meant to be used by messaging providers to identify which messages are Transceiver-related. + bytes4 constant WH_TRANSCEIVER_PAYLOAD_PREFIX = 0x9945FF10; + + /// @dev Prefix for all Wormhole transceiver initialisation payloads + /// This is bytes4(keccak256("WormholeTransceiverInit")) + bytes4 constant WH_TRANSCEIVER_INIT_PREFIX = 0x9c23bd3b; + + /// @dev Prefix for all Wormhole peer registration payloads + /// This is bytes4(keccak256("WormholePeerRegistration")) + bytes4 constant WH_PEER_REGISTRATION_PREFIX = 0x18fc67c2; + + constructor( + address nttManager, + address wormholeCoreBridge, + address wormholeRelayerAddr, + address specialRelayerAddr, + uint8 _consistencyLevel + ) Transceiver(nttManager) { + wormhole = IWormhole(wormholeCoreBridge); + wormholeRelayer = IWormholeRelayer(wormholeRelayerAddr); + specialRelayer = ISpecialRelayer(specialRelayerAddr); + wormholeTransceiver_evmChainId = block.chainid; + consistencyLevel = _consistencyLevel; + } + + enum RelayingType { + Standard, + Special, + Manual + } + + function _initialize() internal override { + super._initialize(); + _initializeTransceiver(); + } + + function _initializeTransceiver() internal { + TransceiverStructs.TransceiverInit memory init = TransceiverStructs.TransceiverInit({ + transceiverIdentifier: WH_TRANSCEIVER_INIT_PREFIX, + nttManagerAddress: toWormholeFormat(nttManager), + nttManagerMode: INttManager(nttManager).getMode(), + tokenAddress: toWormholeFormat(nttManagerToken), + tokenDecimals: INttManager(nttManager).tokenDecimals() + }); + wormhole.publishMessage(0, TransceiverStructs.encodeTransceiverInit(init), consistencyLevel); + } + + function _checkImmutables() internal view override { + super._checkImmutables(); + assert(this.wormhole() == wormhole); + assert(this.wormholeRelayer() == wormholeRelayer); + assert(this.specialRelayer() == specialRelayer); + assert(this.consistencyLevel() == consistencyLevel); + } + + // =============== Storage =============================================== + + bytes32 private constant WORMHOLE_CONSUMED_VAAS_SLOT = + bytes32(uint256(keccak256("whTransceiver.consumedVAAs")) - 1); + + bytes32 private constant WORMHOLE_PEERS_SLOT = + bytes32(uint256(keccak256("whTransceiver.peers")) - 1); + + bytes32 private constant WORMHOLE_RELAYING_ENABLED_CHAINS_SLOT = + bytes32(uint256(keccak256("whTransceiver.relayingEnabledChains")) - 1); + + bytes32 private constant SPECIAL_RELAYING_ENABLED_CHAINS_SLOT = + bytes32(uint256(keccak256("whTransceiver.specialRelayingEnabledChains")) - 1); + + bytes32 private constant WORMHOLE_EVM_CHAIN_IDS = + bytes32(uint256(keccak256("whTransceiver.evmChainIds")) - 1); + + // =============== Storage Setters/Getters ======================================== + + function _getWormholeConsumedVAAsStorage() + internal + pure + returns (mapping(bytes32 => bool) storage $) + { + uint256 slot = uint256(WORMHOLE_CONSUMED_VAAS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + function _getWormholePeersStorage() + internal + pure + returns (mapping(uint16 => bytes32) storage $) + { + uint256 slot = uint256(WORMHOLE_PEERS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + function _getWormholeRelayingEnabledChainsStorage() + internal + pure + returns (mapping(uint16 => uint256) storage $) + { + uint256 slot = uint256(WORMHOLE_RELAYING_ENABLED_CHAINS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + function _getSpecialRelayingEnabledChainsStorage() + internal + pure + returns (mapping(uint16 => uint256) storage $) + { + uint256 slot = uint256(SPECIAL_RELAYING_ENABLED_CHAINS_SLOT); + assembly ("memory-safe") { + $.slot := slot + } + } + + function _getWormholeEvmChainIdsStorage() + internal + pure + returns (mapping(uint16 => uint256) storage $) + { + uint256 slot = uint256(WORMHOLE_EVM_CHAIN_IDS); + assembly ("memory-safe") { + $.slot := slot + } + } + + // =============== Public Getters ====================================================== + + /// @inheritdoc IWormholeTransceiverState + function isVAAConsumed(bytes32 hash) public view returns (bool) { + return _getWormholeConsumedVAAsStorage()[hash]; + } + + /// @inheritdoc IWormholeTransceiverState + function getWormholePeer(uint16 chainId) public view returns (bytes32) { + return _getWormholePeersStorage()[chainId]; + } + + /// @inheritdoc IWormholeTransceiverState + function isWormholeRelayingEnabled(uint16 chainId) public view returns (bool) { + return toBool(_getWormholeRelayingEnabledChainsStorage()[chainId]); + } + + /// @inheritdoc IWormholeTransceiverState + function isSpecialRelayingEnabled(uint16 chainId) public view returns (bool) { + return toBool(_getSpecialRelayingEnabledChainsStorage()[chainId]); + } + + /// @inheritdoc IWormholeTransceiverState + function isWormholeEvmChain(uint16 chainId) public view returns (bool) { + return toBool(_getWormholeEvmChainIdsStorage()[chainId]); + } + + // =============== Admin =============================================================== + + /// @inheritdoc IWormholeTransceiverState + function setWormholePeer(uint16 peerChainId, bytes32 peerContract) external onlyOwner { + if (peerChainId == 0) { + revert InvalidWormholeChainIdZero(); + } + if (peerContract == bytes32(0)) { + revert InvalidWormholePeerZeroAddress(); + } + + bytes32 oldPeerContract = _getWormholePeersStorage()[peerChainId]; + + // We don't want to allow updating a peer since this adds complexity in the accountant + // If the owner makes a mistake with peer registration they should deploy a new Wormhole + // transceiver and register this new transceiver with the NttManager + if (oldPeerContract != bytes32(0)) { + revert PeerAlreadySet(peerChainId, oldPeerContract); + } + + _getWormholePeersStorage()[peerChainId] = peerContract; + + // Publish a message for this transceiver registration + TransceiverStructs.TransceiverRegistration memory registration = TransceiverStructs + .TransceiverRegistration({ + transceiverIdentifier: WH_PEER_REGISTRATION_PREFIX, + transceiverChainId: peerChainId, + transceiverAddress: peerContract + }); + wormhole.publishMessage( + 0, TransceiverStructs.encodeTransceiverRegistration(registration), consistencyLevel + ); + + emit SetWormholePeer(peerChainId, peerContract); + } + + /// @inheritdoc IWormholeTransceiverState + function setIsWormholeEvmChain(uint16 chainId) external onlyOwner { + if (chainId == 0) { + revert InvalidWormholeChainIdZero(); + } + _getWormholeEvmChainIdsStorage()[chainId] = TRUE; + + emit SetIsWormholeEvmChain(chainId); + } + + /// @inheritdoc IWormholeTransceiverState + function setIsWormholeRelayingEnabled(uint16 chainId, bool isEnabled) external onlyOwner { + if (chainId == 0) { + revert InvalidWormholeChainIdZero(); + } + _getWormholeRelayingEnabledChainsStorage()[chainId] = toWord(isEnabled); + + emit SetIsWormholeRelayingEnabled(chainId, isEnabled); + } + + /// @inheritdoc IWormholeTransceiverState + function setIsSpecialRelayingEnabled(uint16 chainId, bool isEnabled) external onlyOwner { + if (chainId == 0) { + revert InvalidWormholeChainIdZero(); + } + _getSpecialRelayingEnabledChainsStorage()[chainId] = toWord(isEnabled); + + emit SetIsSpecialRelayingEnabled(chainId, isEnabled); + } + + // ============= Internal =============================================================== + + function _checkInvalidRelayingConfig(uint16 chainId) internal view returns (bool) { + return isWormholeRelayingEnabled(chainId) && !isWormholeEvmChain(chainId); + } + + function _shouldRelayViaStandardRelaying(uint16 chainId) internal view returns (bool) { + return isWormholeRelayingEnabled(chainId) && isWormholeEvmChain(chainId); + } + + function _setVAAConsumed(bytes32 hash) internal { + _getWormholeConsumedVAAsStorage()[hash] = true; + } + + // =============== MODIFIERS =============================================== + + modifier onlyRelayer() { + if (msg.sender != address(wormholeRelayer)) { + revert CallerNotRelayer(msg.sender); + } + _; + } +} diff --git a/evm/src/WormholeTransceiver.sol b/evm/src/WormholeTransceiver.sol deleted file mode 100644 index f2f887e5f..000000000 --- a/evm/src/WormholeTransceiver.sol +++ /dev/null @@ -1,459 +0,0 @@ -// SPDX-License-Identifier: Apache 2 -pragma solidity >=0.8.8 <0.9.0; - -import "wormhole-solidity-sdk/WormholeRelayerSDK.sol"; -import "wormhole-solidity-sdk/libraries/BytesParsing.sol"; -import "wormhole-solidity-sdk/interfaces/IWormhole.sol"; - -import "./libraries/TransceiverHelpers.sol"; -import "./libraries/TransceiverStructs.sol"; -import "./interfaces/IWormholeTransceiver.sol"; -import "./interfaces/ISpecialRelayer.sol"; -import "./interfaces/INttManager.sol"; -import "./Transceiver.sol"; - -contract WormholeTransceiver is Transceiver, IWormholeTransceiver, IWormholeReceiver { - using BytesParsing for bytes; - - uint256 public constant GAS_LIMIT = 500000; - uint8 public immutable consistencyLevel; - - /// @dev Prefix for all TransceiverMessage payloads - /// This is 0x99'E''W''H' - /// @notice Magic string (constant value set by messaging provider) that idenfies the payload as an transceiver-emitted payload. - /// Note that this is not a security critical field. It's meant to be used by messaging providers to identify which messages are Transceiver-related. - bytes4 constant WH_TRANSCEIVER_PAYLOAD_PREFIX = 0x9945FF10; - - /// @dev Prefix for all Wormhole transceiver initialisation payloads - /// This is bytes4(keccak256("WormholeTransceiverInit")) - bytes4 constant WH_TRANSCEIVER_INIT_PREFIX = 0x9c23bd3b; - - /// @dev Prefix for all Wormhole peer registration payloads - /// This is bytes4(keccak256("WormholePeerRegistration")) - bytes4 constant WH_PEER_REGISTRATION_PREFIX = 0x18fc67c2; - - IWormhole public immutable wormhole; - IWormholeRelayer public immutable wormholeRelayer; - ISpecialRelayer public immutable specialRelayer; - uint256 public immutable wormholeTransceiver_evmChainId; - - struct WormholeTransceiverInstruction { - bool shouldSkipRelayerSend; - } - - enum RelayingType { - Standard, - Special, - Manual - } - - /// =============== STORAGE =============================================== - - bytes32 private constant WORMHOLE_CONSUMED_VAAS_SLOT = - bytes32(uint256(keccak256("whTransceiver.consumedVAAs")) - 1); - - bytes32 private constant WORMHOLE_PEERS_SLOT = - bytes32(uint256(keccak256("whTransceiver.peers")) - 1); - - bytes32 private constant WORMHOLE_RELAYING_ENABLED_CHAINS_SLOT = - bytes32(uint256(keccak256("whTransceiver.relayingEnabledChains")) - 1); - - bytes32 private constant SPECIAL_RELAYING_ENABLED_CHAINS_SLOT = - bytes32(uint256(keccak256("whTransceiver.specialRelayingEnabledChains")) - 1); - - bytes32 private constant WORMHOLE_EVM_CHAIN_IDS = - bytes32(uint256(keccak256("whTransceiver.evmChainIds")) - 1); - - /// =============== GETTERS/SETTERS ======================================== - - function _getWormholeConsumedVAAsStorage() - internal - pure - returns (mapping(bytes32 => bool) storage $) - { - uint256 slot = uint256(WORMHOLE_CONSUMED_VAAS_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - - function _getWormholePeersStorage() - internal - pure - returns (mapping(uint16 => bytes32) storage $) - { - uint256 slot = uint256(WORMHOLE_PEERS_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - - function _getWormholeRelayingEnabledChainsStorage() - internal - pure - returns (mapping(uint16 => uint256) storage $) - { - uint256 slot = uint256(WORMHOLE_RELAYING_ENABLED_CHAINS_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - - function _getSpecialRelayingEnabledChainsStorage() - internal - pure - returns (mapping(uint16 => uint256) storage $) - { - uint256 slot = uint256(SPECIAL_RELAYING_ENABLED_CHAINS_SLOT); - assembly ("memory-safe") { - $.slot := slot - } - } - - function _getWormholeEvmChainIdsStorage() - internal - pure - returns (mapping(uint16 => uint256) storage $) - { - uint256 slot = uint256(WORMHOLE_EVM_CHAIN_IDS); - assembly ("memory-safe") { - $.slot := slot - } - } - - modifier onlyRelayer() { - if (msg.sender != address(wormholeRelayer)) { - revert CallerNotRelayer(msg.sender); - } - _; - } - - constructor( - address nttManager, - address wormholeCoreBridge, - address wormholeRelayerAddr, - address specialRelayerAddr, - uint8 _consistencyLevel - ) Transceiver(nttManager) { - wormhole = IWormhole(wormholeCoreBridge); - wormholeRelayer = IWormholeRelayer(wormholeRelayerAddr); - specialRelayer = ISpecialRelayer(specialRelayerAddr); - wormholeTransceiver_evmChainId = block.chainid; - consistencyLevel = _consistencyLevel; - } - - function _initialize() internal override { - super._initialize(); - _initializeTransceiver(); - } - - function _initializeTransceiver() internal { - TransceiverStructs.TransceiverInit memory init = TransceiverStructs.TransceiverInit({ - transceiverIdentifier: WH_TRANSCEIVER_INIT_PREFIX, - nttManagerAddress: toWormholeFormat(nttManager), - nttManagerMode: INttManager(nttManager).getMode(), - tokenAddress: toWormholeFormat(nttManagerToken), - tokenDecimals: INttManager(nttManager).tokenDecimals() - }); - wormhole.publishMessage(0, TransceiverStructs.encodeTransceiverInit(init), consistencyLevel); - } - - function _checkInvalidRelayingConfig(uint16 chainId) internal view returns (bool) { - return isWormholeRelayingEnabled(chainId) && !isWormholeEvmChain(chainId); - } - - function _shouldRelayViaStandardRelaying(uint16 chainId) internal view returns (bool) { - return isWormholeRelayingEnabled(chainId) && isWormholeEvmChain(chainId); - } - - function _quoteDeliveryPrice( - uint16 targetChain, - TransceiverStructs.TransceiverInstruction memory instruction - ) internal view override returns (uint256 nativePriceQuote) { - // Check the special instruction up front to see if we should skip sending via a relayer - WormholeTransceiverInstruction memory weIns = - parseWormholeTransceiverInstruction(instruction.payload); - if (weIns.shouldSkipRelayerSend) { - return 0; - } - - if (_checkInvalidRelayingConfig(targetChain)) { - revert InvalidRelayingConfig(targetChain); - } - - if (_shouldRelayViaStandardRelaying(targetChain)) { - (uint256 cost,) = wormholeRelayer.quoteEVMDeliveryPrice(targetChain, 0, GAS_LIMIT); - return cost; - } else if (isSpecialRelayingEnabled(targetChain)) { - uint256 cost = specialRelayer.quoteDeliveryPrice(getNttManagerToken(), targetChain, 0); - return cost; - } else { - return 0; - } - } - - function _sendMessage( - uint16 recipientChain, - uint256 deliveryPayment, - address caller, - bytes32 recipientNttManagerAddress, - TransceiverStructs.TransceiverInstruction memory instruction, - bytes memory nttManagerMessage - ) internal override { - ( - TransceiverStructs.TransceiverMessage memory transceiverMessage, - bytes memory encodedTransceiverPayload - ) = TransceiverStructs.buildAndEncodeTransceiverMessage( - WH_TRANSCEIVER_PAYLOAD_PREFIX, - toWormholeFormat(caller), - recipientNttManagerAddress, - nttManagerMessage, - new bytes(0) - ); - - WormholeTransceiverInstruction memory weIns = - parseWormholeTransceiverInstruction(instruction.payload); - - if (!weIns.shouldSkipRelayerSend && _shouldRelayViaStandardRelaying(recipientChain)) { - wormholeRelayer.sendPayloadToEvm{value: deliveryPayment}( - recipientChain, - fromWormholeFormat(getWormholePeer(recipientChain)), - encodedTransceiverPayload, - 0, - GAS_LIMIT - ); - - emit RelayingInfo(uint8(RelayingType.Standard), deliveryPayment); - } else if (!weIns.shouldSkipRelayerSend && isSpecialRelayingEnabled(recipientChain)) { - uint64 sequence = - wormhole.publishMessage(0, encodedTransceiverPayload, consistencyLevel); - specialRelayer.requestDelivery{value: deliveryPayment}( - getNttManagerToken(), recipientChain, 0, sequence - ); - - emit RelayingInfo(uint8(RelayingType.Special), deliveryPayment); - } else { - wormhole.publishMessage(0, encodedTransceiverPayload, consistencyLevel); - emit RelayingInfo(uint8(RelayingType.Manual), deliveryPayment); - } - - emit SendTransceiverMessage(recipientChain, transceiverMessage); - } - - function receiveWormholeMessages( - bytes memory payload, - bytes[] memory additionalMessages, - bytes32 sourceAddress, - uint16 sourceChain, - bytes32 deliveryHash - ) external payable onlyRelayer { - if (getWormholePeer(sourceChain) != sourceAddress) { - revert InvalidWormholePeer(sourceChain, sourceAddress); - } - - // VAA replay protection - // Note that this VAA is for the AR delivery, not for the raw message emitted by the source chain Transceiver contract. - // The VAAs received by this entrypoint are different than the VAA received by the receiveMessage entrypoint. - if (isVAAConsumed(deliveryHash)) { - revert TransferAlreadyCompleted(deliveryHash); - } - _setVAAConsumed(deliveryHash); - - // We don't honor additional message in this handler. - if (additionalMessages.length > 0) { - revert UnexpectedAdditionalMessages(); - } - - // emit `ReceivedRelayedMessage` event - emit ReceivedRelayedMessage(deliveryHash, sourceChain, sourceAddress); - - // parse the encoded Transceiver payload - TransceiverStructs.TransceiverMessage memory parsedTransceiverMessage; - TransceiverStructs.NttManagerMessage memory parsedNttManagerMessage; - (parsedTransceiverMessage, parsedNttManagerMessage) = TransceiverStructs - .parseTransceiverAndNttManagerMessage(WH_TRANSCEIVER_PAYLOAD_PREFIX, payload); - - _deliverToNttManager( - sourceChain, - parsedTransceiverMessage.sourceNttManagerAddress, - parsedTransceiverMessage.recipientNttManagerAddress, - parsedNttManagerMessage - ); - } - - /// @notice Receive an attested message from the verification layer - /// This function should verify the encodedVm and then deliver the attestation to the transceiver nttManager contract. - function receiveMessage(bytes memory encodedMessage) external { - uint16 sourceChainId; - bytes memory payload; - (sourceChainId, payload) = _verifyMessage(encodedMessage); - - // parse the encoded Transceiver payload - TransceiverStructs.TransceiverMessage memory parsedTransceiverMessage; - TransceiverStructs.NttManagerMessage memory parsedNttManagerMessage; - (parsedTransceiverMessage, parsedNttManagerMessage) = TransceiverStructs - .parseTransceiverAndNttManagerMessage(WH_TRANSCEIVER_PAYLOAD_PREFIX, payload); - - _deliverToNttManager( - sourceChainId, - parsedTransceiverMessage.sourceNttManagerAddress, - parsedTransceiverMessage.recipientNttManagerAddress, - parsedNttManagerMessage - ); - } - - function _verifyMessage(bytes memory encodedMessage) internal returns (uint16, bytes memory) { - // verify VAA against Wormhole Core Bridge contract - (IWormhole.VM memory vm, bool valid, string memory reason) = - wormhole.parseAndVerifyVM(encodedMessage); - - // ensure that the VAA is valid - if (!valid) { - revert InvalidVaa(reason); - } - - // ensure that the message came from a registered peer contract - if (!_verifyBridgeVM(vm)) { - revert InvalidWormholePeer(vm.emitterChainId, vm.emitterAddress); - } - - // save the VAA hash in storage to protect against replay attacks. - if (isVAAConsumed(vm.hash)) { - revert TransferAlreadyCompleted(vm.hash); - } - _setVAAConsumed(vm.hash); - - // emit `ReceivedMessage` event - emit ReceivedMessage(vm.hash, vm.emitterChainId, vm.emitterAddress, vm.sequence); - - return (vm.emitterChainId, vm.payload); - } - - function _verifyBridgeVM(IWormhole.VM memory vm) internal view returns (bool) { - checkFork(wormholeTransceiver_evmChainId); - return getWormholePeer(vm.emitterChainId) == vm.emitterAddress; - } - - function isVAAConsumed(bytes32 hash) public view returns (bool) { - return _getWormholeConsumedVAAsStorage()[hash]; - } - - function _setVAAConsumed(bytes32 hash) internal { - _getWormholeConsumedVAAsStorage()[hash] = true; - } - - /// @notice Get the corresponding Transceiver contract on other chains that have been registered via governance. - /// This design should be extendable to other chains, so each Transceiver would be potentially concerned with Transceivers on multiple other chains - /// Note that peers are registered under wormhole chainID values - function getWormholePeer(uint16 chainId) public view returns (bytes32) { - return _getWormholePeersStorage()[chainId]; - } - - function setWormholePeer(uint16 peerChainId, bytes32 peerContract) external onlyOwner { - _setWormholePeer(peerChainId, peerContract); - } - - function _setWormholePeer(uint16 chainId, bytes32 peerContract) internal { - if (chainId == 0) { - revert InvalidWormholeChainIdZero(); - } - if (peerContract == bytes32(0)) { - revert InvalidWormholePeerZeroAddress(); - } - - bytes32 oldPeerContract = _getWormholePeersStorage()[chainId]; - - // We don't want to allow updating a peer since this adds complexity in the accountant - // If the owner makes a mistake with peer registration they should deploy a new Wormhole - // transceiver and register this new transceiver with the NttManager - if (oldPeerContract != bytes32(0)) { - revert PeerAlreadySet(chainId, oldPeerContract); - } - - _getWormholePeersStorage()[chainId] = peerContract; - - // Publish a message for this transceiver registration - TransceiverStructs.TransceiverRegistration memory registration = TransceiverStructs - .TransceiverRegistration({ - transceiverIdentifier: WH_PEER_REGISTRATION_PREFIX, - transceiverChainId: chainId, - transceiverAddress: peerContract - }); - wormhole.publishMessage( - 0, TransceiverStructs.encodeTransceiverRegistration(registration), consistencyLevel - ); - - emit SetWormholePeer(chainId, peerContract); - } - - function isWormholeRelayingEnabled(uint16 chainId) public view returns (bool) { - return toBool(_getWormholeRelayingEnabledChainsStorage()[chainId]); - } - - function setIsWormholeRelayingEnabled(uint16 chainId, bool isEnabled) external onlyOwner { - _setIsWormholeRelayingEnabled(chainId, isEnabled); - } - - function _setIsWormholeRelayingEnabled(uint16 chainId, bool isEnabled) internal { - if (chainId == 0) { - revert InvalidWormholeChainIdZero(); - } - _getWormholeRelayingEnabledChainsStorage()[chainId] = toWord(isEnabled); - - emit SetIsWormholeRelayingEnabled(chainId, isEnabled); - } - - function isSpecialRelayingEnabled(uint16 chainId) public view returns (bool) { - return toBool(_getSpecialRelayingEnabledChainsStorage()[chainId]); - } - - function _setIsSpecialRelayingEnabled(uint16 chainId, bool isEnabled) internal { - if (chainId == 0) { - revert InvalidWormholeChainIdZero(); - } - _getSpecialRelayingEnabledChainsStorage()[chainId] = toWord(isEnabled); - - emit SetIsSpecialRelayingEnabled(chainId, isEnabled); - } - - function isWormholeEvmChain(uint16 chainId) public view returns (bool) { - return toBool(_getWormholeEvmChainIdsStorage()[chainId]); - } - - function setIsWormholeEvmChain(uint16 chainId) external onlyOwner { - _setIsWormholeEvmChain(chainId); - } - - function _setIsWormholeEvmChain(uint16 chainId) internal { - if (chainId == 0) { - revert InvalidWormholeChainIdZero(); - } - _getWormholeEvmChainIdsStorage()[chainId] = TRUE; - - emit SetIsWormholeEvmChain(chainId); - } - - function parseWormholeTransceiverInstruction(bytes memory encoded) - public - pure - returns (WormholeTransceiverInstruction memory instruction) - { - // If the user doesn't pass in any transceiver instructions then the default is false - if (encoded.length == 0) { - instruction.shouldSkipRelayerSend = false; - return instruction; - } - - uint256 offset = 0; - (instruction.shouldSkipRelayerSend, offset) = encoded.asBoolUnchecked(offset); - encoded.checkLength(offset); - } - - function encodeWormholeTransceiverInstruction(WormholeTransceiverInstruction memory instruction) - public - pure - returns (bytes memory) - { - return abi.encodePacked(instruction.shouldSkipRelayerSend); - } -} diff --git a/evm/src/interfaces/INttManager.sol b/evm/src/interfaces/INttManager.sol index 98af8e3bb..f3136a059 100644 --- a/evm/src/interfaces/INttManager.sol +++ b/evm/src/interfaces/INttManager.sol @@ -4,7 +4,30 @@ pragma solidity >=0.8.8 <0.9.0; import "../libraries/TrimmedAmount.sol"; import "../libraries/TransceiverStructs.sol"; -interface INttManager { +import "./INttManagerState.sol"; + +interface INttManager is INttManagerState { + enum Mode { + LOCKING, + BURNING + } + + // @dev Information about attestations for a given message. + struct AttestationInfo { + // whether this message has been executed + bool executed; + // bitmap of transceivers that have attested to this message (NOTE: might contain disabled transceivers) + uint64 attestedTransceivers; + } + + struct _Sequence { + uint64 num; + } + + struct _Threshold { + uint8 num; + } + /// @notice payment for a transfer is too low. /// @param requiredPayment The required payment. /// @param providedPayment The provided payment. @@ -16,43 +39,41 @@ interface INttManager { //// @param amount The amount to transfer. error TransferAmountHasDust(uint256 amount, uint256 dust); + /// @notice The mode is invalid. It is neither in LOCKING or BURNING mode. + /// @param mode The mode. + error InvalidMode(uint8 mode); + + error RefundFailed(uint256 refundAmount); + error TransceiverAlreadyAttestedToMessage(bytes32 nttManagerMessageHash); error MessageNotApproved(bytes32 msgHash); error InvalidTargetChain(uint16 targetChain, uint16 thisChain); error ZeroAmount(); error InvalidRecipient(); error BurnAmountDifferentThanBalanceDiff(uint256 burnAmount, uint256 balanceDiff); - /// @notice The mode is invalid. It is neither in LOCKING or BURNING mode. - /// @param mode The mode. - error InvalidMode(uint8 mode); + /// @notice Transfer a given amount to a recipient on a given chain. This function is called + /// by the user to send the token cross-chain. This function will either lock or burn the + /// sender's tokens. Finally, this function will call into registered `Endpoint` contracts + /// to send a message with the incrementing sequence number and the token transfer payload. + /// @param amount The amount to transfer. + /// @param recipientChain The chain ID for the destination. + /// @param recipient The recipient address. + function transfer( + uint256 amount, + uint16 recipientChain, + bytes32 recipient + ) external payable returns (uint64 msgId); - /// @notice the peer for the chain does not match the configuration. - /// @param chainId ChainId of the source chain. - /// @param peerAddress Address of the peer nttManager contract. - error InvalidPeer(uint16 chainId, bytes32 peerAddress); - error InvalidPeerChainIdZero(); - - /// @notice Peer cannot be the zero address. - error InvalidPeerZeroAddress(); - - /// @notice The number of thresholds should not be zero. - error ZeroThreshold(); - - /// @notice The threshold for transceiver attestations is too high. - /// @param threshold The threshold. - /// @param transceivers The number of transceivers. - error ThresholdTooHigh(uint256 threshold, uint256 transceivers); - error RetrievedIncorrectRegisteredTransceivers(uint256 retrieved, uint256 registered); - - // @notice transfer a given amount to a recipient on a given chain. - // @dev transfers are queued if the outbound limit is hit - // and must be completed by the client. - // - // @param amount The amount to transfer. - // @param recipientChain The chain to transfer to. - // @param recipient The recipient address. - // @param shouldQueue Whether the transfer should be queued if the outbound limit is hit. - // @param encodedInstructions Additional instructions to be forwarded to the recipient chain. + /// @notice Transfer a given amount to a recipient on a given chain. This function is called + /// by the user to send the token cross-chain. This function will either lock or burn the + /// sender's tokens. Finally, this function will call into registered `Endpoint` contracts + /// to send a message with the incrementing sequence number and the token transfer payload. + /// @dev Transfers are queued if the outbound limit is hit and must be completed by the client. + /// @param amount The amount to transfer. + /// @param recipientChain The chain ID for the destination. + /// @param recipient The recipient address. + /// @param shouldQueue Whether the transfer should be queued if the outbound limit is hit. + /// @param encodedInstructions Additional instructions to be forwarded to the recipient chain. function transfer( uint256 amount, uint16 recipientChain, @@ -61,23 +82,8 @@ interface INttManager { bytes memory encodedInstructions ) external payable returns (uint64 msgId); - function getPeer(uint16 chainId_) external view returns (bytes32); - - function setPeer(uint16 peerChainId, bytes32 peerContract) external; - - /// @notice Check if a message has been approved. The message should have at least - /// the minimum threshold of attestations fron distinct transceivers. - /// - /// @param digest The digest of the message. - /// @return Whether the message has been approved. - function isMessageApproved(bytes32 digest) external view returns (bool); - - function isMessageExecuted(bytes32 digest) external view returns (bool); - - /// @notice Complete an outbound trasnfer that's been queued. - /// @dev This method is called by the client to complete an - /// outbound transfer that's been queued. - /// + /// @notice Complete an outbound transfer that's been queued. + /// @dev This method is called by the client to complete an outbound transfer that's been queued. /// @param queueSequence The sequence of the message in the queue. /// @return msgSequence The sequence of the message. function completeOutboundQueuedTransfer(uint64 queueSequence) @@ -85,40 +91,23 @@ interface INttManager { payable returns (uint64 msgSequence); - // @notice Complete an inbound queued transfer. - // @param digest The digest of the message to complete. + /// @notice Complete an inbound queued transfer. + /// @param digest The digest of the message to complete. function completeInboundQueuedTransfer(bytes32 digest) external; - // @notice Set the outbound transfer limit for a given chain. - // @param limit The new limit. - function setOutboundLimit(uint256 limit) external; - - // @notice Set the inbound transfer limit for a given chain. - // @param limit The new limit. - // @param chainId The chain to set the limit for. - function setInboundLimit(uint256 limit, uint16 chainId) external; - - // @notice Fetch the delivery price for a given recipient chain transfer. - // @param recipientChain The chain to transfer to. - // @param transceiverInstructions An additional instruction the transceiver can forward to - // the recipient chain. - // @param enabledTransceivers The transceivers that are enabled for the transfer. - // @return The delivery prices associated with each transceiver, and the sum - // of these prices. + /// @notice Fetch the delivery price for a given recipient chain transfer. + /// @param recipientChain The chain ID of the transfer destination. + /// @return - The delivery prices associated with each endpoint and the total price. function quoteDeliveryPrice( uint16 recipientChain, TransceiverStructs.TransceiverInstruction[] memory transceiverInstructions, address[] memory enabledTransceivers ) external view returns (uint256[] memory, uint256); - function nextMessageSequence() external view returns (uint64); - - function token() external view returns (address); - - /// @notice Called by an Transceiver contract to deliver a verified attestation. - /// @dev This function enforces attestation threshold and replay logic for messages. - /// Once all validations are complete, this function calls _executeMsg to execute - /// the command specified by the message. + /// @notice Called by an Endpoint contract to deliver a verified attestation. + /// @dev This function enforces attestation threshold and replay logic for messages. Once all + /// validations are complete, this function calls `executeMsg` to execute the command specified + /// by the message. /// @param sourceChainId The chain id of the sender. /// @param sourceNttManagerAddress The address of the sender's nttManager contract. /// @param payload The VAA payload. @@ -128,15 +117,19 @@ interface INttManager { TransceiverStructs.NttManagerMessage memory payload ) external; - /// @notice upgrade to a new nttManager implementation. - /// @dev This is upgraded via a proxy. - /// - /// @param newImplementation The address of the new implementation. - function upgrade(address newImplementation) external; - - /// @notice Returns the mode (locking or burning) of the NttManager. - /// @return mode A uint8 corresponding to the mode - function getMode() external view returns (uint8); + /// @notice Called after a message has been sufficiently verified to execute the command in the message. + /// This function will decode the payload as an NttManagerMessage to extract the sequence, msgType, + /// and other parameters. + /// @dev This function is exposed as a fallback for when an `Transceiver` is deregistered + /// when a message is in flight. + /// @param sourceChainId The chain id of the sender. + /// @param sourceNttManagerAddress The address of the sender's nttManager contract. + /// @param message The message to execute. + function executeMsg( + uint16 sourceChainId, + bytes32 sourceNttManagerAddress, + TransceiverStructs.NttManagerMessage memory message + ) external; /// @notice Returns the number of decimals of the token managed by the NttManager. /// @return decimals The number of decimals of the token. diff --git a/evm/src/interfaces/INttManagerState.sol b/evm/src/interfaces/INttManagerState.sol new file mode 100644 index 000000000..a789ba965 --- /dev/null +++ b/evm/src/interfaces/INttManagerState.sol @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity >=0.8.8 <0.9.0; + +import "../libraries/TrimmedAmount.sol"; +import "../libraries/TransceiverStructs.sol"; + +import "./INttManagerState.sol"; + +interface INttManagerState { + /// @notice The caller is not the deployer. + error UnexpectedDeployer(address expectedOwner, address owner); + + /// @notice Peer for the chain does not match the configuration. + /// @param chainId ChainId of the source chain. + /// @param peerAddress Address of the peer nttManager contract. + error InvalidPeer(uint16 chainId, bytes32 peerAddress); + + /// @notice Peer chain ID cannot be zero. + error InvalidPeerChainIdZero(); + + /// @notice Peer cannot be the zero address. + error InvalidPeerZeroAddress(); + + /// @notice The number of thresholds should not be zero. + error ZeroThreshold(); + + /// @notice The threshold for transceiver attestations is too high. + /// @param threshold The threshold. + /// @param transceivers The number of transceivers. + error ThresholdTooHigh(uint256 threshold, uint256 transceivers); + error RetrievedIncorrectRegisteredTransceivers(uint256 retrieved, uint256 registered); + + /// @notice Sets the transceiver for the given chain. + /// @param transceiver The address of the transceiver. + /// @dev This method can only be executed by the `owner`. + function setTransceiver(address transceiver) external; + + /// @notice Removes the transceiver for the given chain. + /// @param transceiver The address of the transceiver. + /// @dev This method can only be executed by the `owner`. + function removeTransceiver(address transceiver) external; + + /// @notice Sets the threshold for the number of attestations required for a message + /// to be considered valid. + /// @param threshold The new threshold. + /// @dev This method can only be executed by the `owner`. + function setThreshold(uint8 threshold) external; + + /// @notice Returns registered peer contract for a given chain. + /// @param chainId_ chain ID. + function getPeer(uint16 chainId_) external view returns (bytes32); + + /// @notice Sets the corresponding peer. + /// @dev The nttManager that executes the message sets the source nttManager as the peer. + /// @param peerChainId The chain ID of the peer. + /// @param peerContract The address of the peer nttManager contract. + function setPeer(uint16 peerChainId, bytes32 peerContract) external; + + /// @notice Checks if a message has been approved. The message should have at least + /// the minimum threshold of attestations from distinct endpoints. + /// @param digest The digest of the message. + /// @return - Boolean indicating if message has been approved. + function isMessageApproved(bytes32 digest) external view returns (bool); + + /// @notice Checks if a message has been executed. + /// @param digest The digest of the message. + /// @return - Boolean indicating if message has been executed. + function isMessageExecuted(bytes32 digest) external view returns (bool); + + /// @notice Sets the outbound transfer limit for a given chain. + /// @dev This method can only be executed by the `owner`. + /// @param limit The new outbound limit. + function setOutboundLimit(uint256 limit) external; + + /// @notice Sets the inbound transfer limit for a given chain. + /// @dev This method can only be executed by the `owner`. + /// @param limit The new limit. + /// @param chainId The chain to set the limit for. + function setInboundLimit(uint256 limit, uint16 chainId) external; + + /// @notice Returns the next message sequence. + function nextMessageSequence() external view returns (uint64); + + /// @notice Upgrades to a new manager implementation. + /// @dev This is upgraded via a proxy, and can only be executed + /// by the `owner`. + /// @param newImplementation The address of the new implementation. + function upgrade(address newImplementation) external; + + /// @notice Pauses the manager. + function pause() external; + + /// @notice Returns the mode (locking or burning) of the NttManager. + /// @return mode A uint8 corresponding to the mode + function getMode() external view returns (uint8); + + /// @notice Returns the number of Transceivers that must attest to a msgId for + /// it to be considered valid and acted upon. + function getThreshold() external view returns (uint8); + + /// @notice Returns a boolean indicating if the transceiver has attested to the message. + function transceiverAttestedToMessage( + bytes32 digest, + uint8 index + ) external view returns (bool); + + /// @notice Returns the number of attestations for a given message. + function messageAttestations(bytes32 digest) external view returns (uint8 count); + + /// @notice Returns of the address of the token managed by this contract. + function token() external view returns (address); + + /// @notice Returns the chain ID. + function chainId() external view returns (uint16); +} diff --git a/evm/src/interfaces/ITransceiver.sol b/evm/src/interfaces/ITransceiver.sol index a6ed8435d..8e55d725a 100644 --- a/evm/src/interfaces/ITransceiver.sol +++ b/evm/src/interfaces/ITransceiver.sol @@ -11,11 +11,22 @@ interface ITransceiver { bytes32 recipientNttManagerAddress, bytes32 expectedRecipientNttManagerAddress ); + /// @notice Fetch the delivery price for a given recipient chain transfer. + /// @param recipientChain The Wormhole chain ID of the target chain. + /// @param instruction An additional Instruction provided by the Transceiver to be + /// executed on the recipient chain. + /// @return deliveryPrice The cost of delivering a message to the recipient chain, + /// in this chain's native token. function quoteDeliveryPrice( uint16 recipientChain, TransceiverStructs.TransceiverInstruction memory instruction ) external view returns (uint256); + /// @dev Send a message to another chain. + /// @param recipientChain The Wormhole chain ID of the recipient. + /// @param instruction An additional Instruction provided by the Transceiver to be + /// executed on the recipient chain. + /// @param nttManagerMessage A message to be sent to the nttManager on the recipient chain. function sendMessage( uint16 recipientChain, TransceiverStructs.TransceiverInstruction memory instruction, @@ -23,7 +34,9 @@ interface ITransceiver { bytes32 recipientNttManagerAddress ) external payable; + /// @notice Upgrades the transceiver to a new implementation. function upgrade(address newImplementation) external; + /// @notice Transfers the ownership of the transceiver to a new address. function transferTransceiverOwnership(address newOwner) external; } diff --git a/evm/src/interfaces/IWormholeTransceiver.sol b/evm/src/interfaces/IWormholeTransceiver.sol index b3492ffde..c9f6a62b2 100644 --- a/evm/src/interfaces/IWormholeTransceiver.sol +++ b/evm/src/interfaces/IWormholeTransceiver.sol @@ -3,35 +3,44 @@ pragma solidity >=0.8.8 <0.9.0; import "../libraries/TransceiverStructs.sol"; -interface IWormholeTransceiver { +import "./IWormholeTransceiverState.sol"; + +interface IWormholeTransceiver is IWormholeTransceiverState { + struct WormholeTransceiverInstruction { + bool shouldSkipRelayerSend; + } + event ReceivedRelayedMessage(bytes32 digest, uint16 emitterChainId, bytes32 emitterAddress); event ReceivedMessage( bytes32 digest, uint16 emitterChainId, bytes32 emitterAddress, uint64 sequence ); - event SendTransceiverMessage( uint16 recipientChain, TransceiverStructs.TransceiverMessage message ); - event RelayingInfo(uint8 relayingType, uint256 deliveryPayment); - event SetWormholePeer(uint16 chainId, bytes32 peerContract); - event SetIsWormholeRelayingEnabled(uint16 chainId, bool isRelayingEnabled); - event SetIsSpecialRelayingEnabled(uint16 chainId, bool isRelayingEnabled); - event SetIsWormholeEvmChain(uint16 chainId); error InvalidRelayingConfig(uint16 chainId); - error CallerNotRelayer(address caller); - error UnexpectedAdditionalMessages(); - error InvalidVaa(string reason); error InvalidWormholePeer(uint16 chainId, bytes32 peerAddress); - error PeerAlreadySet(uint16 chainId, bytes32 peerAddress); error TransferAlreadyCompleted(bytes32 vaaHash); - error InvalidWormholePeerZeroAddress(); - error InvalidWormholeChainIdZero(); + /// @notice Receive an attested message from the verification layer. This function should verify + /// the `encodedVm` and then deliver the attestation to the transceiver NttManager contract. + /// @param encodedMessage The attested message. function receiveMessage(bytes memory encodedMessage) external; - function isVAAConsumed(bytes32 hash) external view returns (bool); - function getWormholePeer(uint16 chainId) external view returns (bytes32); - function isWormholeRelayingEnabled(uint16 chainId) external view returns (bool); - function isSpecialRelayingEnabled(uint16 chainId) external view returns (bool); - function isWormholeEvmChain(uint16 chainId) external view returns (bool); + + /// @notice Parses the encoded instruction and returns the instruction struct. This instruction + /// is specific to the WormholeTransceiver contract. + /// @param encoded The encoded instruction. + /// @return instruction The parsed `WormholeTransceiverInstruction`. + function parseWormholeTransceiverInstruction(bytes memory encoded) + external + pure + returns (WormholeTransceiverInstruction memory instruction); + + /// @notice Encodes the `WormholeTransceiverInstruction` into a byte array. + /// @param instruction The `WormholeTransceiverInstruction` to encode. + /// @return encoded The encoded instruction. + function encodeWormholeTransceiverInstruction(WormholeTransceiverInstruction memory instruction) + external + pure + returns (bytes memory); } diff --git a/evm/src/interfaces/IWormholeTransceiverState.sol b/evm/src/interfaces/IWormholeTransceiverState.sol new file mode 100644 index 000000000..04bdec017 --- /dev/null +++ b/evm/src/interfaces/IWormholeTransceiverState.sol @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity >=0.8.8 <0.9.0; + +import "../libraries/TransceiverStructs.sol"; + +interface IWormholeTransceiverState { + event RelayingInfo(uint8 relayingType, uint256 deliveryPayment); + event SetWormholePeer(uint16 chainId, bytes32 peerContract); + event SetIsWormholeRelayingEnabled(uint16 chainId, bool isRelayingEnabled); + event SetIsSpecialRelayingEnabled(uint16 chainId, bool isRelayingEnabled); + event SetIsWormholeEvmChain(uint16 chainId); + + error UnexpectedAdditionalMessages(); + error InvalidVaa(string reason); + error PeerAlreadySet(uint16 chainId, bytes32 peerAddress); + error InvalidWormholePeerZeroAddress(); + error InvalidWormholeChainIdZero(); + error CallerNotRelayer(address caller); + + /// @notice Get the corresponding Transceiver contract on other chains that have been registered + /// via governance. This design should be extendable to other chains, so each Transceiver would + /// be potentially concerned with Transceivers on multiple other chains. + /// @dev that peers are registered under Wormhole chain ID values. + /// @param chainId The Wormhole chain ID of the peer to get. + /// @return peerContract The address of the peer contract on the given chain. + function getWormholePeer(uint16 chainId) external view returns (bytes32); + + /// @notice Returns a boolean indicating whether the given VAA hash has been consumed. + /// @param hash The VAA hash to check. + function isVAAConsumed(bytes32 hash) external view returns (bool); + + /// @notice Returns a boolean indicating whether Wormhole relaying is enabled for the given chain. + /// @param chainId The Wormhole chain ID to check. + function isWormholeRelayingEnabled(uint16 chainId) external view returns (bool); + + /// @notice Returns a boolean indicating whether special relaying is enabled for the given chain. + /// @param chainId The Wormhole chain ID to check. + function isSpecialRelayingEnabled(uint16 chainId) external view returns (bool); + + /// @notice Returns a boolean indicating whether the given chain is EVM compatible. + /// @param chainId The Wormhole chain ID to check. + function isWormholeEvmChain(uint16 chainId) external view returns (bool); + + /// @notice Set the Wormhole peer contract for the given chain. + /// @dev This function is only callable by the `owner`. + /// @param chainId The Wormhole chain ID of the peer to set. + /// @param peerContract The address of the peer contract on the given chain. + function setWormholePeer(uint16 chainId, bytes32 peerContract) external; + + /// @notice Set whether the chain is EVM compatible. + /// @dev This function is only callable by the `owner`. + /// @param chainId The Wormhole chain ID to set. + function setIsWormholeEvmChain(uint16 chainId) external; + + /// @notice Set whether Wormhole relaying is enabled for the given chain. + /// @dev This function is only callable by the `owner`. + /// @param chainId The Wormhole chain ID to set. + /// @param isRelayingEnabled A boolean indicating whether relaying is enabled. + function setIsWormholeRelayingEnabled(uint16 chainId, bool isRelayingEnabled) external; + + /// @notice Set whether special relaying is enabled for the given chain. + /// @dev This function is only callable by the `owner`. + /// @param chainId The Wormhole chain ID to set. + /// @param isRelayingEnabled A boolean indicating whether special relaying is enabled. + function setIsSpecialRelayingEnabled(uint16 chainId, bool isRelayingEnabled) external; +} diff --git a/evm/test/IntegrationRelayer.t.sol b/evm/test/IntegrationRelayer.t.sol index 609a81631..2e19cf350 100755 --- a/evm/test/IntegrationRelayer.t.sol +++ b/evm/test/IntegrationRelayer.t.sol @@ -4,16 +4,17 @@ pragma solidity 0.8.19; import "forge-std/Test.sol"; import "forge-std/console.sol"; -import "../src/NttManager.sol"; -import "../src/Transceiver.sol"; +import "../src/NttManager/NttManager.sol"; +import "../src/Transceiver/Transceiver.sol"; import "../src/interfaces/INttManager.sol"; import "../src/interfaces/IRateLimiter.sol"; import "../src/interfaces/INttManagerEvents.sol"; import "../src/interfaces/IRateLimiterEvents.sol"; import "../src/interfaces/IWormholeTransceiver.sol"; +import "../src/interfaces/IWormholeTransceiverState.sol"; import {Utils} from "./libraries/Utils.sol"; import {DummyToken, DummyTokenMintAndBurn} from "./mocks/DummyToken.sol"; -import {WormholeTransceiver} from "../src/WormholeTransceiver.sol"; +import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -35,7 +36,7 @@ contract TestEndToEndRelayerBase is Test { returns (TransceiverStructs.TransceiverInstruction memory) { WormholeTransceiver.WormholeTransceiverInstruction memory instruction = - WormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); bytes memory encodedInstructionWormhole; // Source fork has id 0 and corresponds to chain 1 @@ -97,7 +98,7 @@ contract TestEndToEndRelayer is DummyToken t1 = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t1), NttManager.Mode.LOCKING, chainId1, 1 days); + new MockNttManagerContract(address(t1), INttManager.Mode.LOCKING, chainId1, 1 days); nttManagerChain1 = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); @@ -129,7 +130,7 @@ contract TestEndToEndRelayer is // Chain 2 setup DummyToken t2 = new DummyTokenMintAndBurn(); NttManager implementationChain2 = - new MockNttManagerContract(address(t2), NttManager.Mode.BURNING, chainId2, 1 days); + new MockNttManagerContract(address(t2), INttManager.Mode.BURNING, chainId2, 1 days); nttManagerChain2 = MockNttManagerContract(address(new ERC1967Proxy(address(implementationChain2), ""))); @@ -440,7 +441,7 @@ contract TestRelayerEndToEndManual is vm.chainId(chainId1); DummyToken t1 = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t1), NttManager.Mode.LOCKING, chainId1, 1 days); + new MockNttManagerContract(address(t1), INttManager.Mode.LOCKING, chainId1, 1 days); nttManagerChain1 = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); @@ -466,7 +467,7 @@ contract TestRelayerEndToEndManual is vm.chainId(chainId2); DummyToken t2 = new DummyTokenMintAndBurn(); NttManager implementationChain2 = - new MockNttManagerContract(address(t2), NttManager.Mode.BURNING, chainId2, 1 days); + new MockNttManagerContract(address(t2), INttManager.Mode.BURNING, chainId2, 1 days); nttManagerChain2 = MockNttManagerContract(address(new ERC1967Proxy(address(implementationChain2), ""))); @@ -560,7 +561,7 @@ contract TestRelayerEndToEndManual is nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1))))); vm.prank(userD); vm.expectRevert( - abi.encodeWithSelector(IWormholeTransceiver.CallerNotRelayer.selector, userD) + abi.encodeWithSelector(IWormholeTransceiverState.CallerNotRelayer.selector, userD) ); wormholeTransceiverChain2.receiveWormholeMessages( vaa.payload, diff --git a/evm/test/IntegrationStandalone.t.sol b/evm/test/IntegrationStandalone.t.sol index 3c8e7c5a0..5ccbab41a 100755 --- a/evm/test/IntegrationStandalone.t.sol +++ b/evm/test/IntegrationStandalone.t.sol @@ -4,8 +4,8 @@ pragma solidity >=0.8.8 <0.9.0; import "forge-std/Test.sol"; import "forge-std/console.sol"; -import "../src/NttManager.sol"; -import "../src/Transceiver.sol"; +import "../src/NttManager/NttManager.sol"; +import "../src/Transceiver/Transceiver.sol"; import "../src/interfaces/INttManager.sol"; import "../src/interfaces/IRateLimiter.sol"; import "../src/interfaces/INttManagerEvents.sol"; @@ -13,7 +13,7 @@ import "../src/interfaces/IRateLimiterEvents.sol"; import {Utils} from "./libraries/Utils.sol"; import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; import "../src/interfaces/IWormholeTransceiver.sol"; -import {WormholeTransceiver} from "../src/WormholeTransceiver.sol"; +import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -62,7 +62,7 @@ contract TestEndToEndBase is Test, INttManagerEvents, IRateLimiterEvents { vm.chainId(chainId1); DummyToken t1 = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t1), NttManager.Mode.LOCKING, chainId1, 1 days); + new MockNttManagerContract(address(t1), INttManager.Mode.LOCKING, chainId1, 1 days); nttManagerChain1 = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); @@ -88,7 +88,7 @@ contract TestEndToEndBase is Test, INttManagerEvents, IRateLimiterEvents { vm.chainId(chainId2); DummyToken t2 = new DummyTokenMintAndBurn(); NttManager implementationChain2 = - new MockNttManagerContract(address(t2), NttManager.Mode.BURNING, chainId2, 1 days); + new MockNttManagerContract(address(t2), INttManager.Mode.BURNING, chainId2, 1 days); nttManagerChain2 = MockNttManagerContract(address(new ERC1967Proxy(address(implementationChain2), ""))); @@ -584,7 +584,7 @@ contract TestEndToEndBase is Test, INttManagerEvents, IRateLimiterEvents { function encodeTransceiverInstruction(bool relayer_off) public view returns (bytes memory) { WormholeTransceiver.WormholeTransceiverInstruction memory instruction = - WormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); bytes memory encodedInstructionWormhole = wormholeTransceiverChain1.encodeWormholeTransceiverInstruction(instruction); TransceiverStructs.TransceiverInstruction memory TransceiverInstruction = TransceiverStructs @@ -598,7 +598,7 @@ contract TestEndToEndBase is Test, INttManagerEvents, IRateLimiterEvents { // Encode an instruction for each of the relayers function encodeTransceiverInstructions(bool relayer_off) public view returns (bytes memory) { WormholeTransceiver.WormholeTransceiverInstruction memory instruction = - WormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); bytes memory encodedInstructionWormhole = wormholeTransceiverChain1.encodeWormholeTransceiverInstruction(instruction); diff --git a/evm/test/NttManager.t.sol b/evm/test/NttManager.t.sol index ed247f25b..7ee503b85 100644 --- a/evm/test/NttManager.t.sol +++ b/evm/test/NttManager.t.sol @@ -4,7 +4,7 @@ pragma solidity >=0.8.8 <0.9.0; import "forge-std/Test.sol"; import "forge-std/console.sol"; -import "../src/NttManager.sol"; +import "../src/NttManager/NttManager.sol"; import "../src/interfaces/INttManager.sol"; import "../src/interfaces/IRateLimiter.sol"; import "../src/interfaces/INttManagerEvents.sol"; @@ -48,10 +48,10 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { DummyToken t = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t), NttManager.Mode.LOCKING, chainId, 1 days); + new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, chainId, 1 days); NttManager otherImplementation = - new MockNttManagerContract(address(t), NttManager.Mode.LOCKING, chainId, 1 days); + new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, chainId, 1 days); nttManager = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); nttManager.initialize(); @@ -148,7 +148,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { // a convenience check, not a security one) DummyToken t = new DummyToken(); NttManager altNttManager = - new MockNttManagerContract(address(t), NttManager.Mode.LOCKING, chainId, 1 days); + new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, chainId, 1 days); DummyTransceiver e = new DummyTransceiver(address(altNttManager)); nttManager.setTransceiver(address(e)); } @@ -417,7 +417,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { function test_noAutomaticSlot() public { DummyToken t = new DummyToken(); MockNttManagerContract c = - new MockNttManagerContract(address(t), NttManager.Mode.LOCKING, 1, 1 days); + new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, 1, 1 days); assertEq(c.lastSlot(), 0x0); } @@ -426,7 +426,7 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { vm.startStateDiffRecording(); - new MockNttManagerContract(address(t), NttManager.Mode.LOCKING, 1, 1 days); + new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, 1, 1 days); Utils.assertSafeUpgradeableConstructor(vm.stopAndReturnStateDiff()); } @@ -526,8 +526,9 @@ contract TestNttManager is Test, INttManagerEvents, IRateLimiterEvents { assertEq(token.balanceOf(address(user_B)), transferAmount.untrim(token.decimals())); // Step 2 (upgrade to a new nttManager) - MockNttManagerContract newNttManager = - new MockNttManagerContract(nttManager.token(), NttManager.Mode.LOCKING, chainId, 1 days); + MockNttManagerContract newNttManager = new MockNttManagerContract( + nttManager.token(), INttManager.Mode.LOCKING, chainId, 1 days + ); nttManagerOther.upgrade(address(newNttManager)); TransceiverHelpersLib.attestTransceiversHelper( diff --git a/evm/test/Ownership.t.sol b/evm/test/Ownership.t.sol index 00baa0226..5e34895cc 100644 --- a/evm/test/Ownership.t.sol +++ b/evm/test/Ownership.t.sol @@ -15,7 +15,7 @@ contract OwnershipTests is Test { function setUp() public { DummyToken t = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t), NttManager.Mode.LOCKING, chainId, 1 days); + new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, chainId, 1 days); nttManager = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); nttManager.initialize(); diff --git a/evm/test/RateLimit.t.sol b/evm/test/RateLimit.t.sol index 6af6785e0..484dc01bf 100644 --- a/evm/test/RateLimit.t.sol +++ b/evm/test/RateLimit.t.sol @@ -2,7 +2,7 @@ import "forge-std/Test.sol"; import "../src/interfaces/IRateLimiterEvents.sol"; -import "../src/NttManager.sol"; +import "../src/NttManager/NttManager.sol"; import "./mocks/DummyTransceiver.sol"; import "./mocks/DummyToken.sol"; import "./mocks/MockNttManager.sol"; @@ -36,7 +36,7 @@ contract TestRateLimit is Test, IRateLimiterEvents { DummyToken t = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t), NttManager.Mode.LOCKING, chainId, 1 days); + new MockNttManagerContract(address(t), INttManager.Mode.LOCKING, chainId, 1 days); nttManager = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); nttManager.initialize(); diff --git a/evm/test/TransceiverStructs.t.sol b/evm/test/TransceiverStructs.t.sol index d2596e72b..348fe2a5c 100644 --- a/evm/test/TransceiverStructs.t.sol +++ b/evm/test/TransceiverStructs.t.sol @@ -4,7 +4,7 @@ pragma solidity >=0.8.8 <0.9.0; import "forge-std/Test.sol"; import "../src/libraries/TransceiverStructs.sol"; -import "../src/WormholeTransceiver.sol"; +import "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; contract TestTransceiverStructs is Test { using TrimmedAmountLib for uint256; diff --git a/evm/test/Upgrades.t.sol b/evm/test/Upgrades.t.sol index ecaf82301..16b456618 100644 --- a/evm/test/Upgrades.t.sol +++ b/evm/test/Upgrades.t.sol @@ -4,7 +4,7 @@ pragma solidity >=0.8.8 <0.9.0; import "forge-std/Test.sol"; import "forge-std/console.sol"; -import "../src/NttManager.sol"; +import "../src/NttManager/NttManager.sol"; import "../src/interfaces/INttManager.sol"; import "../src/interfaces/IRateLimiter.sol"; import "../src/interfaces/INttManagerEvents.sol"; @@ -14,8 +14,7 @@ import "../src/libraries/external/Initializable.sol"; import "../src/libraries/Implementation.sol"; import {Utils} from "./libraries/Utils.sol"; import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; -import {WormholeTransceiver} from "../src/WormholeTransceiver.sol"; -import {WormholeTransceiver} from "../src/WormholeTransceiver.sol"; +import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -63,7 +62,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { vm.chainId(chainId1); DummyToken t1 = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t1), NttManager.Mode.LOCKING, chainId1, 1 days); + new MockNttManagerContract(address(t1), INttManager.Mode.LOCKING, chainId1, 1 days); nttManagerChain1 = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); @@ -89,7 +88,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { vm.chainId(chainId2); DummyToken t2 = new DummyTokenMintAndBurn(); NttManager implementationChain2 = - new MockNttManagerContract(address(t2), NttManager.Mode.BURNING, chainId2, 1 days); + new MockNttManagerContract(address(t2), INttManager.Mode.BURNING, chainId2, 1 days); nttManagerChain2 = MockNttManagerContract(address(new ERC1967Proxy(address(implementationChain2), ""))); @@ -130,7 +129,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { function test_basicUpgradeNttManager() public { // Basic call to upgrade with the same contact as ewll NttManager newImplementation = new MockNttManagerContract( - address(nttManagerChain1.token()), NttManager.Mode.LOCKING, chainId1, 1 days + address(nttManagerChain1.token()), INttManager.Mode.LOCKING, chainId1, 1 days ); nttManagerChain1.upgrade(address(newImplementation)); @@ -156,13 +155,13 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { function test_doubleUpgradeNttManager() public { // Basic call to upgrade with the same contact as ewll NttManager newImplementation = new MockNttManagerContract( - address(nttManagerChain1.token()), NttManager.Mode.LOCKING, chainId1, 1 days + address(nttManagerChain1.token()), INttManager.Mode.LOCKING, chainId1, 1 days ); nttManagerChain1.upgrade(address(newImplementation)); basicFunctionality(); newImplementation = new MockNttManagerContract( - address(nttManagerChain1.token()), NttManager.Mode.LOCKING, chainId1, 1 days + address(nttManagerChain1.token()), INttManager.Mode.LOCKING, chainId1, 1 days ); nttManagerChain1.upgrade(address(newImplementation)); @@ -192,7 +191,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { function test_storageSlotNttManager() public { // Basic call to upgrade with the same contact as ewll NttManager newImplementation = new MockNttManagerStorageLayoutChange( - address(nttManagerChain1.token()), NttManager.Mode.LOCKING, chainId1, 1 days + address(nttManagerChain1.token()), INttManager.Mode.LOCKING, chainId1, 1 days ); nttManagerChain1.upgrade(address(newImplementation)); @@ -228,7 +227,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { function test_callMigrateNttManager() public { // Basic call to upgrade with the same contact as ewll NttManager newImplementation = new MockNttManagerMigrateBasic( - address(nttManagerChain1.token()), NttManager.Mode.LOCKING, chainId1, 1 days + address(nttManagerChain1.token()), INttManager.Mode.LOCKING, chainId1, 1 days ); vm.expectRevert("Proper migrate called"); @@ -259,7 +258,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { // Basic call to upgrade with the same contact as ewll NttManager newImplementation = new MockNttManagerImmutableCheck( - address(tnew), NttManager.Mode.LOCKING, chainId1, 1 days + address(tnew), INttManager.Mode.LOCKING, chainId1, 1 days ); vm.expectRevert(); // Reverts with a panic on the assert. So, no way to tell WHY this happened. @@ -296,7 +295,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { // Basic call to upgrade with the same contact as ewll NttManager newImplementation = new MockNttManagerImmutableRemoveCheck( - address(tnew), NttManager.Mode.LOCKING, chainId1, 1 days + address(tnew), INttManager.Mode.LOCKING, chainId1, 1 days ); // Allow an upgrade, since we enabled the ability to edit the immutables within the code @@ -334,7 +333,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { // Basic call to upgrade so that we can get the real implementation. NttManager newImplementation = new MockNttManagerContract( - address(nttManagerChain1.token()), NttManager.Mode.LOCKING, chainId1, 1 days + address(nttManagerChain1.token()), INttManager.Mode.LOCKING, chainId1, 1 days ); nttManagerChain1.upgrade(address(newImplementation)); @@ -539,7 +538,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { function encodeTransceiverInstruction(bool relayer_off) public view returns (bytes memory) { WormholeTransceiver.WormholeTransceiverInstruction memory instruction = - WormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); bytes memory encodedInstructionWormhole = wormholeTransceiverChain1.encodeWormholeTransceiverInstruction(instruction); TransceiverStructs.TransceiverInstruction memory TransceiverInstruction = TransceiverStructs @@ -578,7 +577,7 @@ contract TestInitialize is Test { vm.chainId(chainId1); DummyToken t1 = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t1), NttManager.Mode.LOCKING, chainId1, 1 days); + new MockNttManagerContract(address(t1), INttManager.Mode.LOCKING, chainId1, 1 days); nttManagerChain1 = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); @@ -598,7 +597,7 @@ contract TestInitialize is Test { vm.chainId(chainId1); DummyToken t1 = new DummyToken(); NttManager implementation = - new MockNttManagerContract(address(t1), NttManager.Mode.LOCKING, chainId1, 1 days); + new MockNttManagerContract(address(t1), INttManager.Mode.LOCKING, chainId1, 1 days); nttManagerChain1 = MockNttManagerContract(address(new ERC1967Proxy(address(implementation), ""))); @@ -606,7 +605,7 @@ contract TestInitialize is Test { // Attempt to initialize the contract from a non-deployer account. vm.prank(userA); vm.expectRevert( - abi.encodeWithSignature("UnexpectedOwner(address,address)", address(this), userA) + abi.encodeWithSignature("UnexpectedDeployer(address,address)", address(this), userA) ); nttManagerChain1.initialize(); } diff --git a/evm/test/libraries/NttManagerHelpers.sol b/evm/test/libraries/NttManagerHelpers.sol index bf2277769..bc424c81f 100644 --- a/evm/test/libraries/NttManagerHelpers.sol +++ b/evm/test/libraries/NttManagerHelpers.sol @@ -3,7 +3,7 @@ pragma solidity >=0.8.8 <0.9.0; import "../../src/libraries/TrimmedAmount.sol"; -import "../../src/NttManager.sol"; +import "../../src/NttManager/NttManager.sol"; library NttManagerHelpersLib { uint16 constant SENDING_CHAIN_ID = 1; diff --git a/evm/test/libraries/TransceiverHelpers.sol b/evm/test/libraries/TransceiverHelpers.sol index f29fef580..2d90b940f 100644 --- a/evm/test/libraries/TransceiverHelpers.sol +++ b/evm/test/libraries/TransceiverHelpers.sol @@ -5,7 +5,7 @@ pragma solidity >=0.8.8 <0.9.0; import "./NttManagerHelpers.sol"; import "../mocks/DummyTransceiver.sol"; import "../mocks/DummyToken.sol"; -import "../../src/NttManager.sol"; +import "../../src/NttManager/NttManager.sol"; import "../../src/libraries/TrimmedAmount.sol"; library TransceiverHelpersLib { diff --git a/evm/test/mocks/DummyTransceiver.sol b/evm/test/mocks/DummyTransceiver.sol index 91c80a654..b5d040712 100644 --- a/evm/test/mocks/DummyTransceiver.sol +++ b/evm/test/mocks/DummyTransceiver.sol @@ -3,7 +3,7 @@ pragma solidity >=0.8.8 <0.9.0; import "forge-std/Test.sol"; -import "../../src/Transceiver.sol"; +import "../../src/Transceiver/Transceiver.sol"; import "../interfaces/ITransceiverReceiver.sol"; contract DummyTransceiver is Transceiver, ITransceiverReceiver { diff --git a/evm/test/mocks/MockNttManager.sol b/evm/test/mocks/MockNttManager.sol index bc1d2882c..807c91858 100644 --- a/evm/test/mocks/MockNttManager.sol +++ b/evm/test/mocks/MockNttManager.sol @@ -2,7 +2,7 @@ pragma solidity >=0.8.8 <0.9.0; -import "../../src/NttManager.sol"; +import "../../src/NttManager/NttManager.sol"; contract MockNttManagerContract is NttManager { constructor( diff --git a/evm/test/mocks/MockTransceivers.sol b/evm/test/mocks/MockTransceivers.sol index 17f60f1d9..aced4509d 100644 --- a/evm/test/mocks/MockTransceivers.sol +++ b/evm/test/mocks/MockTransceivers.sol @@ -2,7 +2,7 @@ pragma solidity >=0.8.8 <0.9.0; -import "../../src/WormholeTransceiver.sol"; +import "../../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; contract MockWormholeTransceiverContract is WormholeTransceiver { constructor(