diff --git a/app/app.go b/app/app.go index aef89c38e..371f0f4b4 100644 --- a/app/app.go +++ b/app/app.go @@ -675,6 +675,15 @@ func NewElysApp( app.AssetprofileKeeper, ) + app.AccountedPoolKeeper = *accountedpoolmodulekeeper.NewKeeper( + appCodec, + keys[accountedpoolmoduletypes.StoreKey], + keys[accountedpoolmoduletypes.MemStoreKey], + app.GetSubspace(accountedpoolmoduletypes.ModuleName), + app.BankKeeper, + ) + accountedPoolModule := accountedpoolmodule.NewAppModule(appCodec, app.AccountedPoolKeeper, app.AccountKeeper, app.BankKeeper) + app.AmmKeeper = *ammmodulekeeper.NewKeeper( appCodec, keys[ammmoduletypes.StoreKey], @@ -700,7 +709,6 @@ func NewElysApp( app.BankKeeper, app.AmmKeeper, app.OracleKeeper, - app.AccountedPoolKeeper, authtypes.FeeCollectorName, DexRevenueCollectorName, ) @@ -826,16 +834,6 @@ func NewElysApp( ) marginModule := marginmodule.NewAppModule(appCodec, app.MarginKeeper, app.AccountKeeper, app.BankKeeper) - app.AccountedPoolKeeper = *accountedpoolmodulekeeper.NewKeeper( - appCodec, - keys[accountedpoolmoduletypes.StoreKey], - keys[accountedpoolmoduletypes.MemStoreKey], - app.GetSubspace(accountedpoolmoduletypes.ModuleName), - app.MarginKeeper, - app.BankKeeper, - ) - accountedPoolModule := accountedpoolmodule.NewAppModule(appCodec, app.AccountedPoolKeeper, app.AccountKeeper, app.BankKeeper) - // this line is used by starport scaffolding # stargate/app/keeperDefinition /**** IBC Routing ****/ @@ -879,7 +877,7 @@ func NewElysApp( ammmoduletypes.NewMultiAmmHooks( // insert amm hooks receivers here app.IncentiveKeeper.AmmHooks(), - app.AccountedPoolKeeper.AmmHooks(), + app.MarginKeeper.AmmHooks(), ), ) diff --git a/testutil/keeper/accountedpool.go b/testutil/keeper/accountedpool.go index 5726a5094..a6a65dd4f 100644 --- a/testutil/keeper/accountedpool.go +++ b/testutil/keeper/accountedpool.go @@ -42,7 +42,6 @@ func AccountedPoolKeeper(t testing.TB) (*keeper.Keeper, sdk.Context) { memStoreKey, paramsSubspace, nil, - nil, ) ctx := sdk.NewContext(stateStore, tmproto.Header{}, false, log.NewNopLogger()) diff --git a/testutil/keeper/incentive.go b/testutil/keeper/incentive.go index 639b0357e..ab05b5346 100644 --- a/testutil/keeper/incentive.go +++ b/testutil/keeper/incentive.go @@ -47,7 +47,6 @@ func IncentiveKeeper(t testing.TB) (*keeper.Keeper, sdk.Context) { nil, nil, nil, - nil, "", "", ) diff --git a/x/accountedpool/keeper/accounted_pool_update.go b/x/accountedpool/keeper/accounted_pool_update.go index 33f898ca3..ace0666b7 100644 --- a/x/accountedpool/keeper/accounted_pool_update.go +++ b/x/accountedpool/keeper/accounted_pool_update.go @@ -30,23 +30,6 @@ func (k Keeper) GetMarginPoolBalances(marginPool margintypes.Pool, denom string) return sdk.ZeroInt(), sdk.ZeroInt(), sdk.ZeroInt() } -// Update accounted pool balance -func (k Keeper) UpdateAccountedPoolByAMM(ctx sdk.Context, ammPool ammtypes.Pool) error { - poolId := ammPool.PoolId - // Get margin pool - marginPool, found := k.margin.GetPool(ctx, poolId) - if !found { - return errors.New("pool doesn't exist!") - } - - return k.UpdateAccountedPool(ctx, ammPool, marginPool) -} - -// Update accounted pool balance -func (k Keeper) UpdateAccountedPoolByMargin(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) error { - return k.UpdateAccountedPool(ctx, ammPool, marginPool) -} - func (k Keeper) UpdateAccountedPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) error { poolId := ammPool.PoolId // Check if already exists diff --git a/x/accountedpool/keeper/accounted_pool_update_test.go b/x/accountedpool/keeper/accounted_pool_update_test.go new file mode 100644 index 000000000..3f0efd3ea --- /dev/null +++ b/x/accountedpool/keeper/accounted_pool_update_test.go @@ -0,0 +1,80 @@ +package keeper_test + +import ( + "testing" + + tmproto "github.com/cometbft/cometbft/proto/tendermint/types" + sdk "github.com/cosmos/cosmos-sdk/types" + simapp "github.com/elys-network/elys/app" + ammtypes "github.com/elys-network/elys/x/amm/types" + + "github.com/elys-network/elys/x/accountedpool/types" + margintypes "github.com/elys-network/elys/x/margin/types" + ptypes "github.com/elys-network/elys/x/parameter/types" +) + +func TestAccountedPoolUpdate(t *testing.T) { + app := simapp.InitElysTestApp(true) + ctx := app.BaseApp.NewContext(true, tmproto.Header{}) + + apk := app.AccountedPoolKeeper + + // Generate 1 random account with 1000stake balanced + addr := simapp.AddTestAddrs(app, ctx, 1, sdk.NewInt(1000000)) + + // Initiate pool + ammPool := ammtypes.Pool{ + PoolId: 0, + Address: addr[0].String(), + PoolParams: ammtypes.PoolParams{}, + TotalShares: sdk.NewCoin("lp-token", sdk.NewInt(100)), + PoolAssets: []ammtypes.PoolAsset{ + {Token: sdk.NewCoin(ptypes.ATOM, sdk.NewInt(100))}, + {Token: sdk.NewCoin(ptypes.USDC, sdk.NewInt(1000))}, + }, + TotalWeight: sdk.NewInt(100), + RebalanceTreasury: addr[0].String(), + } + // Initiate pool + accountedPool := types.AccountedPool{ + PoolId: 0, + TotalShares: ammPool.TotalShares, + PoolAssets: []ammtypes.PoolAsset{}, + TotalWeight: ammPool.TotalWeight, + } + + for _, asset := range ammPool.PoolAssets { + accountedPool.PoolAssets = append(accountedPool.PoolAssets, asset) + } + // Set accounted pool + apk.SetAccountedPool(ctx, accountedPool) + + marginPool := margintypes.Pool{ + AmmPoolId: 0, + Health: sdk.NewDec(1), + Enabled: true, + Closed: false, + InterestRate: sdk.NewDec(1), + PoolAssets: []margintypes.PoolAsset{ + { + Liabilities: sdk.NewInt(400), + Custody: sdk.NewInt(0), + AssetBalance: sdk.NewInt(100), + UnsettledLiabilities: sdk.NewInt(0), + BlockInterest: sdk.NewInt(0), + AssetDenom: ptypes.USDC, + }, + { + Liabilities: sdk.NewInt(50), + Custody: sdk.NewInt(0), + AssetBalance: sdk.NewInt(0), + UnsettledLiabilities: sdk.NewInt(0), + BlockInterest: sdk.NewInt(0), + AssetDenom: ptypes.ATOM, + }, + }, + } + // Update accounted pool + apk.UpdateAccountedPool(ctx, ammPool, marginPool) + +} diff --git a/x/accountedpool/keeper/hooks_margin.go b/x/accountedpool/keeper/hooks_margin.go index bce80fe0a..ba61058e3 100644 --- a/x/accountedpool/keeper/hooks_margin.go +++ b/x/accountedpool/keeper/hooks_margin.go @@ -7,15 +7,35 @@ import ( ) func (k Keeper) AfterMarginPositionOpended(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { - k.UpdateAccountedPoolByMargin(ctx, ammPool, marginPool) + k.UpdateAccountedPool(ctx, ammPool, marginPool) } func (k Keeper) AfterMarginPositionModified(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { - k.UpdateAccountedPoolByMargin(ctx, ammPool, marginPool) + k.UpdateAccountedPool(ctx, ammPool, marginPool) } func (k Keeper) AfterMarginPositionClosed(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { - k.UpdateAccountedPoolByMargin(ctx, ammPool, marginPool) + k.UpdateAccountedPool(ctx, ammPool, marginPool) +} + +// AfterPoolCreated is called after CreatePool +func (k Keeper) AfterAmmPoolCreated(ctx sdk.Context, ammPool ammtypes.Pool) { + k.InitiateAccountedPool(ctx, ammPool) +} + +// AfterJoinPool is called after JoinPool, JoinSwapExternAmountIn, and JoinSwapShareAmountOut +func (k Keeper) AfterAmmJoinPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { + k.UpdateAccountedPool(ctx, ammPool, marginPool) +} + +// AfterExitPool is called after ExitPool, ExitSwapShareAmountIn, and ExitSwapExternAmountOut +func (k Keeper) AfterAmmExitPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { + k.UpdateAccountedPool(ctx, ammPool, marginPool) +} + +// AfterSwap is called after SwapExactAmountIn and SwapExactAmountOut +func (k Keeper) AfterAmmSwap(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { + k.UpdateAccountedPool(ctx, ammPool, marginPool) } // Hooks wrapper struct for tvl keeper @@ -41,3 +61,19 @@ func (h MarginHooks) AfterMarginPositionModified(ctx sdk.Context, ammPool ammtyp func (h MarginHooks) AfterMarginPositionClosed(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { h.k.AfterMarginPositionClosed(ctx, ammPool, marginPool) } + +func (h MarginHooks) AfterAmmPoolCreated(ctx sdk.Context, ammPool ammtypes.Pool) { + h.k.AfterAmmPoolCreated(ctx, ammPool) +} + +func (h MarginHooks) AfterAmmJoinPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { + h.k.AfterAmmJoinPool(ctx, ammPool, marginPool) +} + +func (h MarginHooks) AfterAmmExitPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { + h.k.AfterAmmExitPool(ctx, ammPool, marginPool) +} + +func (h MarginHooks) AfterAmmSwap(ctx sdk.Context, ammPool ammtypes.Pool, marginPool margintypes.Pool) { + h.k.AfterAmmSwap(ctx, ammPool, marginPool) +} diff --git a/x/accountedpool/keeper/keeper.go b/x/accountedpool/keeper/keeper.go index 312d9c365..0e7d7a43c 100644 --- a/x/accountedpool/keeper/keeper.go +++ b/x/accountedpool/keeper/keeper.go @@ -18,7 +18,6 @@ type ( storeKey storetypes.StoreKey memKey storetypes.StoreKey paramstore paramtypes.Subspace - margin types.MarginKeeper bankKeeper types.BankKeeper } ) @@ -28,7 +27,6 @@ func NewKeeper( storeKey, memKey storetypes.StoreKey, ps paramtypes.Subspace, - margin types.MarginKeeper, bk types.BankKeeper, ) *Keeper { // set KeyTable if it has not already been set @@ -41,7 +39,6 @@ func NewKeeper( storeKey: storeKey, memKey: memKey, paramstore: ps, - margin: margin, bankKeeper: bk, } } diff --git a/x/amm/keeper/initialize_pool.go b/x/amm/keeper/initialize_pool.go index b72771e19..e67163a52 100644 --- a/x/amm/keeper/initialize_pool.go +++ b/x/amm/keeper/initialize_pool.go @@ -15,7 +15,7 @@ import ( // - Records total liquidity increase // - Calls the AfterPoolCreated hook func (k Keeper) InitializePool(ctx sdk.Context, pool *types.Pool, sender sdk.AccAddress) (err error) { - tvl, err := pool.TVL(ctx, k.oracleKeeper, k.accountedPoolKeeper) + tvl, err := pool.TVL(ctx, k.oracleKeeper) if err != nil { return err } diff --git a/x/amm/keeper/swap_in_amt_given_out.go b/x/amm/keeper/swap_in_amt_given_out.go new file mode 100644 index 000000000..ee6cfe48a --- /dev/null +++ b/x/amm/keeper/swap_in_amt_given_out.go @@ -0,0 +1,22 @@ +package keeper + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/elys-network/elys/x/amm/types" +) + +// SwapInAmtGivenOut is a mutative method for CalcOutAmtGivenIn, which includes the actual swap. +func (k Keeper) SwapInAmtGivenOut( + ctx sdk.Context, poolId uint64, oracleKeeper types.OracleKeeper, snapshot *types.Pool, + tokensOut sdk.Coins, tokenInDenom string, swapFee sdk.Dec) ( + tokenIn sdk.Coin, weightBalanceBonus sdk.Dec, err error, +) { + ammPool, found := k.GetPool(ctx, poolId) + if !found { + return sdk.Coin{}, sdk.ZeroDec(), fmt.Errorf("invalid pool: %d", poolId) + } + + return ammPool.SwapInAmtGivenOut(ctx, oracleKeeper, snapshot, tokensOut, tokenInDenom, swapFee, k.accountedPoolKeeper) +} diff --git a/x/amm/keeper/swap_out_amt_given_in.go b/x/amm/keeper/swap_out_amt_given_in.go new file mode 100644 index 000000000..e10bf44de --- /dev/null +++ b/x/amm/keeper/swap_out_amt_given_in.go @@ -0,0 +1,25 @@ +package keeper + +import ( + fmt "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/elys-network/elys/x/amm/types" +) + +// SwapOutAmtGivenIn is a mutative method for CalcOutAmtGivenIn, which includes the actual swap. +func (k Keeper) SwapOutAmtGivenIn( + ctx sdk.Context, poolId uint64, + oracleKeeper types.OracleKeeper, + snapshot *types.Pool, + tokensIn sdk.Coins, + tokenOutDenom string, + swapFee sdk.Dec, +) (tokenOut sdk.Coin, weightBalanceBonus sdk.Dec, err error) { + ammPool, found := k.GetPool(ctx, poolId) + if !found { + return sdk.Coin{}, sdk.ZeroDec(), fmt.Errorf("invalid pool: %d", poolId) + } + + return ammPool.SwapOutAmtGivenIn(ctx, oracleKeeper, snapshot, tokensIn, tokenOutDenom, swapFee, k.accountedPoolKeeper) +} diff --git a/x/amm/types/calc_exit_pool.go b/x/amm/types/calc_exit_pool.go index f3d947eb1..848de46bc 100644 --- a/x/amm/types/calc_exit_pool.go +++ b/x/amm/types/calc_exit_pool.go @@ -9,7 +9,7 @@ import ( ) func CalcExitValueWithoutSlippage(ctx sdk.Context, oracleKeeper OracleKeeper, accPoolKeeper AccountedPoolKeeper, pool Pool, exitingShares sdk.Int, tokenOutDenom string) (sdk.Dec, error) { - tvl, err := pool.TVL(ctx, oracleKeeper, accPoolKeeper) + tvl, err := pool.TVL(ctx, oracleKeeper) if err != nil { return sdk.ZeroDec(), err } diff --git a/x/amm/types/pool_join_pool_no_swap.go b/x/amm/types/pool_join_pool_no_swap.go index 4150f6191..bb4f00cd9 100644 --- a/x/amm/types/pool_join_pool_no_swap.go +++ b/x/amm/types/pool_join_pool_no_swap.go @@ -131,7 +131,7 @@ func (p *Pool) JoinPoolNoSwap(ctx sdk.Context, oracleKeeper OracleKeeper, accoun } initialWeightDistance := p.WeightDistanceFromTarget(ctx, oracleKeeper, p.PoolAssets) - tvl, err := p.TVL(ctx, oracleKeeper, accountedPoolKeeper) + tvl, err := p.TVL(ctx, oracleKeeper) if err != nil { return sdk.ZeroInt(), err } diff --git a/x/amm/types/tvl.go b/x/amm/types/tvl.go index 97fb23f1f..8de115533 100644 --- a/x/amm/types/tvl.go +++ b/x/amm/types/tvl.go @@ -6,7 +6,7 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" ) -func (p *Pool) TVL(ctx sdk.Context, oracleKeeper OracleKeeper, accountedPoolKeeepr AccountedPoolKeeper) (sdk.Dec, error) { +func (p *Pool) TVL(ctx sdk.Context, oracleKeeper OracleKeeper) (sdk.Dec, error) { // OracleAssetsTVL * TotalWeight / OracleAssetsWeight // E.g. JUNO / USDT / USDC (30:30:30) // TVL = USDC_USDT_liquidity * 90 / 60 diff --git a/x/amm/types/tvl_test.go b/x/amm/types/tvl_test.go index de9d35a57..cd63ae6e5 100644 --- a/x/amm/types/tvl_test.go +++ b/x/amm/types/tvl_test.go @@ -107,7 +107,7 @@ func (suite *TestSuite) TestTVL() { PoolAssets: tc.poolAssets, TotalWeight: sdk.ZeroInt(), } - tvl, err := pool.TVL(suite.ctx, suite.app.OracleKeeper, suite.app.AccountedPoolKeeper) + tvl, err := pool.TVL(suite.ctx, suite.app.OracleKeeper) if tc.expError { suite.Require().Error(err) } else { diff --git a/x/incentive/keeper/keeper.go b/x/incentive/keeper/keeper.go index ebbed91a1..d8023171c 100644 --- a/x/incentive/keeper/keeper.go +++ b/x/incentive/keeper/keeper.go @@ -17,18 +17,17 @@ import ( type ( Keeper struct { - cdc codec.BinaryCodec - storeKey storetypes.StoreKey - memKey storetypes.StoreKey - paramstore paramtypes.Subspace - cmk types.CommitmentKeeper - stk types.StakingKeeper - tci *types.TotalCommitmentInfo - authKeeper types.AccountKeeper - bankKeeper types.BankKeeper - amm types.AmmKeeper - oracleKeeper types.OracleKeeper - accountedPoolKeeper types.AccountedPoolKeeper + cdc codec.BinaryCodec + storeKey storetypes.StoreKey + memKey storetypes.StoreKey + paramstore paramtypes.Subspace + cmk types.CommitmentKeeper + stk types.StakingKeeper + tci *types.TotalCommitmentInfo + authKeeper types.AccountKeeper + bankKeeper types.BankKeeper + amm types.AmmKeeper + oracleKeeper types.OracleKeeper feeCollectorName string // name of the FeeCollector ModuleAccount dexRevCollectorName string // name of the Dex Revenue ModuleAccount @@ -46,7 +45,6 @@ func NewKeeper( bk types.BankKeeper, amm types.AmmKeeper, ok types.OracleKeeper, - apk types.AccountedPoolKeeper, feeCollectorName string, dexRevCollectorName string, ) *Keeper { @@ -69,7 +67,6 @@ func NewKeeper( bankKeeper: bk, amm: amm, oracleKeeper: ok, - accountedPoolKeeper: apk, } } @@ -232,7 +229,7 @@ func (k Keeper) UpdateTokensCommitment(commitments *ctypes.Commitments, new_unco func (k Keeper) CalculateProxyTVL(ctx sdk.Context) sdk.Dec { multipliedShareSum := sdk.ZeroDec() k.amm.IterateLiquidityPools(ctx, func(p ammtypes.Pool) bool { - tvl, err := p.TVL(ctx, k.oracleKeeper, k.accountedPoolKeeper) + tvl, err := p.TVL(ctx, k.oracleKeeper) if err != nil { return false } @@ -260,7 +257,7 @@ func (k Keeper) CalculateTVL(ctx sdk.Context) sdk.Dec { TVL := sdk.ZeroDec() k.amm.IterateLiquidityPools(ctx, func(p ammtypes.Pool) bool { - tvl, err := p.TVL(ctx, k.oracleKeeper, k.accountedPoolKeeper) + tvl, err := p.TVL(ctx, k.oracleKeeper) if err != nil { return false } diff --git a/x/incentive/keeper/keeper_fees.go b/x/incentive/keeper/keeper_fees.go index 7d582a506..764b2f188 100644 --- a/x/incentive/keeper/keeper_fees.go +++ b/x/incentive/keeper/keeper_fees.go @@ -73,7 +73,7 @@ func (k Keeper) CollectGasFeesToIncentiveModule(ctx sdk.Context) sdk.Coins { // Executes the swap in the pool and stores the output. Updates pool assets but // does not actually transfer any tokens to or from the pool. snapshot := k.amm.GetPoolSnapshotOrSet(ctx, pool) - tokenOutCoin, _, err := pool.SwapOutAmtGivenIn(ctx, k.oracleKeeper, &snapshot, sdk.Coins{tokenIn}, ptypes.USDC, sdk.ZeroDec(), k.accountedPoolKeeper) + tokenOutCoin, _, err := k.amm.SwapOutAmtGivenIn(ctx, pool.PoolId, k.oracleKeeper, &snapshot, sdk.Coins{tokenIn}, ptypes.USDC, sdk.ZeroDec()) if err != nil { continue } diff --git a/x/incentive/keeper/keeper_lps.go b/x/incentive/keeper/keeper_lps.go index 0c7ec72f2..6c3865f51 100644 --- a/x/incentive/keeper/keeper_lps.go +++ b/x/incentive/keeper/keeper_lps.go @@ -20,7 +20,7 @@ func (k Keeper) CalculateRewardsForLPs(ctx sdk.Context, totalProxyTVL sdk.Dec, c // newEdenAllocated = 80 / ( 80 + 90 + 200 + 0) * 100 // Pool share = 80 // edenAmountPerEpochLp = 100 - tvl, err := p.TVL(ctx, k.oracleKeeper, k.accountedPoolKeeper) + tvl, err := p.TVL(ctx, k.oracleKeeper) if err != nil { return false } diff --git a/x/incentive/types/expected_keepers.go b/x/incentive/types/expected_keepers.go index 47841a775..7ed4191fe 100644 --- a/x/incentive/types/expected_keepers.go +++ b/x/incentive/types/expected_keepers.go @@ -92,6 +92,15 @@ type AmmKeeper interface { // IterateCommitments iterates over all Commitments and performs a callback. IterateLiquidityPools(sdk.Context, func(ammtypes.Pool) bool) GetPoolSnapshotOrSet(ctx sdk.Context, pool ammtypes.Pool) (val ammtypes.Pool) + + SwapOutAmtGivenIn( + ctx sdk.Context, poolId uint64, + oracleKeeper ammtypes.OracleKeeper, + snapshot *ammtypes.Pool, + tokensIn sdk.Coins, + tokenOutDenom string, + swapFee sdk.Dec, + ) (tokenOut sdk.Coin, weightBalanceBonus sdk.Dec, err error) } // OracleKeeper defines the expected interface needed to retrieve price info diff --git a/x/margin/keeper/check_max_open_positions_test.go b/x/margin/keeper/check_max_open_positions_test.go index 891452846..7afb95d8d 100644 --- a/x/margin/keeper/check_max_open_positions_test.go +++ b/x/margin/keeper/check_max_open_positions_test.go @@ -24,7 +24,7 @@ func TestCheckMaxOpenPositions_OpenPositionsBelowMax(t *testing.T) { // Mock behavior mockChecker.On("GetOpenMTPCount", ctx).Return(uint64(5)) - mockChecker.On("GetMaxOpenPositions", ctx).Return(10) + mockChecker.On("GetMaxOpenPositions", ctx).Return(uint64(10)) err := k.CheckMaxOpenPositions(ctx) @@ -46,7 +46,7 @@ func TestCheckMaxOpenPositions_OpenPositionsEqualToMax(t *testing.T) { // Mock behavior mockChecker.On("GetOpenMTPCount", ctx).Return(uint64(10)) - mockChecker.On("GetMaxOpenPositions", ctx).Return(10) + mockChecker.On("GetMaxOpenPositions", ctx).Return(uint64(10)) err := k.CheckMaxOpenPositions(ctx) @@ -68,7 +68,7 @@ func TestCheckMaxOpenPositions_OpenPositionsExceedMax(t *testing.T) { // Mock behavior mockChecker.On("GetOpenMTPCount", ctx).Return(uint64(11)) - mockChecker.On("GetMaxOpenPositions", ctx).Return(10) + mockChecker.On("GetMaxOpenPositions", ctx).Return(uint64(10)) err := k.CheckMaxOpenPositions(ctx) diff --git a/x/accountedpool/keeper/hooks_amm.go b/x/margin/keeper/hooks_amm.go similarity index 71% rename from x/accountedpool/keeper/hooks_amm.go rename to x/margin/keeper/hooks_amm.go index 493f7bf2f..26fd476fd 100644 --- a/x/accountedpool/keeper/hooks_amm.go +++ b/x/margin/keeper/hooks_amm.go @@ -6,23 +6,45 @@ import ( ) // AfterPoolCreated is called after CreatePool -func (k Keeper) AfterPoolCreated(ctx sdk.Context, sender sdk.AccAddress, pool ammtypes.Pool) { - k.InitiateAccountedPool(ctx, pool) +func (k Keeper) AfterPoolCreated(ctx sdk.Context, sender sdk.AccAddress, ammPool ammtypes.Pool) { + if k.hooks != nil { + k.hooks.AfterAmmPoolCreated(ctx, ammPool) + } } // AfterJoinPool is called after JoinPool, JoinSwapExternAmountIn, and JoinSwapShareAmountOut -func (k Keeper) AfterJoinPool(ctx sdk.Context, sender sdk.AccAddress, pool ammtypes.Pool, enterCoins sdk.Coins, shareOutAmount sdk.Int) { - k.UpdateAccountedPoolByAMM(ctx, pool) +func (k Keeper) AfterJoinPool(ctx sdk.Context, sender sdk.AccAddress, ammPool ammtypes.Pool, enterCoins sdk.Coins, shareOutAmount sdk.Int) { + marginPool, found := k.GetPool(ctx, ammPool.PoolId) + if !found { + return + } + + if k.hooks != nil { + k.hooks.AfterAmmJoinPool(ctx, ammPool, marginPool) + } } // AfterExitPool is called after ExitPool, ExitSwapShareAmountIn, and ExitSwapExternAmountOut -func (k Keeper) AfterExitPool(ctx sdk.Context, sender sdk.AccAddress, pool ammtypes.Pool, shareInAmount sdk.Int, exitCoins sdk.Coins) { - k.UpdateAccountedPoolByAMM(ctx, pool) +func (k Keeper) AfterExitPool(ctx sdk.Context, sender sdk.AccAddress, ammPool ammtypes.Pool, shareInAmount sdk.Int, exitCoins sdk.Coins) { + marginPool, found := k.GetPool(ctx, ammPool.PoolId) + if !found { + return + } + + if k.hooks != nil { + k.hooks.AfterAmmExitPool(ctx, ammPool, marginPool) + } } // AfterSwap is called after SwapExactAmountIn and SwapExactAmountOut -func (k Keeper) AfterSwap(ctx sdk.Context, sender sdk.AccAddress, pool ammtypes.Pool, input sdk.Coins, output sdk.Coins) { - k.UpdateAccountedPoolByAMM(ctx, pool) +func (k Keeper) AfterSwap(ctx sdk.Context, sender sdk.AccAddress, ammPool ammtypes.Pool, input sdk.Coins, output sdk.Coins) { + marginPool, found := k.GetPool(ctx, ammPool.PoolId) + if !found { + return + } + if k.hooks != nil { + k.hooks.AfterAmmSwap(ctx, ammPool, marginPool) + } } // Hooks wrapper struct for tvl keeper diff --git a/x/margin/keeper/params.go b/x/margin/keeper/params.go index 0c83fa3ef..404ae750a 100644 --- a/x/margin/keeper/params.go +++ b/x/margin/keeper/params.go @@ -95,8 +95,8 @@ func (k Keeper) GetIncrementalInterestPaymentFundAddress(ctx sdk.Context) sdk.Ac return addr } -func (k Keeper) GetMaxOpenPositions(ctx sdk.Context) int64 { - return k.GetParams(ctx).MaxOpenPositions +func (k Keeper) GetMaxOpenPositions(ctx sdk.Context) uint64 { + return (uint64)(k.GetParams(ctx).MaxOpenPositions) } func (k Keeper) GetIncrementalInterestPaymentEnabled(ctx sdk.Context) bool { diff --git a/x/margin/types/expected_keepers.go b/x/margin/types/expected_keepers.go index 6ed938bfd..1e67c13bf 100644 --- a/x/margin/types/expected_keepers.go +++ b/x/margin/types/expected_keepers.go @@ -16,7 +16,7 @@ type AuthorizationChecker interface { //go:generate mockery --srcpkg . --name PositionChecker --structname PositionChecker --filename position_checker.go --with-expecter type PositionChecker interface { GetOpenMTPCount(ctx sdk.Context) uint64 - GetMaxOpenPositions(ctx sdk.Context) int64 + GetMaxOpenPositions(ctx sdk.Context) uint64 } //go:generate mockery --srcpkg . --name PoolChecker --structname PoolChecker --filename pool_checker.go --with-expecter diff --git a/x/margin/types/hooks.go b/x/margin/types/hooks.go index 547ccabb3..adde87679 100644 --- a/x/margin/types/hooks.go +++ b/x/margin/types/hooks.go @@ -14,6 +14,18 @@ type MarginHooks interface { // AfterMarginPositionClosed is called after a position gets closed. AfterMarginPositionClosed(ctx sdk.Context, ammPool ammtypes.Pool, marginPool Pool) + + // AfterPoolCreated is called after CreatePool + AfterAmmPoolCreated(ctx sdk.Context, ammPool ammtypes.Pool) + + // AfterJoinPool is called after JoinPool, JoinSwapExternAmountIn, and JoinSwapShareAmountOut + AfterAmmJoinPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool Pool) + + // AfterExitPool is called after ExitPool, ExitSwapShareAmountIn, and ExitSwapExternAmountOut + AfterAmmExitPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool Pool) + + // AfterSwap is called after SwapExactAmountIn and SwapExactAmountOut + AfterAmmSwap(ctx sdk.Context, ammPool ammtypes.Pool, marginPool Pool) } var _ MarginHooks = MultiMarginHooks{} @@ -43,3 +55,27 @@ func (h MultiMarginHooks) AfterMarginPositionClosed(ctx sdk.Context, ammPool amm h[i].AfterMarginPositionClosed(ctx, ammPool, marginPool) } } + +func (h MultiMarginHooks) AfterAmmPoolCreated(ctx sdk.Context, ammPool ammtypes.Pool) { + for i := range h { + h[i].AfterAmmPoolCreated(ctx, ammPool) + } +} + +func (h MultiMarginHooks) AfterAmmJoinPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool Pool) { + for i := range h { + h[i].AfterAmmJoinPool(ctx, ammPool, marginPool) + } +} + +func (h MultiMarginHooks) AfterAmmExitPool(ctx sdk.Context, ammPool ammtypes.Pool, marginPool Pool) { + for i := range h { + h[i].AfterAmmExitPool(ctx, ammPool, marginPool) + } +} + +func (h MultiMarginHooks) AfterAmmSwap(ctx sdk.Context, ammPool ammtypes.Pool, marginPool Pool) { + for i := range h { + h[i].AfterAmmSwap(ctx, ammPool, marginPool) + } +} diff --git a/x/margin/types/mocks/position_checker.go b/x/margin/types/mocks/position_checker.go index 6132367bc..ee18ecc89 100644 --- a/x/margin/types/mocks/position_checker.go +++ b/x/margin/types/mocks/position_checker.go @@ -21,14 +21,14 @@ func (_m *PositionChecker) EXPECT() *PositionChecker_Expecter { } // GetMaxOpenPositions provides a mock function with given fields: ctx -func (_m *PositionChecker) GetMaxOpenPositions(ctx types.Context) int64 { +func (_m *PositionChecker) GetMaxOpenPositions(ctx types.Context) uint64 { ret := _m.Called(ctx) - var r0 int64 - if rf, ok := ret.Get(0).(func(types.Context) int64); ok { + var r0 uint64 + if rf, ok := ret.Get(0).(func(types.Context) uint64); ok { r0 = rf(ctx) } else { - r0 = ret.Get(0).(int64) + r0 = ret.Get(0).(uint64) } return r0 @@ -52,12 +52,12 @@ func (_c *PositionChecker_GetMaxOpenPositions_Call) Run(run func(ctx types.Conte return _c } -func (_c *PositionChecker_GetMaxOpenPositions_Call) Return(_a0 int64) *PositionChecker_GetMaxOpenPositions_Call { +func (_c *PositionChecker_GetMaxOpenPositions_Call) Return(_a0 uint64) *PositionChecker_GetMaxOpenPositions_Call { _c.Call.Return(_a0) return _c } -func (_c *PositionChecker_GetMaxOpenPositions_Call) RunAndReturn(run func(types.Context) int64) *PositionChecker_GetMaxOpenPositions_Call { +func (_c *PositionChecker_GetMaxOpenPositions_Call) RunAndReturn(run func(types.Context) uint64) *PositionChecker_GetMaxOpenPositions_Call { _c.Call.Return(run) return _c }