Skip to content

Commit

Permalink
optimize logic and reduce alloc (#2)
Browse files Browse the repository at this point in the history
* improve GetOutputAmount, MulDiv, MulDivRoundingUp

* reduce alloc

* add comment

* add RemainingAmountIn
  • Loading branch information
it4rb authored Mar 25, 2024
1 parent 9e9d779 commit 4caa4e0
Show file tree
Hide file tree
Showing 11 changed files with 423 additions and 139 deletions.
69 changes: 49 additions & 20 deletions entities/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ var (
)

type StepComputations struct {
sqrtPriceStartX96 *utils.Uint160
sqrtPriceStartX96 utils.Uint160
tickNext int
initialized bool
sqrtPriceNextX96 *utils.Uint160
amountIn *utils.Uint256
amountOut *utils.Uint256
feeAmount *utils.Uint256
sqrtPriceNextX96 utils.Uint160
amountIn utils.Uint256
amountOut utils.Uint256
feeAmount utils.Uint256
}

// Represents a V3 pool
Expand Down Expand Up @@ -60,6 +60,15 @@ type GetAmountResult struct {
CrossInitTickLoops int
}

type GetAmountResultV2 struct {
ReturnedAmount *utils.Int256
RemainingAmountIn *utils.Int256
SqrtRatioX96 *utils.Uint160
Liquidity *utils.Uint128
CurrentTick int
CrossInitTickLoops int
}

func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCodeHashManualOverride string) (common.Address, error) {
return utils.ComputePoolAddress(constants.FactoryAddress, tokenA, tokenB, fee, initCodeHashManualOverride)
}
Expand Down Expand Up @@ -90,16 +99,17 @@ func NewPoolV2(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRati
return nil, ErrFeeTooHigh
}

tickCurrentSqrtRatioX96, err := utils.GetSqrtRatioAtTickV2(tickCurrent)
var tickCurrentSqrtRatioX96, nextTickSqrtRatioX96 utils.Uint160
err := utils.GetSqrtRatioAtTickV2(tickCurrent, &tickCurrentSqrtRatioX96)
if err != nil {
return nil, err
}
nextTickSqrtRatioX96, err := utils.GetSqrtRatioAtTickV2(tickCurrent + 1)
err = utils.GetSqrtRatioAtTickV2(tickCurrent+1, &nextTickSqrtRatioX96)
if err != nil {
return nil, err
}

if sqrtRatioX96.Cmp(tickCurrentSqrtRatioX96) < 0 || sqrtRatioX96.Cmp(nextTickSqrtRatioX96) > 0 {
if sqrtRatioX96.Cmp(&tickCurrentSqrtRatioX96) < 0 || sqrtRatioX96.Cmp(&nextTickSqrtRatioX96) > 0 {
return nil, ErrInvalidSqrtRatioX96
}
token0 := tokenA
Expand Down Expand Up @@ -216,6 +226,21 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi
}, nil
}

func (p *Pool) GetOutputAmountV2(inputAmount *utils.Int256, zeroForOne bool, sqrtPriceLimitX96 *utils.Uint160) (*GetAmountResultV2, error) {
swapResult, err := p.swap(zeroForOne, inputAmount, sqrtPriceLimitX96)
if err != nil {
return nil, err
}
return &GetAmountResultV2{
ReturnedAmount: new(utils.Int256).Neg(swapResult.amountCalculated),
RemainingAmountIn: new(utils.Int256).Set(swapResult.remainingAmountIn),
SqrtRatioX96: swapResult.sqrtRatioX96,
Liquidity: swapResult.liquidity,
CurrentTick: swapResult.currentTick,
CrossInitTickLoops: swapResult.crossInitTickLoops,
}, nil
}

