Skip to content

Commit

Permalink
Merge pull request #674 from lukso-network/DEV-5243_C4-139--related-t…
Browse files Browse the repository at this point in the history
…o-58-Refactor-function-_fallbackLSP17Extendable-to-enable-to-run-code-after-this-function-is-called_Jean

fix: (C4 #139) refactor `_fallbackLSP17Extendable` function to enable to run code after it is called + prevent potential solc bug "storage write removal"
  • Loading branch information
CJ42 authored Aug 14, 2023
2 parents ef2bbe9 + bf16c54 commit 39ff1f4
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 132 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-lint-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
"lsp9init",
"lsp11",
"lsp11init",
"lsp17",
"lsp20",
"lsp20init",
"lsp23",
Expand Down
65 changes: 21 additions & 44 deletions contracts/LSP0ERC725Account/LSP0ERC725AccountCore.sol
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,18 @@ abstract contract LSP0ERC725AccountCore is
*
* @custom:events {ValueReceived} event when receiving native tokens.
*/
fallback() external payable virtual {
fallback(
bytes calldata callData
) external payable virtual returns (bytes memory) {
if (msg.value != 0) {
emit ValueReceived(msg.sender, msg.value);
}

if (msg.data.length < 4) {
return;
return "";
}

_fallbackLSP17Extendable();
return _fallbackLSP17Extendable(callData);
}

/**
Expand Down Expand Up @@ -789,57 +791,32 @@ abstract contract LSP0ERC725AccountCore is
* If there is an extension for the function selector being called, it calls the extension with the
* CALL opcode, passing the `msg.data` appended with the 20 bytes of the `msg.sender` and
* 32 bytes of the `msg.value`
*
* Because the function uses assembly `return()`/`revert()` to terminate the call, it cannot be
* called before other codes in {fallback()}.
*
* Otherwise, the codes after {_fallbackLSP17Extendable()} may never be reached.
*/
function _fallbackLSP17Extendable() internal virtual override {
function _fallbackLSP17Extendable(
bytes calldata callData
) internal virtual override returns (bytes memory) {
// If there is a function selector
address extension = _getExtension(msg.sig);

// if no extension was found for bytes4(0) return don't revert
if (msg.sig == bytes4(0) && extension == address(0)) return;
if (msg.sig == bytes4(0) && extension == address(0)) return "";

// if no extension was found for other function selectors, revert
if (extension == address(0))
revert NoExtensionFoundForFunctionSelector(msg.sig);

// solhint-disable no-inline-assembly
// if the extension was found, call the extension with the msg.data
// appended with bytes20(address) and bytes32(msg.value)
assembly {
calldatacopy(0, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
mstore(calldatasize(), shl(96, caller()))

// The msg.value is stored right after the calldata + msg.sender
mstore(add(calldatasize(), 20), callvalue())

// Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata
let success := call(
gas(),
extension,
0,
0,
add(calldatasize(), 52),
0,
0
)

// Copy the returned data
returndatacopy(0, 0, returndatasize())

switch success
// call returns 0 on failed calls
case 0 {
revert(0, returndatasize())
}
default {
return(0, returndatasize())
(bool success, bytes memory result) = extension.call(
abi.encodePacked(callData, msg.sender, msg.value)
);

if (success) {
return result;
} else {
// `mload(result)` -> offset in memory where `result.length` is located
// `add(result, 32)` -> offset in memory where `result` data starts
assembly {
let resultdata_size := mload(result)
revert(add(result, 32), resultdata_size)
}
}
}
Expand Down
50 changes: 16 additions & 34 deletions contracts/LSP17ContractExtension/LSP17Extendable.sol
Original file line number Diff line number Diff line change
Expand Up @@ -83,48 +83,30 @@ abstract contract LSP17Extendable is ERC165 {
*
* Otherwise, the codes after _fallbackLSP17Extendable() may never be reached.
*/
function _fallbackLSP17Extendable() internal virtual {
function _fallbackLSP17Extendable(
bytes calldata callData
) internal virtual returns (bytes memory) {
// If there is a function selector
address extension = _getExtension(msg.sig);

// if no extension was found, revert
if (extension == address(0))
revert NoExtensionFoundForFunctionSelector(msg.sig);

// solhint-disable no-inline-assembly
// if the extension was found, call the extension with the msg.data
// appended with bytes20(address) and bytes32(msg.value)
assembly {
calldatacopy(0, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
mstore(calldatasize(), shl(96, caller()))

// The msg.value is stored right after the calldata + msg.sender
mstore(add(calldatasize(), 20), callvalue())

// Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata
let success := call(
gas(),
extension,
0,
0,
add(calldatasize(), 52),
0,
0
)

// Copy the returned data
returndatacopy(0, 0, returndatasize())
(bool success, bytes memory result) = extension.call(
abi.encodePacked(callData, msg.sender, msg.value)
);

switch success
// call returns 0 on failed calls
case 0 {
revert(0, returndatasize())
}
default {
return(0, returndatasize())
if (success) {
return result;
} else {
// `mload(result)` -> offset in memory where `result.length` is located
// `add(result, 32)` -> offset in memory where `result` data starts
// solhint-disable no-inline-assembly
/// @solidity memory-safe-assembly
assembly {
let resultdata_size := mload(result)
revert(add(result, 32), resultdata_size)
}
}
}
Expand Down
89 changes: 38 additions & 51 deletions contracts/LSP9Vault/LSP9VaultCore.sol
Original file line number Diff line number Diff line change
Expand Up @@ -116,73 +116,60 @@ contract LSP9VaultCore is
* - the first 4 bytes of the calldata do not match any publicly callable functions from the contract ABI.
* - receiving native tokens with some calldata.
*/
fallback() external payable virtual {
if (msg.value != 0) emit ValueReceived(msg.sender, msg.value);
if (msg.data.length < 4) return;
fallback(
bytes calldata callData
) external payable virtual returns (bytes memory) {
if (msg.value != 0) {
emit ValueReceived(msg.sender, msg.value);
}

if (msg.data.length < 4) {
return "";
}

_fallbackLSP17Extendable();
return _fallbackLSP17Extendable(callData);
}

/**
* @dev Forwards the call to an extension mapped to a function selector. If no extension address
* is mapped to the function selector (address(0)), then revert.
* @dev Forwards the call to an extension mapped to a function selector.
*
* The bytes4(0) msg.sig is an exception, the function won't revert if there is no extension found
* mapped to bytes4(0), but will execute the call to the extension in case it existed.
* Calls {_getExtension} to get the address of the extension mapped to the function selector being
* called on the account. If there is no extension, the `address(0)` will be returned.
*
* The call to the extension is appended with bytes20 (msg.sender) and bytes32 (msg.value).
* Returns the return value on success and revert in case of failure.
* Reverts if there is no extension for the function being called, except for the bytes4(0) function
* selector, which passes even if there is no extension for it.
*
* Because the function uses assembly {return()/revert()} to terminate the call, it cannot be
* called before other codes in fallback().
*
* Otherwise, the codes after _fallbackLSP17Extendable() may never be reached.
* If there is an extension for the function selector being called, it calls the extension with the
* CALL opcode, passing the `msg.data` appended with the 20 bytes of the `msg.sender` and
* 32 bytes of the `msg.value`
*/
function _fallbackLSP17Extendable() internal virtual override {
function _fallbackLSP17Extendable(
bytes calldata callData
) internal virtual override returns (bytes memory) {
// If there is a function selector
address extension = _getExtension(msg.sig);

// if no extension was found for bytes4(0) return don't revert
if (msg.sig == bytes4(0) && extension == address(0)) return;
if (msg.sig == bytes4(0) && extension == address(0)) return "";

// if no extension was found, revert
// if no extension was found for other function selectors, revert
if (extension == address(0))
revert NoExtensionFoundForFunctionSelector(msg.sig);

// solhint-disable no-inline-assembly
// if the extension was found, call the extension with the msg.data
// appended with bytes20(address) and bytes32(msg.value)
assembly {
calldatacopy(0, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
mstore(calldatasize(), shl(96, caller()))

// The msg.value is stored right after the calldata + msg.sender
mstore(add(calldatasize(), 20), callvalue())

// Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata
let success := call(
gas(),
extension,
0,
0,
add(calldatasize(), 52),
0,
0
)

// Copy the returned data
returndatacopy(0, 0, returndatasize())

switch success
// call returns 0 on failed calls
case 0 {
revert(0, returndatasize())
}
default {
return(0, returndatasize())
(bool success, bytes memory result) = extension.call(
abi.encodePacked(callData, msg.sender, msg.value)
);

if (success) {
return result;
} else {
// `mload(result)` -> offset in memory where `result.length` is located
// `add(result, 32)` -> offset in memory where `result` data starts
// solhint-disable no-inline-assembly
/// @solidity memory-safe-assembly
assembly {
let resultdata_size := mload(result)
revert(add(result, 32), resultdata_size)
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions contracts/Mocks/FallbackExtensions/RevertErrorsTestExtension.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// SPDX-License-Identifier: Apache-2.0
pragma solidity ^0.8.0;

/**
* @dev This contract is used only for testing purposes
*/
contract RevertErrorsTestExtension {
error SomeCustomError(address someAddress);

function revertWithCustomError() public view {
revert SomeCustomError(msg.sender);
}

function revertWithErrorString() public pure {
revert("some error message");
}

function revertWithPanicError() public pure {
uint256 number = 2;

// trigger an arithmetic underflow.
// this should trigger a error of type `Panic(uint256)`
// with error code 17 (0x11) --> Panic(0x11)
number -= 10;
}

function revertWithNoErrorData() public pure {
// solhint-disable reason-string
revert();
}
}
68 changes: 68 additions & 0 deletions contracts/Mocks/LSP17ExtendableTester.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-License-Identifier: Apache-2.0
pragma solidity ^0.8.4;

import {LSP17Extendable} from "../LSP17ContractExtension/LSP17Extendable.sol";

/**
* @dev This contract is used only for testing purposes
*/
contract LSP17ExtendableTester is LSP17Extendable {
mapping(bytes4 => address) internal _extensions;

string internal _someStorageData;
string internal _anotherStorageData;

// This `receive()` function is just put there to disable the following solc compiler warning:
//
// "This contract has a payable fallback function, but no receive ether function.
// Consider adding a receive ether function."
receive() external payable {}

// solhint-disable no-complex-fallback
fallback() external payable {
// CHECK we can update the contract's storage BEFORE calling an extension
setStorageData("updated BEFORE calling `_fallbackLSP17Extendable`");

_fallbackLSP17Extendable(msg.data);

// CHECK we can update the contract's storage AFTER calling an extension
setAnotherStorageData(
"updated AFTER calling `_fallbackLSP17Extendable`"
);
}

function getExtension(
bytes4 functionSelector
) public view returns (address) {
return _getExtension(functionSelector);
}

function setExtension(
bytes4 functionSelector,
address extensionContract
) public {
_extensions[functionSelector] = extensionContract;
}

function getStorageData() public view returns (string memory) {
return _someStorageData;
}

function setStorageData(string memory newData) public {
_someStorageData = newData;
}

function getAnotherStorageData() public view returns (string memory) {
return _anotherStorageData;
}

function setAnotherStorageData(string memory newData) public {
_anotherStorageData = newData;
}

function _getExtension(
bytes4 functionSelector
) internal view override returns (address) {
return _extensions[functionSelector];
}
}
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"test:lsp9init": "hardhat test --no-compile tests/LSP9Vault/LSP9VaultInit.test.ts",
"test:lsp11": "hardhat test --no-compile tests/LSP11BasicSocialRecovery/LSP11BasicSocialRecovery.test.ts",
"test:lsp11init": "hardhat test --no-compile tests/LSP11BasicSocialRecovery/LSP11BasicSocialRecoveryInit.test.ts",
"test:lsp17": "hardhat test --no-compile tests/LSP17ContractExtension/LSP17Extendable.test.ts",
"test:lsp20": "hardhat test --no-compile tests/LSP20CallVerification/LSP6/LSP20WithLSP6.test.ts",
"test:lsp20init": "hardhat test --no-compile tests/LSP20CallVerification/LSP6/LSP20WithLSP6Init.test.ts",
"test:lsp23": "hardhat test --no-compile tests/LSP23LinkedContractsDeployment/LSP23LinkedContractsDeployment.test.ts",
Expand Down
Loading

0 comments on commit 39ff1f4

Please sign in to comment.