diff --git a/Cargo.lock b/Cargo.lock index b33194aee3..2e1a7412ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1537,7 +1537,6 @@ dependencies = [ "blockifier", "cairo-lang-starknet-classes", "cairo-lang-utils", - "cairo-vm", "clap", "flate2", "indexmap 2.6.0", @@ -10444,6 +10443,7 @@ dependencies = [ "futures", "mempool_test_utils", "papyrus_config", + "papyrus_proc_macros", "pretty_assertions", "rstest", "serde", diff --git a/Cargo.toml b/Cargo.toml index 750bb508fa..b13e462cd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,6 @@ members = [ "crates/l1-provider", "crates/mempool", "crates/mempool_infra", - "crates/mempool_node", "crates/mempool_p2p", "crates/mempool_p2p_types", "crates/mempool_test_utils", @@ -39,6 +38,7 @@ members = [ "crates/papyrus_storage", "crates/papyrus_sync", "crates/papyrus_test_utils", + "crates/sequencer_node", "crates/sequencing/papyrus_consensus", "crates/sequencing/papyrus_consensus_orchestrator", "crates/starknet_api", @@ -218,7 +218,7 @@ starknet_mempool_types = { path = "crates/mempool_types", version = "0.0.0" } starknet_monitoring_endpoint = { path = "crates/monitoring_endpoint", version = "0.0.0" } starknet_patricia = { path = "crates/starknet_patricia", version = "0.0.0" } starknet_sequencer_infra = { path = "crates/mempool_infra", version = "0.0.0" } -starknet_sequencer_node = { path = "crates/mempool_node", version = "0.0.0" } +starknet_sequencer_node = { path = "crates/sequencer_node", version = "0.0.0" } starknet_sierra_compile = { path = "crates/starknet_sierra_compile", version = "0.0.0" } starknet_task_executor = { path = "crates/task_executor", version = "0.0.0" } static_assertions = "1.1.0" diff --git a/config/mempool/default_config.json b/config/mempool/default_config.json index 6c153760d4..0525672a8b 100644 --- a/config/mempool/default_config.json +++ b/config/mempool/default_config.json @@ -81,13 +81,13 @@ }, "batcher_config.block_builder_config.chain_info.fee_token_addresses.eth_fee_token_address": { "description": "Address of the ETH fee token.", - "privacy": "Public", - "value": "0x0" + "pointer_target": "eth_fee_token_address", + "privacy": "Public" }, "batcher_config.block_builder_config.chain_info.fee_token_addresses.strk_fee_token_address": { "description": "Address of the STRK fee token.", - "privacy": "Public", - "value": "0x0" + "pointer_target": "strk_fee_token_address", + "privacy": "Public" }, "batcher_config.block_builder_config.execute_config.concurrency_config.chunk_size": { "description": "The size of the transaction chunk executed in parallel.", @@ -204,10 +204,10 @@ "privacy": "Public", "value": 81920 }, - "components.batcher.execution_mode.LocalExecution.enable_remote_connection": { - "description": "Specifies whether the component, when running locally, allows remote connections.", + "components.batcher.execution_mode": { + "description": "The component execution mode.", "privacy": "Public", - "value": false + "value": "LocalExecutionWithRemoteDisabled" }, "components.batcher.local_server_config.#is_none": { "description": "Flag for an optional field.", @@ -254,10 +254,10 @@ "privacy": "Public", "value": "0.0.0.0:8080" }, - "components.consensus_manager.execution_mode.LocalExecution.enable_remote_connection": { - "description": "Specifies whether the component, when running locally, allows remote connections.", + "components.consensus_manager.execution_mode": { + "description": "The component execution mode.", "privacy": "Public", - "value": false + "value": "LocalExecutionWithRemoteDisabled" }, "components.consensus_manager.local_server_config.#is_none": { "description": "Flag for an optional field.", @@ -304,10 +304,10 @@ "privacy": "Public", "value": "0.0.0.0:8080" }, - "components.gateway.execution_mode.LocalExecution.enable_remote_connection": { - "description": "Specifies whether the component, when running locally, allows remote connections.", + "components.gateway.execution_mode": { + "description": "The component execution mode.", "privacy": "Public", - "value": false + "value": "LocalExecutionWithRemoteDisabled" }, "components.gateway.local_server_config.#is_none": { "description": "Flag for an optional field.", @@ -354,10 +354,10 @@ "privacy": "Public", "value": "0.0.0.0:8080" }, - "components.http_server.execution_mode.LocalExecution.enable_remote_connection": { - "description": "Specifies whether the component, when running locally, allows remote connections.", + "components.http_server.execution_mode": { + "description": "The component execution mode.", "privacy": "Public", - "value": true + "value": "LocalExecutionWithRemoteEnabled" }, "components.http_server.local_server_config.#is_none": { "description": "Flag for an optional field.", @@ -404,10 +404,10 @@ "privacy": "Public", "value": "0.0.0.0:8080" }, - "components.mempool.execution_mode.LocalExecution.enable_remote_connection": { - "description": "Specifies whether the component, when running locally, allows remote connections.", + "components.mempool.execution_mode": { + "description": "The component execution mode.", "privacy": "Public", - "value": false + "value": "LocalExecutionWithRemoteDisabled" }, "components.mempool.local_server_config.#is_none": { "description": "Flag for an optional field.", @@ -454,10 +454,10 @@ "privacy": "Public", "value": "0.0.0.0:8080" }, - "components.mempool_p2p.execution_mode.LocalExecution.enable_remote_connection": { - "description": "Specifies whether the component, when running locally, allows remote connections.", + "components.mempool_p2p.execution_mode": { + "description": "The component execution mode.", "privacy": "Public", - "value": false + "value": "LocalExecutionWithRemoteDisabled" }, "components.mempool_p2p.local_server_config.#is_none": { "description": "Flag for an optional field.", @@ -504,10 +504,10 @@ "privacy": "Public", "value": "0.0.0.0:8080" }, - "components.monitoring_endpoint.execution_mode.LocalExecution.enable_remote_connection": { - "description": "Specifies whether the component, when running locally, allows remote connections.", + "components.monitoring_endpoint.execution_mode": { + "description": "The component execution mode.", "privacy": "Public", - "value": true + "value": "LocalExecutionWithRemoteEnabled" }, "components.monitoring_endpoint.local_server_config.#is_none": { "description": "Flag for an optional field.", @@ -594,6 +594,11 @@ "privacy": "Public", "value": "0x0" }, + "eth_fee_token_address": { + "description": "A required param! Address of the ETH fee token.", + "param_type": "String", + "privacy": "TemporaryValue" + }, "gateway_config.chain_info.chain_id": { "description": "The chain ID of the StarkNet chain.", "pointer_target": "chain_id", @@ -601,13 +606,13 @@ }, "gateway_config.chain_info.fee_token_addresses.eth_fee_token_address": { "description": "Address of the ETH fee token.", - "privacy": "Public", - "value": "0x0" + "pointer_target": "eth_fee_token_address", + "privacy": "Public" }, "gateway_config.chain_info.fee_token_addresses.strk_fee_token_address": { "description": "Address of the STRK fee token.", - "privacy": "Public", - "value": "0x0" + "pointer_target": "strk_fee_token_address", + "privacy": "Public" }, "gateway_config.stateful_tx_validator_config.max_nonce_for_validation_skip": { "description": "Maximum nonce for which the validation is skipped.", @@ -793,5 +798,10 @@ "description": "The url of the rpc server.", "privacy": "Public", "value": "" + }, + "strk_fee_token_address": { + "description": "A required param! Address of the STRK fee token.", + "param_type": "String", + "privacy": "TemporaryValue" } } diff --git a/crates/batcher/src/batcher.rs b/crates/batcher/src/batcher.rs index d8086e31d5..f878d69e9a 100644 --- a/crates/batcher/src/batcher.rs +++ b/crates/batcher/src/batcher.rs @@ -16,7 +16,10 @@ use starknet_batcher_types::batcher_types::{ GetProposalContentInput, GetProposalContentResponse, ProposalId, + SendProposalContentInput, + SendProposalContentResponse, StartHeightInput, + ValidateProposalInput, }; use starknet_batcher_types::errors::BatcherError; use starknet_mempool_types::communication::SharedMempoolClient; @@ -99,6 +102,22 @@ impl Batcher { Ok(()) } + #[instrument(skip(self), err)] + pub async fn validate_proposal( + &mut self, + validate_proposal_input: ValidateProposalInput, + ) -> BatcherResult<()> { + todo!(); + } + + #[instrument(skip(self), err)] + pub async fn send_proposal_content( + &mut self, + send_proposal_content_input: SendProposalContentInput, + ) -> BatcherResult { + todo!(); + } + #[instrument(skip(self), err)] pub async fn get_proposal_content( &mut self, diff --git a/crates/batcher/src/block_builder.rs b/crates/batcher/src/block_builder.rs index 92193f9c26..9c471f5740 100644 --- a/crates/batcher/src/block_builder.rs +++ b/crates/batcher/src/block_builder.rs @@ -145,10 +145,10 @@ impl BlockBuilderTrait for BlockBuilder { tx_provider: Box, output_content_sender: tokio::sync::mpsc::UnboundedSender, ) -> BlockBuilderResult { - let mut should_close_block = false; + let mut block_is_full = false; let mut execution_infos = IndexMap::new(); // TODO(yael 6/10/2024): delete the timeout condition once the executor has a timeout - while !should_close_block && tokio::time::Instant::now() < deadline { + while !block_is_full && tokio::time::Instant::now() < deadline { let next_tx_chunk = tx_provider.get_txs(self.tx_chunk_size).await?; debug!("Got {} transactions from the transaction provider.", next_tx_chunk.len()); if next_tx_chunk.is_empty() { @@ -164,7 +164,7 @@ impl BlockBuilderTrait for BlockBuilder { } let results = self.executor.lock().await.add_txs_to_block(&executor_input_chunk); trace!("Transaction execution results: {:?}", results); - should_close_block = collect_execution_results_and_stream_txs( + block_is_full = collect_execution_results_and_stream_txs( next_tx_chunk, results, &mut execution_infos, diff --git a/crates/batcher/src/communication.rs b/crates/batcher/src/communication.rs index 4a0ee89148..d82b984318 100644 --- a/crates/batcher/src/communication.rs +++ b/crates/batcher/src/communication.rs @@ -35,7 +35,12 @@ impl ComponentRequestHandler for Batcher { BatcherRequest::DecisionReached(input) => { BatcherResponse::DecisionReached(self.decision_reached(input).await) } - _ => unimplemented!(), + BatcherRequest::ValidateProposal(input) => { + BatcherResponse::ValidateProposal(self.validate_proposal(input).await) + } + BatcherRequest::SendProposalContent(input) => { + BatcherResponse::SendProposalContent(self.send_proposal_content(input).await) + } } } } diff --git a/crates/batcher/src/papyrus_state.rs b/crates/batcher/src/papyrus_state.rs index 6ef8c8a0d2..2ee1295f74 100644 --- a/crates/batcher/src/papyrus_state.rs +++ b/crates/batcher/src/papyrus_state.rs @@ -1,6 +1,10 @@ // TODO(yael 22/9/2024): This module is copied from native_blockifier, need to how to share it // between the crates. -use blockifier::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV1}; +use blockifier::execution::contract_class::{ + ContractClassV0, + ContractClassV1, + RunnableContractClass, +}; use blockifier::state::errors::StateError; use blockifier::state::global_cache::GlobalContractCache; use blockifier::state::state_api::{StateReader, StateResult}; @@ -40,7 +44,7 @@ impl PapyrusReader { fn get_compiled_contract_class_inner( &self, class_hash: ClassHash, - ) -> StateResult { + ) -> StateResult { let state_number = StateNumber(self.latest_block); let class_declaration_block_number = self .reader()? @@ -60,7 +64,7 @@ impl PapyrusReader { inconsistent.", ); - return Ok(ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?)); + return Ok(RunnableContractClass::V1(ContractClassV1::try_from(casm_contract_class)?)); } let v0_contract_class = self @@ -118,7 +122,10 @@ impl StateReader for PapyrusReader { } } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { // Assumption: the global cache is cleared upon reverted blocks. let contract_class = self.global_class_hash_to_class.get(&class_hash); diff --git a/crates/batcher_types/src/batcher_types.rs b/crates/batcher_types/src/batcher_types.rs index 994a631bb0..49b3f3f1af 100644 --- a/crates/batcher_types/src/batcher_types.rs +++ b/crates/batcher_types/src/batcher_types.rs @@ -59,6 +59,7 @@ pub enum GetProposalContent { pub struct ValidateProposalInput { pub proposal_id: ProposalId, pub deadline: chrono::DateTime, + pub retrospective_block_hash: Option, } impl BuildProposalInput { diff --git a/crates/blockifier/src/blockifier/transaction_executor_test.rs b/crates/blockifier/src/blockifier/transaction_executor_test.rs index d739b4662a..4259c207d6 100644 --- a/crates/blockifier/src/blockifier/transaction_executor_test.rs +++ b/crates/blockifier/src/blockifier/transaction_executor_test.rs @@ -20,6 +20,7 @@ use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::deploy_account_tx; use crate::test_utils::initial_test_state::test_state; +use crate::test_utils::l1_handler::l1handler_tx; use crate::test_utils::{ create_calldata, maybe_dummy_block_hash_and_number, @@ -38,7 +39,6 @@ use crate::transaction::test_utils::{ TestInitData, }; use crate::transaction::transaction_execution::Transaction; -use crate::transaction::transactions::L1HandlerTransaction; fn tx_executor_test_body( state: CachedState, block_context: BlockContext, @@ -230,7 +230,7 @@ fn test_l1_handler(block_context: BlockContext) { let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1); let state = test_state(&block_context.chain_info, BALANCE, &[(test_contract, 1)]); - let tx = Transaction::L1Handler(L1HandlerTransaction::create_for_testing( + let tx = Transaction::L1Handler(l1handler_tx( Fee(1908000000000000), test_contract.get_instance_address(0), )); diff --git a/crates/blockifier/src/concurrency/fee_utils_test.rs b/crates/blockifier/src/concurrency/fee_utils_test.rs index ba6cc8df95..44619d8cf1 100644 --- a/crates/blockifier/src/concurrency/fee_utils_test.rs +++ b/crates/blockifier/src/concurrency/fee_utils_test.rs @@ -16,20 +16,20 @@ use crate::transaction::objects::FeeType; use crate::transaction::test_utils::{ account_invoke_tx, block_context, - default_l1_resource_bounds, + default_all_resource_bounds, }; #[rstest] pub fn test_fill_sequencer_balance_reads( block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] erc20_version: CairoVersion, ) { let account = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); let account_tx = account_invoke_tx(invoke_tx_args! { sender_address: account.get_instance_address(0), calldata: create_trivial_calldata(account.get_instance_address(0)), - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, }); let chain_info = &block_context.chain_info; let state = &mut test_state_inner(chain_info, BALANCE, &[(account, 1)], erc20_version); diff --git a/crates/blockifier/src/concurrency/versioned_state.rs b/crates/blockifier/src/concurrency/versioned_state.rs index a6edb590ee..33d97670b9 100644 --- a/crates/blockifier/src/concurrency/versioned_state.rs +++ b/crates/blockifier/src/concurrency/versioned_state.rs @@ -7,7 +7,7 @@ use starknet_types_core::felt::Felt; use crate::concurrency::versioned_storage::VersionedStorage; use crate::concurrency::TxIndex; -use crate::execution::contract_class::ContractClass; +use crate::execution::contract_class::RunnableContractClass; use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult, UpdatableState}; @@ -34,7 +34,7 @@ pub struct VersionedState { // the compiled contract classes mapping. Each key with value false, sohuld not apprear // in the compiled contract classes mapping. declared_contracts: VersionedStorage, - compiled_contract_classes: VersionedStorage, + compiled_contract_classes: VersionedStorage, } impl VersionedState { @@ -336,7 +336,10 @@ impl StateReader for VersionedStateProxy { } } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let mut state = self.state(); match state.compiled_contract_classes.read(self.tx_index, class_hash) { Some(value) => Ok(value), diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index 8a69d59a39..261e175f59 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -72,7 +72,7 @@ fn test_versioned_state_proxy() { let class_hash = class_hash!(27_u8); let another_class_hash = class_hash!(28_u8); let compiled_class_hash = compiled_class_hash!(29_u8); - let contract_class = test_contract.get_class(); + let contract_class = test_contract.get_runnable_class(); // Create the versioned state let cached_state = CachedState::from(DictStateReader { @@ -118,7 +118,8 @@ fn test_versioned_state_proxy() { let class_hash_v7 = class_hash!(28_u8); let class_hash_v10 = class_hash!(29_u8); let compiled_class_hash_v18 = compiled_class_hash!(30_u8); - let contract_class_v11 = FeatureContract::TestContract(CairoVersion::Cairo1).get_class(); + let contract_class_v11 = + FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class(); versioned_state_proxys[3].state().apply_writes( 3, @@ -404,7 +405,8 @@ fn test_false_validate_reads_declared_contracts( ..Default::default() }; let version_state_proxy = safe_versioned_state.pin_version(0); - let compiled_contract_calss = FeatureContract::TestContract(CairoVersion::Cairo1).get_class(); + let compiled_contract_calss = + FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class(); let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]); version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class); assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); @@ -429,7 +431,7 @@ fn test_apply_writes( assert_eq!(transactional_states[0].cache.borrow().writes.class_hashes.len(), 1); // Transaction 0 contract class. - let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo1).get_class(); + let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class(); assert!(transactional_states[0].class_hash_to_class.borrow().is_empty()); transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap(); assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1); @@ -509,7 +511,10 @@ fn test_delete_writes( } // Modify the `class_hash_to_class` member of the CachedState. tx_state - .set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class()) + .set_contract_class( + feature_contract.get_class_hash(), + feature_contract.get_runnable_class(), + ) .unwrap(); safe_versioned_state.pin_version(i).apply_writes( &tx_state.cache.borrow().writes, @@ -568,7 +573,7 @@ fn test_delete_writes_completeness( declared_contracts: HashMap::from([(feature_contract.get_class_hash(), true)]), }; let class_hash_to_class_writes = - HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_class())]); + HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_runnable_class())]); let tx_index = 0; let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index); @@ -631,9 +636,9 @@ fn test_versioned_proxy_state_flow( transactional_states[3].set_class_hash_at(contract_address, class_hash_3).unwrap(); // Clients contract class values. - let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo0).get_class(); + let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo0).get_runnable_class(); let contract_class_2 = - FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1).get_class(); + FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1).get_runnable_class(); transactional_states[0].set_contract_class(class_hash, contract_class_0).unwrap(); transactional_states[2].set_contract_class(class_hash, contract_class_2.clone()).unwrap(); diff --git a/crates/blockifier/src/execution/contract_class.rs b/crates/blockifier/src/execution/contract_class.rs index d64a0c8c69..e65cd08787 100644 --- a/crates/blockifier/src/execution/contract_class.rs +++ b/crates/blockifier/src/execution/contract_class.rs @@ -21,7 +21,7 @@ use itertools::Itertools; use semver::Version; use serde::de::Error as DeserializationError; use serde::{Deserialize, Deserializer, Serialize}; -use starknet_api::contract_class::{ContractClass as RawContractClass, EntryPointType}; +use starknet_api::contract_class::{ContractClass, EntryPointType}; use starknet_api::core::EntryPointSelector; use starknet_api::deprecated_contract_class::{ ContractClass as DeprecatedContractClass, @@ -62,46 +62,42 @@ pub enum TrackedResource { /// Represents a runnable Starknet contract class (meaning, the program is runnable by the VM). #[derive(Clone, Debug, Eq, PartialEq, derive_more::From)] -pub enum ContractClass { +pub enum RunnableContractClass { V0(ContractClassV0), V1(ContractClassV1), #[cfg(feature = "cairo_native")] V1Native(NativeContractClassV1), } -impl TryFrom for ContractClass { +impl TryFrom for RunnableContractClass { type Error = ProgramError; - fn try_from(raw_contract_class: RawContractClass) -> Result { - let contract_class: ContractClass = match raw_contract_class { - RawContractClass::V0(raw_contract_class) => { - ContractClass::V0(raw_contract_class.try_into()?) - } - RawContractClass::V1(raw_contract_class) => { - ContractClass::V1(raw_contract_class.try_into()?) - } + fn try_from(raw_contract_class: ContractClass) -> Result { + let contract_class: Self = match raw_contract_class { + ContractClass::V0(raw_contract_class) => Self::V0(raw_contract_class.try_into()?), + ContractClass::V1(raw_contract_class) => Self::V1(raw_contract_class.try_into()?), }; Ok(contract_class) } } -impl ContractClass { +impl RunnableContractClass { pub fn constructor_selector(&self) -> Option { match self { - ContractClass::V0(class) => class.constructor_selector(), - ContractClass::V1(class) => class.constructor_selector(), + Self::V0(class) => class.constructor_selector(), + Self::V1(class) => class.constructor_selector(), #[cfg(feature = "cairo_native")] - ContractClass::V1Native(class) => class.constructor_selector(), + Self::V1Native(class) => class.constructor_selector(), } } pub fn estimate_casm_hash_computation_resources(&self) -> ExecutionResources { match self { - ContractClass::V0(class) => class.estimate_casm_hash_computation_resources(), - ContractClass::V1(class) => class.estimate_casm_hash_computation_resources(), + Self::V0(class) => class.estimate_casm_hash_computation_resources(), + Self::V1(class) => class.estimate_casm_hash_computation_resources(), #[cfg(feature = "cairo_native")] - ContractClass::V1Native(_) => { + Self::V1Native(_) => { todo!("Use casm to estimate casm hash computation resources") } } @@ -112,12 +108,12 @@ impl ContractClass { visited_pcs: &HashSet, ) -> Result, TransactionExecutionError> { match self { - ContractClass::V0(_) => { + Self::V0(_) => { panic!("get_visited_segments is not supported for v0 contracts.") } - ContractClass::V1(class) => class.get_visited_segments(visited_pcs), + Self::V1(class) => class.get_visited_segments(visited_pcs), #[cfg(feature = "cairo_native")] - ContractClass::V1Native(_) => { + Self::V1Native(_) => { panic!("get_visited_segments is not supported for native contracts.") } } @@ -125,10 +121,10 @@ impl ContractClass { pub fn bytecode_length(&self) -> usize { match self { - ContractClass::V0(class) => class.bytecode_length(), - ContractClass::V1(class) => class.bytecode_length(), + Self::V0(class) => class.bytecode_length(), + Self::V1(class) => class.bytecode_length(), #[cfg(feature = "cairo_native")] - ContractClass::V1Native(_) => { + Self::V1Native(_) => { todo!("implement bytecode_length for native contracts.") } } @@ -137,12 +133,10 @@ impl ContractClass { /// Returns whether this contract should run using Cairo steps or Sierra gas. pub fn tracked_resource(&self, min_sierra_version: &CompilerVersion) -> TrackedResource { match self { - ContractClass::V0(_) => TrackedResource::CairoSteps, - ContractClass::V1(contract_class) => { - contract_class.tracked_resource(min_sierra_version) - } + Self::V0(_) => TrackedResource::CairoSteps, + Self::V1(contract_class) => contract_class.tracked_resource(min_sierra_version), #[cfg(feature = "cairo_native")] - ContractClass::V1Native(_) => TrackedResource::SierraGas, + Self::V1Native(_) => TrackedResource::SierraGas, } } } @@ -545,13 +539,16 @@ impl TryFrom for ClassInfo { abi_length, } = class_info; - Ok(Self { contract_class: contract_class.try_into()?, sierra_program_length, abi_length }) + Ok(Self { contract_class: contract_class.clone(), sierra_program_length, abi_length }) } } impl ClassInfo { pub fn bytecode_length(&self) -> usize { - self.contract_class.bytecode_length() + match &self.contract_class { + ContractClass::V0(contract_class) => contract_class.bytecode_length(), + ContractClass::V1(contract_class) => contract_class.bytecode.len(), + } } pub fn contract_class(&self) -> ContractClass { @@ -581,8 +578,6 @@ impl ClassInfo { let (contract_class_version, condition) = match contract_class { ContractClass::V0(_) => (0, sierra_program_length == 0), ContractClass::V1(_) => (1, sierra_program_length > 0), - #[cfg(feature = "cairo_native")] - ContractClass::V1Native(_) => (1, sierra_program_length > 0), }; if condition { diff --git a/crates/blockifier/src/execution/execution_utils.rs b/crates/blockifier/src/execution/execution_utils.rs index 8896f171bb..a17b2a71ba 100644 --- a/crates/blockifier/src/execution/execution_utils.rs +++ b/crates/blockifier/src/execution/execution_utils.rs @@ -30,7 +30,7 @@ use super::errors::{ }; use super::syscalls::hint_processor::ENTRYPOINT_NOT_FOUND_ERROR; use crate::execution::call_info::{CallInfo, Retdata}; -use crate::execution::contract_class::{ContractClass, TrackedResource}; +use crate::execution::contract_class::{RunnableContractClass, TrackedResource}; use crate::execution::entry_point::{ execute_constructor_entry_point, CallEntryPoint, @@ -54,7 +54,7 @@ pub const SEGMENT_ARENA_BUILTIN_SIZE: usize = 3; /// A wrapper for execute_entry_point_call that performs pre and post-processing. pub fn execute_entry_point_call_wrapper( mut call: CallEntryPoint, - contract_class: ContractClass, + contract_class: RunnableContractClass, state: &mut dyn State, resources: &mut ExecutionResources, context: &mut EntryPointExecutionContext, @@ -118,13 +118,13 @@ pub fn execute_entry_point_call_wrapper( /// Executes a specific call to a contract entry point and returns its output. pub fn execute_entry_point_call( call: CallEntryPoint, - contract_class: ContractClass, + contract_class: RunnableContractClass, state: &mut dyn State, resources: &mut ExecutionResources, context: &mut EntryPointExecutionContext, ) -> EntryPointExecutionResult { match contract_class { - ContractClass::V0(contract_class) => { + RunnableContractClass::V0(contract_class) => { deprecated_entry_point_execution::execute_entry_point_call( call, contract_class, @@ -133,15 +133,17 @@ pub fn execute_entry_point_call( context, ) } - ContractClass::V1(contract_class) => entry_point_execution::execute_entry_point_call( - call, - contract_class, - state, - resources, - context, - ), + RunnableContractClass::V1(contract_class) => { + entry_point_execution::execute_entry_point_call( + call, + contract_class, + state, + resources, + context, + ) + } #[cfg(feature = "cairo_native")] - ContractClass::V1Native(contract_class) => { + RunnableContractClass::V1Native(contract_class) => { if context.tracked_resource_stack.last() == Some(&TrackedResource::CairoSteps) { // We cannot run native with cairo steps as the tracked resources (it's a vm // resouorce). diff --git a/crates/blockifier/src/execution/stack_trace.rs b/crates/blockifier/src/execution/stack_trace.rs index 2a52ea831d..c17b8b58ba 100644 --- a/crates/blockifier/src/execution/stack_trace.rs +++ b/crates/blockifier/src/execution/stack_trace.rs @@ -96,6 +96,7 @@ impl From<&VmExceptionFrame> for String { } #[cfg_attr(feature = "transaction_serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(derive_more::From)] pub enum Frame { EntryPoint(EntryPointErrorFrame), Vm(VmExceptionFrame), @@ -112,24 +113,6 @@ impl From<&Frame> for String { } } -impl From for Frame { - fn from(value: EntryPointErrorFrame) -> Self { - Frame::EntryPoint(value) - } -} - -impl From for Frame { - fn from(value: VmExceptionFrame) -> Self { - Frame::Vm(value) - } -} - -impl From for Frame { - fn from(value: String) -> Self { - Frame::StringFrame(value) - } -} - #[cfg_attr(feature = "transaction_serde", derive(serde::Serialize, serde::Deserialize))] #[derive(Default)] pub struct ErrorStack { diff --git a/crates/blockifier/src/execution/syscalls/syscall_tests/library_call.rs b/crates/blockifier/src/execution/syscalls/syscall_tests/library_call.rs index 66849a3634..1d41a94838 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_tests/library_call.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_tests/library_call.rs @@ -147,7 +147,7 @@ fn test_nested_library_call(test_contract: FeatureContract, expected_gas: u64) { }; // The default VersionedConstants is used in the execute_directly call bellow. - let tracked_resource = test_contract.get_class().tracked_resource( + let tracked_resource = test_contract.get_runnable_class().tracked_resource( &VersionedConstants::create_for_testing().min_compiler_version_for_sierra_gas, ); diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index ab51525b8d..d9a4a598f6 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -8,7 +8,7 @@ use starknet_types_core::felt::Felt; use crate::abi::abi_utils::get_fee_token_var_address; use crate::context::TransactionContext; -use crate::execution::contract_class::ContractClass; +use crate::execution::contract_class::RunnableContractClass; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader, StateResult, UpdatableState}; use crate::transaction::objects::TransactionExecutionInfo; @@ -18,7 +18,7 @@ use crate::utils::{strict_subtract_mappings, subtract_mappings}; #[path = "cached_state_test.rs"] mod test; -pub type ContractClassMapping = HashMap; +pub type ContractClassMapping = HashMap; /// Caches read and write requests. /// @@ -174,7 +174,10 @@ impl StateReader for CachedState { Ok(*class_hash) } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let mut cache = self.cache.borrow_mut(); let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut(); @@ -257,7 +260,7 @@ impl State for CachedState { fn set_contract_class( &mut self, class_hash: ClassHash, - contract_class: ContractClass, + contract_class: RunnableContractClass, ) -> StateResult<()> { self.class_hash_to_class.get_mut().insert(class_hash, contract_class); let mut cache = self.cache.borrow_mut(); @@ -493,7 +496,10 @@ impl<'a, S: StateReader + ?Sized> StateReader for MutRefState<'a, S> { self.0.get_class_hash_at(contract_address) } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { self.0.get_compiled_contract_class(class_hash) } diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index 05da0ad607..569c38f828 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -108,7 +108,7 @@ fn declare_contract() { let mut state = CachedState::from(DictStateReader { ..Default::default() }); let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); - let contract_class = test_contract.get_class(); + let contract_class = test_contract.get_runnable_class(); assert_eq!(state.cache.borrow().writes.declared_contracts.get(&class_hash), None); assert_eq!(state.cache.borrow().initial_reads.declared_contracts.get(&class_hash), None); @@ -167,7 +167,7 @@ fn get_contract_class() { let state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 0)]); assert_eq!( state.get_compiled_contract_class(test_contract.get_class_hash()).unwrap(), - test_contract.get_class() + test_contract.get_runnable_class() ); // Negative flow. @@ -214,7 +214,8 @@ fn cached_state_state_diff_conversion() { // are aligned with. let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let test_class_hash = test_contract.get_class_hash(); - let class_hash_to_class = HashMap::from([(test_class_hash, test_contract.get_class())]); + let class_hash_to_class = + HashMap::from([(test_class_hash, test_contract.get_runnable_class())]); let nonce_initial_values = HashMap::new(); @@ -418,7 +419,7 @@ fn test_contract_cache_is_used() { // cache. let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); - let contract_class = test_contract.get_class(); + let contract_class = test_contract.get_runnable_class(); let mut reader = DictStateReader::default(); reader.class_hash_to_class.insert(class_hash, contract_class.clone()); let state = CachedState::new(reader); diff --git a/crates/blockifier/src/state/global_cache.rs b/crates/blockifier/src/state/global_cache.rs index 54d71a1fac..670045fe5f 100644 --- a/crates/blockifier/src/state/global_cache.rs +++ b/crates/blockifier/src/state/global_cache.rs @@ -3,10 +3,10 @@ use std::sync::{Arc, Mutex, MutexGuard}; use cached::{Cached, SizedCache}; use starknet_api::core::ClassHash; -use crate::execution::contract_class::ContractClass; +use crate::execution::contract_class::RunnableContractClass; // Note: `ContractClassLRUCache` key-value types must align with `ContractClassMapping`. -type ContractClassLRUCache = SizedCache; +type ContractClassLRUCache = SizedCache; pub type LockedContractClassCache<'a> = MutexGuard<'a, ContractClassLRUCache>; #[derive(Debug, Clone)] // Thread-safe LRU cache for contract classes, optimized for inter-language sharing when @@ -23,11 +23,11 @@ impl GlobalContractCache { self.0.lock().expect("Global contract cache is poisoned.") } - pub fn get(&self, class_hash: &ClassHash) -> Option { + pub fn get(&self, class_hash: &ClassHash) -> Option { self.lock().cache_get(class_hash).cloned() } - pub fn set(&self, class_hash: ClassHash, contract_class: ContractClass) { + pub fn set(&self, class_hash: ClassHash, contract_class: RunnableContractClass) { self.lock().cache_set(class_hash, contract_class); } diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index b6c20b2e45..8d430a1892 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -7,7 +7,7 @@ use starknet_types_core::felt::Felt; use super::cached_state::{ContractClassMapping, StateMaps}; use crate::abi::abi_utils::get_fee_token_var_address; use crate::abi::sierra_types::next_storage_key; -use crate::execution::contract_class::ContractClass; +use crate::execution::contract_class::RunnableContractClass; use crate::state::errors::StateError; pub type StateResult = Result; @@ -41,7 +41,10 @@ pub trait StateReader { fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult; /// Returns the contract class of the given class hash. - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult; + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult; /// Returns the compiled class hash of the given class hash. fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult; @@ -94,7 +97,7 @@ pub trait State: StateReader { fn set_contract_class( &mut self, class_hash: ClassHash, - contract_class: ContractClass, + contract_class: RunnableContractClass, ) -> StateResult<()>; /// Sets the given compiled class hash under the given class hash. diff --git a/crates/blockifier/src/test_utils.rs b/crates/blockifier/src/test_utils.rs index 7d6d0256f6..fe8993dd06 100644 --- a/crates/blockifier/src/test_utils.rs +++ b/crates/blockifier/src/test_utils.rs @@ -5,6 +5,7 @@ pub mod deploy_account; pub mod dict_state_reader; pub mod initial_test_state; pub mod invoke; +pub mod l1_handler; pub mod prices; pub mod struct_impls; pub mod syscall; diff --git a/crates/blockifier/src/test_utils/contracts.rs b/crates/blockifier/src/test_utils/contracts.rs index ec84d7ea33..62b400060d 100644 --- a/crates/blockifier/src/test_utils/contracts.rs +++ b/crates/blockifier/src/test_utils/contracts.rs @@ -1,4 +1,5 @@ -use starknet_api::contract_class::EntryPointType; +use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; +use starknet_api::contract_class::{ContractClass, EntryPointType}; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, EntryPointSelector}; use starknet_api::deprecated_contract_class::{ ContractClass as DeprecatedContractClass, @@ -11,9 +12,10 @@ use strum_macros::EnumIter; use crate::abi::abi_utils::selector_from_name; use crate::abi::constants::CONSTRUCTOR_ENTRY_POINT_NAME; -use crate::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV1}; +use crate::execution::contract_class::RunnableContractClass; use crate::execution::entry_point::CallEntryPoint; use crate::test_utils::cairo_compile::{cairo0_compile, cairo1_compile}; +use crate::test_utils::struct_impls::LoadContractFromFile; use crate::test_utils::{get_raw_contract_class, CairoVersion}; // This file contains featured contracts, used for tests. Use the function 'test_state' in @@ -150,11 +152,19 @@ impl FeatureContract { pub fn get_class(&self) -> ContractClass { match self.cairo_version() { - CairoVersion::Cairo0 => ContractClassV0::from_file(&self.get_compiled_path()).into(), - CairoVersion::Cairo1 => ContractClassV1::from_file(&self.get_compiled_path()).into(), + CairoVersion::Cairo0 => { + ContractClass::V0(DeprecatedContractClass::from_file(&self.get_compiled_path())) + } + CairoVersion::Cairo1 => { + ContractClass::V1(CasmContractClass::from_file(&self.get_compiled_path())) + } } } + pub fn get_runnable_class(&self) -> RunnableContractClass { + self.get_class().try_into().unwrap() + } + // TODO(Arni, 1/1/2025): Remove this function, and use the get_class function instead. pub fn get_deprecated_contract_class(&self) -> DeprecatedContractClass { let mut raw_contract_class: serde_json::Value = @@ -310,8 +320,8 @@ impl FeatureContract { entry_point_selector: EntryPointSelector, entry_point_type: EntryPointType, ) -> EntryPointOffset { - match self.get_class() { - ContractClass::V0(class) => { + match self.get_runnable_class() { + RunnableContractClass::V0(class) => { class .entry_points_by_type .get(&entry_point_type) @@ -321,7 +331,7 @@ impl FeatureContract { .unwrap() .offset } - ContractClass::V1(class) => { + RunnableContractClass::V1(class) => { class .entry_points_by_type .get_entry_point(&CallEntryPoint { @@ -333,7 +343,7 @@ impl FeatureContract { .offset } #[cfg(feature = "cairo_native")] - ContractClass::V1Native(_) => { + RunnableContractClass::V1Native(_) => { panic!("Not implemented for cairo native contracts") } } diff --git a/crates/blockifier/src/test_utils/dict_state_reader.rs b/crates/blockifier/src/test_utils/dict_state_reader.rs index 54fcd890e2..469e6cbdb7 100644 --- a/crates/blockifier/src/test_utils/dict_state_reader.rs +++ b/crates/blockifier/src/test_utils/dict_state_reader.rs @@ -4,7 +4,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; -use crate::execution::contract_class::ContractClass; +use crate::execution::contract_class::RunnableContractClass; use crate::state::cached_state::StorageEntry; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult}; @@ -15,7 +15,7 @@ pub struct DictStateReader { pub storage_view: HashMap, pub address_to_nonce: HashMap, pub address_to_class_hash: HashMap, - pub class_hash_to_class: HashMap, + pub class_hash_to_class: HashMap, pub class_hash_to_compiled_class_hash: HashMap, } @@ -35,7 +35,10 @@ impl StateReader for DictStateReader { Ok(nonce) } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let contract_class = self.class_hash_to_class.get(&class_hash).cloned(); match contract_class { Some(contract_class) => Ok(contract_class), diff --git a/crates/blockifier/src/test_utils/initial_test_state.rs b/crates/blockifier/src/test_utils/initial_test_state.rs index 6962244bfc..82711da2ef 100644 --- a/crates/blockifier/src/test_utils/initial_test_state.rs +++ b/crates/blockifier/src/test_utils/initial_test_state.rs @@ -49,7 +49,7 @@ pub fn test_state_inner( // Declare and deploy account and ERC20 contracts. let erc20 = FeatureContract::ERC20(erc20_contract_version); - class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_class()); + class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_runnable_class()); address_to_class_hash .insert(chain_info.fee_token_address(&FeeType::Eth), erc20.get_class_hash()); address_to_class_hash @@ -58,7 +58,7 @@ pub fn test_state_inner( // Set up the rest of the requested contracts. for (contract, n_instances) in contract_instances.iter() { let class_hash = contract.get_class_hash(); - class_hash_to_class.insert(class_hash, contract.get_class()); + class_hash_to_class.insert(class_hash, contract.get_runnable_class()); for instance in 0..*n_instances { let instance_address = contract.get_instance_address(instance); address_to_class_hash.insert(instance_address, class_hash); diff --git a/crates/blockifier/src/test_utils/l1_handler.rs b/crates/blockifier/src/test_utils/l1_handler.rs new file mode 100644 index 0000000000..0a95567abb --- /dev/null +++ b/crates/blockifier/src/test_utils/l1_handler.rs @@ -0,0 +1,24 @@ +use starknet_api::calldata; +use starknet_api::core::{ContractAddress, Nonce}; +use starknet_api::executable_transaction::L1HandlerTransaction; +use starknet_api::transaction::{Fee, TransactionHash, TransactionVersion}; +use starknet_types_core::felt::Felt; + +use crate::abi::abi_utils::selector_from_name; + +pub fn l1handler_tx(l1_fee: Fee, contract_address: ContractAddress) -> L1HandlerTransaction { + let calldata = calldata![ + Felt::from(0x123), // from_address. + Felt::from(0x876), // key. + Felt::from(0x44) // value. + ]; + let tx = starknet_api::transaction::L1HandlerTransaction { + version: TransactionVersion::ZERO, + nonce: Nonce::default(), + contract_address, + entry_point_selector: selector_from_name("l1_handler_set_value"), + calldata, + }; + let tx_hash = TransactionHash::default(); + L1HandlerTransaction { tx, tx_hash, paid_fee_on_l1: l1_fee } +} diff --git a/crates/blockifier/src/test_utils/struct_impls.rs b/crates/blockifier/src/test_utils/struct_impls.rs index 8db656841a..5c5ba7882b 100644 --- a/crates/blockifier/src/test_utils/struct_impls.rs +++ b/crates/blockifier/src/test_utils/struct_impls.rs @@ -1,15 +1,14 @@ use std::sync::Arc; +use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use cairo_vm::vm::runners::cairo_runner::ExecutionResources; use serde_json::Value; use starknet_api::block::{BlockNumber, BlockTimestamp, NonzeroGasPrice}; -use starknet_api::core::{ChainId, ClassHash, ContractAddress, Nonce}; -use starknet_api::transaction::{Fee, TransactionHash, TransactionVersion}; -use starknet_api::{calldata, contract_address}; -use starknet_types_core::felt::Felt; +use starknet_api::contract_address; +use starknet_api::core::{ChainId, ClassHash}; +use starknet_api::deprecated_contract_class::ContractClass as DeprecatedContractClass; use super::update_json_value; -use crate::abi::abi_utils::selector_from_name; use crate::blockifier::block::{BlockInfo, GasPrices}; use crate::bouncer::{BouncerConfig, BouncerWeights, BuiltinCount}; use crate::context::{BlockContext, ChainInfo, FeeTokenAddresses, TransactionContext}; @@ -35,7 +34,6 @@ use crate::test_utils::{ TEST_SEQUENCER_ADDRESS, }; use crate::transaction::objects::{DeprecatedTransactionInfo, TransactionInfo}; -use crate::transaction::transactions::L1HandlerTransaction; use crate::versioned_constants::{ GasCosts, OsConstants, @@ -210,6 +208,17 @@ impl CallExecution { // Contract loaders. +// TODO(Noa): Consider using PathBuf. +pub trait LoadContractFromFile: serde::de::DeserializeOwned { + fn from_file(contract_path: &str) -> Self { + let raw_contract_class = get_raw_contract_class(contract_path); + serde_json::from_str(&raw_contract_class).unwrap() + } +} + +impl LoadContractFromFile for CasmContractClass {} +impl LoadContractFromFile for DeprecatedContractClass {} + impl ContractClassV0 { pub fn from_file(contract_path: &str) -> Self { let raw_contract_class = get_raw_contract_class(contract_path); @@ -224,25 +233,6 @@ impl ContractClassV1 { } } -impl L1HandlerTransaction { - pub fn create_for_testing(l1_fee: Fee, contract_address: ContractAddress) -> Self { - let calldata = calldata![ - Felt::from(0x123), // from_address. - Felt::from(0x876), // key. - Felt::from(0x44) // value. - ]; - let tx = starknet_api::transaction::L1HandlerTransaction { - version: TransactionVersion::ZERO, - nonce: Nonce::default(), - contract_address, - entry_point_selector: selector_from_name("l1_handler_set_value"), - calldata, - }; - let tx_hash = TransactionHash::default(); - Self { tx, tx_hash, paid_fee_on_l1: l1_fee } - } -} - impl BouncerWeights { pub fn create_for_testing(builtin_count: BuiltinCount) -> Self { Self { builtin_count, ..Self::empty() } diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 930aa990eb..76bad0b168 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -24,7 +24,7 @@ use starknet_types_core::felt::Felt; use crate::abi::abi_utils::selector_from_name; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; -use crate::execution::contract_class::ContractClass; +use crate::execution::contract_class::RunnableContractClass; use crate::execution::entry_point::{CallEntryPoint, CallType, EntryPointExecutionContext}; use crate::execution::stack_trace::extract_trailing_cairo1_revert_trace; use crate::fee::fee_checks::{FeeCheckReportFields, PostExecutionReport}; @@ -959,11 +959,11 @@ impl ValidatableTransaction for AccountTransaction { } } -pub fn is_cairo1(contract_class: &ContractClass) -> bool { +pub fn is_cairo1(contract_class: &RunnableContractClass) -> bool { match contract_class { - ContractClass::V0(_) => false, - ContractClass::V1(_) => true, + RunnableContractClass::V0(_) => false, + RunnableContractClass::V1(_) => true, #[cfg(feature = "cairo_native")] - ContractClass::V1Native(_) => true, + RunnableContractClass::V1Native(_) => true, } } diff --git a/crates/blockifier/src/transaction/account_transactions_test.rs b/crates/blockifier/src/transaction/account_transactions_test.rs index d7a3ad7e41..3d39a48b54 100644 --- a/crates/blockifier/src/transaction/account_transactions_test.rs +++ b/crates/blockifier/src/transaction/account_transactions_test.rs @@ -47,7 +47,6 @@ use crate::abi::abi_utils::{ use crate::check_tx_execution_error_for_invalid_scenario; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; -use crate::execution::contract_class::{ContractClass, ContractClassV1}; use crate::execution::entry_point::EntryPointExecutionContext; use crate::execution::syscalls::SyscallSelector; use crate::fee::fee_utils::{get_fee_by_gas_vector, get_sequencer_balance_keys}; @@ -748,7 +747,7 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { let TestInitData { mut state, account_address, mut nonce_manager, .. } = create_test_init_data(chain_info, CairoVersion::Cairo0); let class_hash = class_hash!(0xdeadeadeaf72_u128); - let contract_class = ContractClass::V1(ContractClassV1::empty_for_testing()); + let contract_class = FeatureContract::Empty(CairoVersion::Cairo1).get_class(); let next_nonce = nonce_manager.next(account_address); // Cannot fail executing a declare tx unless it's V2 or above, and already declared. @@ -758,7 +757,7 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { sender_address: account_address, ..Default::default() }; - state.set_contract_class(class_hash, contract_class.clone()).unwrap(); + state.set_contract_class(class_hash, contract_class.clone().try_into().unwrap()).unwrap(); state.set_compiled_class_hash(class_hash, declare_tx.compiled_class_hash).unwrap(); let class_info = calculate_class_info_for_testing(contract_class); let declare_account_tx = AccountTransaction::Declare( @@ -1257,7 +1256,7 @@ fn test_insufficient_max_fee_reverts( #[rstest] fn test_deploy_account_constructor_storage_write( - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, block_context: BlockContext, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] cairo_version: CairoVersion, ) { @@ -1275,7 +1274,7 @@ fn test_deploy_account_constructor_storage_write( chain_info, deploy_account_tx_args! { class_hash, - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, constructor_calldata: constructor_calldata.clone(), }, ); @@ -1301,7 +1300,7 @@ fn test_deploy_account_constructor_storage_write( fn test_count_actual_storage_changes( max_fee: Fee, block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[case] version: TransactionVersion, #[case] fee_type: FeeType, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] cairo_version: CairoVersion, @@ -1344,7 +1343,7 @@ fn test_count_actual_storage_changes( let mut state = TransactionalState::create_transactional(&mut state); let invoke_args = invoke_tx_args! { max_fee, - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, version, sender_address: account_address, calldata: write_1_calldata, @@ -1482,7 +1481,7 @@ fn test_count_actual_storage_changes( #[case::tx_version_3(TransactionVersion::THREE)] fn test_concurrency_execute_fee_transfer( max_fee: Fee, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[case] version: TransactionVersion, ) { // TODO(Meshi, 01/06/2024): make the test so it will include changes in @@ -1501,7 +1500,7 @@ fn test_concurrency_execute_fee_transfer( sender_address: account.get_instance_address(0), max_fee, calldata: create_trivial_calldata(test_contract.get_instance_address(0)), - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, version }); let fee_type = &account_tx.fee_type(); @@ -1548,7 +1547,7 @@ fn test_concurrency_execute_fee_transfer( sender_address: account.get_instance_address(0), calldata: transfer_calldata, max_fee, - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, }); let execution_result = @@ -1582,7 +1581,7 @@ fn test_concurrency_execute_fee_transfer( #[case::tx_version_3(TransactionVersion::THREE)] fn test_concurrent_fee_transfer_when_sender_is_sequencer( max_fee: Fee, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[case] version: TransactionVersion, ) { let mut block_context = BlockContext::create_for_account_testing(); @@ -1599,7 +1598,7 @@ fn test_concurrent_fee_transfer_when_sender_is_sequencer( max_fee, sender_address: account_address, calldata: create_trivial_calldata(test_contract.get_instance_address(0)), - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, version }); let fee_type = &account_tx.fee_type(); @@ -1717,7 +1716,7 @@ fn test_initial_gas( #[rstest] fn test_revert_in_execute( block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, ) { let account = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); let chain_info = &block_context.chain_info; @@ -1735,7 +1734,7 @@ fn test_revert_in_execute( // Skip validate phase, as we want to test the revert in the execute phase. let validate = false; let tx_execution_info = account_invoke_tx(invoke_tx_args! { - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, ..tx_args }) .execute(state, &block_context, true, validate) @@ -1748,7 +1747,7 @@ fn test_revert_in_execute( #[rstest] fn test_call_contract_that_panics( mut block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[values(true, false)] enable_reverts: bool, #[values("test_revert_helper", "bad_selector")] inner_selector: &str, ) { @@ -1785,7 +1784,7 @@ fn test_call_contract_that_panics( state, &block_context, invoke_tx_args! { - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, ..tx_args }, ) diff --git a/crates/blockifier/src/transaction/test_utils.rs b/crates/blockifier/src/transaction/test_utils.rs index 0dd1bb6bcd..e8a8936760 100644 --- a/crates/blockifier/src/transaction/test_utils.rs +++ b/crates/blockifier/src/transaction/test_utils.rs @@ -1,5 +1,6 @@ use rstest::fixture; use starknet_api::block::GasPrice; +use starknet_api::contract_class::ContractClass; use starknet_api::core::{ClassHash, ContractAddress, Nonce}; use starknet_api::execution_resources::GasAmount; use starknet_api::test_utils::deploy_account::DeployAccountTxArgs; @@ -25,7 +26,7 @@ use strum::IntoEnumIterator; use crate::abi::abi_utils::get_fee_token_var_address; use crate::context::{BlockContext, ChainInfo}; -use crate::execution::contract_class::{ClassInfo, ContractClass}; +use crate::execution::contract_class::ClassInfo; use crate::state::cached_state::CachedState; use crate::state::state_api::State; use crate::test_utils::contracts::FeatureContract; @@ -378,8 +379,6 @@ pub fn calculate_class_info_for_testing(contract_class: ContractClass) -> ClassI let sierra_program_length = match contract_class { ContractClass::V0(_) => 0, ContractClass::V1(_) => 100, - #[cfg(feature = "cairo_native")] - ContractClass::V1Native(_) => 100, }; ClassInfo::new(&contract_class, sierra_program_length, 100).unwrap() } diff --git a/crates/blockifier/src/transaction/transaction_execution.rs b/crates/blockifier/src/transaction/transaction_execution.rs index 5b4a1b5acc..64ca5b5901 100644 --- a/crates/blockifier/src/transaction/transaction_execution.rs +++ b/crates/blockifier/src/transaction/transaction_execution.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use cairo_vm::vm::runners::cairo_runner::ExecutionResources; use starknet_api::core::{calculate_contract_address, ContractAddress, Nonce}; +use starknet_api::executable_transaction::L1HandlerTransaction; use starknet_api::transaction::{Fee, Transaction as StarknetApiTransaction, TransactionHash}; use crate::bouncer::verify_tx_weights_within_max_capacity; @@ -27,7 +28,6 @@ use crate::transaction::transactions::{ ExecutableTransaction, ExecutionFlags, InvokeTransaction, - L1HandlerTransaction, }; // TODO: Move into transaction.rs, makes more sense to be defined there. diff --git a/crates/blockifier/src/transaction/transactions.rs b/crates/blockifier/src/transaction/transactions.rs index 56785195e5..ad8b13df78 100644 --- a/crates/blockifier/src/transaction/transactions.rs +++ b/crates/blockifier/src/transaction/transactions.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use cairo_vm::vm::runners::cairo_runner::ExecutionResources; -use starknet_api::contract_class::EntryPointType; +use starknet_api::contract_class::{ContractClass, EntryPointType}; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; +use starknet_api::executable_transaction::L1HandlerTransaction; use starknet_api::transaction::{ AccountDeploymentData, Calldata, @@ -18,7 +19,7 @@ use starknet_api::transaction::{ use crate::abi::abi_utils::selector_from_name; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; -use crate::execution::contract_class::{ClassInfo, ContractClass}; +use crate::execution::contract_class::ClassInfo; use crate::execution::entry_point::{ CallEntryPoint, CallType, @@ -153,7 +154,8 @@ impl DeclareTransaction { ) -> TransactionExecutionResult { let declare_version = declare_tx.version(); // Verify contract class version. - if !is_cairo1(&class_info.contract_class()) { + // TODO(Noa): Avoid the unnecessary conversion. + if !is_cairo1(&class_info.contract_class().try_into()?) { if declare_version > TransactionVersion::ONE { Err(TransactionExecutionError::ContractClassVersionMismatch { declare_version, @@ -229,7 +231,7 @@ impl DeclareTransaction { match state.get_compiled_contract_class(class_hash) { Err(StateError::UndeclaredClassHash(_)) => { // Class is undeclared; declare it. - state.set_contract_class(class_hash, self.contract_class())?; + state.set_contract_class(class_hash, self.contract_class().try_into()?)?; if let Some(compiled_class_hash) = compiled_class_hash { state.set_compiled_class_hash(class_hash, compiled_class_hash)?; } @@ -263,7 +265,7 @@ impl Executable for DeclareTransaction { // We allow redeclaration of the class for backward compatibility. // In the past, we allowed redeclaration of Cairo 0 contracts since there was // no class commitment (so no need to check if the class is already declared). - state.set_contract_class(class_hash, self.contract_class())?; + state.set_contract_class(class_hash, self.contract_class().try_into()?)?; } } starknet_api::transaction::DeclareTransaction::V2(DeclareTransactionV2 { @@ -557,20 +559,6 @@ impl TransactionInfoCreator for InvokeTransaction { } } -#[derive(Clone, Debug)] -pub struct L1HandlerTransaction { - pub tx: starknet_api::transaction::L1HandlerTransaction, - pub tx_hash: TransactionHash, - pub paid_fee_on_l1: Fee, -} - -impl L1HandlerTransaction { - pub fn payload_size(&self) -> usize { - // The calldata includes the "from" field, which is not a part of the payload. - self.tx.calldata.0.len() - 1 - } -} - impl HasRelatedFeeType for L1HandlerTransaction { fn version(&self) -> TransactionVersion { self.tx.version diff --git a/crates/blockifier/src/transaction/transactions_test.rs b/crates/blockifier/src/transaction/transactions_test.rs index 5498091cc9..ed9ac027de 100644 --- a/crates/blockifier/src/transaction/transactions_test.rs +++ b/crates/blockifier/src/transaction/transactions_test.rs @@ -25,6 +25,7 @@ use starknet_api::transaction::{ Fee, GasVectorComputationMode, L2ToL1Payload, + Resource, ResourceBounds, TransactionSignature, TransactionVersion, @@ -88,6 +89,7 @@ use crate::test_utils::deploy_account::deploy_account_tx; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; use crate::test_utils::invoke::invoke_tx; +use crate::test_utils::l1_handler::l1handler_tx; use crate::test_utils::prices::Prices; use crate::test_utils::{ create_calldata, @@ -102,7 +104,9 @@ use crate::test_utils::{ CURRENT_BLOCK_NUMBER_FOR_VALIDATE, CURRENT_BLOCK_TIMESTAMP, CURRENT_BLOCK_TIMESTAMP_FOR_VALIDATE, + DEFAULT_L1_DATA_GAS_MAX_AMOUNT, DEFAULT_L1_GAS_AMOUNT, + DEFAULT_L2_GAS_MAX_AMOUNT, DEFAULT_STRK_L1_DATA_GAS_PRICE, DEFAULT_STRK_L1_GAS_PRICE, DEFAULT_STRK_L2_GAS_PRICE, @@ -144,7 +148,7 @@ use crate::transaction::test_utils::{ VALID, }; use crate::transaction::transaction_types::TransactionType; -use crate::transaction::transactions::{ExecutableTransaction, L1HandlerTransaction}; +use crate::transaction::transactions::ExecutableTransaction; use crate::versioned_constants::VersionedConstants; use crate::{ check_tx_execution_error_for_custom_hint, @@ -466,7 +470,7 @@ fn test_invoke_tx( let actual_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); let tracked_resource = account_contract - .get_class() + .get_runnable_class() .tracked_resource(&versioned_constants.min_compiler_version_for_sierra_gas); // Build expected validate call info. @@ -1033,7 +1037,7 @@ fn test_max_fee_exceeds_balance( } #[rstest] -fn test_insufficient_new_resource_bounds( +fn test_insufficient_new_resource_bounds_pre_validation( mut block_context: BlockContext, #[values(true, false)] use_kzg_da: bool, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] account_cairo_version: CairoVersion, @@ -1166,7 +1170,7 @@ fn test_insufficient_new_resource_bounds( } #[rstest] -fn test_insufficient_resource_bounds( +fn test_insufficient_deprecated_resource_bounds_pre_validation( block_context: BlockContext, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] account_cairo_version: CairoVersion, ) { @@ -1252,55 +1256,88 @@ fn test_insufficient_resource_bounds( ); } -// TODO(Aner, 21/01/24) modify test for 4844. #[rstest] +#[case::l1_bounds(default_l1_resource_bounds(), Resource::L1Gas)] +#[case::all_bounds_l1_gas_overdraft(default_all_resource_bounds(), Resource::L1Gas)] +#[case::all_bounds_l2_gas_overdraft(default_all_resource_bounds(), Resource::L2Gas)] +#[case::all_bounds_l1_data_gas_overdraft(default_all_resource_bounds(), Resource::L1DataGas)] fn test_actual_fee_gt_resource_bounds( - block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + mut block_context: BlockContext, + #[case] resource_bounds: ValidResourceBounds, + #[case] overdraft_resource: Resource, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] account_cairo_version: CairoVersion, ) { - let block_context = &block_context; + let block_context = &mut block_context; + block_context.block_info.use_kzg_da = true; + let mut nonce_manager = NonceManager::default(); + let gas_mode = resource_bounds.get_gas_vector_computation_mode(); + let gas_prices = block_context.block_info.gas_prices.get_gas_prices_by_fee_type(&FeeType::Strk); let account_contract = FeatureContract::AccountWithoutValidations(account_cairo_version); let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let state = &mut test_state( &block_context.chain_info, BALANCE, - &[(account_contract, 1), (test_contract, 1)], + &[(account_contract, 2), (test_contract, 1)], ); - let invoke_tx_args = invoke_tx_args! { - sender_address: account_contract.get_instance_address(0), - calldata: create_trivial_calldata(test_contract.get_instance_address(0)), - resource_bounds: default_l1_resource_bounds + let sender_address0 = account_contract.get_instance_address(0); + let sender_address1 = account_contract.get_instance_address(1); + let tx_args = invoke_tx_args! { + sender_address: sender_address0, + calldata: create_calldata( + test_contract.get_instance_address(0), "write_a_lot", &[felt!(2_u8), felt!(7_u8)] + ), + resource_bounds, + nonce: nonce_manager.next(sender_address0), }; - let tx = &account_invoke_tx(invoke_tx_args.clone()); - let minimal_l1_gas = - estimate_minimal_gas_vector(block_context, tx, &GasVectorComputationMode::NoL2Gas).l1_gas; - let minimal_resource_bounds = l1_resource_bounds( - minimal_l1_gas, - block_context.block_info.gas_prices.get_l1_gas_price_by_fee_type(&FeeType::Strk).into(), - ); - // The estimated minimal fee is lower than the actual fee. - let invalid_tx = account_invoke_tx( - invoke_tx_args! { resource_bounds: minimal_resource_bounds, ..invoke_tx_args }, - ); + // Execute the tx to compute the final gas costs. + let tx = &account_invoke_tx(tx_args.clone()); + let execution_result = tx.execute(state, block_context, true, true).unwrap(); + let mut actual_gas = execution_result.receipt.gas; + + // Create new gas bounds that are lower than the actual gas. + let (expected_fee, overdraft_resource_bounds) = match gas_mode { + GasVectorComputationMode::NoL2Gas => { + let l1_gas_bound = GasAmount(actual_gas.to_discounted_l1_gas(gas_prices).0 - 1); + ( + GasVector::from_l1_gas(l1_gas_bound).cost(gas_prices), + l1_resource_bounds(l1_gas_bound, gas_prices.l1_gas_price.into()), + ) + } + GasVectorComputationMode::All => { + match overdraft_resource { + Resource::L1Gas => actual_gas.l1_gas.0 -= 1, + Resource::L2Gas => actual_gas.l2_gas.0 -= 1, + Resource::L1DataGas => actual_gas.l1_data_gas.0 -= 1, + } + ( + actual_gas.cost(gas_prices), + ValidResourceBounds::all_bounds_from_vectors(&actual_gas, gas_prices), + ) + } + }; + let invalid_tx = account_invoke_tx(invoke_tx_args! { + sender_address: sender_address1, + resource_bounds: overdraft_resource_bounds, + // To get the same DA cost, write a different value. + calldata: create_calldata( + test_contract.get_instance_address(0), "write_a_lot", &[felt!(2_u8), felt!(8_u8)] + ), + nonce: nonce_manager.next(sender_address1), + }); let execution_result = invalid_tx.execute(state, block_context, true, true).unwrap(); let execution_error = execution_result.revert_error.unwrap(); - // Test error. - assert!(execution_error.starts_with(&format!("Insufficient max {resource}", resource = L1Gas))); - // Test that fee was charged. - let minimal_fee = minimal_l1_gas - .checked_mul( - block_context.block_info.gas_prices.get_l1_gas_price_by_fee_type(&FeeType::Strk).into(), - ) - .unwrap(); - assert_eq!(execution_result.receipt.fee, minimal_fee); + + // Test error and that fee was charged. Should be at most the fee charged in a successful + // execution. + assert!(execution_error.starts_with(&format!("Insufficient max {overdraft_resource}"))); + assert_eq!(execution_result.receipt.fee, expected_fee); } #[rstest] fn test_invalid_nonce( block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] account_cairo_version: CairoVersion, ) { let account_contract = FeatureContract::AccountWithoutValidations(account_cairo_version); @@ -1313,7 +1350,7 @@ fn test_invalid_nonce( let valid_invoke_tx_args = invoke_tx_args! { sender_address: account_contract.get_instance_address(0), calldata: create_trivial_calldata(test_contract.get_instance_address(0)), - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, }; let mut transactional_state = TransactionalState::create_transactional(state); @@ -1424,7 +1461,7 @@ fn declare_expected_state_changes_count(version: TransactionVersion) -> StateCha #[case(TransactionVersion::TWO, CairoVersion::Cairo1)] #[case(TransactionVersion::THREE, CairoVersion::Cairo1)] fn test_declare_tx( - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] account_cairo_version: CairoVersion, #[case] tx_version: TransactionVersion, #[case] empty_contract_version: CairoVersion, @@ -1455,7 +1492,7 @@ fn test_declare_tx( max_fee: MAX_FEE, sender_address, version: tx_version, - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, class_hash, compiled_class_hash, nonce: nonce_manager.next(sender_address), @@ -1472,6 +1509,7 @@ fn test_declare_tx( let fee_type = &account_tx.fee_type(); let tx_context = &block_context.to_tx_context(&account_tx); let actual_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); + assert_eq!(actual_execution_info.revert_error, None); // Build expected validate call info. let expected_validate_call_info = declare_validate_callinfo( @@ -1481,10 +1519,10 @@ fn test_declare_tx( account.get_class_hash(), sender_address, account - .get_class() + .get_runnable_class() .tracked_resource(&versioned_constants.min_compiler_version_for_sierra_gas), if tx_version >= TransactionVersion::THREE { - user_initial_gas_from_bounds(default_l1_resource_bounds) + user_initial_gas_from_bounds(default_all_resource_bounds) } else { None }, @@ -1529,7 +1567,7 @@ fn test_declare_tx( let expected_total_gas = expected_actual_resources.to_gas_vector( versioned_constants, use_kzg_da, - &GasVectorComputationMode::NoL2Gas, + &tx_context.get_gas_vector_computation_mode(), ); let expected_execution_info = TransactionExecutionInfo { @@ -1566,7 +1604,7 @@ fn test_declare_tx( // Verify class declaration. let contract_class_from_state = state.get_compiled_contract_class(class_hash).unwrap(); - assert_eq!(contract_class_from_state, class_info.contract_class()); + assert_eq!(contract_class_from_state, class_info.contract_class().try_into().unwrap()); // Checks that redeclaring the same contract fails. let account_tx2 = declare_tx( @@ -1574,7 +1612,7 @@ fn test_declare_tx( max_fee: MAX_FEE, sender_address, version: tx_version, - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, class_hash, compiled_class_hash, nonce: nonce_manager.next(sender_address), @@ -1626,7 +1664,7 @@ fn test_declare_tx_v0(default_l1_resource_bounds: ValidResourceBounds) { fn test_deploy_account_tx( #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] cairo_version: CairoVersion, #[values(false, true)] use_kzg_da: bool, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, ) { let block_context = &BlockContext::create_for_account_testing_with_kzg(use_kzg_da); let versioned_constants = &block_context.versioned_constants; @@ -1636,7 +1674,10 @@ fn test_deploy_account_tx( let account_class_hash = account.get_class_hash(); let state = &mut test_state(chain_info, BALANCE, &[(account, 1)]); let deploy_account = deploy_account_tx( - deploy_account_tx_args! { resource_bounds: default_l1_resource_bounds, class_hash: account_class_hash }, + deploy_account_tx_args! { + resource_bounds: default_all_resource_bounds, + class_hash: account_class_hash + }, &mut nonce_manager, ); @@ -1646,6 +1687,7 @@ fn test_deploy_account_tx( let deployed_account_address = deploy_account.contract_address(); let constructor_calldata = deploy_account.constructor_calldata(); let salt = deploy_account.contract_address_salt(); + let user_initial_gas = user_initial_gas_from_bounds(default_all_resource_bounds); // Update the balance of the about to be deployed account contract in the erc20 contract, so it // can pay for the transaction execution. @@ -1677,9 +1719,9 @@ fn test_deploy_account_tx( deployed_account_address, cairo_version, account - .get_class() + .get_runnable_class() .tracked_resource(&versioned_constants.min_compiler_version_for_sierra_gas), - user_initial_gas_from_bounds(default_l1_resource_bounds), + user_initial_gas, ); // Build expected execute call info. @@ -1690,7 +1732,7 @@ fn test_deploy_account_tx( entry_point_type: EntryPointType::Constructor, entry_point_selector: selector_from_name(abi_constants::CONSTRUCTOR_ENTRY_POINT_NAME), storage_address: deployed_account_address, - initial_gas: default_initial_gas_cost(), + initial_gas: user_initial_gas.unwrap_or(GasAmount(default_initial_gas_cost())).0, ..Default::default() }, ..Default::default() @@ -1738,7 +1780,7 @@ fn test_deploy_account_tx( let expected_total_gas = actual_resources.to_gas_vector( &block_context.versioned_constants, block_context.block_info.use_kzg_da, - &GasVectorComputationMode::NoL2Gas, + &tx_context.get_gas_vector_computation_mode(), ); let expected_execution_info = TransactionExecutionInfo { @@ -1779,7 +1821,10 @@ fn test_deploy_account_tx( // Negative flow. // Deploy to an existing address. let deploy_account = deploy_account_tx( - deploy_account_tx_args! { resource_bounds: default_l1_resource_bounds, class_hash: account_class_hash }, + deploy_account_tx_args! { + resource_bounds: default_all_resource_bounds, + class_hash: account_class_hash + }, &mut nonce_manager, ); let account_tx = AccountTransaction::DeployAccount(deploy_account); @@ -1800,7 +1845,7 @@ fn test_deploy_account_tx( #[rstest] fn test_fail_deploy_account_undeclared_class_hash( block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, ) { let block_context = &block_context; let chain_info = &block_context.chain_info; @@ -1808,7 +1853,9 @@ fn test_fail_deploy_account_undeclared_class_hash( let mut nonce_manager = NonceManager::default(); let undeclared_hash = class_hash!("0xdeadbeef"); let deploy_account = deploy_account_tx( - deploy_account_tx_args! {resource_bounds: default_l1_resource_bounds, class_hash: undeclared_hash }, + deploy_account_tx_args! { + resource_bounds: default_all_resource_bounds, class_hash: undeclared_hash + }, &mut nonce_manager, ); let tx_context = block_context.to_tx_context(&deploy_account); @@ -2038,7 +2085,7 @@ fn test_validate_accounts_tx( #[rstest] fn test_valid_flag( block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] account_cairo_version: CairoVersion, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] test_contract_cairo_version: CairoVersion, ) { @@ -2054,7 +2101,7 @@ fn test_valid_flag( let account_tx = account_invoke_tx(invoke_tx_args! { sender_address: account_contract.get_instance_address(0), calldata: create_trivial_calldata(test_contract.get_instance_address(0)), - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, }); let actual_execution_info = account_tx.execute(state, block_context, true, false).unwrap(); @@ -2066,7 +2113,7 @@ fn test_valid_flag( #[rstest] fn test_only_query_flag( block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[values(true, false)] only_query: bool, ) { let account_balance = BALANCE; @@ -2096,13 +2143,16 @@ fn test_only_query_flag( ]; let expected_resource_bounds = vec![ - Felt::TWO, // Length of ResourceBounds array. - felt!(L1Gas.to_hex()), // Resource. - felt!(DEFAULT_L1_GAS_AMOUNT.0), // Max amount. - felt!(DEFAULT_STRK_L1_GAS_PRICE.get().0), // Max price per unit. - felt!(L2Gas.to_hex()), // Resource. - Felt::ZERO, // Max amount. - Felt::ZERO, // Max price per unit. + Felt::THREE, // Length of ResourceBounds array. + felt!(L1Gas.to_hex()), // Resource. + felt!(DEFAULT_L1_GAS_AMOUNT.0), // Max amount. + felt!(DEFAULT_STRK_L1_GAS_PRICE.get().0), // Max price per unit. + felt!(L2Gas.to_hex()), // Resource. + felt!(DEFAULT_L2_GAS_MAX_AMOUNT.0), // Max amount. + felt!(DEFAULT_STRK_L2_GAS_PRICE.get().0), // Max price per unit. + felt!(L1DataGas.to_hex()), // Resource. + felt!(DEFAULT_L1_DATA_GAS_MAX_AMOUNT.0), // Max amount. + felt!(DEFAULT_STRK_L1_DATA_GAS_PRICE.get().0), // Max price per unit. ]; let expected_unsupported_fields = vec![ @@ -2150,14 +2200,14 @@ fn test_only_query_flag( ); let invoke_tx = crate::test_utils::invoke::invoke_tx(invoke_tx_args! { calldata: execute_calldata, - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, sender_address, only_query, }); let account_tx = AccountTransaction::Invoke(invoke_tx); let tx_execution_info = account_tx.execute(state, block_context, true, true).unwrap(); - assert!(!tx_execution_info.is_reverted()) + assert_eq!(tx_execution_info.revert_error, None); } #[rstest] @@ -2169,7 +2219,7 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) { let block_context = &BlockContext::create_for_account_testing_with_kzg(use_kzg_da); let contract_address = test_contract.get_instance_address(0); let versioned_constants = &block_context.versioned_constants; - let tx = L1HandlerTransaction::create_for_testing(Fee(1), contract_address); + let tx = l1handler_tx(Fee(1), contract_address); let calldata = tx.tx.calldata.clone(); let key = calldata.0[1]; let value = calldata.0[2]; @@ -2202,7 +2252,7 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) { }, accessed_storage_keys: HashSet::from_iter(vec![accessed_storage_key]), tracked_resource: test_contract - .get_class() + .get_runnable_class() .tracked_resource(&versioned_constants.min_compiler_version_for_sierra_gas), ..Default::default() }; @@ -2303,7 +2353,7 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) { // TODO(Meshi, 15/6/2024): change the l1_handler_set_value cairo function to // always uptade the storage instad. state.set_storage_at(contract_address, StorageKey::try_from(key).unwrap(), Felt::ZERO).unwrap(); - let tx_no_fee = L1HandlerTransaction::create_for_testing(Fee(0), contract_address); + let tx_no_fee = l1handler_tx(Fee(0), contract_address); let error = tx_no_fee.execute(state, block_context, false, true).unwrap_err(); // Do not charge fee as L1Handler's resource bounds (/max fee) is 0. // Today, we check that the paid_fee is positive, no matter what was the actual fee. let expected_actual_fee = @@ -2320,7 +2370,7 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) { #[rstest] fn test_execute_tx_with_invalid_tx_version( block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, ) { let cairo_version = CairoVersion::Cairo0; let account = FeatureContract::AccountWithoutValidations(cairo_version); @@ -2335,7 +2385,7 @@ fn test_execute_tx_with_invalid_tx_version( &[felt!(invalid_version)], ); let account_tx = account_invoke_tx(invoke_tx_args! { - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, sender_address: account.get_instance_address(0), calldata, }); @@ -2393,7 +2443,7 @@ fn max_event_data() -> usize { }))] fn test_emit_event_exceeds_limit( block_context: BlockContext, - default_l1_resource_bounds: ValidResourceBounds, + default_all_resource_bounds: ValidResourceBounds, #[case] event_keys: Vec, #[case] event_data: Vec, #[case] n_emitted_events: usize, @@ -2432,7 +2482,7 @@ fn test_emit_event_exceeds_limit( let account_tx = account_invoke_tx(invoke_tx_args! { sender_address: account_contract.get_instance_address(0), calldata: execute_calldata, - resource_bounds: default_l1_resource_bounds, + resource_bounds: default_all_resource_bounds, nonce: nonce!(0_u8), }); let execution_info = account_tx.execute(state, block_context, true, true).unwrap(); diff --git a/crates/blockifier_reexecution/Cargo.toml b/crates/blockifier_reexecution/Cargo.toml index de662f4994..528997ecd9 100644 --- a/crates/blockifier_reexecution/Cargo.toml +++ b/crates/blockifier_reexecution/Cargo.toml @@ -12,7 +12,6 @@ blockifier_regression_https_testing = [] blockifier.workspace = true cairo-lang-starknet-classes.workspace = true cairo-lang-utils.workspace = true -cairo-vm.workspace = true clap = { workspace = true, features = ["cargo", "derive"] } flate2.workspace = true indexmap = { workspace = true, features = ["serde"] } diff --git a/crates/blockifier_reexecution/src/state_reader/compile.rs b/crates/blockifier_reexecution/src/state_reader/compile.rs index 96ab2675d1..4088dcea60 100644 --- a/crates/blockifier_reexecution/src/state_reader/compile.rs +++ b/crates/blockifier_reexecution/src/state_reader/compile.rs @@ -4,18 +4,20 @@ use std::collections::HashMap; use std::io::{self, Read}; -use std::sync::Arc; -use blockifier::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV0Inner}; use blockifier::state::state_api::StateResult; use cairo_lang_starknet_classes::contract_class::ContractEntryPoints; use cairo_lang_utils::bigint::BigUintAsHex; -use cairo_vm::types::program::Program; use flate2::bufread; use serde::Deserialize; -use starknet_api::contract_class::EntryPointType; +use starknet_api::contract_class::{ContractClass, EntryPointType}; use starknet_api::core::EntryPointSelector; -use starknet_api::deprecated_contract_class::{EntryPointOffset, EntryPointV0}; +use starknet_api::deprecated_contract_class::{ + ContractClass as DeprecatedContractClass, + EntryPointOffset, + EntryPointV0, + Program, +}; use starknet_api::hash::StarkHash; use starknet_core::types::{ CompressedLegacyContractClass, @@ -90,8 +92,9 @@ pub fn sierra_to_contact_class_v1(sierra: FlattenedSierraClass) -> StateResult StateResult { let as_str = decode_reader(legacy.program).unwrap(); - let program = Program::from_bytes(as_str.as_bytes(), None).unwrap(); + let program: Program = serde_json::from_str(&as_str).unwrap(); let entry_points_by_type = map_entry_points_by_type_legacy(legacy.entry_points_by_type); - let inner = Arc::new(ContractClassV0Inner { program, entry_points_by_type }); - Ok(ContractClass::V0(ContractClassV0(inner))) + Ok((DeprecatedContractClass { program, entry_points_by_type, abi: None }).into()) } diff --git a/crates/blockifier_reexecution/src/state_reader/reexecution_state_reader.rs b/crates/blockifier_reexecution/src/state_reader/reexecution_state_reader.rs index 843a3278d9..e0b1b67734 100644 --- a/crates/blockifier_reexecution/src/state_reader/reexecution_state_reader.rs +++ b/crates/blockifier_reexecution/src/state_reader/reexecution_state_reader.rs @@ -1,4 +1,5 @@ use blockifier::execution::contract_class::ClassInfo; +use blockifier::state::state_api::StateResult; use blockifier::transaction::transaction_execution::Transaction as BlockifierTransaction; use papyrus_execution::DEPRECATED_CONTRACT_SIERRA_SIZE; use starknet_api::core::ClassHash; @@ -10,11 +11,10 @@ use crate::state_reader::errors::ReexecutionError; use crate::state_reader::test_state_reader::ReexecutionResult; pub(crate) trait ReexecutionStateReader { - fn get_contract_class(&self, class_hash: ClassHash) - -> ReexecutionResult; + fn get_contract_class(&self, class_hash: &ClassHash) -> StateResult; fn get_class_info(&self, class_hash: ClassHash) -> ReexecutionResult { - match self.get_contract_class(class_hash)? { + match self.get_contract_class(&class_hash)? { StarknetContractClass::Sierra(sierra) => { let abi_length = sierra.abi.len(); let sierra_length = sierra.sierra_program.len(); diff --git a/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs b/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs index e7902b3c9e..9931019190 100644 --- a/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs +++ b/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs @@ -6,8 +6,8 @@ use blockifier::blockifier::config::TransactionExecutorConfig; use blockifier::blockifier::transaction_executor::TransactionExecutor; use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; -use blockifier::execution::contract_class::ContractClass as BlockifierContractClass; -use blockifier::state::cached_state::{CachedState, CommitmentStateDiff}; +use blockifier::execution::contract_class::RunnableContractClass; +use blockifier::state::cached_state::{CachedState, CommitmentStateDiff, StateMaps}; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; use blockifier::transaction::transaction_execution::Transaction as BlockifierTransaction; @@ -42,6 +42,14 @@ pub type ReexecutionResult = Result; pub type StarknetContractClassMapping = HashMap; +pub struct OfflineReexecutionData { + state_maps: StateMaps, + contract_class_mapping: StarknetContractClassMapping, + block_context_next_block: BlockContext, + transactions_next_block: Vec, + state_diff_next_block: CommitmentStateDiff, +} + pub struct TestStateReader { rpc_state_reader: RpcStateReader, #[allow(dead_code)] @@ -70,10 +78,14 @@ impl StateReader for TestStateReader { fn get_compiled_contract_class( &self, class_hash: ClassHash, - ) -> StateResult { + ) -> StateResult { match self.get_contract_class(&class_hash)? { - StarknetContractClass::Sierra(sierra) => sierra_to_contact_class_v1(sierra), - StarknetContractClass::Legacy(legacy) => legacy_to_contract_class_v0(legacy), + StarknetContractClass::Sierra(sierra) => { + Ok(sierra_to_contact_class_v1(sierra).unwrap().try_into().unwrap()) + } + StarknetContractClass::Legacy(legacy) => { + Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap()) + } } } @@ -158,24 +170,6 @@ impl TestStateReader { )?) } - pub fn get_contract_class(&self, class_hash: &ClassHash) -> StateResult { - let params = json!({ - "block_id": self.rpc_state_reader.block_id, - "class_hash": class_hash.0.to_string(), - }); - let contract_class: StarknetContractClass = serde_json::from_value( - self.rpc_state_reader.send_rpc_request("starknet_getClass", params.clone())?, - ) - .map_err(serde_err_to_state_err)?; - // Create a binding to avoid value being dropped. - let mut dumper_binding = self.contract_class_mapping_dumper.lock().unwrap(); - // If dumper exists, insert the contract class to the mapping. - if let Some(contract_class_mapping_dumper) = dumper_binding.as_mut() { - contract_class_mapping_dumper.insert(*class_hash, contract_class.clone()); - } - Ok(contract_class) - } - pub fn get_all_txs_in_block(&self) -> ReexecutionResult> { // TODO(Aviv): Use batch request to get all txs in a block. self.get_tx_hashes()? @@ -262,10 +256,7 @@ impl TestStateReader { } impl ReexecutionStateReader for TestStateReader { - fn get_contract_class( - &self, - class_hash: ClassHash, - ) -> ReexecutionResult { + fn get_contract_class(&self, class_hash: &ClassHash) -> StateResult { let params = json!({ "block_id": self.rpc_state_reader.block_id, "class_hash": class_hash.0.to_string(), @@ -274,6 +265,12 @@ impl ReexecutionStateReader for TestStateReader { self.rpc_state_reader.send_rpc_request("starknet_getClass", params.clone())?, ) .map_err(serde_err_to_state_err)?; + // Create a binding to avoid value being dropped. + let mut dumper_binding = self.contract_class_mapping_dumper.lock().unwrap(); + // If dumper exists, insert the contract class to the mapping. + if let Some(contract_class_mapping_dumper) = dumper_binding.as_mut() { + contract_class_mapping_dumper.insert(*class_hash, contract_class.clone()); + } Ok(contract_class) } } @@ -337,3 +334,133 @@ impl ConsecutiveStateReaders for ConsecutiveTestStateReaders { self.next_block_state_reader.get_state_diff() } } + +pub struct OfflineStateReader { + pub state_maps: StateMaps, + pub contract_class_mapping: StarknetContractClassMapping, +} + +impl StateReader for OfflineStateReader { + fn get_storage_at( + &self, + contract_address: ContractAddress, + key: StorageKey, + ) -> StateResult { + Ok(*self.state_maps.storage.get(&(contract_address, key)).ok_or( + StateError::StateReadError(format!( + "Missing Storage Value at contract_address: {}, key:{:?}", + contract_address, key + )), + )?) + } + + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { + Ok(*self.state_maps.nonces.get(&contract_address).ok_or(StateError::StateReadError( + format!("Missing nonce at contract_address: {contract_address}"), + ))?) + } + + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { + Ok(*self.state_maps.class_hashes.get(&contract_address).ok_or( + StateError::StateReadError(format!( + "Missing class hash at contract_address: {contract_address}" + )), + )?) + } + + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + match self.get_contract_class(&class_hash)? { + StarknetContractClass::Sierra(sierra) => { + Ok(sierra_to_contact_class_v1(sierra).unwrap().try_into().unwrap()) + } + StarknetContractClass::Legacy(legacy) => { + Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap()) + } + } + } + + fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { + Ok(*self.state_maps.compiled_class_hashes.get(&class_hash).ok_or( + StateError::StateReadError(format!( + "Missing compiled class hash at class hash: {class_hash}" + )), + )?) + } +} + +impl OfflineStateReader { + pub fn get_transaction_executor( + self, + block_context_next_block: BlockContext, + transaction_executor_config: Option, + ) -> ReexecutionResult> { + Ok(TransactionExecutor::::new( + CachedState::new(self), + block_context_next_block, + transaction_executor_config.unwrap_or_default(), + )) + } +} + +impl ReexecutionStateReader for OfflineStateReader { + fn get_contract_class(&self, class_hash: &ClassHash) -> StateResult { + Ok(self + .contract_class_mapping + .get(class_hash) + .ok_or(StateError::StateReadError(format!( + "Missing contract class at class hash: {class_hash}" + )))? + .clone()) + } +} + +pub struct OfflineConsecutiveStateReaders { + pub offline_state_reader_prev_block: OfflineStateReader, + pub block_context_next_block: BlockContext, + pub transactions_next_block: Vec, + pub state_diff_next_block: CommitmentStateDiff, +} + +impl OfflineConsecutiveStateReaders { + // TODO(Aner): create directly from json. + pub fn new( + OfflineReexecutionData { + state_maps, + contract_class_mapping, + block_context_next_block, + transactions_next_block, + state_diff_next_block, + }: OfflineReexecutionData, + ) -> Self { + OfflineConsecutiveStateReaders { + offline_state_reader_prev_block: OfflineStateReader { + state_maps, + contract_class_mapping, + }, + block_context_next_block, + transactions_next_block, + state_diff_next_block, + } + } +} + +impl ConsecutiveStateReaders for OfflineConsecutiveStateReaders { + fn get_transaction_executor( + self, + transaction_executor_config: Option, + ) -> ReexecutionResult> { + self.offline_state_reader_prev_block + .get_transaction_executor(self.block_context_next_block, transaction_executor_config) + } + + fn get_next_block_txs(&self) -> ReexecutionResult> { + Ok(self.transactions_next_block.clone()) + } + + fn get_next_block_state_diff(&self) -> ReexecutionResult { + Ok(self.state_diff_next_block.clone()) + } +} diff --git a/crates/gateway/src/rpc_state_reader.rs b/crates/gateway/src/rpc_state_reader.rs index 60fec240bd..39c325b003 100644 --- a/crates/gateway/src/rpc_state_reader.rs +++ b/crates/gateway/src/rpc_state_reader.rs @@ -1,5 +1,9 @@ use blockifier::blockifier::block::BlockInfo; -use blockifier::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV1}; +use blockifier::execution::contract_class::{ + ContractClassV0, + ContractClassV1, + RunnableContractClass, +}; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; use papyrus_rpc::CompiledContractClass; @@ -125,7 +129,10 @@ impl BlockifierStateReader for RpcStateReader { } } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let get_compiled_class_params = GetCompiledContractClassParams { class_hash, block_id: self.block_id }; @@ -134,10 +141,10 @@ impl BlockifierStateReader for RpcStateReader { let contract_class: CompiledContractClass = serde_json::from_value(result).map_err(serde_err_to_state_err)?; match contract_class { - CompiledContractClass::V1(contract_class_v1) => Ok(ContractClass::V1( + CompiledContractClass::V1(contract_class_v1) => Ok(RunnableContractClass::V1( ContractClassV1::try_from(contract_class_v1).map_err(StateError::ProgramError)?, )), - CompiledContractClass::V0(contract_class_v0) => Ok(ContractClass::V0( + CompiledContractClass::V0(contract_class_v0) => Ok(RunnableContractClass::V0( ContractClassV0::try_from(contract_class_v0).map_err(StateError::ProgramError)?, )), } diff --git a/crates/gateway/src/rpc_state_reader_test.rs b/crates/gateway/src/rpc_state_reader_test.rs index 54acae03db..c5c53487eb 100644 --- a/crates/gateway/src/rpc_state_reader_test.rs +++ b/crates/gateway/src/rpc_state_reader_test.rs @@ -1,4 +1,4 @@ -use blockifier::execution::contract_class::ContractClass; +use blockifier::execution::contract_class::RunnableContractClass; use blockifier::state::state_api::StateReader; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use papyrus_rpc::CompiledContractClass; @@ -181,7 +181,7 @@ async fn test_get_compiled_contract_class() { .await .unwrap() .unwrap(); - assert_eq!(result, ContractClass::V1(expected_result.try_into().unwrap())); + assert_eq!(result, RunnableContractClass::V1(expected_result.try_into().unwrap())); mock.assert_async().await; } diff --git a/crates/gateway/src/state_reader.rs b/crates/gateway/src/state_reader.rs index 0aa555b050..993fe4b5d5 100644 --- a/crates/gateway/src/state_reader.rs +++ b/crates/gateway/src/state_reader.rs @@ -1,5 +1,5 @@ use blockifier::blockifier::block::BlockInfo; -use blockifier::execution::contract_class::ContractClass; +use blockifier::execution::contract_class::RunnableContractClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; #[cfg(test)] @@ -45,7 +45,10 @@ impl BlockifierStateReader for Box { self.as_ref().get_class_hash_at(contract_address) } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { self.as_ref().get_compiled_contract_class(class_hash) } diff --git a/crates/gateway/src/state_reader_test_utils.rs b/crates/gateway/src/state_reader_test_utils.rs index c080a3e0d7..d2b33bc6a5 100644 --- a/crates/gateway/src/state_reader_test_utils.rs +++ b/crates/gateway/src/state_reader_test_utils.rs @@ -1,6 +1,6 @@ use blockifier::blockifier::block::BlockInfo; use blockifier::context::BlockContext; -use blockifier::execution::contract_class::ContractClass; +use blockifier::execution::contract_class::RunnableContractClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; use blockifier::test_utils::contracts::FeatureContract; @@ -44,7 +44,10 @@ impl BlockifierStateReader for TestStateReader { self.blockifier_state_reader.get_class_hash_at(contract_address) } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { self.blockifier_state_reader.get_compiled_contract_class(class_hash) } diff --git a/crates/mempool/src/mempool.rs b/crates/mempool/src/mempool.rs index d5d70ded6b..e7a33f4ef9 100644 --- a/crates/mempool/src/mempool.rs +++ b/crates/mempool/src/mempool.rs @@ -109,18 +109,20 @@ impl Mempool { )] pub fn add_tx(&mut self, args: AddTransactionArgs) -> MempoolResult<()> { let AddTransactionArgs { tx, account_state } = args; - self.validate_incoming_tx_nonce(tx.contract_address(), tx.nonce())?; + let tx_ref = TransactionReference::new(&tx); + self.validate_incoming_tx_nonce(tx_ref.address, tx_ref.nonce)?; self.handle_fee_escalation(&tx)?; self.tx_pool.insert(tx)?; // Align to account nonce, only if it is at least the one stored. let AccountState { address, nonce: incoming_account_nonce } = account_state; - match self.account_nonces.get(&address) { - Some(stored_account_nonce) if &incoming_account_nonce < stored_account_nonce => {} - _ => { - self.align_to_account_state(account_state); - } + // TODO(Elin): abstract mempool nonces. + let mempool_account_nonce = self.mempool_state.get(&address).unwrap_or_else(|| { + self.account_nonces.entry(address).or_insert(incoming_account_nonce) + }); + if tx_ref.nonce == *mempool_account_nonce { + self.tx_queue.insert(tx_ref); } Ok(()) @@ -130,12 +132,24 @@ impl Mempool { /// updates account balances). #[tracing::instrument(skip(self, args), err)] pub fn commit_block(&mut self, args: CommitBlockArgs) -> MempoolResult<()> { - let CommitBlockArgs { nonces, tx_hashes } = args; + let CommitBlockArgs { nonces: address_to_nonce, tx_hashes } = args; tracing::debug!("Committing block with {} transactions to mempool.", tx_hashes.len()); // Align mempool data to committed nonces. - for (&address, &nonce) in &nonces { - let next_nonce = try_increment_nonce(nonce)?; + for (&address, &next_nonce) in &address_to_nonce { + // FIXME: Remove after first POC. + // If commit_block wants to decrease the stored account nonce this can mean one of two + // things: + // 1. this is a reorg, which should be handled by a dedicated TBD mechanism and not + // inside commit_block + // 2. the stored nonce originated from add_tx, so should be treated as tentative due + // to possible races with the gateway; these types of nonces should be tagged somehow + // so that commit_block can override them. Regardless, in the first POC this cannot + // happen because the GW nonces are always 1. + if let Some(&stored_nonce) = self.account_nonces.get(&address) { + assert!(stored_nonce <= next_nonce, "NOT SUPPORTED YET {address:?} {next_nonce:?}.") + } + let account_state = AccountState { address, nonce: next_nonce }; self.align_to_account_state(account_state); } @@ -143,7 +157,7 @@ impl Mempool { // Rewind nonces of addresses that were not included in block. let known_addresses_not_included_in_block = - self.mempool_state.keys().filter(|&key| !nonces.contains_key(key)); + self.mempool_state.keys().filter(|&key| !address_to_nonce.contains_key(key)); for address in known_addresses_not_included_in_block { // Account nonce is the minimal nonce of this address: it was proposed but not included. let tx_reference = self diff --git a/crates/mempool/src/mempool_test.rs b/crates/mempool/src/mempool_test.rs index 604b463d0d..c2a410451f 100644 --- a/crates/mempool/src/mempool_test.rs +++ b/crates/mempool/src/mempool_test.rs @@ -102,8 +102,7 @@ impl MempoolContentBuilder { } fn with_fee_escalation_percentage(mut self, fee_escalation_percentage: u8) -> Self { - self.config.enable_fee_escalation = true; - self.config.fee_escalation_percentage = fee_escalation_percentage; + self.config = MempoolConfig { enable_fee_escalation: true, fee_escalation_percentage }; self } @@ -316,7 +315,7 @@ fn test_add_tx(mut mempool: Mempool) { } #[rstest] -fn test_add_tx_multi_nonce_success(mut mempool: Mempool) { +fn test_add_tx_correctly_places_txs_in_queue_and_pool(mut mempool: Mempool) { // Setup. let input_address_0_nonce_0 = add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 0, account_nonce: 0); @@ -421,34 +420,6 @@ fn test_add_tx_with_identical_tip_succeeds(mut mempool: Mempool) { expected_mempool_content.assert_eq(&mempool); } -#[rstest] -fn test_add_tx_delete_tx_with_lower_nonce_than_account_nonce() { - // Setup. - let tx_nonce_0_account_nonce_0 = - add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 0, account_nonce: 0); - let tx_nonce_1_account_nonce_1 = - add_tx_input!(tx_hash: 2, address: "0x0", tx_nonce: 1, account_nonce: 1); - - let queue_txs = [TransactionReference::new(&tx_nonce_0_account_nonce_0.tx)]; - let pool_txs = [tx_nonce_0_account_nonce_0.tx]; - let mut mempool = MempoolContentBuilder::new() - .with_pool(pool_txs) - .with_priority_queue(queue_txs) - .build_into_mempool(); - - // Test. - add_tx(&mut mempool, &tx_nonce_1_account_nonce_1); - - // Assert the transaction with the lower nonce is removed. - let expected_queue_txs = [TransactionReference::new(&tx_nonce_1_account_nonce_1.tx)]; - let expected_pool_txs = [tx_nonce_1_account_nonce_1.tx]; - let expected_mempool_content = MempoolContentBuilder::new() - .with_pool(expected_pool_txs) - .with_priority_queue(expected_queue_txs) - .build(); - expected_mempool_content.assert_eq(&mempool); -} - #[rstest] fn test_add_tx_tip_priority_over_tx_hash(mut mempool: Mempool) { // Setup. @@ -470,49 +441,6 @@ fn test_add_tx_tip_priority_over_tx_hash(mut mempool: Mempool) { expected_mempool_content.assert_eq(&mempool); } -#[rstest] -fn test_add_tx_account_state_fills_nonce_gap(mut mempool: Mempool) { - // Setup. - let tx_input_nonce_1 = add_tx_input!(tx_hash: 1, tx_nonce: 1, account_nonce: 0); - // Input that increments the account state. - let tx_input_nonce_2 = add_tx_input!(tx_hash: 2, tx_nonce: 2, account_nonce: 1); - - // Test and assert. - - // First, with gap. - add_tx(&mut mempool, &tx_input_nonce_1); - let expected_mempool_content = MempoolContentBuilder::new().with_priority_queue([]).build(); - expected_mempool_content.assert_eq(&mempool); - - // Then, fill it. - add_tx(&mut mempool, &tx_input_nonce_2); - let expected_mempool_content = MempoolContentBuilder::new() - .with_priority_queue([TransactionReference::new(&tx_input_nonce_1.tx)]) - .build(); - expected_mempool_content.assert_eq(&mempool); -} - -#[rstest] -fn test_add_tx_sequential_nonces(mut mempool: Mempool) { - // Setup. - let input_nonce_0 = add_tx_input!(tx_hash: 0, tx_nonce: 0, account_nonce: 0); - let input_nonce_1 = add_tx_input!(tx_hash: 1, tx_nonce: 1, account_nonce: 0); - - // Test. - for input in [&input_nonce_0, &input_nonce_1] { - add_tx(&mut mempool, input); - } - - // Assert: only eligible transaction appears in the queue. - let expected_queue_txs = [TransactionReference::new(&input_nonce_0.tx)]; - let expected_pool_txs = [input_nonce_0.tx, input_nonce_1.tx]; - let expected_mempool_content = MempoolContentBuilder::new() - .with_pool(expected_pool_txs) - .with_priority_queue(expected_queue_txs) - .build(); - expected_mempool_content.assert_eq(&mempool); -} - #[rstest] fn test_add_tx_fills_nonce_gap(mut mempool: Mempool) { // Setup. @@ -553,7 +481,7 @@ fn test_commit_block_includes_all_proposed_txs() { let tx_address_1_nonce_3 = tx!(tx_hash: 5, address: "0x1", tx_nonce: 3); let tx_address_2_nonce_1 = tx!(tx_hash: 6, address: "0x2", tx_nonce: 1); - let queue_txs = [&tx_address_0_nonce_4, &tx_address_1_nonce_3, &tx_address_2_nonce_1] + let queue_txs = [&tx_address_2_nonce_1, &tx_address_1_nonce_3, &tx_address_0_nonce_4] .map(TransactionReference::new); let pool_txs = [ tx_address_0_nonce_3, @@ -569,7 +497,7 @@ fn test_commit_block_includes_all_proposed_txs() { .build_into_mempool(); // Test. - let nonces = [("0x0", 3), ("0x1", 2)]; + let nonces = [("0x0", 4), ("0x1", 3)]; let tx_hashes = [1, 4]; commit_block(&mut mempool, nonces, tx_hashes); @@ -581,80 +509,6 @@ fn test_commit_block_includes_all_proposed_txs() { expected_mempool_content.assert_eq(&mempool); } -#[rstest] -fn test_commit_block_rewinds_queued_nonce() { - // Setup. - let tx_address_0_nonce_3 = tx!(tx_hash: 1, address: "0x0", tx_nonce: 3); - let tx_address_0_nonce_4 = tx!(tx_hash: 2, address: "0x0", tx_nonce: 4); - let tx_address_0_nonce_5 = tx!(tx_hash: 3, address: "0x0", tx_nonce: 5); - let tx_address_1_nonce_1 = tx!(tx_hash: 4, address: "0x1", tx_nonce: 1); - - let queued_txs = [&tx_address_0_nonce_5, &tx_address_1_nonce_1].map(TransactionReference::new); - let pool_txs = [ - tx_address_0_nonce_3, - tx_address_0_nonce_4.clone(), - tx_address_0_nonce_5, - tx_address_1_nonce_1, - ]; - let mut mempool = MempoolContentBuilder::new() - .with_pool(pool_txs) - .with_priority_queue(queued_txs) - .build_into_mempool(); - - // Test. - let nonces = [("0x0", 3), ("0x1", 1)]; - let tx_hashes = [1, 4]; - commit_block(&mut mempool, nonces, tx_hashes); - - // Assert. - let expected_queue_txs = [TransactionReference::new(&tx_address_0_nonce_4)]; - let expected_mempool_content = - MempoolContentBuilder::new().with_priority_queue(expected_queue_txs).build(); - expected_mempool_content.assert_eq(&mempool); -} - -#[rstest] -fn test_commit_block_from_different_leader() { - // Setup. - let tx_address_0_nonce_3 = tx!(tx_hash: 1, address: "0x0", tx_nonce: 3); - let tx_address_0_nonce_5 = tx!(tx_hash: 2, address: "0x0", tx_nonce: 5); - let tx_address_0_nonce_6 = tx!(tx_hash: 3, address: "0x0", tx_nonce: 6); - let tx_address_1_nonce_2 = tx!(tx_hash: 4, address: "0x1", tx_nonce: 2); - - let queued_txs = [TransactionReference::new(&tx_address_1_nonce_2)]; - let pool_txs = [ - tx_address_0_nonce_3, - tx_address_0_nonce_5, - tx_address_0_nonce_6.clone(), - tx_address_1_nonce_2.clone(), - ]; - let mut mempool = MempoolContentBuilder::new() - .with_pool(pool_txs) - .with_priority_queue(queued_txs) - .build_into_mempool(); - - // Test. - let nonces = [ - ("0x0", 5), - ("0x1", 0), // A hole, missing nonce 1 for address "0x1". - ("0x2", 1), - ]; - let tx_hashes = [ - 1, 2, // Hashes known to mempool. - 5, 6, // Hashes unknown to mempool, from a different node. - ]; - commit_block(&mut mempool, nonces, tx_hashes); - - // Assert. - let expected_queue_txs = [TransactionReference::new(&tx_address_0_nonce_6)]; - let expected_pool_txs = [tx_address_0_nonce_6, tx_address_1_nonce_2]; - let expected_mempool_content = MempoolContentBuilder::new() - .with_pool(expected_pool_txs) - .with_priority_queue(expected_queue_txs) - .build(); - expected_mempool_content.assert_eq(&mempool); -} - // Fee escalation tests. #[rstest] diff --git a/crates/mempool/src/transaction_queue_test_utils.rs b/crates/mempool/src/transaction_queue_test_utils.rs index 4c4b3f10b7..38dcf993ad 100644 --- a/crates/mempool/src/transaction_queue_test_utils.rs +++ b/crates/mempool/src/transaction_queue_test_utils.rs @@ -1,18 +1,20 @@ -use std::collections::{BTreeSet, HashMap}; +use std::collections::HashMap; -use pretty_assertions::assert_eq; use starknet_api::block::GasPrice; use crate::mempool::TransactionReference; use crate::transaction_queue::{PendingTransaction, PriorityTransaction, TransactionQueue}; +type OptionalPriorityTransactions = Option>; +type OptionalPendingTransactions = Option>; + /// Represents the internal content of the transaction queue. /// Enables customized (and potentially inconsistent) creation for unit testing. /// Note: gas price threshold is only used for building the (non-test) queue struct. #[derive(Debug, Default)] pub struct TransactionQueueContent { - priority_queue: Option>, - pending_queue: Option>, + priority_queue: OptionalPriorityTransactions, + pending_queue: OptionalPendingTransactions, gas_price_threshold: Option, } @@ -20,11 +22,16 @@ impl TransactionQueueContent { #[track_caller] pub fn assert_eq(&self, tx_queue: &TransactionQueue) { if let Some(priority_queue) = &self.priority_queue { - assert_eq!(&tx_queue.priority_queue, priority_queue); + let expected_priority_txs: Vec<_> = priority_queue.iter().map(|tx| &tx.0).collect(); + let actual_priority_txs: Vec<_> = tx_queue.iter_over_ready_txs().collect(); + assert_eq!(actual_priority_txs, expected_priority_txs); } if let Some(pending_queue) = &self.pending_queue { - assert_eq!(&tx_queue.pending_queue, pending_queue); + let expected_pending_txs: Vec<_> = pending_queue.iter().map(|tx| &tx.0).collect(); + let actual_pending_txs: Vec<_> = + tx_queue.pending_queue.iter().rev().map(|tx| &tx.0).collect(); + assert_eq!(actual_pending_txs, expected_pending_txs); } } @@ -46,14 +53,19 @@ impl TransactionQueueContent { } } - TransactionQueue { priority_queue, pending_queue, address_to_tx, gas_price_threshold } + TransactionQueue { + priority_queue: priority_queue.into_iter().collect(), + pending_queue: pending_queue.into_iter().collect(), + address_to_tx, + gas_price_threshold, + } } } #[derive(Debug, Default)] pub struct TransactionQueueContentBuilder { - priority_queue: Option>, - pending_queue: Option>, + priority_queue: OptionalPriorityTransactions, + pending_queue: OptionalPendingTransactions, gas_price_threshold: Option, } diff --git a/crates/mempool/tests/flow_test.rs b/crates/mempool/tests/flow_test.rs index 04c8a626dc..8236a1625b 100644 --- a/crates/mempool/tests/flow_test.rs +++ b/crates/mempool/tests/flow_test.rs @@ -86,7 +86,7 @@ fn test_add_same_nonce_tx_after_previous_not_included_in_block(mut mempool: Memp &[tx_nonce_3_account_nonce_3.tx, tx_nonce_4_account_nonce_3.tx.clone()], ); - let nonces = [("0x0", 3)]; // Transaction with nonce 4 is not included in the block. + let nonces = [("0x0", 4)]; // Transaction with nonce 3 was included, 4 was not. let tx_hashes = [1]; commit_block(&mut mempool, nonces, tx_hashes); @@ -105,31 +105,54 @@ fn test_add_same_nonce_tx_after_previous_not_included_in_block(mut mempool: Memp ); } +#[rstest] +fn test_add_tx_handles_nonces_correctly(mut mempool: Mempool) { + // Setup. + let input_nonce_0 = add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 0, account_nonce: 0); + let input_nonce_1 = add_tx_input!(tx_hash: 2, address: "0x0", tx_nonce: 1, account_nonce: 1); + let input_nonce_2 = add_tx_input!(tx_hash: 3, address: "0x0", tx_nonce: 2, account_nonce: 0); + + // Test. + // Account is registered in mempool. + add_tx(&mut mempool, &input_nonce_0); + // Although the input account nonce is higher, mempool looks at its internal registry. + add_tx(&mut mempool, &input_nonce_1); + get_txs_and_assert_expected(&mut mempool, 2, &[input_nonce_0.tx, input_nonce_1.tx]); + // Although the input account nonce is lower, mempool looks at internal registry. + add_tx(&mut mempool, &input_nonce_2); + get_txs_and_assert_expected(&mut mempool, 1, &[input_nonce_2.tx]); +} + #[rstest] fn test_commit_block_includes_proposed_txs_subset(mut mempool: Mempool) { // Setup. + let tx_address_0_nonce_1 = + add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 1, account_nonce: 1); let tx_address_0_nonce_3 = - add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 3, account_nonce: 3); - let tx_address_0_nonce_5 = - add_tx_input!(tx_hash: 2, address: "0x0", tx_nonce: 5, account_nonce: 3); - let tx_address_0_nonce_6 = - add_tx_input!(tx_hash: 3, address: "0x0", tx_nonce: 6, account_nonce: 3); - let tx_address_1_nonce_0 = - add_tx_input!(tx_hash: 4, address: "0x1", tx_nonce: 0, account_nonce: 0); - let tx_address_1_nonce_1 = - add_tx_input!(tx_hash: 5, address: "0x1", tx_nonce: 1, account_nonce: 0); + add_tx_input!(tx_hash: 2, address: "0x0", tx_nonce: 3, account_nonce: 1); + let tx_address_0_nonce_4 = + add_tx_input!(tx_hash: 3, address: "0x0", tx_nonce: 4, account_nonce: 1); + let tx_address_1_nonce_2 = - add_tx_input!(tx_hash: 6, address: "0x1", tx_nonce: 2, account_nonce: 0); + add_tx_input!(tx_hash: 4, address: "0x1", tx_nonce: 2, account_nonce: 2); + let tx_address_1_nonce_3 = + add_tx_input!(tx_hash: 5, address: "0x1", tx_nonce: 3, account_nonce: 2); + let tx_address_1_nonce_4 = + add_tx_input!(tx_hash: 6, address: "0x1", tx_nonce: 4, account_nonce: 2); + + let tx_address_2_nonce_1 = + add_tx_input!(tx_hash: 7, address: "0x2", tx_nonce: 1, account_nonce: 1); let tx_address_2_nonce_2 = - add_tx_input!(tx_hash: 7, address: "0x2", tx_nonce: 2, account_nonce: 2); + add_tx_input!(tx_hash: 8, address: "0x2", tx_nonce: 2, account_nonce: 1); for input in [ - &tx_address_0_nonce_5, - &tx_address_0_nonce_6, &tx_address_0_nonce_3, + &tx_address_0_nonce_4, + &tx_address_0_nonce_1, + &tx_address_1_nonce_4, + &tx_address_1_nonce_3, &tx_address_1_nonce_2, - &tx_address_1_nonce_1, - &tx_address_1_nonce_0, + &tx_address_2_nonce_1, &tx_address_2_nonce_2, ] { add_tx(&mut mempool, input); @@ -139,28 +162,33 @@ fn test_commit_block_includes_proposed_txs_subset(mut mempool: Mempool) { get_txs_and_assert_expected( &mut mempool, 2, - &[tx_address_2_nonce_2.tx.clone(), tx_address_1_nonce_0.tx], + &[tx_address_2_nonce_1.tx.clone(), tx_address_1_nonce_2.tx], ); get_txs_and_assert_expected( &mut mempool, - 2, - &[tx_address_1_nonce_1.tx.clone(), tx_address_0_nonce_3.tx], + 4, + &[ + tx_address_2_nonce_2.tx, + tx_address_1_nonce_3.tx.clone(), + tx_address_0_nonce_1.tx, + tx_address_1_nonce_4.tx.clone(), + ], ); - // Not included in block: address "0x2" nonce 2, address "0x1" nonce 1. - let nonces = [("0x0", 3), ("0x1", 0)]; + // Address 0x0 stays as proposed, address 0x1 rewinds nonce 4, address 0x2 rewinds completely. + let nonces = [("0x0", 2), ("0x1", 4)]; let tx_hashes = [1, 4]; commit_block(&mut mempool, nonces, tx_hashes); get_txs_and_assert_expected( &mut mempool, 2, - &[tx_address_2_nonce_2.tx, tx_address_1_nonce_1.tx], + &[tx_address_2_nonce_1.tx, tx_address_1_nonce_4.tx], ); } #[rstest] -fn test_flow_commit_block_fills_nonce_gap(mut mempool: Mempool) { +fn test_commit_block_fills_nonce_gap(mut mempool: Mempool) { // Setup. let tx_nonce_3_account_nonce_3 = add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 3, account_nonce: 3); @@ -174,7 +202,7 @@ fn test_flow_commit_block_fills_nonce_gap(mut mempool: Mempool) { get_txs_and_assert_expected(&mut mempool, 2, &[tx_nonce_3_account_nonce_3.tx]); - let nonces = [("0x0", 4)]; + let nonces = [("0x0", 5)]; let tx_hashes = [1, 3]; commit_block(&mut mempool, nonces, tx_hashes); @@ -189,3 +217,58 @@ fn test_flow_commit_block_fills_nonce_gap(mut mempool: Mempool) { get_txs_and_assert_expected(&mut mempool, 2, &[tx_nonce_5_account_nonce_3.tx]); } + +#[rstest] +fn test_flow_commit_block_rewinds_queued_nonce(mut mempool: Mempool) { + // Setup. + let tx_nonce_2 = add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 2, account_nonce: 2); + let tx_nonce_3 = add_tx_input!(tx_hash: 2, address: "0x0", tx_nonce: 3, account_nonce: 2); + let tx_nonce_4 = add_tx_input!(tx_hash: 3, address: "0x0", tx_nonce: 4, account_nonce: 2); + + for input in [&tx_nonce_2, &tx_nonce_3, &tx_nonce_4] { + add_tx(&mut mempool, input); + } + + get_txs_and_assert_expected( + &mut mempool, + 3, + &[tx_nonce_2.tx, tx_nonce_3.tx.clone(), tx_nonce_4.tx.clone()], + ); + + // Test. + let nonces = [("0x0", 3)]; + let tx_hashes = [1]; + // Nonce 2 was accepted, but 3 and 4 were not, so are rewound. + commit_block(&mut mempool, nonces, tx_hashes); + + // Nonces 3 and 4 were re-enqueued correctly. + get_txs_and_assert_expected(&mut mempool, 2, &[tx_nonce_3.tx, tx_nonce_4.tx]); +} + +#[rstest] +fn test_flow_commit_block_from_different_leader(mut mempool: Mempool) { + // Setup. + // TODO: set the mempool to `validate` mode once supported. + + let tx_nonce_2 = add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 2, account_nonce: 2); + let tx_nonce_3 = add_tx_input!(tx_hash: 2, address: "0x0", tx_nonce: 3, account_nonce: 2); + let tx_nonce_4 = add_tx_input!(tx_hash: 3, address: "0x0", tx_nonce: 4, account_nonce: 2); + + for input in [&tx_nonce_2, &tx_nonce_3, &tx_nonce_4] { + add_tx(&mut mempool, input); + } + + // Test. + let nonces = [("0x0", 4), ("0x1", 2)]; + let tx_hashes = [ + 1, // Address 0: known hash accepted for nonce 2. + 99, // Address 0: unknown hash accepted for nonce 3. + 4, // Unknown Address 1 (with unknown hash) for nonce 2. + ]; + commit_block(&mut mempool, nonces, tx_hashes); + + // Assert: two stale transactions were removed, one was added to a block by a different leader + // and the other "lost" to a different transaction with the same nonce that was added by the + // different leader. + get_txs_and_assert_expected(&mut mempool, 1, &[tx_nonce_4.tx]); +} diff --git a/crates/native_blockifier/src/py_block_executor_test.rs b/crates/native_blockifier/src/py_block_executor_test.rs index 313d5c5c14..0f5b59426d 100644 --- a/crates/native_blockifier/src/py_block_executor_test.rs +++ b/crates/native_blockifier/src/py_block_executor_test.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use blockifier::blockifier::transaction_executor::BLOCK_STATE_ACCESS_ERR; -use blockifier::execution::contract_class::{ContractClass, ContractClassV1}; +use blockifier::execution::contract_class::{ContractClassV1, RunnableContractClass}; use blockifier::state::state_api::StateReader; use cached::Cached; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; @@ -19,7 +19,8 @@ use crate::test_utils::MockStorage; fn global_contract_cache_update() { // Initialize executor and set a contract class on the state. let casm = CasmContractClass { compiler_version: "0.1.0".to_string(), ..Default::default() }; - let contract_class = ContractClass::V1(ContractClassV1::try_from(casm.clone()).unwrap()); + let contract_class = + RunnableContractClass::V1(ContractClassV1::try_from(casm.clone()).unwrap()); let class_hash = class_hash!("0x1"); let temp_storage_path = tempfile::tempdir().unwrap().into_path(); diff --git a/crates/native_blockifier/src/py_l1_handler.rs b/crates/native_blockifier/src/py_l1_handler.rs index 3c72eb7508..3aa9f33d8e 100644 --- a/crates/native_blockifier/src/py_l1_handler.rs +++ b/crates/native_blockifier/src/py_l1_handler.rs @@ -1,9 +1,9 @@ use std::sync::Arc; use blockifier::abi::constants; -use blockifier::transaction::transactions::L1HandlerTransaction; use pyo3::prelude::*; use starknet_api::core::{ContractAddress, EntryPointSelector, Nonce}; +use starknet_api::executable_transaction::L1HandlerTransaction; use starknet_api::transaction::{Calldata, Fee, TransactionHash}; use crate::errors::{NativeBlockifierInputError, NativeBlockifierResult}; diff --git a/crates/native_blockifier/src/py_transaction.rs b/crates/native_blockifier/src/py_transaction.rs index f4c54aa395..5ac4b56df8 100644 --- a/crates/native_blockifier/src/py_transaction.rs +++ b/crates/native_blockifier/src/py_transaction.rs @@ -1,17 +1,13 @@ use std::collections::BTreeMap; -use blockifier::execution::contract_class::{ - ClassInfo, - ContractClass, - ContractClassV0, - ContractClassV1, -}; +use blockifier::execution::contract_class::ClassInfo; use blockifier::transaction::account_transaction::AccountTransaction; use blockifier::transaction::transaction_execution::Transaction; use blockifier::transaction::transaction_types::TransactionType; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use starknet_api::block::GasPrice; +use starknet_api::contract_class::ContractClass; use starknet_api::execution_resources::GasAmount; use starknet_api::transaction::{ DeprecatedResourceBoundsMapping, @@ -170,14 +166,14 @@ impl PyClassInfo { py_class_info: PyClassInfo, tx: &starknet_api::transaction::DeclareTransaction, ) -> NativeBlockifierResult { - let contract_class: ContractClass = match tx { + let contract_class = match tx { starknet_api::transaction::DeclareTransaction::V0(_) | starknet_api::transaction::DeclareTransaction::V1(_) => { - ContractClassV0::try_from_json_string(&py_class_info.raw_contract_class)?.into() + ContractClass::V0(serde_json::from_str(&py_class_info.raw_contract_class)?) } starknet_api::transaction::DeclareTransaction::V2(_) | starknet_api::transaction::DeclareTransaction::V3(_) => { - ContractClassV1::try_from_json_string(&py_class_info.raw_contract_class)?.into() + ContractClass::V1(serde_json::from_str(&py_class_info.raw_contract_class)?) } }; let class_info = ClassInfo::new( diff --git a/crates/native_blockifier/src/state_readers/papyrus_state.rs b/crates/native_blockifier/src/state_readers/papyrus_state.rs index 463d242b0c..bdf38d748c 100644 --- a/crates/native_blockifier/src/state_readers/papyrus_state.rs +++ b/crates/native_blockifier/src/state_readers/papyrus_state.rs @@ -1,4 +1,8 @@ -use blockifier::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV1}; +use blockifier::execution::contract_class::{ + ContractClassV0, + ContractClassV1, + RunnableContractClass, +}; use blockifier::state::errors::StateError; use blockifier::state::global_cache::GlobalContractCache; use blockifier::state::state_api::{StateReader, StateResult}; @@ -43,7 +47,7 @@ impl PapyrusReader { fn get_compiled_contract_class_inner( &self, class_hash: ClassHash, - ) -> StateResult { + ) -> StateResult { let state_number = StateNumber(self.latest_block); let class_declaration_block_number = self .reader()? @@ -63,7 +67,7 @@ impl PapyrusReader { inconsistent.", ); - return Ok(ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?)); + return Ok(RunnableContractClass::V1(ContractClassV1::try_from(casm_contract_class)?)); } let v0_contract_class = self @@ -121,7 +125,10 @@ impl StateReader for PapyrusReader { } } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { // Assumption: the global cache is cleared upon reverted blocks. let contract_class = self.global_class_hash_to_class.get(&class_hash); diff --git a/crates/native_blockifier/src/state_readers/py_state_reader.rs b/crates/native_blockifier/src/state_readers/py_state_reader.rs index 35d78eea15..f379dcb05b 100644 --- a/crates/native_blockifier/src/state_readers/py_state_reader.rs +++ b/crates/native_blockifier/src/state_readers/py_state_reader.rs @@ -1,4 +1,8 @@ -use blockifier::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV1}; +use blockifier::execution::contract_class::{ + ContractClassV0, + ContractClassV1, + RunnableContractClass, +}; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; use pyo3::{FromPyObject, PyAny, PyErr, PyObject, PyResult, Python}; @@ -64,8 +68,11 @@ impl StateReader for PyStateReader { .map_err(|err| StateError::StateReadError(err.to_string())) } - fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { - Python::with_gil(|py| -> Result { + fn get_compiled_contract_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + Python::with_gil(|py| -> Result { let args = (PyFelt::from(class_hash),); let py_raw_compiled_class: PyRawCompiledClass = self .state_reader_proxy @@ -73,7 +80,7 @@ impl StateReader for PyStateReader { .call_method1("get_raw_compiled_class", args)? .extract()?; - Ok(ContractClass::try_from(py_raw_compiled_class)?) + Ok(RunnableContractClass::try_from(py_raw_compiled_class)?) }) .map_err(|err| { if Python::with_gil(|py| err.is_instance_of::(py)) { @@ -103,7 +110,7 @@ pub struct PyRawCompiledClass { pub version: usize, } -impl TryFrom for ContractClass { +impl TryFrom for RunnableContractClass { type Error = NativeBlockifierError; fn try_from(raw_compiled_class: PyRawCompiledClass) -> NativeBlockifierResult { diff --git a/crates/papyrus_common/src/lib.rs b/crates/papyrus_common/src/lib.rs index 12e6eeeab7..3fc0c90bea 100644 --- a/crates/papyrus_common/src/lib.rs +++ b/crates/papyrus_common/src/lib.rs @@ -8,6 +8,7 @@ pub mod pending_classes; pub mod python_json; pub mod state; pub mod storage_query; +pub mod tcp; pub(crate) fn usize_into_felt(u: usize) -> Felt { u128::try_from(u).expect("Expect at most 128 bits").into() diff --git a/crates/papyrus_common/src/tcp.rs b/crates/papyrus_common/src/tcp.rs new file mode 100644 index 0000000000..87199415ce --- /dev/null +++ b/crates/papyrus_common/src/tcp.rs @@ -0,0 +1,18 @@ +use std::net::TcpListener; + +pub fn find_free_port() -> u16 { + // The socket is automatically closed when the function exits. + // The port may still be available when accessed, but this is not guaranteed. + // TODO(Asmaa): find a reliable way to ensure the port stays free. + let listener = TcpListener::bind("0.0.0.0:0").expect("Failed to bind"); + listener.local_addr().expect("Failed to get local address").port() +} + +pub fn find_n_free_ports() -> [u16; N] { + // The socket is automatically closed when the function exits. + // The port may still be available when accessed, but this is not guaranteed. + // TODO(Asmaa): find a reliable way to ensure the port stays free. + let listeners: [TcpListener; N] = + core::array::from_fn(|_i| TcpListener::bind("0.0.0.0:0").expect("Failed to bind")); + core::array::from_fn(|i| listeners[i].local_addr().expect("Failed to get local address").port()) +} diff --git a/crates/papyrus_config/src/dumping.rs b/crates/papyrus_config/src/dumping.rs index 64eb593328..006e76ba4b 100644 --- a/crates/papyrus_config/src/dumping.rs +++ b/crates/papyrus_config/src/dumping.rs @@ -287,10 +287,10 @@ pub fn ser_pointer_target_required_param( ) } -// Takes a config map and a vector of {target param, serialized pointer, and vector of params that -// will point to it}. -// Adds to the map the target params. -// Replaces the value of the pointers to contain only the name of the target they point to. +/// Takes a config map and a vector of target parameters with their serialized representations. +/// Adds each target param to the config map. +/// Updates entries in the map to point to these targets, replacing values of entries that match +/// the target parameter paths to contain only the name of the target they point to. pub(crate) fn combine_config_map_and_pointers( mut config_map: BTreeMap, pointers: &Vec<(ParamPath, SerializedParam)>, diff --git a/crates/papyrus_execution/src/execution_utils.rs b/crates/papyrus_execution/src/execution_utils.rs index 439a903c23..df77e900fe 100644 --- a/crates/papyrus_execution/src/execution_utils.rs +++ b/crates/papyrus_execution/src/execution_utils.rs @@ -5,9 +5,9 @@ use std::path::PathBuf; // Expose the tool for creating entry point selectors from function names. pub use blockifier::abi::abi_utils::selector_from_name; use blockifier::execution::contract_class::{ - ContractClass as BlockifierContractClass, ContractClassV0, ContractClassV1, + RunnableContractClass, }; use blockifier::state::cached_state::{CachedState, CommitmentStateDiff, MutRefState}; use blockifier::state::state_api::StateReader; @@ -59,14 +59,14 @@ pub(crate) fn get_contract_class( txn: &StorageTxn<'_, RO>, class_hash: &ClassHash, state_number: StateNumber, -) -> Result, ExecutionUtilsError> { +) -> Result, ExecutionUtilsError> { match txn.get_state_reader()?.get_class_definition_block_number(class_hash)? { Some(block_number) if state_number.is_before(block_number) => return Ok(None), Some(_block_number) => { let Some(casm) = txn.get_casm(class_hash)? else { return Err(ExecutionUtilsError::CasmTableNotSynced); }; - return Ok(Some(BlockifierContractClass::V1( + return Ok(Some(RunnableContractClass::V1( ContractClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?, ))); } @@ -78,7 +78,7 @@ pub(crate) fn get_contract_class( else { return Ok(None); }; - Ok(Some(BlockifierContractClass::V0( + Ok(Some(RunnableContractClass::V0( ContractClassV0::try_from(deprecated_class).map_err(ExecutionUtilsError::ProgramError)?, ))) } diff --git a/crates/papyrus_execution/src/lib.rs b/crates/papyrus_execution/src/lib.rs index 6b6253847f..cf393a232e 100644 --- a/crates/papyrus_execution/src/lib.rs +++ b/crates/papyrus_execution/src/lib.rs @@ -27,7 +27,7 @@ use blockifier::blockifier::block::{pre_process_block, BlockInfo, GasPrices}; use blockifier::bouncer::BouncerConfig; use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses, TransactionContext}; use blockifier::execution::call_info::CallExecution; -use blockifier::execution::contract_class::{ClassInfo, ContractClass as BlockifierContractClass}; +use blockifier::execution::contract_class::ClassInfo; use blockifier::execution::entry_point::{ CallEntryPoint, CallType as BlockifierCallType, @@ -790,19 +790,16 @@ fn to_blockifier_tx( abi_length, only_query, ) => { - let class_v0 = BlockifierContractClass::V0(deprecated_class.try_into().map_err( - |e: cairo_vm::types::errors::program_errors::ProgramError| { - ExecutionError::TransactionExecutionError { - transaction_index, - execution_error: e.to_string(), - } - }, - )?); - let class_info = ClassInfo::new(&class_v0, DEPRECATED_CONTRACT_SIERRA_SIZE, abi_length) - .map_err(|err| ExecutionError::BadDeclareTransaction { - tx: DeclareTransaction::V0(declare_tx.clone()), - err, - })?; + let class_info = ClassInfo::new( + &deprecated_class.into(), + DEPRECATED_CONTRACT_SIERRA_SIZE, + abi_length, + ) + .map_err(|err| ExecutionError::BadDeclareTransaction { + tx: DeclareTransaction::V0(declare_tx.clone()), + err, + })?; + BlockifierTransaction::from_api( Transaction::Declare(DeclareTransaction::V0(declare_tx)), tx_hash, @@ -819,14 +816,15 @@ fn to_blockifier_tx( abi_length, only_query, ) => { - let class_v0 = BlockifierContractClass::V0( - deprecated_class.try_into().map_err(BlockifierError::new)?, - ); - let class_info = ClassInfo::new(&class_v0, DEPRECATED_CONTRACT_SIERRA_SIZE, abi_length) - .map_err(|err| ExecutionError::BadDeclareTransaction { - tx: DeclareTransaction::V1(declare_tx.clone()), - err, - })?; + let class_info = ClassInfo::new( + &deprecated_class.into(), + DEPRECATED_CONTRACT_SIERRA_SIZE, + abi_length, + ) + .map_err(|err| ExecutionError::BadDeclareTransaction { + tx: DeclareTransaction::V1(declare_tx.clone()), + err, + })?; BlockifierTransaction::from_api( Transaction::Declare(DeclareTransaction::V1(declare_tx)), tx_hash, @@ -844,16 +842,13 @@ fn to_blockifier_tx( abi_length, only_query, ) => { - let class_v1 = BlockifierContractClass::V1( - compiled_class.try_into().map_err(BlockifierError::new)?, - ); let class_info = - ClassInfo::new(&class_v1, sierra_program_length, abi_length).map_err(|err| { - ExecutionError::BadDeclareTransaction { + ClassInfo::new(&compiled_class.into(), sierra_program_length, abi_length).map_err( + |err| ExecutionError::BadDeclareTransaction { tx: DeclareTransaction::V2(declare_tx.clone()), err, - } - })?; + }, + )?; BlockifierTransaction::from_api( Transaction::Declare(DeclareTransaction::V2(declare_tx)), tx_hash, @@ -871,16 +866,13 @@ fn to_blockifier_tx( abi_length, only_query, ) => { - let class_v1 = BlockifierContractClass::V1( - compiled_class.try_into().map_err(BlockifierError::new)?, - ); let class_info = - ClassInfo::new(&class_v1, sierra_program_length, abi_length).map_err(|err| { - ExecutionError::BadDeclareTransaction { + ClassInfo::new(&compiled_class.into(), sierra_program_length, abi_length).map_err( + |err| ExecutionError::BadDeclareTransaction { tx: DeclareTransaction::V3(declare_tx.clone()), err, - } - })?; + }, + )?; BlockifierTransaction::from_api( Transaction::Declare(DeclareTransaction::V3(declare_tx)), tx_hash, diff --git a/crates/papyrus_execution/src/state_reader.rs b/crates/papyrus_execution/src/state_reader.rs index a15df3c3c8..a963241f61 100644 --- a/crates/papyrus_execution/src/state_reader.rs +++ b/crates/papyrus_execution/src/state_reader.rs @@ -5,9 +5,9 @@ mod state_reader_test; use std::cell::Cell; use blockifier::execution::contract_class::{ - ContractClass as BlockifierContractClass, ContractClassV0, ContractClassV1, + RunnableContractClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; @@ -78,13 +78,13 @@ impl BlockifierStateReader for ExecutionStateReader { fn get_compiled_contract_class( &self, class_hash: ClassHash, - ) -> StateResult { + ) -> StateResult { if let Some(pending_casm) = self .maybe_pending_data .as_ref() .and_then(|pending_data| pending_data.classes.get_compiled_class(class_hash)) { - return Ok(BlockifierContractClass::V1( + return Ok(RunnableContractClass::V1( ContractClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, )); } @@ -93,7 +93,7 @@ impl BlockifierStateReader for ExecutionStateReader { .as_ref() .and_then(|pending_data| pending_data.classes.get_class(class_hash)) { - return Ok(BlockifierContractClass::V0( + return Ok(RunnableContractClass::V0( ContractClassV0::try_from(pending_deprecated_class) .map_err(StateError::ProgramError)?, )); diff --git a/crates/papyrus_execution/src/state_reader_test.rs b/crates/papyrus_execution/src/state_reader_test.rs index 136012eebe..ce52a279ef 100644 --- a/crates/papyrus_execution/src/state_reader_test.rs +++ b/crates/papyrus_execution/src/state_reader_test.rs @@ -2,9 +2,9 @@ use std::cell::Cell; use assert_matches::assert_matches; use blockifier::execution::contract_class::{ - ContractClass as BlockifierContractClass, ContractClassV0, ContractClassV1, + RunnableContractClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::StateReader; @@ -50,7 +50,7 @@ fn read_state() { let class0 = ContractClass::default(); let casm0 = get_test_casm(); let blockifier_casm0 = - BlockifierContractClass::V1(ContractClassV1::try_from(casm0.clone()).unwrap()); + RunnableContractClass::V1(ContractClassV1::try_from(casm0.clone()).unwrap()); let compiled_class_hash0 = CompiledClassHash(StarkHash::default()); let class_hash1 = ClassHash(1u128.into()); @@ -65,7 +65,7 @@ fn read_state() { let mut casm1 = get_test_casm(); casm1.bytecode[0] = BigUintAsHex { value: 12345u32.into() }; let blockifier_casm1 = - BlockifierContractClass::V1(ContractClassV1::try_from(casm1.clone()).unwrap()); + RunnableContractClass::V1(ContractClassV1::try_from(casm1.clone()).unwrap()); let nonce1 = Nonce(felt!(2_u128)); let class_hash3 = ClassHash(567_u128.into()); let class_hash4 = ClassHash(89_u128.into()); @@ -241,7 +241,7 @@ fn read_state() { // Test that if the class is deprecated it is returned. assert_eq!( state_reader2.get_compiled_contract_class(class_hash4).unwrap(), - BlockifierContractClass::V0(ContractClassV0::try_from(class1).unwrap()) + RunnableContractClass::V0(ContractClassV0::try_from(class1).unwrap()) ); // Test get_class_hash_at when the class is replaced. diff --git a/crates/papyrus_network/src/lib.rs b/crates/papyrus_network/src/lib.rs index 1d3a00bc52..e9a5ea8989 100644 --- a/crates/papyrus_network/src/lib.rs +++ b/crates/papyrus_network/src/lib.rs @@ -13,7 +13,7 @@ mod peer_manager; mod sqmr; #[cfg(test)] mod test_utils; -mod utils; +pub mod utils; use std::collections::BTreeMap; use std::time::Duration; diff --git a/crates/papyrus_network/src/network_manager/mod.rs b/crates/papyrus_network/src/network_manager/mod.rs index 1972bb352b..b41e692af0 100644 --- a/crates/papyrus_network/src/network_manager/mod.rs +++ b/crates/papyrus_network/src/network_manager/mod.rs @@ -2,7 +2,7 @@ mod swarm_trait; #[cfg(test)] mod test; -#[cfg(feature = "testing")] +#[cfg(any(test, feature = "testing"))] pub mod test_utils; use std::collections::HashMap; diff --git a/crates/papyrus_network/src/network_manager/test_utils.rs b/crates/papyrus_network/src/network_manager/test_utils.rs index dfd2b9f433..1aaba292d9 100644 --- a/crates/papyrus_network/src/network_manager/test_utils.rs +++ b/crates/papyrus_network/src/network_manager/test_utils.rs @@ -1,25 +1,33 @@ +use core::net::Ipv4Addr; + use futures::channel::mpsc::{Receiver, SendError, Sender}; use futures::channel::oneshot; use futures::future::{ready, Ready}; use futures::sink::With; use futures::stream::Map; use futures::{FutureExt, SinkExt, StreamExt}; +use libp2p::core::multiaddr::Protocol; use libp2p::gossipsub::SubscriptionError; -use libp2p::PeerId; +use libp2p::identity::Keypair; +use libp2p::{Multiaddr, PeerId}; +use papyrus_common::tcp::find_n_free_ports; use super::{ BroadcastTopicClient, BroadcastedMessageMetadata, GenericReceiver, + NetworkManager, ReportReceiver, ServerQueryManager, ServerResponsesSender, SqmrClientPayload, SqmrClientSender, SqmrServerReceiver, + Topic, }; use crate::network_manager::{BroadcastReceivedMessagesConverterFn, BroadcastTopicChannels}; use crate::sqmr::Bytes; +use crate::NetworkConfig; pub fn mock_register_sqmr_protocol_client( buffer_size: usize, @@ -139,6 +147,45 @@ where Ok(TestSubscriberChannels { subscriber_channels, mock_network }) } +pub fn create_network_config_connected_to_broadcast_channels( + topic: Topic, +) -> (NetworkConfig, BroadcastTopicChannels) +where + T: TryFrom + 'static, + Bytes: From, +{ + const BUFFER_SIZE: usize = 1000; + + let [channels_port, config_port] = find_n_free_ports::<2>(); + + let channels_secret_key = [1u8; 64]; + let channels_public_key = Keypair::ed25519_from_bytes(channels_secret_key).unwrap().public(); + + let channels_config = NetworkConfig { + tcp_port: channels_port, + secret_key: Some(channels_secret_key.to_vec()), + ..Default::default() + }; + let result_config = NetworkConfig { + tcp_port: config_port, + bootstrap_peer_multiaddr: Some( + Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::LOCALHOST)) + .with(Protocol::Tcp(channels_port)) + .with(Protocol::P2p(PeerId::from_public_key(&channels_public_key))), + ), + ..Default::default() + }; + + let mut channels_network_manager = NetworkManager::new(channels_config, None); + let broadcast_channels = + channels_network_manager.register_broadcast_topic(topic, BUFFER_SIZE).unwrap(); + + tokio::task::spawn(channels_network_manager.run()); + + (result_config, broadcast_channels) +} + pub struct MockClientResponsesManager, Response: TryFrom> { query: Result>::Error>, report_receiver: ReportReceiver, diff --git a/crates/mempool_node/Cargo.toml b/crates/sequencer_node/Cargo.toml similarity index 92% rename from crates/mempool_node/Cargo.toml rename to crates/sequencer_node/Cargo.toml index 6c34b492cc..316d7b77e7 100644 --- a/crates/mempool_node/Cargo.toml +++ b/crates/sequencer_node/Cargo.toml @@ -6,7 +6,7 @@ repository.workspace = true license.workspace = true [features] -testing = ["thiserror"] +testing = ["papyrus_proc_macros", "thiserror"] [lints] workspace = true @@ -17,6 +17,7 @@ clap.workspace = true const_format.workspace = true futures.workspace = true papyrus_config.workspace = true +papyrus_proc_macros = { workspace = true, optional = true } rstest.workspace = true serde.workspace = true starknet_api.workspace = true diff --git a/crates/mempool_node/build.rs b/crates/sequencer_node/build.rs similarity index 100% rename from crates/mempool_node/build.rs rename to crates/sequencer_node/build.rs diff --git a/crates/mempool_node/src/bin/sequencer_dump_config.rs b/crates/sequencer_node/src/bin/sequencer_dump_config.rs similarity index 100% rename from crates/mempool_node/src/bin/sequencer_dump_config.rs rename to crates/sequencer_node/src/bin/sequencer_dump_config.rs diff --git a/crates/mempool_node/src/communication.rs b/crates/sequencer_node/src/communication.rs similarity index 89% rename from crates/mempool_node/src/communication.rs rename to crates/sequencer_node/src/communication.rs index ab1f30617f..27e756d143 100644 --- a/crates/mempool_node/src/communication.rs +++ b/crates/sequencer_node/src/communication.rs @@ -127,21 +127,24 @@ pub fn create_node_clients( ) -> SequencerNodeClients { let batcher_client: Option = match config.components.batcher.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Arc::new(LocalBatcherClient::new(channels.take_batcher_tx()))) } ComponentExecutionMode::Disabled => None, }; let mempool_client: Option = match config.components.mempool.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Arc::new(LocalMempoolClient::new(channels.take_mempool_tx()))) } ComponentExecutionMode::Disabled => None, }; let gateway_client: Option = match config.components.gateway.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Arc::new(LocalGatewayClient::new(channels.take_gateway_tx()))) } ComponentExecutionMode::Disabled => None, @@ -149,11 +152,10 @@ pub fn create_node_clients( let mempool_p2p_propagator_client: Option = match config.components.mempool.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { - Some(Arc::new(LocalMempoolP2pPropagatorClient::new( - channels.take_mempool_p2p_propagator_tx(), - ))) - } + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => Some(Arc::new( + LocalMempoolP2pPropagatorClient::new(channels.take_mempool_p2p_propagator_tx()), + )), ComponentExecutionMode::Disabled => None, }; SequencerNodeClients { diff --git a/crates/mempool_node/src/components.rs b/crates/sequencer_node/src/components.rs similarity index 81% rename from crates/mempool_node/src/components.rs rename to crates/sequencer_node/src/components.rs index b1445be735..e5bed7e8f6 100644 --- a/crates/mempool_node/src/components.rs +++ b/crates/sequencer_node/src/components.rs @@ -31,7 +31,8 @@ pub fn create_node_components( clients: &SequencerNodeClients, ) -> SequencerNodeComponents { let batcher = match config.components.batcher.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { let mempool_client = clients.get_mempool_client().expect("Mempool Client should be available"); Some(create_batcher(config.batcher_config.clone(), mempool_client)) @@ -39,7 +40,8 @@ pub fn create_node_components( ComponentExecutionMode::Disabled => None, }; let consensus_manager = match config.components.consensus_manager.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { let batcher_client = clients.get_batcher_client().expect("Batcher Client should be available"); Some(ConsensusManager::new(config.consensus_manager_config.clone(), batcher_client)) @@ -47,7 +49,8 @@ pub fn create_node_components( ComponentExecutionMode::Disabled => None, }; let gateway = match config.components.gateway.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { let mempool_client = clients.get_mempool_client().expect("Mempool Client should be available"); @@ -61,7 +64,8 @@ pub fn create_node_components( ComponentExecutionMode::Disabled => None, }; let http_server = match config.components.http_server.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { let gateway_client = clients.get_gateway_client().expect("Gateway Client should be available"); @@ -72,7 +76,8 @@ pub fn create_node_components( let (mempool_p2p_propagator, mempool_p2p_runner) = match config.components.mempool_p2p.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { let gateway_client = clients.get_gateway_client().expect("Gateway Client should be available"); let (mempool_p2p_propagator, mempool_p2p_runner) = create_p2p_propagator_and_runner( @@ -85,7 +90,8 @@ pub fn create_node_components( }; let mempool = match config.components.mempool.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { let mempool_p2p_propagator_client = clients .get_mempool_p2p_propagator_client() .expect("Propagator Client should be available"); @@ -96,10 +102,10 @@ pub fn create_node_components( }; let monitoring_endpoint = match config.components.monitoring_endpoint.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: true } => Some( + ComponentExecutionMode::LocalExecutionWithRemoteEnabled => Some( create_monitoring_endpoint(config.monitoring_endpoint_config.clone(), VERSION_FULL), ), - ComponentExecutionMode::LocalExecution { enable_remote_connection: false } => None, + ComponentExecutionMode::LocalExecutionWithRemoteDisabled => None, ComponentExecutionMode::Disabled => None, }; diff --git a/crates/mempool_node/src/config/component_config.rs b/crates/sequencer_node/src/config/component_config.rs similarity index 100% rename from crates/mempool_node/src/config/component_config.rs rename to crates/sequencer_node/src/config/component_config.rs diff --git a/crates/mempool_node/src/config/component_execution_config.rs b/crates/sequencer_node/src/config/component_execution_config.rs similarity index 72% rename from crates/mempool_node/src/config/component_execution_config.rs rename to crates/sequencer_node/src/config/component_execution_config.rs index ba80865292..094765666c 100644 --- a/crates/mempool_node/src/config/component_execution_config.rs +++ b/crates/sequencer_node/src/config/component_execution_config.rs @@ -1,11 +1,6 @@ use std::collections::BTreeMap; -use papyrus_config::dumping::{ - append_sub_config_name, - ser_optional_sub_config, - ser_param, - SerializeConfig, -}; +use papyrus_config::dumping::{ser_optional_sub_config, ser_param, SerializeConfig}; use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; use serde::{Deserialize, Serialize}; use starknet_sequencer_infra::component_definitions::{ @@ -18,30 +13,10 @@ use validator::{Validate, ValidationError}; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub enum ComponentExecutionMode { Disabled, - LocalExecution { enable_remote_connection: bool }, + LocalExecutionWithRemoteEnabled, + LocalExecutionWithRemoteDisabled, } -impl ComponentExecutionMode { - fn dump(&self) -> BTreeMap { - match self { - ComponentExecutionMode::Disabled => BTreeMap::from_iter([ser_param( - "Disabled", - &"Disabled", - "The component is disabled.", - ParamPrivacyInput::Public, - )]), - ComponentExecutionMode::LocalExecution { enable_remote_connection } => { - BTreeMap::from_iter([ser_param( - "LocalExecution.enable_remote_connection", - enable_remote_connection, - "Specifies whether the component, when running locally, allows remote \ - connections.", - ParamPrivacyInput::Public, - )]) - } - } - } -} // TODO(Lev/Tsabary): When papyrus_config will support it, change to include communication config in // the enum. @@ -57,8 +32,14 @@ pub struct ComponentExecutionConfig { impl SerializeConfig for ComponentExecutionConfig { fn dump(&self) -> BTreeMap { + let members = BTreeMap::from_iter([ser_param( + "execution_mode", + &self.execution_mode, + "The component execution mode.", + ParamPrivacyInput::Public, + )]); vec![ - append_sub_config_name(self.execution_mode.dump(), "execution_mode"), + members, ser_optional_sub_config(&self.local_server_config, "local_server_config"), ser_optional_sub_config(&self.remote_client_config, "remote_client_config"), ser_optional_sub_config(&self.remote_server_config, "remote_server_config"), @@ -72,9 +53,7 @@ impl SerializeConfig for ComponentExecutionConfig { impl Default for ComponentExecutionConfig { fn default() -> Self { Self { - execution_mode: ComponentExecutionMode::LocalExecution { - enable_remote_connection: false, - }, + execution_mode: ComponentExecutionMode::LocalExecutionWithRemoteDisabled, local_server_config: Some(LocalServerConfig::default()), remote_client_config: None, remote_server_config: None, @@ -86,9 +65,7 @@ impl Default for ComponentExecutionConfig { impl ComponentExecutionConfig { pub fn gateway_default_config() -> Self { Self { - execution_mode: ComponentExecutionMode::LocalExecution { - enable_remote_connection: false, - }, + execution_mode: ComponentExecutionMode::LocalExecutionWithRemoteDisabled, local_server_config: Some(LocalServerConfig::default()), remote_client_config: None, remote_server_config: None, @@ -100,9 +77,7 @@ impl ComponentExecutionConfig { // a workaround I've set the local one, but this should be addressed. pub fn http_server_default_config() -> Self { Self { - execution_mode: ComponentExecutionMode::LocalExecution { - enable_remote_connection: true, - }, + execution_mode: ComponentExecutionMode::LocalExecutionWithRemoteEnabled, local_server_config: Some(LocalServerConfig::default()), remote_client_config: None, remote_server_config: Some(RemoteServerConfig::default()), @@ -114,9 +89,7 @@ impl ComponentExecutionConfig { // one of them is set. As a workaround I've set the local one, but this should be addressed. pub fn monitoring_endpoint_default_config() -> Self { Self { - execution_mode: ComponentExecutionMode::LocalExecution { - enable_remote_connection: true, - }, + execution_mode: ComponentExecutionMode::LocalExecutionWithRemoteEnabled, local_server_config: Some(LocalServerConfig::default()), remote_client_config: None, remote_server_config: Some(RemoteServerConfig::default()), @@ -125,9 +98,7 @@ impl ComponentExecutionConfig { pub fn mempool_default_config() -> Self { Self { - execution_mode: ComponentExecutionMode::LocalExecution { - enable_remote_connection: false, - }, + execution_mode: ComponentExecutionMode::LocalExecutionWithRemoteDisabled, local_server_config: Some(LocalServerConfig::default()), remote_client_config: None, remote_server_config: None, @@ -136,9 +107,7 @@ impl ComponentExecutionConfig { pub fn batcher_default_config() -> Self { Self { - execution_mode: ComponentExecutionMode::LocalExecution { - enable_remote_connection: false, - }, + execution_mode: ComponentExecutionMode::LocalExecutionWithRemoteDisabled, local_server_config: Some(LocalServerConfig::default()), remote_client_config: None, remote_server_config: None, @@ -147,9 +116,7 @@ impl ComponentExecutionConfig { pub fn consensus_manager_default_config() -> Self { Self { - execution_mode: ComponentExecutionMode::LocalExecution { - enable_remote_connection: false, - }, + execution_mode: ComponentExecutionMode::LocalExecutionWithRemoteDisabled, local_server_config: Some(LocalServerConfig::default()), remote_client_config: None, remote_server_config: None, @@ -158,9 +125,7 @@ impl ComponentExecutionConfig { pub fn mempool_p2p_default_config() -> Self { Self { - execution_mode: ComponentExecutionMode::LocalExecution { - enable_remote_connection: false, - }, + execution_mode: ComponentExecutionMode::LocalExecutionWithRemoteDisabled, local_server_config: Some(LocalServerConfig::default()), remote_client_config: None, remote_server_config: None, @@ -178,18 +143,8 @@ pub fn validate_single_component_config( component_config.remote_server_config.is_some(), ) { (ComponentExecutionMode::Disabled, false, false, false) => Ok(()), - ( - ComponentExecutionMode::LocalExecution { enable_remote_connection: true }, - true, - false, - true, - ) => Ok(()), - ( - ComponentExecutionMode::LocalExecution { enable_remote_connection: false }, - true, - false, - false, - ) => Ok(()), + (ComponentExecutionMode::LocalExecutionWithRemoteEnabled, true, false, true) => Ok(()), + (ComponentExecutionMode::LocalExecutionWithRemoteDisabled, true, false, false) => Ok(()), _ => { let mut error = ValidationError::new("Invalid component execution configuration."); error.message = Some("Ensure settings align with the chosen execution mode.".into()); diff --git a/crates/mempool_node/src/config/config_test.rs b/crates/sequencer_node/src/config/config_test.rs similarity index 83% rename from crates/mempool_node/src/config/config_test.rs rename to crates/sequencer_node/src/config/config_test.rs index 8f1513525e..3ad5600f26 100644 --- a/crates/mempool_node/src/config/config_test.rs +++ b/crates/sequencer_node/src/config/config_test.rs @@ -15,8 +15,8 @@ use starknet_sequencer_infra::component_definitions::{ }; use validator::Validate; +use crate::config::test_utils::{create_test_config_load_args, RequiredParams}; use crate::config::{ - create_test_config_load_args, ComponentExecutionConfig, ComponentExecutionMode, SequencerNodeConfig, @@ -26,9 +26,9 @@ use crate::config::{ }; const LOCAL_EXECUTION_MODE: ComponentExecutionMode = - ComponentExecutionMode::LocalExecution { enable_remote_connection: false }; + ComponentExecutionMode::LocalExecutionWithRemoteDisabled; const ENABLE_REMOTE_CONNECTION_MODE: ComponentExecutionMode = - ComponentExecutionMode::LocalExecution { enable_remote_connection: true }; + ComponentExecutionMode::LocalExecutionWithRemoteEnabled; /// Test the validation of the struct ComponentExecutionConfig. /// Validates that execution mode of the component and the local/remote config are at sync. @@ -93,10 +93,20 @@ fn test_default_config_file_is_up_to_date() { /// Tests parsing a node config without additional args. #[test] fn test_config_parsing() { - let args = create_test_config_load_args(&REQUIRED_PARAM_CONFIG_POINTERS); + let required_params = RequiredParams::create_for_testing(); + let args = create_test_config_load_args(required_params); let config = SequencerNodeConfig::load_and_process(args); let config = config.expect("Parsing function failed."); let result = config_validate(&config); assert_matches!(result, Ok(_), "Expected Ok but got {:?}", result); } + +/// Tests compatibility of the required parameter settings: pointer targets and test util struct. +#[test] +fn test_required_params_setting() { + let required_pointers = + REQUIRED_PARAM_CONFIG_POINTERS.iter().map(|(x, _)| x.to_owned()).collect::>(); + let required_params = RequiredParams::field_names(); + assert_eq!(required_pointers, required_params); +} diff --git a/crates/mempool_node/src/config/mod.rs b/crates/sequencer_node/src/config/mod.rs similarity index 62% rename from crates/mempool_node/src/config/mod.rs rename to crates/sequencer_node/src/config/mod.rs index ba406012b9..83a4f8f5e3 100644 --- a/crates/mempool_node/src/config/mod.rs +++ b/crates/sequencer_node/src/config/mod.rs @@ -4,7 +4,10 @@ mod config_test; pub mod component_config; pub mod component_execution_config; pub mod node_config; +#[cfg(any(feature = "testing", test))] +pub mod test_utils; +// TODO(Tsabary): Remove these, and replace with direct imports. pub use component_config::*; pub use component_execution_config::*; pub use node_config::*; diff --git a/crates/mempool_node/src/config/node_config.rs b/crates/sequencer_node/src/config/node_config.rs similarity index 62% rename from crates/mempool_node/src/config/node_config.rs rename to crates/sequencer_node/src/config/node_config.rs index fc5206179a..4021c7b49e 100644 --- a/crates/mempool_node/src/config/node_config.rs +++ b/crates/sequencer_node/src/config/node_config.rs @@ -11,9 +11,6 @@ use papyrus_config::dumping::{ SerializeConfig, }; use papyrus_config::loading::load_and_process_config; -use papyrus_config::validators::validate_ascii; -#[cfg(any(feature = "testing", test))] -use papyrus_config::SerializedContent; use papyrus_config::{ConfigError, ParamPath, SerializationType, SerializedParam}; use serde::{Deserialize, Serialize}; use starknet_api::core::ChainId; @@ -39,17 +36,25 @@ pub const DEFAULT_CHAIN_ID: ChainId = ChainId::Mainnet; // Required target parameters. pub static REQUIRED_PARAM_CONFIG_POINTERS: LazyLock> = LazyLock::new(|| { - vec![ser_pointer_target_required_param( - "chain_id", - SerializationType::String, - "The chain to follow.", - )] + vec![ + ser_pointer_target_required_param( + "chain_id", + SerializationType::String, + "The chain to follow.", + ), + ser_pointer_target_required_param( + "eth_fee_token_address", + SerializationType::String, + "Address of the ETH fee token.", + ), + ser_pointer_target_required_param( + "strk_fee_token_address", + SerializationType::String, + "Address of the STRK fee token.", + ), + ] }); -// TODO(Tsabary): Create a struct detailing all the required parameters and their types. Add a macro -// that verifies the correctness of the required parameters compared to -// REQUIRED_PARAM_CONFIG_POINTERS. - // Optional target parameters, i.e., target parameters with default values. pub static DEFAULT_PARAM_CONFIG_POINTERS: LazyLock> = LazyLock::new(Vec::new); @@ -61,49 +66,10 @@ pub static CONFIG_POINTERS: LazyLock> = LazyLo combined }); -// TODO(Tsabary): Bundle required config values in a struct, detailing whether they are pointer -// targets or not. Then, derive their values in the config (struct and pointers). -// Also, add functionality to derive them for testing. -// Creates a vector of strings with the command name and required parameters that can be used as -// arguments to load a config. -#[cfg(any(feature = "testing", test))] -pub fn create_test_config_load_args(pointers: &Vec<(ParamPath, SerializedParam)>) -> Vec { - let mut dummy_values = Vec::new(); - - // Command name. - dummy_values.push(node_command().to_string()); - - // Iterate over required config parameters and add them as args with suitable arbitrary values. - for (target_param, serialized_pointer) in pointers { - // Param name. - let required_param_name_as_arg = format!("--{}", target_param); - dummy_values.push(required_param_name_as_arg); - - // Param value. - let serialization_type = match &serialized_pointer.content { - SerializedContent::ParamType(serialization_type) => serialization_type, - _ => panic!("Required parameters have to be of type ParamType."), - }; - let arbitrary_value = match serialization_type { - SerializationType::Boolean => "false", - SerializationType::Float => "15.2", - SerializationType::NegativeInteger => "-30", - SerializationType::PositiveInteger => "17", - SerializationType::String => "ArbitraryString", - } - .to_string(); - dummy_values.push(arbitrary_value); - } - dummy_values -} - // TODO(yair): Make the GW and batcher execution config point to the same values. /// The configurations of the various components of the node. -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Validate)] +#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq, Validate)] pub struct SequencerNodeConfig { - /// The [chain id](https://docs.rs/starknet_api/latest/starknet_api/core/struct.ChainId.html) of the Starknet network. - #[validate(custom = "validate_ascii")] - pub chain_id: ChainId, #[validate] pub components: ComponentConfig, #[validate] @@ -148,23 +114,6 @@ impl SerializeConfig for SequencerNodeConfig { } } -impl Default for SequencerNodeConfig { - fn default() -> Self { - Self { - chain_id: DEFAULT_CHAIN_ID, - components: Default::default(), - batcher_config: Default::default(), - consensus_manager_config: Default::default(), - gateway_config: Default::default(), - http_server_config: Default::default(), - rpc_state_reader_config: Default::default(), - compiler_config: Default::default(), - mempool_p2p_config: Default::default(), - monitoring_endpoint_config: Default::default(), - } - } -} - impl SequencerNodeConfig { /// Creates a config object. Selects the values from the default file and from resources with /// higher priority. diff --git a/crates/sequencer_node/src/config/test_utils.rs b/crates/sequencer_node/src/config/test_utils.rs new file mode 100644 index 0000000000..cf9e48f138 --- /dev/null +++ b/crates/sequencer_node/src/config/test_utils.rs @@ -0,0 +1,50 @@ +use std::vec::Vec; // Used by #[gen_field_names_fn]. + +use papyrus_proc_macros::gen_field_names_fn; +use starknet_api::core::{ChainId, ContractAddress}; + +use crate::config::node_command; + +/// Required parameters utility struct. +#[gen_field_names_fn] +pub struct RequiredParams { + pub chain_id: ChainId, + pub eth_fee_token_address: ContractAddress, + pub strk_fee_token_address: ContractAddress, +} + +impl RequiredParams { + pub fn create_for_testing() -> Self { + Self { + chain_id: ChainId::create_for_testing(), + eth_fee_token_address: ContractAddress::from(2_u128), + strk_fee_token_address: ContractAddress::from(3_u128), + } + } + + // TODO(Tsabary): replace with a macro. + pub fn cli_args(&self) -> Vec { + let args = vec![ + "--chain_id".to_string(), + self.chain_id.to_string(), + "--eth_fee_token_address".to_string(), + self.eth_fee_token_address.to_string(), + "--strk_fee_token_address".to_string(), + self.strk_fee_token_address.to_string(), + ]; + // Verify all arguments and their values are present. + assert!( + args.len() == Self::field_names().len() * 2, + "Required parameter cli generation failure." + ); + args + } +} + +// Creates a vector of strings with the command name and required parameters that can be used as +// arguments to load a config. +pub fn create_test_config_load_args(required_params: RequiredParams) -> Vec { + let mut cli_args = vec![node_command().to_string()]; + cli_args.extend(required_params.cli_args()); + cli_args +} diff --git a/crates/mempool_node/src/lib.rs b/crates/sequencer_node/src/lib.rs similarity index 100% rename from crates/mempool_node/src/lib.rs rename to crates/sequencer_node/src/lib.rs diff --git a/crates/mempool_node/src/main.rs b/crates/sequencer_node/src/main.rs similarity index 100% rename from crates/mempool_node/src/main.rs rename to crates/sequencer_node/src/main.rs diff --git a/crates/mempool_node/src/servers.rs b/crates/sequencer_node/src/servers.rs similarity index 88% rename from crates/mempool_node/src/servers.rs rename to crates/sequencer_node/src/servers.rs index 3792f2e13a..39bd42594a 100644 --- a/crates/mempool_node/src/servers.rs +++ b/crates/sequencer_node/src/servers.rs @@ -54,7 +54,8 @@ pub fn create_node_servers( components: SequencerNodeComponents, ) -> SequencerNodeServers { let batcher_server = match config.components.batcher.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Box::new(create_local_batcher_server( components.batcher.expect("Batcher is not initialized."), communication.take_batcher_rx(), @@ -63,7 +64,8 @@ pub fn create_node_servers( ComponentExecutionMode::Disabled => None, }; let consensus_manager_server = match config.components.consensus_manager.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Box::new(create_consensus_manager_server( components.consensus_manager.expect("Consensus Manager is not initialized."), ))) @@ -71,7 +73,8 @@ pub fn create_node_servers( ComponentExecutionMode::Disabled => None, }; let gateway_server = match config.components.gateway.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Box::new(create_gateway_server( components.gateway.expect("Gateway is not initialized."), communication.take_gateway_rx(), @@ -80,13 +83,15 @@ pub fn create_node_servers( ComponentExecutionMode::Disabled => None, }; let http_server = match config.components.http_server.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => Some(Box::new( + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => Some(Box::new( create_http_server(components.http_server.expect("Http Server is not initialized.")), )), ComponentExecutionMode::Disabled => None, }; let monitoring_endpoint_server = match config.components.monitoring_endpoint.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Box::new(create_monitoring_endpoint_server( components.monitoring_endpoint.expect("Monitoring Endpoint is not initialized."), ))) @@ -94,7 +99,8 @@ pub fn create_node_servers( ComponentExecutionMode::Disabled => None, }; let mempool_server = match config.components.mempool.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Box::new(create_mempool_server( components.mempool.expect("Mempool is not initialized."), communication.take_mempool_rx(), @@ -104,7 +110,8 @@ pub fn create_node_servers( }; let mempool_p2p_propagator_server = match config.components.mempool_p2p.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Box::new(create_mempool_p2p_propagator_server( components .mempool_p2p_propagator @@ -116,7 +123,8 @@ pub fn create_node_servers( }; let mempool_p2p_runner_server = match config.components.mempool_p2p.execution_mode { - ComponentExecutionMode::LocalExecution { enable_remote_connection: _ } => { + ComponentExecutionMode::LocalExecutionWithRemoteDisabled + | ComponentExecutionMode::LocalExecutionWithRemoteEnabled => { Some(Box::new(MempoolP2pRunnerServer::new( components.mempool_p2p_runner.expect("Mempool P2P Runner is not initialized."), ))) diff --git a/crates/mempool_node/src/test_utils/compilation.rs b/crates/sequencer_node/src/test_utils/compilation.rs similarity index 98% rename from crates/mempool_node/src/test_utils/compilation.rs rename to crates/sequencer_node/src/test_utils/compilation.rs index 3307385d37..645fe536ae 100644 --- a/crates/mempool_node/src/test_utils/compilation.rs +++ b/crates/sequencer_node/src/test_utils/compilation.rs @@ -25,6 +25,7 @@ fn compile_node() -> io::Result { let compilation_result = Command::new("cargo") .arg("build") .current_dir(&project_path) + .arg("--quiet") .stderr(Stdio::inherit()) .stdout(Stdio::inherit()) .status(); diff --git a/crates/mempool_node/src/test_utils/compilation_test.rs b/crates/sequencer_node/src/test_utils/compilation_test.rs similarity index 100% rename from crates/mempool_node/src/test_utils/compilation_test.rs rename to crates/sequencer_node/src/test_utils/compilation_test.rs diff --git a/crates/mempool_node/src/test_utils/mod.rs b/crates/sequencer_node/src/test_utils/mod.rs similarity index 100% rename from crates/mempool_node/src/test_utils/mod.rs rename to crates/sequencer_node/src/test_utils/mod.rs diff --git a/crates/mempool_node/src/utils.rs b/crates/sequencer_node/src/utils.rs similarity index 100% rename from crates/mempool_node/src/utils.rs rename to crates/sequencer_node/src/utils.rs diff --git a/crates/mempool_node/src/version.rs b/crates/sequencer_node/src/version.rs similarity index 100% rename from crates/mempool_node/src/version.rs rename to crates/sequencer_node/src/version.rs diff --git a/crates/mempool_node/src/version_test.rs b/crates/sequencer_node/src/version_test.rs similarity index 100% rename from crates/mempool_node/src/version_test.rs rename to crates/sequencer_node/src/version_test.rs diff --git a/crates/sequencing/papyrus_consensus/src/bin/run_simulation.rs b/crates/sequencing/papyrus_consensus/src/bin/run_simulation.rs index 6b425d72f2..90b877f18b 100644 --- a/crates/sequencing/papyrus_consensus/src/bin/run_simulation.rs +++ b/crates/sequencing/papyrus_consensus/src/bin/run_simulation.rs @@ -4,7 +4,6 @@ //! uses the `run_consensus` binary which is able to simulate network issues for consensus messages. use std::collections::HashSet; use std::fs::{self, File}; -use std::net::TcpListener; use std::os::unix::process::CommandExt; use std::process::Command; use std::str::FromStr; @@ -14,6 +13,7 @@ use clap::Parser; use fs2::FileExt; use lazy_static::lazy_static; use nix::unistd::Pid; +use papyrus_common::tcp::find_free_port; use tokio::process::Command as TokioCommand; lazy_static! { @@ -188,14 +188,6 @@ fn parse_duration(s: &str) -> Result { Ok(Duration::from_secs(secs)) } -fn find_free_port() -> u16 { - // The socket is automatically closed when the function exits. - // The port may still be available when accessed, but this is not guaranteed. - // TODO(Asmaa): find a reliable way to ensure the port stays free. - let listener = TcpListener::bind("0.0.0.0:0").expect("Failed to bind"); - listener.local_addr().expect("Failed to get local address").port() -} - // Returns if the simulation should exit. async fn monitor_simulation( nodes: &mut Vec, diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler.rs b/crates/sequencing/papyrus_consensus/src/stream_handler.rs index f2cee108dc..0531ba48a6 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler.rs @@ -1,4 +1,5 @@ //! Stream handler, see StreamManager struct. + use std::cmp::Ordering; use std::collections::btree_map::Entry as BTreeEntry; use std::collections::hash_map::Entry as HashMapEntry; @@ -6,7 +7,12 @@ use std::collections::{BTreeMap, HashMap}; use futures::channel::mpsc; use futures::StreamExt; -use papyrus_network::network_manager::BroadcastTopicServer; +use papyrus_network::network_manager::{ + BroadcastTopicClient, + BroadcastTopicClientTrait, + BroadcastTopicServer, +}; +use papyrus_network::utils::StreamHashMap; use papyrus_network_types::network_types::{BroadcastedMessageMetadata, OpaquePeerId}; use papyrus_protobuf::consensus::{StreamMessage, StreamMessageBody}; use papyrus_protobuf::converters::ProtobufConversionError; @@ -17,19 +23,18 @@ use tracing::{instrument, warn}; mod stream_handler_test; type PeerId = OpaquePeerId; +type StreamId = u64; type MessageId = u64; -type StreamKey = (PeerId, u64); +type StreamKey = (PeerId, StreamId); const CHANNEL_BUFFER_LENGTH: usize = 100; #[derive(Debug, Clone)] struct StreamData> + TryFrom, Error = ProtobufConversionError>> { next_message_id: MessageId, - // The message_id of the message that is marked as "fin" (the last message), - // if None, it means we have not yet gotten to it. + // Last message ID. If None, it means we have not yet gotten to it. fin_message_id: Option, max_message_id_received: MessageId, - // The sender that corresponds to the receiver that was sent out for this stream. sender: mpsc::Sender, // A buffer for messages that were received out of order. message_buffer: BTreeMap>, @@ -47,41 +52,84 @@ impl> + TryFrom, Error = ProtobufConversionError } } -/// A StreamHandler is responsible for buffering and sending messages in order. +/// A StreamHandler is responsible for: +/// - Buffering inbound messages and reporting them to the application in order. +/// - Sending outbound messages to the network, wrapped in StreamMessage. pub struct StreamHandler< T: Clone + Into> + TryFrom, Error = ProtobufConversionError>, > { - // An end of a channel used to send out receivers, one for each stream. + // For each stream ID from the network, send the application a Receiver + // that will receive the messages in order. This allows sending such Receivers. inbound_channel_sender: mpsc::Sender>, - // An end of a channel used to receive messages. + // This receives messages from the network. inbound_receiver: BroadcastTopicServer>, - // A map from stream_id to a struct that contains all the information about the stream. - // This includes both the message buffer and some metadata (like the latest message_id). + // A map from (peer_id, stream_id) to a struct that contains all the information + // about the stream. This includes both the message buffer and some metadata + // (like the latest message ID). inbound_stream_data: HashMap>, - // TODO(guyn): perhaps make input_stream_data and output_stream_data? + // Whenever application wants to start a new stream, it must send out a + // (stream_id, Receiver) pair. Each receiver gets messages that should + // be sent out to the network. + outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, + // A map where the abovementioned Receivers are stored. + outbound_stream_receivers: StreamHashMap>, + // A network sender that allows sending StreamMessages to peers. + outbound_sender: BroadcastTopicClient>, + // For each stream, keep track of the message_id of the last message sent. + outbound_stream_number: HashMap, } -impl> + TryFrom, Error = ProtobufConversionError>> +impl> + TryFrom, Error = ProtobufConversionError>> StreamHandler { /// Create a new StreamHandler. pub fn new( inbound_channel_sender: mpsc::Sender>, inbound_receiver: BroadcastTopicServer>, + outbound_channel_receiver: mpsc::Receiver<(StreamId, mpsc::Receiver)>, + outbound_sender: BroadcastTopicClient>, ) -> Self { - StreamHandler { + Self { inbound_channel_sender, inbound_receiver, inbound_stream_data: HashMap::new(), + outbound_channel_receiver, + outbound_sender, + outbound_stream_receivers: StreamHashMap::new(HashMap::new()), + outbound_stream_number: HashMap::new(), } } - /// Listen for messages on the receiver channel, buffering them if necessary. - /// Guarantees that messages are sent in order. + /// Listen for messages coming from the network and from the application. + /// - Outbound messages are wrapped as StreamMessage and sent to the network directly. + /// - Inbound messages are stripped of StreamMessage and buffered until they can be sent in the + /// correct order to the application. + #[instrument(skip_all)] pub async fn run(&mut self) { loop { - // TODO(guyn): this select is here to allow us to add the outbound flow. tokio::select!( + // Go over the channel receiver to see if there is a new channel. + Some((stream_id, receiver)) = self.outbound_channel_receiver.next() => { + self.outbound_stream_receivers.insert(stream_id, receiver); + } + // Go over all existing outbound receivers to see if there are any messages. + output = self.outbound_stream_receivers.next() => { + match output { + Some((key, Some(message))) => { + self.broadcast(key, message).await; + } + Some((key, None)) => { + self.broadcast_fin(key).await; + } + None => { + warn!( + "StreamHashMap should not be closed! \ + Usually only the individual channels are closed. " + ) + } + } + } + // Check if there is an inbound message from the network. Some(message) = self.inbound_receiver.next() => { self.handle_message(message); } @@ -98,6 +146,31 @@ impl> + TryFrom, Error = ProtobufConversionError } } + // Send the message to the network. + async fn broadcast(&mut self, stream_id: StreamId, message: T) { + let message = StreamMessage { + message: StreamMessageBody::Content(message), + stream_id, + message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0), + }; + // TODO(guyn): reconsider the "expect" here. + self.outbound_sender.broadcast_message(message).await.expect("Send should succeed"); + self.outbound_stream_number + .insert(stream_id, self.outbound_stream_number.get(&stream_id).unwrap_or(&0) + 1); + } + + // Send a fin message to the network. + async fn broadcast_fin(&mut self, stream_id: StreamId) { + let message = StreamMessage { + message: StreamMessageBody::Fin, + stream_id, + message_id: *self.outbound_stream_number.get(&stream_id).unwrap_or(&0), + }; + self.outbound_sender.broadcast_message(message).await.expect("Send should succeed"); + self.outbound_stream_number.remove(&stream_id); + } + + // Handle a message that was received from the network. #[instrument(skip_all, level = "warn")] fn handle_message( &mut self, @@ -134,7 +207,7 @@ impl> + TryFrom, Error = ProtobufConversionError data.max_message_id_received = message_id; } - // Check for Fin type message + // Check for Fin type message. match message.message { StreamMessageBody::Content(_) => {} StreamMessageBody::Fin => { @@ -167,7 +240,6 @@ impl> + TryFrom, Error = ProtobufConversionError match message_id.cmp(&data.next_message_id) { Ordering::Equal => { Self::inbound_send(data, message); - Self::process_buffer(data); if data.message_buffer.is_empty() && data.fin_message_id.is_some() { @@ -190,6 +262,7 @@ impl> + TryFrom, Error = ProtobufConversionError } } + // Store an inbound message in the buffer. fn store(data: &mut StreamData, key: StreamKey, message: StreamMessage) { let message_id = message.message_id; diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs index 962ead1e9e..0bd7250f12 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs @@ -9,21 +9,22 @@ use papyrus_network::network_manager::test_utils::{ TestSubscriberChannels, }; use papyrus_network::network_manager::BroadcastTopicChannels; +use papyrus_network_types::network_types::BroadcastedMessageMetadata; use papyrus_protobuf::consensus::{ConsensusMessage, Proposal, StreamMessage, StreamMessageBody}; use papyrus_test_utils::{get_rng, GetTestInstance}; -use super::StreamHandler; +use super::{MessageId, StreamHandler, StreamId}; + +const TIMEOUT: Duration = Duration::from_millis(100); +const CHANNEL_SIZE: usize = 100; #[cfg(test)] mod tests { - - use papyrus_network_types::network_types::BroadcastedMessageMetadata; - use super::*; fn make_test_message( - stream_id: u64, - message_id: u64, + stream_id: StreamId, + message_id: MessageId, fin: bool, ) -> StreamMessage { let content = match fin { @@ -53,26 +54,72 @@ mod tests { MockBroadcastedMessagesSender>, mpsc::Receiver>, BroadcastedMessageMetadata, + mpsc::Sender<(StreamId, mpsc::Receiver)>, + futures::stream::Map< + mpsc::Receiver>, + fn(Vec) -> StreamMessage, + >, ) { + // The outbound_sender is the network connector for broadcasting messages. + // The network_broadcast_receiver is used to catch those messages in the test. + let TestSubscriberChannels { mock_network: mock_broadcast_network, subscriber_channels } = + mock_register_broadcast_topic().unwrap(); + let BroadcastTopicChannels { + broadcasted_messages_receiver: _, + broadcast_topic_client: outbound_sender, + } = subscriber_channels; + + let network_broadcast_receiver = mock_broadcast_network.messages_to_broadcast_receiver; + + // This is used to feed receivers of messages to StreamHandler for broadcasting. + // The receiver goes into StreamHandler, sender is used by the test (as mock Consensus). + // Note that each new channel comes in a tuple with (stream_id, receiver). + let (outbound_channel_sender, outbound_channel_receiver) = + mpsc::channel::<(StreamId, mpsc::Receiver)>(CHANNEL_SIZE); + + // The network_sender_to_inbound is the sender of the mock network, that is used by the + // test to send messages into the StreamHandler (from the mock network). let TestSubscriberChannels { mock_network, subscriber_channels } = mock_register_broadcast_topic().unwrap(); - let network_sender = mock_network.broadcasted_messages_sender; - let BroadcastTopicChannels { broadcasted_messages_receiver, broadcast_topic_client: _ } = - subscriber_channels; + let network_sender_to_inbound = mock_network.broadcasted_messages_sender; + + // The inbound_receiver is given to StreamHandler to inbound to mock network messages. + let BroadcastTopicChannels { + broadcasted_messages_receiver: inbound_receiver, + broadcast_topic_client: _, + } = subscriber_channels; + + // The inbound_channel_sender is given to StreamHandler so it can output new channels for + // each stream. The inbound_channel_receiver is given to the "mock consensus" that + // gets new channels and inbounds to them. + let (inbound_channel_sender, inbound_channel_receiver) = + mpsc::channel::>(CHANNEL_SIZE); // TODO(guyn): We should also give the broadcast_topic_client to the StreamHandler - let (tx_output, rx_output) = mpsc::channel::>(100); - let handler = StreamHandler::new(tx_output, broadcasted_messages_receiver); + // This will allow reporting to the network things like bad peers. + let handler = StreamHandler::new( + inbound_channel_sender, + inbound_receiver, + outbound_channel_receiver, + outbound_sender, + ); - let broadcasted_message_metadata = - BroadcastedMessageMetadata::get_test_instance(&mut get_rng()); + let inbound_metadata = BroadcastedMessageMetadata::get_test_instance(&mut get_rng()); - (handler, network_sender, rx_output, broadcasted_message_metadata) + ( + handler, + network_sender_to_inbound, + inbound_channel_receiver, + inbound_metadata, + outbound_channel_sender, + network_broadcast_receiver, + ) } #[tokio::test] - async fn stream_handler_in_order() { - let (mut stream_handler, mut network_sender, mut rx_output, metadata) = setup_test(); + async fn inbound_in_order() { + let (mut stream_handler, mut network_sender, mut inbound_channel_receiver, metadata, _, _) = + setup_test(); let stream_id = 127; for i in 0..10 { @@ -81,12 +128,12 @@ mod tests { } let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; }); join_handle.await.expect("Task should succeed"); - let mut receiver = rx_output.next().await.unwrap(); + let mut receiver = inbound_channel_receiver.next().await.unwrap(); for _ in 0..9 { // message number 9 is Fin, so it will not be sent! let _ = receiver.next().await.unwrap(); @@ -96,23 +143,32 @@ mod tests { } #[tokio::test] - async fn stream_handler_in_reverse() { - let (mut stream_handler, mut network_sender, mut rx_output, metadata) = setup_test(); - let peer_id = metadata.originator_id.clone(); + async fn inbound_in_reverse() { + let ( + mut stream_handler, + mut network_sender, + mut inbound_channel_receiver, + inbound_metadata, + _, + _, + ) = setup_test(); + let peer_id = inbound_metadata.originator_id.clone(); let stream_id = 127; for i in 0..5 { let message = make_test_message(stream_id, 5 - i, i == 0); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } + + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); let mut stream_handler = join_handle.await.expect("Task should succeed"); // Get the receiver for the stream. - let mut receiver = rx_output.next().await.unwrap(); + let mut receiver = inbound_channel_receiver.next().await.unwrap(); // Check that the channel is empty (no messages were sent yet). assert!(receiver.try_next().is_err()); @@ -130,9 +186,11 @@ mod tests { assert!(do_vecs_match(&keys, &range)); // Now send the last message: - send(&mut network_sender, &metadata, make_test_message(stream_id, 0, false)).await; + send(&mut network_sender, &inbound_metadata, make_test_message(stream_id, 0, false)).await; + + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); @@ -148,9 +206,16 @@ mod tests { } #[tokio::test] - async fn stream_handler_multiple_streams() { - let (mut stream_handler, mut network_sender, mut rx_output, metadata) = setup_test(); - let peer_id = metadata.originator_id.clone(); + async fn inbound_multiple_streams() { + let ( + mut stream_handler, + mut network_sender, + mut inbound_channel_receiver, + inbound_metadata, + _, + _, + ) = setup_test(); + let peer_id = inbound_metadata.originator_id.clone(); let stream_id1 = 127; // Send all messages in order (except the first one). let stream_id2 = 10; // Send in reverse order (except the first one). @@ -158,30 +223,32 @@ mod tests { for i in 1..10 { let message = make_test_message(stream_id1, i, i == 9); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } for i in 0..5 { let message = make_test_message(stream_id2, 5 - i, i == 0); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } for i in 5..10 { let message = make_test_message(stream_id3, i, false); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } + for i in 1..5 { let message = make_test_message(stream_id3, i, false); - send(&mut network_sender, &metadata, message).await; + send(&mut network_sender, &inbound_metadata, message).await; } + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); let mut stream_handler = join_handle.await.expect("Task should succeed"); - let values = vec![(peer_id.clone(), 1), (peer_id.clone(), 10), (peer_id.clone(), 127)]; + let values = [(peer_id.clone(), 1), (peer_id.clone(), 10), (peer_id.clone(), 127)]; assert!( stream_handler .inbound_stream_data @@ -221,40 +288,38 @@ mod tests { )); // Get the receiver for the first stream. - let mut receiver1 = rx_output.next().await.unwrap(); + let mut receiver1 = inbound_channel_receiver.next().await.unwrap(); // Check that the channel is empty (no messages were sent yet). assert!(receiver1.try_next().is_err()); // Get the receiver for the second stream. - let mut receiver2 = rx_output.next().await.unwrap(); + let mut receiver2 = inbound_channel_receiver.next().await.unwrap(); // Check that the channel is empty (no messages were sent yet). assert!(receiver2.try_next().is_err()); // Get the receiver for the third stream. - let mut receiver3 = rx_output.next().await.unwrap(); + let mut receiver3 = inbound_channel_receiver.next().await.unwrap(); // Check that the channel is empty (no messages were sent yet). assert!(receiver3.try_next().is_err()); // Send the last message on stream_id1: - send(&mut network_sender, &metadata, make_test_message(stream_id1, 0, false)).await; + send(&mut network_sender, &inbound_metadata, make_test_message(stream_id1, 0, false)).await; + + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); - let mut stream_handler = join_handle.await.expect("Task should succeed"); - // Should be able to read all the messages for stream_id1. for _ in 0..9 { // message number 9 is Fin, so it will not be sent! let _ = receiver1.next().await.unwrap(); } - - // Check that the receiver was closed: - assert!(matches!(receiver1.try_next(), Ok(None))); + let mut stream_handler = join_handle.await.expect("Task should succeed"); // stream_id1 should be gone let values = [(peer_id.clone(), 1), (peer_id.clone(), 10)]; @@ -267,22 +332,21 @@ mod tests { ); // Send the last message on stream_id2: - send(&mut network_sender, &metadata, make_test_message(stream_id2, 0, false)).await; + send(&mut network_sender, &inbound_metadata, make_test_message(stream_id2, 0, false)).await; + + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); - let mut stream_handler = join_handle.await.expect("Task should succeed"); - // Should be able to read all the messages for stream_id2. for _ in 0..5 { // message number 5 is Fin, so it will not be sent! let _ = receiver2.next().await.unwrap(); } - // Check that the receiver was closed: - assert!(matches!(receiver2.try_next(), Ok(None))); + let mut stream_handler = join_handle.await.expect("Task should succeed"); // Stream_id2 should also be gone. let values = [(peer_id.clone(), 1)]; @@ -295,10 +359,11 @@ mod tests { ); // Send the last message on stream_id3: - send(&mut network_sender, &metadata, make_test_message(stream_id3, 0, false)).await; + send(&mut network_sender, &inbound_metadata, make_test_message(stream_id3, 0, false)).await; + // Run the loop for a short duration to process the message. let join_handle = tokio::spawn(async move { - let _ = tokio::time::timeout(Duration::from_millis(100), stream_handler.run()).await; + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; stream_handler }); @@ -308,9 +373,6 @@ mod tests { let _ = receiver3.next().await.unwrap(); } - // In this case the receiver is not closed, because we didn't send a fin. - assert!(receiver3.try_next().is_err()); - // Stream_id3 should still be there, because we didn't send a fin. let values = [(peer_id.clone(), 1)]; assert!( @@ -326,4 +388,126 @@ mod tests { stream_handler.inbound_stream_data[&(peer_id, stream_id3)].message_buffer.is_empty() ); } + + // This test does two things: + // 1. Opens two outbound channels and checks that messages get correctly sent on both. + // 2. Closes the first channel and checks that Fin is sent and that the relevant structures + // inside the stream handler are cleaned up. + #[tokio::test] + async fn outbound_multiple_streams() { + let ( + mut stream_handler, + _, + _, + _, + mut broadcast_channel_sender, + mut broadcasted_messages_receiver, + ) = setup_test(); + + let stream_id1: StreamId = 42; + let stream_id2: StreamId = 127; + + // Start a new stream by sending the (stream_id, receiver). + let (mut sender1, receiver1) = mpsc::channel(CHANNEL_SIZE); + broadcast_channel_sender.send((stream_id1, receiver1)).await.unwrap(); + + // Send a message on the stream. + let message1 = ConsensusMessage::Proposal(Proposal::default()); + sender1.send(message1.clone()).await.unwrap(); + + // Run the loop for a short duration to process the message. + let join_handle = tokio::spawn(async move { + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; + stream_handler + }); + + // Wait for an incoming message. + let broadcasted_message = broadcasted_messages_receiver.next().await.unwrap(); + let mut stream_handler = join_handle.await.expect("Task should succeed"); + + // Check that message was broadcasted. + assert_eq!(broadcasted_message.message, StreamMessageBody::Content(message1)); + assert_eq!(broadcasted_message.stream_id, stream_id1); + assert_eq!(broadcasted_message.message_id, 0); + + // Check that internally, stream_handler holds this receiver. + assert_eq!( + stream_handler.outbound_stream_receivers.keys().collect::>(), + vec![&stream_id1] + ); + // Check that the number of messages sent on this stream is 1. + assert_eq!(stream_handler.outbound_stream_number[&stream_id1], 1); + + // Send another message on the same stream. + let message2 = ConsensusMessage::Proposal(Proposal::default()); + sender1.send(message2.clone()).await.unwrap(); + + // Run the loop for a short duration to process the message. + let join_handle = tokio::spawn(async move { + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; + stream_handler + }); + + // Wait for an incoming message. + let broadcasted_message = broadcasted_messages_receiver.next().await.unwrap(); + + let mut stream_handler = join_handle.await.expect("Task should succeed"); + + // Check that message was broadcasted. + assert_eq!(broadcasted_message.message, StreamMessageBody::Content(message2)); + assert_eq!(broadcasted_message.stream_id, stream_id1); + assert_eq!(broadcasted_message.message_id, 1); + assert_eq!(stream_handler.outbound_stream_number[&stream_id1], 2); + + // Start a new stream by sending the (stream_id, receiver). + let (mut sender2, receiver2) = mpsc::channel(CHANNEL_SIZE); + broadcast_channel_sender.send((stream_id2, receiver2)).await.unwrap(); + + // Send a message on the stream. + let message3 = ConsensusMessage::Proposal(Proposal::default()); + sender2.send(message3.clone()).await.unwrap(); + + // Run the loop for a short duration to process the message. + let join_handle = tokio::spawn(async move { + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; + stream_handler + }); + + // Wait for an incoming message. + let broadcasted_message = broadcasted_messages_receiver.next().await.unwrap(); + + let mut stream_handler = join_handle.await.expect("Task should succeed"); + + // Check that message was broadcasted. + assert_eq!(broadcasted_message.message, StreamMessageBody::Content(message3)); + assert_eq!(broadcasted_message.stream_id, stream_id2); + assert_eq!(broadcasted_message.message_id, 0); + let mut vec1 = stream_handler.outbound_stream_receivers.keys().collect::>(); + vec1.sort(); + let mut vec2 = vec![&stream_id1, &stream_id2]; + vec2.sort(); + do_vecs_match(&vec1, &vec2); + assert_eq!(stream_handler.outbound_stream_number[&stream_id2], 1); + + // Close the first channel. + sender1.close_channel(); + + // Run the loop for a short duration to process that the channel was closed. + let join_handle = tokio::spawn(async move { + let _ = tokio::time::timeout(TIMEOUT, stream_handler.run()).await; + stream_handler + }); + + // Check that we got a fin message. + let broadcasted_message = broadcasted_messages_receiver.next().await.unwrap(); + assert_eq!(broadcasted_message.message, StreamMessageBody::Fin); + + let stream_handler = join_handle.await.expect("Task should succeed"); + + // Check that the information about this stream is gone. + assert_eq!( + stream_handler.outbound_stream_receivers.keys().collect::>(), + vec![&stream_id2] + ); + } } diff --git a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs index 148ee8e4ce..0e8f237027 100644 --- a/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs +++ b/crates/sequencing/papyrus_consensus_orchestrator/src/sequencer_consensus_context.rs @@ -140,8 +140,15 @@ impl ConsensusContext for SequencerConsensusContext { let chrono_timeout = chrono::Duration::from_std(timeout).expect("Can't convert timeout to chrono::Duration"); - let input = - ValidateProposalInput { proposal_id, deadline: chrono::Utc::now() + chrono_timeout }; + let input = ValidateProposalInput { + proposal_id, + deadline: chrono::Utc::now() + chrono_timeout, + // TODO(Matan 3/11/2024): Add the real value of the retrospective block hash. + retrospective_block_hash: Some(BlockHashAndNumber { + number: BlockNumber::default(), + hash: BlockHash::default(), + }), + }; self.maybe_start_height(height).await; batcher.validate_proposal(input).await.expect("Failed to initiate proposal validation"); tokio::spawn( diff --git a/crates/starknet_api/src/contract_class.rs b/crates/starknet_api/src/contract_class.rs index 6470347578..e633e3e4ae 100644 --- a/crates/starknet_api/src/contract_class.rs +++ b/crates/starknet_api/src/contract_class.rs @@ -22,7 +22,7 @@ pub enum EntryPointType { } /// Represents a raw Starknet contract class. -#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize, derive_more::From)] pub enum ContractClass { V0(DeprecatedContractClass), V1(CasmContractClass), diff --git a/crates/starknet_api/src/deprecated_contract_class.rs b/crates/starknet_api/src/deprecated_contract_class.rs index 81b56ffaf2..5ea52037af 100644 --- a/crates/starknet_api/src/deprecated_contract_class.rs +++ b/crates/starknet_api/src/deprecated_contract_class.rs @@ -26,6 +26,12 @@ pub struct ContractClass { pub entry_points_by_type: HashMap>, } +impl ContractClass { + pub fn bytecode_length(&self) -> usize { + self.program.data.as_array().expect("The program data must be an array.").len() + } +} + /// A [ContractClass](`crate::deprecated_contract_class::ContractClass`) abi entry. // Using untagged so the serialization will be sorted by the keys (the default behavior of Serde for // untagged enums). We care about the order of the fields in the serialization because it affects diff --git a/crates/starknet_api/src/executable_transaction.rs b/crates/starknet_api/src/executable_transaction.rs index 3b4b258a54..8c3ad500f9 100644 --- a/crates/starknet_api/src/executable_transaction.rs +++ b/crates/starknet_api/src/executable_transaction.rs @@ -14,6 +14,7 @@ use crate::transaction::{ AllResourceBounds, Calldata, ContractAddressSalt, + Fee, PaymasterData, Tip, TransactionHash, @@ -292,3 +293,17 @@ impl InvokeTransaction { Self::create(invoke_tx, chain_id) } } + +#[derive(Clone, Debug)] +pub struct L1HandlerTransaction { + pub tx: crate::transaction::L1HandlerTransaction, + pub tx_hash: TransactionHash, + pub paid_fee_on_l1: Fee, +} + +impl L1HandlerTransaction { + pub fn payload_size(&self) -> usize { + // The calldata includes the "from" field, which is not a part of the payload. + self.tx.calldata.0.len() - 1 + } +} diff --git a/crates/tests-integration/Cargo.toml b/crates/tests-integration/Cargo.toml index 294c3470fb..2b96fed2ec 100644 --- a/crates/tests-integration/Cargo.toml +++ b/crates/tests-integration/Cargo.toml @@ -34,7 +34,7 @@ starknet_gateway = { workspace = true, features = ["testing"] } starknet_gateway_types.workspace = true starknet_http_server.workspace = true starknet_sequencer_infra.workspace = true -starknet_sequencer_node.workspace = true +starknet_sequencer_node = { workspace = true, features = ["testing"] } starknet_task_executor.workspace = true strum.workspace = true tempfile.workspace = true diff --git a/crates/tests-integration/src/bin/run_test_rpc_state_reader.rs b/crates/tests-integration/src/bin/run_test_rpc_state_reader.rs index be5ee0f285..9afbc08eff 100644 --- a/crates/tests-integration/src/bin/run_test_rpc_state_reader.rs +++ b/crates/tests-integration/src/bin/run_test_rpc_state_reader.rs @@ -24,11 +24,12 @@ async fn main() -> anyhow::Result<()> { .await; // Derive the configuration for the sequencer node. - let config = create_config(rpc_server_addr, storage_for_test.batcher_storage_config).await; + let (config, required_params) = + create_config(rpc_server_addr, storage_for_test.batcher_storage_config).await; // Note: the batcher storage file handle is passed as a reference to maintain its ownership in // this scope, such that the handle is not dropped and the storage is maintained. - dump_config_file_changes(config)?; + dump_config_file_changes(config, required_params)?; // Keep the program running so the rpc state reader server, its storage, and the batcher // storage, are all maintained. diff --git a/crates/tests-integration/src/integration_test_config_utils.rs b/crates/tests-integration/src/integration_test_config_utils.rs index b393f2cd50..bef874bb5d 100644 --- a/crates/tests-integration/src/integration_test_config_utils.rs +++ b/crates/tests-integration/src/integration_test_config_utils.rs @@ -2,6 +2,7 @@ use std::fs::File; use std::io::Write; use serde_json::{json, Value}; +use starknet_sequencer_node::config::test_utils::RequiredParams; use starknet_sequencer_node::config::SequencerNodeConfig; use tokio::io::Result; use tracing::info; @@ -42,24 +43,27 @@ macro_rules! config_fields_to_json { /// cargo run --bin starknet_sequencer_node -- --config_file NODE_CONFIG_CHANGES_FILE_PATH /// Transaction generator: /// cargo run --bin run_test_tx_generator -- --config_file TX_GEN_CONFIG_CHANGES_FILE_PATH -pub fn dump_config_file_changes(config: SequencerNodeConfig) -> anyhow::Result<()> { +pub fn dump_config_file_changes( + config: SequencerNodeConfig, + required_params: RequiredParams, +) -> anyhow::Result<()> { // Dump config changes file for the sequencer node. let json_data = config_fields_to_json!( - config.chain_id, + required_params.chain_id, + required_params.eth_fee_token_address, + required_params.strk_fee_token_address, config.rpc_state_reader_config.json_rpc_version, config.rpc_state_reader_config.url, config.batcher_config.storage.db_config.path_prefix, config.http_server_config.ip, config.http_server_config.port, - config.gateway_config.chain_info.fee_token_addresses.eth_fee_token_address, - config.gateway_config.chain_info.fee_token_addresses.strk_fee_token_address, config.consensus_manager_config.consensus_config.start_height, ); dump_json_data(json_data, NODE_CONFIG_CHANGES_FILE_PATH)?; // Dump config changes file for the transaction generator. let json_data = config_fields_to_json!( - config.chain_id, + required_params.chain_id, config.http_server_config.ip, config.http_server_config.port, ); @@ -81,7 +85,10 @@ fn dump_json_data(json_data: Value, path: &str) -> Result<()> { Ok(()) } -/// Strips the "config." prefix from the input string. +/// Strips the "config." and "required_params." prefixes from the input string. fn strip_config_prefix(input: &str) -> &str { - input.strip_prefix("config.").unwrap_or(input) + input + .strip_prefix("config.") + .or_else(|| input.strip_prefix("required_params.")) + .unwrap_or(input) } diff --git a/crates/tests-integration/src/integration_test_setup.rs b/crates/tests-integration/src/integration_test_setup.rs index cb7dec0d37..28eb4e1af6 100644 --- a/crates/tests-integration/src/integration_test_setup.rs +++ b/crates/tests-integration/src/integration_test_setup.rs @@ -53,7 +53,8 @@ impl IntegrationTestSetup { .await; // Derive the configuration for the mempool node. - let config = create_config(rpc_server_addr, storage_for_test.batcher_storage_config).await; + let (config, _required_params) = + create_config(rpc_server_addr, storage_for_test.batcher_storage_config).await; let (clients, servers) = create_node_modules(&config); diff --git a/crates/tests-integration/src/integration_test_utils.rs b/crates/tests-integration/src/integration_test_utils.rs index 2f05beecd6..4740fb91bd 100644 --- a/crates/tests-integration/src/integration_test_utils.rs +++ b/crates/tests-integration/src/integration_test_utils.rs @@ -29,6 +29,7 @@ use starknet_gateway::config::{ use starknet_gateway_types::errors::GatewaySpecError; use starknet_http_server::config::HttpServerConfig; use starknet_sequencer_node::config::component_config::ComponentConfig; +use starknet_sequencer_node::config::test_utils::RequiredParams; use starknet_sequencer_node::config::{ ComponentExecutionConfig, ComponentExecutionMode, @@ -39,19 +40,22 @@ use tokio::net::TcpListener; pub async fn create_config( rpc_server_addr: SocketAddr, batcher_storage_config: StorageConfig, -) -> SequencerNodeConfig { +) -> (SequencerNodeConfig, RequiredParams) { // TODO(Arni/ Matan): Enable the consensus in the end to end test. let components = ComponentConfig { consensus_manager: ComponentExecutionConfig { execution_mode: ComponentExecutionMode::Disabled, + local_server_config: None, ..Default::default() }, ..Default::default() }; let chain_id = batcher_storage_config.db_config.chain_id.clone(); + // TODO(Tsabary): create chain_info in setup, and pass relevant values throughout. let mut chain_info = ChainInfo::create_for_testing(); chain_info.chain_id = chain_id.clone(); + let fee_token_addresses = chain_info.fee_token_addresses.clone(); let batcher_config = create_batcher_config(batcher_storage_config, chain_info.clone()); let gateway_config = create_gateway_config(chain_info).await; let http_server_config = create_http_server_config().await; @@ -59,16 +63,22 @@ pub async fn create_config( let consensus_manager_config = ConsensusManagerConfig { consensus_config: ConsensusConfig { start_height: BlockNumber(1), ..Default::default() }, }; - SequencerNodeConfig { - chain_id, - components, - batcher_config, - consensus_manager_config, - gateway_config, - http_server_config, - rpc_state_reader_config, - ..SequencerNodeConfig::default() - } + ( + SequencerNodeConfig { + components, + batcher_config, + consensus_manager_config, + gateway_config, + http_server_config, + rpc_state_reader_config, + ..SequencerNodeConfig::default() + }, + RequiredParams { + chain_id, + eth_fee_token_address: fee_token_addresses.eth_fee_token_address, + strk_fee_token_address: fee_token_addresses.strk_fee_token_address, + }, + ) } pub fn test_rpc_state_reader_config(rpc_server_addr: SocketAddr) -> RpcStateReaderConfig { diff --git a/deployments/sequencer/main.py b/deployments/sequencer/main.py index 159a2db70e..6cdd6df84f 100644 --- a/deployments/sequencer/main.py +++ b/deployments/sequencer/main.py @@ -23,6 +23,7 @@ class SystemStructure: def __post_init__(self): self.config.validate() + class SequencerSystem(Chart): def __init__( self, diff --git a/deployments/sequencer/services/defaults.py b/deployments/sequencer/services/defaults.py index 10f348f407..3eb865c338 100644 --- a/deployments/sequencer/services/defaults.py +++ b/deployments/sequencer/services/defaults.py @@ -4,4 +4,4 @@ startup_probe=Probe(port="http", path="/", period_seconds=5, failure_threshold=10, timeout_seconds=5), readiness_probe=Probe(port="http", path="/", period_seconds=5, failure_threshold=10, timeout_seconds=5), liveness_probe=Probe(port="http", path="/", period_seconds=5, failure_threshold=10, timeout_seconds=5) -) \ No newline at end of file +) diff --git a/deployments/sequencer/services/service.py b/deployments/sequencer/services/service.py index 31c8140e62..71d5b65daa 100644 --- a/deployments/sequencer/services/service.py +++ b/deployments/sequencer/services/service.py @@ -1,5 +1,6 @@ import json import dataclasses + from typing import Optional, Dict, Union from constructs import Construct from cdk8s import Names