diff --git a/zk/tests/zk_counters_test.go b/zk/tests/zk_counters_test.go index dcc2f38d8e0..11d1c824033 100644 --- a/zk/tests/zk_counters_test.go +++ b/zk/tests/zk_counters_test.go @@ -17,6 +17,7 @@ import ( "github.com/ledgerwatch/erigon-lib/common/hexutil" "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv/memdb" + "github.com/ledgerwatch/erigon/consensus" "github.com/ledgerwatch/erigon/consensus/ethash/ethashcfg" "github.com/ledgerwatch/erigon/core" "github.com/ledgerwatch/erigon/core/rawdb" @@ -38,13 +39,13 @@ import ( "github.com/status-im/keycard-go/hexutils" ) -const root = "./testdata" -const transactionGasLimit = 30000000 - -var ( - noop = state.NewNoopWriter() +const ( + root = "./testdata" + transactionGasLimit = 30000000 ) +var noop = state.NewNoopWriter() + type vector struct { BatchL2Data string `json:"batchL2Data"` BatchL2DataDecoded []byte @@ -84,6 +85,12 @@ type vector struct { } `json:"txs"` } +type vectorTest struct { + idx int + fileName string + vector vector +} + func Test_RunTestVectors(t *testing.T) { // we need to ensure we're running in a sequencer context to wrap the jump table os.Setenv(seq.SEQUENCER_ENV_KEY, "1") @@ -92,47 +99,136 @@ func Test_RunTestVectors(t *testing.T) { m := mock.Mock(t) blockReader, _ := m.BlocksIO() - files, err := os.ReadDir(root) + vectorTests, err := vectorTests() if err != nil { - t.Fatal(err) + t.Fatalf("could not get vector tests: %v", err) + } + + for idx, test := range vectorTests { + t.Run(test.fileName, func(t *testing.T) { + runTest(t, blockReader, vectorTests[idx]) + }) } +} - var tests []vector - var fileNames []string +func vectorTests() ([]vectorTest, error) { + files, err := os.ReadDir(root) + if err != nil { + return nil, fmt.Errorf("could not read directory %s: %v", root, err) + } + var tests []vectorTest for _, file := range files { - var inner []vector + var vectors []vector contents, err := os.ReadFile(fmt.Sprintf("%s/%s", root, file.Name())) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("could not read file %s: %v", file.Name(), err) } - if err = json.Unmarshal(contents, &inner); err != nil { - t.Fatal(err) + if err = json.Unmarshal(contents, &vectors); err != nil { + return nil, fmt.Errorf("could not unmarshal file %s: %v", file.Name(), err) } - for i := len(inner) - 1; i >= 0; i-- { - fileNames = append(fileNames, file.Name()) + for i := len(vectors) - 1; i >= 0; i-- { + tests = append(tests, vectorTest{ + idx: i, + fileName: file.Name(), + vector: vectors[i], + }) } - tests = append(tests, inner...) } - for idx, test := range tests { - t.Run(fileNames[idx], func(t *testing.T) { - runTest(t, blockReader, test, err, fileNames[idx], idx) - }) + return tests, nil +} + +func runTest(t *testing.T, blockReader services.FullBlockReader, test vectorTest) { + // arrange + decodedBlocks, err := decodeBlocks(test.vector, test.fileName) + if err != nil { + t.Fatalf("could not decode blocks: %v", err) + } + + db, tx := memdb.NewTestTx(t) + defer db.Close() + defer tx.Rollback() + + for _, table := range kv.ChaindataTables { + if err = tx.CreateBucket(table); err != nil { + t.Fatalf("could not create bucket: %v", err) + } + } + + genesisBlock, err := writeGenesisState(&test.vector, decodedBlocks, test.idx, tx) + if err != nil { + t.Fatalf("could not write genesis state: %v", err) + } + + genesisRoot := genesisBlock.Root() + expectedGenesisRoot := common.HexToHash(test.vector.ExpectedOldRoot) + if genesisRoot != expectedGenesisRoot { + t.Fatal("genesis root did not match expected") + } + + header := &types.Header{ + Number: big.NewInt(1), + Difficulty: big.NewInt(0), + } + + chainCfg := chainConfig(test.vector.ChainId) + ethashCfg := newEthashConfig() + vmCfg := newVmConfig() + + engine := newEngine(chainCfg, ethashCfg, blockReader, db) + ibs := state.New(state.NewPlainStateReader(tx)) + + shouldVerifyMerkleProof := false + if test.vector.Txs[0].Type == 11 { + if test.vector.Txs[0].IndexL1InfoTree != 0 { + shouldVerifyMerkleProof = true + } + if err := updateGER(test.vector, ibs, chainCfg); err != nil { + t.Fatalf("could not update ger: %v", err) + } + } + + batchCollector := vm.NewBatchCounterCollector(test.vector.SmtDepths[0], uint16(test.vector.ForkId), 0.6, false, nil) + + if err = applyTransactionsToDecodedBlocks( + decodedBlocks, + test.vector, + chainCfg, + vmCfg, + header, + tx, + engine, + ibs, + batchCollector, + ); err != nil { + t.Fatalf("could not apply transactions: %v", err) + } + + // act + errors, err := testVirtualCounters(test.vector, batchCollector, shouldVerifyMerkleProof) + if err != nil { + t.Fatalf("could not test virtual counters: %v", err) + } + + // assert + if len(errors) > 0 { + t.Errorf("counter mismath in file %s: %s \n", test.fileName, strings.Join(errors, " ")) } } -func runTest(t *testing.T, blockReader services.FullBlockReader, test vector, err error, fileName string, idx int) { - test.BatchL2DataDecoded, err = hex.DecodeHex(test.BatchL2Data) +func decodeBlocks(v vector, fileName string) ([]tx.DecodedBatchL2Data, error) { + batchL2DataDecoded, err := hex.DecodeHex(v.BatchL2Data) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("could not decode batchL2Data: %v", err) } - decodedBlocks, err := tx.DecodeBatchL2Blocks(test.BatchL2DataDecoded, test.ForkId) + decodedBlocks, err := tx.DecodeBatchL2Blocks(batchL2DataDecoded, v.ForkId) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("could not decode batchL2Blocks: %v", err) } + fmt.Println(decodedBlocks[0].Transactions) if len(decodedBlocks) == 0 { fmt.Printf("found no blocks in file %s", fileName) } @@ -142,35 +238,29 @@ func runTest(t *testing.T, blockReader services.FullBlockReader, test vector, er } } - db, tx := memdb.NewTestTx(t) - defer db.Close() - defer tx.Rollback() - - for _, table := range kv.ChaindataTables { - if err = tx.CreateBucket(table); err != nil { - t.Fatal(err) - } - } + return decodedBlocks, nil +} +func writeGenesisState(v *vector, decodedBlocks []tx.DecodedBatchL2Data, idx int, tx kv.RwTx) (*types.Block, error) { genesisAccounts := map[common.Address]types.GenesisAccount{} - for _, g := range test.Genesis { + for _, g := range v.Genesis { addr := common.HexToAddress(g.Address) key, err := hex.DecodeHex(g.PvtKey) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("could not decode private key: %v", err) } nonce, err := strconv.ParseUint(g.Nonce, 10, 64) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("could not parse nonce: %v", err) } balance, ok := new(big.Int).SetString(g.Balance, 10) if !ok { - t.Fatal(errors.New("could not parse balance")) + return nil, errors.New("could not parse balance") } code, err := hex.DecodeHex(g.ByteCode) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("could not decode bytecode: %v", err) } acc := types.GenesisAccount{ Balance: balance, @@ -185,39 +275,28 @@ func runTest(t *testing.T, blockReader services.FullBlockReader, test vector, er genesis := &types.Genesis{ Alloc: genesisAccounts, Config: &chain.Config{ - ChainID: big.NewInt(test.ChainId), + ChainID: big.NewInt(v.ChainId), }, } genesisBlock, _, sparseTree, err := core.WriteGenesisState(genesis, tx, fmt.Sprintf("%s/temp-%v", os.TempDir(), idx), log.New()) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("could not write genesis state: %v", err) } smtDepth := sparseTree.GetDepth() - for len(test.SmtDepths) < len(decodedBlocks) { - test.SmtDepths = append(test.SmtDepths, smtDepth) + for len(v.SmtDepths) < len(decodedBlocks) { + v.SmtDepths = append(v.SmtDepths, smtDepth) } - if len(test.SmtDepths) == 0 { - test.SmtDepths = append(test.SmtDepths, smtDepth) + if len(v.SmtDepths) == 0 { + v.SmtDepths = append(v.SmtDepths, smtDepth) } - genesisRoot := genesisBlock.Root() - expectedGenesisRoot := common.HexToHash(test.ExpectedOldRoot) - if genesisRoot != expectedGenesisRoot { - t.Fatal("genesis root did not match expected") - } - - sequencer := common.HexToAddress(test.SequencerAddress) - - header := &types.Header{ - Number: big.NewInt(1), - Difficulty: big.NewInt(0), - } - getHeader := func(hash common.Hash, number uint64) *types.Header { return rawdb.ReadHeader(tx, hash, number) } - blockHashFunc := core.GetHashFn(header, getHeader) + return genesisBlock, nil +} +func chainConfig(chainId int64) *chain.Config { chainConfig := params.ChainConfigByChainName("hermez-dev") - chainConfig.ChainID = big.NewInt(test.ChainId) + chainConfig.ChainID = big.NewInt(chainId) chainConfig.ForkID4Block = big.NewInt(0) chainConfig.ForkID5DragonfruitBlock = big.NewInt(0) @@ -225,7 +304,11 @@ func runTest(t *testing.T, blockReader services.FullBlockReader, test vector, er chainConfig.ForkID7EtrogBlock = big.NewInt(0) chainConfig.ForkID88ElderberryBlock = big.NewInt(0) - ethashCfg := ðashcfg.Config{ + return chainConfig +} + +func newEthashConfig() *ethashcfg.Config { + return ðashcfg.Config{ CachesInMem: 1, CachesLockMmap: true, DatasetDir: "./dataset", @@ -236,11 +319,10 @@ func runTest(t *testing.T, blockReader services.FullBlockReader, test vector, er NotifyFull: false, Log: nil, } +} - logger := log.New() - engine := ethconsensusconfig.CreateConsensusEngine(context.Background(), &nodecfg.Config{Dirs: datadir.New("./datadir")}, chainConfig, ethashCfg, []string{}, true, heimdall.NewHeimdallClient("", logger), true, blockReader, db.ReadOnly(), logger) - - vmCfg := vm.ZkConfig{ +func newVmConfig() vm.ZkConfig { + return vm.ZkConfig{ Config: vm.Config{ Debug: false, Tracer: nil, @@ -255,75 +337,93 @@ func runTest(t *testing.T, blockReader services.FullBlockReader, test vector, er ExtraEips: nil, }, } +} - stateReader := state.NewPlainStateReader(tx) - ibs := state.New(stateReader) - verifyMerkleProof := false +func updateGER(v vector, ibs *state.IntraBlockState, chainCfg *chain.Config) error { + parentRoot := common.Hash{} + deltaTimestamp, err := strconv.ParseUint(v.Txs[0].DeltaTimestamp, 10, 64) + if err != nil { + return fmt.Errorf("could not parse delta timestamp: %v", err) + } + ibs.PreExecuteStateSet(chainCfg, 1, deltaTimestamp, &parentRoot) - if test.Txs[0].Type == 11 { - if test.Txs[0].IndexL1InfoTree != 0 { - verifyMerkleProof = true - } - parentRoot := common.Hash{} - deltaTimestamp, _ := strconv.ParseUint(test.Txs[0].DeltaTimestamp, 10, 64) - ibs.PreExecuteStateSet(chainConfig, 1, deltaTimestamp, &parentRoot) - - // handle writing to the ger manager contract - if test.Txs[0].L1Info != nil { - timestamp, _ := strconv.ParseUint(test.Txs[0].L1Info.Timestamp, 10, 64) - ger := string(test.Txs[0].L1Info.GlobalExitRoot) - blockHash := string(test.Txs[0].L1Info.BlockHash) - - hexutil.Remove0xPrefixIfExists(&ger) - hexutil.Remove0xPrefixIfExists(&blockHash) - - l1info := &zktypes.L1InfoTreeUpdate{ - GER: common.BytesToHash(hexutils.HexToBytes(ger)), - ParentHash: common.BytesToHash(hexutils.HexToBytes(blockHash)), - Timestamp: timestamp, - } - // first check if this ger has already been written - l1BlockHash := ibs.ReadGerManagerL1BlockHash(l1info.GER) - if l1BlockHash == (common.Hash{}) { - // not in the contract so let's write it! - ibs.WriteGerManagerL1BlockHash(l1info.GER, l1info.ParentHash) - } - } + if v.Txs[0].L1Info == nil { + return nil } - batchCollector := vm.NewBatchCounterCollector(test.SmtDepths[0], uint16(test.ForkId), 0.6, false, nil) + // handle writing to the ger manager contract + timestamp, err := strconv.ParseUint(v.Txs[0].L1Info.Timestamp, 10, 64) + if err != nil { + return fmt.Errorf("could not parse timestamp: %v", err) + } + ger := string(v.Txs[0].L1Info.GlobalExitRoot) + blockHash := string(v.Txs[0].L1Info.BlockHash) + + hexutil.Remove0xPrefixIfExists(&ger) + hexutil.Remove0xPrefixIfExists(&blockHash) + + l1info := &zktypes.L1InfoTreeUpdate{ + GER: common.BytesToHash(hexutils.HexToBytes(ger)), + ParentHash: common.BytesToHash(hexutils.HexToBytes(blockHash)), + Timestamp: timestamp, + } + // first check if this ger has already been written + l1BlockHash := ibs.ReadGerManagerL1BlockHash(l1info.GER) + if l1BlockHash == (common.Hash{}) { + // not in the contract so let's write it! + ibs.WriteGerManagerL1BlockHash(l1info.GER, l1info.ParentHash) + } + + return nil +} + +func applyTransactionsToDecodedBlocks( + decodedBlocks []tx.DecodedBatchL2Data, + v vector, + chainCfg *chain.Config, + vmCfg vm.ZkConfig, + header *types.Header, + tx kv.RwTx, + engine consensus.Engine, + ibs *state.IntraBlockState, + batchCollector *vm.BatchCounterCollector, +) error { + sequencerAddress := common.HexToAddress(v.SequencerAddress) + + getHeader := func(hash common.Hash, number uint64) *types.Header { return rawdb.ReadHeader(tx, hash, number) } + blockHashFunc := core.GetHashFn(header, getHeader) blockStarted := false for i, block := range decodedBlocks { for _, transaction := range block.Transactions { - vmCfg.Config.SkipAnalysis = core.SkipAnalysis(chainConfig, header.Number.Uint64()) + vmCfg.Config.SkipAnalysis = core.SkipAnalysis(chainCfg, header.Number.Uint64()) - blockContext := core.NewEVMBlockContext(header, blockHashFunc, engine, &sequencer) + blockContext := core.NewEVMBlockContext(header, blockHashFunc, engine, &sequencerAddress) if !blockStarted { overflow, err := batchCollector.StartNewBlock(false) if err != nil { - t.Fatal(err) + return fmt.Errorf("could not start new block: %v", err) } if overflow { - t.Fatal("unexpected overflow") + return fmt.Errorf("unexpected overflow") } blockStarted = true } - txCounters := vm.NewTransactionCounter(transaction, test.SmtDepths[i], uint16(test.ForkId), 0.6, false) + txCounters := vm.NewTransactionCounter(transaction, v.SmtDepths[i], uint16(v.ForkId), 0.6, false) overflow, err := batchCollector.AddNewTransactionCounters(txCounters) if err != nil { - t.Fatal(err) + return fmt.Errorf("could not add new transaction counters: %v", err) } gasPool := new(core.GasPool).AddGas(transactionGasLimit) vmCfg.CounterCollector = txCounters.ExecutionCounters() - evm := vm.NewZkEVM(blockContext, evmtypes.TxContext{}, ibs, chainConfig, vmCfg) + evm := vm.NewZkEVM(blockContext, evmtypes.TxContext{}, ibs, chainCfg, vmCfg) _, result, err := core.ApplyTransaction_zkevm( - chainConfig, + chainCfg, engine, evm, gasPool, @@ -335,30 +435,50 @@ func runTest(t *testing.T, blockReader services.FullBlockReader, test vector, er zktypes.EFFECTIVE_GAS_PRICE_PERCENTAGE_MAXIMUM, true, ) - if err != nil { // this could be deliberate in the test so just move on and note it fmt.Println("err handling tx", err) continue } if overflow { - t.Fatal("unexpected overflow") + return fmt.Errorf("unexpected overflow") } if err = txCounters.ProcessTx(ibs, result.ReturnData); err != nil { - t.Fatal(err) + return fmt.Errorf("could not process tx: %v", err) } batchCollector.UpdateExecutionAndProcessingCountersCache(txCounters) } } - combined, err := batchCollector.CombineCollectors(verifyMerkleProof) + return nil +} + +func newEngine(chainCfg *chain.Config, ethashCfg *ethashcfg.Config, blockReader services.FullBlockReader, db kv.RwDB) consensus.Engine { + logger := log.New() + return ethconsensusconfig.CreateConsensusEngine( + context.Background(), + &nodecfg.Config{Dirs: datadir.New("./datadir")}, + chainCfg, + ethashCfg, + []string{}, + true, + heimdall.NewHeimdallClient("", logger), + true, + blockReader, + db.ReadOnly(), + logger, + ) +} + +func testVirtualCounters(v vector, batchCollector *vm.BatchCounterCollector, shouldVerifyMerkleProof bool) ([]string, error) { + combined, err := batchCollector.CombineCollectors(shouldVerifyMerkleProof) if err != nil { - t.Fatal(err) + return nil, fmt.Errorf("could not combine collectors: %v", err) } - vc := test.VirtualCounters + vc := v.VirtualCounters var errors []string if vc.Keccaks != combined[vm.K].Used() { @@ -385,7 +505,6 @@ func runTest(t *testing.T, blockReader services.FullBlockReader, test vector, er if vc.Steps != combined[vm.S].Used() { errors = append(errors, fmt.Sprintf("S=%v:%v", combined[vm.S].Used(), vc.Steps)) } - if len(errors) > 0 { - t.Errorf("counter mismath in file %s: %s \n", fileName, strings.Join(errors, " ")) - } + + return errors, nil }