Skip to content

Commit

Permalink
fix: allow running code after _fallbackLSP17Extendable function.
Browse files Browse the repository at this point in the history
  • Loading branch information
CJ42 committed Aug 11, 2023
1 parent 99f3d15 commit ada5adf
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 128 deletions.
70 changes: 26 additions & 44 deletions contracts/LSP0ERC725Account/LSP0ERC725AccountCore.sol
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,18 @@ abstract contract LSP0ERC725AccountCore is
*
* @custom:events {ValueReceived} event when receiving native tokens.
*/
fallback() external payable virtual {
fallback(
bytes calldata callData
) external payable virtual returns (bytes memory) {
if (msg.value != 0) {
emit ValueReceived(msg.sender, msg.value);
}

if (msg.data.length < 4) {
return;
return "";
}

_fallbackLSP17Extendable();
return _fallbackLSP17Extendable(callData);
}

/**
Expand Down Expand Up @@ -789,57 +791,37 @@ abstract contract LSP0ERC725AccountCore is
* If there is an extension for the function selector being called, it calls the extension with the
* CALL opcode, passing the `msg.data` appended with the 20 bytes of the `msg.sender` and
* 32 bytes of the `msg.value`
*
* Because the function uses assembly `return()`/`revert()` to terminate the call, it cannot be
* called before other codes in {fallback()}.
*
* Otherwise, the codes after {_fallbackLSP17Extendable()} may never be reached.
*/
function _fallbackLSP17Extendable() internal virtual override {
function _fallbackLSP17Extendable(
bytes calldata callData
) internal virtual override returns (bytes memory) {
// If there is a function selector
address extension = _getExtension(msg.sig);

// if no extension was found for bytes4(0) return don't revert
if (msg.sig == bytes4(0) && extension == address(0)) return;
if (msg.sig == bytes4(0) && extension == address(0)) return "";

// if no extension was found for other function selectors, revert
if (extension == address(0))
revert NoExtensionFoundForFunctionSelector(msg.sig);

// solhint-disable no-inline-assembly
// if the extension was found, call the extension with the msg.data
// appended with bytes20(address) and bytes32(msg.value)
assembly {
calldatacopy(0, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
mstore(calldatasize(), shl(96, caller()))

// The msg.value is stored right after the calldata + msg.sender
mstore(add(calldatasize(), 20), callvalue())

// Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata
let success := call(
gas(),
extension,
0,
0,
add(calldatasize(), 52),
0,
0
)

// Copy the returned data
returndatacopy(0, 0, returndatasize())

switch success
// call returns 0 on failed calls
case 0 {
revert(0, returndatasize())
}
default {
return(0, returndatasize())
bytes memory calldataWithCallerInfos = abi.encodePacked(
callData,
msg.sender,
msg.value
);

(bool success, bytes memory result) = extension.call(
calldataWithCallerInfos
);

if (success) {
return result;
} else {
// `result` -> first word in memory where the length of `result` is stored
// `add(result, 32)` -> next word in memory is where the `result` data starts
assembly {
revert(add(result, 32), mload(result))
}
}
}
Expand Down
51 changes: 18 additions & 33 deletions contracts/LSP17ContractExtension/LSP17Extendable.sol
Original file line number Diff line number Diff line change
Expand Up @@ -83,48 +83,33 @@ abstract contract LSP17Extendable is ERC165 {
*
* Otherwise, the codes after _fallbackLSP17Extendable() may never be reached.
*/
function _fallbackLSP17Extendable() internal virtual {
function _fallbackLSP17Extendable(
bytes calldata callData
) internal virtual returns (bytes memory) {
// If there is a function selector
address extension = _getExtension(msg.sig);

// if no extension was found, revert
if (extension == address(0))
revert NoExtensionFoundForFunctionSelector(msg.sig);

// solhint-disable no-inline-assembly
// if the extension was found, call the extension with the msg.data
// appended with bytes20(address) and bytes32(msg.value)
assembly {
calldatacopy(0, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
mstore(calldatasize(), shl(96, caller()))

// The msg.value is stored right after the calldata + msg.sender
mstore(add(calldatasize(), 20), callvalue())

// Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata
let success := call(
gas(),
extension,
0,
0,
add(calldatasize(), 52),
0,
0
)
bytes memory calldataWithCallerInfos = abi.encodePacked(
callData,
msg.sender,
msg.value
);

// Copy the returned data
returndatacopy(0, 0, returndatasize())
(bool success, bytes memory result) = extension.call(
calldataWithCallerInfos
);

switch success
// call returns 0 on failed calls
case 0 {
revert(0, returndatasize())
}
default {
return(0, returndatasize())
if (success) {
return result;
} else {
// `result` -> first word in memory where the length of `result` is stored
// `add(result, 32)` -> next word in memory is where the `result` data starts
assembly {
revert(add(result, 32), mload(result))
}
}
}
Expand Down
92 changes: 41 additions & 51 deletions contracts/LSP9Vault/LSP9VaultCore.sol
Original file line number Diff line number Diff line change
Expand Up @@ -116,73 +116,63 @@ contract LSP9VaultCore is
* - the first 4 bytes of the calldata do not match any publicly callable functions from the contract ABI.
* - receiving native tokens with some calldata.
*/
fallback() external payable virtual {
if (msg.value != 0) emit ValueReceived(msg.sender, msg.value);
if (msg.data.length < 4) return;
fallback(
bytes calldata callData
) external payable virtual returns (bytes memory) {
if (msg.value != 0) {
emit ValueReceived(msg.sender, msg.value);
}

if (msg.data.length < 4) {
return "";
}

_fallbackLSP17Extendable();
return _fallbackLSP17Extendable(callData);
}

/**
* @dev Forwards the call to an extension mapped to a function selector. If no extension address
* is mapped to the function selector (address(0)), then revert.
* @dev Forwards the call to an extension mapped to a function selector.
*
* The bytes4(0) msg.sig is an exception, the function won't revert if there is no extension found
* mapped to bytes4(0), but will execute the call to the extension in case it existed.
*
* The call to the extension is appended with bytes20 (msg.sender) and bytes32 (msg.value).
* Returns the return value on success and revert in case of failure.
* Calls {_getExtension} to get the address of the extension mapped to the function selector being
* called on the account. If there is no extension, the `address(0)` will be returned.
*
* Because the function uses assembly {return()/revert()} to terminate the call, it cannot be
* called before other codes in fallback().
* Reverts if there is no extension for the function being called, except for the bytes4(0) function
* selector, which passes even if there is no extension for it.
*
* Otherwise, the codes after _fallbackLSP17Extendable() may never be reached.
* If there is an extension for the function selector being called, it calls the extension with the
* CALL opcode, passing the `msg.data` appended with the 20 bytes of the `msg.sender` and
* 32 bytes of the `msg.value`
*/
function _fallbackLSP17Extendable() internal virtual override {
function _fallbackLSP17Extendable(
bytes calldata callData
) internal virtual override returns (bytes memory) {
// If there is a function selector
address extension = _getExtension(msg.sig);

// if no extension was found for bytes4(0) return don't revert
if (msg.sig == bytes4(0) && extension == address(0)) return;
if (msg.sig == bytes4(0) && extension == address(0)) return "";

// if no extension was found, revert
// if no extension was found for other function selectors, revert
if (extension == address(0))
revert NoExtensionFoundForFunctionSelector(msg.sig);

// solhint-disable no-inline-assembly
// if the extension was found, call the extension with the msg.data
// appended with bytes20(address) and bytes32(msg.value)
assembly {
calldatacopy(0, 0, calldatasize())

// The msg.sender address is shifted to the left by 12 bytes to remove the padding
// Then the address without padding is stored right after the calldata
mstore(calldatasize(), shl(96, caller()))

// The msg.value is stored right after the calldata + msg.sender
mstore(add(calldatasize(), 20), callvalue())

// Add 52 bytes for the msg.sender and msg.value appended at the end of the calldata
let success := call(
gas(),
extension,
0,
0,
add(calldatasize(), 52),
0,
0
)

// Copy the returned data
returndatacopy(0, 0, returndatasize())

switch success
// call returns 0 on failed calls
case 0 {
revert(0, returndatasize())
}
default {
return(0, returndatasize())
bytes memory calldataWithCallerInfos = abi.encodePacked(
callData,
msg.sender,
msg.value
);

(bool success, bytes memory result) = extension.call(
calldataWithCallerInfos
);

if (success) {
return result;
} else {
// `result` -> first word in memory where the length of `result` is stored
// `add(result, 32)` -> next word in memory is where the `result` data starts
assembly {
revert(add(result, 32), mload(result))
}
}
}
Expand Down

0 comments on commit ada5adf

Please sign in to comment.