From c173118450eafb9cd1ae2d1433bc28f220520fd9 Mon Sep 17 00:00:00 2001 From: Cosmic Vagabond <121588426+cosmic-vagabond@users.noreply.github.com> Date: Wed, 20 Sep 2023 11:31:07 +0200 Subject: [PATCH] feat: add additional short flow logic --- x/margin/keeper/open_long_process.go | 2 +- x/margin/keeper/open_short_process.go | 152 +++-- x/margin/keeper/open_short_test.go | 637 +++++++++++++++++++++ x/margin/types/expected_keepers.go | 3 + x/margin/types/mocks/open_short_checker.go | 120 ++++ 5 files changed, 860 insertions(+), 54 deletions(-) create mode 100644 x/margin/keeper/open_short_test.go diff --git a/x/margin/keeper/open_long_process.go b/x/margin/keeper/open_long_process.go index d3fd4f433..45ef9f9e0 100644 --- a/x/margin/keeper/open_long_process.go +++ b/x/margin/keeper/open_long_process.go @@ -39,7 +39,7 @@ func (k Keeper) ProcessOpenLong(ctx sdk.Context, mtp *types.MTP, leverage sdk.De return nil, err } if !k.OpenLongChecker.HasSufficientPoolBalance(ctx, ammPool, ptypes.BaseCurrency, borrowingAmount) { - return nil, sdkerrors.Wrap(types.ErrBorrowTooHigh, leveragedAmount.String()) + return nil, sdkerrors.Wrap(types.ErrBorrowTooHigh, borrowingAmount.String()) } } else { if !k.OpenLongChecker.HasSufficientPoolBalance(ctx, ammPool, msg.CollateralAsset, leveragedAmount) { diff --git a/x/margin/keeper/open_short_process.go b/x/margin/keeper/open_short_process.go index d497882e1..5ce15fb65 100644 --- a/x/margin/keeper/open_short_process.go +++ b/x/margin/keeper/open_short_process.go @@ -2,64 +2,110 @@ package keeper import ( sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/elys-network/elys/x/margin/types" + ptypes "github.com/elys-network/elys/x/parameter/types" ) func (k Keeper) ProcessOpenShort(ctx sdk.Context, mtp *types.MTP, leverage sdk.Dec, eta sdk.Dec, collateralAmountDec sdk.Dec, poolId uint64, msg *types.MsgOpen) (*types.MTP, error) { // Determine the trading asset. - // tradingAsset := k.OpenShortChecker.GetTradingAsset(msg.CollateralAsset, msg.BorrowAsset) - - // // Fetch the pool associated with the given pool ID. - // pool, found := k.OpenShortChecker.GetPool(ctx, poolId) - // if !found { - // return nil, sdkerrors.Wrap(types.ErrPoolDoesNotExist, tradingAsset) - // } - - // // Check if the pool is enabled. - // if !k.OpenShortChecker.IsPoolEnabled(ctx, poolId) { - // return nil, sdkerrors.Wrap(types.ErrMTPDisabled, tradingAsset) - // } - - // // Fetch the corresponding AMM (Automated Market Maker) pool. - // ammPool, err := k.OpenShortChecker.GetAmmPool(ctx, poolId, tradingAsset) - // if err != nil { - // return nil, err - // } - - // // Calculate the leveraged amount based on the collateral provided and the leverage. - // leveragedAmount := sdk.NewInt(collateralAmountDec.Mul(leverage).TruncateInt().Int64()) - - // // Borrow the asset the user wants to short. - // // err = k.OpenLongChecker.Borrow(ctx, msg.CollateralAsset, msg.BorrowAsset, msg.CollateralAmount, custodyAmount, mtp, &ammPool, &pool, eta) - // // if err != nil { - // // return nil, err - // // } - - // // Calculate the custody amount. - // swappedAmount, err := k.OpenShortChecker.EstimateSwap(ctx, leveragedAmount, ptypes.BaseCurrency, ammPool) - // if err != nil { - // return nil, err - // } - - // // Ensure the AMM pool has enough balance. - // if !k.OpenShortChecker.HasSufficientPoolBalance(ctx, ammPool, msg.BorrowAsset, swappedAmount) { - // return nil, sdkerrors.Wrap(types.ErrSwapTooHigh, swappedAmount.String()) - // } - - // // Additional checks and operations: - // // 1. Check minimum liabilities. - // err = k.OpenShortChecker.CheckMinLiabilities(ctx, swappedAmount, eta, pool, ammPool, msg.CollateralAsset) - // if err != nil { - // return nil, err - // } - - // // 2. Update the pool and MTP health. - // if err = k.OpenShortChecker.UpdatePoolHealth(ctx, &pool); err != nil { - // return nil, err - // } - // if err = k.OpenShortChecker.UpdateMTPHealth(ctx, *mtp, ammPool); err != nil { - // return nil, err - // } + tradingAsset := k.OpenShortChecker.GetTradingAsset(msg.CollateralAsset, msg.BorrowAsset) + + // Fetch the pool associated with the given pool ID. + pool, found := k.OpenShortChecker.GetPool(ctx, poolId) + if !found { + return nil, sdkerrors.Wrap(types.ErrPoolDoesNotExist, tradingAsset) + } + + // Check if the pool is enabled. + if !k.OpenShortChecker.IsPoolEnabled(ctx, poolId) { + return nil, sdkerrors.Wrap(types.ErrMTPDisabled, tradingAsset) + } + + // Fetch the corresponding AMM (Automated Market Maker) pool. + ammPool, err := k.OpenShortChecker.GetAmmPool(ctx, poolId, tradingAsset) + if err != nil { + return nil, err + } + + // Calculate the leveraged amount based on the collateral provided and the leverage. + leveragedAmount := sdk.NewInt(collateralAmountDec.Mul(leverage).TruncateInt().Int64()) + + if msg.CollateralAsset != ptypes.BaseCurrency { + return nil, sdkerrors.Wrap(types.ErrInvalidBorrowingAsset, "collateral must be base currency") + } + + custodyAmtToken := sdk.NewCoin(ptypes.BaseCurrency, leveragedAmount) + borrowingAmount, err := k.OpenShortChecker.EstimateSwapGivenOut(ctx, custodyAmtToken, msg.BorrowAsset, ammPool) + if err != nil { + return nil, err + } + + // check the balance + if !k.OpenShortChecker.HasSufficientPoolBalance(ctx, ammPool, ptypes.BaseCurrency, borrowingAmount) { + return nil, sdkerrors.Wrap(types.ErrBorrowTooHigh, borrowingAmount.String()) + } + + // Check minimum liabilities. + collateralTokenAmt := sdk.NewCoin(msg.CollateralAsset, msg.CollateralAmount) + err = k.OpenShortChecker.CheckMinLiabilities(ctx, collateralTokenAmt, eta, pool, ammPool, msg.BorrowAsset) + if err != nil { + return nil, err + } + + // Calculate custody amount. + leveragedAmtTokenIn := sdk.NewCoin(msg.BorrowAsset, borrowingAmount) + custodyAmount, err := k.OpenShortChecker.EstimateSwap(ctx, leveragedAmtTokenIn, ptypes.BaseCurrency, ammPool) + if err != nil { + return nil, err + } + + // Ensure the AMM pool has enough balance. + if !k.OpenShortChecker.HasSufficientPoolBalance(ctx, ammPool, ptypes.BaseCurrency, custodyAmount) { + return nil, sdkerrors.Wrap(types.ErrCustodyTooHigh, custodyAmount.String()) + } + + // if position is short then override the custody asset to the base currency + if mtp.Position == types.Position_SHORT { + mtp.CustodyAssets = []string{ptypes.BaseCurrency} + } + + // Borrow the asset the user wants to short. + err = k.OpenShortChecker.Borrow(ctx, msg.CollateralAsset, ptypes.BaseCurrency, msg.CollateralAmount, custodyAmount, mtp, &ammPool, &pool, eta) + if err != nil { + return nil, err + } + + // Update the pool health. + if err = k.OpenShortChecker.UpdatePoolHealth(ctx, &pool); err != nil { + return nil, err + } + + // Take custody from the pool balance. + if err = k.OpenShortChecker.TakeInCustody(ctx, *mtp, &pool); err != nil { + return nil, err + } + + // Update the MTP health. + lr, err := k.OpenShortChecker.UpdateMTPHealth(ctx, *mtp, ammPool) + if err != nil { + return nil, err + } + + // Check if the MTP is unhealthy + safetyFactor := k.OpenShortChecker.GetSafetyFactor(ctx) + if lr.LTE(safetyFactor) { + return nil, types.ErrMTPUnhealthy + } + + // Update consolidated collateral amount + k.OpenShortChecker.CalcMTPConsolidateCollateral(ctx, mtp) + + // Calculate consolidate liabiltiy + k.OpenShortChecker.CalcMTPConsolidateLiability(ctx, mtp) + + // Set MTP + k.OpenShortChecker.SetMTP(ctx, mtp) // Return the updated Margin Trading Position (MTP). return mtp, nil diff --git a/x/margin/keeper/open_short_test.go b/x/margin/keeper/open_short_test.go new file mode 100644 index 000000000..b765a2a8b --- /dev/null +++ b/x/margin/keeper/open_short_test.go @@ -0,0 +1,637 @@ +package keeper_test + +import ( + "errors" + "testing" + + "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + ammtypes "github.com/elys-network/elys/x/amm/types" + "github.com/elys-network/elys/x/margin/keeper" + "github.com/elys-network/elys/x/margin/types" + "github.com/elys-network/elys/x/margin/types/mocks" + "github.com/stretchr/testify/assert" + + tmproto "github.com/cometbft/cometbft/proto/tendermint/types" + simapp "github.com/elys-network/elys/app" + ptypes "github.com/elys-network/elys/x/parameter/types" + "github.com/stretchr/testify/require" +) + +func TestOpenShort_PoolNotFound(t *testing.T) { + // Setup the mock checker + mockChecker := new(mocks.OpenShortChecker) + + // Create an instance of Keeper with the mock checker + k := keeper.Keeper{ + OpenShortChecker: mockChecker, + } + + var ( + ctx = sdk.Context{} // Mock or setup a context + msg = &types.MsgOpen{ + Leverage: math.LegacyNewDec(10), + CollateralAmount: math.NewInt(1), + CollateralAsset: "aaa", + BorrowAsset: "bbb", + } + poolId = uint64(42) + ) + + // Mock behavior + mockChecker.On("GetMaxLeverageParam", ctx).Return(msg.Leverage) + mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.CollateralAsset) + mockChecker.On("GetPool", ctx, poolId).Return(types.Pool{}, false) + + _, err := k.OpenShort(ctx, poolId, msg) + + // Expect an error about the pool not existing + assert.True(t, errors.Is(err, types.ErrPoolDoesNotExist)) + mockChecker.AssertExpectations(t) +} + +func TestOpenShort_PoolDisabled(t *testing.T) { + // Setup the mock checker + mockChecker := new(mocks.OpenShortChecker) + + // Create an instance of Keeper with the mock checker + k := keeper.Keeper{ + OpenShortChecker: mockChecker, + } + + var ( + ctx = sdk.Context{} // Mock or setup a context + msg = &types.MsgOpen{ + Leverage: math.LegacyNewDec(10), + CollateralAmount: math.NewInt(1), + } + poolId = uint64(42) + ) + + // Mock behaviors + mockChecker.On("GetMaxLeverageParam", ctx).Return(msg.Leverage) + mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.CollateralAsset) + mockChecker.On("GetPool", ctx, poolId).Return(types.Pool{}, true) + mockChecker.On("IsPoolEnabled", ctx, poolId).Return(false) + + _, err := k.OpenShort(ctx, poolId, msg) + + // Expect an error about the pool being disabled + assert.True(t, errors.Is(err, types.ErrMTPDisabled)) + mockChecker.AssertExpectations(t) +} + +func TestOpenShort_InsufficientAmmPoolBalanceForLeveragedAmount(t *testing.T) { + // Setup the mock checker + mockChecker := new(mocks.OpenShortChecker) + + // Create an instance of Keeper with the mock checker + k := keeper.Keeper{ + OpenShortChecker: mockChecker, + } + + var ( + ctx = sdk.Context{} // Mock or setup a context + msg = &types.MsgOpen{ + Leverage: math.LegacyNewDec(2), + CollateralAmount: math.NewInt(1000), + Creator: "", + CollateralAsset: ptypes.BaseCurrency, + BorrowAsset: "uatom", + Position: types.Position_SHORT, + } + poolId = uint64(42) + ) + + // Mock the behaviors to get to the HasSufficientPoolBalance check + mockChecker.On("GetMaxLeverageParam", ctx).Return(msg.Leverage) + mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) + mockChecker.On("GetPool", ctx, poolId).Return(types.Pool{}, true) + mockChecker.On("IsPoolEnabled", ctx, poolId).Return(true) + mockChecker.On("GetAmmPool", ctx, poolId, msg.BorrowAsset).Return(ammtypes.Pool{}, nil) // Assuming a valid pool is returned + + collateralAmountDec := sdk.NewDecFromBigInt(msg.CollateralAmount.BigInt()) + leveragedAmount := sdk.NewInt(collateralAmountDec.Mul(msg.Leverage).TruncateInt().Int64()) + custodyAmtToken := sdk.NewCoin(ptypes.BaseCurrency, leveragedAmount) + borrowingAmount := sdk.NewInt(99) + + mockChecker.On("EstimateSwapGivenOut", ctx, custodyAmtToken, msg.BorrowAsset, ammtypes.Pool{}).Return(borrowingAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, borrowingAmount).Return(false) + + _, err := k.OpenShort(ctx, poolId, msg) + + // Expect an error about the borrow amount being too high + assert.True(t, errors.Is(err, types.ErrBorrowTooHigh)) + mockChecker.AssertExpectations(t) +} + +func TestOpenShort_InsufficientLiabilities(t *testing.T) { + // Setup the mock checker + mockChecker := new(mocks.OpenShortChecker) + + // Create an instance of Keeper with the mock checker + k := keeper.Keeper{ + OpenShortChecker: mockChecker, + } + + var ( + ctx = sdk.Context{} // Mock or setup a context + msg = &types.MsgOpen{ + Leverage: math.LegacyNewDec(2), + CollateralAmount: math.NewInt(1000), + Creator: "", + CollateralAsset: ptypes.BaseCurrency, + BorrowAsset: "uatom", + Position: types.Position_SHORT, + } + poolId = uint64(42) + ) + + // Mock the behaviors to get to the CheckMinLiabilities check + mockChecker.On("GetMaxLeverageParam", ctx).Return(msg.Leverage) + mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) + mockChecker.On("GetPool", ctx, poolId).Return(types.Pool{}, true) + mockChecker.On("IsPoolEnabled", ctx, poolId).Return(true) + mockChecker.On("GetAmmPool", ctx, poolId, msg.BorrowAsset).Return(ammtypes.Pool{}, nil) // Assuming a valid pool is returned + + collateralAmountDec := sdk.NewDecFromBigInt(msg.CollateralAmount.BigInt()) + leveragedAmount := sdk.NewInt(collateralAmountDec.Mul(msg.Leverage).TruncateInt().Int64()) + custodyAmtToken := sdk.NewCoin(ptypes.BaseCurrency, leveragedAmount) + borrowingAmount := sdk.NewInt(99) + + mockChecker.On("EstimateSwapGivenOut", ctx, custodyAmtToken, msg.BorrowAsset, ammtypes.Pool{}).Return(borrowingAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, borrowingAmount).Return(true) + + // Mock the behavior where CheckMinLiabilities returns an error indicating insufficient liabilities + liabilityError := errors.New("insufficient liabilities") + collateralTokenAmt := sdk.NewCoin(msg.CollateralAsset, msg.CollateralAmount) + + mockChecker.On("CheckMinLiabilities", ctx, collateralTokenAmt, sdk.NewDec(1), types.Pool{}, ammtypes.Pool{}, msg.BorrowAsset).Return(liabilityError) + + _, err := k.OpenShort(ctx, poolId, msg) + + // Expect the custom error indicating insufficient liabilities + assert.True(t, errors.Is(err, liabilityError)) + mockChecker.AssertExpectations(t) +} + +func TestOpenShort_InsufficientAmmPoolBalanceForCustody(t *testing.T) { + // Setup the mock checker + mockChecker := new(mocks.OpenShortChecker) + + // Create an instance of Keeper with the mock checker + k := keeper.Keeper{ + OpenShortChecker: mockChecker, + } + + var ( + ctx = sdk.Context{} // Mock or setup a context + msg = &types.MsgOpen{ + Leverage: math.LegacyNewDec(10), + CollateralAmount: math.NewInt(1000), + Creator: "", + CollateralAsset: ptypes.BaseCurrency, + BorrowAsset: "uatom", + Position: types.Position_SHORT, + } + poolId = uint64(42) + ) + // Mock behaviors + mockChecker.On("GetMaxLeverageParam", ctx).Return(msg.Leverage) + mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) + mockChecker.On("GetPool", ctx, poolId).Return(types.Pool{}, true) + mockChecker.On("IsPoolEnabled", ctx, poolId).Return(true) + mockChecker.On("GetAmmPool", ctx, poolId, msg.BorrowAsset).Return(ammtypes.Pool{}, nil) + + collateralAmountDec := sdk.NewDecFromBigInt(msg.CollateralAmount.BigInt()) + leveragedAmount := sdk.NewInt(collateralAmountDec.Mul(msg.Leverage).TruncateInt().Int64()) + custodyAmtToken := sdk.NewCoin(ptypes.BaseCurrency, leveragedAmount) + borrowingAmount := sdk.NewInt(99) + + mockChecker.On("EstimateSwapGivenOut", ctx, custodyAmtToken, msg.BorrowAsset, ammtypes.Pool{}).Return(borrowingAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, borrowingAmount).Return(true) + + collateralTokenAmt := sdk.NewCoin(msg.CollateralAsset, msg.CollateralAmount) + eta := math.LegacyNewDec(9) + + mockChecker.On("CheckMinLiabilities", ctx, collateralTokenAmt, eta, types.Pool{}, ammtypes.Pool{}, msg.BorrowAsset).Return(nil) + + leveragedAmtTokenIn := sdk.NewCoin(msg.BorrowAsset, borrowingAmount) + custodyAmount := math.NewInt(199) + + mockChecker.On("EstimateSwap", ctx, leveragedAmtTokenIn, ptypes.BaseCurrency, ammtypes.Pool{}).Return(custodyAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, custodyAmount).Return(false) + + _, err := k.OpenShort(ctx, poolId, msg) + + // Expect an error about custody amount being too high + assert.True(t, errors.Is(err, types.ErrCustodyTooHigh)) + mockChecker.AssertExpectations(t) +} + +func TestOpenShort_ErrorsDuringOperations(t *testing.T) { + // Setup the mock checker + mockChecker := new(mocks.OpenShortChecker) + + // Create an instance of Keeper with the mock checker + k := keeper.Keeper{ + OpenShortChecker: mockChecker, + } + + var ( + ctx = sdk.Context{} // Mock or setup a context + msg = &types.MsgOpen{ + Leverage: math.LegacyNewDec(10), + CollateralAmount: math.NewInt(1000), + Creator: "", + CollateralAsset: ptypes.BaseCurrency, + BorrowAsset: "uatom", + Position: types.Position_SHORT, + } + poolId = uint64(42) + ) + + // Mock behaviors + mockChecker.On("GetMaxLeverageParam", ctx).Return(msg.Leverage) + mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) + mockChecker.On("GetPool", ctx, poolId).Return(types.Pool{}, true) + mockChecker.On("IsPoolEnabled", ctx, poolId).Return(true) + mockChecker.On("GetAmmPool", ctx, poolId, msg.BorrowAsset).Return(ammtypes.Pool{}, nil) + + collateralAmountDec := sdk.NewDecFromBigInt(msg.CollateralAmount.BigInt()) + leveragedAmount := sdk.NewInt(collateralAmountDec.Mul(msg.Leverage).TruncateInt().Int64()) + custodyAmtToken := sdk.NewCoin(ptypes.BaseCurrency, leveragedAmount) + borrowingAmount := sdk.NewInt(99) + + mockChecker.On("EstimateSwapGivenOut", ctx, custodyAmtToken, msg.BorrowAsset, ammtypes.Pool{}).Return(borrowingAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, borrowingAmount).Return(true) + + collateralTokenAmt := sdk.NewCoin(msg.CollateralAsset, msg.CollateralAmount) + eta := math.LegacyNewDec(9) + + mockChecker.On("CheckMinLiabilities", ctx, collateralTokenAmt, eta, types.Pool{}, ammtypes.Pool{}, msg.BorrowAsset).Return(nil) + + leveragedAmtTokenIn := sdk.NewCoin(msg.BorrowAsset, borrowingAmount) + custodyAmount := math.NewInt(199) + + mockChecker.On("EstimateSwap", ctx, leveragedAmtTokenIn, ptypes.BaseCurrency, ammtypes.Pool{}).Return(custodyAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, custodyAmount).Return(true) + + mtp := types.NewMTP(msg.Creator, msg.CollateralAsset, ptypes.BaseCurrency, msg.Position, msg.Leverage, poolId) + + borrowError := errors.New("borrow error") + mockChecker.On("Borrow", ctx, msg.CollateralAsset, ptypes.BaseCurrency, msg.CollateralAmount, custodyAmount, mtp, &ammtypes.Pool{}, &types.Pool{}, eta).Return(borrowError) + + _, err := k.OpenShort(ctx, poolId, msg) + + // Expect the borrow error + assert.True(t, errors.Is(err, borrowError)) + mockChecker.AssertExpectations(t) +} + +func TestOpenShort_LeverageRatioLessThanSafetyFactor(t *testing.T) { + // Setup the mock checker + mockChecker := new(mocks.OpenShortChecker) + + // Create an instance of Keeper with the mock checker + k := keeper.Keeper{ + OpenShortChecker: mockChecker, + } + + var ( + ctx = sdk.Context{} // Mock or setup a context + msg = &types.MsgOpen{ + Leverage: math.LegacyNewDec(10), + CollateralAmount: math.NewInt(1000), + Creator: "", + CollateralAsset: ptypes.BaseCurrency, + BorrowAsset: "uatom", + Position: types.Position_SHORT, + } + poolId = uint64(42) + ) + + // Mock behaviors + mockChecker.On("GetMaxLeverageParam", ctx).Return(msg.Leverage) + mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) + mockChecker.On("GetPool", ctx, poolId).Return(types.Pool{}, true) + mockChecker.On("IsPoolEnabled", ctx, poolId).Return(true) + mockChecker.On("GetAmmPool", ctx, poolId, msg.BorrowAsset).Return(ammtypes.Pool{}, nil) + + collateralAmountDec := sdk.NewDecFromBigInt(msg.CollateralAmount.BigInt()) + leveragedAmount := sdk.NewInt(collateralAmountDec.Mul(msg.Leverage).TruncateInt().Int64()) + custodyAmtToken := sdk.NewCoin(ptypes.BaseCurrency, leveragedAmount) + borrowingAmount := sdk.NewInt(99) + + mockChecker.On("EstimateSwapGivenOut", ctx, custodyAmtToken, msg.BorrowAsset, ammtypes.Pool{}).Return(borrowingAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, borrowingAmount).Return(true) + + collateralTokenAmt := sdk.NewCoin(msg.CollateralAsset, msg.CollateralAmount) + eta := math.LegacyNewDec(9) + + mockChecker.On("CheckMinLiabilities", ctx, collateralTokenAmt, eta, types.Pool{}, ammtypes.Pool{}, msg.BorrowAsset).Return(nil) + + leveragedAmtTokenIn := sdk.NewCoin(msg.BorrowAsset, borrowingAmount) + custodyAmount := math.NewInt(199) + + mockChecker.On("EstimateSwap", ctx, leveragedAmtTokenIn, ptypes.BaseCurrency, ammtypes.Pool{}).Return(custodyAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, custodyAmount).Return(true) + + mtp := types.NewMTP(msg.Creator, msg.CollateralAsset, ptypes.BaseCurrency, msg.Position, msg.Leverage, poolId) + + mockChecker.On("Borrow", ctx, msg.CollateralAsset, ptypes.BaseCurrency, msg.CollateralAmount, custodyAmount, mtp, &ammtypes.Pool{}, &types.Pool{}, eta).Return(nil) + mockChecker.On("UpdatePoolHealth", ctx, &types.Pool{}).Return(nil) + mockChecker.On("TakeInCustody", ctx, *mtp, &types.Pool{}).Return(nil) + + lr := math.LegacyNewDec(50) + + mockChecker.On("UpdateMTPHealth", ctx, *mtp, ammtypes.Pool{}).Return(lr, nil) + mockChecker.On("GetSafetyFactor", ctx).Return(sdk.NewDec(100)) + + _, err := k.OpenShort(ctx, poolId, msg) + + // Expect an error indicating MTP is unhealthy + assert.True(t, errors.Is(err, types.ErrMTPUnhealthy)) + mockChecker.AssertExpectations(t) +} + +func TestOpenShort_Success(t *testing.T) { + // Setup the mock checker + mockChecker := new(mocks.OpenShortChecker) + + // Create an instance of Keeper with the mock checker + k := keeper.Keeper{ + OpenShortChecker: mockChecker, + } + + var ( + ctx = sdk.Context{} // Mock or setup a context + msg = &types.MsgOpen{ + Leverage: math.LegacyNewDec(10), + CollateralAmount: math.NewInt(1000), + Creator: "", + CollateralAsset: ptypes.BaseCurrency, + BorrowAsset: "uatom", + Position: types.Position_SHORT, + } + poolId = uint64(42) + ) + + // Mock behaviors + mockChecker.On("GetMaxLeverageParam", ctx).Return(msg.Leverage) + mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) + mockChecker.On("GetPool", ctx, poolId).Return(types.Pool{}, true) + mockChecker.On("IsPoolEnabled", ctx, poolId).Return(true) + mockChecker.On("GetAmmPool", ctx, poolId, msg.BorrowAsset).Return(ammtypes.Pool{}, nil) + + collateralAmountDec := sdk.NewDecFromBigInt(msg.CollateralAmount.BigInt()) + leveragedAmount := sdk.NewInt(collateralAmountDec.Mul(msg.Leverage).TruncateInt().Int64()) + custodyAmtToken := sdk.NewCoin(ptypes.BaseCurrency, leveragedAmount) + borrowingAmount := sdk.NewInt(99) + + mockChecker.On("EstimateSwapGivenOut", ctx, custodyAmtToken, msg.BorrowAsset, ammtypes.Pool{}).Return(borrowingAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, borrowingAmount).Return(true) + + collateralTokenAmt := sdk.NewCoin(msg.CollateralAsset, msg.CollateralAmount) + eta := math.LegacyNewDec(9) + + mockChecker.On("CheckMinLiabilities", ctx, collateralTokenAmt, eta, types.Pool{}, ammtypes.Pool{}, msg.BorrowAsset).Return(nil) + + leveragedAmtTokenIn := sdk.NewCoin(msg.BorrowAsset, borrowingAmount) + custodyAmount := math.NewInt(199) + + mockChecker.On("EstimateSwap", ctx, leveragedAmtTokenIn, ptypes.BaseCurrency, ammtypes.Pool{}).Return(custodyAmount, nil) + mockChecker.On("HasSufficientPoolBalance", ctx, ammtypes.Pool{}, ptypes.BaseCurrency, custodyAmount).Return(true) + + mtp := types.NewMTP(msg.Creator, msg.CollateralAsset, ptypes.BaseCurrency, msg.Position, msg.Leverage, poolId) + + mockChecker.On("Borrow", ctx, msg.CollateralAsset, ptypes.BaseCurrency, msg.CollateralAmount, custodyAmount, mtp, &ammtypes.Pool{}, &types.Pool{}, eta).Return(nil) + mockChecker.On("UpdatePoolHealth", ctx, &types.Pool{}).Return(nil) + mockChecker.On("TakeInCustody", ctx, *mtp, &types.Pool{}).Return(nil) + + lr := math.LegacyNewDec(50) + + mockChecker.On("UpdateMTPHealth", ctx, *mtp, ammtypes.Pool{}).Return(lr, nil) + + safetyFactor := math.LegacyNewDec(10) + + mockChecker.On("GetSafetyFactor", ctx).Return(safetyFactor) + + mockChecker.On("CalcMTPConsolidateCollateral", ctx, mtp).Return(nil) + mockChecker.On("CalcMTPConsolidateLiability", ctx, mtp).Return() + mockChecker.On("SetMTP", ctx, mtp).Return(nil) + + _, err := k.OpenShort(ctx, poolId, msg) + // Expect no error + assert.Nil(t, err) + mockChecker.AssertExpectations(t) +} + +func TestOpenShort_BaseCurrency_Collateral(t *testing.T) { + app := simapp.InitElysTestApp(true) + ctx := app.BaseApp.NewContext(true, tmproto.Header{}) + + mk, amm, oracle := app.MarginKeeper, app.AmmKeeper, app.OracleKeeper + + // Setup coin prices + SetupStableCoinPrices(ctx, oracle) + + // Generate 1 random account with 1000stake balanced + addr := simapp.AddTestAddrs(app, ctx, 1, sdk.NewInt(1000000)) + + // Create a pool + // Mint 100000USDC + usdcToken := sdk.NewCoins(sdk.NewCoin(ptypes.BaseCurrency, sdk.NewInt(100000))) + // Mint 100000ATOM + atomToken := sdk.NewCoins(sdk.NewCoin(ptypes.ATOM, sdk.NewInt(100000))) + + err := app.BankKeeper.MintCoins(ctx, ammtypes.ModuleName, usdcToken) + require.NoError(t, err) + err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, ammtypes.ModuleName, addr[0], usdcToken) + require.NoError(t, err) + + err = app.BankKeeper.MintCoins(ctx, ammtypes.ModuleName, atomToken) + require.NoError(t, err) + err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, ammtypes.ModuleName, addr[0], atomToken) + require.NoError(t, err) + + poolAssets := []ammtypes.PoolAsset{ + { + Weight: sdk.NewInt(50), + Token: sdk.NewCoin(ptypes.ATOM, sdk.NewInt(100000)), + }, + { + Weight: sdk.NewInt(50), + Token: sdk.NewCoin(ptypes.BaseCurrency, sdk.NewInt(10000)), + }, + } + + argSwapFee := sdk.MustNewDecFromStr("0.0") + argExitFee := sdk.MustNewDecFromStr("0.0") + + poolParams := &ammtypes.PoolParams{ + SwapFee: argSwapFee, + ExitFee: argExitFee, + } + + msg := ammtypes.NewMsgCreatePool( + addr[0].String(), + poolParams, + poolAssets, + ) + + // Create a ATOM+USDC pool + poolId, err := amm.CreatePool(ctx, msg) + require.NoError(t, err) + require.Equal(t, poolId, uint64(0)) + + pools := amm.GetAllPool(ctx) + + // check length of pools + require.Equal(t, len(pools), 1) + + // check block height + require.Equal(t, int64(0), ctx.BlockHeight()) + + pool, found := amm.GetPool(ctx, poolId) + require.Equal(t, found, true) + + poolAddress := sdk.MustAccAddressFromBech32(pool.GetAddress()) + require.NoError(t, err) + + // Balance check before create a margin position + balances := app.BankKeeper.GetAllBalances(ctx, poolAddress) + require.Equal(t, balances.AmountOf(ptypes.BaseCurrency), sdk.NewInt(10000)) + require.Equal(t, balances.AmountOf(ptypes.ATOM), sdk.NewInt(100000)) + + // Create a margin position open msg + msg2 := types.NewMsgOpen( + addr[0].String(), + ptypes.BaseCurrency, + sdk.NewInt(100), + ptypes.ATOM, + types.Position_SHORT, + sdk.NewDec(5), + ) + + _, err = mk.Open(ctx, msg2) + require.NoError(t, err) + + mtps := mk.GetAllMTPs(ctx) + require.Equal(t, len(mtps), 1) + + balances = app.BankKeeper.GetAllBalances(ctx, poolAddress) + require.Equal(t, balances.AmountOf(ptypes.BaseCurrency), sdk.NewInt(10100)) + require.Equal(t, balances.AmountOf(ptypes.ATOM), sdk.NewInt(100000)) + + _, found = mk.OpenShortChecker.GetPool(ctx, pool.PoolId) + require.Equal(t, found, true) + + err = mk.InvariantCheck(ctx) + require.Equal(t, err, nil) +} + +func TestOpenShort_ATOM_Collateral(t *testing.T) { + app := simapp.InitElysTestApp(true) + ctx := app.BaseApp.NewContext(true, tmproto.Header{}) + + mk, amm, oracle := app.MarginKeeper, app.AmmKeeper, app.OracleKeeper + + // Setup coin prices + SetupStableCoinPrices(ctx, oracle) + + // Generate 1 random account with 1000stake balanced + addr := simapp.AddTestAddrs(app, ctx, 1, sdk.NewInt(1000000)) + + // Create a pool + // Mint 100000USDC + usdcToken := sdk.NewCoins(sdk.NewCoin(ptypes.BaseCurrency, sdk.NewInt(100000))) + // Mint 100000ATOM + atomToken := sdk.NewCoins(sdk.NewCoin(ptypes.ATOM, sdk.NewInt(100000))) + + err := app.BankKeeper.MintCoins(ctx, ammtypes.ModuleName, usdcToken) + require.NoError(t, err) + err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, ammtypes.ModuleName, addr[0], usdcToken) + require.NoError(t, err) + + err = app.BankKeeper.MintCoins(ctx, ammtypes.ModuleName, atomToken) + require.NoError(t, err) + err = app.BankKeeper.SendCoinsFromModuleToAccount(ctx, ammtypes.ModuleName, addr[0], atomToken) + require.NoError(t, err) + + poolAssets := []ammtypes.PoolAsset{ + { + Weight: sdk.NewInt(50), + Token: sdk.NewCoin(ptypes.ATOM, sdk.NewInt(1000)), + }, + { + Weight: sdk.NewInt(50), + Token: sdk.NewCoin(ptypes.BaseCurrency, sdk.NewInt(10000)), + }, + } + + argSwapFee := sdk.MustNewDecFromStr("0.0") + argExitFee := sdk.MustNewDecFromStr("0.0") + + poolParams := &ammtypes.PoolParams{ + SwapFee: argSwapFee, + ExitFee: argExitFee, + } + + msg := ammtypes.NewMsgCreatePool( + addr[0].String(), + poolParams, + poolAssets, + ) + + // Create a ATOM+USDC pool + poolId, err := amm.CreatePool(ctx, msg) + require.NoError(t, err) + require.Equal(t, poolId, uint64(0)) + + pools := amm.GetAllPool(ctx) + + // check length of pools + require.Equal(t, len(pools), 1) + + // check block height + require.Equal(t, int64(0), ctx.BlockHeight()) + + pool, found := amm.GetPool(ctx, poolId) + require.Equal(t, found, true) + + poolAddress := sdk.MustAccAddressFromBech32(pool.GetAddress()) + require.NoError(t, err) + + // Balance check before create a margin position + balances := app.BankKeeper.GetAllBalances(ctx, poolAddress) + require.Equal(t, balances.AmountOf(ptypes.BaseCurrency), sdk.NewInt(10000)) + require.Equal(t, balances.AmountOf(ptypes.ATOM), sdk.NewInt(1000)) + + // Create a margin position open msg + msg2 := types.NewMsgOpen( + addr[0].String(), + ptypes.ATOM, + sdk.NewInt(10), + ptypes.ATOM, + types.Position_SHORT, + sdk.NewDec(5), + ) + + _, err = mk.Open(ctx, msg2) + assert.True(t, errors.Is(err, sdkerrors.Wrap(types.ErrInvalidCollateralAsset, "collateral asset cannot be the same as the borrowed asset in a short position"))) + + mtps := mk.GetAllMTPs(ctx) + require.Equal(t, len(mtps), 0) + + balances = app.BankKeeper.GetAllBalances(ctx, poolAddress) + require.Equal(t, balances.AmountOf(ptypes.BaseCurrency), sdk.NewInt(10000)) + require.Equal(t, balances.AmountOf(ptypes.ATOM), sdk.NewInt(1000)) + + _, found = mk.OpenShortChecker.GetPool(ctx, pool.PoolId) + require.Equal(t, found, false) + + err = mk.InvariantCheck(ctx) + require.Equal(t, err, nil) +} diff --git a/x/margin/types/expected_keepers.go b/x/margin/types/expected_keepers.go index 54b7c9647..7f44bd7c1 100644 --- a/x/margin/types/expected_keepers.go +++ b/x/margin/types/expected_keepers.go @@ -90,6 +90,9 @@ type OpenShortChecker interface { GetAmmPoolBalance(ctx sdk.Context, ammPool ammtypes.Pool, assetDenom string) (sdk.Int, error) CheckShortAssets(ctx sdk.Context, collateralAsset string, borrowAsset string) error CheckSamePosition(ctx sdk.Context, msg *MsgOpen) *MTP + SetMTP(ctx sdk.Context, mtp *MTP) error + CalcMTPConsolidateCollateral(ctx sdk.Context, mtp *MTP) error + CalcMTPConsolidateLiability(ctx sdk.Context, mtp *MTP) } //go:generate mockery --srcpkg . --name CloseLongChecker --structname CloseLongChecker --filename close_long_checker.go --with-expecter diff --git a/x/margin/types/mocks/open_short_checker.go b/x/margin/types/mocks/open_short_checker.go index 2edbab27a..a0aaecd63 100644 --- a/x/margin/types/mocks/open_short_checker.go +++ b/x/margin/types/mocks/open_short_checker.go @@ -76,6 +76,83 @@ func (_c *OpenShortChecker_Borrow_Call) RunAndReturn(run func(types.Context, str return _c } +// CalcMTPConsolidateCollateral provides a mock function with given fields: ctx, mtp +func (_m *OpenShortChecker) CalcMTPConsolidateCollateral(ctx types.Context, mtp *margintypes.MTP) error { + ret := _m.Called(ctx, mtp) + + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, *margintypes.MTP) error); ok { + r0 = rf(ctx, mtp) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// OpenShortChecker_CalcMTPConsolidateCollateral_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CalcMTPConsolidateCollateral' +type OpenShortChecker_CalcMTPConsolidateCollateral_Call struct { + *mock.Call +} + +// CalcMTPConsolidateCollateral is a helper method to define mock.On call +// - ctx types.Context +// - mtp *margintypes.MTP +func (_e *OpenShortChecker_Expecter) CalcMTPConsolidateCollateral(ctx interface{}, mtp interface{}) *OpenShortChecker_CalcMTPConsolidateCollateral_Call { + return &OpenShortChecker_CalcMTPConsolidateCollateral_Call{Call: _e.mock.On("CalcMTPConsolidateCollateral", ctx, mtp)} +} + +func (_c *OpenShortChecker_CalcMTPConsolidateCollateral_Call) Run(run func(ctx types.Context, mtp *margintypes.MTP)) *OpenShortChecker_CalcMTPConsolidateCollateral_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.Context), args[1].(*margintypes.MTP)) + }) + return _c +} + +func (_c *OpenShortChecker_CalcMTPConsolidateCollateral_Call) Return(_a0 error) *OpenShortChecker_CalcMTPConsolidateCollateral_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *OpenShortChecker_CalcMTPConsolidateCollateral_Call) RunAndReturn(run func(types.Context, *margintypes.MTP) error) *OpenShortChecker_CalcMTPConsolidateCollateral_Call { + _c.Call.Return(run) + return _c +} + +// CalcMTPConsolidateLiability provides a mock function with given fields: ctx, mtp +func (_m *OpenShortChecker) CalcMTPConsolidateLiability(ctx types.Context, mtp *margintypes.MTP) { + _m.Called(ctx, mtp) +} + +// OpenShortChecker_CalcMTPConsolidateLiability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CalcMTPConsolidateLiability' +type OpenShortChecker_CalcMTPConsolidateLiability_Call struct { + *mock.Call +} + +// CalcMTPConsolidateLiability is a helper method to define mock.On call +// - ctx types.Context +// - mtp *margintypes.MTP +func (_e *OpenShortChecker_Expecter) CalcMTPConsolidateLiability(ctx interface{}, mtp interface{}) *OpenShortChecker_CalcMTPConsolidateLiability_Call { + return &OpenShortChecker_CalcMTPConsolidateLiability_Call{Call: _e.mock.On("CalcMTPConsolidateLiability", ctx, mtp)} +} + +func (_c *OpenShortChecker_CalcMTPConsolidateLiability_Call) Run(run func(ctx types.Context, mtp *margintypes.MTP)) *OpenShortChecker_CalcMTPConsolidateLiability_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.Context), args[1].(*margintypes.MTP)) + }) + return _c +} + +func (_c *OpenShortChecker_CalcMTPConsolidateLiability_Call) Return() *OpenShortChecker_CalcMTPConsolidateLiability_Call { + _c.Call.Return() + return _c +} + +func (_c *OpenShortChecker_CalcMTPConsolidateLiability_Call) RunAndReturn(run func(types.Context, *margintypes.MTP)) *OpenShortChecker_CalcMTPConsolidateLiability_Call { + _c.Call.Return(run) + return _c +} + // CheckMinLiabilities provides a mock function with given fields: ctx, collateralTokenAmt, eta, pool, ammPool, borrowAsset func (_m *OpenShortChecker) CheckMinLiabilities(ctx types.Context, collateralTokenAmt types.Coin, eta math.LegacyDec, pool margintypes.Pool, ammPool ammtypes.Pool, borrowAsset string) error { ret := _m.Called(ctx, collateralTokenAmt, eta, pool, ammPool, borrowAsset) @@ -698,6 +775,49 @@ func (_c *OpenShortChecker_IsPoolEnabled_Call) RunAndReturn(run func(types.Conte return _c } +// SetMTP provides a mock function with given fields: ctx, mtp +func (_m *OpenShortChecker) SetMTP(ctx types.Context, mtp *margintypes.MTP) error { + ret := _m.Called(ctx, mtp) + + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, *margintypes.MTP) error); ok { + r0 = rf(ctx, mtp) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// OpenShortChecker_SetMTP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMTP' +type OpenShortChecker_SetMTP_Call struct { + *mock.Call +} + +// SetMTP is a helper method to define mock.On call +// - ctx types.Context +// - mtp *margintypes.MTP +func (_e *OpenShortChecker_Expecter) SetMTP(ctx interface{}, mtp interface{}) *OpenShortChecker_SetMTP_Call { + return &OpenShortChecker_SetMTP_Call{Call: _e.mock.On("SetMTP", ctx, mtp)} +} + +func (_c *OpenShortChecker_SetMTP_Call) Run(run func(ctx types.Context, mtp *margintypes.MTP)) *OpenShortChecker_SetMTP_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.Context), args[1].(*margintypes.MTP)) + }) + return _c +} + +func (_c *OpenShortChecker_SetMTP_Call) Return(_a0 error) *OpenShortChecker_SetMTP_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *OpenShortChecker_SetMTP_Call) RunAndReturn(run func(types.Context, *margintypes.MTP) error) *OpenShortChecker_SetMTP_Call { + _c.Call.Return(run) + return _c +} + // SetPool provides a mock function with given fields: ctx, pool func (_m *OpenShortChecker) SetPool(ctx types.Context, pool margintypes.Pool) { _m.Called(ctx, pool)