diff --git a/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol b/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol new file mode 100644 index 000000000..fecbde3c3 --- /dev/null +++ b/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity >=0.8.8 <0.9.0; + +import "wormhole-solidity-sdk/WormholeRelayerSDK.sol"; +import "wormhole-solidity-sdk/libraries/BytesParsing.sol"; +import "wormhole-solidity-sdk/interfaces/IWormhole.sol"; + +import "../../libraries/TransceiverHelpers.sol"; +import "../../libraries/TransceiverStructs.sol"; + +import "../../interfaces/IWormholeTransceiver.sol"; +import "../../interfaces/ISpecialRelayer.sol"; +import "../../interfaces/INttManager.sol"; + +import "./WormholeTransceiverState.sol"; + +contract WormholeTransceiver is + IWormholeTransceiver, + IWormholeReceiver, + WormholeTransceiverState +{ + using BytesParsing for bytes; + + constructor( + address nttManager, + address wormholeCoreBridge, + address wormholeRelayerAddr, + address specialRelayerAddr, + uint8 _consistencyLevel + ) + WormholeTransceiverState( + nttManager, + wormholeCoreBridge, + wormholeRelayerAddr, + specialRelayerAddr, + _consistencyLevel + ) + {} + + + // ==================== External Interface =============================================== + + /// @inheritdoc IWormholeTransceiver + function receiveMessage(bytes memory encodedMessage) external { + uint16 sourceChainId; + bytes memory payload; + (sourceChainId, payload) = _verifyMessage(encodedMessage); + + // parse the encoded Transceiver payload + TransceiverStructs.TransceiverMessage memory parsedTransceiverMessage; + TransceiverStructs.NttManagerMessage memory parsedNttManagerMessage; + (parsedTransceiverMessage, parsedNttManagerMessage) = TransceiverStructs + .parseTransceiverAndNttManagerMessage(WH_TRANSCEIVER_PAYLOAD_PREFIX, payload); + + _deliverToNttManager( + sourceChainId, + parsedTransceiverMessage.sourceNttManagerAddress, + parsedTransceiverMessage.recipientNttManagerAddress, + parsedNttManagerMessage + ); + } + + /// @inheritdoc IWormholeReceiver + function receiveWormholeMessages( + bytes memory payload, + bytes[] memory additionalMessages, + bytes32 sourceAddress, + uint16 sourceChain, + bytes32 deliveryHash + ) external payable onlyRelayer { + if (getWormholePeer(sourceChain) != sourceAddress) { + revert InvalidWormholePeer(sourceChain, sourceAddress); + } + + // VAA replay protection: + // - Note that this VAA is for the AR delivery, not for the raw message emitted by the source + // - chain Transceiver contract. The VAAs received by this entrypoint are different than the + // - VAA received by the receiveMessage entrypoint. + if (isVAAConsumed(deliveryHash)) { + revert TransferAlreadyCompleted(deliveryHash); + } + _setVAAConsumed(deliveryHash); + + // We don't honor additional messages in this handler. + if (additionalMessages.length > 0) { + revert UnexpectedAdditionalMessages(); + } + + // emit `ReceivedRelayedMessage` event + emit ReceivedRelayedMessage(deliveryHash, sourceChain, sourceAddress); + + // parse the encoded Transceiver payload + TransceiverStructs.TransceiverMessage memory parsedTransceiverMessage; + TransceiverStructs.NttManagerMessage memory parsedNttManagerMessage; + (parsedTransceiverMessage, parsedNttManagerMessage) = TransceiverStructs + .parseTransceiverAndNttManagerMessage(WH_TRANSCEIVER_PAYLOAD_PREFIX, payload); + + _deliverToNttManager( + sourceChain, + parsedTransceiverMessage.sourceNttManagerAddress, + parsedTransceiverMessage.recipientNttManagerAddress, + parsedNttManagerMessage + ); + } + + /// @inheritdoc IWormholeTransceiver + function parseWormholeTransceiverInstruction(bytes memory encoded) + public + pure + returns (WormholeTransceiverInstruction memory instruction) + { + // If the user doesn't pass in any transceiver instructions then the default is false + if (encoded.length == 0) { + instruction.shouldSkipRelayerSend = false; + return instruction; + } + + uint256 offset = 0; + (instruction.shouldSkipRelayerSend, offset) = encoded.asBoolUnchecked(offset); + encoded.checkLength(offset); + } + + /// @inheritdoc IWormholeTransceiver + function encodeWormholeTransceiverInstruction(WormholeTransceiverInstruction memory instruction) + public + pure + returns (bytes memory) + { + return abi.encodePacked(instruction.shouldSkipRelayerSend); + } + + + // ==================== Internal ======================================================== + + function _quoteDeliveryPrice( + uint16 targetChain, + TransceiverStructs.TransceiverInstruction memory instruction + ) internal view override returns (uint256 nativePriceQuote) { + // Check the special instruction up front to see if we should skip sending via a relayer + WormholeTransceiverInstruction memory weIns = + parseWormholeTransceiverInstruction(instruction.payload); + if (weIns.shouldSkipRelayerSend) { + return 0; + } + + if (_checkInvalidRelayingConfig(targetChain)) { + revert InvalidRelayingConfig(targetChain); + } + + if (_shouldRelayViaStandardRelaying(targetChain)) { + (uint256 cost,) = wormholeRelayer.quoteEVMDeliveryPrice(targetChain, 0, GAS_LIMIT); + return cost; + } else if (isSpecialRelayingEnabled(targetChain)) { + uint256 cost = specialRelayer.quoteDeliveryPrice(getNttManagerToken(), targetChain, 0); + return cost; + } else { + return 0; + } + } + + function _sendMessage( + uint16 recipientChain, + uint256 deliveryPayment, + address caller, + bytes32 recipientNttManagerAddress, + TransceiverStructs.TransceiverInstruction memory instruction, + bytes memory nttManagerMessage + ) internal override { + ( + TransceiverStructs.TransceiverMessage memory transceiverMessage, + bytes memory encodedTransceiverPayload + ) = TransceiverStructs.buildAndEncodeTransceiverMessage( + WH_TRANSCEIVER_PAYLOAD_PREFIX, + toWormholeFormat(caller), + recipientNttManagerAddress, + nttManagerMessage, + new bytes(0) + ); + + WormholeTransceiverInstruction memory weIns = + parseWormholeTransceiverInstruction(instruction.payload); + + if (!weIns.shouldSkipRelayerSend && _shouldRelayViaStandardRelaying(recipientChain)) { + wormholeRelayer.sendPayloadToEvm{value: deliveryPayment}( + recipientChain, + fromWormholeFormat(getWormholePeer(recipientChain)), + encodedTransceiverPayload, + 0, + GAS_LIMIT + ); + } else if (!weIns.shouldSkipRelayerSend && isSpecialRelayingEnabled(recipientChain)) { + uint64 sequence = + wormhole.publishMessage(0, encodedTransceiverPayload, consistencyLevel); + specialRelayer.requestDelivery{value: deliveryPayment}( + getNttManagerToken(), recipientChain, 0, sequence + ); + } else { + wormhole.publishMessage(0, encodedTransceiverPayload, consistencyLevel); + } + + emit SendTransceiverMessage(recipientChain, transceiverMessage); + } + + function _verifyMessage(bytes memory encodedMessage) internal returns (uint16, bytes memory) { + // verify VAA against Wormhole Core Bridge contract + (IWormhole.VM memory vm, bool valid, string memory reason) = + wormhole.parseAndVerifyVM(encodedMessage); + + // ensure that the VAA is valid + if (!valid) { + revert InvalidVaa(reason); + } + + // ensure that the message came from a registered peer contract + if (!_verifyBridgeVM(vm)) { + revert InvalidWormholePeer(vm.emitterChainId, vm.emitterAddress); + } + + // save the VAA hash in storage to protect against replay attacks. + if (isVAAConsumed(vm.hash)) { + revert TransferAlreadyCompleted(vm.hash); + } + _setVAAConsumed(vm.hash); + + // emit `ReceivedMessage` event + emit ReceivedMessage(vm.hash, vm.emitterChainId, vm.emitterAddress, vm.sequence); + + return (vm.emitterChainId, vm.payload); + } + + function _verifyBridgeVM(IWormhole.VM memory vm) internal view returns (bool) { + checkFork(wormholeTransceiver_evmChainId); + return getWormholePeer(vm.emitterChainId) == vm.emitterAddress; + } +} diff --git a/evm/src/Transceiver/WormholeTransceiver.sol b/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiverState.sol similarity index 51% rename from evm/src/Transceiver/WormholeTransceiver.sol rename to evm/src/Transceiver/WormholeTransceiver/WormholeTransceiverState.sol index 1488756fb..15e47bd15 100644 --- a/evm/src/Transceiver/WormholeTransceiver.sol +++ b/evm/src/Transceiver/WormholeTransceiver/WormholeTransceiverState.sol @@ -5,20 +5,29 @@ import "wormhole-solidity-sdk/WormholeRelayerSDK.sol"; import "wormhole-solidity-sdk/libraries/BytesParsing.sol"; import "wormhole-solidity-sdk/interfaces/IWormhole.sol"; -import "../libraries/TransceiverHelpers.sol"; -import "../libraries/TransceiverStructs.sol"; +import "../../libraries/TransceiverHelpers.sol"; +import "../../libraries/TransceiverStructs.sol"; -import "../interfaces/IWormholeTransceiver.sol"; -import "../interfaces/ISpecialRelayer.sol"; -import "../interfaces/INttManager.sol"; +import "../../interfaces/IWormholeTransceiver.sol"; +import "../../interfaces/IWormholeTransceiverState.sol"; +import "../../interfaces/ISpecialRelayer.sol"; +import "../../interfaces/INttManager.sol"; -import "./Transceiver.sol"; +import "../Transceiver.sol"; -contract WormholeTransceiver is Transceiver, IWormholeTransceiver, IWormholeReceiver { +abstract contract WormholeTransceiverState is IWormholeTransceiverState, Transceiver { using BytesParsing for bytes; - uint256 public constant GAS_LIMIT = 500000; + // ==================== Immutables =============================================== uint8 public immutable consistencyLevel; + IWormhole public immutable wormhole; + IWormholeRelayer public immutable wormholeRelayer; + ISpecialRelayer public immutable specialRelayer; + uint256 public immutable wormholeTransceiver_evmChainId; + + + // ==================== Constants ================================================ + uint256 public constant GAS_LIMIT = 500000; /// @dev Prefix for all TransceiverMessage payloads /// This is 0x99'E''W''H' @@ -34,16 +43,39 @@ contract WormholeTransceiver is Transceiver, IWormholeTransceiver, IWormholeRece /// This is bytes4(keccak256("WormholePeerRegistration")) bytes4 constant WH_PEER_REGISTRATION_PREFIX = 0x18fc67c2; - IWormhole public immutable wormhole; - IWormholeRelayer public immutable wormholeRelayer; - ISpecialRelayer public immutable specialRelayer; - uint256 public immutable wormholeTransceiver_evmChainId; - struct WormholeTransceiverInstruction { - bool shouldSkipRelayerSend; + constructor( + address nttManager, + address wormholeCoreBridge, + address wormholeRelayerAddr, + address specialRelayerAddr, + uint8 _consistencyLevel + ) Transceiver(nttManager) { + wormhole = IWormhole(wormholeCoreBridge); + wormholeRelayer = IWormholeRelayer(wormholeRelayerAddr); + specialRelayer = ISpecialRelayer(specialRelayerAddr); + wormholeTransceiver_evmChainId = block.chainid; + consistencyLevel = _consistencyLevel; } - /// =============== STORAGE =============================================== + function _initialize() internal override { + super._initialize(); + _initializeTransceiver(); + } + + function _initializeTransceiver() internal { + TransceiverStructs.TransceiverInit memory init = TransceiverStructs.TransceiverInit({ + transceiverIdentifier: WH_TRANSCEIVER_INIT_PREFIX, + nttManagerAddress: toWormholeFormat(nttManager), + nttManagerMode: INttManager(nttManager).getMode(), + tokenAddress: toWormholeFormat(nttManagerToken), + tokenDecimals: INttManager(nttManager).tokenDecimals() + }); + wormhole.publishMessage(0, TransceiverStructs.encodeTransceiverInit(init), consistencyLevel); + } + + + // =============== Storage =============================================== bytes32 private constant WORMHOLE_CONSUMED_VAAS_SLOT = bytes32(uint256(keccak256("whTransceiver.consumedVAAs")) - 1); @@ -60,7 +92,8 @@ contract WormholeTransceiver is Transceiver, IWormholeTransceiver, IWormholeRece bytes32 private constant WORMHOLE_EVM_CHAIN_IDS = bytes32(uint256(keccak256("whTransceiver.evmChainIds")) - 1); - /// =============== GETTERS/SETTERS ======================================== + + // =============== Storage Setters/Getters ======================================== function _getWormholeConsumedVAAsStorage() internal @@ -117,233 +150,75 @@ contract WormholeTransceiver is Transceiver, IWormholeTransceiver, IWormholeRece } } - modifier onlyRelayer() { - if (msg.sender != address(wormholeRelayer)) { - revert CallerNotRelayer(msg.sender); - } - _; - } - - constructor( - address nttManager, - address wormholeCoreBridge, - address wormholeRelayerAddr, - address specialRelayerAddr, - uint8 _consistencyLevel - ) Transceiver(nttManager) { - wormhole = IWormhole(wormholeCoreBridge); - wormholeRelayer = IWormholeRelayer(wormholeRelayerAddr); - specialRelayer = ISpecialRelayer(specialRelayerAddr); - wormholeTransceiver_evmChainId = block.chainid; - consistencyLevel = _consistencyLevel; - } + // =============== Public Getters ====================================================== - function _initialize() internal override { - super._initialize(); - _initializeTransceiver(); + /// @inheritdoc IWormholeTransceiverState + function isVAAConsumed(bytes32 hash) public view returns (bool) { + return _getWormholeConsumedVAAsStorage()[hash]; } - function _initializeTransceiver() internal { - TransceiverStructs.TransceiverInit memory init = TransceiverStructs.TransceiverInit({ - transceiverIdentifier: WH_TRANSCEIVER_INIT_PREFIX, - nttManagerAddress: toWormholeFormat(nttManager), - nttManagerMode: INttManager(nttManager).getMode(), - tokenAddress: toWormholeFormat(nttManagerToken), - tokenDecimals: INttManager(nttManager).tokenDecimals() - }); - wormhole.publishMessage(0, TransceiverStructs.encodeTransceiverInit(init), consistencyLevel); + /// @inheritdoc IWormholeTransceiverState + function getWormholePeer(uint16 chainId) public view returns (bytes32) { + return _getWormholePeersStorage()[chainId]; } - function _checkInvalidRelayingConfig(uint16 chainId) internal view returns (bool) { - return isWormholeRelayingEnabled(chainId) && !isWormholeEvmChain(chainId); + /// @inheritdoc IWormholeTransceiverState + function isWormholeRelayingEnabled(uint16 chainId) public view returns (bool) { + return toBool(_getWormholeRelayingEnabledChainsStorage()[chainId]); } - function _shouldRelayViaStandardRelaying(uint16 chainId) internal view returns (bool) { - return isWormholeRelayingEnabled(chainId) && isWormholeEvmChain(chainId); + /// @inheritdoc IWormholeTransceiverState + function isSpecialRelayingEnabled(uint16 chainId) public view returns (bool) { + return toBool(_getSpecialRelayingEnabledChainsStorage()[chainId]); } - function _quoteDeliveryPrice( - uint16 targetChain, - TransceiverStructs.TransceiverInstruction memory instruction - ) internal view override returns (uint256 nativePriceQuote) { - // Check the special instruction up front to see if we should skip sending via a relayer - WormholeTransceiverInstruction memory weIns = - parseWormholeTransceiverInstruction(instruction.payload); - if (weIns.shouldSkipRelayerSend) { - return 0; - } - - if (_checkInvalidRelayingConfig(targetChain)) { - revert InvalidRelayingConfig(targetChain); - } - - if (_shouldRelayViaStandardRelaying(targetChain)) { - (uint256 cost,) = wormholeRelayer.quoteEVMDeliveryPrice(targetChain, 0, GAS_LIMIT); - return cost; - } else if (isSpecialRelayingEnabled(targetChain)) { - uint256 cost = specialRelayer.quoteDeliveryPrice(getNttManagerToken(), targetChain, 0); - return cost; - } else { - return 0; - } + /// @inheritdoc IWormholeTransceiverState + function isWormholeEvmChain(uint16 chainId) public view returns (bool) { + return toBool(_getWormholeEvmChainIdsStorage()[chainId]); } - function _sendMessage( - uint16 recipientChain, - uint256 deliveryPayment, - address caller, - bytes32 recipientNttManagerAddress, - TransceiverStructs.TransceiverInstruction memory instruction, - bytes memory nttManagerMessage - ) internal override { - ( - TransceiverStructs.TransceiverMessage memory transceiverMessage, - bytes memory encodedTransceiverPayload - ) = TransceiverStructs.buildAndEncodeTransceiverMessage( - WH_TRANSCEIVER_PAYLOAD_PREFIX, - toWormholeFormat(caller), - recipientNttManagerAddress, - nttManagerMessage, - new bytes(0) - ); - WormholeTransceiverInstruction memory weIns = - parseWormholeTransceiverInstruction(instruction.payload); - - if (!weIns.shouldSkipRelayerSend && _shouldRelayViaStandardRelaying(recipientChain)) { - wormholeRelayer.sendPayloadToEvm{value: deliveryPayment}( - recipientChain, - fromWormholeFormat(getWormholePeer(recipientChain)), - encodedTransceiverPayload, - 0, - GAS_LIMIT - ); - } else if (!weIns.shouldSkipRelayerSend && isSpecialRelayingEnabled(recipientChain)) { - uint64 sequence = - wormhole.publishMessage(0, encodedTransceiverPayload, consistencyLevel); - specialRelayer.requestDelivery{value: deliveryPayment}( - getNttManagerToken(), recipientChain, 0, sequence - ); - } else { - wormhole.publishMessage(0, encodedTransceiverPayload, consistencyLevel); - } + // =============== Admin =============================================================== - emit SendTransceiverMessage(recipientChain, transceiverMessage); + /// @inheritdoc IWormholeTransceiverState + function setWormholePeer(uint16 peerChainId, bytes32 peerContract) external onlyOwner { + _setWormholePeer(peerChainId, peerContract); } - function receiveWormholeMessages( - bytes memory payload, - bytes[] memory additionalMessages, - bytes32 sourceAddress, - uint16 sourceChain, - bytes32 deliveryHash - ) external payable onlyRelayer { - if (getWormholePeer(sourceChain) != sourceAddress) { - revert InvalidWormholePeer(sourceChain, sourceAddress); - } - - // VAA replay protection - // Note that this VAA is for the AR delivery, not for the raw message emitted by the source chain Transceiver contract. - // The VAAs received by this entrypoint are different than the VAA received by the receiveMessage entrypoint. - if (isVAAConsumed(deliveryHash)) { - revert TransferAlreadyCompleted(deliveryHash); - } - _setVAAConsumed(deliveryHash); - - // We don't honor additional message in this handler. - if (additionalMessages.length > 0) { - revert UnexpectedAdditionalMessages(); - } - - // emit `ReceivedRelayedMessage` event - emit ReceivedRelayedMessage(deliveryHash, sourceChain, sourceAddress); - - // parse the encoded Transceiver payload - TransceiverStructs.TransceiverMessage memory parsedTransceiverMessage; - TransceiverStructs.NttManagerMessage memory parsedNttManagerMessage; - (parsedTransceiverMessage, parsedNttManagerMessage) = TransceiverStructs - .parseTransceiverAndNttManagerMessage(WH_TRANSCEIVER_PAYLOAD_PREFIX, payload); - - _deliverToNttManager( - sourceChain, - parsedTransceiverMessage.sourceNttManagerAddress, - parsedTransceiverMessage.recipientNttManagerAddress, - parsedNttManagerMessage - ); + /// @inheritdoc IWormholeTransceiverState + function setIsWormholeEvmChain(uint16 chainId) external onlyOwner { + _setIsWormholeEvmChain(chainId); } - /// @notice Receive an attested message from the verification layer - /// This function should verify the encodedVm and then deliver the attestation to the transceiver nttManager contract. - function receiveMessage(bytes memory encodedMessage) external { - uint16 sourceChainId; - bytes memory payload; - (sourceChainId, payload) = _verifyMessage(encodedMessage); - - // parse the encoded Transceiver payload - TransceiverStructs.TransceiverMessage memory parsedTransceiverMessage; - TransceiverStructs.NttManagerMessage memory parsedNttManagerMessage; - (parsedTransceiverMessage, parsedNttManagerMessage) = TransceiverStructs - .parseTransceiverAndNttManagerMessage(WH_TRANSCEIVER_PAYLOAD_PREFIX, payload); - - _deliverToNttManager( - sourceChainId, - parsedTransceiverMessage.sourceNttManagerAddress, - parsedTransceiverMessage.recipientNttManagerAddress, - parsedNttManagerMessage - ); + /// @inheritdoc IWormholeTransceiverState + function setIsWormholeRelayingEnabled(uint16 chainId, bool isEnabled) external onlyOwner { + _setIsWormholeRelayingEnabled(chainId, isEnabled); } - function _verifyMessage(bytes memory encodedMessage) internal returns (uint16, bytes memory) { - // verify VAA against Wormhole Core Bridge contract - (IWormhole.VM memory vm, bool valid, string memory reason) = - wormhole.parseAndVerifyVM(encodedMessage); - // ensure that the VAA is valid - if (!valid) { - revert InvalidVaa(reason); - } - - // ensure that the message came from a registered peer contract - if (!_verifyBridgeVM(vm)) { - revert InvalidWormholePeer(vm.emitterChainId, vm.emitterAddress); - } + // ============= Internal =============================================================== - // save the VAA hash in storage to protect against replay attacks. - if (isVAAConsumed(vm.hash)) { - revert TransferAlreadyCompleted(vm.hash); + function _setIsWormholeEvmChain(uint16 chainId) internal { + if (chainId == 0) { + revert InvalidWormholeChainIdZero(); } - _setVAAConsumed(vm.hash); - - // emit `ReceivedMessage` event - emit ReceivedMessage(vm.hash, vm.emitterChainId, vm.emitterAddress, vm.sequence); + _getWormholeEvmChainIdsStorage()[chainId] = TRUE; - return (vm.emitterChainId, vm.payload); + emit SetIsWormholeEvmChain(chainId); } - function _verifyBridgeVM(IWormhole.VM memory vm) internal view returns (bool) { - checkFork(wormholeTransceiver_evmChainId); - return getWormholePeer(vm.emitterChainId) == vm.emitterAddress; + function _checkInvalidRelayingConfig(uint16 chainId) internal view returns (bool) { + return isWormholeRelayingEnabled(chainId) && !isWormholeEvmChain(chainId); } - function isVAAConsumed(bytes32 hash) public view returns (bool) { - return _getWormholeConsumedVAAsStorage()[hash]; + function _shouldRelayViaStandardRelaying(uint16 chainId) internal view returns (bool) { + return isWormholeRelayingEnabled(chainId) && isWormholeEvmChain(chainId); } function _setVAAConsumed(bytes32 hash) internal { _getWormholeConsumedVAAsStorage()[hash] = true; } - /// @notice Get the corresponding Transceiver contract on other chains that have been registered via governance. - /// This design should be extendable to other chains, so each Transceiver would be potentially concerned with Transceivers on multiple other chains - /// Note that peers are registered under wormhole chainID values - function getWormholePeer(uint16 chainId) public view returns (bytes32) { - return _getWormholePeersStorage()[chainId]; - } - - function setWormholePeer(uint16 peerChainId, bytes32 peerContract) external onlyOwner { - _setWormholePeer(peerChainId, peerContract); - } - function _setWormholePeer(uint16 chainId, bytes32 peerContract) internal { if (chainId == 0) { revert InvalidWormholeChainIdZero(); @@ -377,14 +252,6 @@ contract WormholeTransceiver is Transceiver, IWormholeTransceiver, IWormholeRece emit SetWormholePeer(chainId, peerContract); } - function isWormholeRelayingEnabled(uint16 chainId) public view returns (bool) { - return toBool(_getWormholeRelayingEnabledChainsStorage()[chainId]); - } - - function setIsWormholeRelayingEnabled(uint16 chainId, bool isEnabled) external onlyOwner { - _setIsWormholeRelayingEnabled(chainId, isEnabled); - } - function _setIsWormholeRelayingEnabled(uint16 chainId, bool isEnabled) internal { if (chainId == 0) { revert InvalidWormholeChainIdZero(); @@ -394,10 +261,6 @@ contract WormholeTransceiver is Transceiver, IWormholeTransceiver, IWormholeRece emit SetIsWormholeRelayingEnabled(chainId, isEnabled); } - function isSpecialRelayingEnabled(uint16 chainId) public view returns (bool) { - return toBool(_getSpecialRelayingEnabledChainsStorage()[chainId]); - } - function _setIsSpecialRelayingEnabled(uint16 chainId, bool isEnabled) internal { if (chainId == 0) { revert InvalidWormholeChainIdZero(); @@ -407,44 +270,13 @@ contract WormholeTransceiver is Transceiver, IWormholeTransceiver, IWormholeRece emit SetIsSpecialRelayingEnabled(chainId, isEnabled); } - function isWormholeEvmChain(uint16 chainId) public view returns (bool) { - return toBool(_getWormholeEvmChainIdsStorage()[chainId]); - } - function setIsWormholeEvmChain(uint16 chainId) external onlyOwner { - _setIsWormholeEvmChain(chainId); - } + // =============== MODIFIERS =============================================== - function _setIsWormholeEvmChain(uint16 chainId) internal { - if (chainId == 0) { - revert InvalidWormholeChainIdZero(); - } - _getWormholeEvmChainIdsStorage()[chainId] = TRUE; - - emit SetIsWormholeEvmChain(chainId); - } - - function parseWormholeTransceiverInstruction(bytes memory encoded) - public - pure - returns (WormholeTransceiverInstruction memory instruction) - { - // If the user doesn't pass in any transceiver instructions then the default is false - if (encoded.length == 0) { - instruction.shouldSkipRelayerSend = false; - return instruction; + modifier onlyRelayer() { + if (msg.sender != address(wormholeRelayer)) { + revert CallerNotRelayer(msg.sender); } - - uint256 offset = 0; - (instruction.shouldSkipRelayerSend, offset) = encoded.asBoolUnchecked(offset); - encoded.checkLength(offset); - } - - function encodeWormholeTransceiverInstruction(WormholeTransceiverInstruction memory instruction) - public - pure - returns (bytes memory) - { - return abi.encodePacked(instruction.shouldSkipRelayerSend); + _; } } diff --git a/evm/src/interfaces/IWormholeTransceiver.sol b/evm/src/interfaces/IWormholeTransceiver.sol index a2a959c49..75cd3c67c 100644 --- a/evm/src/interfaces/IWormholeTransceiver.sol +++ b/evm/src/interfaces/IWormholeTransceiver.sol @@ -3,7 +3,13 @@ pragma solidity >=0.8.8 <0.9.0; import "../libraries/TransceiverStructs.sol"; -interface IWormholeTransceiver { +import "./IWormholeTransceiverState.sol"; + +interface IWormholeTransceiver is IWormholeTransceiverState { + struct WormholeTransceiverInstruction { + bool shouldSkipRelayerSend; + } + event ReceivedRelayedMessage(bytes32 digest, uint16 emitterChainId, bytes32 emitterAddress); event ReceivedMessage( bytes32 digest, uint16 emitterChainId, bytes32 emitterAddress, uint64 sequence @@ -12,25 +18,30 @@ interface IWormholeTransceiver { event SendTransceiverMessage( uint16 recipientChain, TransceiverStructs.TransceiverMessage message ); - event SetWormholePeer(uint16 chainId, bytes32 peerContract); - event SetIsWormholeRelayingEnabled(uint16 chainId, bool isRelayingEnabled); - event SetIsSpecialRelayingEnabled(uint16 chainId, bool isRelayingEnabled); - event SetIsWormholeEvmChain(uint16 chainId); error InvalidRelayingConfig(uint16 chainId); - error CallerNotRelayer(address caller); - error UnexpectedAdditionalMessages(); - error InvalidVaa(string reason); error InvalidWormholePeer(uint16 chainId, bytes32 peerAddress); - error PeerAlreadySet(uint16 chainId, bytes32 peerAddress); error TransferAlreadyCompleted(bytes32 vaaHash); - error InvalidWormholePeerZeroAddress(); - error InvalidWormholeChainIdZero(); + /// @notice Receive an attested message from the verification layer. This function should verify + /// the `encodedVm` and then deliver the attestation to the transceiver NttManager contract. + /// @param encodedMessage The attested message. function receiveMessage(bytes memory encodedMessage) external; - function isVAAConsumed(bytes32 hash) external view returns (bool); - function getWormholePeer(uint16 chainId) external view returns (bytes32); - function isWormholeRelayingEnabled(uint16 chainId) external view returns (bool); - function isSpecialRelayingEnabled(uint16 chainId) external view returns (bool); - function isWormholeEvmChain(uint16 chainId) external view returns (bool); + + /// @notice Parses the encoded instruction and returns the instruction struct. This instruction + /// is specific to the WormholeTransceiver contract. + /// @param encoded The encoded instruction. + /// @return instruction The parsed `WormholeTransceiverInstruction`. + function parseWormholeTransceiverInstruction(bytes memory encoded) + external + pure + returns (WormholeTransceiverInstruction memory instruction); + + /// @notice Encodes the `WormholeTransceiverInstruction` into a byte array. + /// @param instruction The `WormholeTransceiverInstruction` to encode. + /// @return encoded The encoded instruction. + function encodeWormholeTransceiverInstruction(WormholeTransceiverInstruction memory instruction) + external + pure + returns (bytes memory); } diff --git a/evm/src/interfaces/IWormholeTransceiverState.sol b/evm/src/interfaces/IWormholeTransceiverState.sol new file mode 100644 index 000000000..578c2d3a8 --- /dev/null +++ b/evm/src/interfaces/IWormholeTransceiverState.sol @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity >=0.8.8 <0.9.0; + +import "../libraries/TransceiverStructs.sol"; + +interface IWormholeTransceiverState { + event SetWormholePeer(uint16 chainId, bytes32 peerContract); + event SetIsWormholeRelayingEnabled(uint16 chainId, bool isRelayingEnabled); + event SetIsSpecialRelayingEnabled(uint16 chainId, bool isRelayingEnabled); + event SetIsWormholeEvmChain(uint16 chainId); + + error UnexpectedAdditionalMessages(); + error InvalidVaa(string reason); + error PeerAlreadySet(uint16 chainId, bytes32 peerAddress); + error InvalidWormholePeerZeroAddress(); + error InvalidWormholeChainIdZero(); + error CallerNotRelayer(address caller); + + /// @notice Get the corresponding Transceiver contract on other chains that have been registered + /// via governance. This design should be extendable to other chains, so each Transceiver would + /// be potentially concerned with Transceivers on multiple other chains. + /// @dev that peers are registered under Wormhole chain ID values. + /// @param chainId The Wormhole chain ID of the peer to get. + /// @return peerContract The address of the peer contract on the given chain. + function getWormholePeer(uint16 chainId) external view returns (bytes32); + + /// @notice Returns a boolean indicating whether the given VAA hash has been consumed. + /// @param hash The VAA hash to check. + function isVAAConsumed(bytes32 hash) external view returns (bool); + + /// @notice Returns a boolean indicating whether Wormhole relaying is enabled for the given chain. + /// @param chainId The Wormhole chain ID to check. + function isWormholeRelayingEnabled(uint16 chainId) external view returns (bool); + + /// @notice Returns a boolean indicating whether special relaying is enabled for the given chain. + /// @param chainId The Wormhole chain ID to check. + function isSpecialRelayingEnabled(uint16 chainId) external view returns (bool); + + /// @notice Returns a boolean indicating whether the given chain is EVM compatible. + /// @param chainId The Wormhole chain ID to check. + function isWormholeEvmChain(uint16 chainId) external view returns (bool); + + /// @notice Set the Wormhole peer contract for the given chain. + /// @dev This function is only callable by the `owner`. + /// @param chainId The Wormhole chain ID of the peer to set. + /// @param peerContract The address of the peer contract on the given chain. + function setWormholePeer(uint16 chainId, bytes32 peerContract) external; + + /// @notice Set whether the chain is EVM compatible. + /// @dev This function is only callable by the `owner`. + /// @param chainId The Wormhole chain ID to set. + function setIsWormholeEvmChain(uint16 chainId) external; + + /// @notice Set whether Wormhole relaying is enabled for the given chain. + /// @dev This function is only callable by the `owner`. + /// @param chainId The Wormhole chain ID to set. + /// @param isRelayingEnabled A boolean indicating whether relaying is enabled. + function setIsWormholeRelayingEnabled(uint16 chainId, bool isRelayingEnabled) external; +} diff --git a/evm/test/IntegrationRelayer.t.sol b/evm/test/IntegrationRelayer.t.sol index bfaf1a623..8a77692f6 100755 --- a/evm/test/IntegrationRelayer.t.sol +++ b/evm/test/IntegrationRelayer.t.sol @@ -11,9 +11,10 @@ import "../src/interfaces/IRateLimiter.sol"; import "../src/interfaces/INttManagerEvents.sol"; import "../src/interfaces/IRateLimiterEvents.sol"; import "../src/interfaces/IWormholeTransceiver.sol"; +import "../src/interfaces/IWormholeTransceiverState.sol"; import {Utils} from "./libraries/Utils.sol"; import {DummyToken, DummyTokenMintAndBurn} from "./mocks/DummyToken.sol"; -import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver.sol"; +import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -35,7 +36,7 @@ contract TestEndToEndRelayerBase is Test { returns (TransceiverStructs.TransceiverInstruction memory) { WormholeTransceiver.WormholeTransceiverInstruction memory instruction = - WormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); bytes memory encodedInstructionWormhole; // Source fork has id 0 and corresponds to chain 1 @@ -560,7 +561,7 @@ contract TestRelayerEndToEndManual is nttManagerChain2.setPeer(chainId1, bytes32(uint256(uint160(address(nttManagerChain1))))); vm.prank(userD); vm.expectRevert( - abi.encodeWithSelector(IWormholeTransceiver.CallerNotRelayer.selector, userD) + abi.encodeWithSelector(IWormholeTransceiverState.CallerNotRelayer.selector, userD) ); wormholeTransceiverChain2.receiveWormholeMessages( vaa.payload, diff --git a/evm/test/IntegrationStandalone.t.sol b/evm/test/IntegrationStandalone.t.sol index 252686cd6..680675039 100755 --- a/evm/test/IntegrationStandalone.t.sol +++ b/evm/test/IntegrationStandalone.t.sol @@ -13,7 +13,7 @@ import "../src/interfaces/IRateLimiterEvents.sol"; import {Utils} from "./libraries/Utils.sol"; import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; import "../src/interfaces/IWormholeTransceiver.sol"; -import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver.sol"; +import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -584,7 +584,7 @@ contract TestEndToEndBase is Test, INttManagerEvents, IRateLimiterEvents { function encodeTransceiverInstruction(bool relayer_off) public view returns (bytes memory) { WormholeTransceiver.WormholeTransceiverInstruction memory instruction = - WormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); bytes memory encodedInstructionWormhole = wormholeTransceiverChain1.encodeWormholeTransceiverInstruction(instruction); TransceiverStructs.TransceiverInstruction memory TransceiverInstruction = TransceiverStructs @@ -598,7 +598,7 @@ contract TestEndToEndBase is Test, INttManagerEvents, IRateLimiterEvents { // Encode an instruction for each of the relayers function encodeTransceiverInstructions(bool relayer_off) public view returns (bytes memory) { WormholeTransceiver.WormholeTransceiverInstruction memory instruction = - WormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); bytes memory encodedInstructionWormhole = wormholeTransceiverChain1.encodeWormholeTransceiverInstruction(instruction); diff --git a/evm/test/TransceiverStructs.t.sol b/evm/test/TransceiverStructs.t.sol index 92eeee399..15171966f 100644 --- a/evm/test/TransceiverStructs.t.sol +++ b/evm/test/TransceiverStructs.t.sol @@ -4,7 +4,7 @@ pragma solidity >=0.8.8 <0.9.0; import "forge-std/Test.sol"; import "../src/libraries/TransceiverStructs.sol"; -import "../src/Transceiver/WormholeTransceiver.sol"; +import "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; contract TestTransceiverStructs is Test { using NormalizedAmountLib for uint256; diff --git a/evm/test/Upgrades.t.sol b/evm/test/Upgrades.t.sol index e5f3c2ead..9a47cf0c7 100644 --- a/evm/test/Upgrades.t.sol +++ b/evm/test/Upgrades.t.sol @@ -14,7 +14,7 @@ import "../src/libraries/external/Initializable.sol"; import "../src/libraries/Implementation.sol"; import {Utils} from "./libraries/Utils.sol"; import {DummyToken, DummyTokenMintAndBurn} from "./NttManager.t.sol"; -import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver.sol"; +import {WormholeTransceiver} from "../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; import "../src/libraries/TransceiverStructs.sol"; import "./mocks/MockNttManager.sol"; import "./mocks/MockTransceivers.sol"; @@ -538,7 +538,7 @@ contract TestUpgrades is Test, INttManagerEvents, IRateLimiterEvents { function encodeTransceiverInstruction(bool relayer_off) public view returns (bytes memory) { WormholeTransceiver.WormholeTransceiverInstruction memory instruction = - WormholeTransceiver.WormholeTransceiverInstruction(relayer_off); + IWormholeTransceiver.WormholeTransceiverInstruction(relayer_off); bytes memory encodedInstructionWormhole = wormholeTransceiverChain1.encodeWormholeTransceiverInstruction(instruction); TransceiverStructs.TransceiverInstruction memory TransceiverInstruction = TransceiverStructs diff --git a/evm/test/mocks/MockTransceivers.sol b/evm/test/mocks/MockTransceivers.sol index 50e888a01..aced4509d 100644 --- a/evm/test/mocks/MockTransceivers.sol +++ b/evm/test/mocks/MockTransceivers.sol @@ -2,7 +2,7 @@ pragma solidity >=0.8.8 <0.9.0; -import "../../src/Transceiver/WormholeTransceiver.sol"; +import "../../src/Transceiver/WormholeTransceiver/WormholeTransceiver.sol"; contract MockWormholeTransceiverContract is WormholeTransceiver { constructor(