Skip to content

Commit

Permalink
feat(amm): add test case that valid the possible exploit of join pool… (
Browse files Browse the repository at this point in the history
#883)

* feat(amm): add test case that valid the possible exploit of join pool flow due to positive effect of weight breaking fee on rebalancing without new liquidities to the pool

* test(amm): fix test

* feat(amm): join pool weight recovery reward

* test(leveragelp): fix test

* Update x/amm/keeper/msg_server_join_pool_test.go

Co-authored-by: Amit Yadav <amy29981@gmail.com>

---------

Co-authored-by: Amit Yadav <amy29981@gmail.com>
  • Loading branch information
cosmic-vagabond and amityadav0 authored Oct 28, 2024
1 parent 9aec191 commit a793d42
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 24 deletions.
32 changes: 31 additions & 1 deletion x/amm/keeper/apply_join_pool_state_change.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@ import (
"github.com/elys-network/elys/x/amm/types"
)

func (k Keeper) applyJoinPoolStateChange(ctx sdk.Context, pool types.Pool, joiner sdk.AccAddress, numShares math.Int, joinCoins sdk.Coins) error {
func (k Keeper) ApplyJoinPoolStateChange(
ctx sdk.Context,
pool types.Pool,
joiner sdk.AccAddress,
numShares math.Int,
joinCoins sdk.Coins,
weightBalanceBonus sdk.Dec,
) error {
if err := k.bankKeeper.SendCoins(ctx, joiner, sdk.MustAccAddressFromBech32(pool.GetAddress()), joinCoins); err != nil {
return err
}
Expand All @@ -17,6 +24,29 @@ func (k Keeper) applyJoinPoolStateChange(ctx sdk.Context, pool types.Pool, joine

k.SetPool(ctx, pool)

rebalanceTreasuryAddr := sdk.MustAccAddressFromBech32(pool.GetRebalanceTreasury())

if weightBalanceBonus.IsPositive() {
// calculate treasury amounts to send as bonus
weightBalanceBonusCoins := PortionCoins(joinCoins, weightBalanceBonus)
for _, coin := range weightBalanceBonusCoins {
treasuryTokenAmount := k.bankKeeper.GetBalance(ctx, rebalanceTreasuryAddr, coin.Denom).Amount
if treasuryTokenAmount.LT(coin.Amount) {
// override coin amount by treasuryTokenAmount
weightBalanceBonusCoins = weightBalanceBonusCoins.
Sub(coin). // remove the original coin
Add(sdk.NewCoin(coin.Denom, treasuryTokenAmount)) // add the treasuryTokenAmount
}
}

// send bonus tokens to recipient if positive
if weightBalanceBonusCoins.IsAllPositive() {
if err := k.bankKeeper.SendCoins(ctx, rebalanceTreasuryAddr, joiner, weightBalanceBonusCoins); err != nil {
return err
}
}
}

types.EmitAddLiquidityEvent(ctx, joiner, pool.GetPoolId(), joinCoins)
if k.hooks != nil {
err := k.hooks.AfterJoinPool(ctx, joiner, pool, joinCoins, numShares)
Expand Down
8 changes: 4 additions & 4 deletions x/amm/keeper/keeper_join_pool_no_swap.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (k Keeper) JoinPoolNoSwap(
}

snapshot := k.GetPoolSnapshotOrSet(ctx, pool)
sharesOut, _, _, err = pool.JoinPool(ctx, &snapshot, k.oracleKeeper, k.accountedPoolKeeper, tokensIn)
sharesOut, _, weightBalanceBonus, err := pool.JoinPool(ctx, &snapshot, k.oracleKeeper, k.accountedPoolKeeper, tokensIn)
if err != nil {
return nil, sdk.ZeroInt(), err
}
Expand All @@ -78,7 +78,7 @@ func (k Keeper) JoinPoolNoSwap(
shareOutAmount, sharesOut))
}

err = k.applyJoinPoolStateChange(ctx, pool, sender, sharesOut, tokensIn)
err = k.ApplyJoinPoolStateChange(ctx, pool, sender, sharesOut, tokensIn, weightBalanceBonus)
if err != nil {
return nil, math.Int{}, err
}
Expand All @@ -93,7 +93,7 @@ func (k Keeper) JoinPoolNoSwap(

// on oracle pool, full tokenInMaxs are used regardless shareOutAmount
snapshot := k.GetPoolSnapshotOrSet(ctx, pool)
sharesOut, _, _, err = pool.JoinPool(ctx, &snapshot, k.oracleKeeper, k.accountedPoolKeeper, tokenInMaxs)
sharesOut, _, weightBalanceBonus, err := pool.JoinPool(ctx, &snapshot, k.oracleKeeper, k.accountedPoolKeeper, tokenInMaxs)
if err != nil {
return nil, sdk.ZeroInt(), err
}
Expand All @@ -104,7 +104,7 @@ func (k Keeper) JoinPoolNoSwap(
shareOutAmount, sharesOut))
}

err = k.applyJoinPoolStateChange(ctx, pool, sender, sharesOut, tokenInMaxs)
err = k.ApplyJoinPoolStateChange(ctx, pool, sender, sharesOut, tokenInMaxs, weightBalanceBonus)
if err != nil {
return nil, math.Int{}, err
}
Expand Down
51 changes: 51 additions & 0 deletions x/amm/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,41 @@ import (
"github.com/stretchr/testify/suite"
)

type assetPriceInfo struct {
denom string
display string
price sdk.Dec
}

const (
initChain = true
)

var (
priceMap = map[string]assetPriceInfo{
"uusdc": {
denom: ptypes.BaseCurrency,
display: "USDC",
price: sdk.OneDec(),
},
"uusdt": {
denom: "uusdt",
display: "USDT",
price: sdk.OneDec(),
},
"uelys": {
denom: ptypes.Elys,
display: "ELYS",
price: sdk.MustNewDecFromStr("3.0"),
},
"uatom": {
denom: ptypes.ATOM,
display: "ATOM",
price: sdk.MustNewDecFromStr("1.0"),
},
}
)

type KeeperTestSuite struct {
suite.Suite

Expand Down Expand Up @@ -80,6 +111,26 @@ func (suite *KeeperTestSuite) SetupStableCoinPrices() {
})
}

func (suite *KeeperTestSuite) SetupCoinPrices() {
// prices set for USDT and USDC
provider := sdk.AccAddress(ed25519.GenPrivKey().PubKey().Address())

for _, v := range priceMap {
suite.app.OracleKeeper.SetAssetInfo(suite.ctx, oracletypes.AssetInfo{
Denom: v.denom,
Display: v.display,
Decimal: 6,
})
suite.app.OracleKeeper.SetPrice(suite.ctx, oracletypes.Price{
Asset: v.display,
Price: v.price,
Source: "elys",
Provider: provider.String(),
Timestamp: uint64(suite.ctx.BlockTime().Unix()),
})
}
}

func SetupMockPools(k *keeper.Keeper, ctx sdk.Context) {
// Create and set mock pools
pools := []types.Pool{
Expand Down
118 changes: 117 additions & 1 deletion x/amm/keeper/msg_server_join_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (suite *KeeperTestSuite) TestMsgServerJoinPool() {
FeeDenom: ptypes.BaseCurrency,
},
// shareOutAmount: sdk.NewInt(805987500000000000), // weight recovery direction - slippage enable
shareOutAmount: sdk.NewInt(1002500000000000000), // weight recovery direction - slippage disable
shareOutAmount: sdk.NewInt(1000000000000000000), // weight recovery direction - slippage disable
expSenderBalance: sdk.Coins{},
expTokenIn: sdk.Coins{sdk.NewInt64Coin("uusdt", 1000000)},
expPass: true,
Expand Down Expand Up @@ -215,3 +215,119 @@ func (suite *KeeperTestSuite) TestMsgServerJoinPool() {
})
}
}

