diff --git a/x/evm/keeper/context.go b/x/evm/keeper/context.go index 870cad4..927bcad 100644 --- a/x/evm/keeper/context.go +++ b/x/evm/keeper/context.go @@ -161,6 +161,12 @@ func (k Keeper) CreateEVM(ctx context.Context, caller common.Address, tracer *tr return ctx, nil, err } + // prepare SDK context for EVM execution + ctx, err = prepareSDKContext(sdk.UnwrapSDKContext(ctx)) + if err != nil { + return ctx, nil, err + } + evm := &vm.EVM{} blockContext, err := k.buildBlockContext(ctx, evm, fee) if err != nil { @@ -181,19 +187,13 @@ func (k Keeper) CreateEVM(ctx context.Context, caller common.Address, tracer *tr NumRetainBlockHashes: ¶ms.NumRetainBlockHashes, } - // prepare SDK context for EVM execution - ctx, err = prepareSDKContext(sdk.UnwrapSDKContext(ctx)) - if err != nil { - return ctx, nil, err - } - *evm = *vm.NewEVMWithPrecompiles( blockContext, txContext, stateDB, types.DefaultChainConfig(ctx), vmConfig, - k.precompiles.toMap(ctx), + k.precompiles.toMap(stateDB), ) if tracer != nil { diff --git a/x/evm/keeper/precompiles.go b/x/evm/keeper/precompiles.go index 5f7a39c..e290b68 100644 --- a/x/evm/keeper/precompiles.go +++ b/x/evm/keeper/precompiles.go @@ -1,8 +1,6 @@ package keeper import ( - "context" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/vm" @@ -55,10 +53,10 @@ func (k *Keeper) loadPrecompiles() error { type precompiles []precompile // toMap converts the precompiles to a map. -func (ps precompiles) toMap(ctx context.Context) map[common.Address]vm.PrecompiledContract { +func (ps precompiles) toMap(stateDB types.StateDB) map[common.Address]vm.PrecompiledContract { m := make(map[common.Address]vm.PrecompiledContract) for _, p := range ps { - m[p.addr] = p.contract.(types.WithContext).WithContext(ctx) + m[p.addr] = p.contract.(types.WithStateDB).WithStateDB(stateDB) } return m diff --git a/x/evm/precompiles/cosmos/common_test.go b/x/evm/precompiles/cosmos/common_test.go new file mode 100644 index 0000000..f52d9e2 --- /dev/null +++ b/x/evm/precompiles/cosmos/common_test.go @@ -0,0 +1,352 @@ +package cosmosprecompile_test + +import ( + "context" + + "cosmossdk.io/core/address" + "github.com/cosmos/cosmos-sdk/baseapp" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/stateless" + "github.com/ethereum/go-ethereum/core/tracing" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie/utils" + "github.com/holiman/uint256" + + "github.com/initia-labs/minievm/x/evm/state" + evmtypes "github.com/initia-labs/minievm/x/evm/types" +) + +var _ evmtypes.StateDB = &MockStateDB{} + +type MockStateDB struct { + ctx sdk.Context + initialCtx sdk.Context + + // Snapshot stack + snaps []*state.Snapshot +} + +func NewMockStateDB(ctx sdk.Context) *MockStateDB { + return &MockStateDB{ + ctx: ctx, + initialCtx: ctx, + } +} + +// Snapshot implements types.StateDB. +func (m *MockStateDB) Snapshot() int { + // get a current snapshot id + sid := len(m.snaps) - 1 + + // create a new snapshot + snap := state.NewSnapshot(m.ctx) + m.snaps = append(m.snaps, snap) + + // use the new snapshot context + m.ctx = snap.Context() + + // return the current snapshot id + return sid +} + +// RevertToSnapshot implements types.StateDB. +func (m *MockStateDB) RevertToSnapshot(i int) { + if i == -1 { + m.ctx = m.initialCtx + m.snaps = m.snaps[:0] + return + } + + // revert to the snapshot with the given id + snap := m.snaps[i] + m.ctx = snap.Context() + + // clear the snapshots after the given id + m.snaps = m.snaps[:i] +} + +// ContextOfSnapshot implements types.StateDB. +func (m *MockStateDB) ContextOfSnapshot(i int) sdk.Context { + if i == -1 { + return m.initialCtx + } + + return m.snaps[i].Context() +} + +//////////////////////// MOCKED METHODS //////////////////////// + +// AddAddressToAccessList implements types.StateDB. +func (m *MockStateDB) AddAddressToAccessList(addr common.Address) { + panic("unimplemented") +} + +// AddBalance implements types.StateDB. +func (m *MockStateDB) AddBalance(common.Address, *uint256.Int, tracing.BalanceChangeReason) { + panic("unimplemented") +} + +// AddLog implements types.StateDB. +func (m *MockStateDB) AddLog(*types.Log) { + panic("unimplemented") +} + +// AddPreimage implements types.StateDB. +func (m *MockStateDB) AddPreimage(common.Hash, []byte) { + panic("unimplemented") +} + +// AddRefund implements types.StateDB. +func (m *MockStateDB) AddRefund(uint64) { + panic("unimplemented") +} + +// AddSlotToAccessList implements types.StateDB. +func (m *MockStateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { + panic("unimplemented") +} + +// AddressInAccessList implements types.StateDB. +func (m *MockStateDB) AddressInAccessList(addr common.Address) bool { + panic("unimplemented") +} + +// CreateAccount implements types.StateDB. +func (m *MockStateDB) CreateAccount(common.Address) { + panic("unimplemented") +} + +// CreateContract implements types.StateDB. +func (m *MockStateDB) CreateContract(common.Address) { + panic("unimplemented") +} + +// Empty implements types.StateDB. +func (m *MockStateDB) Empty(common.Address) bool { + panic("unimplemented") +} + +// Exist implements types.StateDB. +func (m *MockStateDB) Exist(common.Address) bool { + panic("unimplemented") +} + +// GetBalance implements types.StateDB. +func (m *MockStateDB) GetBalance(common.Address) *uint256.Int { + panic("unimplemented") +} + +// GetCode implements types.StateDB. +func (m *MockStateDB) GetCode(common.Address) []byte { + panic("unimplemented") +} + +// GetCodeHash implements types.StateDB. +func (m *MockStateDB) GetCodeHash(common.Address) common.Hash { + panic("unimplemented") +} + +// GetCodeSize implements types.StateDB. +func (m *MockStateDB) GetCodeSize(common.Address) int { + panic("unimplemented") +} + +// GetCommittedState implements types.StateDB. +func (m *MockStateDB) GetCommittedState(common.Address, common.Hash) common.Hash { + panic("unimplemented") +} + +// GetNonce implements types.StateDB. +func (m *MockStateDB) GetNonce(common.Address) uint64 { + panic("unimplemented") +} + +// GetRefund implements types.StateDB. +func (m *MockStateDB) GetRefund() uint64 { + panic("unimplemented") +} + +// GetState implements types.StateDB. +func (m *MockStateDB) GetState(common.Address, common.Hash) common.Hash { + panic("unimplemented") +} + +// GetStorageRoot implements types.StateDB. +func (m *MockStateDB) GetStorageRoot(addr common.Address) common.Hash { + panic("unimplemented") +} + +// GetTransientState implements types.StateDB. +func (m *MockStateDB) GetTransientState(addr common.Address, key common.Hash) common.Hash { + panic("unimplemented") +} + +// HasSelfDestructed implements types.StateDB. +func (m *MockStateDB) HasSelfDestructed(common.Address) bool { + panic("unimplemented") +} + +// PointCache implements types.StateDB. +func (m *MockStateDB) PointCache() *utils.PointCache { + panic("unimplemented") +} + +// Prepare implements types.StateDB. +func (m *MockStateDB) Prepare(rules params.Rules, sender common.Address, coinbase common.Address, dest *common.Address, precompiles []common.Address, txAccesses types.AccessList) { + panic("unimplemented") +} + +// SelfDestruct implements types.StateDB. +func (m *MockStateDB) SelfDestruct(common.Address) { + panic("unimplemented") +} + +// Selfdestruct6780 implements types.StateDB. +func (m *MockStateDB) Selfdestruct6780(common.Address) { + panic("unimplemented") +} + +// SetCode implements types.StateDB. +func (m *MockStateDB) SetCode(common.Address, []byte) { + panic("unimplemented") +} + +// SetNonce implements types.StateDB. +func (m *MockStateDB) SetNonce(common.Address, uint64) { + panic("unimplemented") +} + +// SetState implements types.StateDB. +func (m *MockStateDB) SetState(common.Address, common.Hash, common.Hash) { + panic("unimplemented") +} + +// SetTransientState implements types.StateDB. +func (m *MockStateDB) SetTransientState(addr common.Address, key common.Hash, value common.Hash) { + panic("unimplemented") +} + +// SlotInAccessList implements types.StateDB. +func (m *MockStateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressOk bool, slotOk bool) { + panic("unimplemented") +} + +// SubBalance implements types.StateDB. +func (m *MockStateDB) SubBalance(common.Address, *uint256.Int, tracing.BalanceChangeReason) { + panic("unimplemented") +} + +// SubRefund implements types.StateDB. +func (m *MockStateDB) SubRefund(uint64) { + panic("unimplemented") +} + +// Witness implements types.StateDB. +func (m *MockStateDB) Witness() *stateless.Witness { + panic("unimplemented") +} + +var _ evmtypes.AccountKeeper = &MockAccountKeeper{} + +// mock account keeper for testing +type MockAccountKeeper struct { + ac address.Codec + accounts map[string]sdk.AccountI +} + +// GetAccount implements types.AccountKeeper. +func (k MockAccountKeeper) GetAccount(ctx context.Context, addr sdk.AccAddress) sdk.AccountI { + str, _ := k.ac.BytesToString(addr.Bytes()) + return k.accounts[str] +} + +// HasAccount implements types.AccountKeeper. +func (k MockAccountKeeper) HasAccount(ctx context.Context, addr sdk.AccAddress) bool { + str, _ := k.ac.BytesToString(addr.Bytes()) + _, ok := k.accounts[str] + return ok +} + +// NewAccount implements types.AccountKeeper. +func (k *MockAccountKeeper) NewAccount(ctx context.Context, acc sdk.AccountI) sdk.AccountI { + acc.SetAccountNumber(uint64(len(k.accounts))) + return acc +} + +// NewAccountWithAddress implements types.AccountKeeper. +func (k MockAccountKeeper) NewAccountWithAddress(ctx context.Context, addr sdk.AccAddress) sdk.AccountI { + return authtypes.NewBaseAccount(addr, nil, uint64(len(k.accounts)), 0) +} + +// NextAccountNumber implements types.AccountKeeper. +func (k MockAccountKeeper) NextAccountNumber(ctx context.Context) uint64 { + return uint64(len(k.accounts)) +} + +// SetAccount implements types.AccountKeeper. +func (k MockAccountKeeper) SetAccount(ctx context.Context, acc sdk.AccountI) { + str, _ := k.ac.BytesToString(acc.GetAddress().Bytes()) + k.accounts[str] = acc +} + +// RemoveAccount implements types.AccountKeeper. +func (k MockAccountKeeper) RemoveAccount(ctx context.Context, acc sdk.AccountI) { + str, _ := k.ac.BytesToString(acc.GetAddress().Bytes()) + delete(k.accounts, str) +} + +var _ evmtypes.BankKeeper = &MockBankKeeper{} + +// mock bank keeper for testing +type MockBankKeeper struct { + ac address.Codec + blockedAddresses map[string]bool +} + +// BlockedAddr implements types.BankKeeper. +func (k MockBankKeeper) BlockedAddr(addr sdk.AccAddress) bool { + str, _ := k.ac.BytesToString(addr.Bytes()) + return k.blockedAddresses[str] +} + +var _ evmtypes.GRPCRouter = MockGRPCRouter{} + +type MockGRPCRouter struct { + routes map[string]baseapp.GRPCQueryHandler +} + +func (router MockGRPCRouter) Route(path string) baseapp.GRPCQueryHandler { + return router.routes[path] +} + +var _ evmtypes.ERC20DenomKeeper = &MockERC20DenomKeeper{} + +type MockERC20DenomKeeper struct { + denomMap map[string]common.Address + addrMap map[common.Address]string +} + +// GetContractAddrByDenom implements types.ERC20DenomKeeper. +func (e *MockERC20DenomKeeper) GetContractAddrByDenom(_ context.Context, denom string) (common.Address, error) { + addr, found := e.denomMap[denom] + if !found { + return common.Address{}, sdkerrors.ErrNotFound + } + + return addr, nil +} + +// GetDenomByContractAddr implements types.ERC20DenomKeeper. +func (e *MockERC20DenomKeeper) GetDenomByContractAddr(_ context.Context, addr common.Address) (string, error) { + denom, found := e.addrMap[addr] + if !found { + return "", sdkerrors.ErrNotFound + } + + return denom, nil +} diff --git a/x/evm/precompiles/cosmos/contract.go b/x/evm/precompiles/cosmos/contract.go index 78faf3a..33e42e2 100644 --- a/x/evm/precompiles/cosmos/contract.go +++ b/x/evm/precompiles/cosmos/contract.go @@ -24,14 +24,14 @@ import ( var _ vm.ExtendedPrecompiledContract = CosmosPrecompile{} var _ vm.PrecompiledContract = CosmosPrecompile{} -var _ types.WithContext = CosmosPrecompile{} +var _ types.WithStateDB = CosmosPrecompile{} type CosmosPrecompile struct { *abi.ABI - ctx context.Context - cdc codec.Codec - ac address.Codec + stateDB types.StateDB + cdc codec.Codec + ac address.Codec ak types.AccountKeeper bk types.BankKeeper @@ -67,8 +67,8 @@ func NewCosmosPrecompile( }, nil } -func (e CosmosPrecompile) WithContext(ctx context.Context) vm.PrecompiledContract { - e.ctx = ctx +func (e CosmosPrecompile) WithStateDB(stateDB types.StateDB) vm.PrecompiledContract { + e.stateDB = stateDB return e } @@ -88,6 +88,9 @@ func (e CosmosPrecompile) originAddress(ctx context.Context, addrBz []byte) (sdk // ExtendedRun implements vm.ExtendedPrecompiledContract. func (e CosmosPrecompile) ExtendedRun(caller vm.ContractRef, input []byte, suppliedGas uint64, readOnly bool) (resBz []byte, usedGas uint64, err error) { + snapshot := e.stateDB.Snapshot() + ctx := e.stateDB.ContextOfSnapshot(snapshot).WithGasMeter(storetypes.NewGasMeter(suppliedGas)) + defer func() { if r := recover(); r != nil { switch r.(type) { @@ -99,6 +102,10 @@ func (e CosmosPrecompile) ExtendedRun(caller vm.ContractRef, input []byte, suppl panic(r) } } + + if err != nil { + e.stateDB.RevertToSnapshot(snapshot) + } }() method, err := e.ABI.MethodById(input) @@ -111,8 +118,6 @@ func (e CosmosPrecompile) ExtendedRun(caller vm.ContractRef, input []byte, suppl return nil, 0, types.ErrPrecompileFailed.Wrap(err.Error()) } - ctx := sdk.UnwrapSDKContext(e.ctx).WithGasMeter(storetypes.NewGasMeter(suppliedGas)) - // charge input gas ctx.GasMeter().ConsumeGas(storetypes.Gas(len(input))*GAS_PER_BYTE, "input bytes") diff --git a/x/evm/precompiles/cosmos/contract_test.go b/x/evm/precompiles/cosmos/contract_test.go index e319d5b..07ad717 100644 --- a/x/evm/precompiles/cosmos/contract_test.go +++ b/x/evm/precompiles/cosmos/contract_test.go @@ -1,7 +1,6 @@ package cosmosprecompile_test import ( - "context" "fmt" "testing" "time" @@ -67,7 +66,8 @@ func Test_CosmosPrecompile_IsBlockedAddress(t *testing.T) { cosmosPrecompile, err := precompiles.NewCosmosPrecompile(cdc, ac, ak, bk, nil, nil, nil) require.NoError(t, err) - cosmosPrecompile = cosmosPrecompile.WithContext(ctx).(precompiles.CosmosPrecompile) + stateDB := NewMockStateDB(ctx) + cosmosPrecompile = cosmosPrecompile.WithStateDB(stateDB).(precompiles.CosmosPrecompile) evmAddr := common.HexToAddress("0x1") cosmosAddr, err := ac.BytesToString(evmAddr.Bytes()) @@ -112,7 +112,8 @@ func Test_CosmosPrecompile_IsModuleAddress(t *testing.T) { cosmosPrecompile, err := precompiles.NewCosmosPrecompile(cdc, ac, ak, bk, nil, nil, nil) require.NoError(t, err) - cosmosPrecompile = cosmosPrecompile.WithContext(ctx).(precompiles.CosmosPrecompile) + stateDB := NewMockStateDB(ctx) + cosmosPrecompile = cosmosPrecompile.WithStateDB(stateDB).(precompiles.CosmosPrecompile) evmAddr := common.HexToAddress("0x1") cosmosAddr, err := ac.BytesToString(evmAddr.Bytes()) @@ -157,7 +158,8 @@ func Test_CosmosPrecompile_ToCosmosAddress(t *testing.T) { cosmosPrecompile, err := precompiles.NewCosmosPrecompile(cdc, ac, ak, bk, nil, nil, nil) require.NoError(t, err) - cosmosPrecompile = cosmosPrecompile.WithContext(ctx).(precompiles.CosmosPrecompile) + stateDB := NewMockStateDB(ctx) + cosmosPrecompile = cosmosPrecompile.WithStateDB(stateDB).(precompiles.CosmosPrecompile) evmAddr := common.HexToAddress("0x1") cosmosAddr, err := ac.BytesToString(evmAddr.Bytes()) @@ -187,7 +189,8 @@ func Test_CosmosPrecompile_ToEVMAddress(t *testing.T) { cosmosPrecompile, err := precompiles.NewCosmosPrecompile(cdc, ac, ak, bk, nil, nil, nil) require.NoError(t, err) - cosmosPrecompile = cosmosPrecompile.WithContext(ctx).(precompiles.CosmosPrecompile) + stateDB := NewMockStateDB(ctx) + cosmosPrecompile = cosmosPrecompile.WithStateDB(stateDB).(precompiles.CosmosPrecompile) evmAddr := common.HexToAddress("0x1") cosmosAddr, err := ac.BytesToString(evmAddr.Bytes()) @@ -217,7 +220,8 @@ func Test_ExecuteCosmos(t *testing.T) { cosmosPrecompile, err := precompiles.NewCosmosPrecompile(cdc, ac, ak, bk, nil, nil, nil) require.NoError(t, err) - cosmosPrecompile = cosmosPrecompile.WithContext(ctx).(precompiles.CosmosPrecompile) + stateDB := NewMockStateDB(ctx) + cosmosPrecompile = cosmosPrecompile.WithStateDB(stateDB).(precompiles.CosmosPrecompile) evmAddr := common.HexToAddress("0x1") cosmosAddr, err := ac.BytesToString(evmAddr.Bytes()) @@ -316,7 +320,8 @@ func Test_QueryCosmos(t *testing.T) { }) require.NoError(t, err) - cosmosPrecompile = cosmosPrecompile.WithContext(ctx).(precompiles.CosmosPrecompile) + stateDB := NewMockStateDB(ctx) + cosmosPrecompile = cosmosPrecompile.WithStateDB(stateDB).(precompiles.CosmosPrecompile) evmAddr := common.HexToAddress("0x1") @@ -362,7 +367,8 @@ func Test_ToDenom(t *testing.T) { }, nil, nil) require.NoError(t, err) - cosmosPrecompile = cosmosPrecompile.WithContext(ctx).(precompiles.CosmosPrecompile) + stateDB := NewMockStateDB(ctx) + cosmosPrecompile = cosmosPrecompile.WithStateDB(stateDB).(precompiles.CosmosPrecompile) evmAddr := common.HexToAddress("0x1") @@ -403,7 +409,8 @@ func Test_ToErc20(t *testing.T) { }, nil, nil) require.NoError(t, err) - cosmosPrecompile = cosmosPrecompile.WithContext(ctx).(precompiles.CosmosPrecompile) + stateDB := NewMockStateDB(ctx) + cosmosPrecompile = cosmosPrecompile.WithStateDB(stateDB).(precompiles.CosmosPrecompile) evmAddr := common.HexToAddress("0x1") @@ -427,103 +434,3 @@ func Test_ToErc20(t *testing.T) { require.NoError(t, err) require.Equal(t, erc20Addr, unpackedRet[0].(common.Address)) } - -var _ types.AccountKeeper = &MockAccountKeeper{} - -// mock account keeper for testing -type MockAccountKeeper struct { - ac address.Codec - accounts map[string]sdk.AccountI -} - -// GetAccount implements types.AccountKeeper. -func (k MockAccountKeeper) GetAccount(ctx context.Context, addr sdk.AccAddress) sdk.AccountI { - str, _ := k.ac.BytesToString(addr.Bytes()) - return k.accounts[str] -} - -// HasAccount implements types.AccountKeeper. -func (k MockAccountKeeper) HasAccount(ctx context.Context, addr sdk.AccAddress) bool { - str, _ := k.ac.BytesToString(addr.Bytes()) - _, ok := k.accounts[str] - return ok -} - -// NewAccount implements types.AccountKeeper. -func (k *MockAccountKeeper) NewAccount(ctx context.Context, acc sdk.AccountI) sdk.AccountI { - acc.SetAccountNumber(uint64(len(k.accounts))) - return acc -} - -// NewAccountWithAddress implements types.AccountKeeper. -func (k MockAccountKeeper) NewAccountWithAddress(ctx context.Context, addr sdk.AccAddress) sdk.AccountI { - return authtypes.NewBaseAccount(addr, nil, uint64(len(k.accounts)), 0) -} - -// NextAccountNumber implements types.AccountKeeper. -func (k MockAccountKeeper) NextAccountNumber(ctx context.Context) uint64 { - return uint64(len(k.accounts)) -} - -// SetAccount implements types.AccountKeeper. -func (k MockAccountKeeper) SetAccount(ctx context.Context, acc sdk.AccountI) { - str, _ := k.ac.BytesToString(acc.GetAddress().Bytes()) - k.accounts[str] = acc -} - -// RemoveAccount implements types.AccountKeeper. -func (k MockAccountKeeper) RemoveAccount(ctx context.Context, acc sdk.AccountI) { - str, _ := k.ac.BytesToString(acc.GetAddress().Bytes()) - delete(k.accounts, str) -} - -var _ types.BankKeeper = &MockBankKeeper{} - -// mock bank keeper for testing -type MockBankKeeper struct { - ac address.Codec - blockedAddresses map[string]bool -} - -// BlockedAddr implements types.BankKeeper. -func (k MockBankKeeper) BlockedAddr(addr sdk.AccAddress) bool { - str, _ := k.ac.BytesToString(addr.Bytes()) - return k.blockedAddresses[str] -} - -var _ types.GRPCRouter = MockGRPCRouter{} - -type MockGRPCRouter struct { - routes map[string]baseapp.GRPCQueryHandler -} - -func (router MockGRPCRouter) Route(path string) baseapp.GRPCQueryHandler { - return router.routes[path] -} - -var _ types.ERC20DenomKeeper = &MockERC20DenomKeeper{} - -type MockERC20DenomKeeper struct { - denomMap map[string]common.Address - addrMap map[common.Address]string -} - -// GetContractAddrByDenom implements types.ERC20DenomKeeper. -func (e *MockERC20DenomKeeper) GetContractAddrByDenom(_ context.Context, denom string) (common.Address, error) { - addr, found := e.denomMap[denom] - if !found { - return common.Address{}, sdkerrors.ErrNotFound - } - - return addr, nil -} - -// GetDenomByContractAddr implements types.ERC20DenomKeeper. -func (e *MockERC20DenomKeeper) GetDenomByContractAddr(_ context.Context, addr common.Address) (string, error) { - denom, found := e.addrMap[addr] - if !found { - return "", sdkerrors.ErrNotFound - } - - return denom, nil -} diff --git a/x/evm/precompiles/erc20_registry/common_test.go b/x/evm/precompiles/erc20_registry/common_test.go new file mode 100644 index 0000000..dba0ab6 --- /dev/null +++ b/x/evm/precompiles/erc20_registry/common_test.go @@ -0,0 +1,246 @@ +package erc20registryprecompile_test + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/stateless" + "github.com/ethereum/go-ethereum/core/tracing" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie/utils" + "github.com/holiman/uint256" + + "github.com/initia-labs/minievm/x/evm/state" + evmtypes "github.com/initia-labs/minievm/x/evm/types" +) + +var _ evmtypes.StateDB = &MockStateDB{} + +type MockStateDB struct { + ctx sdk.Context + initialCtx sdk.Context + + // Snapshot stack + snaps []*state.Snapshot +} + +func NewMockStateDB(ctx sdk.Context) *MockStateDB { + return &MockStateDB{ + ctx: ctx, + initialCtx: ctx, + } +} + +// Snapshot implements types.StateDB. +func (m *MockStateDB) Snapshot() int { + // get a current snapshot id + sid := len(m.snaps) - 1 + + // create a new snapshot + snap := state.NewSnapshot(m.ctx) + m.snaps = append(m.snaps, snap) + + // use the new snapshot context + m.ctx = snap.Context() + + // return the current snapshot id + return sid +} + +// RevertToSnapshot implements types.StateDB. +func (m *MockStateDB) RevertToSnapshot(i int) { + if i == -1 { + m.ctx = m.initialCtx + m.snaps = m.snaps[:0] + return + } + + // revert to the snapshot with the given id + snap := m.snaps[i] + m.ctx = snap.Context() + + // clear the snapshots after the given id + m.snaps = m.snaps[:i] +} + +// ContextOfSnapshot implements types.StateDB. +func (m *MockStateDB) ContextOfSnapshot(i int) sdk.Context { + if i == -1 { + return m.initialCtx + } + + return m.snaps[i].Context() +} + +//////////////////////// MOCKED METHODS //////////////////////// + +// AddAddressToAccessList implements types.StateDB. +func (m *MockStateDB) AddAddressToAccessList(addr common.Address) { + panic("unimplemented") +} + +// AddBalance implements types.StateDB. +func (m *MockStateDB) AddBalance(common.Address, *uint256.Int, tracing.BalanceChangeReason) { + panic("unimplemented") +} + +// AddLog implements types.StateDB. +func (m *MockStateDB) AddLog(*types.Log) { + panic("unimplemented") +} + +// AddPreimage implements types.StateDB. +func (m *MockStateDB) AddPreimage(common.Hash, []byte) { + panic("unimplemented") +} + +// AddRefund implements types.StateDB. +func (m *MockStateDB) AddRefund(uint64) { + panic("unimplemented") +} + +// AddSlotToAccessList implements types.StateDB. +func (m *MockStateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { + panic("unimplemented") +} + +// AddressInAccessList implements types.StateDB. +func (m *MockStateDB) AddressInAccessList(addr common.Address) bool { + panic("unimplemented") +} + +// CreateAccount implements types.StateDB. +func (m *MockStateDB) CreateAccount(common.Address) { + panic("unimplemented") +} + +// CreateContract implements types.StateDB. +func (m *MockStateDB) CreateContract(common.Address) { + panic("unimplemented") +} + +// Empty implements types.StateDB. +func (m *MockStateDB) Empty(common.Address) bool { + panic("unimplemented") +} + +// Exist implements types.StateDB. +func (m *MockStateDB) Exist(common.Address) bool { + panic("unimplemented") +} + +// GetBalance implements types.StateDB. +func (m *MockStateDB) GetBalance(common.Address) *uint256.Int { + panic("unimplemented") +} + +// GetCode implements types.StateDB. +func (m *MockStateDB) GetCode(common.Address) []byte { + panic("unimplemented") +} + +// GetCodeHash implements types.StateDB. +func (m *MockStateDB) GetCodeHash(common.Address) common.Hash { + panic("unimplemented") +} + +// GetCodeSize implements types.StateDB. +func (m *MockStateDB) GetCodeSize(common.Address) int { + panic("unimplemented") +} + +// GetCommittedState implements types.StateDB. +func (m *MockStateDB) GetCommittedState(common.Address, common.Hash) common.Hash { + panic("unimplemented") +} + +// GetNonce implements types.StateDB. +func (m *MockStateDB) GetNonce(common.Address) uint64 { + panic("unimplemented") +} + +// GetRefund implements types.StateDB. +func (m *MockStateDB) GetRefund() uint64 { + panic("unimplemented") +} + +// GetState implements types.StateDB. +func (m *MockStateDB) GetState(common.Address, common.Hash) common.Hash { + panic("unimplemented") +} + +// GetStorageRoot implements types.StateDB. +func (m *MockStateDB) GetStorageRoot(addr common.Address) common.Hash { + panic("unimplemented") +} + +// GetTransientState implements types.StateDB. +func (m *MockStateDB) GetTransientState(addr common.Address, key common.Hash) common.Hash { + panic("unimplemented") +} + +// HasSelfDestructed implements types.StateDB. +func (m *MockStateDB) HasSelfDestructed(common.Address) bool { + panic("unimplemented") +} + +// PointCache implements types.StateDB. +func (m *MockStateDB) PointCache() *utils.PointCache { + panic("unimplemented") +} + +// Prepare implements types.StateDB. +func (m *MockStateDB) Prepare(rules params.Rules, sender common.Address, coinbase common.Address, dest *common.Address, precompiles []common.Address, txAccesses types.AccessList) { + panic("unimplemented") +} + +// SelfDestruct implements types.StateDB. +func (m *MockStateDB) SelfDestruct(common.Address) { + panic("unimplemented") +} + +// Selfdestruct6780 implements types.StateDB. +func (m *MockStateDB) Selfdestruct6780(common.Address) { + panic("unimplemented") +} + +// SetCode implements types.StateDB. +func (m *MockStateDB) SetCode(common.Address, []byte) { + panic("unimplemented") +} + +// SetNonce implements types.StateDB. +func (m *MockStateDB) SetNonce(common.Address, uint64) { + panic("unimplemented") +} + +// SetState implements types.StateDB. +func (m *MockStateDB) SetState(common.Address, common.Hash, common.Hash) { + panic("unimplemented") +} + +// SetTransientState implements types.StateDB. +func (m *MockStateDB) SetTransientState(addr common.Address, key common.Hash, value common.Hash) { + panic("unimplemented") +} + +// SlotInAccessList implements types.StateDB. +func (m *MockStateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressOk bool, slotOk bool) { + panic("unimplemented") +} + +// SubBalance implements types.StateDB. +func (m *MockStateDB) SubBalance(common.Address, *uint256.Int, tracing.BalanceChangeReason) { + panic("unimplemented") +} + +// SubRefund implements types.StateDB. +func (m *MockStateDB) SubRefund(uint64) { + panic("unimplemented") +} + +// Witness implements types.StateDB. +func (m *MockStateDB) Witness() *stateless.Witness { + panic("unimplemented") +} diff --git a/x/evm/precompiles/erc20_registry/contract.go b/x/evm/precompiles/erc20_registry/contract.go index 593b7ea..3e2739b 100644 --- a/x/evm/precompiles/erc20_registry/contract.go +++ b/x/evm/precompiles/erc20_registry/contract.go @@ -1,14 +1,12 @@ package erc20registryprecompile import ( - "context" "errors" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/core/vm" storetypes "cosmossdk.io/store/types" - sdk "github.com/cosmos/cosmos-sdk/types" "github.com/initia-labs/minievm/x/evm/contracts/i_erc20_registry" "github.com/initia-labs/minievm/x/evm/types" @@ -16,12 +14,12 @@ import ( var _ vm.ExtendedPrecompiledContract = ERC20RegistryPrecompile{} var _ vm.PrecompiledContract = ERC20RegistryPrecompile{} -var _ types.WithContext = ERC20RegistryPrecompile{} +var _ types.WithStateDB = ERC20RegistryPrecompile{} type ERC20RegistryPrecompile struct { *abi.ABI - ctx context.Context - k types.IERC20StoresKeeper + stateDB types.StateDB + k types.IERC20StoresKeeper } func NewERC20RegistryPrecompile(k types.IERC20StoresKeeper) (ERC20RegistryPrecompile, error) { @@ -33,8 +31,8 @@ func NewERC20RegistryPrecompile(k types.IERC20StoresKeeper) (ERC20RegistryPrecom return ERC20RegistryPrecompile{ABI: abi, k: k}, nil } -func (e ERC20RegistryPrecompile) WithContext(ctx context.Context) vm.PrecompiledContract { - e.ctx = ctx +func (e ERC20RegistryPrecompile) WithStateDB(stateDB types.StateDB) vm.PrecompiledContract { + e.stateDB = stateDB return e } @@ -47,17 +45,24 @@ const ( // ExtendedRun implements vm.ExtendedPrecompiledContract. func (e ERC20RegistryPrecompile) ExtendedRun(caller vm.ContractRef, input []byte, suppliedGas uint64, readOnly bool) (resBz []byte, usedGas uint64, err error) { + snapshot := e.stateDB.Snapshot() + ctx := e.stateDB.ContextOfSnapshot(snapshot).WithGasMeter(storetypes.NewGasMeter(suppliedGas)) + defer func() { if r := recover(); r != nil { switch r.(type) { case storetypes.ErrorOutOfGas: // convert cosmos out of gas error to EVM out of gas error - usedGas = suppliedGas + 1 - err = nil + usedGas = suppliedGas + err = vm.ErrOutOfGas default: panic(r) } } + + if err != nil { + e.stateDB.RevertToSnapshot(snapshot) + } }() method, err := e.ABI.MethodById(input) @@ -70,7 +75,6 @@ func (e ERC20RegistryPrecompile) ExtendedRun(caller vm.ContractRef, input []byte return nil, 0, types.ErrPrecompileFailed.Wrap(err.Error()) } - ctx := sdk.UnwrapSDKContext(e.ctx).WithGasMeter(storetypes.NewGasMeter(suppliedGas)) ctx.GasMeter().ConsumeGas(storetypes.Gas(len(input))*GAS_PER_BYTE, "input bytes") switch method.Name { diff --git a/x/evm/precompiles/erc20_registry/contract_test.go b/x/evm/precompiles/erc20_registry/contract_test.go index 528fd5a..d7494f9 100644 --- a/x/evm/precompiles/erc20_registry/contract_test.go +++ b/x/evm/precompiles/erc20_registry/contract_test.go @@ -75,7 +75,8 @@ func Test_ERC20RegistryPrecompile(t *testing.T) { require.NoError(t, err) // set context - registry = registry.WithContext(ctx).(precompiles.ERC20RegistryPrecompile) + stateDB := NewMockStateDB(ctx) + registry = registry.WithStateDB(stateDB).(precompiles.ERC20RegistryPrecompile) erc20Addr := common.HexToAddress("0x1") accountAddr := common.HexToAddress("0x2") @@ -89,9 +90,8 @@ func Test_ERC20RegistryPrecompile(t *testing.T) { require.NoError(t, err) // out of gas error - _, gasUsed, err := registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.REGISTER_GAS-1, false) - require.NoError(t, err) - require.Equal(t, gasUsed, uint64(precompiles.REGISTER_GAS)) + _, _, err = registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.REGISTER_GAS-1, false) + require.ErrorIs(t, err, vm.ErrOutOfGas) // non read only method fail _, _, err = registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.REGISTER_GAS+uint64(len(bz)), true) @@ -107,9 +107,8 @@ func Test_ERC20RegistryPrecompile(t *testing.T) { require.NoError(t, err) // out of gas error - _, gasUsed, err = registry.ExtendedRun(vm.AccountRef(erc20FactoryAddr), bz, precompiles.REGISTER_FROM_FACTORY_GAS-1, false) - require.NoError(t, err) - require.Equal(t, gasUsed, uint64(precompiles.REGISTER_FROM_FACTORY_GAS)) + _, _, err = registry.ExtendedRun(vm.AccountRef(erc20FactoryAddr), bz, precompiles.REGISTER_FROM_FACTORY_GAS-1, false) + require.ErrorIs(t, err, vm.ErrOutOfGas) // non read only method fail _, _, err = registry.ExtendedRun(vm.AccountRef(erc20FactoryAddr), bz, precompiles.REGISTER_FROM_FACTORY_GAS+uint64(len(bz)), true) @@ -140,9 +139,8 @@ func Test_ERC20RegistryPrecompile(t *testing.T) { require.NoError(t, err) // out of gas error - _, gasUsed, err = registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.REGISTER_STORE_GAS-1, false) - require.NoError(t, err) - require.Equal(t, gasUsed, uint64(precompiles.REGISTER_STORE_GAS)) + _, _, err = registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.REGISTER_STORE_GAS-1, false) + require.ErrorIs(t, err, vm.ErrOutOfGas) // non read only method fail _, _, err = registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.REGISTER_STORE_GAS+uint64(len(bz)), true) @@ -158,9 +156,8 @@ func Test_ERC20RegistryPrecompile(t *testing.T) { require.NoError(t, err) // out of gas panic - _, gasUsed, err = registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.IS_STORE_REGISTERED_GAS-1, true) - require.NoError(t, err) - require.Equal(t, gasUsed, uint64(precompiles.IS_STORE_REGISTERED_GAS)) + _, _, err = registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.IS_STORE_REGISTERED_GAS-1, true) + require.ErrorIs(t, err, vm.ErrOutOfGas) resBz, usedGas, err = registry.ExtendedRun(vm.AccountRef(erc20Addr), bz, precompiles.IS_STORE_REGISTERED_GAS+uint64(len(bz)), true) require.NoError(t, err) diff --git a/x/evm/state/snapshot.go b/x/evm/state/snapshot.go index de8f266..bed936e 100644 --- a/x/evm/state/snapshot.go +++ b/x/evm/state/snapshot.go @@ -23,3 +23,7 @@ func NewSnapshot(ctx context.Context) *Snapshot { func (s *Snapshot) Commit() { s.commit() } + +func (s *Snapshot) Context() sdk.Context { + return s.ctx +} diff --git a/x/evm/state/statedb.go b/x/evm/state/statedb.go index 4eb7926..682f4b6 100644 --- a/x/evm/state/statedb.go +++ b/x/evm/state/statedb.go @@ -589,6 +589,14 @@ func (s *StateDB) RevertToSnapshot(i int) { s.snaps = s.snaps[:i] } +func (s *StateDB) ContextOfSnapshot(i int) sdk.Context { + if i == -1 { + return s.initialCtx + } + + return s.snaps[i].ctx +} + // Prepare handles the preparatory steps for executing a state transition with. // This method must be invoked before state transition. // diff --git a/x/evm/types/expected_keeper.go b/x/evm/types/expected_keeper.go index 0bcbba5..6546333 100644 --- a/x/evm/types/expected_keeper.go +++ b/x/evm/types/expected_keeper.go @@ -85,8 +85,13 @@ type IERC721Keeper interface { GetTokenInfos(ctx context.Context, classId string, tokenIds []string) (tokenUris []string, tokenData []string, err error) } -type WithContext interface { - WithContext(ctx context.Context) vm.PrecompiledContract +type StateDB interface { + vm.StateDB + ContextOfSnapshot(i int) sdk.Context +} + +type WithStateDB interface { + WithStateDB(stateDB StateDB) vm.PrecompiledContract } type GRPCRouter interface {