From 86b2e9451eab108d7e28f5060db6a8a5470d573a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=9Fingen?= Date: Wed, 4 Sep 2024 15:30:08 +0100 Subject: [PATCH 1/2] fix: Prevent arbitrary calls to leverage zappers --- .../Zappers/Interfaces/IFlashLoanProvider.sol | 10 +- .../Zappers/Interfaces/ILeverageZapper.sol | 5 +- contracts/src/Zappers/LeverageLSTZapper.sol | 18 +- contracts/src/Zappers/LeverageWETHZapper.sol | 18 +- .../Modules/FlashLoans/BalancerFlashLoan.sol | 33 +- contracts/src/test/zapperLeverage.t.sol | 476 +++++++++++++++--- 6 files changed, 443 insertions(+), 117 deletions(-) diff --git a/contracts/src/Zappers/Interfaces/IFlashLoanProvider.sol b/contracts/src/Zappers/Interfaces/IFlashLoanProvider.sol index 406b39eb..d773ac30 100644 --- a/contracts/src/Zappers/Interfaces/IFlashLoanProvider.sol +++ b/contracts/src/Zappers/Interfaces/IFlashLoanProvider.sol @@ -13,11 +13,7 @@ interface IFlashLoanProvider { LeverDownTrove } - function makeFlashLoan( - IERC20 _token, - uint256 _amount, - IFlashLoanReceiver _caller, - Operation _operation, - bytes calldata userData - ) external; + function receiver() external view returns (IFlashLoanReceiver); + + function makeFlashLoan(IERC20 _token, uint256 _amount, Operation _operation, bytes calldata userData) external; } diff --git a/contracts/src/Zappers/Interfaces/ILeverageZapper.sol b/contracts/src/Zappers/Interfaces/ILeverageZapper.sol index 5f66b07e..2dbf6810 100644 --- a/contracts/src/Zappers/Interfaces/ILeverageZapper.sol +++ b/contracts/src/Zappers/Interfaces/ILeverageZapper.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.0; +import "./IFlashLoanProvider.sol"; import "./IExchange.sol"; interface ILeverageZapper { @@ -33,7 +34,9 @@ interface ILeverageZapper { uint256 minBoldAmount; } - function exchange() external returns (IExchange); + function flashLoanProvider() external view returns (IFlashLoanProvider); + + function exchange() external view returns (IExchange); function openLeveragedTroveWithRawETH(OpenLeveragedTroveParams calldata _params) external payable; diff --git a/contracts/src/Zappers/LeverageLSTZapper.sol b/contracts/src/Zappers/LeverageLSTZapper.sol index e5b7ccf9..1c0e3faf 100644 --- a/contracts/src/Zappers/LeverageLSTZapper.sol +++ b/contracts/src/Zappers/LeverageLSTZapper.sol @@ -64,11 +64,7 @@ contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper // Flash loan coll flashLoanProvider.makeFlashLoan( - collTokenCached, - _params.flashLoanAmount, - IFlashLoanReceiver(address(this)), - IFlashLoanProvider.Operation.OpenTrove, - abi.encode(_params) + collTokenCached, _params.flashLoanAmount, IFlashLoanProvider.Operation.OpenTrove, abi.encode(_params) ); } @@ -120,11 +116,7 @@ contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper // Flash loan coll flashLoanProvider.makeFlashLoan( - collToken, - _params.flashLoanAmount, - IFlashLoanReceiver(address(this)), - IFlashLoanProvider.Operation.LeverUpTrove, - abi.encode(_params) + collToken, _params.flashLoanAmount, IFlashLoanProvider.Operation.LeverUpTrove, abi.encode(_params) ); } @@ -161,11 +153,7 @@ contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper // Flash loan coll flashLoanProvider.makeFlashLoan( - collToken, - _params.flashLoanAmount, - IFlashLoanReceiver(address(this)), - IFlashLoanProvider.Operation.LeverDownTrove, - abi.encode(_params) + collToken, _params.flashLoanAmount, IFlashLoanProvider.Operation.LeverDownTrove, abi.encode(_params) ); } diff --git a/contracts/src/Zappers/LeverageWETHZapper.sol b/contracts/src/Zappers/LeverageWETHZapper.sol index 2ccfb486..ab3478fb 100644 --- a/contracts/src/Zappers/LeverageWETHZapper.sol +++ b/contracts/src/Zappers/LeverageWETHZapper.sol @@ -55,11 +55,7 @@ contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper { // Flash loan coll flashLoanProvider.makeFlashLoan( - WETHCached, - _params.flashLoanAmount, - IFlashLoanReceiver(address(this)), - IFlashLoanProvider.Operation.OpenTrove, - abi.encode(_params) + WETHCached, _params.flashLoanAmount, IFlashLoanProvider.Operation.OpenTrove, abi.encode(_params) ); } @@ -112,11 +108,7 @@ contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper { // Flash loan coll flashLoanProvider.makeFlashLoan( - WETH, - _params.flashLoanAmount, - IFlashLoanReceiver(address(this)), - IFlashLoanProvider.Operation.LeverUpTrove, - abi.encode(_params) + WETH, _params.flashLoanAmount, IFlashLoanProvider.Operation.LeverUpTrove, abi.encode(_params) ); } @@ -153,11 +145,7 @@ contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper { // Flash loan coll flashLoanProvider.makeFlashLoan( - WETH, - _params.flashLoanAmount, - IFlashLoanReceiver(address(this)), - IFlashLoanProvider.Operation.LeverDownTrove, - abi.encode(_params) + WETH, _params.flashLoanAmount, IFlashLoanProvider.Operation.LeverDownTrove, abi.encode(_params) ); } diff --git a/contracts/src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol b/contracts/src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol index 9f1915b5..3920f500 100644 --- a/contracts/src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol +++ b/contracts/src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol @@ -16,14 +16,9 @@ contract BalancerFlashLoan is IFlashLoanRecipient, IFlashLoanProvider { using SafeERC20 for IERC20; IVault private constant vault = IVault(0xBA12222222228d8Ba445958a75a0704d566BF2C8); + IFlashLoanReceiver public receiver; - function makeFlashLoan( - IERC20 _token, - uint256 _amount, - IFlashLoanReceiver _caller, // TODO: should it always be msg.sender? - Operation _operation, - bytes calldata _params - ) external { + function makeFlashLoan(IERC20 _token, uint256 _amount, Operation _operation, bytes calldata _params) external { IERC20[] memory tokens = new IERC20[](1); tokens[0] = _token; uint256[] memory amounts = new uint256[](1); @@ -34,19 +29,22 @@ contract BalancerFlashLoan is IFlashLoanRecipient, IFlashLoanProvider { if (_operation == Operation.OpenTrove) { ILeverageZapper.OpenLeveragedTroveParams memory openTroveParams = abi.decode(_params, (ILeverageZapper.OpenLeveragedTroveParams)); - userData = abi.encode(_caller, _operation, openTroveParams); + userData = abi.encode(_operation, openTroveParams); } else if (_operation == Operation.LeverUpTrove) { ILeverageZapper.LeverUpTroveParams memory leverUpTroveParams = abi.decode(_params, (ILeverageZapper.LeverUpTroveParams)); - userData = abi.encode(_caller, _operation, leverUpTroveParams); + userData = abi.encode(_operation, leverUpTroveParams); } else if (_operation == Operation.LeverDownTrove) { ILeverageZapper.LeverDownTroveParams memory leverDownTroveParams = abi.decode(_params, (ILeverageZapper.LeverDownTroveParams)); - userData = abi.encode(_caller, _operation, leverDownTroveParams); + userData = abi.encode(_operation, leverDownTroveParams); } else { revert("LZ: Wrong Operation"); } + // This will be used by the callback below no + receiver = IFlashLoanReceiver(msg.sender); + vault.flashLoan(this, tokens, amounts, userData); } @@ -57,16 +55,16 @@ contract BalancerFlashLoan is IFlashLoanRecipient, IFlashLoanProvider { bytes calldata userData ) external override { require(msg.sender == address(vault), "Caller is not Vault"); + require(address(receiver) != address(0), "Flash loan not properly initiated"); - // decode receiver and operation - IFlashLoanReceiver receiver = IFlashLoanReceiver(abi.decode(userData[0:32], (address))); - Operation operation = abi.decode(userData[32:64], (Operation)); + // decode and operation + Operation operation = abi.decode(userData[0:32], (Operation)); if (operation == Operation.OpenTrove) { // Open // decode params ILeverageZapper.OpenLeveragedTroveParams memory openTroveParams = - abi.decode(userData[64:], (ILeverageZapper.OpenLeveragedTroveParams)); + abi.decode(userData[32:], (ILeverageZapper.OpenLeveragedTroveParams)); // Flash loan minus fees uint256 effectiveFlashLoanAmount = amounts[0] - feeAmounts[0]; // We send only effective flash loan, keeping fees here @@ -77,7 +75,7 @@ contract BalancerFlashLoan is IFlashLoanRecipient, IFlashLoanProvider { // Lever up // decode params ILeverageZapper.LeverUpTroveParams memory leverUpTroveParams = - abi.decode(userData[64:], (ILeverageZapper.LeverUpTroveParams)); + abi.decode(userData[32:], (ILeverageZapper.LeverUpTroveParams)); // Flash loan minus fees uint256 effectiveFlashLoanAmount = amounts[0] - feeAmounts[0]; // We send only effective flash loan, keeping fees here @@ -88,7 +86,7 @@ contract BalancerFlashLoan is IFlashLoanRecipient, IFlashLoanProvider { // Lever down // decode params ILeverageZapper.LeverDownTroveParams memory leverDownTroveParams = - abi.decode(userData[64:], (ILeverageZapper.LeverDownTroveParams)); + abi.decode(userData[32:], (ILeverageZapper.LeverDownTroveParams)); // Flash loan minus fees uint256 effectiveFlashLoanAmount = amounts[0] - feeAmounts[0]; // We send only effective flash loan, keeping fees here @@ -101,5 +99,8 @@ contract BalancerFlashLoan is IFlashLoanRecipient, IFlashLoanProvider { // Return flash loan tokens[0].safeTransfer(address(vault), amounts[0] + feeAmounts[0]); + + // Reset receiver + receiver = IFlashLoanReceiver(address(0)); } } diff --git a/contracts/src/test/zapperLeverage.t.sol b/contracts/src/test/zapperLeverage.t.sol index e6bb6dc0..11e4d158 100644 --- a/contracts/src/test/zapperLeverage.t.sol +++ b/contracts/src/test/zapperLeverage.t.sol @@ -11,6 +11,8 @@ import "../Zappers/Modules/Exchanges/UniswapV3/IUniswapV3Pool.sol"; import "../Zappers/Modules/Exchanges/UniV3Exchange.sol"; import "../Zappers/Modules/Exchanges/UniswapV3/INonfungiblePositionManager.sol"; import "../Zappers/Modules/Exchanges/UniswapV3/IUniswapV3Factory.sol"; +import "../Zappers/Interfaces/IFlashLoanProvider.sol"; +import "../Zappers/Modules/FlashLoans/Balancer/vault/IVault.sol"; import "./Utils/UniPriceConverter.sol"; contract ZapperLeverageLSTMainnet is DevTestSetup { @@ -28,6 +30,16 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { TestDeployer.LiquityContracts[] contractsArray; + struct OpenTroveVars { + uint256 price; + uint256 flashLoanAmount; + uint256 expectedBoldAmount; + uint256 maxNetDebt; + uint256 effectiveBoldAmount; + uint256 value; + uint256 troveId; + } + struct LeverVars { uint256 price; uint256 currentCR; @@ -46,6 +58,7 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { uint256 initialDebt; uint256 newLeverageRatio; uint256 resultingCollateralRatio; + uint256 boldBalanceBefore; uint256 ethBalanceBefore; uint256 collBalanceBefore; uint256 flashLoanAmount; @@ -215,21 +228,34 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { IPriceFeed _priceFeed, bool _lst ) internal returns (uint256) { - (uint256 price,) = _priceFeed.fetchPrice(); + return openLeveragedTroveWithIndex(_leverageZapper, 0, _collAmount, _leverageRatio, _priceFeed, _lst); + } + + function openLeveragedTroveWithIndex( + ILeverageZapper _leverageZapper, + uint256 _index, + uint256 _collAmount, + uint256 _leverageRatio, + IPriceFeed _priceFeed, + bool _lst + ) internal returns (uint256) { + OpenTroveVars memory vars; + (vars.price,) = _priceFeed.fetchPrice(); IExchange exchange = _leverageZapper.exchange(); // This should be done in the frontend - uint256 flashLoanAmount = _collAmount * (_leverageRatio - DECIMAL_PRECISION) / DECIMAL_PRECISION; - uint256 expectedBoldAmount = flashLoanAmount * price / DECIMAL_PRECISION; - uint256 maxNetDebt = expectedBoldAmount * 105 / 100; // slippage - uint256 effectiveBoldAmount = exchange.getBoldAmountToSwap(expectedBoldAmount, maxNetDebt, flashLoanAmount); + vars.flashLoanAmount = _collAmount * (_leverageRatio - DECIMAL_PRECISION) / DECIMAL_PRECISION; + vars.expectedBoldAmount = vars.flashLoanAmount * vars.price / DECIMAL_PRECISION; + vars.maxNetDebt = vars.expectedBoldAmount * 105 / 100; // slippage + vars.effectiveBoldAmount = + exchange.getBoldAmountToSwap(vars.expectedBoldAmount, vars.maxNetDebt, vars.flashLoanAmount); ILeverageZapper.OpenLeveragedTroveParams memory params = ILeverageZapper.OpenLeveragedTroveParams({ owner: A, - ownerIndex: 0, + ownerIndex: _index, collAmount: _collAmount, - flashLoanAmount: flashLoanAmount, - boldAmount: effectiveBoldAmount, + flashLoanAmount: vars.flashLoanAmount, + boldAmount: vars.effectiveBoldAmount, upperHint: 0, lowerHint: 0, annualInterestRate: 5e16, @@ -239,12 +265,12 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { receiver: address(0) }); vm.startPrank(A); - uint256 value = _lst ? ETH_GAS_COMPENSATION : _collAmount + ETH_GAS_COMPENSATION; - _leverageZapper.openLeveragedTroveWithRawETH{value: value}(params); - uint256 troveId = addressToTroveId(A); + vars.value = _lst ? ETH_GAS_COMPENSATION : _collAmount + ETH_GAS_COMPENSATION; + _leverageZapper.openLeveragedTroveWithRawETH{value: vars.value}(params); + vars.troveId = addressToTroveId(A, _index); vm.stopPrank(); - return troveId; + return vars.troveId; } function testCanOpenTroveWithCurve() external { @@ -265,6 +291,7 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { uint256 leverageRatio = 2e18; uint256 resultingCollateralRatio = _leverageZapper.leverageRatioToCollateralRatio(leverageRatio); + uint256 boldBalanceBefore = boldToken.balanceOf(A); uint256 ethBalanceBefore = A.balance; uint256 collBalanceBefore = contractsArray[_branch].collToken.balanceOf(A); @@ -299,7 +326,7 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { "Wrong CR" ); // token balances - assertApproxEqAbs(boldToken.balanceOf(A), 0, 15, "BOLD bal mismatch"); + assertEq(boldToken.balanceOf(A), boldBalanceBefore, "BOLD bal mismatch"); if (lst) { assertEq(A.balance, ethBalanceBefore - ETH_GAS_COMPENSATION, "ETH bal mismatch"); assertEq( @@ -342,31 +369,24 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { vm.expectRevert("LZ: Caller not FlashLoan provider"); IFlashLoanReceiver(address(_leverageZapper)).receiveFlashLoanOnOpenLeveragedTrove(params, 10 ether); vm.stopPrank(); - } - // Lever up - /* - function leverUpTroveWithCurve(uint256 _troveId, uint256 _leverageRatio) internal returns (uint256) { - return leverUpTrove(leverageZapperCurve, _troveId, _leverageRatio); + // Check receiver is back to zero + assertEq(address(_leverageZapper.flashLoanProvider().receiver()), address(0), "Receiver should be zero"); } - function leverUpTroveWithUniV3(uint256 _troveId, uint256 _leverageRatio) internal returns (uint256) { - return leverUpTrove(leverageZapperUniV3, _troveId, _leverageRatio); - } - */ + // Lever up - function leverUpTrove( + function _getLeverUpFlashLoanAndBoldAmount( ILeverageZapper _leverageZapper, uint256 _troveId, uint256 _leverageRatio, ITroveManager _troveManager, IPriceFeed _priceFeed - ) internal returns (uint256) { - LeverVars memory vars; - (vars.price,) = _priceFeed.fetchPrice(); + ) internal returns (uint256, uint256) { IExchange exchange = _leverageZapper.exchange(); - // This should be done in the frontend + LeverVars memory vars; + (vars.price,) = _priceFeed.fetchPrice(); vars.currentCR = _troveManager.getCurrentICR(_troveId, vars.price); vars.currentLR = _leverageZapper.leverageRatioToCollateralRatio(vars.currentCR); assertGt(_leverageRatio, vars.currentLR, "Leverage ratio should increase"); @@ -378,17 +398,31 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { vars.effectiveBoldAmount = exchange.getBoldAmountToSwap(vars.expectedBoldAmount, vars.maxNetDebtIncrease, vars.flashLoanAmount); + return (vars.flashLoanAmount, vars.effectiveBoldAmount); + } + + function leverUpTrove( + ILeverageZapper _leverageZapper, + uint256 _troveId, + uint256 _leverageRatio, + ITroveManager _troveManager, + IPriceFeed _priceFeed + ) internal returns (uint256) { + // This should be done in the frontend + (uint256 flashLoanAmount, uint256 effectiveBoldAmount) = + _getLeverUpFlashLoanAndBoldAmount(_leverageZapper, _troveId, _leverageRatio, _troveManager, _priceFeed); + ILeverageZapper.LeverUpTroveParams memory params = ILeverageZapper.LeverUpTroveParams({ troveId: _troveId, - flashLoanAmount: vars.flashLoanAmount, - boldAmount: vars.effectiveBoldAmount, + flashLoanAmount: flashLoanAmount, + boldAmount: effectiveBoldAmount, maxUpfrontFee: 1000e18 }); vm.startPrank(A); _leverageZapper.leverUpTrove(params); vm.stopPrank(); - return vars.flashLoanAmount; + return flashLoanAmount; } function testCanLeverUpTroveWithCurve() external { @@ -417,6 +451,7 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { vars.newLeverageRatio = 2.5e18; vars.resultingCollateralRatio = _leverageZapper.leverageRatioToCollateralRatio(vars.newLeverageRatio); + vars.boldBalanceBefore = boldToken.balanceOf(A); vars.ethBalanceBefore = A.balance; vars.collBalanceBefore = contractsArray[_branch].collToken.balanceOf(A); @@ -451,9 +486,12 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { "Wrong CR" ); // token balances - assertApproxEqAbs(boldToken.balanceOf(A), 0, 10, "BOLD bal mismatch"); + assertEq(boldToken.balanceOf(A), vars.boldBalanceBefore, "BOLD bal mismatch"); assertEq(A.balance, vars.ethBalanceBefore, "ETH bal mismatch"); assertEq(contractsArray[_branch].collToken.balanceOf(A), vars.collBalanceBefore, "Coll bal mismatch"); + + // Check receiver is back to zero + assertEq(address(_leverageZapper.flashLoanProvider().receiver()), address(0), "Receiver should be zero"); } function testOnlyFlashLoanProviderCanCallLeverUpCallbackWithCurve() external { @@ -481,27 +519,169 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { vm.stopPrank(); } - // Lever down - /* - function leverDownTroveWithCurve(uint256 _troveId, uint256 _leverageRatio) internal returns (uint256) { - return leverDownTrove(leverageZapperCurve, _troveId, _leverageRatio); + function testOnlyOwnerOrManagerCanLeverUpWithCurveFromZapper() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverUpFromZapper(leverageZapperCurveArray[i], i); + } } - function leverDownTroveWithUniV3(uint256 _troveId, uint256 _leverageRatio) internal returns (uint256) { - return leverDownTrove(leverageZapperUniV3, _troveId, _leverageRatio); + function testOnlyOwnerOrManagerCanLeverUpWithUnFromZapperi() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverUpFromZapper(leverageZapperUniV3Array[i], i); + } } - */ - function leverDownTrove( + function _testOnlyOwnerOrManagerCanLeverUpFromZapper(ILeverageZapper _leverageZapper, uint256 _branch) internal { + // Open trove + uint256 collAmount = 10 ether; + uint256 leverageRatio = 2e18; + bool lst = _branch > 0; + uint256 troveId = openLeveragedTroveWithIndex( + _leverageZapper, 0, collAmount, leverageRatio, contractsArray[_branch].priceFeed, lst + ); + + (uint256 flashLoanAmount, uint256 effectiveBoldAmount) = _getLeverUpFlashLoanAndBoldAmount( + _leverageZapper, + troveId, + 2.5e18, // _leverageRatio, + contractsArray[_branch].troveManager, + contractsArray[_branch].priceFeed + ); + + ILeverageZapper.LeverUpTroveParams memory params = ILeverageZapper.LeverUpTroveParams({ + troveId: troveId, + flashLoanAmount: flashLoanAmount, + boldAmount: effectiveBoldAmount, + maxUpfrontFee: 1000e18 + }); + // B tries to lever up A’s trove + vm.startPrank(B); + vm.expectRevert(AddRemoveManagers.NotOwnerNorRemoveManager.selector); + _leverageZapper.leverUpTrove(params); + vm.stopPrank(); + + // Check receiver is back to zero + assertEq(address(_leverageZapper.flashLoanProvider().receiver()), address(0), "Receiver should be zero"); + } + + function testOnlyOwnerOrManagerCanLeverUpWithCurveFromBalancerFLProvider() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverUpFromBalancerFLProvider(leverageZapperCurveArray[i], i); + } + } + + function testOnlyOwnerOrManagerCanLeverUpWithUniFromBalancerFLProvider() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverUpFromBalancerFLProvider(leverageZapperUniV3Array[i], i); + } + } + + function _testOnlyOwnerOrManagerCanLeverUpFromBalancerFLProvider(ILeverageZapper _leverageZapper, uint256 _branch) + internal + { + // Open trove + uint256 collAmount = 10 ether; + uint256 leverageRatio = 2e18; + bool lst = _branch > 0; + uint256 troveId = openLeveragedTroveWithIndex( + _leverageZapper, 1, collAmount, leverageRatio, contractsArray[_branch].priceFeed, lst + ); + + (uint256 flashLoanAmount, uint256 effectiveBoldAmount) = _getLeverUpFlashLoanAndBoldAmount( + _leverageZapper, + troveId, + 2.5e18, // _leverageRatio, + contractsArray[_branch].troveManager, + contractsArray[_branch].priceFeed + ); + + // B tries to lever up A’s trove calling our flash loan provider module + ILeverageZapper.LeverUpTroveParams memory params = ILeverageZapper.LeverUpTroveParams({ + troveId: troveId, + flashLoanAmount: flashLoanAmount, + boldAmount: effectiveBoldAmount, + maxUpfrontFee: 1000e18 + }); + IFlashLoanProvider flashLoanProvider = _leverageZapper.flashLoanProvider(); + vm.startPrank(B); + vm.expectRevert(); // reverts without data because it calls back B + flashLoanProvider.makeFlashLoan( + contractsArray[_branch].collToken, + flashLoanAmount, + IFlashLoanProvider.Operation.LeverUpTrove, + abi.encode(params) + ); + vm.stopPrank(); + + // Check receiver is back to zero + assertEq(address(flashLoanProvider.receiver()), address(0), "Receiver should be zero"); + } + + function testOnlyOwnerOrManagerCanLeverUpWithCurveFromBalancerVault() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverUpFromBalancerVault(leverageZapperCurveArray[i], i); + } + } + + function testOnlyOwnerOrManagerCanLeverUpWithUniFromBalancerVault() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverUpFromBalancerVault(leverageZapperUniV3Array[i], i); + } + } + + function _testOnlyOwnerOrManagerCanLeverUpFromBalancerVault(ILeverageZapper _leverageZapper, uint256 _branch) + internal + { + // Open trove + uint256 collAmount = 10 ether; + uint256 leverageRatio = 2e18; + bool lst = _branch > 0; + uint256 troveId = openLeveragedTroveWithIndex( + _leverageZapper, 2, collAmount, leverageRatio, contractsArray[_branch].priceFeed, lst + ); + + // B tries to lever up A’s trove calling Balancer Vault directly + (uint256 flashLoanAmount, uint256 effectiveBoldAmount) = _getLeverUpFlashLoanAndBoldAmount( + _leverageZapper, + troveId, + 2.5e18, // _leverageRatio, + contractsArray[_branch].troveManager, + contractsArray[_branch].priceFeed + ); + + ILeverageZapper.LeverUpTroveParams memory params = ILeverageZapper.LeverUpTroveParams({ + troveId: troveId, + flashLoanAmount: flashLoanAmount, + boldAmount: effectiveBoldAmount, + maxUpfrontFee: 1000e18 + }); + IFlashLoanProvider flashLoanProvider = _leverageZapper.flashLoanProvider(); + IERC20[] memory tokens = new IERC20[](1); + tokens[0] = contractsArray[_branch].collToken; + uint256[] memory amounts = new uint256[](1); + amounts[0] = flashLoanAmount; + bytes memory userData = abi.encode(address(_leverageZapper), IFlashLoanProvider.Operation.LeverUpTrove, params); + IVault vault = IVault(0xBA12222222228d8Ba445958a75a0704d566BF2C8); + vm.startPrank(B); + vm.expectRevert("Flash loan not properly initiated"); + vault.flashLoan(IFlashLoanRecipient(address(flashLoanProvider)), tokens, amounts, userData); + vm.stopPrank(); + + // Check receiver is back to zero + assertEq(address(flashLoanProvider.receiver()), address(0), "Receiver should be zero"); + } + + // Lever down + + function _getLeverDownFlashLoanAndBoldAmount( ILeverageZapper _leverageZapper, uint256 _troveId, uint256 _leverageRatio, ITroveManager _troveManager, IPriceFeed _priceFeed - ) internal returns (uint256) { + ) internal returns (uint256, uint256) { (uint256 price,) = _priceFeed.fetchPrice(); - // This should be done in the frontend uint256 currentCR = _troveManager.getCurrentICR(_troveId, price); uint256 currentLR = _leverageZapper.leverageRatioToCollateralRatio(currentCR); assertLt(_leverageRatio, currentLR, "Leverage ratio should decrease"); @@ -510,6 +690,20 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { uint256 expectedBoldAmount = flashLoanAmount * price / DECIMAL_PRECISION; uint256 minBoldDebt = expectedBoldAmount * 95 / 100; // slippage + return (flashLoanAmount, minBoldDebt); + } + + function leverDownTrove( + ILeverageZapper _leverageZapper, + uint256 _troveId, + uint256 _leverageRatio, + ITroveManager _troveManager, + IPriceFeed _priceFeed + ) internal returns (uint256) { + // This should be done in the frontend + (uint256 flashLoanAmount, uint256 minBoldDebt) = + _getLeverDownFlashLoanAndBoldAmount(_leverageZapper, _troveId, _leverageRatio, _troveManager, _priceFeed); + ILeverageZapper.LeverDownTroveParams memory params = ILeverageZapper.LeverDownTroveParams({ troveId: _troveId, flashLoanAmount: flashLoanAmount, @@ -536,54 +730,60 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { } function _testCanLeverDownTrove(ILeverageZapper _leverageZapper, uint256 _branch) internal { - uint256 collAmount = 10 ether; - uint256 initialLeverageRatio = 2e18; + TestVars memory vars; + vars.collAmount = 10 ether; + vars.initialLeverageRatio = 2e18; - uint256 troveId = openLeveragedTrove( - _leverageZapper, collAmount, initialLeverageRatio, contractsArray[_branch].priceFeed, _branch > 0 + vars.troveId = openLeveragedTrove( + _leverageZapper, vars.collAmount, vars.initialLeverageRatio, contractsArray[_branch].priceFeed, _branch > 0 ); - uint256 initialDebt = getTroveEntireDebt(contractsArray[_branch].troveManager, troveId); + vars.initialDebt = getTroveEntireDebt(contractsArray[_branch].troveManager, vars.troveId); - uint256 newLeverageRatio = 1.5e18; - uint256 resultingCollateralRatio = _leverageZapper.leverageRatioToCollateralRatio(newLeverageRatio); + vars.newLeverageRatio = 1.5e18; + vars.resultingCollateralRatio = _leverageZapper.leverageRatioToCollateralRatio(vars.newLeverageRatio); - uint256 ethBalanceBefore = A.balance; - uint256 collBalanceBefore = contractsArray[_branch].collToken.balanceOf(A); + vars.boldBalanceBefore = boldToken.balanceOf(A); + vars.ethBalanceBefore = A.balance; + vars.collBalanceBefore = contractsArray[_branch].collToken.balanceOf(A); - uint256 flashLoanAmount = leverDownTrove( + vars.flashLoanAmount = leverDownTrove( _leverageZapper, - troveId, - newLeverageRatio, + vars.troveId, + vars.newLeverageRatio, contractsArray[_branch].troveManager, contractsArray[_branch].priceFeed ); // Checks - (uint256 price,) = contractsArray[_branch].priceFeed.fetchPrice(); + (vars.price,) = contractsArray[_branch].priceFeed.fetchPrice(); // coll assertApproxEqAbs( - getTroveEntireColl(contractsArray[_branch].troveManager, troveId), - collAmount * newLeverageRatio / DECIMAL_PRECISION, + getTroveEntireColl(contractsArray[_branch].troveManager, vars.troveId), + vars.collAmount * vars.newLeverageRatio / DECIMAL_PRECISION, 22e16, "Coll mismatch" ); // debt - uint256 expectedMinNetDebt = initialDebt - flashLoanAmount * price / DECIMAL_PRECISION * 101 / 100; + uint256 expectedMinNetDebt = + vars.initialDebt - vars.flashLoanAmount * vars.price / DECIMAL_PRECISION * 101 / 100; uint256 expectedMaxNetDebt = expectedMinNetDebt * 105 / 100; - uint256 troveEntireDebt = getTroveEntireDebt(contractsArray[_branch].troveManager, troveId); + uint256 troveEntireDebt = getTroveEntireDebt(contractsArray[_branch].troveManager, vars.troveId); assertGe(troveEntireDebt, expectedMinNetDebt, "Debt too low"); assertLe(troveEntireDebt, expectedMaxNetDebt, "Debt too high"); // CR assertApproxEqAbs( - contractsArray[_branch].troveManager.getCurrentICR(troveId, price), - resultingCollateralRatio, + contractsArray[_branch].troveManager.getCurrentICR(vars.troveId, vars.price), + vars.resultingCollateralRatio, 3e15, "Wrong CR" ); // token balances - assertApproxEqAbs(boldToken.balanceOf(A), 0, 15, "BOLD bal mismatch"); - assertEq(A.balance, ethBalanceBefore, "ETH bal mismatch"); - assertEq(contractsArray[_branch].collToken.balanceOf(A), collBalanceBefore, "Coll bal mismatch"); + assertEq(boldToken.balanceOf(A), vars.boldBalanceBefore, "BOLD bal mismatch"); + assertEq(A.balance, vars.ethBalanceBefore, "ETH bal mismatch"); + assertEq(contractsArray[_branch].collToken.balanceOf(A), vars.collBalanceBefore, "Coll bal mismatch"); + + // Check receiver is back to zero + assertEq(address(_leverageZapper.flashLoanProvider().receiver()), address(0), "Receiver should be zero"); } function testOnlyFlashLoanProviderCanCallLeverDownCallbackWithCurve() external { @@ -609,4 +809,154 @@ contract ZapperLeverageLSTMainnet is DevTestSetup { IFlashLoanReceiver(address(_leverageZapper)).receiveFlashLoanOnLeverDownTrove(params, 10 ether); vm.stopPrank(); } + + function testOnlyOwnerOrManagerCanLeverDownWithCurveFromZapper() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverDownFromZapper(leverageZapperCurveArray[i], i); + } + } + + function testOnlyOwnerOrManagerCanLeverDownWithUnFromZapperi() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverDownFromZapper(leverageZapperUniV3Array[i], i); + } + } + + function _testOnlyOwnerOrManagerCanLeverDownFromZapper(ILeverageZapper _leverageZapper, uint256 _branch) internal { + // Open trove + uint256 collAmount = 10 ether; + uint256 leverageRatio = 2e18; + bool lst = _branch > 0; + uint256 troveId = openLeveragedTroveWithIndex( + _leverageZapper, 0, collAmount, leverageRatio, contractsArray[_branch].priceFeed, lst + ); + + // B tries to lever up A’s trove + (uint256 flashLoanAmount, uint256 minBoldDebt) = _getLeverDownFlashLoanAndBoldAmount( + _leverageZapper, + troveId, + 1.5e18, // _leverageRatio, + contractsArray[_branch].troveManager, + contractsArray[_branch].priceFeed + ); + + ILeverageZapper.LeverDownTroveParams memory params = ILeverageZapper.LeverDownTroveParams({ + troveId: troveId, + flashLoanAmount: flashLoanAmount, + minBoldAmount: minBoldDebt + }); + vm.startPrank(B); + vm.expectRevert(AddRemoveManagers.NotOwnerNorRemoveManager.selector); + _leverageZapper.leverDownTrove(params); + vm.stopPrank(); + + // Check receiver is back to zero + assertEq(address(_leverageZapper.flashLoanProvider().receiver()), address(0), "Receiver should be zero"); + } + + function testOnlyOwnerOrManagerCanLeverDownWithCurveFromBalancerFLProvider() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverDownFromBalancerFLProvider(leverageZapperCurveArray[i], i); + } + } + + function testOnlyOwnerOrManagerCanLeverDownWithUniFromBalancerFLProvider() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverDownFromBalancerFLProvider(leverageZapperUniV3Array[i], i); + } + } + + function _testOnlyOwnerOrManagerCanLeverDownFromBalancerFLProvider(ILeverageZapper _leverageZapper, uint256 _branch) + internal + { + // Open trove + uint256 collAmount = 10 ether; + uint256 leverageRatio = 2e18; + bool lst = _branch > 0; + uint256 troveId = openLeveragedTroveWithIndex( + _leverageZapper, 1, collAmount, leverageRatio, contractsArray[_branch].priceFeed, lst + ); + + // B tries to lever up A’s trove calling our flash loan provider module + (uint256 flashLoanAmount, uint256 minBoldDebt) = _getLeverDownFlashLoanAndBoldAmount( + _leverageZapper, + troveId, + 1.5e18, // _leverageRatio, + contractsArray[_branch].troveManager, + contractsArray[_branch].priceFeed + ); + + ILeverageZapper.LeverDownTroveParams memory params = ILeverageZapper.LeverDownTroveParams({ + troveId: troveId, + flashLoanAmount: flashLoanAmount, + minBoldAmount: minBoldDebt + }); + IFlashLoanProvider flashLoanProvider = _leverageZapper.flashLoanProvider(); + vm.startPrank(B); + vm.expectRevert(); // reverts without data because it calls back B + flashLoanProvider.makeFlashLoan( + contractsArray[_branch].collToken, + flashLoanAmount, + IFlashLoanProvider.Operation.LeverDownTrove, + abi.encode(params) + ); + vm.stopPrank(); + + // Check receiver is back to zero + assertEq(address(flashLoanProvider.receiver()), address(0), "Receiver should be zero"); + } + + function testOnlyOwnerOrManagerCanLeverDownWithCurveFromBalancerVault() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverDownFromBalancerVault(leverageZapperCurveArray[i], i); + } + } + + function testOnlyOwnerOrManagerCanLeverDownWithUniFromBalancerVault() external { + for (uint256 i = 0; i < NUM_COLLATERALS; i++) { + _testOnlyOwnerOrManagerCanLeverDownFromBalancerVault(leverageZapperUniV3Array[i], i); + } + } + + function _testOnlyOwnerOrManagerCanLeverDownFromBalancerVault(ILeverageZapper _leverageZapper, uint256 _branch) + internal + { + // Open trove + uint256 collAmount = 10 ether; + uint256 leverageRatio = 2e18; + bool lst = _branch > 0; + uint256 troveId = openLeveragedTroveWithIndex( + _leverageZapper, 2, collAmount, leverageRatio, contractsArray[_branch].priceFeed, lst + ); + + // B tries to lever up A’s trove calling Balancer Vault directly + (uint256 flashLoanAmount, uint256 minBoldDebt) = _getLeverDownFlashLoanAndBoldAmount( + _leverageZapper, + troveId, + 1.5e18, // _leverageRatio, + contractsArray[_branch].troveManager, + contractsArray[_branch].priceFeed + ); + + ILeverageZapper.LeverDownTroveParams memory params = ILeverageZapper.LeverDownTroveParams({ + troveId: troveId, + flashLoanAmount: flashLoanAmount, + minBoldAmount: minBoldDebt + }); + IFlashLoanProvider flashLoanProvider = _leverageZapper.flashLoanProvider(); + IERC20[] memory tokens = new IERC20[](1); + tokens[0] = contractsArray[_branch].collToken; + uint256[] memory amounts = new uint256[](1); + amounts[0] = flashLoanAmount; + bytes memory userData = + abi.encode(address(_leverageZapper), IFlashLoanProvider.Operation.LeverDownTrove, params); + IVault vault = IVault(0xBA12222222228d8Ba445958a75a0704d566BF2C8); + vm.startPrank(B); + vm.expectRevert("Flash loan not properly initiated"); + vault.flashLoan(IFlashLoanRecipient(address(flashLoanProvider)), tokens, amounts, userData); + vm.stopPrank(); + + // Check receiver is back to zero + assertEq(address(flashLoanProvider.receiver()), address(0), "Receiver should be zero"); + } } From 0f89df3d7820a3d3c7ed19decdaf7d8517486665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=9Fingen?= Date: Mon, 9 Sep 2024 10:52:56 +0100 Subject: [PATCH 2/2] fix: Move Balancer flash loan receiver reset to main function For better readability (address PR #409 comments). --- .../src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/contracts/src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol b/contracts/src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol index 3920f500..0c01ee43 100644 --- a/contracts/src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol +++ b/contracts/src/Zappers/Modules/FlashLoans/BalancerFlashLoan.sol @@ -46,6 +46,9 @@ contract BalancerFlashLoan is IFlashLoanRecipient, IFlashLoanProvider { receiver = IFlashLoanReceiver(msg.sender); vault.flashLoan(this, tokens, amounts, userData); + + // Reset receiver + receiver = IFlashLoanReceiver(address(0)); } function receiveFlashLoan( @@ -99,8 +102,5 @@ contract BalancerFlashLoan is IFlashLoanRecipient, IFlashLoanProvider { // Return flash loan tokens[0].safeTransfer(address(vault), amounts[0] + feeAmounts[0]); - - // Reset receiver - receiver = IFlashLoanReceiver(address(0)); } }