Skip to content

Commit

Permalink
CS5.1 - Disallow callbacks at the start of a script's callback functi…
Browse files Browse the repository at this point in the history
…on (#92)

Brings over the exact same
[fixes](compound-finance/quark#229) to the same
files from the `quark` repo.
  • Loading branch information
kevincheng96 authored Oct 22, 2024
1 parent 08f7025 commit 6d738f6
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/UniswapFlashLoan.sol
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ contract UniswapFlashLoan is IUniswapV3FlashCallback, QuarkScript {
* @param data FlashLoanCallbackPayload encoded to bytes passed from IUniswapV3Pool.flash(); contains scripts info to execute before repaying the flash loan
*/
function uniswapV3FlashCallback(uint256 fee0, uint256 fee1, bytes calldata data) external {
disallowCallback();

FlashLoanCallbackPayload memory input = abi.decode(data, (FlashLoanCallbackPayload));
IUniswapV3Pool pool =
IUniswapV3Pool(PoolAddress.computeAddress(UniswapFactoryAddress.getAddress(), input.poolKey));
Expand Down
2 changes: 2 additions & 0 deletions src/UniswapFlashSwapExactOut.sol
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ contract UniswapFlashSwapExactOut is IUniswapV3SwapCallback, QuarkScript {
* @param data FlashSwap encoded to bytes passed from UniswapV3Pool.swap(); contains script info to execute (possibly with checks) before returning the owed amount
*/
function uniswapV3SwapCallback(int256 amount0Delta, int256 amount1Delta, bytes calldata data) external {
disallowCallback();

FlashSwapExactOutInput memory input = abi.decode(data, (FlashSwapExactOutInput));
IUniswapV3Pool pool =
IUniswapV3Pool(PoolAddress.computeAddress(UniswapFactoryAddress.getAddress(), input.poolKey));
Expand Down
43 changes: 43 additions & 0 deletions test/UniswapFlashLoan.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,49 @@ contract UniswapFlashLoanTest is Test {
assertEq(IComet(comet).borrowBalanceOf(address(wallet)), 1000e6);
}

function testRevertsForSecondCallback() public {
vm.pauseGasMetering();
QuarkWallet wallet = QuarkWallet(factory.create(alice, address(0)));
address[] memory callContracts = new address[](1);
bytes[] memory callDatas = new bytes[](1);
// Call into the wallet and try to execute the fallback function again using the callback mechanism
callContracts[0] = address(wallet);
callDatas[0] = abi.encodeWithSelector(
Ethcall.run.selector,
address(wallet),
abi.encodeCall(UniswapFlashLoan.uniswapV3FlashCallback, (100, 500, bytes(""))),
0
);
QuarkWallet.QuarkOperation memory op = new QuarkOperationHelper().newBasicOpWithCalldata(
wallet,
uniswapFlashLoan,
abi.encodeWithSelector(
UniswapFlashLoan.run.selector,
UniswapFlashLoan.UniswapFlashLoanPayload({
token0: USDC,
token1: DAI,
fee: 100,
amount0: 1000e6,
amount1: 0,
callContract: multicallAddress,
callData: abi.encodeWithSelector(Multicall.run.selector, callContracts, callDatas)
})
),
ScriptType.ScriptAddress
);
bytes memory signature = new SignatureHelper().signOp(alicePrivateKey, wallet, op);
vm.resumeGasMetering();
vm.expectRevert(
abi.encodeWithSelector(
Multicall.MulticallError.selector,
0,
callContracts[0],
abi.encodeWithSelector(QuarkWallet.NoActiveCallback.selector)
)
);
wallet.executeQuarkOperation(op, signature);
}

function testRevertsForInvalidCaller() public {
vm.pauseGasMetering();
QuarkWallet wallet = QuarkWallet(factory.create(alice, address(0)));
Expand Down
45 changes: 45 additions & 0 deletions test/UniswapFlashSwapExactOut.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,51 @@ contract UniswapFlashSwapExactOutTest is Test {
assertEq(IComet(comet).borrowBalanceOf(address(wallet)), borrowAmountOfUSDC);
}

function testRevertsForSecondCallback() public {
vm.pauseGasMetering();
QuarkWallet wallet = QuarkWallet(factory.create(alice, address(0)));
// Set up some funds for test
deal(WETH, address(wallet), 10 ether);
address[] memory callContracts = new address[](1);
bytes[] memory callDatas = new bytes[](1);
// Call into the wallet and try to execute the fallback function again using the callback mechanism
callContracts[0] = address(wallet);
callDatas[0] = abi.encodeWithSelector(
Ethcall.run.selector,
address(wallet),
abi.encodeCall(UniswapFlashSwapExactOut.uniswapV3SwapCallback, (100, 500, bytes(""))),
0
);
QuarkWallet.QuarkOperation memory op = new QuarkOperationHelper().newBasicOpWithCalldata(
wallet,
uniswapFlashSwapExactOut,
abi.encodeWithSelector(
UniswapFlashSwapExactOut.run.selector,
UniswapFlashSwapExactOut.UniswapFlashSwapExactOutPayload({
tokenOut: WETH,
tokenIn: USDC,
fee: 500,
amountOut: 1 ether,
sqrtPriceLimitX96: 0,
callContract: multicallAddress,
callData: abi.encodeWithSelector(Multicall.run.selector, callContracts, callDatas)
})
),
ScriptType.ScriptAddress
);
bytes memory signature = new SignatureHelper().signOp(alicePrivateKey, wallet, op);
vm.resumeGasMetering();
vm.expectRevert(
abi.encodeWithSelector(
Multicall.MulticallError.selector,
0,
callContracts[0],
abi.encodeWithSelector(QuarkWallet.NoActiveCallback.selector)
)
);
wallet.executeQuarkOperation(op, signature);
}

function testInvalidCallerFlashSwap() public {
vm.pauseGasMetering();
QuarkWallet wallet = QuarkWallet(factory.create(alice, address(0)));
Expand Down

0 comments on commit 6d738f6

Please sign in to comment.