diff --git a/src/solvers/ConstantSum/ConstantSumSolver.sol b/src/solvers/ConstantSum/ConstantSumSolver.sol index 5a374ecb..19ff6d0d 100644 --- a/src/solvers/ConstantSum/ConstantSumSolver.sol +++ b/src/solvers/ConstantSum/ConstantSumSolver.sol @@ -51,21 +51,12 @@ contract ConstantSumSolver { ); uint256 amountOut; if (swapXIn) { - uint256 fees = amountIn.mulWadUp(poolParams.swapFee); - uint256 deltaL = - fees.mulWadUp(startReserves.L.divWadDown(startReserves.rx)); - - // amountOut = amountIn.mulWadDown( - // poolParams.price.mulWadDown(ONE - poolParams.swapFee) - // - poolParams.swapFee - // ); + uint256 deltaL = amountIn.mulWadUp(poolParams.swapFee); - amountOut = amountIn.mulWadDown( - (ONE + poolParams.price).mulWadDown(poolParams.swapFee) - ONE + amountOut = amountIn.mulWadDown(poolParams.price).mulWadDown( + ONE - poolParams.swapFee ); - console2.log("Fees: ", fees); - endReserves.rx = startReserves.rx + amountIn; endReserves.L = startReserves.L + deltaL; @@ -75,17 +66,11 @@ contract ConstantSumSolver { if (startReserves.ry < amountOut) revert NotEnoughLiquidity(); endReserves.ry = startReserves.ry - amountOut; } else { - uint256 fees = amountIn.mulWadUp(poolParams.swapFee); uint256 deltaL = - fees.mulWadUp(startReserves.L.divWadDown(startReserves.ry)); - uint256 effectiveAmountIn = amountIn - fees; - - console2.log("Effective amount in: ", effectiveAmountIn); - - amountOut = - effectiveAmountIn.mulWadDown(ONE.divWadDown(poolParams.price)); + amountIn.mulWadUp(poolParams.swapFee).divWadUp(poolParams.price); - console2.log("Fees: ", fees); + amountOut = (ONE - poolParams.swapFee).mulWadDown(amountIn) + .divWadDown(poolParams.price); endReserves.ry = startReserves.ry + amountIn; endReserves.L = startReserves.L + deltaL; diff --git a/src/strategies/ConstantSum/ConstantSum.sol b/src/strategies/ConstantSum/ConstantSum.sol index 8ff3b551..1142bff3 100644 --- a/src/strategies/ConstantSum/ConstantSum.sol +++ b/src/strategies/ConstantSum/ConstantSum.sol @@ -100,13 +100,13 @@ contract ConstantSum is IStrategy { console2.log("amountIn in validate: ", amountIn); fees = amountIn.mulWadUp(params.swapFee); console2.log("fees in validate: ", fees); - minLiquidityDelta += fees.mulWadUp(startL).divWadUp(startRx); + minLiquidityDelta += fees; } else if (nextRy > startRy) { amountIn = nextRy - startRy; console2.log("amountIn in validate: ", amountIn); fees = amountIn.mulWadUp(params.swapFee); console2.log("fees in validate: ", fees); - minLiquidityDelta += fees.mulWadUp(startL).divWadUp(startRy); + minLiquidityDelta += fees.divWadUp(params.price); } else { revert("invalid swap: inputs x and y have the same sign!"); } diff --git a/src/test/ConstantSum/ConstantSumTest.t.sol b/src/test/ConstantSum/ConstantSumTest.t.sol index 0753029e..4fc22136 100644 --- a/src/test/ConstantSum/ConstantSumTest.t.sol +++ b/src/test/ConstantSum/ConstantSumTest.t.sol @@ -39,11 +39,7 @@ contract ConstantSumTest is Test { vm.warp(0); ConstantSum.ConstantSumParams memory params = ConstantSum - .ConstantSumParams({ - price: ONE * 2, - swapFee: TEST_SWAP_FEE, - controller: address(0) - }); + .ConstantSumParams({ price: ONE * 2, swapFee: 0, controller: address(0) }); uint256 init_x = ONE * 1; uint256 init_y = ONE * 1; @@ -106,7 +102,7 @@ contract ConstantSumTest is Test { assertEq(initL, 1.5 ether); } - function test_constant_sum_swap_x_in() public basic_feeless { + function test_constant_sum_swap_x_in_no_fee() public basic_feeless { bool xIn = true; uint256 amountIn = 0.1 ether; uint256 poolId = dfmm.nonce() - 1; @@ -115,11 +111,26 @@ contract ConstantSumTest is Test { console2.log("Valid: ", valid); console2.log("AmountOut: ", amountOut); assert(valid); - assert(amountOut == 200_000_000_000_000_000); + + console2.log("AmountOut: ", amountOut); + assertEq(amountOut, 0.2 ether); + + (uint256 endReservesRx, uint256 endReservesRy, uint256 endReservesL) = + abi.decode(swapData, (uint256, uint256, uint256)); + + console2.log("endReservesRx: ", endReservesRx); + assertEq(endReservesRx, 1.1 ether); + + console2.log("endReservesRy: ", endReservesRy); + assertEq(endReservesRy, 0.8 ether); + + console2.log("endReservesL: ", endReservesL); + assertEq(endReservesL, 1.5 ether); + dfmm.swap(poolId, swapData); } - function test_constant_sum_swap_y_in() public basic_feeless { + function test_constant_sum_swap_y_in_no_fee() public basic_feeless { bool xIn = false; uint256 amountIn = 0.1 ether; uint256 poolId = dfmm.nonce() - 1; @@ -128,7 +139,22 @@ contract ConstantSumTest is Test { console2.log("Valid: ", valid); console2.log("AmountOut: ", amountOut); assert(valid); - assert(amountOut == 50_000_000_000_000_000); + + console2.log("AmountOut: ", amountOut); + assertEq(amountOut, 0.05 ether); + + (uint256 endReservesRx, uint256 endReservesRy, uint256 endReservesL) = + abi.decode(swapData, (uint256, uint256, uint256)); + + console2.log("endReservesRx: ", endReservesRx); + assertEq(endReservesRx, 0.95 ether); + + console2.log("endReservesRy: ", endReservesRy); + assertEq(endReservesRy, 1.1 ether); + + console2.log("endReservesL: ", endReservesL); + assertEq(endReservesL, 1.5 ether); + dfmm.swap(poolId, swapData); } @@ -159,7 +185,7 @@ contract ConstantSumTest is Test { assert(valid); console2.log("AmountOut: ", amountOut); - assertEq(amountOut, 199_100_000_000_000_000); + assertEq(amountOut, 0.1994 ether); (uint256 endReservesRx, uint256 endReservesRy, uint256 endReservesL) = abi.decode(swapData, (uint256, uint256, uint256)); @@ -168,10 +194,10 @@ contract ConstantSumTest is Test { assertEq(endReservesRx, 1.1 ether); console2.log("endReservesRy: ", endReservesRy); - assertEq(endReservesRy, 0.8009 ether); + assertEq(endReservesRy, 0.8006 ether); console2.log("endReservesL: ", endReservesL); - assertEq(endReservesL, 1.50045 ether); + assertEq(endReservesL, 1.5003 ether); dfmm.swap(poolId, swapData); } @@ -185,7 +211,22 @@ contract ConstantSumTest is Test { console2.log("Valid: ", valid); console2.log("AmountOut: ", amountOut); assert(valid); - assert(amountOut == 50_000_000_000_000_000); + + console2.log("AmountOut: ", amountOut); + assertEq(amountOut, 0.04985 ether); + + (uint256 endReservesRx, uint256 endReservesRy, uint256 endReservesL) = + abi.decode(swapData, (uint256, uint256, uint256)); + + console2.log("endReservesRx: ", endReservesRx); + assertEq(endReservesRx, 0.95015 ether); + + console2.log("endReservesRy: ", endReservesRy); + assertEq(endReservesRy, 1.1 ether); + + console2.log("endReservesL: ", endReservesL); + assertEq(endReservesL, 1.50015 ether); + dfmm.swap(poolId, swapData); } }