diff --git a/crates/blockifier_reexecution/src/main.rs b/crates/blockifier_reexecution/src/main.rs index 261994643d6..f94d9bcc1f0 100644 --- a/crates/blockifier_reexecution/src/main.rs +++ b/crates/blockifier_reexecution/src/main.rs @@ -11,6 +11,7 @@ use blockifier_reexecution::state_reader::utils::{ }; use clap::{Args, Parser, Subcommand}; use starknet_api::block::BlockNumber; +use starknet_api::core::ChainId; use starknet_gateway::config::RpcStateReaderConfig; /// BlockifierReexecution CLI. @@ -24,13 +25,49 @@ pub struct BlockifierReexecutionCliArgs { command: Command, } +#[derive(clap::ValueEnum, Clone, Debug)] +enum SupportedChainId { + Mainnet, + Testnet, + Integration, +} + +impl From for ChainId { + fn from(chain_id: SupportedChainId) -> Self { + match chain_id { + SupportedChainId::Mainnet => Self::Mainnet, + SupportedChainId::Testnet => Self::Sepolia, + SupportedChainId::Integration => Self::IntegrationSepolia, + } + } +} + +#[derive(Debug, Args)] +struct RpcArgs { + /// Node url. + #[clap(long, short = 'n')] + node_url: String, + + /// Optional chain ID (if not provided, it will be guessed from the node url). + #[clap(long, short = 'c')] + chain_id: Option, +} + +impl RpcArgs { + pub(crate) fn parse_chain_id(&self) -> ChainId { + self.chain_id + .clone() + .map(ChainId::from) + .unwrap_or(guess_chain_id_from_node_url(self.node_url.as_str()).unwrap()) + } +} + #[derive(Debug, Subcommand)] enum Command { /// Runs the RPC test. RpcTest { - /// Node url. - #[clap(long, short = 'n')] - node_url: String, + #[clap(flatten)] + rpc_args: RpcArgs, /// Block number. #[clap(long, short = 'b')] @@ -39,9 +76,8 @@ enum Command { /// Writes the RPC queries to json files. WriteRpcRepliesToJson { - /// Node url. - #[clap(long, short = 'n')] - node_url: String, + #[clap(flatten)] + rpc_args: RpcArgs, /// Block number. #[clap(long, short = 'b')] @@ -76,18 +112,21 @@ fn main() { let args = BlockifierReexecutionCliArgs::parse(); match args.command { - Command::RpcTest { node_url, block_number } => { - println!("Running RPC test for block number {block_number} using node url {node_url}.",); + Command::RpcTest { block_number, rpc_args } => { + println!( + "Running RPC test for block number {block_number} using node url {}.", + rpc_args.node_url + ); let config = RpcStateReaderConfig { - url: node_url.clone(), + url: rpc_args.node_url.clone(), json_rpc_version: JSON_RPC_VERSION.to_string(), }; reexecute_and_verify_correctness(ConsecutiveTestStateReaders::new( BlockNumber(block_number - 1), Some(config), - guess_chain_id_from_node_url(node_url.as_str()).unwrap(), + rpc_args.parse_chain_id(), false, )); @@ -96,16 +135,16 @@ fn main() { println!("RPC test passed successfully."); } - Command::WriteRpcRepliesToJson { node_url, block_number, full_file_path } => { + Command::WriteRpcRepliesToJson { block_number, full_file_path, rpc_args } => { let full_file_path = full_file_path.unwrap_or(format!( "./crates/blockifier_reexecution/resources/block_{block_number}/reexecution_data.\ json" )); - let chain_id = guess_chain_id_from_node_url(node_url.as_str()).unwrap(); + let chain_id = rpc_args.parse_chain_id(); // TODO(Aner): refactor to reduce code duplication. let config = RpcStateReaderConfig { - url: node_url, + url: rpc_args.node_url, json_rpc_version: JSON_RPC_VERSION.to_string(), };