From 7a113092f8d740cde13e2ddc251e7a61f6970c03 Mon Sep 17 00:00:00 2001 From: Dori Medini Date: Tue, 12 Nov 2024 10:18:39 +0200 Subject: [PATCH] refactor(blockifier_reexecution): allow overriding chain ID explicitly --- crates/blockifier_reexecution/src/main.rs | 53 ++++++++++++++++++----- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/crates/blockifier_reexecution/src/main.rs b/crates/blockifier_reexecution/src/main.rs index 261994643d6..9b6ff4791ef 100644 --- a/crates/blockifier_reexecution/src/main.rs +++ b/crates/blockifier_reexecution/src/main.rs @@ -11,6 +11,7 @@ use blockifier_reexecution::state_reader::utils::{ }; use clap::{Args, Parser, Subcommand}; use starknet_api::block::BlockNumber; +use starknet_api::core::ChainId; use starknet_gateway::config::RpcStateReaderConfig; /// BlockifierReexecution CLI. @@ -24,13 +25,38 @@ pub struct BlockifierReexecutionCliArgs { command: Command, } +#[derive(clap::ValueEnum, Clone, Debug)] +enum SupportedChainId { + Mainnet, + Testnet, +} + +impl From for ChainId { + fn from(chain_id: SupportedChainId) -> Self { + match chain_id { + SupportedChainId::Mainnet => ChainId::Mainnet, + SupportedChainId::Testnet => ChainId::Sepolia, + } + } +} + +#[derive(Debug, Args)] +struct RpcArgs { + /// Node url. + #[clap(long, short = 'n')] + node_url: String, + + /// Optional chain ID (if not provided, it will be guessed from the node url). + #[clap(long, short = 'c')] + chain_id: Option, +} + #[derive(Debug, Subcommand)] enum Command { /// Runs the RPC test. RpcTest { - /// Node url. - #[clap(long, short = 'n')] - node_url: String, + #[clap(flatten)] + rpc_args: RpcArgs, /// Block number. #[clap(long, short = 'b')] @@ -39,9 +65,8 @@ enum Command { /// Writes the RPC queries to json files. WriteRpcRepliesToJson { - /// Node url. - #[clap(long, short = 'n')] - node_url: String, + #[clap(flatten)] + rpc_args: RpcArgs, /// Block number. #[clap(long, short = 'b')] @@ -76,7 +101,7 @@ fn main() { let args = BlockifierReexecutionCliArgs::parse(); match args.command { - Command::RpcTest { node_url, block_number } => { + Command::RpcTest { block_number, rpc_args: RpcArgs { node_url, chain_id } } => { println!("Running RPC test for block number {block_number} using node url {node_url}.",); let config = RpcStateReaderConfig { @@ -87,7 +112,9 @@ fn main() { reexecute_and_verify_correctness(ConsecutiveTestStateReaders::new( BlockNumber(block_number - 1), Some(config), - guess_chain_id_from_node_url(node_url.as_str()).unwrap(), + chain_id + .map(ChainId::from) + .unwrap_or(guess_chain_id_from_node_url(node_url.as_str()).unwrap()), false, )); @@ -96,12 +123,18 @@ fn main() { println!("RPC test passed successfully."); } - Command::WriteRpcRepliesToJson { node_url, block_number, full_file_path } => { + Command::WriteRpcRepliesToJson { + block_number, + full_file_path, + rpc_args: RpcArgs { node_url, chain_id }, + } => { let full_file_path = full_file_path.unwrap_or(format!( "./crates/blockifier_reexecution/resources/block_{block_number}/reexecution_data.\ json" )); - let chain_id = guess_chain_id_from_node_url(node_url.as_str()).unwrap(); + let chain_id = chain_id + .map(ChainId::from) + .unwrap_or(guess_chain_id_from_node_url(node_url.as_str()).unwrap()); // TODO(Aner): refactor to reduce code duplication. let config = RpcStateReaderConfig {