/**
* Given a desired output amount of a token, return the computed input amount and a pool with state updated after the trade
* @param outputAmount the output amount for which to quote the input amount
Expand Down Expand Up @@ -318,7 +343,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLim
// start swap while loop
for !state.amountSpecifiedRemaining.IsZero() && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 {
var step StepComputations
step.sqrtPriceStartX96 = state.sqrtPriceX96
step.sqrtPriceStartX96.Set(state.sqrtPriceX96)

// because each iteration of the while loop rounds, we can't optimize this code (relative to the smart contract)
// by simply traversing to the next available tick, we instead need to exactly replicate
Expand All @@ -334,32 +359,35 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLim
step.tickNext = utils.MaxTick
}

step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTickV2(step.tickNext)
err = utils.GetSqrtRatioAtTickV2(step.tickNext, &step.sqrtPriceNextX96)
if err != nil {
return nil, err
}
var targetValue *utils.Uint160
var targetValue utils.Uint160
if zeroForOne {
if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) < 0 {
targetValue = sqrtPriceLimitX96
targetValue.Set(sqrtPriceLimitX96)
} else {
targetValue = step.sqrtPriceNextX96
targetValue.Set(&step.sqrtPriceNextX96)
}
} else {
if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) > 0 {
targetValue = sqrtPriceLimitX96
targetValue.Set(sqrtPriceLimitX96)
} else {
targetValue = step.sqrtPriceNextX96
targetValue.Set(&step.sqrtPriceNextX96)
}
}

state.sqrtPriceX96, step.amountIn, step.amountOut, step.feeAmount, err = utils.ComputeSwapStep(state.sqrtPriceX96, targetValue, state.liquidity, state.amountSpecifiedRemaining, p.Fee)
var nxtSqrtPriceX96 utils.Uint160
err = utils.ComputeSwapStep(state.sqrtPriceX96, &targetValue, state.liquidity, state.amountSpecifiedRemaining, p.Fee,
&nxtSqrtPriceX96, &step.amountIn, &step.amountOut, &step.feeAmount)
if err != nil {
return nil, err
}
state.sqrtPriceX96.Set(&nxtSqrtPriceX96)

var amountInPlusFee utils.Uint256
amountInPlusFee.Add(step.amountIn, step.feeAmount)
amountInPlusFee.Add(&step.amountIn, &step.feeAmount)

var amountInPlusFeeSigned utils.Int256
err = utils.ToInt256(&amountInPlusFee, &amountInPlusFeeSigned)
Expand All @@ -368,7 +396,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLim
}

var amountOutSigned utils.Int256
err = utils.ToInt256(step.amountOut, &amountOutSigned)
err = utils.ToInt256(&step.amountOut, &amountOutSigned)
if err != nil {
return nil, err
}
Expand All @@ -382,7 +410,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLim
}

// TODO
if state.sqrtPriceX96.Cmp(step.sqrtPriceNextX96) == 0 {
if state.sqrtPriceX96.Cmp(&step.sqrtPriceNextX96) == 0 {
// if the tick is initialized, run the tick transition
if step.initialized {
tick, err := p.TickDataProvider.GetTick(step.tickNext)
Expand All @@ -406,14 +434,15 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLim
state.tick = step.tickNext
}

} else if state.sqrtPriceX96.Cmp(step.sqrtPriceStartX96) != 0 {
} else if state.sqrtPriceX96.Cmp(&step.sqrtPriceStartX96) != 0 {
// recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved
state.tick, err = utils.GetTickAtSqrtRatioV2(state.sqrtPriceX96)
if err != nil {
return nil, err
}
}
}

return &SwapResult{
amountCalculated: state.amountCalculated,
sqrtRatioX96: state.sqrtPriceX96,
Expand Down
131 changes: 127 additions & 4 deletions utils/full_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,130 @@ func MulDivRoundingUp(a, b, denominator *uint256.Int) (*uint256.Int, error) {
return resultU, nil
}

func MulDivRoundingUpV2(a, b, denominator, result *uint256.Int) error {
var remainder Uint256
err := MulDivV2(a, b, denominator, result, &remainder)
if err != nil {
return err
}

if !remainder.IsZero() {
if result.Cmp(MaxUint256) == 0 {
return ErrInvariant
}
result.AddUint64(result, 1)
}
return nil
}

// result=floor(a×b÷denominator), remainder=a×b%denominator
// (pass remainder=nil if not required)
// (the main usage for `remainder` is to be used in `MulDivRoundingUpV2` to determine if we need to round up, so it won't have to call MulMod again)
func MulDivV2(a, b, denominator, result, remainder *uint256.Int) error {
// https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/FullMath.sol
// 512-bit multiply [prod1 prod0] = a * b
// Compute the product mod 2**256 and mod 2**256 - 1
// then use the Chinese Remainder Theorem to reconstruct
// the 512 bit result. The result is stored in two 256
// variables such that product = prod1 * 2**256 + prod0
var prod0 Uint256 // Least significant 256 bits of the product
var prod1 Uint256 // Most significant 256 bits of the product

var denominatorTmp Uint256 // temp var (need to modify denominator along the way)
denominatorTmp.Set(denominator)

var mm Uint256
mm.MulMod(a, b, MaxUint256)
prod0.Mul(a, b)
prod1.Sub(&mm, &prod0)
if mm.Cmp(&prod0) < 0 {
prod1.SubUint64(&prod1, 1)
}

// Handle non-overflow cases, 256 by 256 division
if prod1.IsZero() {
if denominatorTmp.IsZero() {
return ErrInvariant
}

if remainder != nil {
// if the caller request then calculate remainder
remainder.MulMod(a, b, &denominatorTmp)
}
result.Div(&prod0, &denominatorTmp)
return nil
}

// Make sure the result is less than 2**256.
// Also prevents denominator == 0
if denominatorTmp.Cmp(&prod1) <= 0 {
return ErrInvariant
}

///////////////////////////////////////////////
// 512 by 256 division.
///////////////////////////////////////////////

// Make division exact by subtracting the remainder from [prod1 prod0]
// Compute remainder using mulmod
if remainder == nil {
// the caller doesn't request but we need it so use a temporary variable here
var remainderTmp Uint256
remainder = &remainderTmp
}
remainder.MulMod(a, b, &denominatorTmp)
// Subtract 256 bit number from 512 bit number
if remainder.Cmp(&prod0) > 0 {
prod1.SubUint64(&prod1, 1)
}
prod0.Sub(&prod0, remainder)

// Factor powers of two out of denominator
// Compute largest power of two divisor of denominator.
// Always >= 1.
var twos, tmp, tmp1, zero, two, three Uint256
twos.And(tmp.Neg(&denominatorTmp), &denominatorTmp)
// Divide denominator by power of two
denominatorTmp.Div(&denominatorTmp, &twos)

// Divide [prod1 prod0] by the factors of two
prod0.Div(&prod0, &twos)
// Shift in bits from prod1 into prod0. For this we need
// to flip `twos` such that it is 2**256 / twos.
// If twos is zero, then it becomes one
zero.Clear()
twos.AddUint64(tmp.Div(tmp1.Sub(&zero, &twos), &twos), 1)
prod0.Or(&prod0, tmp.Mul(&prod1, &twos))

// Invert denominator mod 2**256
// Now that denominator is an odd number, it has an inverse
// modulo 2**256 such that denominator * inv = 1 mod 2**256.
// Compute the inverse by starting with a seed that is correct
// correct for four bits. That is, denominator * inv = 1 mod 2**4
var inv Uint256
two.SetUint64(2)
three.SetUint64(3)
inv.Xor(tmp.Mul(&denominatorTmp, &three), &two)
// Now use Newton-Raphson iteration to improve the precision.
// Thanks to Hensel's lifting lemma, this also works in modular
// arithmetic, doubling the correct bits in each step.
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**8
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**16
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**32
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**64
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**128
inv.Mul(&inv, tmp.Sub(&two, tmp1.Mul(&denominatorTmp, &inv))) // inverse mod 2**256

// Because the division is now exact we can divide by multiplying
// with the modular inverse of denominator. This will give us the
// correct result modulo 2**256. Since the precoditions guarantee
// that the outcome is less than 2**256, this is the final result.
// We don't need to compute the high bits of the result and prod1
// is no longer required.
result.Mul(&prod0, &inv)
return nil
}

// Calculates floor(a×b÷denominator) with full precision
func MulDiv(a, b, denominator *uint256.Int) (*uint256.Int, error) {
// the product can overflow so need to use big.Int here
Expand All @@ -46,11 +170,10 @@ func MulDiv(a, b, denominator *uint256.Int) (*uint256.Int, error) {
}

// Returns ceil(x / y)
func DivRoundingUp(a, denominator *uint256.Int) *uint256.Int {
var result, rem uint256.Int
func DivRoundingUp(a, denominator, result *uint256.Int) {
var rem uint256.Int
result.DivMod(a, denominator, &rem)
if !rem.IsZero() {
result.AddUint64(&result, 1)
result.AddUint64(result, 1)
}
return &result
}
Loading

0 comments on commit 4caa4e0

Please sign in to comment.