Skip to content
This repository has been archived by the owner on Jul 2, 2024. It is now read-only.

Commit

Permalink
Nose/fix rollup audit (#183)
Browse files Browse the repository at this point in the history
* convert to action enum

* cheaper calldata allocation

* magic collision byte

* compiler destructuring fix

* enforce 1 byte fallback sig

* fix previous commit

* minor fixes

* changs magic byte + add collision assertions

* enforce and test invoke id byte
  • Loading branch information
test9955667 authored May 17, 2023
1 parent b6d8380 commit dba7e8a
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ interface IMultiInvokerRollup is IMultiInvoker {

error MultiInvokerRollupAddressIndexOutOfBoundsError();
error MultiInvokerRollupInvalidUint256LengthError();
error MultiInvokerRollupMissingMagicByteError();

function addressCache(uint256 index) external view returns(address);
function addressLookup(address value) external view returns(uint256 index);
Expand Down
154 changes: 97 additions & 57 deletions packages/perennial/contracts/multiinvoker/MultiInvokerRollup.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pragma solidity 0.8.17;
import "./MultiInvoker.sol";
import "../interfaces/IMultiInvokerRollup.sol";


/**
* @title MultiInvokerRollup
* @notice A calldata-optimized implementation of the Perennial MultiInvoker
Expand Down Expand Up @@ -41,6 +42,10 @@ contract MultiInvokerRollup is IMultiInvokerRollup, MultiInvoker {
/// @dev Index lookup of above array for constructing calldata
mapping(address => uint256) public addressLookup;

/// @dev magic byte to prepend to calldata for the fallback.
/// Prevents public fns from being called by arbitrary fallback data
uint8 public constant INVOKE_ID = 73;

/**
* @notice Constructs the contract
* @param usdc_ The USDC token contract address
Expand All @@ -57,12 +62,15 @@ contract MultiInvokerRollup is IMultiInvokerRollup, MultiInvoker {
/**
* @notice This function serves exactly the same as invoke(Invocation[] memory invocations),
* but includes logic to handle the highly packed calldata
* @dev Fallback eliminates the need to include function sig in calldata
* @dev Fallback eliminates need for 4 byte sig. MUST prepend INVOKE_ID to calldata
* @param input Packed data to pass to invoke logic
* @return required no-op
*/
fallback (bytes calldata input) external returns (bytes memory) {
_decodeFallbackAndInvoke(input);
PTR memory ptr;
if (_readUint8(input, ptr) != INVOKE_ID) revert MultiInvokerRollupMissingMagicByteError();

_decodeFallbackAndInvoke(input, ptr);
return "";
}

Expand All @@ -75,81 +83,112 @@ contract MultiInvokerRollup is IMultiInvokerRollup, MultiInvoker {
* [2:length] => current encoded type (see individual type decoding functions)
* @param input Packed data to pass to invoke logic
*/
function _decodeFallbackAndInvoke(bytes calldata input) internal {
PTR memory ptr;
function _decodeFallbackAndInvoke(bytes calldata input, PTR memory ptr) internal {

while (ptr.pos < input.length) {
uint8 action = _readUint8(input, ptr);
PerennialAction action = PerennialAction(_readUint8(input, ptr));

if (action == PerennialAction.DEPOSIT) {
address account = _readAndCacheAddress(input, ptr);
address product = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

if (action == 1) { // DEPOSIT
(address account, address product, UFixed18 amount) =
(_readAndCacheAddress(input, ptr), _readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
_deposit(account, IProduct(product), amount);

} else if (action == 2) { // WITHDRAW
(address receiver, address product, UFixed18 amount) =
(_readAndCacheAddress(input, ptr), _readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.WITHDRAW) {
address receiver = _readAndCacheAddress(input, ptr);
address product = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_withdraw(receiver, IProduct(product), amount);

} else if (action == 3) { // OPEN_TAKE
(address product, UFixed18 amount) = (_readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.OPEN_TAKE) {
address product = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_openTake(IProduct(product), amount);

} else if (action == 4) { // CLOSE_TAKE
(address product, UFixed18 amount) = (_readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.CLOSE_TAKE) {
address product = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_closeTake(IProduct(product), amount);

} else if (action == 5) { // OPEN_MAKE
(address product, UFixed18 amount) = (_readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.OPEN_MAKE) {
address product = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_openMake(IProduct(product), amount);

} else if (action == 6) { // CLOSE_MAKE
(address product, UFixed18 amount) = (_readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.CLOSE_MAKE) {
address product = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_closeMake(IProduct(product), amount);

} else if (action == 7) { // CLAIM
(address product, uint256[] memory programIds) =
(_readAndCacheAddress(input, ptr), _readUint256Array(input, ptr));
} else if (action == PerennialAction.CLAIM) {
address product = _readAndCacheAddress(input, ptr);
uint256[] memory programIds = _readUint256Array(input, ptr);

_claim(IProduct(product), programIds);

} else if (action == 8) { // WRAP
(address receiver, UFixed18 amount) = (_readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.WRAP) {
address receiver = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_wrap(receiver, amount);

} else if (action == 9) { // UNWRAP
(address receiver, UFixed18 amount) = (_readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.UNWRAP) {
address receiver = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_unwrap(receiver, amount);

} else if (action == 10) { // WRAP_AND_DEPOSIT
(address account, address product, UFixed18 amount) =
(_readAndCacheAddress(input, ptr), _readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.WRAP_AND_DEPOSIT) {
address account = _readAndCacheAddress(input, ptr);
address product = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_wrapAndDeposit(account, IProduct(product), amount);

} else if (action == 11) { // WITHDRAW_AND_UNWRAP
(address receiver, address product, UFixed18 amount) =
(_readAndCacheAddress(input, ptr), _readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.WITHDRAW_AND_UNWRAP) {
address receiver = _readAndCacheAddress(input, ptr);
address product = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_withdrawAndUnwrap(receiver, IProduct(product), amount);

} else if (action == 12) { // VAULT_DEPOSIT
(address depositer, address vault, UFixed18 amount) =
(_readAndCacheAddress(input, ptr), _readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.VAULT_DEPOSIT) {
address depositer = _readAndCacheAddress(input, ptr);
address vault = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_vaultDeposit(depositer, IPerennialVault(vault), amount);

} else if (action == 13) { // VAULT_REDEEM
(address vault, UFixed18 shares) = (_readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.VAULT_REDEEM) {
address vault = _readAndCacheAddress(input, ptr);
UFixed18 shares = _readUFixed18(input, ptr);

_vaultRedeem(IPerennialVault(vault), shares);

} else if (action == 14) { // VAULT_CLAIM
(address owner, address vault) = (_readAndCacheAddress(input, ptr), _readAndCacheAddress(input, ptr));
} else if (action == PerennialAction.VAULT_CLAIM) {
address owner = _readAndCacheAddress(input, ptr);
address vault = _readAndCacheAddress(input, ptr);

_vaultClaim(IPerennialVault(vault), owner);

} else if (action == 15) { // VAULT_WRAP_AND_DEPOSIT
(address account, address vault, UFixed18 amount) =
(_readAndCacheAddress(input, ptr), _readAndCacheAddress(input, ptr), _readUFixed18(input, ptr));
} else if (action == PerennialAction.VAULT_WRAP_AND_DEPOSIT) {
address account = _readAndCacheAddress(input, ptr);
address vault = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);

_vaultWrapAndDeposit(account, IPerennialVault(vault), amount);
} else if (action == 16) { // CHARGE_FEE
(address receiver, UFixed18 amount, bool wrapped) =
(_readAndCacheAddress(input, ptr), _readUFixed18(input, ptr), _readBool(input, ptr));

} else if (action == PerennialAction.CHARGE_FEE) {
address receiver = _readAndCacheAddress(input, ptr);
UFixed18 amount = _readUFixed18(input, ptr);
bool wrapped = _readBool(input, ptr);

_chargeFee(receiver, amount, wrapped);
}
}
Expand Down Expand Up @@ -183,7 +222,7 @@ contract MultiInvokerRollup is IMultiInvokerRollup, MultiInvoker {

_cacheAddress(result);
} else {
uint256 idx = _bytesToUint256(input[ptr.pos:ptr.pos + len]);
uint256 idx = _bytesToUint256(input, ptr.pos, len);
ptr.pos += len;

result = _lookupAddress(idx);
Expand Down Expand Up @@ -244,7 +283,7 @@ contract MultiInvokerRollup is IMultiInvokerRollup, MultiInvoker {
* @return result The decoded uint8 length
*/
function _readUint8(bytes calldata input, PTR memory ptr) private pure returns (uint8 result) {
result = _bytesToUint8(input[ptr.pos:ptr.pos + UINT8_LENGTH]);
result = _bytesToUint8(input, ptr.pos);
ptr.pos += UINT8_LENGTH;
}

Expand All @@ -258,7 +297,7 @@ contract MultiInvokerRollup is IMultiInvokerRollup, MultiInvoker {
uint8 len = _readUint8(input, ptr);
if (len > UINT256_LENGTH) revert MultiInvokerRollupInvalidUint256LengthError();

result = _bytesToUint256(input[ptr.pos:ptr.pos + len]);
result = _bytesToUint256(input, ptr.pos, len);
ptr.pos += len;
}

Expand All @@ -267,9 +306,12 @@ contract MultiInvokerRollup is IMultiInvokerRollup, MultiInvoker {
* @param input 1 byte slice to convert to uint8 to decode lengths
* @return result The uint8 representation of input
*/
function _bytesToUint8(bytes memory input) private pure returns (uint8 result) {
function _bytesToUint8(bytes calldata input, uint256 pos) private pure returns (uint8 result) {
assembly {
result := mload(add(input, UINT8_LENGTH))
// 1) load calldata into temp starting at ptr position
let temp := calldataload(add(input.offset, pos))
// 2) shifts the calldata such that only the first byte is stored in result
result := shr(mul(8, sub(UINT256_LENGTH, UINT8_LENGTH)), temp)
}
}

Expand All @@ -291,14 +333,12 @@ contract MultiInvokerRollup is IMultiInvokerRollup, MultiInvoker {
* @param input The bytes to convert to uint256
* @return result The resulting uint256
*/
function _bytesToUint256(bytes memory input) private pure returns (uint256 result) {
uint256 len = input.length;

function _bytesToUint256(bytes calldata input, uint256 pos, uint256 len) private pure returns (uint256 result) {
assembly {
result := mload(add(input, UINT256_LENGTH))
// 1) load the calldata into result starting at the ptr position
result := calldataload(add(input.offset, pos))
// 2) shifts the calldata such that only the next length of bytes specified by `len` populates the uint256 result
result := shr(mul(8, sub(UINT256_LENGTH, len)), result)
}

// readable right shift to change right padding of mload to left padding
result >>= (UINT256_LENGTH - len) * 8;
}
}
Loading

0 comments on commit dba7e8a

Please sign in to comment.