diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2b5820b..34cfdb8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,7 +4,7 @@ jobs: test: strategy: matrix: - go-version: [ 1.16.x, 1.17.x ] + go-version: [ 1.18.x, 1.19.x, 1.20.x ] runs-on: ubuntu-latest steps: - name: Install Go diff --git a/entities/pool.go b/entities/pool.go index e9afec4..a383b8b 100644 --- a/entities/pool.go +++ b/entities/pool.go @@ -20,13 +20,13 @@ var ( ) type StepComputations struct { - sqrtPriceStartX96 *big.Int - tickNext int - initialized bool - sqrtPriceNextX96 *big.Int - amountIn *big.Int - amountOut *big.Int - deltaL *big.Int + startSqrtP *big.Int + tickNext int + initialized bool + nextSqrtP *big.Int + usedAmount *big.Int + returnedAmount *big.Int + deltaL *big.Int } // Represents a V3 pool @@ -229,23 +229,23 @@ func (p *Pool) GetInputAmount( * @param zeroForOne Whether the amount in is token0 or token1 * @param amountSpecified The amount of the swap, which implicitly configures the swap as exact input (positive), or exact output (negative) * @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap - * @returns amountCalculated + * @returns returnedAmount * @returns sqrtRatioX96 * @returns liquidity * @returns tickCurrent */ -func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) ( +func (p *Pool) swap(isToken0 bool, amountSpecified, sqrtPriceLimitX96 *big.Int) ( amountCalCulated *big.Int, sqrtRatioX96 *big.Int, liquidity, reinvestLiquidity *big.Int, tickCurrent int, err error, ) { if sqrtPriceLimitX96 == nil { - if zeroForOne { + if isToken0 { sqrtPriceLimitX96 = new(big.Int).Add(utils.MinSqrtRatio, constants.One) } else { sqrtPriceLimitX96 = new(big.Int).Sub(utils.MaxSqrtRatio, constants.One) } } - if zeroForOne { + if isToken0 { if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) < 0 { return nil, nil, nil, nil, 0, ErrSqrtPriceLimitX96TooLow } @@ -261,36 +261,36 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int } } - exactInput := amountSpecified.Cmp(constants.Zero) >= 0 + isExactInput := amountSpecified.Cmp(constants.Zero) >= 0 // keep track of swap state state := struct { - amountSpecifiedRemaining *big.Int - amountCalculated *big.Int - sqrtPriceX96 *big.Int - tick int - liquidity *big.Int - reinvestLiquidity *big.Int + specifiedAmount *big.Int + returnedAmount *big.Int + sqrtP *big.Int + tick int + liquidity *big.Int + reinvestL *big.Int }{ - amountSpecifiedRemaining: amountSpecified, - amountCalculated: constants.Zero, - sqrtPriceX96: p.SqrtRatioX96, - tick: p.TickCurrent, - liquidity: p.Liquidity, - reinvestLiquidity: p.ReinvestLiquidity, + specifiedAmount: amountSpecified, + returnedAmount: constants.Zero, + sqrtP: p.SqrtRatioX96, + tick: p.TickCurrent, + liquidity: p.Liquidity, + reinvestL: p.ReinvestLiquidity, } // start swap while loop - for state.amountSpecifiedRemaining.Cmp(constants.Zero) != 0 && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 { + for state.specifiedAmount.Cmp(constants.Zero) != 0 && state.sqrtP.Cmp(sqrtPriceLimitX96) != 0 { var step StepComputations - step.sqrtPriceStartX96 = state.sqrtPriceX96 + step.startSqrtP = state.sqrtP // because each iteration of the while loop rounds, we can't optimize this code (relative to the smart contract) // by simply traversing to the next available tick, we instead need to exactly replicate // tickBitmap.nextInitializedTickWithinOneWord step.tickNext, step.initialized, err = p.TickDataProvider.NextInitializedTickWithinFixedDistance( - state.tick, zeroForOne, 480, + state.tick, isToken0, 480, ) if err != nil { return nil, nil, nil, nil, 0, err @@ -302,39 +302,44 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int step.tickNext = utils.MaxTick } - step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTick(step.tickNext) + step.nextSqrtP, err = utils.GetSqrtRatioAtTick(step.tickNext) if err != nil { return nil, nil, nil, nil, 0, err } var targetValue *big.Int - if zeroForOne { - if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) < 0 { + if isToken0 { + if step.nextSqrtP.Cmp(sqrtPriceLimitX96) < 0 { targetValue = sqrtPriceLimitX96 } else { - targetValue = step.sqrtPriceNextX96 + targetValue = step.nextSqrtP } } else { - if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) > 0 { + if step.nextSqrtP.Cmp(sqrtPriceLimitX96) > 0 { targetValue = sqrtPriceLimitX96 } else { - targetValue = step.sqrtPriceNextX96 + targetValue = step.nextSqrtP } } - state.sqrtPriceX96, step.amountIn, step.amountOut, step.deltaL, err = utils.ComputeSwapStep( - state.sqrtPriceX96, targetValue, new(big.Int).Add(state.liquidity, state.reinvestLiquidity), - state.amountSpecifiedRemaining, p.Fee, exactInput, zeroForOne, + step.usedAmount, step.returnedAmount, step.deltaL, state.sqrtP, err = utils.ComputeSwapStep( + new(big.Int).Add(state.liquidity, state.reinvestL), + state.sqrtP, + targetValue, + p.Fee, + state.specifiedAmount, + isExactInput, + isToken0, ) if err != nil { return nil, nil, nil, nil, 0, err } - state.amountSpecifiedRemaining = new(big.Int).Sub(state.amountSpecifiedRemaining, step.amountIn) - state.amountCalculated = new(big.Int).Add(state.amountCalculated, step.amountOut) - state.reinvestLiquidity = new(big.Int).Add(state.reinvestLiquidity, step.deltaL) + state.specifiedAmount = new(big.Int).Sub(state.specifiedAmount, step.usedAmount) + state.returnedAmount = new(big.Int).Add(state.returnedAmount, step.returnedAmount) + state.reinvestL = new(big.Int).Add(state.reinvestL, step.deltaL) // TODO - if state.sqrtPriceX96.Cmp(step.sqrtPriceNextX96) == 0 { + if state.sqrtP.Cmp(step.nextSqrtP) == 0 { // if the tick is initialized, run the tick transition if step.initialized { tick, err := p.TickDataProvider.GetTick(step.tickNext) @@ -345,25 +350,25 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int liquidityNet := tick.LiquidityNet // if we're moving leftward, we interpret liquidityNet as the opposite sign // safe because liquidityNet cannot be type(int128).min - if zeroForOne { + if isToken0 { liquidityNet = new(big.Int).Mul(liquidityNet, constants.NegativeOne) } state.liquidity = utils.AddDelta(state.liquidity, liquidityNet) } - if zeroForOne { + if isToken0 { state.tick = step.tickNext - 1 } else { state.tick = step.tickNext } } else { // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved - state.tick, err = utils.GetTickAtSqrtRatio(state.sqrtPriceX96) + state.tick, err = utils.GetTickAtSqrtRatio(state.sqrtP) if err != nil { return nil, nil, nil, nil, 0, err } } } - return state.amountCalculated, state.sqrtPriceX96, state.liquidity, state.reinvestLiquidity, state.tick, nil + return state.returnedAmount, state.sqrtP, state.liquidity, state.reinvestL, state.tick, nil } func (p *Pool) tickSpacing() int { diff --git a/utils/swap_math.go b/utils/swap_math.go index 86b5cb0..df47040 100644 --- a/utils/swap_math.go +++ b/utils/swap_math.go @@ -12,49 +12,88 @@ var TwoFeeUnits = new(big.Int).Mul(FeeUnits, big.NewInt(2)) // ComputeSwapStep computes the actual swap input / output amounts to be deducted or added, // the swap fee to be collected and the resulting sqrtP func ComputeSwapStep( - sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, amountRemaining *big.Int, feeInUnits constants.FeeAmount, - exactIn, isToken0 bool, -) (sqrtRatioNextX96, amountIn, amountOut, deltaL *big.Int, err error) { + liquidity *big.Int, + currentSqrtP *big.Int, + targetSqrtP *big.Int, + feeInFeeUnits constants.FeeAmount, + specifiedAmount *big.Int, + isExactInput bool, + isToken0 bool, +) ( + usedAmount *big.Int, + returnedAmount *big.Int, + deltaL *big.Int, + nextSqrtP *big.Int, + err error, +) { // in the event currentSqrtP == targetSqrtP because of tick movements, return // e.g. swapped up tick where specified price limit is on an initialised tick // then swapping down tick will cause next tick to be the same as the current tick - if sqrtRatioCurrentX96.Cmp(sqrtRatioTargetX96) == 0 { - return sqrtRatioCurrentX96, constants.Zero, constants.Zero, constants.Zero, nil + if currentSqrtP.Cmp(targetSqrtP) == 0 { + return currentSqrtP, constants.Zero, constants.Zero, constants.Zero, nil } - sqrtRatioNextX96 = big.NewInt(0) + nextSqrtP = big.NewInt(0) - usedAmount := calcReachAmount(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, feeInUnits, exactIn, isToken0) + usedAmount = calcReachAmount( + liquidity, + currentSqrtP, + targetSqrtP, + feeInFeeUnits, + isExactInput, + isToken0, + ) - if exactIn && usedAmount.Cmp(amountRemaining) >= 0 || (!exactIn && usedAmount.Cmp(amountRemaining) <= 0) { - usedAmount = amountRemaining + if isExactInput && usedAmount.Cmp(specifiedAmount) > 0 || (!isExactInput && usedAmount.Cmp(specifiedAmount) <= 0) { + usedAmount = specifiedAmount } else { - sqrtRatioNextX96 = sqrtRatioTargetX96 + nextSqrtP = targetSqrtP } - amountIn = usedAmount - - var absUsedAmount *big.Int - + var absDelta *big.Int if usedAmount.Cmp(constants.Zero) >= 0 { - absUsedAmount = usedAmount + absDelta = usedAmount } else { - absUsedAmount = new(big.Int).Mul(usedAmount, constants.NegativeOne) + absDelta = new(big.Int).Mul(usedAmount, constants.NegativeOne) } - if sqrtRatioNextX96.Cmp(constants.Zero) == 0 { + if nextSqrtP.Cmp(constants.Zero) == 0 { deltaL = estimateIncrementalLiquidity( - absUsedAmount, liquidity, sqrtRatioCurrentX96, feeInUnits, exactIn, isToken0, + absDelta, + liquidity, + currentSqrtP, + feeInFeeUnits, + isExactInput, + isToken0, ) - sqrtRatioNextX96 = calcFinalPrice(absUsedAmount, liquidity, deltaL, sqrtRatioCurrentX96, exactIn, isToken0) + nextSqrtP = calcFinalPrice( + absDelta, + liquidity, + deltaL, + currentSqrtP, + isExactInput, + isToken0, + ) } else { deltaL = calcIncrementalLiquidity( - sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, absUsedAmount, exactIn, isToken0, + absDelta, + liquidity, + currentSqrtP, + nextSqrtP, + isExactInput, + isToken0, ) } - amountOut = calcReturnedAmount(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, deltaL, exactIn, isToken0) + returnedAmount = calcReturnedAmount( + liquidity, + currentSqrtP, + nextSqrtP, + deltaL, + isExactInput, + isToken0, + ) return } @@ -62,18 +101,22 @@ func ComputeSwapStep( // calcReachAmount calculates the amount needed to reach targetSqrtP from currentSqrtP // we cast currentSqrtP and targetSqrtP to uint256 as they are multiplied by TWO_FEE_UNITS or feeInFeeUnits func calcReachAmount( - sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity *big.Int, feeInUnits constants.FeeAmount, - exactIn, isToken0 bool, + liquidity *big.Int, + currentSqrtP *big.Int, + targetSqrtP *big.Int, + feeInFeeUnits constants.FeeAmount, + isExactInput bool, + isToken0 bool, ) (reachAmount *big.Int) { var absPriceDiff *big.Int - if sqrtRatioCurrentX96.Cmp(sqrtRatioTargetX96) >= 0 { - absPriceDiff = new(big.Int).Sub(sqrtRatioCurrentX96, sqrtRatioTargetX96) + if currentSqrtP.Cmp(targetSqrtP) >= 0 { + absPriceDiff = new(big.Int).Sub(currentSqrtP, targetSqrtP) } else { - absPriceDiff = new(big.Int).Sub(sqrtRatioTargetX96, sqrtRatioCurrentX96) + absPriceDiff = new(big.Int).Sub(targetSqrtP, currentSqrtP) } - if exactIn { + if isExactInput { // we round down so that we avoid taking giving away too much for the specified input // i.e. require less input qty to move ticks if isToken0 { @@ -82,24 +125,24 @@ func calcReachAmount( // denominator = currentSqrtP * (2 * targetSqrtP - currentSqrtP * feeInFeeUnits / FEE_UNITS) // overflow should not happen because the absPriceDiff is capped to ~5% denominator := new(big.Int).Sub( - new(big.Int).Mul(TwoFeeUnits, sqrtRatioTargetX96), - new(big.Int).Mul(big.NewInt(int64(feeInUnits)), sqrtRatioCurrentX96), + new(big.Int).Mul(TwoFeeUnits, targetSqrtP), + new(big.Int).Mul(big.NewInt(int64(feeInFeeUnits)), currentSqrtP), ) numerator := MulDiv(liquidity, new(big.Int).Mul(TwoFeeUnits, absPriceDiff), denominator) - reachAmount = MulDiv(numerator, constants.Q96, sqrtRatioCurrentX96) + reachAmount = MulDiv(numerator, constants.Q96, currentSqrtP) } else { // exactInput + swap 1 -> 0 // numerator: liquidity * absPriceDiff * (TWO_FEE_UNITS * targetSqrtP - feeInFeeUnits * (targetSqrtP + currentSqrtP)) // denominator: (TWO_FEE_UNITS * targetSqrtP - feeInFeeUnits * currentSqrtP) // overflow should not happen because the absPriceDiff is capped to ~5% denominator := new(big.Int).Sub( - new(big.Int).Mul(TwoFeeUnits, sqrtRatioCurrentX96), - new(big.Int).Mul(big.NewInt(int64(feeInUnits)), sqrtRatioTargetX96), + new(big.Int).Mul(TwoFeeUnits, currentSqrtP), + new(big.Int).Mul(big.NewInt(int64(feeInFeeUnits)), targetSqrtP), ) numerator := MulDiv(liquidity, new(big.Int).Mul(TwoFeeUnits, absPriceDiff), denominator) - reachAmount = MulDiv(numerator, sqrtRatioCurrentX96, constants.Q96) + reachAmount = MulDiv(numerator, currentSqrtP, constants.Q96) } } else { // we will perform negation as the last step @@ -110,15 +153,15 @@ func calcReachAmount( // denominator: (currentSqrtP * targetSqrtP) * (2 * currentSqrtP - deltaL * targetSqrtP) // overflow should not happen because the absPriceDiff is capped to ~5% denominator := new(big.Int).Sub( - new(big.Int).Mul(TwoFeeUnits, sqrtRatioCurrentX96), - new(big.Int).Mul(big.NewInt(int64(feeInUnits)), sqrtRatioTargetX96), + new(big.Int).Mul(TwoFeeUnits, currentSqrtP), + new(big.Int).Mul(big.NewInt(int64(feeInFeeUnits)), targetSqrtP), ) numerator := new(big.Int).Sub( - denominator, new(big.Int).Mul(big.NewInt(int64(feeInUnits)), sqrtRatioCurrentX96), + denominator, new(big.Int).Mul(big.NewInt(int64(feeInFeeUnits)), currentSqrtP), ) numerator = MulDiv(new(big.Int).Lsh(liquidity, 96), numerator, denominator) - reachAmount = new(big.Int).Div(MulDiv(numerator, absPriceDiff, sqrtRatioCurrentX96), sqrtRatioTargetX96) + reachAmount = new(big.Int).Div(MulDiv(numerator, absPriceDiff, currentSqrtP), targetSqrtP) reachAmount = new(big.Int).Mul(reachAmount, constants.NegativeOne) } else { // exactOut + swap 1 -> 0 @@ -126,11 +169,11 @@ func calcReachAmount( // denominator: (TWO_FEE_UNITS * targetSqrtP - feeInFeeUnits * currentSqrtP) // overflow should not happen because the absPriceDiff is capped to ~5% denominator := new(big.Int).Sub( - new(big.Int).Mul(TwoFeeUnits, sqrtRatioTargetX96), - new(big.Int).Mul(big.NewInt(int64(feeInUnits)), sqrtRatioCurrentX96), + new(big.Int).Mul(TwoFeeUnits, targetSqrtP), + new(big.Int).Mul(big.NewInt(int64(feeInFeeUnits)), currentSqrtP), ) numerator := new(big.Int).Sub( - denominator, new(big.Int).Mul(big.NewInt(int64(feeInUnits)), sqrtRatioTargetX96), + denominator, new(big.Int).Mul(big.NewInt(int64(feeInFeeUnits)), targetSqrtP), ) numerator = MulDiv(liquidity, numerator, denominator) @@ -142,54 +185,68 @@ func calcReachAmount( return reachAmount } -// calcReturnedAmount calculates returned output | input tokens in exchange for specified amount -// round down when calculating returned output (isExactInput) so we avoid sending too much -// round up when calculating returned input (!isExactInput) so we get desired output amount -func calcReturnedAmount( - sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, deltaL *big.Int, exactIn, isToken0 bool, -) (returnedAmount *big.Int) { - if isToken0 { - if exactIn { - // minimise actual output (<0, make less negative) so we avoid sending too much - // returnedAmount = deltaL * nextSqrtP - liquidity * (currentSqrtP - nextSqrtP) - returnedAmount = new(big.Int).Add( - MulDivRoundingUp(deltaL, sqrtRatioTargetX96, constants.Q96), - new(big.Int).Mul( - MulDiv( - liquidity, new(big.Int).Sub(sqrtRatioCurrentX96, sqrtRatioTargetX96), constants.Q96, - ), constants.NegativeOne, - ), - ) +// estimateIncrementalLiquidity estimates deltaL, the swap fee to be collected based on amount specified +// for the final swap step to be performed, +// where the next (temporary) tick will not be crossed +func estimateIncrementalLiquidity( + absDelta *big.Int, + liquidity *big.Int, + currentSqrtP *big.Int, + feeInFeeUnits constants.FeeAmount, + isExactInput bool, + isToken0 bool, +) (deltaL *big.Int) { + // this is when we didn't reach the target (last step before loop end), then we have to recalculate the target_X96, deltaL ... + fee := big.NewInt(int64(feeInFeeUnits)) + + if isExactInput { + if isToken0 { + // deltaL = feeInFeeUnits * absDelta * currentSqrtP / 2 + deltaL = MulDiv(currentSqrtP, new(big.Int).Mul(absDelta, fee), new(big.Int).Lsh(TwoFeeUnits, 96)) } else { - // maximise actual input (>0) so we get desired output amount - // returnedAmount = deltaL * nextSqrtP + liquidity * (nextSqrtP - currentSqrtP) - returnedAmount = new(big.Int).Add( - MulDivRoundingUp(deltaL, sqrtRatioTargetX96, constants.Q96), - MulDivRoundingUp(liquidity, new(big.Int).Sub(sqrtRatioTargetX96, sqrtRatioCurrentX96), constants.Q96), + // deltaL = feeInFeeUnits * absDelta * / (currentSqrtP * 2) + // Because nextSqrtP = (liquidity + absDelta / currentSqrtP) * currentSqrtP / (liquidity + deltaL) + // so we round down deltaL, to round up nextSqrtP + deltaL = MulDivRoundingDown( + constants.Q96, new(big.Int).Mul(absDelta, fee), new(big.Int).Mul(TwoFeeUnits, currentSqrtP), ) } } else { - // returnedAmount = (liquidity + deltaL)/nextSqrtP - (liquidity)/currentSqrtP - // if exactInput, minimise actual output (<0, make less negative) so we avoid sending too much - // if exactOutput, maximise actual input (>0) so we get desired output amount - returnedAmount = new(big.Int).Add( - MulDivRoundingUp(new(big.Int).Add(liquidity, deltaL), constants.Q96, sqrtRatioTargetX96), - new(big.Int).Mul(MulDivRoundingUp(liquidity, constants.Q96, sqrtRatioCurrentX96), constants.NegativeOne), - ) - } + // obtain the smaller root of the quadratic equation + // ax^2 - 2bx + c = 0 such that b > 0, and x denotes deltaL + a := fee + b := new(big.Int).Mul(new(big.Int).Sub(FeeUnits, fee), liquidity) + c := new(big.Int).Mul(new(big.Int).Mul(fee, liquidity), absDelta) - if exactIn && returnedAmount.Cmp(constants.One) == 0 { - // rounding make returnedAmount == 1 - returnedAmount = constants.Zero + if isToken0 { + // a = feeInFeeUnits + // b = (FEE_UNITS - feeInFeeUnits) * liquidity - FEE_UNITS * absDelta * currentSqrtP + // c = feeInFeeUnits * liquidity * absDelta * currentSqrtP + b = new(big.Int).Sub(b, MulDiv(new(big.Int).Mul(FeeUnits, absDelta), currentSqrtP, constants.Q96)) + c = MulDiv(c, currentSqrtP, constants.Q96) + } else { + // a = feeInFeeUnits + // b = (FEE_UNITS - feeInFeeUnits) * liquidity - FEE_UNITS * absDelta / currentSqrtP + // c = liquidity * feeInFeeUnits * absDelta / currentSqrtP + b = new(big.Int).Sub(b, MulDiv(new(big.Int).Mul(FeeUnits, absDelta), constants.Q96, currentSqrtP)) + c = MulDiv(c, constants.Q96, currentSqrtP) + } + + deltaL = GetSmallerRootOfQuadEqn(a, b, c) } - return returnedAmount + return deltaL } // calcIncrementalLiquidity calculates deltaL, the swap fee to be collected for an intermediate swap step, // where the next (temporary) tick will be crossed func calcIncrementalLiquidity( - sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, absAmount *big.Int, exactIn, isToken0 bool, + absDelta *big.Int, + liquidity *big.Int, + currentSqrtP *big.Int, + nextSqrtP *big.Int, + isExactInput bool, + isToken0 bool, ) (deltaL *big.Int) { var tmp1, tmp2, tmp3 *big.Int @@ -197,13 +254,13 @@ func calcIncrementalLiquidity( if isToken0 { // deltaL = nextSqrtP * (liquidity / currentSqrtP +/- absDelta)) - liquidity // needs to be minimum - tmp1 = MulDiv(liquidity, constants.Q96, sqrtRatioCurrentX96) - if exactIn { - tmp2 = new(big.Int).Add(tmp1, absAmount) + tmp1 = MulDiv(liquidity, constants.Q96, currentSqrtP) + if isExactInput { + tmp2 = new(big.Int).Add(tmp1, absDelta) } else { - tmp2 = new(big.Int).Sub(tmp1, absAmount) + tmp2 = new(big.Int).Sub(tmp1, absDelta) } - tmp3 = MulDiv(sqrtRatioTargetX96, tmp2, constants.Q96) + tmp3 = MulDiv(nextSqrtP, tmp2, constants.Q96) // in edge cases where liquidity or absDelta is small // liquidity might be greater than nextSqrtP * ((liquidity / currentSqrtP) +/- absDelta)) @@ -216,13 +273,13 @@ func calcIncrementalLiquidity( } else { // deltaL = (liquidity * currentSqrtP +/- absDelta) / nextSqrtP - liquidity // needs to be minimum - tmp1 = MulDiv(liquidity, sqrtRatioCurrentX96, constants.Q96) - if exactIn { - tmp2 = new(big.Int).Add(tmp1, absAmount) + tmp1 = MulDiv(liquidity, currentSqrtP, constants.Q96) + if isExactInput { + tmp2 = new(big.Int).Add(tmp1, absDelta) } else { - tmp2 = new(big.Int).Sub(tmp1, absAmount) + tmp2 = new(big.Int).Sub(tmp1, absDelta) } - tmp3 = MulDiv(tmp2, constants.Q96, sqrtRatioTargetX96) + tmp3 = MulDiv(tmp2, constants.Q96, nextSqrtP) // in edge cases where liquidity or absDelta is small // liquidity might be greater than nextSqrtP * ((liquidity / currentSqrtP) +/- absDelta)) @@ -237,98 +294,104 @@ func calcIncrementalLiquidity( return deltaL } -// estimateIncrementalLiquidity estimates deltaL, the swap fee to be collected based on amount specified -// for the final swap step to be performed, -// where the next (temporary) tick will not be crossed -func estimateIncrementalLiquidity( - absAmount, liquidity, sqrtRatioCurrentX96 *big.Int, feeInUnits constants.FeeAmount, exactIn, isToken0 bool, -) (deltaL *big.Int) { - // this is when we didn't reach the target (last step before loop end), then we have to recalculate the target_X96, deltaL ... - fee := big.NewInt(int64(feeInUnits)) +// calcFinalPrice calculates returned output | input tokens in exchange for specified amount +// round down when calculating returned output (isExactInput) so we avoid sending too much +// round up when calculating returned input (!isExactInput) so we get desired output amount +func calcFinalPrice( + absDelta *big.Int, + liquidity *big.Int, + deltaL *big.Int, + currentSqrtP *big.Int, + isExactInput bool, + isToken0 bool, +) *big.Int { + finalPrice := constants.Zero - if exactIn { - if isToken0 { - // deltaL = feeInFeeUnits * absDelta * currentSqrtP / 2 - deltaL = MulDiv(sqrtRatioCurrentX96, new(big.Int).Mul(absAmount, fee), new(big.Int).Lsh(TwoFeeUnits, 96)) + if isToken0 { + tmp := MulDiv(absDelta, currentSqrtP, constants.Q96) + + if isExactInput { + // minimise actual output (<0, make less negative) so we avoid sending too much + // returnedAmount = deltaL * nextSqrtP - liquidity * (currentSqrtP - nextSqrtP) + finalPrice = MulDivRoundingUp( + new(big.Int).Add(liquidity, deltaL), currentSqrtP, new(big.Int).Add(liquidity, tmp), + ) } else { - // deltaL = feeInFeeUnits * absDelta * / (currentSqrtP * 2) - // Because nextSqrtP = (liquidity + absDelta / currentSqrtP) * currentSqrtP / (liquidity + deltaL) - // so we round down deltaL, to round up nextSqrtP - deltaL = MulDivRoundingDown( - constants.Q96, new(big.Int).Mul(absAmount, fee), new(big.Int).Mul(TwoFeeUnits, sqrtRatioCurrentX96), + // maximise actual input (>0) so we get desired output amount + // returnedAmount = deltaL * nextSqrtP + liquidity * (nextSqrtP - currentSqrtP) + finalPrice = MulDiv( + new(big.Int).Add(liquidity, deltaL), currentSqrtP, new(big.Int).Sub(liquidity, tmp), ) } } else { - // obtain the smaller root of the quadratic equation - // ax^2 - 2bx + c = 0 such that b > 0, and x denotes deltaL - a := fee - b := new(big.Int).Mul(new(big.Int).Sub(FeeUnits, fee), liquidity) - c := new(big.Int).Mul(new(big.Int).Mul(fee, liquidity), absAmount) + // returnedAmount = (liquidity + deltaL)/nextSqrtP - (liquidity)/currentSqrtP + // if exactInput, minimise actual output (<0, make less negative) so we avoid sending too much + // if exactOutput, maximise actual input (>0) so we get desired output amount + tmp := MulDiv(absDelta, constants.Q96, currentSqrtP) - if isToken0 { - // a = feeInFeeUnits - // b = (FEE_UNITS - feeInFeeUnits) * liquidity - FEE_UNITS * absDelta * currentSqrtP - // c = feeInFeeUnits * liquidity * absDelta * currentSqrtP - b = new(big.Int).Sub(b, MulDiv(new(big.Int).Mul(FeeUnits, absAmount), sqrtRatioCurrentX96, constants.Q96)) - c = MulDiv(c, sqrtRatioCurrentX96, constants.Q96) + if isExactInput { + finalPrice = MulDiv( + new(big.Int).Add(liquidity, tmp), currentSqrtP, new(big.Int).Add(liquidity, deltaL), + ) } else { - // a = feeInFeeUnits - // b = (FEE_UNITS - feeInFeeUnits) * liquidity - FEE_UNITS * absDelta / currentSqrtP - // c = liquidity * feeInFeeUnits * absDelta / currentSqrtP - b = new(big.Int).Sub(b, MulDiv(new(big.Int).Mul(FeeUnits, absAmount), constants.Q96, sqrtRatioCurrentX96)) - c = MulDiv(c, constants.Q96, sqrtRatioCurrentX96) + finalPrice = MulDivRoundingUp( + new(big.Int).Sub(liquidity, tmp), currentSqrtP, new(big.Int).Add(liquidity, deltaL), + ) } + } - deltaL = GetSmallerRootOfQuadEqn(a, b, c) + if isExactInput && finalPrice.Cmp(constants.One) == 0 { + finalPrice = constants.Zero } - return deltaL + return finalPrice } -// calcFinalPrice calculates returned output | input tokens in exchange for specified amount +// calcReturnedAmount calculates returned output | input tokens in exchange for specified amount // round down when calculating returned output (isExactInput) so we avoid sending too much // round up when calculating returned input (!isExactInput) so we get desired output amount -func calcFinalPrice( - absAmount, liquidity, deltaL, sqrtRatioCurrentX96 *big.Int, exactIn, isToken0 bool, -) (returnAmount *big.Int) { - returnAmount = constants.Zero - +func calcReturnedAmount( + liquidity *big.Int, + currentSqrtP *big.Int, + nextSqrtP *big.Int, + deltaL *big.Int, + isExactInput bool, + isToken0 bool, +) (returnedAmount *big.Int) { if isToken0 { - tmp := MulDiv(absAmount, sqrtRatioCurrentX96, constants.Q96) - - if exactIn { + if isExactInput { // minimise actual output (<0, make less negative) so we avoid sending too much // returnedAmount = deltaL * nextSqrtP - liquidity * (currentSqrtP - nextSqrtP) - returnAmount = MulDivRoundingUp( - new(big.Int).Add(liquidity, deltaL), sqrtRatioCurrentX96, new(big.Int).Add(liquidity, tmp), + returnedAmount = new(big.Int).Add( + MulDivRoundingUp(deltaL, nextSqrtP, constants.Q96), + new(big.Int).Mul( + MulDiv( + liquidity, new(big.Int).Sub(currentSqrtP, nextSqrtP), constants.Q96, + ), constants.NegativeOne, + ), ) } else { // maximise actual input (>0) so we get desired output amount // returnedAmount = deltaL * nextSqrtP + liquidity * (nextSqrtP - currentSqrtP) - returnAmount = MulDiv( - new(big.Int).Add(liquidity, deltaL), sqrtRatioCurrentX96, new(big.Int).Sub(liquidity, tmp), + returnedAmount = new(big.Int).Add( + MulDivRoundingUp(deltaL, nextSqrtP, constants.Q96), + MulDivRoundingUp(liquidity, new(big.Int).Sub(nextSqrtP, currentSqrtP), constants.Q96), ) } } else { // returnedAmount = (liquidity + deltaL)/nextSqrtP - (liquidity)/currentSqrtP // if exactInput, minimise actual output (<0, make less negative) so we avoid sending too much // if exactOutput, maximise actual input (>0) so we get desired output amount - tmp := MulDiv(absAmount, constants.Q96, sqrtRatioCurrentX96) - - if exactIn { - returnAmount = MulDiv( - new(big.Int).Add(liquidity, tmp), sqrtRatioCurrentX96, new(big.Int).Add(liquidity, deltaL), - ) - } else { - returnAmount = MulDivRoundingUp( - new(big.Int).Sub(liquidity, tmp), sqrtRatioCurrentX96, new(big.Int).Add(liquidity, deltaL), - ) - } + returnedAmount = new(big.Int).Add( + MulDivRoundingUp(new(big.Int).Add(liquidity, deltaL), constants.Q96, nextSqrtP), + new(big.Int).Mul(MulDivRoundingUp(liquidity, constants.Q96, currentSqrtP), constants.NegativeOne), + ) } - if exactIn && returnAmount.Cmp(constants.One) == 0 { - returnAmount = constants.Zero + if isExactInput && returnedAmount.Cmp(constants.One) == 0 { + // rounding make returnedAmount == 1 + returnedAmount = constants.Zero } - return returnAmount + return returnedAmount }