From d132628f7c7e2663b22bdd3978d4d68a937c28eb Mon Sep 17 00:00:00 2001 From: jelysn <129082781+jelysn@users.noreply.github.com> Date: Fri, 1 Sep 2023 15:47:35 +0800 Subject: [PATCH] add slippage on join/exit pool - oracle pool (#179) * basic code to add slippage on join/exit pool * Resolve unit test on add/remove liquidity on oracle pool, and add more case on join pool * Resolve comments on swap txs sorting mechanism PR --- x/amm/keeper/abci.go | 66 +----------- x/amm/keeper/abci_test.go | 4 +- x/amm/keeper/batch_processing.go | 6 +- x/amm/keeper/msg_server_exit_pool_test.go | 8 +- x/amm/keeper/msg_server_join_pool_test.go | 25 ++++- x/amm/types/calc_exit_pool.go | 66 +++++++++++- x/amm/types/errors.go | 3 +- x/amm/types/key_batch_txs.go | 16 ++- x/amm/types/pool_join_pool_no_swap.go | 120 ++++++++++++++++++++-- 9 files changed, 215 insertions(+), 99 deletions(-) diff --git a/x/amm/keeper/abci.go b/x/amm/keeper/abci.go index c89b46b91..48b49f084 100644 --- a/x/amm/keeper/abci.go +++ b/x/amm/keeper/abci.go @@ -1,7 +1,6 @@ package keeper import ( - "fmt" "strings" "time" @@ -22,9 +21,8 @@ func (k Keeper) GetStackedSlippage(ctx sdk.Context, poolId uint64) sdk.Dec { } func (k Keeper) ApplySwapRequest(ctx sdk.Context, msg sdk.Msg) error { - switch msg.(type) { + switch msg := msg.(type) { case *types.MsgSwapExactAmountIn: - msg := msg.(*types.MsgSwapExactAmountIn) sender, err := sdk.AccAddressFromBech32(msg.Sender) if err != nil { return err @@ -35,7 +33,6 @@ func (k Keeper) ApplySwapRequest(ctx sdk.Context, msg sdk.Msg) error { } return nil case *types.MsgSwapExactAmountOut: - msg := msg.(*types.MsgSwapExactAmountOut) sender, err := sdk.AccAddressFromBech32(msg.Sender) if err != nil { return err @@ -46,17 +43,15 @@ func (k Keeper) ApplySwapRequest(ctx sdk.Context, msg sdk.Msg) error { } return nil default: - return fmt.Errorf("unexpected swap message") + return types.ErrInvalidSwapMsgType } } func (k Keeper) DeleteSwapRequest(ctx sdk.Context, msg sdk.Msg, index uint64) { - switch msg.(type) { + switch msg := msg.(type) { case *types.MsgSwapExactAmountIn: - msg := msg.(*types.MsgSwapExactAmountIn) k.DeleteSwapExactAmountInRequest(ctx, msg, index) case *types.MsgSwapExactAmountOut: - msg := msg.(*types.MsgSwapExactAmountOut) k.DeleteSwapExactAmountOutRequest(ctx, msg, index) } } @@ -72,12 +67,10 @@ func (k Keeper) SelectOneSwapRequest(ctx sdk.Context, sprefix []byte) (sdk.Msg, func (k Keeper) SelectReverseSwapRequest(ctx sdk.Context, msg sdk.Msg) (sdk.Msg, uint64) { sprefix := []byte{} - switch msg.(type) { + switch msg := msg.(type) { case *types.MsgSwapExactAmountIn: - msg := msg.(*types.MsgSwapExactAmountIn) sprefix = types.TKeyPrefixSwapExactAmountInPrefix(msg) case *types.MsgSwapExactAmountOut: - msg := msg.(*types.MsgSwapExactAmountOut) sprefix = types.TKeyPrefixSwapExactAmountOutPrefix(msg) } @@ -90,12 +83,10 @@ func (k Keeper) SelectReverseSwapRequest(ctx sdk.Context, msg sdk.Msg) (sdk.Msg, } func (k Keeper) FirstPoolId(msg sdk.Msg) uint64 { - switch msg.(type) { + switch msg := msg.(type) { case *types.MsgSwapExactAmountIn: - msg := msg.(*types.MsgSwapExactAmountIn) return types.FirstPoolIdFromSwapExactAmountIn(msg) case *types.MsgSwapExactAmountOut: - msg := msg.(*types.MsgSwapExactAmountOut) return types.FirstPoolIdFromSwapExactAmountOut(msg) } return 0 @@ -176,51 +167,4 @@ func (k Keeper) EndBlocker(ctx sdk.Context) { defer telemetry.ModuleMeasureSince(types.ModuleName, time.Now(), telemetry.MetricKeyBeginBlocker) k.ExecuteSwapRequests(ctx) - // swapInRequests := k.GetAllSwapExactAmountInRequests(ctx) - // for _, msg := range swapInRequests { - // sender, err := sdk.AccAddressFromBech32(msg.Sender) - // if err != nil { - // continue - // } - - // cacheCtx, write := ctx.CacheContext() - // _, err = k.RouteExactAmountIn(cacheCtx, sender, msg.Routes, msg.TokenIn, math.Int(msg.TokenOutMinAmount)) - // if err != nil { - // continue - // } - // write() - - // // Swap event is handled elsewhere - // ctx.EventManager().EmitEvents(sdk.Events{ - // sdk.NewEvent( - // sdk.EventTypeMessage, - // sdk.NewAttribute(sdk.AttributeKeyModule, types.AttributeValueCategory), - // sdk.NewAttribute(sdk.AttributeKeySender, msg.Sender), - // ), - // }) - - // } - // swapOutRequests := k.GetAllSwapExactAmountOutRequests(ctx) - // for _, msg := range swapOutRequests { - // sender, err := sdk.AccAddressFromBech32(msg.Sender) - // if err != nil { - // continue - // } - - // cacheCtx, write := ctx.CacheContext() - // _, err = k.RouteExactAmountOut(cacheCtx, sender, msg.Routes, msg.TokenInMaxAmount, msg.TokenOut) - // if err != nil { - // continue - // } - // write() - - // // Swap event is handled elsewhere - // ctx.EventManager().EmitEvents(sdk.Events{ - // sdk.NewEvent( - // sdk.EventTypeMessage, - // sdk.NewAttribute(sdk.AttributeKeyModule, types.AttributeValueCategory), - // sdk.NewAttribute(sdk.AttributeKeySender, msg.Sender), - // ), - // }) - // } } diff --git a/x/amm/keeper/abci_test.go b/x/amm/keeper/abci_test.go index 777613b73..6f5b978fc 100644 --- a/x/amm/keeper/abci_test.go +++ b/x/amm/keeper/abci_test.go @@ -276,16 +276,14 @@ func (suite *KeeperTestSuite) TestExecuteSwapRequests() { msgServer := keeper.NewMsgServerImpl(suite.app.AmmKeeper) for _, msg := range tc.swapMsgs { - switch msg.(type) { + switch msg := msg.(type) { case *types.MsgSwapExactAmountIn: - msg := msg.(*types.MsgSwapExactAmountIn) _, err := msgServer.SwapExactAmountIn( sdk.WrapSDKContext(suite.ctx), msg, ) suite.Require().NoError(err) case *types.MsgSwapExactAmountOut: - msg := msg.(*types.MsgSwapExactAmountOut) _, err := msgServer.SwapExactAmountOut( sdk.WrapSDKContext(suite.ctx), msg, diff --git a/x/amm/keeper/batch_processing.go b/x/amm/keeper/batch_processing.go index efd32a263..7ce70cdeb 100644 --- a/x/amm/keeper/batch_processing.go +++ b/x/amm/keeper/batch_processing.go @@ -21,11 +21,10 @@ func (k Keeper) GetLastSwapRequestIndex(ctx sdk.Context) uint64 { } // SetSwapExactAmountInRequests stores swap exact amount in request -func (k Keeper) SetSwapExactAmountInRequests(ctx sdk.Context, msg *types.MsgSwapExactAmountIn, index uint64) error { +func (k Keeper) SetSwapExactAmountInRequests(ctx sdk.Context, msg *types.MsgSwapExactAmountIn, index uint64) { store := prefix.NewStore(ctx.TransientStore(k.transientStoreKey), types.KeyPrefix(types.TSwapExactAmountInKey)) b := k.cdc.MustMarshal(msg) store.Set(types.TKeyPrefixSwapExactAmountIn(msg, index), b) - return nil } // DeleteSwapExactAmountInRequest removes a swap exact amount in request @@ -65,11 +64,10 @@ func (k Keeper) GetFirstSwapExactAmountInRequest(ctx sdk.Context, sprefix []byte } // SetSwapExactAmountInRequests stores swap exact amount out request -func (k Keeper) SetSwapExactAmountOutRequests(ctx sdk.Context, msg *types.MsgSwapExactAmountOut, index uint64) error { +func (k Keeper) SetSwapExactAmountOutRequests(ctx sdk.Context, msg *types.MsgSwapExactAmountOut, index uint64) { store := prefix.NewStore(ctx.TransientStore(k.transientStoreKey), types.KeyPrefix(types.TSwapExactAmountOutKey)) b := k.cdc.MustMarshal(msg) store.Set(types.TKeyPrefixSwapExactAmountOut(msg, index), b) - return nil } // DeleteSwapExactAmountOutRequest deletes a swap exact amount out request diff --git a/x/amm/keeper/msg_server_exit_pool_test.go b/x/amm/keeper/msg_server_exit_pool_test.go index 09eda173d..343eeda76 100644 --- a/x/amm/keeper/msg_server_exit_pool_test.go +++ b/x/amm/keeper/msg_server_exit_pool_test.go @@ -79,8 +79,8 @@ func (suite *KeeperTestSuite) TestMsgServerExitPool() { }, shareInAmount: types.OneShare.Quo(sdk.NewInt(10)), tokenOutDenom: "uusdt", - minAmountsOut: sdk.Coins{sdk.NewInt64Coin("uusdt", 97368)}, - expSenderBalance: sdk.Coins{sdk.NewInt64Coin("uusdt", 97368)}, + minAmountsOut: sdk.Coins{sdk.NewInt64Coin("uusdt", 95114)}, + expSenderBalance: sdk.Coins{sdk.NewInt64Coin("uusdt", 95114)}, expPass: true, }, { @@ -100,8 +100,8 @@ func (suite *KeeperTestSuite) TestMsgServerExitPool() { }, shareInAmount: types.OneShare.Quo(sdk.NewInt(10)), tokenOutDenom: "uusdc", - minAmountsOut: sdk.Coins{sdk.NewInt64Coin("uusdc", 100000)}, - expSenderBalance: sdk.Coins{sdk.NewInt64Coin("uusdc", 100000)}, + minAmountsOut: sdk.Coins{sdk.NewInt64Coin("uusdc", 99197)}, + expSenderBalance: sdk.Coins{sdk.NewInt64Coin("uusdc", 99197)}, expPass: true, }, } { diff --git a/x/amm/keeper/msg_server_join_pool_test.go b/x/amm/keeper/msg_server_join_pool_test.go index 42abae6f4..498189187 100644 --- a/x/amm/keeper/msg_server_join_pool_test.go +++ b/x/amm/keeper/msg_server_join_pool_test.go @@ -78,7 +78,7 @@ func (suite *KeeperTestSuite) TestMsgServerJoinPool() { ThresholdWeightDifference: sdk.NewDecWithPrec(2, 1), // 20% FeeDenom: "uusdc", }, - shareOutAmount: sdk.NewInt(833333333333333333), // weight breaking fee + shareOutAmount: sdk.NewInt(694444166666666666), // weight breaking fee expSenderBalance: sdk.Coins{}, expTokenIn: sdk.Coins{sdk.NewInt64Coin("uusdt", 1000000)}, expPass: true, @@ -99,11 +99,32 @@ func (suite *KeeperTestSuite) TestMsgServerJoinPool() { ThresholdWeightDifference: sdk.NewDecWithPrec(2, 1), // 20% FeeDenom: "uusdc", }, - shareOutAmount: sdk.NewInt(1250000000000000000), // weight breaking fee + shareOutAmount: sdk.NewInt(805987500000000000), // weight recovery direction expSenderBalance: sdk.Coins{}, expTokenIn: sdk.Coins{sdk.NewInt64Coin("uusdt", 1000000)}, expPass: true, }, + { + desc: "oracle pool join - zero slippage add liquidity", + senderInitBalance: sdk.Coins{sdk.NewInt64Coin("uusdc", 1500000), sdk.NewInt64Coin("uusdt", 500000)}, + poolInitBalance: sdk.Coins{sdk.NewInt64Coin("uusdc", 1500000), sdk.NewInt64Coin("uusdt", 500000)}, + poolParams: types.PoolParams{ + SwapFee: sdk.ZeroDec(), + ExitFee: sdk.ZeroDec(), + UseOracle: true, + WeightBreakingFeeMultiplier: sdk.NewDecWithPrec(1, 0), // 1.00 + ExternalLiquidityRatio: sdk.NewDec(1), + LpFeePortion: sdk.ZeroDec(), + StakingFeePortion: sdk.ZeroDec(), + WeightRecoveryFeePortion: sdk.ZeroDec(), + ThresholdWeightDifference: sdk.NewDecWithPrec(2, 1), // 20% + FeeDenom: "uusdc", + }, + shareOutAmount: sdk.NewInt(2000000000000000000), + expSenderBalance: sdk.Coins{}, + expTokenIn: sdk.Coins{sdk.NewInt64Coin("uusdc", 1500000), sdk.NewInt64Coin("uusdt", 500000)}, + expPass: true, + }, } { suite.Run(tc.desc, func() { suite.SetupTest() diff --git a/x/amm/types/calc_exit_pool.go b/x/amm/types/calc_exit_pool.go index 4e8c73947..e470e0ddf 100644 --- a/x/amm/types/calc_exit_pool.go +++ b/x/amm/types/calc_exit_pool.go @@ -8,6 +8,66 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" ) +func CalcExitValueWithoutSlippage(ctx sdk.Context, oracleKeeper OracleKeeper, pool Pool, exitingShares sdk.Int, tokenOutDenom string) (sdk.Dec, error) { + tvl, err := pool.TVL(ctx, oracleKeeper) + if err != nil { + return sdk.ZeroDec(), err + } + + totalShares := pool.GetTotalShares() + var refundedShares sdk.Dec + refundedShares = sdk.NewDecFromInt(exitingShares) + exitValue := tvl.Mul(refundedShares).Quo(sdk.NewDecFromInt(totalShares.Amount)) + + if exitingShares.GTE(totalShares.Amount) { + return sdk.ZeroDec(), sdkerrors.Wrapf(ErrLimitMaxAmount, ErrMsgFormatSharesLargerThanMax, exitingShares, totalShares) + } + + shareOutRatio := refundedShares.QuoInt(totalShares.Amount) + // exitedCoins = shareOutRatio * pool liquidity + exitedCoins := sdk.Coins{} + poolLiquidity := pool.GetTotalPoolLiquidity() + + for _, asset := range poolLiquidity { + // round down here, due to not wanting to over-exit + exitAmt := shareOutRatio.MulInt(asset.Amount).TruncateInt() + if exitAmt.LTE(sdk.ZeroInt()) { + continue + } + if exitAmt.GTE(asset.Amount) { + return sdk.ZeroDec(), errors.New("too many shares out") + } + exitedCoins = exitedCoins.Add(sdk.NewCoin(asset.Denom, exitAmt)) + } + + slippageValue := sdk.ZeroDec() + for _, exitedCoin := range exitedCoins { + if exitedCoin.Denom == tokenOutDenom { + continue + } + inTokenPrice := oracleKeeper.GetAssetPriceFromDenom(ctx, exitedCoin.Denom) + if inTokenPrice.IsZero() { + return sdk.ZeroDec(), fmt.Errorf("token price not set: %s", exitedCoin.Denom) + } + resizedAmount := sdk.NewDecFromInt(exitedCoin.Amount). + Quo(pool.PoolParams.ExternalLiquidityRatio).RoundInt() + slippageAmount, err := pool.CalcGivenInSlippage( + ctx, + oracleKeeper, + &pool, + sdk.Coins{sdk.NewCoin(exitedCoin.Denom, resizedAmount)}, + tokenOutDenom, + ) + if err != nil { + return sdk.ZeroDec(), err + } + + slippageValue = slippageValue.Add(slippageAmount.Mul(inTokenPrice)) + } + exitValueWithoutSlippage := exitValue.Sub(slippageValue) + return exitValueWithoutSlippage, nil +} + // CalcExitPool returns how many tokens should come out, when exiting k LP shares against a "standard" CFMM func CalcExitPool(ctx sdk.Context, oracleKeeper OracleKeeper, pool Pool, exitingShares sdk.Int, tokenOutDenom string) (sdk.Coins, error) { totalShares := pool.GetTotalShares() @@ -27,13 +87,13 @@ func CalcExitPool(ctx sdk.Context, oracleKeeper OracleKeeper, pool Pool, exiting if pool.PoolParams.UseOracle && tokenOutDenom != "" { initialWeightDistance := pool.WeightDistanceFromTarget(ctx, oracleKeeper, pool.PoolAssets) - tvl, err := pool.TVL(ctx, oracleKeeper) + tokenPrice := oracleKeeper.GetAssetPriceFromDenom(ctx, tokenOutDenom) + exitValueWithoutSlippage, err := CalcExitValueWithoutSlippage(ctx, oracleKeeper, pool, exitingShares, tokenOutDenom) if err != nil { return sdk.Coins{}, err } - tokenPrice := oracleKeeper.GetAssetPriceFromDenom(ctx, tokenOutDenom) - oracleOutAmount := tvl.Mul(refundedShares).Quo(sdk.NewDecFromInt(totalShares.Amount)).Quo(tokenPrice) + oracleOutAmount := exitValueWithoutSlippage.Quo(tokenPrice) newAssetPools, err := pool.NewPoolAssetsAfterSwap( sdk.Coins{}, diff --git a/x/amm/types/errors.go b/x/amm/types/errors.go index 3e7b9b4ab..042cc833d 100644 --- a/x/amm/types/errors.go +++ b/x/amm/types/errors.go @@ -22,7 +22,8 @@ var ( ErrTooManyTokensOut = sdkerrors.Register(ModuleName, 31, "tx is trying to get more tokens out of the pool than exist") - ErrInvalidPoolId = sdkerrors.Register(ModuleName, 91, "invalid pool id") + ErrInvalidPoolId = sdkerrors.Register(ModuleName, 91, "invalid pool id") + ErrInvalidSwapMsgType = sdkerrors.Register(ModuleName, 92, "unexpected swap message type") ) const ( diff --git a/x/amm/types/key_batch_txs.go b/x/amm/types/key_batch_txs.go index e34d0c9db..0777d0a53 100644 --- a/x/amm/types/key_batch_txs.go +++ b/x/amm/types/key_batch_txs.go @@ -1,15 +1,12 @@ package types import ( - "encoding/binary" fmt "fmt" "strings" sdk "github.com/cosmos/cosmos-sdk/types" ) -var _ binary.ByteOrder - const ( TLastSwapRequestIndex = "last-swap-request-index" TSwapExactAmountInKey = "batch/swap-exact-amount-in" @@ -17,7 +14,7 @@ const ( ) func TKeyPrefixSwapExactAmountInPrefix(m *MsgSwapExactAmountIn) []byte { - prefix := []byte(m.TokenIn.Denom + "/") + prefix := []byte(fmt.Sprintf("%s/", m.TokenIn.Denom)) routeKeys := []string{} for _, route := range m.Routes[:1] { routeKeys = append(routeKeys, fmt.Sprintf("%d/%s", route.PoolId, route.TokenOutDenom)) @@ -27,8 +24,8 @@ func TKeyPrefixSwapExactAmountInPrefix(m *MsgSwapExactAmountIn) []byte { } func FirstPoolIdFromSwapExactAmountIn(m *MsgSwapExactAmountIn) uint64 { - for _, route := range m.Routes { - return route.PoolId + if len(m.Routes) > 0 { + return m.Routes[0].PoolId } return 0 } @@ -39,12 +36,11 @@ func TKeyPrefixSwapExactAmountIn(m *MsgSwapExactAmountIn, index uint64) []byte { } func TKeyPrefixSwapExactAmountOutPrefix(m *MsgSwapExactAmountOut) []byte { - prefix := []byte("/" + m.TokenOut.Denom) + prefix := []byte(fmt.Sprintf("/%s", m.TokenOut.Denom)) routeKeys := []string{} - for i := len(m.Routes) - 1; i >= 0; i-- { - route := m.Routes[i] + if len(m.Routes) > 0 { + route := m.Routes[len(m.Routes)-1] routeKeys = append(routeKeys, fmt.Sprintf("%s/%d", route.TokenInDenom, route.PoolId)) - break } prefix = append([]byte(strings.Join(routeKeys, "/")), prefix...) return prefix diff --git a/x/amm/types/pool_join_pool_no_swap.go b/x/amm/types/pool_join_pool_no_swap.go index 5f72fcf1a..0fc951cf1 100644 --- a/x/amm/types/pool_join_pool_no_swap.go +++ b/x/amm/types/pool_join_pool_no_swap.go @@ -7,6 +7,109 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" ) +type PoolAssetUSDValue struct { + Asset string + Value sdk.Dec +} + +type InternalSwapRequest struct { + InAmount sdk.Coin + OutToken string +} + +func (p *Pool) CalcJoinValueWithoutSlippage(ctx sdk.Context, oracleKeeper OracleKeeper, tokensIn sdk.Coins) (math.LegacyDec, error) { + joinValue := sdk.ZeroDec() + for _, asset := range tokensIn { + tokenPrice := oracleKeeper.GetAssetPriceFromDenom(ctx, asset.Denom) + if tokenPrice.IsZero() { + return sdk.ZeroDec(), fmt.Errorf("token price not set: %s", asset.Denom) + } + v := tokenPrice.Mul(sdk.NewDecFromInt(asset.Amount)) + joinValue = joinValue.Add(v) + } + + // weights := NormalizedWeights(p.PoolAssets) + weights, err := OraclePoolNormalizedWeights(ctx, oracleKeeper, p.PoolAssets) + if err != nil { + return sdk.ZeroDec(), err + } + + inAmounts := []PoolAssetUSDValue{} + outAmounts := []PoolAssetUSDValue{} + + for _, weight := range weights { + targetAmount := joinValue.Mul(weight.Weight) + tokenPrice := oracleKeeper.GetAssetPriceFromDenom(ctx, weight.Asset) + if tokenPrice.IsZero() { + return sdk.ZeroDec(), fmt.Errorf("token price not set: %s", weight.Asset) + } + inAmount := tokenPrice.Mul(sdk.NewDecFromInt(tokensIn.AmountOf(weight.Asset))) + if targetAmount.GT(inAmount) { + outAmounts = append(outAmounts, PoolAssetUSDValue{ + Asset: weight.Asset, + Value: targetAmount.Sub(inAmount), + }) + } + + if targetAmount.LT(inAmount) { + inAmounts = append(inAmounts, PoolAssetUSDValue{ + Asset: weight.Asset, + Value: inAmount.Sub(targetAmount), + }) + } + } + + internalSwapRequests := []InternalSwapRequest{} + for i, j := 0, 0; i < len(inAmounts) && j < len(outAmounts); { + inTokenPrice := oracleKeeper.GetAssetPriceFromDenom(ctx, inAmounts[i].Asset) + if inTokenPrice.IsZero() { + return sdk.ZeroDec(), fmt.Errorf("token price not set: %s", inAmounts[i].Asset) + } + inAsset := inAmounts[i].Asset + outAsset := outAmounts[j].Asset + inAmount := sdk.ZeroInt() + if inAmounts[i].Value.GT(outAmounts[j].Value) { + inAmount = outAmounts[j].Value.Quo(inTokenPrice).RoundInt() + j++ + } else if inAmounts[i].Value.LT(outAmounts[j].Value) { + inAmount = inAmounts[i].Value.Quo(inTokenPrice).RoundInt() + i++ + } else { + inAmount = inAmounts[i].Value.Quo(inTokenPrice).RoundInt() + i++ + j++ + } + internalSwapRequests = append(internalSwapRequests, InternalSwapRequest{ + InAmount: sdk.NewCoin(inAsset, inAmount), + OutToken: outAsset, + }) + } + + slippageValue := sdk.ZeroDec() + for _, req := range internalSwapRequests { + inTokenPrice := oracleKeeper.GetAssetPriceFromDenom(ctx, req.InAmount.Denom) + if inTokenPrice.IsZero() { + return sdk.ZeroDec(), fmt.Errorf("token price not set: %s", req.InAmount.Denom) + } + resizedAmount := sdk.NewDecFromInt(req.InAmount.Amount). + Quo(p.PoolParams.ExternalLiquidityRatio).RoundInt() + slippageAmount, err := p.CalcGivenInSlippage( + ctx, + oracleKeeper, + p, + sdk.Coins{sdk.NewCoin(req.InAmount.Denom, resizedAmount)}, + req.OutToken, + ) + if err != nil { + return sdk.ZeroDec(), err + } + + slippageValue = slippageValue.Add(slippageAmount.Mul(inTokenPrice)) + } + joinValueWithoutSlippage := joinValue.Sub(slippageValue) + return joinValueWithoutSlippage, nil +} + // JoinPoolNoSwap calculates the number of shares needed for an all-asset join given tokensIn with swapFee applied. // It updates the liquidity if the pool is joined successfully. If not, returns error. func (p *Pool) JoinPoolNoSwap(ctx sdk.Context, oracleKeeper OracleKeeper, tokensIn sdk.Coins) (numShares math.Int, err error) { @@ -21,20 +124,15 @@ func (p *Pool) JoinPoolNoSwap(ctx sdk.Context, oracleKeeper OracleKeeper, tokens return numShares, nil } - initialWeightDistance := p.WeightDistanceFromTarget(ctx, oracleKeeper, p.PoolAssets) - tvl, err := p.TVL(ctx, oracleKeeper) + joinValueWithoutSlippage, err := p.CalcJoinValueWithoutSlippage(ctx, oracleKeeper, tokensIn) if err != nil { return sdk.ZeroInt(), err } - joinValue := sdk.ZeroDec() - for _, asset := range tokensIn { - tokenPrice := oracleKeeper.GetAssetPriceFromDenom(ctx, asset.Denom) - if tokenPrice.IsZero() { - return sdk.ZeroInt(), fmt.Errorf("token price not set: %s", asset.Denom) - } - v := tokenPrice.Mul(sdk.NewDecFromInt(asset.Amount)) - joinValue = joinValue.Add(v) + initialWeightDistance := p.WeightDistanceFromTarget(ctx, oracleKeeper, p.PoolAssets) + tvl, err := p.TVL(ctx, oracleKeeper) + if err != nil { + return sdk.ZeroInt(), err } newAssetPools, err := p.NewPoolAssetsAfterSwap( @@ -58,7 +156,7 @@ func (p *Pool) JoinPoolNoSwap(ctx sdk.Context, oracleKeeper OracleKeeper, tokens totalShares := p.GetTotalShares() numSharesDec := sdk.NewDecFromInt(totalShares.Amount). - Mul(joinValue).Quo(tvl). + Mul(joinValueWithoutSlippage).Quo(tvl). Mul(sdk.OneDec().Add(weightBalanceBonus).Sub(weightBreakingFee)) numShares = numSharesDec.RoundInt() p.IncreaseLiquidity(numShares, tokensIn)