From e1bb0cc9bc770b13aa204506ce11ed57b0a96f6f Mon Sep 17 00:00:00 2001 From: kenta-elys Date: Thu, 14 Sep 2023 17:00:41 +0000 Subject: [PATCH] fix: unit margin unit test failure --- x/margin/keeper/open.go | 4 +- x/margin/keeper/open_test.go | 13 ++ x/margin/types/expected_keepers.go | 4 + x/margin/types/mocks/open_checker.go | 172 +++++++++++++++++++++++++++ 4 files changed, 191 insertions(+), 2 deletions(-) diff --git a/x/margin/keeper/open.go b/x/margin/keeper/open.go index 04b6031ef..9354ab129 100644 --- a/x/margin/keeper/open.go +++ b/x/margin/keeper/open.go @@ -16,11 +16,11 @@ func (k Keeper) Open(ctx sdk.Context, msg *types.MsgOpen) (*types.MsgOpenRespons } // Check if it is the same direction position for the same trader. - if mtp := k.CheckSamePosition(ctx, msg); mtp != nil { + if mtp := k.OpenChecker.CheckSamePosition(ctx, msg); mtp != nil { return k.OpenConsolidate(ctx, mtp, msg) } - if err := k.CheckMaxOpenPositions(ctx); err != nil { + if err := k.OpenChecker.CheckMaxOpenPositions(ctx); err != nil { return nil, err } diff --git a/x/margin/keeper/open_test.go b/x/margin/keeper/open_test.go index cbbb6a69f..5f60d32ff 100644 --- a/x/margin/keeper/open_test.go +++ b/x/margin/keeper/open_test.go @@ -78,14 +78,18 @@ func TestOpen_ErrorCheckMaxOpenPositions(t *testing.T) { var ( ctx = sdk.Context{} // Mock or setup a context msg = &types.MsgOpen{ + Creator: "creator", CollateralAsset: "aaa", BorrowAsset: "bbb", + Position: types.Position_LONG, + Leverage: sdk.NewDec(10), } ) // Mock behavior mockChecker.On("CheckLongingAssets", ctx, msg.CollateralAsset, msg.BorrowAsset).Return(nil) mockChecker.On("CheckUserAuthorization", ctx, msg).Return(nil) + mockChecker.On("CheckSamePosition", ctx, msg).Return(nil) mockChecker.On("CheckMaxOpenPositions", ctx).Return(sdkerrors.Wrap(types.ErrMaxOpenPositions, "cannot open new positions")) _, err := k.Open(ctx, msg) @@ -106,14 +110,18 @@ func TestOpen_ErrorPreparePools(t *testing.T) { var ( ctx = sdk.Context{} // Mock or setup a context msg = &types.MsgOpen{ + Creator: "creator", CollateralAsset: "aaa", BorrowAsset: "bbb", + Position: types.Position_LONG, + Leverage: sdk.NewDec(10), } ) // Mock behavior mockChecker.On("CheckLongingAssets", ctx, msg.CollateralAsset, msg.BorrowAsset).Return(nil) mockChecker.On("CheckUserAuthorization", ctx, msg).Return(nil) + mockChecker.On("CheckSamePosition", ctx, msg).Return(nil) mockChecker.On("CheckMaxOpenPositions", ctx).Return(nil) mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) mockChecker.On("PreparePools", ctx, msg.BorrowAsset).Return(uint64(0), ammtypes.Pool{}, types.Pool{}, errors.New("error executing prepare pools")) @@ -145,6 +153,7 @@ func TestOpen_ErrorCheckPoolHealth(t *testing.T) { // Mock behavior mockChecker.On("CheckLongingAssets", ctx, msg.CollateralAsset, msg.BorrowAsset).Return(nil) mockChecker.On("CheckUserAuthorization", ctx, msg).Return(nil) + mockChecker.On("CheckSamePosition", ctx, msg).Return(nil) mockChecker.On("CheckMaxOpenPositions", ctx).Return(nil) mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) mockChecker.On("PreparePools", ctx, msg.BorrowAsset).Return(poolId, ammtypes.Pool{}, types.Pool{}, nil) @@ -177,6 +186,7 @@ func TestOpen_ErrorInvalidPosition(t *testing.T) { // Mock behavior mockChecker.On("CheckLongingAssets", ctx, msg.CollateralAsset, msg.BorrowAsset).Return(nil) mockChecker.On("CheckUserAuthorization", ctx, msg).Return(nil) + mockChecker.On("CheckSamePosition", ctx, msg).Return(nil) mockChecker.On("CheckMaxOpenPositions", ctx).Return(nil) mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) mockChecker.On("PreparePools", ctx, msg.BorrowAsset).Return(poolId, ammtypes.Pool{}, types.Pool{}, nil) @@ -210,6 +220,7 @@ func TestOpen_ErrorOpenLong(t *testing.T) { // Mock behavior mockChecker.On("CheckLongingAssets", ctx, msg.CollateralAsset, msg.BorrowAsset).Return(nil) mockChecker.On("CheckUserAuthorization", ctx, msg).Return(nil) + mockChecker.On("CheckSamePosition", ctx, msg).Return(nil) mockChecker.On("CheckMaxOpenPositions", ctx).Return(nil) mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) mockChecker.On("PreparePools", ctx, msg.BorrowAsset).Return(poolId, ammtypes.Pool{}, types.Pool{}, nil) @@ -244,6 +255,7 @@ func TestOpen_ErrorOpenShort(t *testing.T) { // Mock behavior mockChecker.On("CheckLongingAssets", ctx, msg.CollateralAsset, msg.BorrowAsset).Return(nil) mockChecker.On("CheckUserAuthorization", ctx, msg).Return(nil) + mockChecker.On("CheckSamePosition", ctx, msg).Return(nil) mockChecker.On("CheckMaxOpenPositions", ctx).Return(nil) mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) mockChecker.On("PreparePools", ctx, msg.BorrowAsset).Return(poolId, ammtypes.Pool{}, types.Pool{}, nil) @@ -279,6 +291,7 @@ func TestOpen_Successful(t *testing.T) { // Mock behavior mockChecker.On("CheckLongingAssets", ctx, msg.CollateralAsset, msg.BorrowAsset).Return(nil) mockChecker.On("CheckUserAuthorization", ctx, msg).Return(nil) + mockChecker.On("CheckSamePosition", ctx, msg).Return(nil) mockChecker.On("CheckMaxOpenPositions", ctx).Return(nil) mockChecker.On("GetTradingAsset", msg.CollateralAsset, msg.BorrowAsset).Return(msg.BorrowAsset) mockChecker.On("PreparePools", ctx, msg.BorrowAsset).Return(poolId, ammtypes.Pool{}, types.Pool{}, nil) diff --git a/x/margin/types/expected_keepers.go b/x/margin/types/expected_keepers.go index b27d21583..93907e732 100644 --- a/x/margin/types/expected_keepers.go +++ b/x/margin/types/expected_keepers.go @@ -38,6 +38,10 @@ type OpenChecker interface { OpenLong(ctx sdk.Context, poolId uint64, msg *MsgOpen) (*MTP, error) OpenShort(ctx sdk.Context, poolId uint64, msg *MsgOpen) (*MTP, error) EmitOpenEvent(ctx sdk.Context, mtp *MTP) + SetMTP(ctx sdk.Context, mtp *MTP) error + CheckSamePosition(ctx sdk.Context, msg *MsgOpen) *MTP + GetOpenMTPCount(ctx sdk.Context) uint64 + GetMaxOpenPositions(ctx sdk.Context) uint64 } //go:generate mockery --srcpkg . --name OpenLongChecker --structname OpenLongChecker --filename open_long_checker.go --with-expecter diff --git a/x/margin/types/mocks/open_checker.go b/x/margin/types/mocks/open_checker.go index 9b7b0d9ef..472802669 100644 --- a/x/margin/types/mocks/open_checker.go +++ b/x/margin/types/mocks/open_checker.go @@ -153,6 +153,51 @@ func (_c *OpenChecker_CheckPoolHealth_Call) RunAndReturn(run func(types.Context, return _c } +// CheckSamePosition provides a mock function with given fields: ctx, msg +func (_m *OpenChecker) CheckSamePosition(ctx types.Context, msg *margintypes.MsgOpen) *margintypes.MTP { + ret := _m.Called(ctx, msg) + + var r0 *margintypes.MTP + if rf, ok := ret.Get(0).(func(types.Context, *margintypes.MsgOpen) *margintypes.MTP); ok { + r0 = rf(ctx, msg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*margintypes.MTP) + } + } + + return r0 +} + +// OpenChecker_CheckSamePosition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckSamePosition' +type OpenChecker_CheckSamePosition_Call struct { + *mock.Call +} + +// CheckSamePosition is a helper method to define mock.On call +// - ctx types.Context +// - msg *margintypes.MsgOpen +func (_e *OpenChecker_Expecter) CheckSamePosition(ctx interface{}, msg interface{}) *OpenChecker_CheckSamePosition_Call { + return &OpenChecker_CheckSamePosition_Call{Call: _e.mock.On("CheckSamePosition", ctx, msg)} +} + +func (_c *OpenChecker_CheckSamePosition_Call) Run(run func(ctx types.Context, msg *margintypes.MsgOpen)) *OpenChecker_CheckSamePosition_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.Context), args[1].(*margintypes.MsgOpen)) + }) + return _c +} + +func (_c *OpenChecker_CheckSamePosition_Call) Return(_a0 *margintypes.MTP) *OpenChecker_CheckSamePosition_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *OpenChecker_CheckSamePosition_Call) RunAndReturn(run func(types.Context, *margintypes.MsgOpen) *margintypes.MTP) *OpenChecker_CheckSamePosition_Call { + _c.Call.Return(run) + return _c +} + // CheckUserAuthorization provides a mock function with given fields: ctx, msg func (_m *OpenChecker) CheckUserAuthorization(ctx types.Context, msg *margintypes.MsgOpen) error { ret := _m.Called(ctx, msg) @@ -230,6 +275,90 @@ func (_c *OpenChecker_EmitOpenEvent_Call) RunAndReturn(run func(types.Context, * return _c } +// GetMaxOpenPositions provides a mock function with given fields: ctx +func (_m *OpenChecker) GetMaxOpenPositions(ctx types.Context) uint64 { + ret := _m.Called(ctx) + + var r0 uint64 + if rf, ok := ret.Get(0).(func(types.Context) uint64); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// OpenChecker_GetMaxOpenPositions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMaxOpenPositions' +type OpenChecker_GetMaxOpenPositions_Call struct { + *mock.Call +} + +// GetMaxOpenPositions is a helper method to define mock.On call +// - ctx types.Context +func (_e *OpenChecker_Expecter) GetMaxOpenPositions(ctx interface{}) *OpenChecker_GetMaxOpenPositions_Call { + return &OpenChecker_GetMaxOpenPositions_Call{Call: _e.mock.On("GetMaxOpenPositions", ctx)} +} + +func (_c *OpenChecker_GetMaxOpenPositions_Call) Run(run func(ctx types.Context)) *OpenChecker_GetMaxOpenPositions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.Context)) + }) + return _c +} + +func (_c *OpenChecker_GetMaxOpenPositions_Call) Return(_a0 uint64) *OpenChecker_GetMaxOpenPositions_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *OpenChecker_GetMaxOpenPositions_Call) RunAndReturn(run func(types.Context) uint64) *OpenChecker_GetMaxOpenPositions_Call { + _c.Call.Return(run) + return _c +} + +// GetOpenMTPCount provides a mock function with given fields: ctx +func (_m *OpenChecker) GetOpenMTPCount(ctx types.Context) uint64 { + ret := _m.Called(ctx) + + var r0 uint64 + if rf, ok := ret.Get(0).(func(types.Context) uint64); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// OpenChecker_GetOpenMTPCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOpenMTPCount' +type OpenChecker_GetOpenMTPCount_Call struct { + *mock.Call +} + +// GetOpenMTPCount is a helper method to define mock.On call +// - ctx types.Context +func (_e *OpenChecker_Expecter) GetOpenMTPCount(ctx interface{}) *OpenChecker_GetOpenMTPCount_Call { + return &OpenChecker_GetOpenMTPCount_Call{Call: _e.mock.On("GetOpenMTPCount", ctx)} +} + +func (_c *OpenChecker_GetOpenMTPCount_Call) Run(run func(ctx types.Context)) *OpenChecker_GetOpenMTPCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.Context)) + }) + return _c +} + +func (_c *OpenChecker_GetOpenMTPCount_Call) Return(_a0 uint64) *OpenChecker_GetOpenMTPCount_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *OpenChecker_GetOpenMTPCount_Call) RunAndReturn(run func(types.Context) uint64) *OpenChecker_GetOpenMTPCount_Call { + _c.Call.Return(run) + return _c +} + // GetTradingAsset provides a mock function with given fields: collateralAsset, borrowAsset func (_m *OpenChecker) GetTradingAsset(collateralAsset string, borrowAsset string) string { ret := _m.Called(collateralAsset, borrowAsset) @@ -452,6 +581,49 @@ func (_c *OpenChecker_PreparePools_Call) RunAndReturn(run func(types.Context, st return _c } +// SetMTP provides a mock function with given fields: ctx, mtp +func (_m *OpenChecker) SetMTP(ctx types.Context, mtp *margintypes.MTP) error { + ret := _m.Called(ctx, mtp) + + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, *margintypes.MTP) error); ok { + r0 = rf(ctx, mtp) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// OpenChecker_SetMTP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMTP' +type OpenChecker_SetMTP_Call struct { + *mock.Call +} + +// SetMTP is a helper method to define mock.On call +// - ctx types.Context +// - mtp *margintypes.MTP +func (_e *OpenChecker_Expecter) SetMTP(ctx interface{}, mtp interface{}) *OpenChecker_SetMTP_Call { + return &OpenChecker_SetMTP_Call{Call: _e.mock.On("SetMTP", ctx, mtp)} +} + +func (_c *OpenChecker_SetMTP_Call) Run(run func(ctx types.Context, mtp *margintypes.MTP)) *OpenChecker_SetMTP_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.Context), args[1].(*margintypes.MTP)) + }) + return _c +} + +func (_c *OpenChecker_SetMTP_Call) Return(_a0 error) *OpenChecker_SetMTP_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *OpenChecker_SetMTP_Call) RunAndReturn(run func(types.Context, *margintypes.MTP) error) *OpenChecker_SetMTP_Call { + _c.Call.Return(run) + return _c +} + // NewOpenChecker creates a new instance of OpenChecker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewOpenChecker(t interface {