func (suite *KeeperTestSuite) TestMsgServerJoinPoolExploitScenario() {
for _, tc := range []struct {
desc string
senderInitBalance sdk.Coins
poolInitBalance sdk.Coins
poolParams types.PoolParams
shareOutAmount math.Int
expSenderBalance sdk.Coins
expTotalLiquidity sdk.Coins
expTokenIn sdk.Coins
expPass bool
}{
{
desc: "Exploit scenario for Join Pool - unfair liquidity extraction",
senderInitBalance: sdk.Coins{sdk.NewInt64Coin(ptypes.ATOM, 100_000_000_000000)},
poolInitBalance: sdk.Coins{sdk.NewInt64Coin(ptypes.ATOM, 100_000_000_000000), sdk.NewInt64Coin(ptypes.BaseCurrency, 100_000_000_000000)},
poolParams: types.PoolParams{
SwapFee: sdk.ZeroDec(),
ExitFee: sdk.ZeroDec(),
UseOracle: true,
WeightBreakingFeeMultiplier: sdk.NewDecWithPrec(1, 2), // 0.01
WeightBreakingFeeExponent: sdk.NewDecWithPrec(25, 1), // 2.5
ExternalLiquidityRatio: sdk.NewDec(1),
WeightRecoveryFeePortion: sdk.NewDecWithPrec(10, 2), // 10%
ThresholdWeightDifference: sdk.NewDecWithPrec(2, 1), // 20%
WeightBreakingFeePortion: sdk.NewDecWithPrec(50, 2), // 50%
FeeDenom: ptypes.BaseCurrency,
},
shareOutAmount: sdk.NewInt(2_000000000000000000),
expSenderBalance: sdk.Coins{},
expTokenIn: sdk.Coins{sdk.NewInt64Coin(ptypes.ATOM, 1_000000)},
expPass: false,
},
} {
suite.Run(tc.desc, func() {
suite.SetupTest()
suite.SetupCoinPrices()

// Step 1: Bootstrap accounts
// Create sender, pool, and treasury accounts
sender := sdk.AccAddress(ed25519.GenPrivKey().PubKey().Address())
poolAddr := types.NewPoolAddress(1)
treasuryAddr := sdk.AccAddress(ed25519.GenPrivKey().PubKey().Address())

// Step 2: Bootstrap balances
err := suite.app.BankKeeper.MintCoins(suite.ctx, minttypes.ModuleName, tc.senderInitBalance)
suite.Require().NoError(err)
err = suite.app.BankKeeper.SendCoinsFromModuleToAccount(suite.ctx, minttypes.ModuleName, sender, tc.senderInitBalance)
suite.Require().NoError(err)
err = suite.app.BankKeeper.MintCoins(suite.ctx, minttypes.ModuleName, tc.poolInitBalance)
suite.Require().NoError(err)
err = suite.app.BankKeeper.SendCoinsFromModuleToAccount(suite.ctx, minttypes.ModuleName, poolAddr, tc.poolInitBalance)
suite.Require().NoError(err)

suite.app.AmmKeeper.SetDenomLiquidity(suite.ctx, types.DenomLiquidity{
Denom: tc.poolInitBalance[0].Denom,
Liquidity: tc.poolInitBalance[0].Amount,
})
suite.app.AmmKeeper.SetDenomLiquidity(suite.ctx, types.DenomLiquidity{
Denom: tc.poolInitBalance[1].Denom,
Liquidity: tc.poolInitBalance[1].Amount,
})

// Step 3: Setup initial pool with 50:50 weight
pool := types.Pool{
PoolId: 1,
Address: poolAddr.String(),
RebalanceTreasury: treasuryAddr.String(),
PoolParams: tc.poolParams,
TotalShares: sdk.NewCoin("amm/pool/1", sdk.NewInt(2).Mul(types.OneShare)),
PoolAssets: []types.PoolAsset{
{
Token: tc.poolInitBalance[0],
Weight: sdk.NewInt(1),
},
{
Token: tc.poolInitBalance[1],
Weight: sdk.NewInt(1),
},
},
TotalWeight: sdk.ZeroInt(),
}
suite.app.AmmKeeper.SetPool(suite.ctx, pool)

// Step 4: Simulate market price movement - adjust weights to 10:1
pool.PoolAssets[0].Weight = sdk.NewInt(10)
pool.PoolAssets[1].Weight = sdk.NewInt(1)
suite.app.AmmKeeper.SetPool(suite.ctx, pool)

// Step 5: New LP adds single-sided liquidity
msgServer := keeper.NewMsgServerImpl(suite.app.AmmKeeper)
resp, err := msgServer.JoinPool(
sdk.WrapSDKContext(suite.ctx),
&types.MsgJoinPool{
Sender: sender.String(),
PoolId: 1,
MaxAmountsIn: tc.senderInitBalance,
ShareAmountOut: tc.shareOutAmount,
})

suite.Require().NoError(err)

// Step 6: Validate if exploit was successful (It should fail)
// Calculate expected number of shares without weight balance bonus
totalShares := pool.TotalShares.Amount
joinValueWithoutSlippage, _ := pool.CalcJoinValueWithoutSlippage(suite.ctx, suite.app.OracleKeeper, suite.app.AccountedPoolKeeper, tc.senderInitBalance)
tvl, _ := pool.TVL(suite.ctx, suite.app.OracleKeeper, suite.app.AccountedPoolKeeper)
expectedNumShares := totalShares.ToLegacyDec().
Mul(joinValueWithoutSlippage).Quo(tvl).RoundInt()

// Number of shares must be lesser or equal to expected
suite.Require().GreaterOrEqual(expectedNumShares.String(), resp.ShareAmountOut.String(), "Exploit detected: Sender received more shares than expected")
})
}
}
34 changes: 18 additions & 16 deletions x/amm/types/pool_join_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,32 +183,34 @@ func (p *Pool) JoinPool(
return sdk.ZeroInt(), sdk.ZeroDec(), sdk.ZeroDec(), err
}
weightDistance := p.WeightDistanceFromTarget(ctx, oracleKeeper, accountedPoolKeeper, newAssetPools)

