From 3f3a42a41074115a4b4ec06393e3d4a27f0e0389 Mon Sep 17 00:00:00 2001 From: ume Date: Fri, 13 Dec 2024 15:53:12 +0800 Subject: [PATCH] add origin payer for unxv3 swap --- contracts/8/DexRouter.sol | 2 +- contracts/8/UnxswapV3Router.sol | 120 +++++++++++++++++++------------- src/tests/UniV3Refund.t.sol | 11 +-- 3 files changed, 80 insertions(+), 53 deletions(-) diff --git a/contracts/8/DexRouter.sol b/contracts/8/DexRouter.sol index 61c92e0..4d66a8f 100644 --- a/contracts/8/DexRouter.sol +++ b/contracts/8/DexRouter.sol @@ -679,7 +679,7 @@ contract DexRouter is amount ); - (uint256 swappedAmount, ) = _uniswapV3Swap( + uint256 swappedAmount = _uniswapV3Swap( payer, payable(middleReceiver), amount, diff --git a/contracts/8/UnxswapV3Router.sol b/contracts/8/UnxswapV3Router.sol index 7faa752..e9bf338 100644 --- a/contracts/8/UnxswapV3Router.sol +++ b/contracts/8/UnxswapV3Router.sol @@ -45,7 +45,6 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { /// @param minReturn The minimum amount of tokens that must be received for the swap to be valid, safeguarding against excessive slippage. /// @param pools An array of pool identifiers defining the swap route within Uniswap V3. /// @return returnAmount The amount of tokens received from the swap. - /// @return srcTokenAddr The address of the source token used for the swap. /// @dev This internal function encapsulates the core logic for executing swaps on Uniswap V3. It is intended to be used by other functions in the contract that prepare and pass the necessary parameters. The function handles the swapping process, ensuring that the minimum return is met and managing the transfer of tokens. function _uniswapV3Swap( address payer, @@ -53,7 +52,7 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { uint256 amount, uint256 minReturn, uint256[] calldata pools - ) internal returns (uint256 returnAmount, address srcTokenAddr) { + ) internal returns (uint256 returnAmount) { assembly { function _revertWithReason(m, len) { mstore( @@ -67,7 +66,7 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { mstore(0x40, m) revert(0, len) } - function _makeSwap(_receiver, _payer, _pool, _amount) + function _makeSwap(_receiver, _payer, _payerOrigin, _pool, _amount) -> _returnAmount { if lt(_INT256_MAX, _amount) { @@ -90,9 +89,18 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { mstore(add(paramPtr, 0x40), _amount) mstore(add(paramPtr, 0x60), _MIN_SQRT_RATIO) mstore(add(paramPtr, 0x80), 0xa0) - mstore(add(paramPtr, 0xa0), 32) + mstore(add(paramPtr, 0xa0), 64) mstore(add(paramPtr, 0xc0), _payer) - let success := call(gas(), poolAddr, 0, freePtr, 0xe4, 0, 0) + mstore(add(paramPtr, 0xe0), _payerOrigin) + let success := call( + gas(), + poolAddr, + 0, + freePtr, + 0x104, + 0, + 0 + ) if iszero(success) { revert(0, 32) } @@ -106,9 +114,18 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { mstore(add(paramPtr, 0x40), _amount) mstore(add(paramPtr, 0x60), _MAX_SQRT_RATIO) mstore(add(paramPtr, 0x80), 0xa0) - mstore(add(paramPtr, 0xa0), 32) + mstore(add(paramPtr, 0xa0), 64) mstore(add(paramPtr, 0xc0), _payer) - let success := call(gas(), poolAddr, 0, freePtr, 0xe4, 0, 0) + mstore(add(paramPtr, 0xe0), _payerOrigin) + let success := call( + gas(), + poolAddr, + 0, + freePtr, + 0x104, + 0, + 0 + ) if iszero(success) { revert(0, 32) } @@ -211,13 +228,11 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { function _emitEvent( _firstPoolStart, _lastPoolStart, - _returnAmount, - wrapWeth, - unwrapWeth - ) -> srcToken { - srcToken := _ETH + _returnAmount + ) { + let srcToken := _ETH let toToken := _ETH - if eq(wrapWeth, false) { + if eq(callvalue(), 0) { let firstPool := calldataload(_firstPoolStart) switch eq(0, and(firstPool, _ONE_FOR_ZERO_MASK)) case true { @@ -227,7 +242,7 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { srcToken := _token1(firstPool) } } - if eq(unwrapWeth, false) { + if eq(and(calldataload(_lastPoolStart), _WETH_UNWRAP_MASK), 0) { let lastPool := calldataload(_lastPoolStart) switch eq(0, and(lastPool, _ONE_FOR_ZERO_MASK)) case true { @@ -252,6 +267,7 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { } let firstPoolStart let lastPoolStart + { let len := pools.length firstPoolStart := pools.offset // @@ -265,11 +281,13 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { revert(0, 4) } } - - let wrapWeth := gt(callvalue(), 0) - if wrapWeth { - _wrapWeth(amount) - payer := address() + let payerOrigin := payer + { + let wrapWeth := gt(callvalue(), 0) + if wrapWeth { + _wrapWeth(amount) + payer := address() + } } mstore(96, amount) // 96 is not override by _makeSwap, since it only use freePtr memory, and it is not override by unWrapWeth ethier @@ -278,46 +296,51 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { } lt(i, lastPoolStart) { i := add(i, 32) } { - amount := _makeSwap(address(), payer, calldataload(i), amount) - payer := address() - } - let unwrapWeth := gt( - and(calldataload(lastPoolStart), _WETH_UNWRAP_MASK), - 0 - ) // pools[lastIndex] & _WETH_UNWRAP_MASK > 0 - - // last one or only one - switch unwrapWeth - case 1 { - returnAmount := _makeSwap( + amount := _makeSwap( address(), payer, - calldataload(lastPoolStart), + payerOrigin, + calldataload(i), amount ) - _unWrapWeth(receiver, returnAmount) + payer := address() } - case 0 { - returnAmount := _makeSwap( - receiver, - payer, - calldataload(lastPoolStart), - amount - ) + { + let unwrapWeth := gt( + and(calldataload(lastPoolStart), _WETH_UNWRAP_MASK), + 0 + ) // pools[lastIndex] & _WETH_UNWRAP_MASK > 0 + + // last one or only one + switch unwrapWeth + case 1 { + returnAmount := _makeSwap( + address(), + payer, + payerOrigin, + calldataload(lastPoolStart), + amount + ) + _unWrapWeth(receiver, returnAmount) + } + case 0 { + returnAmount := _makeSwap( + receiver, + payer, + payerOrigin, + calldataload(lastPoolStart), + amount + ) + } } + if lt(returnAmount, minReturn) { _revertWithReason( 0x000000164d696e2072657475726e206e6f742072656163686564000000000000, 90 ) // Min return not reached } - srcTokenAddr := _emitEvent( - firstPoolStart, - lastPoolStart, - returnAmount, - wrapWeth, - unwrapWeth - ) + _emitEvent(firstPoolStart, lastPoolStart, returnAmount) } } @@ -343,7 +366,8 @@ contract UnxswapV3Router is IUniswapV3SwapCallback, CommonUtils { } let amount := mload(0) if gt(amount, 0) { - mstore(add(4, emptyPtr), origin()) + let payerOrigin := calldataload(164) + mstore(add(4, emptyPtr), payerOrigin) mstore(add(36, emptyPtr), amount) validateERC20Transfer( call(gas(), token, 0, emptyPtr, 0x44, 0, 0x20) diff --git a/src/tests/UniV3Refund.t.sol b/src/tests/UniV3Refund.t.sol index 145a22c..8939eba 100644 --- a/src/tests/UniV3Refund.t.sol +++ b/src/tests/UniV3Refund.t.sol @@ -39,9 +39,10 @@ contract UniswapV3Test is Test { IERC20(USDC).approve(token_approve, type(uint256).max); deal(USDC, user, 120000 * 10 ** 6); + deal(WETH, user, 0); } - function _test_okx() public { + function test_okx() public { uint256[] memory pools = new uint256[](1); pools[0] = uint256(bytes32(abi.encodePacked(bytes12(0), pool))); vm.prank(user, user); @@ -54,9 +55,10 @@ contract UniswapV3Test is Test { pools ); console2.log("film balance", IERC20(FILM).balanceOf(pool)); + require(IERC20(WETH).balanceOf(user) > 2.33 ether, "not valid"); } - function _test_okx_unxv3() public { + function test_okx_unxv3() public { uint256[] memory pools = new uint256[](2); pools[0] = uint256( bytes32(abi.encodePacked(bytes1(0x00), bytes11(0), pool_usdc)) @@ -71,6 +73,7 @@ contract UniswapV3Test is Test { 390789165003, pools ); + require(IERC20(WETH).balanceOf(user) > 2.33 ether, "not valid"); } struct SwapInfo { @@ -81,7 +84,7 @@ contract UniswapV3Test is Test { PMMLib.PMMSwapRequest[] extraData; } - function test_okx_smartswap() public { + function _test_okx_smartswap() public { uint256 amount = 3 ether; SwapInfo memory swapInfo; swapInfo.baseRequest.fromToken = uint256(uint160(address(ETH_ADDRESS))); @@ -121,7 +124,7 @@ contract UniswapV3Test is Test { ); } - function test_okx_smartswap_usdc() public { + function _test_okx_smartswap_usdc() public { uint256 amount = 120000 * 10 ** 6; SwapInfo memory swapInfo; swapInfo.baseRequest.fromToken = uint256(uint160(address(USDC)));