From 8711004b1f9157dec2fb6e270f9f8e960f807320 Mon Sep 17 00:00:00 2001 From: Hoang Trinh Date: Mon, 10 Jul 2023 17:00:57 +0700 Subject: [PATCH] initialize project --- .github/workflows/test.yml | 45 ++++ .gitignore | 36 +++ LICENSE | 21 ++ README.md | 74 +++++ constants/constants.go | 47 ++++ entities/nearestusabletick.go | 38 +++ entities/nearestusabletick_test.go | 39 +++ entities/pool.go | 341 ++++++++++++++++++++++++ entities/pool_test.go | 182 +++++++++++++ entities/tickdataprovider.go | 29 ++ entities/ticklist.go | 253 ++++++++++++++++++ entities/ticklist_test.go | 144 ++++++++++ entities/ticklistdataprovider.go | 25 ++ go.mod | 21 ++ go.sum | 40 +++ utils/calldata.go | 27 ++ utils/compute_pool_address.go | 63 +++++ utils/compute_pool_address_test.go | 33 +++ utils/encode.go | 16 ++ utils/encode_test.go | 25 ++ utils/full_math.go | 16 ++ utils/liquidity_math.go | 15 ++ utils/max_liquidity_for_amounts.go | 93 +++++++ utils/max_liquidity_for_amounts_test.go | 252 +++++++++++++++++ utils/most_significant_bit.go | 29 ++ utils/price_tick_conversions.go | 68 +++++ utils/price_tick_conversions_test.go | 111 ++++++++ utils/sqrtprice_math.go | 125 +++++++++ utils/swap_math.go | 76 ++++++ utils/tick_math.go | 197 ++++++++++++++ utils/tick_math_test.go | 34 +++ 31 files changed, 2515 insertions(+) create mode 100644 .github/workflows/test.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 constants/constants.go create mode 100644 entities/nearestusabletick.go create mode 100644 entities/nearestusabletick_test.go create mode 100644 entities/pool.go create mode 100644 entities/pool_test.go create mode 100644 entities/tickdataprovider.go create mode 100644 entities/ticklist.go create mode 100644 entities/ticklist_test.go create mode 100644 entities/ticklistdataprovider.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 utils/calldata.go create mode 100644 utils/compute_pool_address.go create mode 100644 utils/compute_pool_address_test.go create mode 100644 utils/encode.go create mode 100644 utils/encode_test.go create mode 100644 utils/full_math.go create mode 100644 utils/liquidity_math.go create mode 100644 utils/max_liquidity_for_amounts.go create mode 100644 utils/max_liquidity_for_amounts_test.go create mode 100644 utils/most_significant_bit.go create mode 100644 utils/price_tick_conversions.go create mode 100644 utils/price_tick_conversions_test.go create mode 100644 utils/sqrtprice_math.go create mode 100644 utils/swap_math.go create mode 100644 utils/tick_math.go create mode 100644 utils/tick_math_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..76dac0e --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,45 @@ +on: [push, pull_request] +name: Test +jobs: + test: + strategy: + matrix: + go-version: [1.20.x] + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go-version }} + - name: Checkout code + uses: actions/checkout@v2 + - name: Test + run: go test ./... + + test-cache: + runs-on: ubuntu-latest + steps: + - name: Install Go + uses: actions/setup-go@v2 + with: + go-version: 1.20.x + - name: Checkout code + uses: actions/checkout@v2 + - uses: actions/cache@v2 + with: + # In order: + # * Module download cache + # * Build cache (Linux) + # * Build cache (Mac) + # * Build cache (Windows) + path: | + ~/go/pkg/mod + ~/.cache/go-build + ~/Library/Caches/go-build + %LocalAppData%\go-build + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: Test + run: go test ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..94e4e41 --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +# OS specific +.DS_Store + +# Dependency directories +vendor/ + +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Go workspace file +go.work + +# IDE configuration files +uniswapv3-sdk/.idea/ +.vscode/launch.json + +# Test result +test/logs +test/result + +# Environment Variables +.env + +# Miscellaneous +command +app diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ef04430 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 daoleno + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..5318ee0 --- /dev/null +++ b/README.md @@ -0,0 +1,74 @@ +# Pancake V3 SDK + +[![API Reference](https://camo.githubusercontent.com/915b7be44ada53c290eb157634330494ebe3e30a/68747470733a2f2f676f646f632e6f72672f6769746875622e636f6d2f676f6c616e672f6764646f3f7374617475732e737667)](https://pkg.go.dev/github.com/KyberNetwork/pancake-v3-sdk) +[![Test](https://github.com/KyberNetwork/pancake-v3-sdk/actions/workflows/test.yml/badge.svg)](https://github.com/KyberNetwork/pancake-v3-sdk/actions/workflows/test.yml) +[![Go Report Card](https://goreportcard.com/badge/github.com/KyberNetwork/pancake-v3-sdk)](https://goreportcard.com/report/github.com/KyberNetwork/pancake-v3-sdk) + +🛠 A Go SDK for building applications on top of Pancake V3 + +## Installation + +```sh +go get github.com/KyberNetwork/pancake-v3-sdk +``` + +## Usage + +The following example shows how to create a pool, and get the inputAmount + +```go +package main + +import ( + "fmt" + "math/big" + + core "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/KyberNetwork/pancake-v3-sdk/entities" + "github.com/KyberNetwork/pancake-v3-sdk/utils" + "github.com/ethereum/go-ethereum/common" +) + +var ( + USDC = core.NewToken(1, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), 6, "USDC", "USD Coin") + DAI = core.NewToken(1, common.HexToAddress("0x6B175474E89094C44Da98b954EedeAC495271d0F"), 18, "DAI", "Dai Stablecoin") + OneEther = big.NewInt(1e18) +) + +func main() { + // create demo ticks + ticks := []entities.Tick{ + { + Index: entities.NearestUsableTick(utils.MinTick, constants.TickSpacings[constants.FeeLow]), + LiquidityNet: OneEther, + LiquidityGross: OneEther, + }, + { + Index: entities.NearestUsableTick(utils.MaxTick, constants.TickSpacings[constants.FeeLow]), + LiquidityNet: new(big.Int).Mul(OneEther, constants.NegativeOne), + LiquidityGross: OneEther, + }, + } + + // create tick data provider + p, err := entities.NewTickListDataProvider(ticks, constants.TickSpacings[constants.FeeLow]) + if err != nil { + panic(err) + } + + // new pool + pool, err := entities.NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), OneEther, 0, p) + if err != nil { + panic(err) + } + + // USDC -> DAI + outputAmount := core.FromRawAmount(DAI, big.NewInt(98)) + inputAmount, _, err := pool.GetInputAmount(outputAmount, nil) + if err != nil { + panic(err) + } + fmt.Println(inputAmount.ToSignificant(4)) +} +``` diff --git a/constants/constants.go b/constants/constants.go new file mode 100644 index 0000000..c19bd6d --- /dev/null +++ b/constants/constants.go @@ -0,0 +1,47 @@ +package constants + +import ( + "math/big" + + "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/ethereum/go-ethereum/common" +) + +const PoolInitCodeHash = "0xe34f199b19b2b4f47f68442619d555527d244f78a3297ea89325f843f87b8b54" + +var ( + FactoryAddress = common.HexToAddress("0x1F98431c8aD98523631AE4a59f267346ea31F984") + AddressZero = common.HexToAddress("0x0000000000000000000000000000000000000000") +) + +// The default factory enabled fee amounts, denominated in hundredths of bips. +type FeeAmount uint64 + +const ( + FeeLowest FeeAmount = 100 + FeeLow FeeAmount = 500 + FeeMedium FeeAmount = 2500 + FeeHigh FeeAmount = 10000 + + FeeMax FeeAmount = 1000000 +) + +// The default factory tick spacings by fee amount. +var TickSpacings = map[FeeAmount]int{ + FeeLowest: 1, + FeeLow: 10, + FeeMedium: 50, + FeeHigh: 200, +} + +var ( + NegativeOne = big.NewInt(-1) + Zero = big.NewInt(0) + One = big.NewInt(1) + + // used in liquidity amount math + Q96 = new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil) + Q192 = new(big.Int).Exp(Q96, big.NewInt(2), nil) + + PercentZero = entities.NewFraction(big.NewInt(0), big.NewInt(1)) +) diff --git a/entities/nearestusabletick.go b/entities/nearestusabletick.go new file mode 100644 index 0000000..39b8812 --- /dev/null +++ b/entities/nearestusabletick.go @@ -0,0 +1,38 @@ +package entities + +import ( + "math" + + "github.com/KyberNetwork/pancake-v3-sdk/utils" +) + +/** + * Returns the closest tick that is nearest a given tick and usable for the given tick spacing + * @param tick the target tick + * @param tickSpacing the spacing of the pool + */ +func NearestUsableTick(tick int, tickSpacing int) int { + if tickSpacing <= 0 { + panic("tickSpacing must be greater than 0") + } + if !(tick >= utils.MinTick && tick <= utils.MaxTick) { + panic("tick exceeds bounds") + } + + rounded := Round(float64(tick)/float64(tickSpacing)) * float64(tickSpacing) + if rounded < utils.MinTick { + return int(rounded) + tickSpacing + } + if rounded > utils.MaxTick { + return int(rounded) - tickSpacing + } + return int(rounded) +} + +// Round like javascript Math.round +// Note that this differs from many languages' round() functions, which often round this case to the next integer away from zero, instead giving a different result in the case of negative numbers with a fractional part of exactly 0.5. +// For example, -1.5 rounds to -2, but -1.5 rounds to -1. +// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Math/round#description +func Round(x float64) float64 { + return math.Floor(x + 0.5) +} diff --git a/entities/nearestusabletick_test.go b/entities/nearestusabletick_test.go new file mode 100644 index 0000000..7dafdad --- /dev/null +++ b/entities/nearestusabletick_test.go @@ -0,0 +1,39 @@ +package entities + +import ( + "testing" + + "github.com/KyberNetwork/pancake-v3-sdk/utils" + "github.com/stretchr/testify/assert" +) + +func TestNearestUsableTick(t *testing.T) { + assert.Panics(t, func() { NearestUsableTick(1, 0) }, "panics if tickSpacing is 0") + assert.Panics(t, func() { NearestUsableTick(1, -5) }, "panics if tickSpacing is negative") + assert.Panics(t, func() { NearestUsableTick(utils.MaxTick+1, 1) }, "panics if tick is greater than MaxTick") + assert.Panics(t, func() { NearestUsableTick(utils.MinTick-1, 1) }, "panics if tick is smaller than MinTick") + + type args struct { + ticks int + tickSpacing int + } + tests := []struct { + name string + args args + want int + }{ + {name: "rounds at positive half", args: args{ticks: 5, tickSpacing: 10}, want: 10}, + {name: "rounds down below positive half", args: args{ticks: 4, tickSpacing: 10}, want: 0}, + {name: "rounds up for negative half 0", args: args{ticks: -5, tickSpacing: 10}, want: 0}, + {name: "rounds up for negative half 1", args: args{ticks: -6, tickSpacing: 10}, want: -10}, + {name: "cannot round past MinTick", args: args{ticks: utils.MinTick, tickSpacing: utils.MaxTick/2 + 100}, want: -(utils.MaxTick/2 + 100)}, + {name: "cannot round past MaxTick", args: args{ticks: utils.MaxTick, tickSpacing: utils.MaxTick/2 + 100}, want: utils.MaxTick/2 + 100}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NearestUsableTick(tt.args.ticks, tt.args.tickSpacing); got != tt.want { + t.Errorf("NearestUsableTick() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/entities/pool.go b/entities/pool.go new file mode 100644 index 0000000..b1dbd66 --- /dev/null +++ b/entities/pool.go @@ -0,0 +1,341 @@ +package entities + +import ( + "errors" + "math/big" + + "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/ethereum/go-ethereum/common" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/KyberNetwork/pancake-v3-sdk/utils" +) + +var ( + ErrFeeTooHigh = errors.New("Fee too high") + ErrInvalidSqrtRatioX96 = errors.New("Invalid sqrtRatioX96") + ErrTokenNotInvolved = errors.New("Token not involved in pool") + ErrSqrtPriceLimitX96TooLow = errors.New("SqrtPriceLimitX96 too low") + ErrSqrtPriceLimitX96TooHigh = errors.New("SqrtPriceLimitX96 too high") +) + +type StepComputations struct { + sqrtPriceStartX96 *big.Int + tickNext int + initialized bool + sqrtPriceNextX96 *big.Int + amountIn *big.Int + amountOut *big.Int + feeAmount *big.Int +} + +// Represents a V3 pool +type Pool struct { + Token0 *entities.Token + Token1 *entities.Token + Fee constants.FeeAmount + SqrtRatioX96 *big.Int + Liquidity *big.Int + TickCurrent int + TickDataProvider TickDataProvider + + token0Price *entities.Price + token1Price *entities.Price +} + +func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCodeHashManualOverride string) (common.Address, error) { + return utils.ComputePoolAddress(constants.FactoryAddress, tokenA, tokenB, fee, initCodeHashManualOverride) +} + +/** + * Construct a pool + * @param tokenA One of the tokens in the pool + * @param tokenB The other token in the pool + * @param fee The fee in hundredths of a bips of the input amount of every swap that is collected by the pool + * @param sqrtRatioX96 The sqrt of the current ratio of amounts of token1 to token0 + * @param liquidity The current value of in range liquidity + * @param tickCurrent The current tick of the pool + * @param ticks The current state of the pool ticks or a data provider that can return tick data + */ +func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, ticks TickDataProvider) (*Pool, error) { + if fee >= constants.FeeMax { + return nil, ErrFeeTooHigh + } + + tickCurrentSqrtRatioX96, err := utils.GetSqrtRatioAtTick(tickCurrent) + if err != nil { + return nil, err + } + nextTickSqrtRatioX96, err := utils.GetSqrtRatioAtTick(tickCurrent + 1) + if err != nil { + return nil, err + } + + if sqrtRatioX96.Cmp(tickCurrentSqrtRatioX96) < 0 || sqrtRatioX96.Cmp(nextTickSqrtRatioX96) > 0 { + return nil, ErrInvalidSqrtRatioX96 + } + token0 := tokenA + token1 := tokenB + isSorted, err := tokenA.SortsBefore(tokenB) + if err != nil { + return nil, err + } + if !isSorted { + token0 = tokenB + token1 = tokenA + } + + return &Pool{ + Token0: token0, + Token1: token1, + Fee: fee, + SqrtRatioX96: sqrtRatioX96, + Liquidity: liquidity, + TickCurrent: tickCurrent, + TickDataProvider: ticks, // TODO: new tick data provider + }, nil +} + +/** + * Returns true if the token is either token0 or token1 + * @param token The token to check + * @returns True if token is either token0 or token + */ +func (p *Pool) InvolvesToken(token *entities.Token) bool { + return p.Token0.Equal(token) || p.Token1.Equal(token) +} + +// Token0Price returns the current mid price of the pool in terms of token0, i.e. the ratio of token1 over token0 +func (p *Pool) Token0Price() *entities.Price { + if p.token0Price != nil { + return p.token0Price + } + p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192, new(big.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96)) + return p.token0Price +} + +// Token1Price returns the current mid price of the pool in terms of token1, i.e. the ratio of token0 over token1 +func (p *Pool) Token1Price() *entities.Price { + if p.token1Price != nil { + return p.token1Price + } + p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(big.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96), constants.Q192) + return p.token1Price +} + +/** + * Return the price of the given token in terms of the other token in the pool. + * @param token The token to return price of + * @returns The price of the given token, in terms of the other. + */ +func (p *Pool) PriceOf(token *entities.Token) (*entities.Price, error) { + if !p.InvolvesToken(token) { + return nil, ErrTokenNotInvolved + } + if p.Token0.Equal(token) { + return p.Token0Price(), nil + } + return p.Token1Price(), nil +} + +// ChainId returns the chain ID of the tokens in the pool. +func (p *Pool) ChainID() uint { + return p.Token0.ChainId() +} + +/** + * Given an input amount of a token, return the computed output amount, and a pool with state updated after the trade + * @param inputAmount The input amount for which to quote the output amount + * @param sqrtPriceLimitX96 The Q64.96 sqrt price limit + * @returns The output amount and the pool with updated state + */ +func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*entities.CurrencyAmount, *Pool, error) { + if !(inputAmount.Currency.IsToken() && p.InvolvesToken(inputAmount.Currency.Wrapped())) { + return nil, nil, ErrTokenNotInvolved + } + zeroForOne := inputAmount.Currency.Equal(p.Token0) + outputAmount, sqrtRatioX96, liquidity, tickCurrent, err := p.swap(zeroForOne, inputAmount.Quotient(), sqrtPriceLimitX96) + if err != nil { + return nil, nil, err + } + var outputToken *entities.Token + if zeroForOne { + outputToken = p.Token1 + } else { + outputToken = p.Token0 + } + pool, err := NewPool(p.Token0, p.Token1, p.Fee, sqrtRatioX96, liquidity, tickCurrent, p.TickDataProvider) + if err != nil { + return nil, nil, err + } + return entities.FromRawAmount(outputToken, new(big.Int).Mul(outputAmount, constants.NegativeOne)), pool, nil +} + +/** + * Given a desired output amount of a token, return the computed input amount and a pool with state updated after the trade + * @param outputAmount the output amount for which to quote the input amount + * @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap + * @returns The input amount and the pool with updated state + */ +func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*entities.CurrencyAmount, *Pool, error) { + if !(outputAmount.Currency.IsToken() && p.InvolvesToken(outputAmount.Currency.Wrapped())) { + return nil, nil, ErrTokenNotInvolved + } + zeroForOne := outputAmount.Currency.Equal(p.Token1) + inputAmount, sqrtRatioX96, liquidity, tickCurrent, err := p.swap(zeroForOne, new(big.Int).Mul(outputAmount.Quotient(), constants.NegativeOne), sqrtPriceLimitX96) + if err != nil { + return nil, nil, err + } + var inputToken *entities.Token + if zeroForOne { + inputToken = p.Token0 + } else { + inputToken = p.Token1 + } + pool, err := NewPool(p.Token0, p.Token1, p.Fee, sqrtRatioX96, liquidity, tickCurrent, p.TickDataProvider) + if err != nil { + return nil, nil, err + } + return entities.FromRawAmount(inputToken, inputAmount), pool, nil +} + +/** + * Executes a swap + * @param zeroForOne Whether the amount in is token0 or token1 + * @param amountSpecified The amount of the swap, which implicitly configures the swap as exact input (positive), or exact output (negative) + * @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap + * @returns amountCalculated + * @returns sqrtRatioX96 + * @returns liquidity + * @returns tickCurrent + */ +func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) (amountCalCulated *big.Int, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, err error) { + if sqrtPriceLimitX96 == nil { + if zeroForOne { + sqrtPriceLimitX96 = new(big.Int).Add(utils.MinSqrtRatio, constants.One) + } else { + sqrtPriceLimitX96 = new(big.Int).Sub(utils.MaxSqrtRatio, constants.One) + } + } + + if zeroForOne { + if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) < 0 { + return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooLow + } + if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) >= 0 { + return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooHigh + } + } else { + if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatio) > 0 { + return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooHigh + } + if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) <= 0 { + return nil, nil, nil, 0, ErrSqrtPriceLimitX96TooLow + } + } + + exactInput := amountSpecified.Cmp(constants.Zero) >= 0 + + // keep track of swap state + + state := struct { + amountSpecifiedRemaining *big.Int + amountCalculated *big.Int + sqrtPriceX96 *big.Int + tick int + liquidity *big.Int + }{ + amountSpecifiedRemaining: amountSpecified, + amountCalculated: constants.Zero, + sqrtPriceX96: p.SqrtRatioX96, + tick: p.TickCurrent, + liquidity: p.Liquidity, + } + + // start swap while loop + for state.amountSpecifiedRemaining.Cmp(constants.Zero) != 0 && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 { + var step StepComputations + step.sqrtPriceStartX96 = state.sqrtPriceX96 + + // because each iteration of the while loop rounds, we can't optimize this code (relative to the smart contract) + // by simply traversing to the next available tick, we instead need to exactly replicate + // tickBitmap.nextInitializedTickWithinOneWord + step.tickNext, step.initialized, err = p.TickDataProvider.NextInitializedTickIndex(state.tick, zeroForOne) + if err != nil { + return nil, nil, nil, 0, err + } + + if step.tickNext < utils.MinTick { + step.tickNext = utils.MinTick + } else if step.tickNext > utils.MaxTick { + step.tickNext = utils.MaxTick + } + + step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTick(step.tickNext) + if err != nil { + return nil, nil, nil, 0, err + } + var targetValue *big.Int + if zeroForOne { + if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) < 0 { + targetValue = sqrtPriceLimitX96 + } else { + targetValue = step.sqrtPriceNextX96 + } + } else { + if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) > 0 { + targetValue = sqrtPriceLimitX96 + } else { + targetValue = step.sqrtPriceNextX96 + } + } + + state.sqrtPriceX96, step.amountIn, step.amountOut, step.feeAmount, err = utils.ComputeSwapStep(state.sqrtPriceX96, targetValue, state.liquidity, state.amountSpecifiedRemaining, p.Fee) + if err != nil { + return nil, nil, nil, 0, err + } + + if exactInput { + state.amountSpecifiedRemaining = new(big.Int).Sub(state.amountSpecifiedRemaining, new(big.Int).Add(step.amountIn, step.feeAmount)) + state.amountCalculated = new(big.Int).Sub(state.amountCalculated, step.amountOut) + } else { + state.amountSpecifiedRemaining = new(big.Int).Add(state.amountSpecifiedRemaining, step.amountOut) + state.amountCalculated = new(big.Int).Add(state.amountCalculated, new(big.Int).Add(step.amountIn, step.feeAmount)) + } + + // TODO + if state.sqrtPriceX96.Cmp(step.sqrtPriceNextX96) == 0 { + // if the tick is initialized, run the tick transition + if step.initialized { + tick, err := p.TickDataProvider.GetTick(step.tickNext) + if err != nil { + return nil, nil, nil, 0, err + } + + liquidityNet := tick.LiquidityNet + // if we're moving leftward, we interpret liquidityNet as the opposite sign + // safe because liquidityNet cannot be type(int128).min + if zeroForOne { + liquidityNet = new(big.Int).Mul(liquidityNet, constants.NegativeOne) + } + state.liquidity = utils.AddDelta(state.liquidity, liquidityNet) + } + if zeroForOne { + state.tick = step.tickNext - 1 + } else { + state.tick = step.tickNext + } + } else if state.sqrtPriceX96.Cmp(step.sqrtPriceStartX96) != 0 { + // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved + state.tick, err = utils.GetTickAtSqrtRatio(state.sqrtPriceX96) + if err != nil { + return nil, nil, nil, 0, err + } + } + } + return state.amountCalculated, state.sqrtPriceX96, state.liquidity, state.tick, nil +} + +func (p *Pool) tickSpacing() int { + return constants.TickSpacings[p.Fee] +} diff --git a/entities/pool_test.go b/entities/pool_test.go new file mode 100644 index 0000000..6d7d355 --- /dev/null +++ b/entities/pool_test.go @@ -0,0 +1,182 @@ +package entities + +import ( + "math/big" + "testing" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/KyberNetwork/pancake-v3-sdk/utils" + "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" +) + +var ( + USDC = entities.NewToken(1, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), 6, "USDC", "USD Coin") + DAI = entities.NewToken(1, common.HexToAddress("0x6B175474E89094C44Da98b954EedeAC495271d0F"), 18, "DAI", "Dai Stablecoin") + OneEther = big.NewInt(1e18) +) + +func TestNewPool(t *testing.T) { + _, err := NewPool(USDC, entities.WETH9[3], constants.FeeMedium, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.ErrorIs(t, err, entities.ErrDifferentChain, "cannot be used for tokens on different chains") + + _, err = NewPool(USDC, entities.WETH9[1], 1e6, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.ErrorIs(t, err, ErrFeeTooHigh, "fee cannot be more than 1e6'") + + _, err = NewPool(USDC, USDC, constants.FeeMedium, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.ErrorIs(t, err, entities.ErrSameAddress, "cannot be used for the same token") + + _, err = NewPool(USDC, entities.WETH9[1], constants.FeeMedium, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 1, nil) + assert.ErrorIs(t, err, ErrInvalidSqrtRatioX96, "price must be within tick price bounds") + + _, err = NewPool(USDC, entities.WETH9[1], constants.FeeMedium, new(big.Int).Add(utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(1)), big.NewInt(0), -1, nil) + assert.ErrorIs(t, err, ErrInvalidSqrtRatioX96, "price must be within tick price bounds") + + _, err = NewPool(USDC, entities.WETH9[1], constants.FeeMedium, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.NoError(t, err, "works with valid arguments for empty pool medium fee") + + _, err = NewPool(USDC, entities.WETH9[1], constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.NoError(t, err, "works with valid arguments for empty pool low fee") + + _, err = NewPool(USDC, entities.WETH9[1], constants.FeeHigh, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.NoError(t, err, "works with valid arguments for empty pool high fee") +} + +func TestGetAddress(t *testing.T) { + addr, _ := GetAddress(USDC, DAI, constants.FeeLow, "") + assert.Equal(t, addr, common.HexToAddress("0x6c6Bc977E13Df9b0de53b251522280BB72383700"), "matches an example") +} + +func TestToken0(t *testing.T) { + pool, _ := NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.Equal(t, pool.Token0, DAI, "always is the token that sorts before") + + pool, _ = NewPool(DAI, USDC, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.Equal(t, pool.Token0, DAI, "always is the token that sorts before") +} + +func TestToken1(t *testing.T) { + pool, _ := NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.Equal(t, pool.Token1, USDC, "always is the token that sorts after") + + pool, _ = NewPool(DAI, USDC, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.Equal(t, pool.Token1, USDC, "always is the token that sorts after") +} + +func TestToken0Price(t *testing.T) { + a1 := new(big.Int).Mul(big.NewInt(101), big.NewInt(1e6)) + a2 := new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18)) + r, _ := utils.GetTickAtSqrtRatio(utils.EncodeSqrtRatioX96(a1, a2)) + pool0, _ := NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(a1, a2), big.NewInt(0), r, nil) + assert.Equal(t, pool0.Token0Price().ToSignificant(5), "1.01", "returns price of token0 in terms of token1") + + pool1, _ := NewPool(DAI, USDC, constants.FeeLow, utils.EncodeSqrtRatioX96(a1, a2), big.NewInt(0), r, nil) + assert.Equal(t, pool1.Token0Price().ToSignificant(5), "1.01", "returns price of token0 in terms of token1") +} + +func TestToken1Price(t *testing.T) { + a1 := new(big.Int).Mul(big.NewInt(101), big.NewInt(1e6)) + a2 := new(big.Int).Mul(big.NewInt(100), big.NewInt(1e18)) + r, _ := utils.GetTickAtSqrtRatio(utils.EncodeSqrtRatioX96(a1, a2)) + pool0, _ := NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(a1, a2), big.NewInt(0), r, nil) + assert.Equal(t, pool0.Token1Price().ToSignificant(5), "0.9901", "returns price of token1 in terms of token0") + + pool1, _ := NewPool(DAI, USDC, constants.FeeLow, utils.EncodeSqrtRatioX96(a1, a2), big.NewInt(0), r, nil) + assert.Equal(t, pool1.Token1Price().ToSignificant(5), "0.9901", "returns price of token1 in terms of token0") +} + +func TestPriceOf(t *testing.T) { + pool, _ := NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + price0, _ := pool.PriceOf(DAI) + assert.Equal(t, price0, pool.Token0Price(), "returns price of token in terms of other token") + price1, _ := pool.PriceOf(USDC) + assert.Equal(t, price1, pool.Token1Price(), "returns price of token in terms of other token") + + _, err := pool.PriceOf(entities.WETH9[1]) + assert.Error(t, err, "invalid token") +} + +func TestChainID(t *testing.T) { + pool0, _ := NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.Equal(t, pool0.ChainID(), uint(1), "returns the token0 chainId") + + pool1, _ := NewPool(DAI, USDC, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.Equal(t, pool1.ChainID(), uint(1), "returns the token0 chainId") +} + +func TestInvolvesToken(t *testing.T) { + pool, _ := NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), big.NewInt(0), 0, nil) + assert.True(t, pool.InvolvesToken(USDC), "involves USDC") + assert.True(t, pool.InvolvesToken(DAI), "involves DAI") + assert.False(t, pool.InvolvesToken(entities.WETH9[1]), "does not involve WETH9") +} + +func newTestPool() *Pool { + ticks := []Tick{ + { + Index: NearestUsableTick(utils.MinTick, constants.TickSpacings[constants.FeeLow]), + LiquidityNet: OneEther, + LiquidityGross: OneEther, + }, + { + Index: NearestUsableTick(utils.MaxTick, constants.TickSpacings[constants.FeeLow]), + LiquidityNet: new(big.Int).Mul(OneEther, constants.NegativeOne), + LiquidityGross: OneEther, + }, + } + + p, err := NewTickListDataProvider(ticks, constants.TickSpacings[constants.FeeLow]) + if err != nil { + panic(err) + } + + pool, err := NewPool(USDC, DAI, constants.FeeLow, utils.EncodeSqrtRatioX96(constants.One, constants.One), OneEther, 0, p) + if err != nil { + panic(err) + } + return pool +} +func TestGetOutputAmount(t *testing.T) { + pool := newTestPool() + + // USDC -> DAI + inputAmount := entities.FromRawAmount(USDC, big.NewInt(100)) + outputAmount, _, err := pool.GetOutputAmount(inputAmount, nil) + if err != nil { + t.Fatal(err) + } + assert.True(t, outputAmount.Currency.Equal(DAI)) + assert.Equal(t, outputAmount.Quotient(), big.NewInt(98)) + + // DAI -> USDC + inputAmount = entities.FromRawAmount(DAI, big.NewInt(100)) + outputAmount, _, err = pool.GetOutputAmount(inputAmount, nil) + if err != nil { + t.Fatal(err) + } + assert.True(t, outputAmount.Currency.Equal(USDC)) + assert.Equal(t, outputAmount.Quotient(), big.NewInt(98)) +} + +func TestGetInputAmount(t *testing.T) { + pool := newTestPool() + + // USDC -> DAI + outputAmount := entities.FromRawAmount(DAI, big.NewInt(98)) + inputAmount, _, err := pool.GetInputAmount(outputAmount, nil) + if err != nil { + t.Fatal(err) + } + assert.True(t, inputAmount.Currency.Equal(USDC)) + assert.Equal(t, inputAmount.Quotient(), big.NewInt(100)) + + // DAI -> USDC + outputAmount = entities.FromRawAmount(USDC, big.NewInt(98)) + inputAmount, _, err = pool.GetInputAmount(outputAmount, nil) + if err != nil { + t.Fatal(err) + } + assert.True(t, inputAmount.Currency.Equal(DAI)) + assert.Equal(t, inputAmount.Quotient(), big.NewInt(100)) +} diff --git a/entities/tickdataprovider.go b/entities/tickdataprovider.go new file mode 100644 index 0000000..6086e61 --- /dev/null +++ b/entities/tickdataprovider.go @@ -0,0 +1,29 @@ +package entities + +import "math/big" + +type Tick struct { + Index int + LiquidityGross *big.Int + LiquidityNet *big.Int +} + +// Provides information about ticks +type TickDataProvider interface { + /** + * Return information corresponding to a specific tick + * @param tick the tick to load + */ + GetTick(tick int) (Tick, error) + + /** + * Return the next tick that is initialized within a single word + * @param tick The current tick + * @param lte Whether the next tick should be lte the current tick + * @param tickSpacing The tick spacing of the pool + */ + NextInitializedTickWithinOneWord(tick int, lte bool, tickSpacing int) (int, bool, error) + + // NextInitializedTickIndex return the next tick that is initialized + NextInitializedTickIndex(tick int, lte bool) (int, bool, error) +} diff --git a/entities/ticklist.go b/entities/ticklist.go new file mode 100644 index 0000000..50cdc65 --- /dev/null +++ b/entities/ticklist.go @@ -0,0 +1,253 @@ +package entities + +import ( + "errors" + "math" + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" +) + +const ( + ZeroValueTickIndex = 0 + ZeroValueTickInitialized = false +) + +var ( + ErrZeroTickSpacing = errors.New("tick spacing must be greater than 0") + ErrInvalidTickSpacing = errors.New("invalid tick spacing") + ErrZeroNet = errors.New("tick net delta must be zero") + ErrSorted = errors.New("ticks must be sorted") + ErrEmptyTickList = errors.New("empty tick list") + ErrBelowSmallest = errors.New("below smallest") + ErrAtOrAboveLargest = errors.New("at or above largest") + ErrInvalidTickIndex = errors.New("invalid tick index") +) + +var ( + EmptyTick = Tick{} +) + +func ValidateList(ticks []Tick, tickSpacing int) error { + if tickSpacing <= 0 { + return ErrZeroTickSpacing + } + + // ensure ticks are spaced appropriately + for _, t := range ticks { + if t.Index%tickSpacing != 0 { + return ErrInvalidTickSpacing + } + } + + // ensure tick liquidity deltas sum to 0 + sum := big.NewInt(0) + for _, tick := range ticks { + sum.Add(sum, tick.LiquidityNet) + } + if sum.Cmp(big.NewInt(0)) != 0 { + return ErrZeroNet + } + + if !isTicksSorted(ticks) { + return ErrSorted + } + + return nil +} + +func IsBelowSmallest(ticks []Tick, tick int) (bool, error) { + if len(ticks) == 0 { + return true, ErrEmptyTickList + } + + return tick < ticks[0].Index, nil +} + +func IsAtOrAboveLargest(ticks []Tick, tick int) (bool, error) { + if len(ticks) == 0 { + return true, ErrEmptyTickList + } + + return tick >= ticks[len(ticks)-1].Index, nil +} + +func GetTick(ticks []Tick, index int) (Tick, error) { + tickIndex, err := binarySearch(ticks, index) + if err != nil { + return EmptyTick, err + } + + if tickIndex < 0 { + return EmptyTick, ErrInvalidTickIndex + } + + tick := ticks[tickIndex] + + return tick, nil +} + +func NextInitializedTick(ticks []Tick, tick int, lte bool) (Tick, error) { + if lte { + isBelowSmallest, err := IsBelowSmallest(ticks, tick) + if err != nil { + return EmptyTick, err + } + + if isBelowSmallest { + return EmptyTick, ErrBelowSmallest + } + + isAtOrAboveLargest, err := IsAtOrAboveLargest(ticks, tick) + if err != nil { + return EmptyTick, err + } + + if isAtOrAboveLargest { + return ticks[len(ticks)-1], nil + } + + index, err := binarySearch(ticks, tick) + if err != nil { + return EmptyTick, err + } + + return ticks[index], nil + } else { + isAtOrAboveLargest, err := IsAtOrAboveLargest(ticks, tick) + if err != nil { + return EmptyTick, err + } + + if isAtOrAboveLargest { + return EmptyTick, ErrAtOrAboveLargest + } + + isBelowSmallest, err := IsBelowSmallest(ticks, tick) + + if err != nil { + return EmptyTick, err + } + + if isBelowSmallest { + return ticks[0], nil + } + + index, err := binarySearch(ticks, tick) + if err != nil { + return EmptyTick, err + } + + return ticks[index+1], nil + } +} + +func NextInitializedTickWithinOneWord(ticks []Tick, tick int, lte bool, tickSpacing int) (int, bool, error) { + compressed := math.Floor(float64(tick) / float64(tickSpacing)) // matches rounding in the code + + if lte { + wordPos := int(compressed) >> 8 + minimum := (wordPos << 8) * tickSpacing + isBelowSmallest, err := IsBelowSmallest(ticks, tick) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err + } + + if isBelowSmallest { + return minimum, ZeroValueTickInitialized, ErrBelowSmallest + } + + nextInitializedTick, err := NextInitializedTick(ticks, tick, lte) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err + } + + index := nextInitializedTick.Index + nextInitializedTickIndex := math.Max(float64(minimum), float64(index)) + return int(nextInitializedTickIndex), int(nextInitializedTickIndex) == index, nil + } else { + wordPos := int(compressed+1) >> 8 + maximum := ((wordPos+1)<<8)*tickSpacing - 1 + isAtOrAboveLargest, err := IsAtOrAboveLargest(ticks, tick) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err + } + + if isAtOrAboveLargest { + return maximum, ZeroValueTickInitialized, ErrAtOrAboveLargest + } + + nextInitializedTick, err := NextInitializedTick(ticks, tick, lte) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err + } + + index := nextInitializedTick.Index + nextInitializedTickIndex := math.Min(float64(maximum), float64(index)) + return int(nextInitializedTickIndex), int(nextInitializedTickIndex) == index, nil + } +} + +func NextInitializedTickIndex(ticks []Tick, tick int, lte bool) (int, bool, error) { + nextInitializedTick, err := NextInitializedTick(ticks, tick, lte) + if err != nil { + return ZeroValueTickIndex, ZeroValueTickInitialized, err + } + + var isInitialized bool + if nextInitializedTick.LiquidityGross.Cmp(constants.Zero) != 0 { + isInitialized = true + } + + return nextInitializedTick.Index, isInitialized, nil +} + +// utils + +func isTicksSorted(ticks []Tick) bool { + for i := 0; i < len(ticks)-1; i++ { + if ticks[i].Index > ticks[i+1].Index { + return false + } + } + return true +} + +/** + * Finds the largest tick in the list of ticks that is less than or equal to tick + * @param ticks list of ticks + * @param tick tick to find the largest tick that is less than or equal to tick + * @private + */ +func binarySearch(ticks []Tick, tick int) (int, error) { + isBelowSmallest, err := IsBelowSmallest(ticks, tick) + if err != nil { + return ZeroValueTickIndex, err + } + + if isBelowSmallest { + return ZeroValueTickIndex, ErrBelowSmallest + } + + // binary search + start := 0 + end := len(ticks) - 1 + for start <= end { + mid := (start + end) / 2 + if ticks[mid].Index == tick { + return mid, nil + } else if ticks[mid].Index < tick { + start = mid + 1 + } else { + end = mid - 1 + } + } + + // if we get here, we didn't find a tick that is less than or equal to tick + // so we return the index of the tick that is closest to tick + if ticks[start].Index < tick { + return start, nil + } else { + return start - 1, nil + } +} diff --git a/entities/ticklist_test.go b/entities/ticklist_test.go new file mode 100644 index 0000000..403781a --- /dev/null +++ b/entities/ticklist_test.go @@ -0,0 +1,144 @@ +package entities + +import ( + "math/big" + "testing" + + "github.com/KyberNetwork/pancake-v3-sdk/utils" + "github.com/stretchr/testify/assert" +) + +var ( + lowTick = Tick{ + Index: utils.MinTick + 1, + LiquidityNet: big.NewInt(10), + LiquidityGross: big.NewInt(10), + } + midTick = Tick{ + Index: 0, + LiquidityNet: big.NewInt(-5), + LiquidityGross: big.NewInt(5), + } + highTick = Tick{ + Index: utils.MaxTick - 1, + LiquidityNet: big.NewInt(-5), + LiquidityGross: big.NewInt(5), + } +) + +func TestValidateList(t *testing.T) { + assert.ErrorIs(t, ValidateList([]Tick{lowTick}, 1), ErrZeroNet, "panics for incomplete lists") + assert.ErrorIs(t, ValidateList([]Tick{highTick, lowTick, midTick}, 1), ErrSorted, "panics for unsorted lists") + assert.ErrorIs(t, ValidateList([]Tick{highTick, midTick, lowTick}, 1337), ErrInvalidTickSpacing, "errors if ticks are not on multiples of tick spacing") +} + +func TestIsBelowSmallest(t *testing.T) { + result := []Tick{lowTick, midTick, highTick} + isBelowSmallest1, _ := IsBelowSmallest(result, utils.MinTick) + assert.True(t, isBelowSmallest1) + + isBelowSmallest2, _ := IsBelowSmallest(result, utils.MinTick+1) + assert.False(t, isBelowSmallest2) +} + +func TestIsAtOrAboveSmallest(t *testing.T) { + result := []Tick{lowTick, midTick, highTick} + + isAtOrAboveLargest1, _ := IsAtOrAboveLargest(result, utils.MaxTick-2) + assert.False(t, isAtOrAboveLargest1) + + isAtOrAboveLargest2, _ := IsAtOrAboveLargest(result, utils.MaxTick-1) + assert.True(t, isAtOrAboveLargest2) +} + +func TestNextInitializedTick(t *testing.T) { + ticks := []Tick{lowTick, midTick, highTick} + + type args struct { + ticks []Tick + tick int + lte bool + } + tests := []struct { + name string + args args + want Tick + }{ + {name: "low - lte = true 0", args: args{ticks: ticks, tick: utils.MinTick + 1, lte: true}, want: lowTick}, + {name: "low - lte = true 1", args: args{ticks: ticks, tick: utils.MinTick + 2, lte: true}, want: lowTick}, + {name: "low - lte = false 0", args: args{ticks: ticks, tick: utils.MinTick, lte: false}, want: lowTick}, + {name: "low - lte = false 1", args: args{ticks: ticks, tick: utils.MinTick + 1, lte: false}, want: midTick}, + {name: "mid - lte = true 0", args: args{ticks: ticks, tick: 0, lte: true}, want: midTick}, + {name: "mid - lte = true 1", args: args{ticks: ticks, tick: 1, lte: true}, want: midTick}, + {name: "mid - lte = false 0", args: args{ticks: ticks, tick: -1, lte: false}, want: midTick}, + {name: "mid - lte = false 1", args: args{ticks: ticks, tick: 0 + 1, lte: false}, want: highTick}, + {name: "high - lte = true 0", args: args{ticks: ticks, tick: utils.MaxTick - 1, lte: true}, want: highTick}, + {name: "high - lte = true 1", args: args{ticks: ticks, tick: utils.MaxTick, lte: true}, want: highTick}, + {name: "high - lte = false 0", args: args{ticks: ticks, tick: utils.MaxTick - 2, lte: false}, want: highTick}, + {name: "high - lte = false 1", args: args{ticks: ticks, tick: utils.MaxTick - 3, lte: false}, want: highTick}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nextInitializedTick, _ := NextInitializedTick(tt.args.ticks, tt.args.tick, tt.args.lte) + assert.Equal(t, tt.want, nextInitializedTick) + }) + } + + nextInitializedTick1, err1 := NextInitializedTick(ticks, utils.MinTick, true) + assert.Zero(t, nextInitializedTick1, "below smallest") + assert.ErrorIs(t, err1, ErrBelowSmallest) + + nextInitializedTick2, err2 := NextInitializedTick(ticks, utils.MaxTick-1, false) + assert.Zero(t, nextInitializedTick2, "at or above largest") + assert.ErrorIs(t, err2, ErrAtOrAboveLargest) +} + +func TestNextInitializedTickWithinOneWord(t *testing.T) { + ticks := []Tick{lowTick, midTick, highTick} + + // words around 0, lte = true + type args struct { + ticks []Tick + tick int + lte bool + tickSpacing int + } + tests := []struct { + name string + args args + want0 int + want1 bool + }{ + // words around 0, lte = true + {name: "lte = true 0", args: args{ticks: ticks, tick: -257, lte: true, tickSpacing: 1}, want0: -512, want1: false}, + {name: "lte = true 1", args: args{ticks: ticks, tick: -256, lte: true, tickSpacing: 1}, want0: -256, want1: false}, + {name: "lte = true 2", args: args{ticks: ticks, tick: -1, lte: true, tickSpacing: 1}, want0: -256, want1: false}, + {name: "lte = true 3", args: args{ticks: ticks, tick: 0, lte: true, tickSpacing: 1}, want0: 0, want1: true}, + {name: "lte = true 4", args: args{ticks: ticks, tick: 1, lte: true, tickSpacing: 1}, want0: 0, want1: true}, + {name: "lte = true 5", args: args{ticks: ticks, tick: 255, lte: true, tickSpacing: 1}, want0: 0, want1: true}, + {name: "lte = true 6", args: args{ticks: ticks, tick: 256, lte: true, tickSpacing: 1}, want0: 256, want1: false}, + {name: "lte = true 7", args: args{ticks: ticks, tick: 257, lte: true, tickSpacing: 1}, want0: 256, want1: false}, + + // words around 0, lte = false + {name: "lte = false 0", args: args{ticks: ticks, tick: -258, lte: false, tickSpacing: 1}, want0: -257, want1: false}, + {name: "lte = false 1", args: args{ticks: ticks, tick: -257, lte: false, tickSpacing: 1}, want0: -1, want1: false}, + {name: "lte = false 2", args: args{ticks: ticks, tick: -256, lte: false, tickSpacing: 1}, want0: -1, want1: false}, + {name: "lte = false 3", args: args{ticks: ticks, tick: -2, lte: false, tickSpacing: 1}, want0: -1, want1: false}, + {name: "lte = false 4", args: args{ticks: ticks, tick: -1, lte: false, tickSpacing: 1}, want0: 0, want1: true}, + {name: "lte = false 5", args: args{ticks: ticks, tick: 0, lte: false, tickSpacing: 1}, want0: 255, want1: false}, + {name: "lte = false 6", args: args{ticks: ticks, tick: 1, lte: false, tickSpacing: 1}, want0: 255, want1: false}, + {name: "lte = false 7", args: args{ticks: ticks, tick: 254, lte: false, tickSpacing: 1}, want0: 255, want1: false}, + {name: "lte = false 8", args: args{ticks: ticks, tick: 255, lte: false, tickSpacing: 1}, want0: 511, want1: false}, + {name: "lte = false 9", args: args{ticks: ticks, tick: 256, lte: false, tickSpacing: 1}, want0: 511, want1: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got0, got1, _ := NextInitializedTickWithinOneWord(tt.args.ticks, tt.args.tick, tt.args.lte, tt.args.tickSpacing) + assert.Equal(t, tt.want0, got0) + assert.Equal(t, tt.want1, got1) + }) + } + +} diff --git a/entities/ticklistdataprovider.go b/entities/ticklistdataprovider.go new file mode 100644 index 0000000..e6e6429 --- /dev/null +++ b/entities/ticklistdataprovider.go @@ -0,0 +1,25 @@ +package entities + +// A data provider for ticks that is backed by an in-memory array of ticks. +type TickListDataProvider struct { + ticks []Tick +} + +func NewTickListDataProvider(ticks []Tick, tickSpacing int) (*TickListDataProvider, error) { + if err := ValidateList(ticks, tickSpacing); err != nil { + return nil, err + } + return &TickListDataProvider{ticks: ticks}, nil +} + +func (p *TickListDataProvider) GetTick(tick int) (Tick, error) { + return GetTick(p.ticks, tick) +} + +func (p *TickListDataProvider) NextInitializedTickWithinOneWord(tick int, lte bool, tickSpacing int) (int, bool, error) { + return NextInitializedTickWithinOneWord(p.ticks, tick, lte, tickSpacing) +} + +func (p *TickListDataProvider) NextInitializedTickIndex(tick int, lte bool) (int, bool, error) { + return NextInitializedTickIndex(p.ticks, tick, lte) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3e5487f --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module github.com/KyberNetwork/pancake-v3-sdk + +go 1.20 + +require ( + github.com/daoleno/uniswap-sdk-core v0.1.7 + github.com/ethereum/go-ethereum v1.10.22 + github.com/shopspring/decimal v1.3.1 + github.com/stretchr/testify v1.8.0 +) + +require ( + github.com/btcsuite/btcd/btcec/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect + github.com/kr/pretty v0.3.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect + golang.org/x/sys v0.0.0-20220702020025-31831981b65f // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e49d42d --- /dev/null +++ b/go.sum @@ -0,0 +1,40 @@ +github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k= +github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 h1:q0rUy8C/TYNBQS1+CGKw68tLOFYSNEs0TFnxxnS9+4U= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/daoleno/uniswap-sdk-core v0.1.7 h1:PdZypLSzM5Mu2rFBjXK9XrHDppSt62GkxXjWLpuMAN4= +github.com/daoleno/uniswap-sdk-core v0.1.7/go.mod h1:DPzL8zNicstPzvX74ZeeHsiIUquZRpwviceDHQ8+UQ4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/crypto/blake256 v1.0.0 h1:/8DMNYp9SGi5f0w7uCm6d6M4OU2rGFK09Y2A4Xv7EE0= +github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= +github.com/ethereum/go-ethereum v1.10.22 h1:HbEgsDo1YTGIf4KB/NNpn+XH+PiNJXUZ9ksRxiqWyMc= +github.com/ethereum/go-ethereum v1.10.22/go.mod h1:EYFyF19u3ezGLD4RqOkLq+ZCXzYbLoNDdZlMt7kyKFg= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/sys v0.0.0-20220702020025-31831981b65f h1:xdsejrW/0Wf2diT5CPp3XmKUNbr7Xvw8kYilQ+6qjRY= +golang.org/x/sys v0.0.0-20220702020025-31831981b65f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/utils/calldata.go b/utils/calldata.go new file mode 100644 index 0000000..e73037d --- /dev/null +++ b/utils/calldata.go @@ -0,0 +1,27 @@ +package utils + +import ( + "math/big" +) + +type MethodParameters struct { + Calldata []byte // The hex encoded calldata to perform the given operation + Value *big.Int // The amount of ether (wei) to send in hex +} + +/** + * Converts a big int to a hex string + * @param bigintIsh + * @returns The hex encoded calldata + */ +func ToHex(i *big.Int) string { + if i == nil { + return "0x00" + } + + hex := i.Text(16) + if len(hex)%2 != 0 { + hex = "0" + hex + } + return "0x" + hex +} diff --git a/utils/compute_pool_address.go b/utils/compute_pool_address.go new file mode 100644 index 0000000..7e01f3a --- /dev/null +++ b/utils/compute_pool_address.go @@ -0,0 +1,63 @@ +package utils + +import ( + "math/big" + + "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" +) + +/** + * Computes a pool address + * @param factoryAddress The Pancake V3 factory address + * @param tokenA The first token of the pair, irrespective of sort order + * @param tokenB The second token of the pair, irrespective of sort order + * @param fee The fee tier of the pool + * @returns The pool address + */ +func ComputePoolAddress(factoryAddress common.Address, tokenA *entities.Token, tokenB *entities.Token, fee constants.FeeAmount, initCodeHashManualOverride string) (common.Address, error) { + isSorted, err := tokenA.SortsBefore(tokenB) + if err != nil { + return common.Address{}, err + } + var ( + token0 *entities.Token + token1 *entities.Token + ) + if isSorted { + token0 = tokenA + token1 = tokenB + } else { + token0 = tokenB + token1 = tokenA + } + return getCreate2Address(factoryAddress, token0.Address, token1.Address, fee, initCodeHashManualOverride), nil +} + +func getCreate2Address(factoyAddress, addressA, addressB common.Address, fee constants.FeeAmount, initCodeHashManualOverride string) common.Address { + var salt [32]byte + copy(salt[:], crypto.Keccak256(abiEncode(addressA, addressB, fee))) + + if initCodeHashManualOverride != "" { + crypto.CreateAddress2(factoyAddress, salt, common.FromHex(initCodeHashManualOverride)) + } + return crypto.CreateAddress2(factoyAddress, salt, common.FromHex(constants.PoolInitCodeHash)) +} + +func abiEncode(addressA, addressB common.Address, fee constants.FeeAmount) []byte { + addressTy, _ := abi.NewType("address", "address", nil) + uint256Ty, _ := abi.NewType("uint256", "uint256", nil) + + arguments := abi.Arguments{{Type: addressTy}, {Type: addressTy}, {Type: uint256Ty}} + + bytes, _ := arguments.Pack( + addressA, + addressB, + big.NewInt(int64(fee)), + ) + return bytes +} diff --git a/utils/compute_pool_address_test.go b/utils/compute_pool_address_test.go new file mode 100644 index 0000000..4b04685 --- /dev/null +++ b/utils/compute_pool_address_test.go @@ -0,0 +1,33 @@ +package utils + +import ( + "testing" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" +) + +func TestComputePoolAddress(t *testing.T) { + factoryAddress := common.HexToAddress("0x1111111111111111111111111111111111111111") + tokenA := entities.NewToken(1, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), 18, "USDC", "USD Coin") + tokenB := entities.NewToken(1, common.HexToAddress("0x6B175474E89094C44Da98b954EedeAC495271d0F"), 18, "DAI", "Dai Stablecoin") + result, err := ComputePoolAddress(factoryAddress, tokenA, tokenB, constants.FeeLow, "") + if err != nil { + panic(err) + } + assert.Equal(t, result, common.HexToAddress("0x90B1b09A9715CaDbFD9331b3A7652B24BfBEfD32")) + + USDC := entities.NewToken(1, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), 18, "USDC", "USD Coin") + DAI := entities.NewToken(1, common.HexToAddress("0x6B175474E89094C44Da98b954EedeAC495271d0F"), 18, "DAI", "Dai Stablecoin") + resultA, err := ComputePoolAddress(factoryAddress, USDC, DAI, constants.FeeLow, "") + if err != nil { + panic(err) + } + resultB, err := ComputePoolAddress(factoryAddress, DAI, USDC, constants.FeeLow, "") + if err != nil { + panic(err) + } + assert.Equal(t, resultA, resultB, "should correctly compute the pool address") +} diff --git a/utils/encode.go b/utils/encode.go new file mode 100644 index 0000000..2646e7e --- /dev/null +++ b/utils/encode.go @@ -0,0 +1,16 @@ +package utils + +import "math/big" + +/** + * Returns the sqrt ratio as a Q64.96 corresponding to a given ratio of amount1 and amount0 + * @param amount1 The numerator amount i.e., the amount of token1 + * @param amount0 The denominator amount i.e., the amount of token0 + * @returns The sqrt ratio + */ +func EncodeSqrtRatioX96(amount1 *big.Int, amount0 *big.Int) *big.Int { + numerator := new(big.Int).Lsh(amount1, 192) + denominator := amount0 + ratioX192 := new(big.Int).Div(numerator, denominator) + return new(big.Int).Sqrt(ratioX192) +} diff --git a/utils/encode_test.go b/utils/encode_test.go new file mode 100644 index 0000000..43120a4 --- /dev/null +++ b/utils/encode_test.go @@ -0,0 +1,25 @@ +package utils + +import ( + "math/big" + "testing" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/stretchr/testify/assert" +) + +func TestEncodeSqrtRatioX96(t *testing.T) { + assert.Equal(t, EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)), constants.Q96, "1/1") + + r0, _ := new(big.Int).SetString("792281625142643375935439503360", 10) + assert.Equal(t, EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(1)), r0, 10, "100/1") + + r1, _ := new(big.Int).SetString("7922816251426433759354395033", 10) + assert.Equal(t, EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(100)), r1, 10, "1/100") + + r2, _ := new(big.Int).SetString("45742400955009932534161870629", 10) + assert.Equal(t, EncodeSqrtRatioX96(big.NewInt(111), big.NewInt(333)), r2, 10, "111/333") + + r3, _ := new(big.Int).SetString("137227202865029797602485611888", 10) + assert.Equal(t, EncodeSqrtRatioX96(big.NewInt(333), big.NewInt(111)), r3, 10, "333/111") +} diff --git a/utils/full_math.go b/utils/full_math.go new file mode 100644 index 0000000..9d4c0e5 --- /dev/null +++ b/utils/full_math.go @@ -0,0 +1,16 @@ +package utils + +import ( + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" +) + +func MulDivRoundingUp(a, b, denominator *big.Int) *big.Int { + product := new(big.Int).Mul(a, b) + result := new(big.Int).Div(product, denominator) + if new(big.Int).Rem(product, denominator).Cmp(big.NewInt(0)) != 0 { + result.Add(result, constants.One) + } + return result +} diff --git a/utils/liquidity_math.go b/utils/liquidity_math.go new file mode 100644 index 0000000..9d40b2e --- /dev/null +++ b/utils/liquidity_math.go @@ -0,0 +1,15 @@ +package utils + +import ( + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" +) + +func AddDelta(x, y *big.Int) *big.Int { + if y.Cmp(constants.Zero) < 0 { + return new(big.Int).Sub(x, new(big.Int).Mul(y, constants.NegativeOne)) + } else { + return new(big.Int).Add(x, y) + } +} diff --git a/utils/max_liquidity_for_amounts.go b/utils/max_liquidity_for_amounts.go new file mode 100644 index 0000000..9de9b89 --- /dev/null +++ b/utils/max_liquidity_for_amounts.go @@ -0,0 +1,93 @@ +package utils + +import ( + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" +) + +/** + * Returns an imprecise maximum amount of liquidity received for a given amount of token 0. + * This function is available to accommodate LiquidityAmounts#getLiquidityForAmount0 in the v3 periphery, + * which could be more precise by at least 32 bits by dividing by Q64 instead of Q96 in the intermediate step, + * and shifting the subtracted ratio left by 32 bits. This imprecise calculation will likely be replaced in a future + * v3 router contract. + * @param sqrtRatioAX96 The price at the lower boundary + * @param sqrtRatioBX96 The price at the upper boundary + * @param amount0 The token0 amount + * @returns liquidity for amount0, imprecise + */ +func maxLiquidityForAmount0Imprecise(sqrtRatioAX96, sqrtRatioBX96, amount0 *big.Int) *big.Int { + if sqrtRatioAX96.Cmp(sqrtRatioBX96) > 0 { + sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 + } + intermediate := new(big.Int).Div(new(big.Int).Mul(sqrtRatioAX96, sqrtRatioBX96), constants.Q96) + return new(big.Int).Div(new(big.Int).Mul(amount0, intermediate), new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96)) +} + +/** + * Returns a precise maximum amount of liquidity received for a given amount of token 0 by dividing by Q64 instead of Q96 in the intermediate step, + * and shifting the subtracted ratio left by 32 bits. + * @param sqrtRatioAX96 The price at the lower boundary + * @param sqrtRatioBX96 The price at the upper boundary + * @param amount0 The token0 amount + * @returns liquidity for amount0, precise + */ +func maxLiquidityForAmount0Precise(sqrtRatioAX96, sqrtRatioBX96, amount0 *big.Int) *big.Int { + if sqrtRatioAX96.Cmp(sqrtRatioBX96) > 0 { + sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 + } + numerator := new(big.Int).Mul(new(big.Int).Mul(amount0, sqrtRatioAX96), sqrtRatioBX96) + denominator := new(big.Int).Mul(constants.Q96, new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96)) + return new(big.Int).Div(numerator, denominator) +} + +/** + * Computes the maximum amount of liquidity received for a given amount of token1 + * @param sqrtRatioAX96 The price at the lower tick boundary + * @param sqrtRatioBX96 The price at the upper tick boundary + * @param amount1 The token1 amount + * @returns liquidity for amount1 + */ +func maxLiquidityForAmount1(sqrtRatioAX96, sqrtRatioBX96, amount1 *big.Int) *big.Int { + if sqrtRatioAX96.Cmp(sqrtRatioBX96) > 0 { + sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 + } + return new(big.Int).Div(new(big.Int).Mul(amount1, constants.Q96), new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96)) +} + +/** + * Computes the maximum amount of liquidity received for a given amount of token0, token1, + * and the prices at the tick boundaries. + * @param sqrtRatioCurrentX96 the current price + * @param sqrtRatioAX96 price at lower boundary + * @param sqrtRatioBX96 price at upper boundary + * @param amount0 token0 amount + * @param amount1 token1 amount + * @param useFullPrecision if false, liquidity will be maximized according to what the router can calculate, + * not what core can theoretically support + */ +func MaxLiquidityForAmounts(sqrtRatioCurrentX96 *big.Int, sqrtRatioAX96, sqrtRatioBX96 *big.Int, amount0, amount1 *big.Int, useFullPrecision bool) *big.Int { + if sqrtRatioAX96.Cmp(sqrtRatioBX96) > 0 { + sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 + } + var maxLiquidityForAmount0 func(*big.Int, *big.Int, *big.Int) *big.Int + if useFullPrecision { + maxLiquidityForAmount0 = maxLiquidityForAmount0Precise + } else { + maxLiquidityForAmount0 = maxLiquidityForAmount0Imprecise + } + if sqrtRatioCurrentX96.Cmp(sqrtRatioAX96) <= 0 { + return maxLiquidityForAmount0(sqrtRatioAX96, sqrtRatioBX96, amount0) + } else if sqrtRatioCurrentX96.Cmp(sqrtRatioBX96) < 0 { + liquidity0 := maxLiquidityForAmount0(sqrtRatioCurrentX96, sqrtRatioBX96, amount0) + liquidity1 := maxLiquidityForAmount1(sqrtRatioAX96, sqrtRatioCurrentX96, amount1) + if liquidity0.Cmp(liquidity1) < 0 { + return liquidity0 + } + return liquidity1 + + } else { + return maxLiquidityForAmount1(sqrtRatioAX96, sqrtRatioBX96, amount1) + } +} diff --git a/utils/max_liquidity_for_amounts_test.go b/utils/max_liquidity_for_amounts_test.go new file mode 100644 index 0000000..dad539b --- /dev/null +++ b/utils/max_liquidity_for_amounts_test.go @@ -0,0 +1,252 @@ +package utils + +import ( + "math/big" + "reflect" + "testing" + + "github.com/daoleno/uniswap-sdk-core/entities" +) + +func TestMaxLiquidityForAmounts(t *testing.T) { + type args struct { + sqrtRatioCurrentX96 *big.Int + sqrtRatioAX96 *big.Int + sqrtRatioBX96 *big.Int + amount0 *big.Int + amount1 *big.Int + useFullPrecision bool + } + lgamounts0, _ := new(big.Int).SetString("1214437677402050006470401421068302637228917309992228326090730924516431320489727", 10) + lgamounts1, _ := new(big.Int).SetString("1214437677402050006470401421098959354205873606971497132040612572422243086574654", 10) + lgamounts2, _ := new(big.Int).SetString("1214437677402050006470401421082903520362793114274352355276488318240158678126184", 10) + tests := []struct { + name string + args args + want *big.Int + }{ + { + name: "imprecise - price inside - 100 token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + big.NewInt(200), + false, + }, + want: big.NewInt(2148), + }, + { + name: "imprecise - price inside - 100 token0, max token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + entities.MaxUint256, + false, + }, + want: big.NewInt(2148), + }, + { + name: "imprecise - price inside - max token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + entities.MaxUint256, + big.NewInt(200), + false, + }, + want: big.NewInt(4297), + }, + { + name: "imprecise - price below - 100 token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(99), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + big.NewInt(200), + false, + }, + want: big.NewInt(1048), + }, + { + name: "imprecise - price below - 100 token0, max token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(99), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + entities.MaxUint256, + false, + }, + want: big.NewInt(1048), + }, + { + name: "imprecise - price below - max token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(99), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + entities.MaxUint256, + big.NewInt(200), + false, + }, + want: lgamounts0, + }, + { + name: "imprecise - price above - 100 token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(111), big.NewInt(100)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + big.NewInt(200), + false, + }, + want: big.NewInt(2097), + }, + { + name: "imprecise - price above - 100 token0, max token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(111), big.NewInt(100)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + entities.MaxUint256, + false, + }, + want: lgamounts1, + }, + { + name: "imprecise - price above - max token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(111), big.NewInt(100)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + entities.MaxUint256, + big.NewInt(200), + false, + }, + want: big.NewInt(2097), + }, + { + name: "precise - price inside - 100 token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + big.NewInt(200), + true, + }, + want: big.NewInt(2148), + }, + { + name: "precise - price inside - 100 token0, max token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + entities.MaxUint256, + true, + }, + want: big.NewInt(2148), + }, + { + name: "precise - price inside - max token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(1), big.NewInt(1)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + entities.MaxUint256, + big.NewInt(200), + true, + }, + want: big.NewInt(4297), + }, + { + name: "precise - price below - 100 token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(99), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + big.NewInt(200), + true, + }, + want: big.NewInt(1048), + }, + { + name: "precise - price below - 100 token0, max token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(99), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + entities.MaxUint256, + true, + }, + want: big.NewInt(1048), + }, + { + name: "precise - price below - max token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(99), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + entities.MaxUint256, + big.NewInt(200), + true, + }, + want: lgamounts2, + }, + { + name: "precise - price above - 100 token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(111), big.NewInt(100)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + big.NewInt(200), + true, + }, + want: big.NewInt(2097), + }, + { + name: "precise - price above - 100 token0, max token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(111), big.NewInt(100)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + big.NewInt(100), + entities.MaxUint256, + true, + }, + want: lgamounts1, + }, + { + name: "precise - price above - max token0, 200 token1", + args: args{ + EncodeSqrtRatioX96(big.NewInt(111), big.NewInt(100)), + EncodeSqrtRatioX96(big.NewInt(100), big.NewInt(110)), + EncodeSqrtRatioX96(big.NewInt(110), big.NewInt(100)), + entities.MaxUint256, + big.NewInt(200), + true, + }, + want: big.NewInt(2097), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := MaxLiquidityForAmounts(tt.args.sqrtRatioCurrentX96, tt.args.sqrtRatioAX96, tt.args.sqrtRatioBX96, tt.args.amount0, tt.args.amount1, tt.args.useFullPrecision); !reflect.DeepEqual(got, tt.want) { + t.Errorf("maxLiquidityForAmounts() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/utils/most_significant_bit.go b/utils/most_significant_bit.go new file mode 100644 index 0000000..09d639c --- /dev/null +++ b/utils/most_significant_bit.go @@ -0,0 +1,29 @@ +package utils + +import ( + "errors" + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/daoleno/uniswap-sdk-core/entities" +) + +var ErrInvalidInput = errors.New("invalid input") + +func MostSignificantBit(x *big.Int) (int64, error) { + if x.Cmp(constants.Zero) <= 0 { + return 0, ErrInvalidInput + } + if x.Cmp(entities.MaxUint256) > 0 { + return 0, ErrInvalidInput + } + var msb int64 + for _, power := range []int64{128, 64, 32, 16, 8, 4, 2, 1} { + min := new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(power)), nil) + if x.Cmp(min) >= 0 { + x = new(big.Int).Rsh(x, uint(power)) + msb += power + } + } + return msb, nil +} diff --git a/utils/price_tick_conversions.go b/utils/price_tick_conversions.go new file mode 100644 index 0000000..e5270ec --- /dev/null +++ b/utils/price_tick_conversions.go @@ -0,0 +1,68 @@ +package utils + +import ( + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/daoleno/uniswap-sdk-core/entities" +) + +/** + * Returns a price object corresponding to the input tick and the base/quote token + * Inputs must be tokens because the address order is used to interpret the price represented by the tick + * @param baseToken the base token of the price + * @param quoteToken the quote token of the price + * @param tick the tick for which to return the price + */ +func TickToPrice(baseToken *entities.Token, quoteToken *entities.Token, tick int) (*entities.Price, error) { + sqrtRatioX96, err := GetSqrtRatioAtTick(tick) + if err != nil { + return nil, err + } + ratioX192 := new(big.Int).Mul(sqrtRatioX96, sqrtRatioX96) + + sorted, err := baseToken.SortsBefore(quoteToken) + if err != nil { + return nil, err + } + if sorted { + return entities.NewPrice(baseToken, quoteToken, constants.Q192, ratioX192), nil + } + return entities.NewPrice(baseToken, quoteToken, ratioX192, constants.Q192), nil +} + +/** + * Returns the first tick for which the given price is greater than or equal to the tick price + * @param price for which to return the closest tick that represents a price less than or equal to the input price, + * i.e. the price of the returned tick is less than or equal to the input price + */ +func PriceToClosestTick(price *entities.Price, baseToken, quoteToken *entities.Token) (int, error) { + sorted, err := baseToken.SortsBefore(quoteToken) + if err != nil { + return 0, err + } + var sqrtRatioX96 *big.Int + if sorted { + sqrtRatioX96 = EncodeSqrtRatioX96(price.Numerator, price.Denominator) + } else { + sqrtRatioX96 = EncodeSqrtRatioX96(price.Denominator, price.Numerator) + } + tick, err := GetTickAtSqrtRatio(sqrtRatioX96) + if err != nil { + return 0, err + } + nextTickPrice, err := TickToPrice(baseToken, quoteToken, tick+1) + if err != nil { + return 0, err + } + if sorted { + if !price.LessThan(nextTickPrice.Fraction) { + tick++ + } + } else { + if !price.GreaterThan(nextTickPrice.Fraction) { + tick++ + } + } + return tick, nil +} diff --git a/utils/price_tick_conversions_test.go b/utils/price_tick_conversions_test.go new file mode 100644 index 0000000..9253677 --- /dev/null +++ b/utils/price_tick_conversions_test.go @@ -0,0 +1,111 @@ +package utils + +import ( + "fmt" + "math/big" + "strings" + "testing" + + "github.com/daoleno/uniswap-sdk-core/entities" + "github.com/ethereum/go-ethereum/common" + "github.com/shopspring/decimal" +) + +func token(sortOrder, decimals, chainID uint) *entities.Token { + if sortOrder > 9 || sortOrder%1 != 0 { + panic("invalid sort order") + } + address := common.HexToAddress("0x" + strings.Repeat(fmt.Sprint(sortOrder), 40)) + return entities.NewToken(chainID, address, decimals, fmt.Sprintf("T%d", sortOrder), fmt.Sprintf("token%d", sortOrder)) +} + +var ( + token0 = token(0, 18, 1) + token1 = token(1, 18, 1) + token2_6decimals = token(2, 6, 1) +) + +func TestTickToPrice(t *testing.T) { + type args struct { + baseToken *entities.Token + quoteToken *entities.Token + tick int + } + tests := []struct { + name string + args args + wantSignificant string + }{ + {"1800 t0/1 t1", args{token1, token0, -74959}, "1800"}, + {"1 t1/1800 t0", args{token0, token1, -74959}, "0.00055556"}, + {"1800 t1/1 t0", args{token0, token1, 74959}, "1800"}, + {"1 t0/1800 t1", args{token1, token0, 74959}, "0.00055556"}, + + // 12 decimal difference + {"1.01 t2/1 t0", args{token1, token2_6decimals, -276225}, "1.01"}, + {"1 t0/1.01 t2", args{token2_6decimals, token0, -276225}, "0.99015"}, + {"1 t2/1.01 t0", args{token0, token2_6decimals, -276423}, "0.99015"}, + {"1.01 t0/1 t2", args{token2_6decimals, token0, -276423}, "1.0099"}, + {"1.01 t2/1 t0", args{token0, token2_6decimals, -276225}, "1.01"}, + {"1 t0/1.01 t2", args{token2_6decimals, token0, -276225}, "0.99015"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := TickToPrice(tt.args.baseToken, tt.args.quoteToken, tt.args.tick) + if err != nil { + t.Errorf("TickToPrice() error = %v", err) + return + } + if got.ToSignificant(5) != tt.wantSignificant { + t.Errorf("TickToPrice() = %v, want %v", got, tt.wantSignificant) + } + }) + } +} + +func TestPriceToClosestTick(t *testing.T) { + tickToPriceNoError := func(baseToken *entities.Token, quoteToken *entities.Token, tick int) *entities.Price { + p, err := TickToPrice(baseToken, quoteToken, tick) + if err != nil { + panic(err) + } + return p + } + type args struct { + price *entities.Price + baseToken *entities.Token + quoteToken *entities.Token + } + B100e18 := decimal.NewFromBigInt(big.NewInt(100), 18).BigInt() + + tests := []struct { + name string + args args + wantTick int + }{ + {"1800 t0/1 t1", args{entities.NewPrice(token1, token0, big.NewInt(1), big.NewInt(1800)), token1, token0}, -74960}, + {"1 t1/1800 t0", args{entities.NewPrice(token0, token1, big.NewInt(1800), big.NewInt(1)), token0, token1}, -74960}, + {"1.01 t2/1 t0", args{entities.NewPrice(token0, token2_6decimals, B100e18, big.NewInt(101e6)), token0, token2_6decimals}, -276225}, + {"1 t0/1.01 t2", args{entities.NewPrice(token2_6decimals, token0, big.NewInt(101e6), B100e18), token2_6decimals, token0}, -276225}, + + // reciprocal with tickToPrice + {"1800 t0/1 t1", args{tickToPriceNoError(token1, token0, -74960), token1, token0}, -74960}, + {"1 t0/1800 t1", args{tickToPriceNoError(token1, token0, 74960), token1, token0}, 74960}, + {"1 t1/1800 t0", args{tickToPriceNoError(token0, token1, -74960), token0, token1}, -74960}, + {"1800 t1/1 t0", args{tickToPriceNoError(token0, token1, 74960), token0, token1}, 74960}, + {"1.01 t2/1 t0", args{tickToPriceNoError(token0, token2_6decimals, -276225), token0, token2_6decimals}, -276225}, + {"1 t0/1.01 t2", args{tickToPriceNoError(token2_6decimals, token0, -276225), token2_6decimals, token0}, -276225}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := PriceToClosestTick(tt.args.price, tt.args.baseToken, tt.args.quoteToken) + if err != nil { + t.Errorf("PriceToClosestTick() error = %v", err) + return + } + if got != tt.wantTick { + t.Errorf("PriceToClosestTick() = %v, want %v", got, tt.wantTick) + } + }) + } +} diff --git a/utils/sqrtprice_math.go b/utils/sqrtprice_math.go new file mode 100644 index 0000000..5c40197 --- /dev/null +++ b/utils/sqrtprice_math.go @@ -0,0 +1,125 @@ +package utils + +import ( + "errors" + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/daoleno/uniswap-sdk-core/entities" +) + +var ( + ErrSqrtPriceLessThanZero = errors.New("sqrt price less than zero") + ErrLiquidityLessThanZero = errors.New("liquidity less than zero") + ErrInvariant = errors.New("invariant violation") +) +var MaxUint160 = new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(160), nil), constants.One) + +func multiplyIn256(x, y *big.Int) *big.Int { + product := new(big.Int).Mul(x, y) + return new(big.Int).And(product, entities.MaxUint256) +} + +func addIn256(x, y *big.Int) *big.Int { + sum := new(big.Int).Add(x, y) + return new(big.Int).And(sum, entities.MaxUint256) +} + +func GetAmount0Delta(sqrtRatioAX96, sqrtRatioBX96, liquidity *big.Int, roundUp bool) *big.Int { + // https://github.com/Uniswap/v3-core/blob/d8b1c635c275d2a9450bd6a78f3fa2484fef73eb/contracts/libraries/SqrtPriceMath.sol#L159 + if sqrtRatioAX96.Cmp(sqrtRatioBX96) > 0 { + sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 + } + + numerator1 := new(big.Int).Lsh(liquidity, 96) + numerator2 := new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96) + + if roundUp { + return MulDivRoundingUp(MulDivRoundingUp(numerator1, numerator2, sqrtRatioBX96), constants.One, sqrtRatioAX96) + } + return new(big.Int).Div(new(big.Int).Div(new(big.Int).Mul(numerator1, numerator2), sqrtRatioBX96), sqrtRatioAX96) +} + +func GetAmount1Delta(sqrtRatioAX96, sqrtRatioBX96, liquidity *big.Int, roundUp bool) *big.Int { + // https://github.com/Uniswap/v3-core/blob/d8b1c635c275d2a9450bd6a78f3fa2484fef73eb/contracts/libraries/SqrtPriceMath.sol#L188 + if sqrtRatioAX96.Cmp(sqrtRatioBX96) > 0 { + sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 + } + + if roundUp { + return MulDivRoundingUp(liquidity, new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96), constants.Q96) + } + return new(big.Int).Div(new(big.Int).Mul(liquidity, new(big.Int).Sub(sqrtRatioBX96, sqrtRatioAX96)), constants.Q96) +} + +func GetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn *big.Int, zeroForOne bool) (*big.Int, error) { + if sqrtPX96.Cmp(constants.Zero) <= 0 { + return nil, ErrSqrtPriceLessThanZero + } + if liquidity.Cmp(constants.Zero) <= 0 { + return nil, ErrLiquidityLessThanZero + } + if zeroForOne { + return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountIn, true) + } + return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true) +} + +func GetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut *big.Int, zeroForOne bool) (*big.Int, error) { + if sqrtPX96.Cmp(constants.Zero) <= 0 { + return nil, ErrSqrtPriceLessThanZero + } + if liquidity.Cmp(constants.Zero) <= 0 { + return nil, ErrLiquidityLessThanZero + } + if zeroForOne { + return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountOut, false) + } + return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false) +} + +func getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amount *big.Int, add bool) (*big.Int, error) { + if amount.Cmp(constants.Zero) == 0 { + return sqrtPX96, nil + } + + numerator1 := new(big.Int).Lsh(liquidity, 96) + if add { + product := multiplyIn256(amount, sqrtPX96) + if new(big.Int).Div(product, amount).Cmp(sqrtPX96) == 0 { + denominator := addIn256(numerator1, product) + if denominator.Cmp(numerator1) >= 0 { + return MulDivRoundingUp(numerator1, sqrtPX96, denominator), nil + } + } + return MulDivRoundingUp(numerator1, constants.One, new(big.Int).Add(new(big.Int).Div(numerator1, sqrtPX96), amount)), nil + } else { + product := multiplyIn256(amount, sqrtPX96) + if new(big.Int).Div(product, amount).Cmp(sqrtPX96) != 0 { + return nil, ErrInvariant + } + if numerator1.Cmp(product) <= 0 { + return nil, ErrInvariant + } + denominator := new(big.Int).Sub(numerator1, product) + return MulDivRoundingUp(numerator1, sqrtPX96, denominator), nil + } +} + +func getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amount *big.Int, add bool) (*big.Int, error) { + if add { + var quotient *big.Int + if amount.Cmp(MaxUint160) <= 0 { + quotient = new(big.Int).Div(new(big.Int).Lsh(amount, 96), liquidity) + } else { + quotient = new(big.Int).Div(new(big.Int).Mul(amount, constants.Q96), liquidity) + } + return new(big.Int).Add(sqrtPX96, quotient), nil + } + + quotient := MulDivRoundingUp(amount, constants.Q96, liquidity) + if sqrtPX96.Cmp(quotient) <= 0 { + return nil, ErrInvariant + } + return new(big.Int).Sub(sqrtPX96, quotient), nil +} diff --git a/utils/swap_math.go b/utils/swap_math.go new file mode 100644 index 0000000..c70e6e6 --- /dev/null +++ b/utils/swap_math.go @@ -0,0 +1,76 @@ +package utils + +import ( + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" +) + +var MaxFee = new(big.Int).Exp(big.NewInt(10), big.NewInt(6), nil) + +func ComputeSwapStep(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, amountRemaining *big.Int, feePips constants.FeeAmount) (sqrtRatioNextX96, amountIn, amountOut, feeAmount *big.Int, err error) { + zeroForOne := sqrtRatioCurrentX96.Cmp(sqrtRatioTargetX96) >= 0 + exactIn := amountRemaining.Cmp(constants.Zero) >= 0 + + if exactIn { + amountRemainingLessFee := new(big.Int).Div(new(big.Int).Mul(amountRemaining, new(big.Int).Sub(MaxFee, big.NewInt(int64(feePips)))), MaxFee) + if zeroForOne { + amountIn = GetAmount0Delta(sqrtRatioTargetX96, sqrtRatioCurrentX96, liquidity, true) + } else { + amountIn = GetAmount1Delta(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, true) + } + if amountRemainingLessFee.Cmp(amountIn) >= 0 { + sqrtRatioNextX96 = sqrtRatioTargetX96 + } else { + sqrtRatioNextX96, err = GetNextSqrtPriceFromInput(sqrtRatioCurrentX96, liquidity, amountRemainingLessFee, zeroForOne) + if err != nil { + return + } + } + } else { + if zeroForOne { + amountOut = GetAmount1Delta(sqrtRatioTargetX96, sqrtRatioCurrentX96, liquidity, false) + } else { + amountOut = GetAmount0Delta(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, false) + } + if new(big.Int).Mul(amountRemaining, constants.NegativeOne).Cmp(amountOut) >= 0 { + sqrtRatioNextX96 = sqrtRatioTargetX96 + } else { + sqrtRatioNextX96, err = GetNextSqrtPriceFromOutput(sqrtRatioCurrentX96, liquidity, new(big.Int).Mul(amountRemaining, constants.NegativeOne), zeroForOne) + if err != nil { + return + } + } + } + + max := sqrtRatioTargetX96.Cmp(sqrtRatioNextX96) == 0 + + if zeroForOne { + if !(max && exactIn) { + amountIn = GetAmount0Delta(sqrtRatioNextX96, sqrtRatioCurrentX96, liquidity, true) + } + if !(max && !exactIn) { + amountOut = GetAmount1Delta(sqrtRatioNextX96, sqrtRatioCurrentX96, liquidity, false) + } + } else { + if !(max && exactIn) { + amountIn = GetAmount1Delta(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, true) + } + if !(max && !exactIn) { + amountOut = GetAmount0Delta(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, false) + } + } + + if !exactIn && amountOut.Cmp(new(big.Int).Mul(amountRemaining, constants.NegativeOne)) > 0 { + amountOut = new(big.Int).Mul(amountRemaining, constants.NegativeOne) + } + + if exactIn && sqrtRatioNextX96.Cmp(sqrtRatioTargetX96) != 0 { + // we didn't reach the target, so take the remainder of the maximum input as fee + feeAmount = new(big.Int).Sub(amountRemaining, amountIn) + } else { + feeAmount = MulDivRoundingUp(amountIn, big.NewInt(int64(feePips)), new(big.Int).Sub(MaxFee, big.NewInt(int64(feePips)))) + } + + return +} diff --git a/utils/tick_math.go b/utils/tick_math.go new file mode 100644 index 0000000..4a56a1f --- /dev/null +++ b/utils/tick_math.go @@ -0,0 +1,197 @@ +package utils + +import ( + "errors" + "math/big" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/daoleno/uniswap-sdk-core/entities" +) + +const ( + MinTick = -887272 // The minimum tick that can be used on any pool. + MaxTick = -MinTick // The maximum tick that can be used on any pool. +) + +var ( + Q32 = big.NewInt(1 << 32) + MinSqrtRatio = big.NewInt(4295128739) // The sqrt ratio corresponding to the minimum tick that could be used on any pool. + MaxSqrtRatio, _ = new(big.Int).SetString("1461446703485210103287273052203988822378723970342", 10) // The sqrt ratio corresponding to the maximum tick that could be used on any pool. +) + +var ( + ErrInvalidTick = errors.New("invalid tick") + ErrInvalidSqrtRatio = errors.New("invalid sqrt ratio") +) + +func mulShift(val *big.Int, mulBy *big.Int) *big.Int { + + return new(big.Int).Rsh(new(big.Int).Mul(val, mulBy), 128) +} + +var ( + sqrtConst1, _ = new(big.Int).SetString("fffcb933bd6fad37aa2d162d1a594001", 16) + sqrtConst2, _ = new(big.Int).SetString("100000000000000000000000000000000", 16) + sqrtConst3, _ = new(big.Int).SetString("fff97272373d413259a46990580e213a", 16) + sqrtConst4, _ = new(big.Int).SetString("fff2e50f5f656932ef12357cf3c7fdcc", 16) + sqrtConst5, _ = new(big.Int).SetString("ffe5caca7e10e4e61c3624eaa0941cd0", 16) + sqrtConst6, _ = new(big.Int).SetString("ffcb9843d60f6159c9db58835c926644", 16) + sqrtConst7, _ = new(big.Int).SetString("ff973b41fa98c081472e6896dfb254c0", 16) + sqrtConst8, _ = new(big.Int).SetString("ff2ea16466c96a3843ec78b326b52861", 16) + sqrtConst9, _ = new(big.Int).SetString("fe5dee046a99a2a811c461f1969c3053", 16) + sqrtConst10, _ = new(big.Int).SetString("fcbe86c7900a88aedcffc83b479aa3a4", 16) + sqrtConst11, _ = new(big.Int).SetString("f987a7253ac413176f2b074cf7815e54", 16) + sqrtConst12, _ = new(big.Int).SetString("f3392b0822b70005940c7a398e4b70f3", 16) + sqrtConst13, _ = new(big.Int).SetString("e7159475a2c29b7443b29c7fa6e889d9", 16) + sqrtConst14, _ = new(big.Int).SetString("d097f3bdfd2022b8845ad8f792aa5825", 16) + sqrtConst15, _ = new(big.Int).SetString("a9f746462d870fdf8a65dc1f90e061e5", 16) + sqrtConst16, _ = new(big.Int).SetString("70d869a156d2a1b890bb3df62baf32f7", 16) + sqrtConst17, _ = new(big.Int).SetString("31be135f97d08fd981231505542fcfa6", 16) + sqrtConst18, _ = new(big.Int).SetString("9aa508b5b7a84e1c677de54f3e99bc9", 16) + sqrtConst19, _ = new(big.Int).SetString("5d6af8dedb81196699c329225ee604", 16) + sqrtConst20, _ = new(big.Int).SetString("2216e584f5fa1ea926041bedfe98", 16) + sqrtConst21, _ = new(big.Int).SetString("48a170391f7dc42444e8fa2", 16) +) + +/** + * Returns the sqrt ratio as a Q64.96 for the given tick. The sqrt ratio is computed as sqrt(1.0001)^tick + * @param tick the tick for which to compute the sqrt ratio + */ +func GetSqrtRatioAtTick(tick int) (*big.Int, error) { + if tick < MinTick || tick > MaxTick { + return nil, ErrInvalidTick + } + absTick := tick + if tick < 0 { + absTick = -tick + } + var ratio *big.Int + if absTick&0x1 != 0 { + ratio = sqrtConst1 + } else { + ratio = sqrtConst2 + } + if (absTick & 0x2) != 0 { + ratio = mulShift(ratio, sqrtConst3) + } + if (absTick & 0x4) != 0 { + ratio = mulShift(ratio, sqrtConst4) + } + if (absTick & 0x8) != 0 { + ratio = mulShift(ratio, sqrtConst5) + } + if (absTick & 0x10) != 0 { + ratio = mulShift(ratio, sqrtConst6) + } + if (absTick & 0x20) != 0 { + ratio = mulShift(ratio, sqrtConst7) + } + if (absTick & 0x40) != 0 { + ratio = mulShift(ratio, sqrtConst8) + } + if (absTick & 0x80) != 0 { + ratio = mulShift(ratio, sqrtConst9) + } + if (absTick & 0x100) != 0 { + ratio = mulShift(ratio, sqrtConst10) + } + if (absTick & 0x200) != 0 { + ratio = mulShift(ratio, sqrtConst11) + } + if (absTick & 0x400) != 0 { + ratio = mulShift(ratio, sqrtConst12) + } + if (absTick & 0x800) != 0 { + ratio = mulShift(ratio, sqrtConst13) + } + if (absTick & 0x1000) != 0 { + ratio = mulShift(ratio, sqrtConst14) + } + if (absTick & 0x2000) != 0 { + ratio = mulShift(ratio, sqrtConst15) + } + if (absTick & 0x4000) != 0 { + ratio = mulShift(ratio, sqrtConst16) + } + if (absTick & 0x8000) != 0 { + ratio = mulShift(ratio, sqrtConst17) + } + if (absTick & 0x10000) != 0 { + ratio = mulShift(ratio, sqrtConst18) + } + if (absTick & 0x20000) != 0 { + ratio = mulShift(ratio, sqrtConst19) + } + if (absTick & 0x40000) != 0 { + ratio = mulShift(ratio, sqrtConst20) + } + if (absTick & 0x80000) != 0 { + ratio = mulShift(ratio, sqrtConst21) + } + if tick > 0 { + ratio = new(big.Int).Div(entities.MaxUint256, ratio) + } + + // back to Q96 + if new(big.Int).Rem(ratio, Q32).Cmp(constants.Zero) > 0 { + return new(big.Int).Add((new(big.Int).Div(ratio, Q32)), constants.One), nil + } else { + return new(big.Int).Div(ratio, Q32), nil + } +} + +var ( + magicSqrt10001, _ = new(big.Int).SetString("255738958999603826347141", 10) + magicTickLow, _ = new(big.Int).SetString("3402992956809132418596140100660247210", 10) + magicTickHigh, _ = new(big.Int).SetString("291339464771989622907027621153398088495", 10) +) + +/** + * Returns the tick corresponding to a given sqrt ratio, s.t. #getSqrtRatioAtTick(tick) <= sqrtRatioX96 + * and #getSqrtRatioAtTick(tick + 1) > sqrtRatioX96 + * @param sqrtRatioX96 the sqrt ratio as a Q64.96 for which to compute the tick + */ +func GetTickAtSqrtRatio(sqrtRatioX96 *big.Int) (int, error) { + if sqrtRatioX96.Cmp(MinSqrtRatio) < 0 || sqrtRatioX96.Cmp(MaxSqrtRatio) >= 0 { + return 0, ErrInvalidSqrtRatio + } + sqrtRatioX128 := new(big.Int).Lsh(sqrtRatioX96, 32) + msb, err := MostSignificantBit(sqrtRatioX128) + if err != nil { + return 0, err + } + var r *big.Int + if big.NewInt(msb).Cmp(big.NewInt(128)) >= 0 { + r = new(big.Int).Rsh(sqrtRatioX128, uint(msb-127)) + } else { + r = new(big.Int).Lsh(sqrtRatioX128, uint(127-msb)) + } + + log2 := new(big.Int).Lsh(new(big.Int).Sub(big.NewInt(msb), big.NewInt(128)), 64) + + for i := 0; i < 14; i++ { + r = new(big.Int).Rsh(new(big.Int).Mul(r, r), 127) + f := new(big.Int).Rsh(r, 128) + log2 = new(big.Int).Or(log2, new(big.Int).Lsh(f, uint(63-i))) + r = new(big.Int).Rsh(r, uint(f.Int64())) + } + + logSqrt10001 := new(big.Int).Mul(log2, magicSqrt10001) + + tickLow := new(big.Int).Rsh(new(big.Int).Sub(logSqrt10001, magicTickLow), 128).Int64() + tickHigh := new(big.Int).Rsh(new(big.Int).Add(logSqrt10001, magicTickHigh), 128).Int64() + + if tickLow == tickHigh { + return int(tickLow), nil + } + + sqrtRatio, err := GetSqrtRatioAtTick(int(tickHigh)) + if err != nil { + return 0, err + } + if sqrtRatio.Cmp(sqrtRatioX96) <= 0 { + return int(tickHigh), nil + } else { + return int(tickLow), nil + } +} diff --git a/utils/tick_math_test.go b/utils/tick_math_test.go new file mode 100644 index 0000000..745c826 --- /dev/null +++ b/utils/tick_math_test.go @@ -0,0 +1,34 @@ +package utils + +import ( + "math/big" + "testing" + + "github.com/KyberNetwork/pancake-v3-sdk/constants" + "github.com/stretchr/testify/assert" +) + +func TestGetSqrtRatioAtTick(t *testing.T) { + _, err := GetSqrtRatioAtTick(MinTick - 1) + assert.ErrorIs(t, err, ErrInvalidTick, "tick tool small") + + _, err = GetSqrtRatioAtTick(MaxTick + 1) + assert.ErrorIs(t, err, ErrInvalidTick, "tick tool large") + + rmax, _ := GetSqrtRatioAtTick(MinTick) + assert.Equal(t, rmax, MinSqrtRatio, "returns the correct value for min tick") + + r0, _ := GetSqrtRatioAtTick(0) + assert.Equal(t, r0, new(big.Int).Lsh(constants.One, 96), "returns the correct value for tick 0") + + rmin, _ := GetSqrtRatioAtTick(MaxTick) + assert.Equal(t, rmin, MaxSqrtRatio, "returns the correct value for max tick") +} + +func TestGetTickAtSqrtRatio(t *testing.T) { + tmin, _ := GetTickAtSqrtRatio(MinSqrtRatio) + assert.Equal(t, tmin, MinTick, "returns the correct value for sqrt ratio at min tick") + + tmax, _ := GetTickAtSqrtRatio(new(big.Int).Sub(MaxSqrtRatio, constants.One)) + assert.Equal(t, tmax, MaxTick-1, "returns the correct value for sqrt ratio at max tick") +}