distanceDiff := weightDistance.Sub(initialWeightDistance)

weightBreakingFee := sdk.ZeroDec()
if distanceDiff.IsPositive() {
// we only allow
tokenInDenom := tokensIn[0].Denom
// target weight
targetWeightIn := GetDenomNormalizedWeight(p.PoolAssets, tokenInDenom)
targetWeightOut := sdk.OneDec().Sub(targetWeightIn)

// weight breaking fee as in Plasma pool
weightIn := GetDenomOracleAssetWeight(ctx, p.PoolId, oracleKeeper, accountedPoolKeeper, newAssetPools, tokenInDenom)
weightOut := sdk.OneDec().Sub(weightIn)
weightBreakingFee = GetWeightBreakingFee(weightIn, weightOut, targetWeightIn, targetWeightOut, p.PoolParams, distanceDiff)
}
// we only allow
tokenInDenom := tokensIn[0].Denom
// target weight
targetWeightIn := GetDenomNormalizedWeight(p.PoolAssets, tokenInDenom)
targetWeightOut := sdk.OneDec().Sub(targetWeightIn)

// weight breaking fee as in Plasma pool
weightIn := GetDenomOracleAssetWeight(ctx, p.PoolId, oracleKeeper, accountedPoolKeeper, newAssetPools, tokenInDenom)
weightOut := sdk.OneDec().Sub(weightIn)
weightBreakingFee := GetWeightBreakingFee(weightIn, weightOut, targetWeightIn, targetWeightOut, p.PoolParams, distanceDiff)

