From bd2c5cd5355bac0443cb9ec19f4925faf0f18cfc Mon Sep 17 00:00:00 2001 From: CJ42 Date: Thu, 10 Aug 2023 16:26:41 +0100 Subject: [PATCH] fix: allow running code after `_fallbackLSP17Extendable` function. --- .../LSP0ERC725AccountCore.sol | 70 ++++++-------- .../LSP17Extendable.sol | 51 ++++------ contracts/LSP9Vault/LSP9VaultCore.sol | 92 +++++++++---------- 3 files changed, 85 insertions(+), 128 deletions(-) diff --git a/contracts/LSP0ERC725Account/LSP0ERC725AccountCore.sol b/contracts/LSP0ERC725Account/LSP0ERC725AccountCore.sol index ecf7a4473..61b8206ea 100644 --- a/contracts/LSP0ERC725Account/LSP0ERC725AccountCore.sol +++ b/contracts/LSP0ERC725Account/LSP0ERC725AccountCore.sol @@ -152,16 +152,18 @@ abstract contract LSP0ERC725AccountCore is * * @custom:events {ValueReceived} event when receiving native tokens. */ - fallback() external payable virtual { + fallback( + bytes calldata callData + ) external payable virtual returns (bytes memory) { if (msg.value != 0) { emit ValueReceived(msg.sender, msg.value); } if (msg.data.length < 4) { - return; + return ""; } - _fallbackLSP17Extendable(); + return _fallbackLSP17Extendable(callData); } /** @@ -789,57 +791,37 @@ abstract contract LSP0ERC725AccountCore is * If there is an extension for the function selector being called, it calls the extension with the * CALL opcode, passing the `msg.data` appended with the 20 bytes of the `msg.sender` and * 32 bytes of the `msg.value` - * - * Because the function uses assembly `return()`/`revert()` to terminate the call, it cannot be - * called before other codes in {fallback()}. - * - * Otherwise, the codes after {_fallbackLSP17Extendable()} may never be reached. */ - function _fallbackLSP17Extendable() internal virtual override { + function _fallbackLSP17Extendable( + bytes calldata callData + ) internal virtual override returns (bytes memory) { // If there is a function selector address extension = _getExtension(msg.sig); // if no extension was found for bytes4(0) return don't revert - if (msg.sig == bytes4(0) && extension == address(0)) return; + if (msg.sig == bytes4(0) && extension == address(0)) return ""; // if no extension was found for other function selectors, revert if (extension == address(0)) revert NoExtensionFoundForFunctionSelector(msg.sig); - // solhint-disable no-inline-assembly - // if the extension was found, call the extension with the msg.data - // appended with bytes20(address) and bytes32(msg.value) - assembly { - calldatacopy(0, 0, calldatasize()) - - // The msg.sender address is shifted to the left by 12 bytes to remove the padding - // Then the address without padding is stored right after the calldata - mstore(calldatasize(), shl(96, caller())) - - // The msg.value is stored right after the calldata + msg.sender - mstore(add(calldatasize(), 20), callvalue()) - - // Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata - let success := call( - gas(), - extension, - 0, - 0, - add(calldatasize(), 52), - 0, - 0 - ) - - // Copy the returned data - returndatacopy(0, 0, returndatasize()) - - switch success - // call returns 0 on failed calls - case 0 { - revert(0, returndatasize()) - } - default { - return(0, returndatasize()) + bytes memory calldataWithCallerInfos = abi.encodePacked( + callData, + msg.sender, + msg.value + ); + + (bool success, bytes memory result) = extension.call( + calldataWithCallerInfos + ); + + if (success) { + return result; + } else { + // `result` -> first word in memory where the length of `result` is stored + // `add(result, 32)` -> next word in memory is where the `result` data starts + assembly { + revert(add(result, 32), mload(result)) } } } diff --git a/contracts/LSP17ContractExtension/LSP17Extendable.sol b/contracts/LSP17ContractExtension/LSP17Extendable.sol index 66a5505cf..c5461988b 100644 --- a/contracts/LSP17ContractExtension/LSP17Extendable.sol +++ b/contracts/LSP17ContractExtension/LSP17Extendable.sol @@ -83,7 +83,9 @@ abstract contract LSP17Extendable is ERC165 { * * Otherwise, the codes after _fallbackLSP17Extendable() may never be reached. */ - function _fallbackLSP17Extendable() internal virtual { + function _fallbackLSP17Extendable( + bytes calldata callData + ) internal virtual returns (bytes memory) { // If there is a function selector address extension = _getExtension(msg.sig); @@ -91,40 +93,23 @@ abstract contract LSP17Extendable is ERC165 { if (extension == address(0)) revert NoExtensionFoundForFunctionSelector(msg.sig); - // solhint-disable no-inline-assembly - // if the extension was found, call the extension with the msg.data - // appended with bytes20(address) and bytes32(msg.value) - assembly { - calldatacopy(0, 0, calldatasize()) - - // The msg.sender address is shifted to the left by 12 bytes to remove the padding - // Then the address without padding is stored right after the calldata - mstore(calldatasize(), shl(96, caller())) - - // The msg.value is stored right after the calldata + msg.sender - mstore(add(calldatasize(), 20), callvalue()) - - // Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata - let success := call( - gas(), - extension, - 0, - 0, - add(calldatasize(), 52), - 0, - 0 - ) + bytes memory calldataWithCallerInfos = abi.encodePacked( + callData, + msg.sender, + msg.value + ); - // Copy the returned data - returndatacopy(0, 0, returndatasize()) + (bool success, bytes memory result) = extension.call( + calldataWithCallerInfos + ); - switch success - // call returns 0 on failed calls - case 0 { - revert(0, returndatasize()) - } - default { - return(0, returndatasize()) + if (success) { + return result; + } else { + // `result` -> first word in memory where the length of `result` is stored + // `add(result, 32)` -> next word in memory is where the `result` data starts + assembly { + revert(add(result, 32), mload(result)) } } } diff --git a/contracts/LSP9Vault/LSP9VaultCore.sol b/contracts/LSP9Vault/LSP9VaultCore.sol index ee5fcc9ab..0a41223e4 100644 --- a/contracts/LSP9Vault/LSP9VaultCore.sol +++ b/contracts/LSP9Vault/LSP9VaultCore.sol @@ -116,73 +116,63 @@ contract LSP9VaultCore is * - the first 4 bytes of the calldata do not match any publicly callable functions from the contract ABI. * - receiving native tokens with some calldata. */ - fallback() external payable virtual { - if (msg.value != 0) emit ValueReceived(msg.sender, msg.value); - if (msg.data.length < 4) return; + fallback( + bytes calldata callData + ) external payable virtual returns (bytes memory) { + if (msg.value != 0) { + emit ValueReceived(msg.sender, msg.value); + } + + if (msg.data.length < 4) { + return ""; + } - _fallbackLSP17Extendable(); + return _fallbackLSP17Extendable(callData); } /** - * @dev Forwards the call to an extension mapped to a function selector. If no extension address - * is mapped to the function selector (address(0)), then revert. + * @dev Forwards the call to an extension mapped to a function selector. * - * The bytes4(0) msg.sig is an exception, the function won't revert if there is no extension found - * mapped to bytes4(0), but will execute the call to the extension in case it existed. - * - * The call to the extension is appended with bytes20 (msg.sender) and bytes32 (msg.value). - * Returns the return value on success and revert in case of failure. + * Calls {_getExtension} to get the address of the extension mapped to the function selector being + * called on the account. If there is no extension, the `address(0)` will be returned. * - * Because the function uses assembly {return()/revert()} to terminate the call, it cannot be - * called before other codes in fallback(). + * Reverts if there is no extension for the function being called, except for the bytes4(0) function + * selector, which passes even if there is no extension for it. * - * Otherwise, the codes after _fallbackLSP17Extendable() may never be reached. + * If there is an extension for the function selector being called, it calls the extension with the + * CALL opcode, passing the `msg.data` appended with the 20 bytes of the `msg.sender` and + * 32 bytes of the `msg.value` */ - function _fallbackLSP17Extendable() internal virtual override { + function _fallbackLSP17Extendable( + bytes calldata callData + ) internal virtual override returns (bytes memory) { // If there is a function selector address extension = _getExtension(msg.sig); // if no extension was found for bytes4(0) return don't revert - if (msg.sig == bytes4(0) && extension == address(0)) return; + if (msg.sig == bytes4(0) && extension == address(0)) return ""; - // if no extension was found, revert + // if no extension was found for other function selectors, revert if (extension == address(0)) revert NoExtensionFoundForFunctionSelector(msg.sig); - // solhint-disable no-inline-assembly - // if the extension was found, call the extension with the msg.data - // appended with bytes20(address) and bytes32(msg.value) - assembly { - calldatacopy(0, 0, calldatasize()) - - // The msg.sender address is shifted to the left by 12 bytes to remove the padding - // Then the address without padding is stored right after the calldata - mstore(calldatasize(), shl(96, caller())) - - // The msg.value is stored right after the calldata + msg.sender - mstore(add(calldatasize(), 20), callvalue()) - - // Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata - let success := call( - gas(), - extension, - 0, - 0, - add(calldatasize(), 52), - 0, - 0 - ) - - // Copy the returned data - returndatacopy(0, 0, returndatasize()) - - switch success - // call returns 0 on failed calls - case 0 { - revert(0, returndatasize()) - } - default { - return(0, returndatasize()) + bytes memory calldataWithCallerInfos = abi.encodePacked( + callData, + msg.sender, + msg.value + ); + + (bool success, bytes memory result) = extension.call( + calldataWithCallerInfos + ); + + if (success) { + return result; + } else { + // `result` -> first word in memory where the length of `result` is stored + // `add(result, 32)` -> next word in memory is where the `result` data starts + assembly { + revert(add(result, 32), mload(result)) } } }