diff --git a/test/invariants/handlers/OwnerHandler.sol b/test/invariants/handlers/OwnerHandler.sol index 8cb93b1..507ebc5 100644 --- a/test/invariants/handlers/OwnerHandler.sol +++ b/test/invariants/handlers/OwnerHandler.sol @@ -80,6 +80,13 @@ contract OwnerHandler is BaseHandler { function setFees(uint256 _seed) external { numberOfCalls["ownerHandler.setFees"]++; + uint256 feeAccrued = arm.feesAccrued(); + if (feeAccrued > weth.balanceOf(address(arm))) { + console.log("OwnerHandler.setFees() - Not enough liquidity to collect fees"); + numberOfCalls["ownerHandler.setFees.skip"]++; + return; + } + uint256 fee = _bound(_seed, 0, maxFees); console.log("OwnerHandler.setFees(%2e)", fee); @@ -91,6 +98,9 @@ contract OwnerHandler is BaseHandler { // Stop prank vm.stopPrank(); + + // Update sum of fees + sum_of_fees += feeAccrued; } /// @notice Collect fees from the ARM @@ -98,7 +108,8 @@ contract OwnerHandler is BaseHandler { function collectFees(uint256) external { numberOfCalls["ownerHandler.collectFees"]++; - if (_estimatedFeesAccrued() > weth.balanceOf(address(arm))) { + uint256 feeAccrued = arm.feesAccrued(); + if (feeAccrued > weth.balanceOf(address(arm))) { console.log("OwnerHandler.collectFees() - Not enough liquidity to collect fees"); numberOfCalls["ownerHandler.collectFees.skip"]++; return; @@ -108,38 +119,9 @@ contract OwnerHandler is BaseHandler { // Collect fees uint256 fees = arm.collectFees(); + require(feeAccrued == fees, "OwnerHandler.collectFees() - Fees collected do not match fees accrued"); // Update sum of fees sum_of_fees += fees; } - - ////////////////////////////////////////////////////// - /// --- HELPERS - ////////////////////////////////////////////////////// - function _estimateAvailableTotalAssets() internal view returns (uint256) { - uint256 assets = steth.balanceOf(address(arm)) + weth.balanceOf(address(arm)) + arm.outstandingEther(); - - uint256 queuedMem = arm.withdrawsQueued(); - uint256 claimedMem = arm.withdrawsClaimed(); - - if (assets + claimedMem < queuedMem + arm.feesAccrued()) { - return 0; - } - - return assets + claimedMem - queuedMem + arm.feesAccrued(); - } - - function _estimatedFeesAccrued() internal view returns (uint256) { - uint256 newTotalAssets = _estimateAvailableTotalAssets(); - - uint256 lastAvailableAssets = arm.lastAvailableAssets(); - if (newTotalAssets <= lastAvailableAssets) { - return arm.feesAccrued(); - } - - uint256 assetIncrease = newTotalAssets - lastAvailableAssets; - uint256 newFeesAccrued = (assetIncrease * arm.fee()) / arm.FEE_SCALE(); - - return arm.feesAccrued() + newFeesAccrued; - } }