// weight recovery reward = weight breaking fee * weight recovery fee portion
weightRecoveryReward := weightBreakingFee.Mul(p.PoolParams.WeightRecoveryFeePortion)

// bonus is valid when distance is lower than original distance and when threshold weight reached
weightBalanceBonus = weightBreakingFee.Neg()
if initialWeightDistance.GT(p.PoolParams.ThresholdWeightDifference) && distanceDiff.IsNegative() {
weightBalanceBonus = p.PoolParams.WeightBreakingFeeMultiplier.Mul(distanceDiff).Abs()
weightBalanceBonus = weightRecoveryReward
// set weight breaking fee to zero if bonus is applied
weightBreakingFee = sdk.ZeroDec()
}

totalShares := p.GetTotalShares()
numSharesDec := sdk.NewDecFromInt(totalShares.Amount).
Mul(joinValueWithoutSlippage).Quo(tvl).
Mul(sdk.OneDec().Add(weightBalanceBonus))
Mul(sdk.OneDec().Sub(weightBreakingFee))
numShares = numSharesDec.RoundInt()
err = p.IncreaseLiquidity(numShares, tokensIn)
if err != nil {
Expand Down
7 changes: 5 additions & 2 deletions x/leveragelp/keeper/msg_server_close_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package keeper_test

import (
"time"

"cosmossdk.io/math"
sdk "github.com/cosmos/cosmos-sdk/types"
minttypes "github.com/cosmos/cosmos-sdk/x/mint/types"
Expand All @@ -10,7 +12,6 @@ import (
ptypes "github.com/elys-network/elys/x/parameter/types"
stablekeeper "github.com/elys-network/elys/x/stablestake/keeper"
stabletypes "github.com/elys-network/elys/x/stablestake/types"
"time"
)

func initializeForClose(suite *KeeperTestSuite, addresses []sdk.AccAddress, asset1, asset2 string) {
Expand Down Expand Up @@ -265,7 +266,9 @@ func (suite *KeeperTestSuite) TestClose() {
},
func() {
position, _ := suite.app.LeveragelpKeeper.GetPosition(suite.ctx, addresses[0], 1)
suite.Require().Equal(position.LeveragedLpAmount, leverageLPShares.QuoRaw(2))
actualShares, ok := sdk.NewIntFromString("9999952380952380950")
suite.Require().True(ok)
suite.Require().Equal(position.LeveragedLpAmount, actualShares)
},
},
{"Closing whole position",
Expand Down

0 comments on commit a793d42

Please sign in to comment.