Skip to content

Commit

Permalink
fix: use snapshot to support revert (#90)
Browse files Browse the repository at this point in the history
* use snapshot to support revert

* add test for call revert
  • Loading branch information
beer-1 authored Oct 29, 2024
1 parent 5a41c08 commit 0ff2791
Show file tree
Hide file tree
Showing 14 changed files with 774 additions and 155 deletions.
25 changes: 23 additions & 2 deletions x/evm/contracts/counter/Counter.go

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions x/evm/contracts/counter/Counter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ contract Counter is IIBCAsyncCallback {
return COSMOS_CONTRACT.query_cosmos(path, req);
}

function execute_cosmos(
string memory exec_msg,
bool call_revert
) external {
COSMOS_CONTRACT.execute_cosmos(exec_msg);

if (call_revert) {
revert("revert");
}
}

function get_blockhash(uint64 n) external view returns (bytes32) {
return blockhash(n);
}
Expand Down
14 changes: 7 additions & 7 deletions x/evm/keeper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ func (k Keeper) CreateEVM(ctx context.Context, caller common.Address, tracer *tr
return ctx, nil, err
}

// prepare SDK context for EVM execution
ctx, err = prepareSDKContext(sdk.UnwrapSDKContext(ctx))
if err != nil {
return ctx, nil, err
}

evm := &vm.EVM{}
blockContext, err := k.buildBlockContext(ctx, evm, fee)
if err != nil {
Expand All @@ -181,19 +187,13 @@ func (k Keeper) CreateEVM(ctx context.Context, caller common.Address, tracer *tr
NumRetainBlockHashes: &params.NumRetainBlockHashes,
}

// prepare SDK context for EVM execution
ctx, err = prepareSDKContext(sdk.UnwrapSDKContext(ctx))
if err != nil {
return ctx, nil, err
}

*evm = *vm.NewEVMWithPrecompiles(
blockContext,
txContext,
stateDB,
types.DefaultChainConfig(ctx),
vmConfig,
k.precompiles.toMap(ctx),
k.precompiles.toMap(stateDB),
)

if tracer != nil {
Expand Down
61 changes: 61 additions & 0 deletions x/evm/keeper/context_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package keeper_test

import (
"fmt"
"strings"
"testing"

Expand Down Expand Up @@ -237,3 +238,63 @@ func Test_RecursiveDepth(t *testing.T) {
_, _, err = input.EVMKeeper.EVMCall(ctx, caller, contractAddr, inputBz, nil, nil)
require.ErrorContains(t, err, types.ErrExceedMaxRecursiveDepth.Error())
}

func Test_RevertAfterExecuteCosmos(t *testing.T) {
ctx, input := createDefaultTestInput(t)
_, _, addr := keyPubAddr()

counterBz, err := hexutil.Decode(counter.CounterBin)
require.NoError(t, err)

// deploy counter contract
caller := common.BytesToAddress(addr.Bytes())
retBz, contractAddr, _, err := input.EVMKeeper.EVMCreate(ctx, caller, counterBz, nil, nil)
require.NoError(t, err)
require.NotEmpty(t, retBz)
require.Len(t, contractAddr, 20)

// call execute cosmos function
parsed, err := counter.CounterMetaData.GetAbi()
require.NoError(t, err)

denom := sdk.DefaultBondDenom
amount := math.NewInt(1000000000)
input.Faucet.Mint(ctx, contractAddr.Bytes(), sdk.NewCoin(denom, amount))

// call execute_cosmos with revert
inputBz, err := parsed.Pack("execute_cosmos",
fmt.Sprintf(`{"@type":"/cosmos.bank.v1beta1.MsgSend","from_address":"%s","to_address":"%s","amount":[{"denom":"%s","amount":"%s"}]}`,
sdk.AccAddress(contractAddr.Bytes()).String(),
addr.String(), // caller
denom,
amount,
),
true,
)
require.NoError(t, err)

_, _, err = input.EVMKeeper.EVMCall(ctx, caller, contractAddr, inputBz, nil, nil)
require.ErrorContains(t, err, types.ErrReverted.Error())

// check balance
require.Equal(t, amount, input.BankKeeper.GetBalance(ctx, sdk.AccAddress(contractAddr.Bytes()), denom).Amount)
require.Equal(t, math.ZeroInt(), input.BankKeeper.GetBalance(ctx, addr, denom).Amount)

// call execute_cosmos without revert
inputBz, err = parsed.Pack("execute_cosmos",
fmt.Sprintf(`{"@type":"/cosmos.bank.v1beta1.MsgSend","from_address":"%s","to_address":"%s","amount":[{"denom":"%s","amount":"%s"}]}`,
sdk.AccAddress(contractAddr.Bytes()).String(),
addr.String(), // caller
denom,
amount,
),
false,
)
require.NoError(t, err)

_, _, err = input.EVMKeeper.EVMCall(ctx, caller, contractAddr, inputBz, nil, nil)
require.NoError(t, err, types.ErrReverted.Error())

require.Equal(t, math.ZeroInt(), input.BankKeeper.GetBalance(ctx, sdk.AccAddress(contractAddr.Bytes()), denom).Amount)
require.Equal(t, amount, input.BankKeeper.GetBalance(ctx, addr, denom).Amount)
}
6 changes: 2 additions & 4 deletions x/evm/keeper/precompiles.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package keeper

import (
"context"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/vm"

Expand Down Expand Up @@ -55,10 +53,10 @@ func (k *Keeper) loadPrecompiles() error {
type precompiles []precompile

// toMap converts the precompiles to a map.
func (ps precompiles) toMap(ctx context.Context) map[common.Address]vm.PrecompiledContract {
func (ps precompiles) toMap(stateDB types.StateDB) map[common.Address]vm.PrecompiledContract {
m := make(map[common.Address]vm.PrecompiledContract)
for _, p := range ps {
m[p.addr] = p.contract.(types.WithContext).WithContext(ctx)
m[p.addr] = p.contract.(types.WithStateDB).WithStateDB(stateDB)
}

return m
Expand Down
Loading

0 comments on commit 0ff2791

Please sign in to comment.