From 6e4f49ea7251b71f72f5c48a239d0af88aaca7d7 Mon Sep 17 00:00:00 2001 From: Hoang Trinh Date: Thu, 8 Jun 2023 12:25:30 +0700 Subject: [PATCH] fix: case liquidityNet is nil (#6) * fix: case liquidityNet is nil * refactor: remove unused code --- entities/pool.go | 4 + entities/ticklist.go | 94 +++++----------- entities/ticklist_test.go | 222 +++++++++++++++++++++++++++++--------- entities/trade_test.go | 10 +- 4 files changed, 211 insertions(+), 119 deletions(-) diff --git a/entities/pool.go b/entities/pool.go index 5f69da5..639e491 100644 --- a/entities/pool.go +++ b/entities/pool.go @@ -415,6 +415,10 @@ func (p *Pool) _updateLiquidityAndCrossTick( liquidityNet = new(big.Int).Mul(liquidityNet, constants.NegativeOne) } + if liquidityNet == nil { + return constants.Zero, newNextTick + } + var liquidityDelta *big.Int if liquidityNet.Cmp(constants.Zero) >= 0 { liquidityDelta = liquidityNet diff --git a/entities/ticklist.go b/entities/ticklist.go index c113545..f766dc8 100644 --- a/entities/ticklist.go +++ b/entities/ticklist.go @@ -2,15 +2,13 @@ package entities import ( "errors" - "math" "math/big" "github.com/KyberNetwork/elastic-go-sdk/v2/utils" ) const ( - TickIndexZero = 0 - TickNotInitialized = false + TickIndexZero = 0 ) var ( @@ -21,7 +19,6 @@ var ( ErrEmptyTickList = errors.New("empty tick list") ErrBelowSmallest = errors.New("below smallest") ErrAtOrAboveLargest = errors.New("at or above largest") - ErrInvalidTickIndex = errors.New("invalid tick index") ) var ( @@ -72,21 +69,6 @@ func IsAtOrAboveLargest(ticks []Tick, tick int) (bool, error) { return tick >= ticks[len(ticks)-1].Index, nil } -func GetTick(ticks []Tick, index int) (Tick, error) { - tickIndex, err := binarySearch(ticks, index) - if err != nil { - return EmptyTick, err - } - - if tickIndex < 0 { - return EmptyTick, ErrInvalidTickIndex - } - - tick := ticks[tickIndex] - - return tick, nil -} - func NextInitializedTick(ticks []Tick, tick int, lte bool) (Tick, error) { if lte { isBelowSmallest, err := IsBelowSmallest(ticks, tick) @@ -142,53 +124,9 @@ func NextInitializedTick(ticks []Tick, tick int, lte bool) (Tick, error) { } } -func NextInitializedTickWithinOneWord(ticks []Tick, tick int, lte bool, tickSpacing int) (int, bool, error) { - compressed := math.Floor(float64(tick) / float64(tickSpacing)) // matches rounding in the code - - if lte { - wordPos := int(compressed) >> 8 - minimum := (wordPos << 8) * tickSpacing - isBelowSmallest, err := IsBelowSmallest(ticks, tick) - if err != nil { - return TickIndexZero, TickNotInitialized, err - } - - if isBelowSmallest { - return minimum, TickNotInitialized, ErrBelowSmallest - } - - nextInitializedTick, err := NextInitializedTick(ticks, tick, lte) - if err != nil { - return TickIndexZero, TickNotInitialized, err - } - - index := nextInitializedTick.Index - nextInitializedTickIndex := math.Max(float64(minimum), float64(index)) - return int(nextInitializedTickIndex), int(nextInitializedTickIndex) == index, nil - } else { - wordPos := int(compressed+1) >> 8 - maximum := ((wordPos+1)<<8)*tickSpacing - 1 - isAtOrAboveLargest, err := IsAtOrAboveLargest(ticks, tick) - if err != nil { - return TickIndexZero, TickNotInitialized, err - } - - if isAtOrAboveLargest { - return maximum, TickNotInitialized, ErrAtOrAboveLargest - } - - nextInitializedTick, err := NextInitializedTick(ticks, tick, lte) - if err != nil { - return TickIndexZero, TickNotInitialized, err - } - - index := nextInitializedTick.Index - nextInitializedTickIndex := math.Min(float64(maximum), float64(index)) - return int(nextInitializedTickIndex), int(nextInitializedTickIndex) == index, nil - } -} - func GetNearestCurrentTick(ticks []Tick, currentTick int) (int, error) { + // https://github.com/KyberNetwork/ks-elastic-sc/blob/3ba84353cbd88f30f222bb9c673e242a2e46fd12/contracts/PoolStorage.sol#L114 + // NearestCurrentTick is initialized with MinTick at the beginning isBelowSmallest, err := IsBelowSmallest(ticks, currentTick) if err != nil { return utils.MinTick, err @@ -210,6 +148,16 @@ func TransformToMap(ticks []Tick) (map[int]TickData, map[int]LinkedListData) { tickDataByIndex := make(map[int]TickData) initializedTicks := make(map[int]LinkedListData) + // Init the initializedTicks + initializedTicks[utils.MinTick] = LinkedListData{ + Previous: utils.MinTick, + Next: utils.MaxTick, + } + initializedTicks[utils.MaxTick] = LinkedListData{ + Previous: utils.MinTick, + Next: utils.MaxTick, + } + for i, t := range ticks { tickDataByIndex[t.Index] = TickData{ LiquidityGross: t.LiquidityGross, @@ -221,16 +169,32 @@ func TransformToMap(ticks []Tick) (map[int]TickData, map[int]LinkedListData) { Next: utils.MaxTick, Previous: utils.MinTick, } + initializedTicks[utils.MinTick] = LinkedListData{ + Previous: utils.MinTick, + Next: t.Index, + } + initializedTicks[utils.MaxTick] = LinkedListData{ + Previous: t.Index, + Next: utils.MaxTick, + } } else if i == 0 { initializedTicks[t.Index] = LinkedListData{ Next: ticks[i+1].Index, Previous: utils.MinTick, } + initializedTicks[utils.MinTick] = LinkedListData{ + Previous: utils.MinTick, + Next: t.Index, + } } else if i == len(ticks)-1 { initializedTicks[t.Index] = LinkedListData{ Next: utils.MaxTick, Previous: ticks[i-1].Index, } + initializedTicks[utils.MaxTick] = LinkedListData{ + Previous: t.Index, + Next: utils.MaxTick, + } } else { initializedTicks[t.Index] = LinkedListData{ Next: ticks[i+1].Index, diff --git a/entities/ticklist_test.go b/entities/ticklist_test.go index 3027c8d..13e7d60 100644 --- a/entities/ticklist_test.go +++ b/entities/ticklist_test.go @@ -95,55 +95,6 @@ func TestNextInitializedTick(t *testing.T) { assert.ErrorIs(t, err2, ErrAtOrAboveLargest) } -func TestNextInitializedTickWithinOneWord(t *testing.T) { - ticks := []Tick{lowTick, midTick, highTick} - - // words around 0, lte = true - type args struct { - ticks []Tick - tick int - lte bool - tickSpacing int - } - tests := []struct { - name string - args args - want0 int - want1 bool - }{ - // words around 0, lte = true - {name: "lte = true 0", args: args{ticks: ticks, tick: -257, lte: true, tickSpacing: 1}, want0: -512, want1: false}, - {name: "lte = true 1", args: args{ticks: ticks, tick: -256, lte: true, tickSpacing: 1}, want0: -256, want1: false}, - {name: "lte = true 2", args: args{ticks: ticks, tick: -1, lte: true, tickSpacing: 1}, want0: -256, want1: false}, - {name: "lte = true 3", args: args{ticks: ticks, tick: 0, lte: true, tickSpacing: 1}, want0: 0, want1: true}, - {name: "lte = true 4", args: args{ticks: ticks, tick: 1, lte: true, tickSpacing: 1}, want0: 0, want1: true}, - {name: "lte = true 5", args: args{ticks: ticks, tick: 255, lte: true, tickSpacing: 1}, want0: 0, want1: true}, - {name: "lte = true 6", args: args{ticks: ticks, tick: 256, lte: true, tickSpacing: 1}, want0: 256, want1: false}, - {name: "lte = true 7", args: args{ticks: ticks, tick: 257, lte: true, tickSpacing: 1}, want0: 256, want1: false}, - - // words around 0, lte = false - {name: "lte = false 0", args: args{ticks: ticks, tick: -258, lte: false, tickSpacing: 1}, want0: -257, want1: false}, - {name: "lte = false 1", args: args{ticks: ticks, tick: -257, lte: false, tickSpacing: 1}, want0: -1, want1: false}, - {name: "lte = false 2", args: args{ticks: ticks, tick: -256, lte: false, tickSpacing: 1}, want0: -1, want1: false}, - {name: "lte = false 3", args: args{ticks: ticks, tick: -2, lte: false, tickSpacing: 1}, want0: -1, want1: false}, - {name: "lte = false 4", args: args{ticks: ticks, tick: -1, lte: false, tickSpacing: 1}, want0: 0, want1: true}, - {name: "lte = false 5", args: args{ticks: ticks, tick: 0, lte: false, tickSpacing: 1}, want0: 255, want1: false}, - {name: "lte = false 6", args: args{ticks: ticks, tick: 1, lte: false, tickSpacing: 1}, want0: 255, want1: false}, - {name: "lte = false 7", args: args{ticks: ticks, tick: 254, lte: false, tickSpacing: 1}, want0: 255, want1: false}, - {name: "lte = false 8", args: args{ticks: ticks, tick: 255, lte: false, tickSpacing: 1}, want0: 511, want1: false}, - {name: "lte = false 9", args: args{ticks: ticks, tick: 256, lte: false, tickSpacing: 1}, want0: 511, want1: false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got0, got1, _ := NextInitializedTickWithinOneWord(tt.args.ticks, tt.args.tick, tt.args.lte, tt.args.tickSpacing) - assert.Equal(t, tt.want0, got0) - assert.Equal(t, tt.want1, got1) - }) - } - -} - func TestGetNearestCurrentTick(t *testing.T) { testCases := []struct { name string @@ -177,3 +128,176 @@ func TestGetNearestCurrentTick(t *testing.T) { }) } } + +func TestTransformToMap(t *testing.T) { + type args struct { + ticks []Tick + } + tests := []struct { + name string + args args + wantTickData map[int]TickData + wantLinkedListData map[int]LinkedListData + }{ + { + name: "it should return correct data when there is no initialized tick", + args: args{ + ticks: []Tick{}, + }, + wantTickData: map[int]TickData{}, + wantLinkedListData: map[int]LinkedListData{ + utils.MinTick: { + Previous: utils.MinTick, + Next: utils.MaxTick, + }, + utils.MaxTick: { + Previous: utils.MinTick, + Next: utils.MaxTick, + }, + }, + }, + { + name: "it should return correct data when there is only 1 initialized tick", + args: args{ + ticks: []Tick{ + { + Index: 10000, + LiquidityNet: big.NewInt(-100000), + LiquidityGross: big.NewInt(100000), + }, + }, + }, + wantTickData: map[int]TickData{ + 10000: { + LiquidityNet: big.NewInt(-100000), + LiquidityGross: big.NewInt(100000), + }, + }, + wantLinkedListData: map[int]LinkedListData{ + utils.MinTick: { + Previous: utils.MinTick, + Next: 10000, + }, + 10000: { + Previous: utils.MinTick, + Next: utils.MaxTick, + }, + utils.MaxTick: { + Previous: 10000, + Next: utils.MaxTick, + }, + }, + }, + { + name: "it should return correct data when there are more than 1 initialized tick (2 ticks)", + args: args{ + ticks: []Tick{ + { + Index: 10000, + LiquidityNet: big.NewInt(-100000), + LiquidityGross: big.NewInt(100000), + }, + { + Index: 20000, + LiquidityNet: big.NewInt(-200000), + LiquidityGross: big.NewInt(200000), + }, + }, + }, + wantTickData: map[int]TickData{ + 10000: { + LiquidityNet: big.NewInt(-100000), + LiquidityGross: big.NewInt(100000), + }, + 20000: { + LiquidityNet: big.NewInt(-200000), + LiquidityGross: big.NewInt(200000), + }, + }, + wantLinkedListData: map[int]LinkedListData{ + utils.MinTick: { + Previous: utils.MinTick, + Next: 10000, + }, + 10000: { + Previous: utils.MinTick, + Next: 20000, + }, + 20000: { + Previous: 10000, + Next: utils.MaxTick, + }, + utils.MaxTick: { + Previous: 20000, + Next: utils.MaxTick, + }, + }, + }, + { + name: "it should return correct data when there are more than 1 initialized tick (3 ticks)", + args: args{ + ticks: []Tick{ + { + Index: 10000, + LiquidityNet: big.NewInt(-100000), + LiquidityGross: big.NewInt(100000), + }, + { + Index: 20000, + LiquidityNet: big.NewInt(-200000), + LiquidityGross: big.NewInt(200000), + }, + { + Index: 30000, + LiquidityNet: big.NewInt(-300000), + LiquidityGross: big.NewInt(300000), + }, + }, + }, + wantTickData: map[int]TickData{ + 10000: { + LiquidityNet: big.NewInt(-100000), + LiquidityGross: big.NewInt(100000), + }, + 20000: { + LiquidityNet: big.NewInt(-200000), + LiquidityGross: big.NewInt(200000), + }, + 30000: { + LiquidityNet: big.NewInt(-300000), + LiquidityGross: big.NewInt(300000), + }, + }, + wantLinkedListData: map[int]LinkedListData{ + utils.MinTick: { + Previous: utils.MinTick, + Next: 10000, + }, + 10000: { + Previous: utils.MinTick, + Next: 20000, + }, + 20000: { + Previous: 10000, + Next: 30000, + }, + 30000: { + Previous: 20000, + Next: utils.MaxTick, + }, + utils.MaxTick: { + Previous: 30000, + Next: utils.MaxTick, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotTickData, gotLinkedListData := TransformToMap(tt.args.ticks) + assert.Equalf(t, tt.wantTickData, gotTickData, "TransformToMap(%v)", tt.args.ticks) + assert.Equalf(t, tt.wantLinkedListData, gotLinkedListData, "TransformToMap(%v)", tt.args.ticks) + }) + } +} diff --git a/entities/trade_test.go b/entities/trade_test.go index 548b7ad..403a9de 100644 --- a/entities/trade_test.go +++ b/entities/trade_test.go @@ -550,11 +550,11 @@ func TestMaximumAmountIn(t *testing.T) { // returns slippage amount if nonzero amountIn, _ = exactOut.MaximumAmountIn(entities.NewPercent(big.NewInt(0), big.NewInt(100)), nil) - assert.True(t, amountIn.EqualTo(entities.FromRawAmount(token1, big.NewInt(10014)).Fraction)) + assert.True(t, amountIn.EqualTo(entities.FromRawAmount(token1, big.NewInt(10011)).Fraction)) amountIn, _ = exactOut.MaximumAmountIn(entities.NewPercent(big.NewInt(5), big.NewInt(100)), nil) - assert.True(t, amountIn.EqualTo(entities.FromRawAmount(token0, big.NewInt(10514)).Fraction)) + assert.True(t, amountIn.EqualTo(entities.FromRawAmount(token0, big.NewInt(10511)).Fraction)) amountIn, _ = exactOut.MaximumAmountIn(entities.NewPercent(big.NewInt(200), big.NewInt(100)), nil) - assert.True(t, amountIn.EqualTo(entities.FromRawAmount(token0, big.NewInt(30042)).Fraction)) + assert.True(t, amountIn.EqualTo(entities.FromRawAmount(token0, big.NewInt(30033)).Fraction)) } func TestMinimumAmountOut(t *testing.T) { @@ -610,11 +610,11 @@ func TestBestTradeExactOut(t *testing.T) { assert.Equal(t, len(result), 2) assert.Equal(t, len(result[1].Swaps[0].Route.Pools), 1) assert.Equal(t, result[1].Swaps[0].Route.TokenPath, []*entities.Token{token0, token2}) - assert.True(t, result[1].InputAmount().EqualTo(entities.FromRawAmount(result[1].InputAmount().Currency, big.NewInt(12229)).Fraction)) + assert.True(t, result[1].InputAmount().EqualTo(entities.FromRawAmount(result[1].InputAmount().Currency, big.NewInt(12228)).Fraction)) assert.True(t, result[1].OutputAmount().EqualTo(entities.FromRawAmount(token2, big.NewInt(10000)).Fraction)) assert.Equal(t, len(result[0].Swaps[0].Route.Pools), 2) assert.Equal(t, result[0].Swaps[0].Route.TokenPath, []*entities.Token{token0, token1, token2}) - assert.True(t, result[0].InputAmount().EqualTo(entities.FromRawAmount(token0, big.NewInt(10014)).Fraction)) + assert.True(t, result[0].InputAmount().EqualTo(entities.FromRawAmount(token0, big.NewInt(10011)).Fraction)) assert.True(t, result[0].OutputAmount().EqualTo(entities.FromRawAmount(token2, big.NewInt(10000)).Fraction)) // respects maxHops