diff --git a/crates/starknet_batcher/src/batcher.rs b/crates/starknet_batcher/src/batcher.rs index 8e0a341cf02..1f2006696e4 100644 --- a/crates/starknet_batcher/src/batcher.rs +++ b/crates/starknet_batcher/src/batcher.rs @@ -1,15 +1,19 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use blockifier::abi::constants; use blockifier::state::global_cache::GlobalContractCache; use chrono::Utc; +use indexmap::IndexMap; #[cfg(test)] use mockall::automock; use papyrus_storage::state::{StateStorageReader, StateStorageWriter}; use starknet_api::block::{BlockHashAndNumber, BlockNumber}; +use starknet_api::block_hash::state_diff_hash::calculate_state_diff_hash; +use starknet_api::core::{ContractAddress, Nonce}; use starknet_api::executable_transaction::Transaction; use starknet_api::state::ThinStateDiff; +use starknet_api::transaction::TransactionHash; use starknet_batcher_types::batcher_types::{ BatcherResult, DecisionReachedInput, @@ -30,24 +34,20 @@ use starknet_batcher_types::errors::BatcherError; use starknet_mempool_types::communication::SharedMempoolClient; use starknet_mempool_types::mempool_types::CommitBlockArgs; use starknet_sequencer_infra::component_definitions::ComponentStarter; -use tracing::{debug, error, info, instrument, trace}; +use thiserror::Error; +use tokio::sync::Mutex; +use tracing::{debug, error, info, instrument, trace, Instrument}; use crate::block_builder::{ BlockBuilderError, BlockBuilderExecutionParams, BlockBuilderFactory, BlockBuilderFactoryTrait, + BlockBuilderTrait, + BlockExecutionArtifacts, BlockMetadata, }; use crate::config::BatcherConfig; -use crate::proposal_manager::{ - GenerateProposalError, - ProposalError, - ProposalManager, - ProposalManagerTrait, - ProposalOutput, - ProposalResult, -}; use crate::transaction_provider::{ DummyL1ProviderClient, ProposeTransactionProvider, @@ -56,6 +56,60 @@ use crate::transaction_provider::{ type OutputStreamReceiver = tokio::sync::mpsc::UnboundedReceiver; type InputStreamSender = tokio::sync::mpsc::Sender; +pub(crate) type ProposalResult = Result; + +// Represents an error completing a proposal. +#[derive(Clone, Debug, Error)] +pub(crate) enum ProposalError { + #[error(transparent)] + BlockBuilderError(Arc), + #[error("Proposal was aborted")] + Aborted, +} + +// Represents a spawned task of building new block proposal. +struct ProposalTask { + pub abort_signal_sender: tokio::sync::oneshot::Sender<()>, + pub join_handle: tokio::task::JoinHandle<()>, +} + +#[derive(Debug, PartialEq)] +pub(crate) struct ProposalOutput { + pub state_diff: ThinStateDiff, + pub commitment: ProposalCommitment, + pub tx_hashes: HashSet, + pub nonces: HashMap, +} + +impl From for ProposalOutput { + fn from(artifacts: BlockExecutionArtifacts) -> Self { + let commitment_state_diff = artifacts.commitment_state_diff; + let nonces = HashMap::from_iter( + commitment_state_diff + .address_to_nonce + .iter() + .map(|(address, nonce)| (*address, *nonce)), + ); + + // TODO: Get these from the transactions. + let deployed_contracts = IndexMap::new(); + let declared_classes = IndexMap::new(); + let state_diff = ThinStateDiff { + deployed_contracts, + storage_diffs: commitment_state_diff.storage_updates, + declared_classes, + nonces: commitment_state_diff.address_to_nonce, + // TODO: Remove this when the structure of storage diffs changes. + deprecated_declared_classes: Vec::new(), + replaced_classes: IndexMap::new(), + }; + let commitment = + ProposalCommitment { state_diff_commitment: calculate_state_diff_hash(&state_diff) }; + let tx_hashes = HashSet::from_iter(artifacts.execution_infos.keys().copied()); + + Self { state_diff, commitment, tx_hashes, nonces } + } +} pub struct Batcher { pub config: BatcherConfig, @@ -63,11 +117,30 @@ pub struct Batcher { pub storage_writer: Box, pub mempool_client: SharedMempoolClient, + // Used to create block builders. + // Using the factory pattern to allow for easier testing. + block_builder_factory: Box, + + // The height that the batcher is currently working on. + // All proposals are considered to be at this height. active_height: Option, - proposal_manager: Box, - block_builder_factory: Box, + // The block proposal that is currently being built, if any. + // At any given time, there can be only one proposal being actively executed (either proposed + // or validated). + active_proposal: Arc>>, + active_proposal_task: Option, + + // Holds all the proposals that completed execution in the current height. + executed_proposals: Arc>>>, + + // The propose blocks transaction streams, used to stream out the proposal transactions. + // Each stream is kept until all the transactions are streamed out, or a new height is started. propose_tx_streams: HashMap, + + // The validate blocks transaction streams, used to stream in the transactions to validate. + // Each stream is kept until SendProposalContent::Finish/Abort is received, or a new height is + // started. validate_tx_streams: HashMap, } @@ -78,16 +151,17 @@ impl Batcher { storage_writer: Box, mempool_client: SharedMempoolClient, block_builder_factory: Box, - proposal_manager: Box, ) -> Self { Self { config: config.clone(), storage_reader, storage_writer, mempool_client, - active_height: None, block_builder_factory, - proposal_manager, + active_height: None, + active_proposal: Arc::new(Mutex::new(None)), + active_proposal_task: None, + executed_proposals: Arc::new(Mutex::new(HashMap::new())), propose_tx_streams: HashMap::new(), validate_tx_streams: HashMap::new(), } @@ -115,7 +189,8 @@ impl Batcher { } // Clear all the proposals from the previous height. - self.proposal_manager.reset().await; + self.abort_active_proposal().await; + self.executed_proposals.lock().await.clear(); self.propose_tx_streams.clear(); self.validate_tx_streams.clear(); @@ -137,6 +212,8 @@ impl Batcher { propose_block_input.retrospective_block_hash, )?; + self.set_active_proposal(propose_block_input.proposal_id).await?; + let tx_provider = ProposeTransactionProvider::new( self.mempool_client.clone(), // TODO: use a real L1 provider client. @@ -163,8 +240,7 @@ impl Batcher { ) .map_err(|_| BatcherError::InternalError)?; - self.proposal_manager - .spawn_proposal(propose_block_input.proposal_id, block_builder, abort_signal_sender) + self.spawn_proposal(propose_block_input.proposal_id, block_builder, abort_signal_sender) .await?; self.propose_tx_streams.insert(propose_block_input.proposal_id, output_tx_receiver); @@ -183,6 +259,8 @@ impl Batcher { validate_block_input.retrospective_block_hash, )?; + self.set_active_proposal(validate_block_input.proposal_id).await?; + // A channel to send the transactions to include in the block being validated. let (input_tx_sender, input_tx_receiver) = tokio::sync::mpsc::channel(self.config.input_stream_content_buffer_size); @@ -209,8 +287,7 @@ impl Batcher { ) .map_err(|_| BatcherError::InternalError)?; - self.proposal_manager - .spawn_proposal(validate_block_input.proposal_id, block_builder, abort_signal_sender) + self.spawn_proposal(validate_block_input.proposal_id, block_builder, abort_signal_sender) .await?; self.validate_tx_streams.insert(validate_block_input.proposal_id, input_tx_sender); @@ -232,11 +309,7 @@ impl Batcher { match send_proposal_content_input.content { SendProposalContent::Txs(txs) => self.handle_send_txs_request(proposal_id, txs).await, SendProposalContent::Finish => self.handle_finish_proposal_request(proposal_id).await, - SendProposalContent::Abort => { - self.proposal_manager.abort_proposal(proposal_id).await; - self.close_input_transaction_stream(proposal_id)?; - Ok(SendProposalContentResponse { response: ProposalStatus::Aborted }) - } + SendProposalContent::Abort => self.handle_abort_proposal_request(proposal_id).await, } } @@ -264,6 +337,8 @@ impl Batcher { let proposal_result = self.get_completed_proposal_result(proposal_id).await.expect("Proposal should exist."); match proposal_result { + // TODO(dafna): at this point the proposal result must be an error, since it finsisehd + // earlier than expected. Consider panicking instead of returning an error. Ok(_) => Err(BatcherError::ProposalAlreadyFinished { proposal_id }), Err(err) => Ok(SendProposalContentResponse { response: proposal_status_from(err)? }), } @@ -275,9 +350,11 @@ impl Batcher { ) -> BatcherResult { debug!("Send proposal content done for {}", proposal_id); - self.close_input_transaction_stream(proposal_id)?; + self.validate_tx_streams.remove(&proposal_id); if self.is_active(proposal_id).await { - self.proposal_manager.await_active_proposal().await; + if let Some(proposal_task) = self.active_proposal_task.take() { + proposal_task.join_handle.await.ok(); + } } let proposal_result = @@ -289,11 +366,16 @@ impl Batcher { Ok(SendProposalContentResponse { response: proposal_status }) } - fn close_input_transaction_stream(&mut self, proposal_id: ProposalId) -> BatcherResult<()> { - self.validate_tx_streams - .remove(&proposal_id) - .ok_or(BatcherError::ProposalNotFound { proposal_id })?; - Ok(()) + async fn handle_abort_proposal_request( + &mut self, + proposal_id: ProposalId, + ) -> BatcherResult { + if self.is_active(proposal_id).await { + self.abort_active_proposal().await; + self.executed_proposals.lock().await.insert(proposal_id, Err(ProposalError::Aborted)); + } + self.validate_tx_streams.remove(&proposal_id); + Ok(SendProposalContentResponse { response: ProposalStatus::Aborted }) } #[instrument(skip(self), err)] @@ -332,39 +414,96 @@ impl Batcher { #[instrument(skip(self), err)] pub async fn decision_reached(&mut self, input: DecisionReachedInput) -> BatcherResult<()> { + let height = self.active_height.ok_or(BatcherError::NoActiveHeight)?; + let proposal_id = input.proposal_id; - let proposal_output = self - .proposal_manager - .take_proposal_result(proposal_id) - .await - .ok_or(BatcherError::ExecutedProposalNotFound { proposal_id })??; + let proposal_result = self.executed_proposals.lock().await.remove(&proposal_id); let ProposalOutput { state_diff, nonces: address_to_nonce, tx_hashes, .. } = - proposal_output; - // TODO: Keep the height from start_height or get it from the input. - let height = self.storage_reader.height().map_err(|err| { - error!("Failed to get height from storage: {}", err); - BatcherError::InternalError - })?; + proposal_result.ok_or(BatcherError::ExecutedProposalNotFound { proposal_id })??; + info!( "Committing proposal {} at height {} and notifying mempool of the block.", proposal_id, height ); trace!("Transactions: {:#?}, State diff: {:#?}.", tx_hashes, state_diff); + + // Commit the proposal to the storage and notify the mempool. self.storage_writer.commit_proposal(height, state_diff).map_err(|err| { error!("Failed to commit proposal to storage: {}", err); BatcherError::InternalError })?; - if let Err(mempool_err) = - self.mempool_client.commit_block(CommitBlockArgs { address_to_nonce, tx_hashes }).await - { + let mempool_result = + self.mempool_client.commit_block(CommitBlockArgs { address_to_nonce, tx_hashes }).await; + + if let Err(mempool_err) = mempool_result { error!("Failed to commit block to mempool: {}", mempool_err); // TODO: Should we rollback the state diff and return an error? - } + }; + Ok(()) } async fn is_active(&self, proposal_id: ProposalId) -> bool { - self.proposal_manager.get_active_proposal().await == Some(proposal_id) + *self.active_proposal.lock().await == Some(proposal_id) + } + + // Sets a new active proposal task. + // Fails if there is another proposal being currently generated, or a proposal with the same ID + // already exists. + async fn set_active_proposal(&mut self, proposal_id: ProposalId) -> BatcherResult<()> { + if self.executed_proposals.lock().await.contains_key(&proposal_id) { + return Err(BatcherError::ProposalAlreadyExists { proposal_id }); + } + + let mut active_proposal = self.active_proposal.lock().await; + if let Some(active_proposal_id) = *active_proposal { + return Err(BatcherError::ServerBusy { + active_proposal_id, + new_proposal_id: proposal_id, + }); + } + + debug!("Set proposal {} as the one being generated.", proposal_id); + *active_proposal = Some(proposal_id); + Ok(()) + } + + // Starts a new block proposal generation task for the given proposal_id. + // Uses the given block_builder to generate the proposal. + #[instrument(skip(self, block_builder), err)] + async fn spawn_proposal( + &mut self, + proposal_id: ProposalId, + mut block_builder: Box, + abort_signal_sender: tokio::sync::oneshot::Sender<()>, + ) -> BatcherResult<()> { + info!("Starting generation of a new proposal with id {}.", proposal_id); + + let active_proposal = self.active_proposal.clone(); + let executed_proposals = self.executed_proposals.clone(); + + let join_handle = tokio::spawn( + async move { + let result = block_builder + .build_block() + .await + .map(ProposalOutput::from) + .map_err(|e| ProposalError::BlockBuilderError(Arc::new(e))); + + // The proposal is done, clear the active proposal. + // Keep the proposal result only if it is the same as the active proposal. + // The active proposal might have changed if this proposal was aborted. + let mut active_proposal = active_proposal.lock().await; + if *active_proposal == Some(proposal_id) { + active_proposal.take(); + executed_proposals.lock().await.insert(proposal_id, result); + } + } + .in_current_span(), + ); + + self.active_proposal_task = Some(ProposalTask { abort_signal_sender, join_handle }); + Ok(()) } // Returns a completed proposal result, either its commitment or an error if the proposal @@ -373,8 +512,7 @@ impl Batcher { &self, proposal_id: ProposalId, ) -> Option> { - let completed_proposals = self.proposal_manager.get_completed_proposals().await; - let guard = completed_proposals.lock().await; + let guard = self.executed_proposals.lock().await; let proposal_result = guard.get(&proposal_id); match proposal_result { @@ -383,6 +521,22 @@ impl Batcher { None => None, } } + + // Ends the current active proposal. + // This call is non-blocking. + async fn abort_active_proposal(&mut self) { + self.active_proposal.lock().await.take(); + if let Some(proposal_task) = self.active_proposal_task.take() { + proposal_task.abort_signal_sender.send(()).ok(); + } + } + + #[cfg(test)] + pub async fn await_active_proposal(&mut self) { + if let Some(proposal_task) = self.active_proposal_task.take() { + proposal_task.join_handle.await.ok(); + } + } } pub fn create_batcher(config: BatcherConfig, mempool_client: SharedMempoolClient) -> Batcher { @@ -396,15 +550,7 @@ pub fn create_batcher(config: BatcherConfig, mempool_client: SharedMempoolClient }); let storage_reader = Arc::new(storage_reader); let storage_writer = Box::new(storage_writer); - let proposal_manager = Box::new(ProposalManager::new()); - Batcher::new( - config, - storage_reader, - storage_writer, - mempool_client, - block_builder_factory, - proposal_manager, - ) + Batcher::new(config, storage_reader, storage_writer, mempool_client, block_builder_factory) } #[cfg_attr(test, automock)] @@ -439,25 +585,6 @@ impl BatcherStorageWriterTrait for papyrus_storage::StorageWriter { } } -impl From for BatcherError { - fn from(err: GenerateProposalError) -> Self { - match err { - GenerateProposalError::AlreadyGeneratingProposal { - current_generating_proposal_id, - new_proposal_id, - } => BatcherError::ServerBusy { - active_proposal_id: current_generating_proposal_id, - new_proposal_id, - }, - GenerateProposalError::BlockBuilderError(..) => BatcherError::InternalError, - GenerateProposalError::NoActiveHeight => BatcherError::NoActiveHeight, - GenerateProposalError::ProposalAlreadyExists { proposal_id } => { - BatcherError::ProposalAlreadyExists { proposal_id } - } - } - } -} - impl From for BatcherError { fn from(err: ProposalError) -> Self { match err { diff --git a/crates/starknet_batcher/src/batcher_test.rs b/crates/starknet_batcher/src/batcher_test.rs index 5f62087bc8f..c520a3bfebd 100644 --- a/crates/starknet_batcher/src/batcher_test.rs +++ b/crates/starknet_batcher/src/batcher_test.rs @@ -1,23 +1,13 @@ -use std::collections::{HashMap, HashSet}; use std::sync::Arc; use assert_matches::assert_matches; -use async_trait::async_trait; use blockifier::abi::constants; use blockifier::test_utils::struct_impls::BlockInfoExt; use chrono::Utc; -use futures::future::BoxFuture; -use futures::FutureExt; -use mockall::automock; -use mockall::predicate::{always, eq}; +use mockall::predicate::eq; use rstest::rstest; use starknet_api::block::{BlockInfo, BlockNumber}; -use starknet_api::core::{ContractAddress, Nonce, StateDiffCommitment}; use starknet_api::executable_transaction::Transaction; -use starknet_api::hash::PoseidonHash; -use starknet_api::state::ThinStateDiff; -use starknet_api::transaction::TransactionHash; -use starknet_api::{contract_address, felt, nonce, tx_hash}; use starknet_batcher_types::batcher_types::{ DecisionReachedInput, GetProposalContent, @@ -36,27 +26,23 @@ use starknet_batcher_types::batcher_types::{ use starknet_batcher_types::errors::BatcherError; use starknet_mempool_types::communication::MockMempoolClient; use starknet_mempool_types::mempool_types::CommitBlockArgs; -use tokio::sync::Mutex; -use crate::batcher::{Batcher, MockBatcherStorageReaderTrait, MockBatcherStorageWriterTrait}; +use crate::batcher::{ + Batcher, + MockBatcherStorageReaderTrait, + MockBatcherStorageWriterTrait, + ProposalOutput, +}; use crate::block_builder::{ AbortSignalSender, BlockBuilderError, - BlockBuilderTrait, + BlockBuilderResult, + BlockExecutionArtifacts, FailOnErrorCause, MockBlockBuilderFactoryTrait, - MockBlockBuilderTrait, }; use crate::config::BatcherConfig; -use crate::proposal_manager::{ - GenerateProposalError, - ProposalError, - ProposalManagerTrait, - ProposalOutput, - ProposalResult, -}; -use crate::test_utils::test_txs; -use crate::transaction_provider::NextTxs; +use crate::test_utils::{test_txs, FakeProposeBlockBuilder, FakeValidateBlockBuilder}; const INITIAL_HEIGHT: BlockNumber = BlockNumber(3); const STREAMING_CHUNK_SIZE: usize = 3; @@ -68,30 +54,17 @@ fn initial_block_info() -> BlockInfo { } fn proposal_commitment() -> ProposalCommitment { - ProposalCommitment { - state_diff_commitment: StateDiffCommitment(PoseidonHash(felt!(u128::try_from(7).unwrap()))), - } -} - -fn proposal_output() -> ProposalOutput { - ProposalOutput { commitment: proposal_commitment(), ..Default::default() } + ProposalOutput::from(BlockExecutionArtifacts::create_for_testing()).commitment } fn deadline() -> chrono::DateTime { chrono::Utc::now() + BLOCK_GENERATION_TIMEOUT } -fn invalid_proposal_result() -> ProposalResult { - Err(ProposalError::BlockBuilderError(Arc::new(BlockBuilderError::FailOnError( - FailOnErrorCause::BlockFull, - )))) -} - struct MockDependencies { storage_reader: MockBatcherStorageReaderTrait, storage_writer: MockBatcherStorageWriterTrait, mempool_client: MockMempoolClient, - proposal_manager: MockProposalManagerTraitWrapper, block_builder_factory: MockBlockBuilderFactoryTrait, } @@ -103,7 +76,6 @@ impl Default for MockDependencies { storage_reader, storage_writer: MockBatcherStorageWriterTrait::new(), mempool_client: MockMempoolClient::new(), - proposal_manager: MockProposalManagerTraitWrapper::new(), block_builder_factory: MockBlockBuilderFactoryTrait::new(), } } @@ -116,7 +88,6 @@ fn create_batcher(mock_dependencies: MockDependencies) -> Batcher { Box::new(mock_dependencies.storage_writer), Arc::new(mock_dependencies.mempool_client), Box::new(mock_dependencies.block_builder_factory), - Box::new(mock_dependencies.proposal_manager), ) } @@ -124,73 +95,82 @@ fn abort_signal_sender() -> AbortSignalSender { tokio::sync::oneshot::channel().0 } -fn mock_create_builder_for_validate_block() -> MockBlockBuilderFactoryTrait { - let mut block_builder_factory = MockBlockBuilderFactoryTrait::new(); +fn expect_create_validate_block_builder( + block_builder_factory: &mut MockBlockBuilderFactoryTrait, + build_block_result: BlockBuilderResult, +) { block_builder_factory.expect_create_block_builder().times(1).return_once( - |_, _, mut tx_provider, _| { - // Spawn a task to keep tx_provider alive until all transactions are read. - // Without this, the provider would be dropped, causing the batcher to fail when sending - // transactions to it during the test. - tokio::spawn(async move { - while tx_provider.get_txs(1).await.is_ok_and(|v| v != NextTxs::End) { - tokio::task::yield_now().await; - } - }); - Ok((Box::new(MockBlockBuilderTrait::new()), abort_signal_sender())) + |_, _, tx_provider, _| { + Ok(( + Box::new(FakeValidateBlockBuilder { + tx_provider, + build_block_result: Some(build_block_result), + }), + abort_signal_sender(), + )) }, ); - block_builder_factory } -fn mock_create_builder_for_propose_block( +fn expect_create_propose_block_builder( + block_builder_factory: &mut MockBlockBuilderFactoryTrait, output_txs: Vec, -) -> MockBlockBuilderFactoryTrait { - let mut block_builder_factory = MockBlockBuilderFactoryTrait::new(); + build_block_result: BlockBuilderResult, +) { block_builder_factory.expect_create_block_builder().times(1).return_once( - |_, _, _, output_content_sender| { - // Simulate the streaming of the block builder output. - for tx in output_txs { - output_content_sender.as_ref().unwrap().send(tx).unwrap(); - } - Ok((Box::new(MockBlockBuilderTrait::new()), abort_signal_sender())) + move |_, _, _, output_content_sender| { + Ok(( + Box::new(FakeProposeBlockBuilder { + output_content_sender: output_content_sender.unwrap(), + output_txs, + build_block_result: Some(build_block_result), + }), + abort_signal_sender(), + )) }, ); +} + +fn successful_validate_block_builder() -> MockBlockBuilderFactoryTrait { + let mut block_builder_factory = MockBlockBuilderFactoryTrait::new(); + expect_create_validate_block_builder( + &mut block_builder_factory, + Ok(BlockExecutionArtifacts::create_for_testing()), + ); block_builder_factory } -fn mock_start_proposal(proposal_manager: &mut MockProposalManagerTraitWrapper) { - proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed()); - proposal_manager - .expect_wrap_spawn_proposal() - .times(1) - .with(eq(PROPOSAL_ID), always(), always()) - .return_once(|_, _, _| { async move { Ok(()) } }.boxed()); +fn failed_validate_block_builder() -> MockBlockBuilderFactoryTrait { + let mut block_builder_factory = MockBlockBuilderFactoryTrait::new(); + expect_create_validate_block_builder( + &mut block_builder_factory, + Err(BlockBuilderError::FailOnError(FailOnErrorCause::BlockFull)), + ); + block_builder_factory } -fn mock_completed_proposal( - proposal_manager: &mut MockProposalManagerTraitWrapper, - proposal_result: ProposalResult, -) { - proposal_manager.expect_wrap_get_completed_proposals().times(1).return_once(move || { - async move { Arc::new(Mutex::new(HashMap::from([(PROPOSAL_ID, proposal_result)]))) }.boxed() - }); +fn successful_propose_block_builder(output_txs: Vec) -> MockBlockBuilderFactoryTrait { + let mut block_builder_factory = MockBlockBuilderFactoryTrait::new(); + expect_create_propose_block_builder( + &mut block_builder_factory, + output_txs, + Ok(BlockExecutionArtifacts::create_for_testing()), + ); + block_builder_factory } -async fn batcher_with_validated_proposal( - proposal_result: ProposalResult, -) -> Batcher { - let block_builder_factory = mock_create_builder_for_validate_block(); - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - mock_start_proposal(&mut proposal_manager); - mock_completed_proposal(&mut proposal_manager, proposal_result); - proposal_manager.expect_wrap_get_active_proposal().returning(|| async move { None }.boxed()); +async fn create_completed_validate_proposal(batcher: &mut Batcher) { + create_active_validate_proposal(batcher).await; - let mut batcher = create_batcher(MockDependencies { - proposal_manager, - block_builder_factory, - ..Default::default() - }); + let finish_proposal_input = + SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Finish }; + batcher.send_proposal_content(finish_proposal_input).await.unwrap(); + + // Make sure the proposal is finished. + batcher.await_active_proposal().await; +} +async fn create_active_validate_proposal(batcher: &mut Batcher) { batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); let validate_block_input = ValidateBlockInput { @@ -200,31 +180,12 @@ async fn batcher_with_validated_proposal( block_info: initial_block_info(), }; batcher.validate_block(validate_block_input).await.unwrap(); - - batcher -} - -fn mock_proposal_manager_validate_flow() -> MockProposalManagerTraitWrapper { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - mock_start_proposal(&mut proposal_manager); - proposal_manager - .expect_wrap_get_active_proposal() - .returning(|| async move { Some(PROPOSAL_ID) }.boxed()); - proposal_manager - .expect_wrap_await_active_proposal() - .times(1) - .returning(|| async move { true }.boxed()); - mock_completed_proposal(&mut proposal_manager, Ok(proposal_output())); - proposal_manager } #[rstest] #[tokio::test] async fn start_height_success() { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed()); - - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); assert_eq!(batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await, Ok(())); } @@ -245,20 +206,14 @@ async fn start_height_success() { )] #[tokio::test] async fn start_height_fail(#[case] height: BlockNumber, #[case] expected_error: BatcherError) { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_reset().never(); - - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); assert_eq!(batcher.start_height(StartHeightInput { height }).await, Err(expected_error)); } #[rstest] #[tokio::test] async fn duplicate_start_height() { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed()); - - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); let initial_height = StartHeightInput { height: INITIAL_HEIGHT }; assert_eq!(batcher.start_height(initial_height.clone()).await, Ok(())); @@ -268,8 +223,7 @@ async fn duplicate_start_height() { #[rstest] #[tokio::test] async fn no_active_height() { - let proposal_manager = MockProposalManagerTraitWrapper::new(); - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); // Calling `propose_block` and `validate_block` without starting a height should fail. @@ -296,24 +250,66 @@ async fn no_active_height() { #[rstest] #[tokio::test] -async fn validate_block_full_flow() { - let block_builder_factory = mock_create_builder_for_validate_block(); - let proposal_manager = mock_proposal_manager_validate_flow(); +async fn consecutive_proposal_generation_success() { + let mut block_builder_factory = MockBlockBuilderFactoryTrait::new(); + expect_create_validate_block_builder( + &mut block_builder_factory, + Ok(BlockExecutionArtifacts::create_for_testing()), + ); + expect_create_propose_block_builder( + &mut block_builder_factory, + vec![], + Ok(BlockExecutionArtifacts::create_for_testing()), + ); + let mut batcher = + create_batcher(MockDependencies { block_builder_factory, ..Default::default() }); + + create_completed_validate_proposal(&mut batcher).await; + + // Make sure another proposal can be generated after the first one finished. + batcher + .propose_block(ProposeBlockInput { + proposal_id: ProposalId(1), + retrospective_block_hash: None, + deadline: chrono::Utc::now() + chrono::Duration::seconds(1), + block_info: initial_block_info(), + }) + .await + .unwrap(); +} + +#[rstest] +#[tokio::test] +async fn concurrent_proposals_generation_fail() { let mut batcher = create_batcher(MockDependencies { - proposal_manager, - block_builder_factory, + block_builder_factory: successful_validate_block_builder(), ..Default::default() }); - batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); + // Start a validate proposal that will remain active. + create_active_validate_proposal(&mut batcher).await; - let validate_block_input = ValidateBlockInput { - proposal_id: PROPOSAL_ID, - deadline: deadline(), - retrospective_block_hash: None, - block_info: initial_block_info(), - }; - batcher.validate_block(validate_block_input).await.unwrap(); + // Make sure another proposal can't be generated while the first one is still active. + let result = batcher + .propose_block(ProposeBlockInput { + proposal_id: ProposalId(1), + retrospective_block_hash: None, + deadline: chrono::Utc::now() + chrono::Duration::seconds(1), + block_info: initial_block_info(), + }) + .await; + + assert_matches!(result, Err(BatcherError::ServerBusy { .. })); +} + +#[rstest] +#[tokio::test] +async fn validate_block_full_flow() { + let mut batcher = create_batcher(MockDependencies { + block_builder_factory: successful_validate_block_builder(), + ..Default::default() + }); + create_active_validate_proposal(&mut batcher).await; let send_proposal_input_txs = SendProposalContentInput { proposal_id: PROPOSAL_ID, @@ -335,11 +331,31 @@ async fn validate_block_full_flow() { ); } +#[rstest] +#[tokio::test] +async fn send_proposal_content_abort() { + let mut batcher = create_batcher(MockDependencies { + block_builder_factory: successful_validate_block_builder(), + ..Default::default() + }); + create_active_validate_proposal(&mut batcher).await; + + let send_abort_request = + SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Abort }; + assert_eq!( + batcher.send_proposal_content(send_abort_request).await.unwrap(), + SendProposalContentResponse { response: ProposalStatus::Aborted } + ); +} + #[rstest] #[tokio::test] async fn send_content_after_proposal_already_finished() { - let successful_proposal_result = Ok(proposal_output()); - let mut batcher = batcher_with_validated_proposal(successful_proposal_result).await; + let mut batcher = create_batcher(MockDependencies { + block_builder_factory: successful_validate_block_builder(), + ..Default::default() + }); + create_completed_validate_proposal(&mut batcher).await; // Send transactions after the proposal has finished. let send_proposal_input_txs = SendProposalContentInput { @@ -347,7 +363,7 @@ async fn send_content_after_proposal_already_finished() { content: SendProposalContent::Txs(test_txs(0..1)), }; let result = batcher.send_proposal_content(send_proposal_input_txs).await; - assert_eq!(result, Err(BatcherError::ProposalAlreadyFinished { proposal_id: PROPOSAL_ID })); + assert_eq!(result, Err(BatcherError::ProposalNotFound { proposal_id: PROPOSAL_ID })); } #[rstest] @@ -373,7 +389,12 @@ async fn send_content_to_unknown_proposal() { #[rstest] #[tokio::test] async fn send_txs_to_an_invalid_proposal() { - let mut batcher = batcher_with_validated_proposal(invalid_proposal_result()).await; + let mut batcher = create_batcher(MockDependencies { + block_builder_factory: failed_validate_block_builder(), + ..Default::default() + }); + create_active_validate_proposal(&mut batcher).await; + batcher.await_active_proposal().await; let send_proposal_input_txs = SendProposalContentInput { proposal_id: PROPOSAL_ID, @@ -383,32 +404,14 @@ async fn send_txs_to_an_invalid_proposal() { assert_eq!(result, SendProposalContentResponse { response: ProposalStatus::InvalidProposal }); } -#[rstest] -#[tokio::test] -async fn send_finish_to_an_invalid_proposal() { - let mut batcher = batcher_with_validated_proposal(invalid_proposal_result()).await; - - let send_proposal_input_txs = - SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Finish }; - let result = batcher.send_proposal_content(send_proposal_input_txs).await.unwrap(); - assert_eq!(result, SendProposalContentResponse { response: ProposalStatus::InvalidProposal }); -} - #[rstest] #[tokio::test] async fn propose_block_full_flow() { // Expecting 3 chunks of streamed txs. let expected_streamed_txs = test_txs(0..STREAMING_CHUNK_SIZE * 2 + 1); - let txs_to_stream = expected_streamed_txs.clone(); - - let block_builder_factory = mock_create_builder_for_propose_block(txs_to_stream); - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - mock_start_proposal(&mut proposal_manager); - mock_completed_proposal(&mut proposal_manager, Ok(proposal_output())); let mut batcher = create_batcher(MockDependencies { - proposal_manager, - block_builder_factory, + block_builder_factory: successful_propose_block_builder(expected_streamed_txs.clone()), ..Default::default() }); @@ -454,16 +457,12 @@ async fn propose_block_full_flow() { #[rstest] #[tokio::test] async fn propose_block_without_retrospective_block_hash() { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed()); - let mut storage_reader = MockBatcherStorageReaderTrait::new(); storage_reader .expect_height() .returning(|| Ok(BlockNumber(constants::STORED_BLOCK_HASH_BUFFER))); - let mut batcher = - create_batcher(MockDependencies { proposal_manager, storage_reader, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies { storage_reader, ..Default::default() }); batcher .start_height(StartHeightInput { height: BlockNumber(constants::STORED_BLOCK_HASH_BUFFER) }) @@ -484,10 +483,7 @@ async fn propose_block_without_retrospective_block_hash() { #[rstest] #[tokio::test] async fn get_content_from_unknown_proposal() { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_get_completed_proposals().times(0); - - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); let get_proposal_content_input = GetProposalContentInput { proposal_id: PROPOSAL_ID }; let result = batcher.get_proposal_content(get_proposal_content_input).await; @@ -498,41 +494,29 @@ async fn get_content_from_unknown_proposal() { #[tokio::test] async fn decision_reached() { let mut mock_dependencies = MockDependencies::default(); - - mock_dependencies - .proposal_manager - .expect_wrap_take_proposal_result() - .times(1) - .with(eq(PROPOSAL_ID)) - .return_once(move |_| { - async move { - Some(Ok(ProposalOutput { - state_diff: ThinStateDiff::default(), - commitment: ProposalCommitment::default(), - tx_hashes: test_tx_hashes(), - nonces: test_contract_nonces(), - })) - } - .boxed() - }); + let expected_proposal_output = + ProposalOutput::from(BlockExecutionArtifacts::create_for_testing()); mock_dependencies .mempool_client .expect_commit_block() .with(eq(CommitBlockArgs { - address_to_nonce: test_contract_nonces(), - tx_hashes: test_tx_hashes(), + address_to_nonce: expected_proposal_output.nonces, + tx_hashes: expected_proposal_output.tx_hashes, })) .returning(|_| Ok(())); mock_dependencies .storage_writer .expect_commit_proposal() - .with(eq(INITIAL_HEIGHT), eq(ThinStateDiff::default())) + .with(eq(INITIAL_HEIGHT), eq(expected_proposal_output.state_diff)) .returning(|_, _| Ok(())); + mock_dependencies.block_builder_factory = successful_validate_block_builder(); let mut batcher = create_batcher(mock_dependencies); + create_completed_validate_proposal(&mut batcher).await; + batcher.decision_reached(DecisionReachedInput { proposal_id: PROPOSAL_ID }).await.unwrap(); } @@ -541,92 +525,10 @@ async fn decision_reached() { async fn decision_reached_no_executed_proposal() { let expected_error = BatcherError::ExecutedProposalNotFound { proposal_id: PROPOSAL_ID }; - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager - .expect_wrap_take_proposal_result() - .times(1) - .with(eq(PROPOSAL_ID)) - .return_once(|_| async move { None }.boxed()); + let mut batcher = create_batcher(MockDependencies::default()); + batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); let decision_reached_result = batcher.decision_reached(DecisionReachedInput { proposal_id: PROPOSAL_ID }).await; assert_eq!(decision_reached_result, Err(expected_error)); } - -// A wrapper trait to allow mocking the ProposalManagerTrait in tests. -#[automock] -trait ProposalManagerTraitWrapper: Send + Sync { - fn wrap_spawn_proposal( - &mut self, - proposal_id: ProposalId, - block_builder: Box, - abort_signal_sender: tokio::sync::oneshot::Sender<()>, - ) -> BoxFuture<'_, Result<(), GenerateProposalError>>; - - fn wrap_take_proposal_result( - &mut self, - proposal_id: ProposalId, - ) -> BoxFuture<'_, Option>>; - - fn wrap_get_active_proposal(&self) -> BoxFuture<'_, Option>; - - fn wrap_get_completed_proposals( - &self, - ) -> BoxFuture<'_, Arc>>>>; - - fn wrap_await_active_proposal(&mut self) -> BoxFuture<'_, bool>; - - fn wrap_abort_proposal(&mut self, proposal_id: ProposalId) -> BoxFuture<'_, ()>; - - fn wrap_reset(&mut self) -> BoxFuture<'_, ()>; -} - -#[async_trait] -impl ProposalManagerTrait for T { - async fn spawn_proposal( - &mut self, - proposal_id: ProposalId, - block_builder: Box, - abort_signal_sender: tokio::sync::oneshot::Sender<()>, - ) -> Result<(), GenerateProposalError> { - self.wrap_spawn_proposal(proposal_id, block_builder, abort_signal_sender).await - } - - async fn take_proposal_result( - &mut self, - proposal_id: ProposalId, - ) -> Option> { - self.wrap_take_proposal_result(proposal_id).await - } - - async fn get_active_proposal(&self) -> Option { - self.wrap_get_active_proposal().await - } - - async fn get_completed_proposals( - &self, - ) -> Arc>>> { - self.wrap_get_completed_proposals().await - } - - async fn await_active_proposal(&mut self) -> bool { - self.wrap_await_active_proposal().await - } - - async fn abort_proposal(&mut self, proposal_id: ProposalId) { - self.wrap_abort_proposal(proposal_id).await - } - - async fn reset(&mut self) { - self.wrap_reset().await - } -} - -fn test_tx_hashes() -> HashSet { - (0..5u8).map(|i| tx_hash!(i + 12)).collect() -} - -fn test_contract_nonces() -> HashMap { - HashMap::from_iter((0..3u8).map(|i| (contract_address!(i + 33), nonce!(i + 9)))) -} diff --git a/crates/starknet_batcher/src/lib.rs b/crates/starknet_batcher/src/lib.rs index f25938de6eb..1deed345d4a 100644 --- a/crates/starknet_batcher/src/lib.rs +++ b/crates/starknet_batcher/src/lib.rs @@ -7,9 +7,6 @@ mod block_builder_test; pub mod communication; pub mod config; pub mod fee_market; -mod proposal_manager; -#[cfg(test)] -mod proposal_manager_test; #[cfg(test)] mod test_utils; mod transaction_executor; diff --git a/crates/starknet_batcher/src/proposal_manager.rs b/crates/starknet_batcher/src/proposal_manager.rs deleted file mode 100644 index 69cc25540ef..00000000000 --- a/crates/starknet_batcher/src/proposal_manager.rs +++ /dev/null @@ -1,259 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; - -use async_trait::async_trait; -use indexmap::IndexMap; -use starknet_api::block_hash::state_diff_hash::calculate_state_diff_hash; -use starknet_api::core::{ContractAddress, Nonce}; -use starknet_api::state::ThinStateDiff; -use starknet_api::transaction::TransactionHash; -use starknet_batcher_types::batcher_types::{ProposalCommitment, ProposalId}; -use thiserror::Error; -use tokio::sync::Mutex; -use tracing::{debug, error, info, instrument, Instrument}; - -use crate::block_builder::{BlockBuilderError, BlockBuilderTrait, BlockExecutionArtifacts}; - -#[derive(Debug, Error)] -pub enum GenerateProposalError { - #[error( - "Received proposal generation request with id {new_proposal_id} while already generating \ - proposal with id {current_generating_proposal_id}." - )] - AlreadyGeneratingProposal { - current_generating_proposal_id: ProposalId, - new_proposal_id: ProposalId, - }, - #[error(transparent)] - BlockBuilderError(#[from] BlockBuilderError), - #[error("No active height to work on.")] - NoActiveHeight, - #[error("Proposal with id {proposal_id} already exists.")] - ProposalAlreadyExists { proposal_id: ProposalId }, -} - -#[derive(Clone, Debug, Error)] -pub enum ProposalError { - #[error(transparent)] - BlockBuilderError(Arc), - #[error("Proposal was aborted")] - Aborted, -} - -#[async_trait] -pub trait ProposalManagerTrait: Send + Sync { - async fn spawn_proposal( - &mut self, - proposal_id: ProposalId, - mut block_builder: Box, - abort_signal_sender: tokio::sync::oneshot::Sender<()>, - ) -> Result<(), GenerateProposalError>; - - async fn take_proposal_result( - &mut self, - proposal_id: ProposalId, - ) -> Option>; - - async fn get_active_proposal(&self) -> Option; - - async fn get_completed_proposals( - &self, - ) -> Arc>>>; - - async fn await_active_proposal(&mut self) -> bool; - - async fn abort_proposal(&mut self, proposal_id: ProposalId); - - // Resets the proposal manager, aborting any active proposal. - async fn reset(&mut self); -} - -// Represents a spawned task of building new block proposal. -struct ProposalTask { - abort_signal_sender: tokio::sync::oneshot::Sender<()>, - join_handle: tokio::task::JoinHandle<()>, -} - -/// Main struct for handling block proposals. -/// Taking care of: -/// - Proposing new blocks. -/// - Validating incoming proposals. -/// - Committing accepted proposals to the storage. -/// -/// Triggered by the consensus. -pub(crate) struct ProposalManager { - /// The block proposal that is currently being built, if any. - /// At any given time, there can be only one proposal being actively executed (either proposed - /// or validated). - active_proposal: Arc>>, - active_proposal_task: Option, - - executed_proposals: Arc>>>, -} - -pub type ProposalResult = Result; - -#[derive(Debug, Default, PartialEq)] -pub struct ProposalOutput { - pub state_diff: ThinStateDiff, - pub commitment: ProposalCommitment, - pub tx_hashes: HashSet, - pub nonces: HashMap, -} - -#[async_trait] -impl ProposalManagerTrait for ProposalManager { - /// Starts a new block proposal generation task for the given proposal_id. - /// Uses the given block_builder to generate the proposal. - #[instrument(skip(self, block_builder), err)] - async fn spawn_proposal( - &mut self, - proposal_id: ProposalId, - mut block_builder: Box, - abort_signal_sender: tokio::sync::oneshot::Sender<()>, - ) -> Result<(), GenerateProposalError> { - self.set_active_proposal(proposal_id).await?; - - info!("Starting generation of a new proposal with id {}.", proposal_id); - - let active_proposal = self.active_proposal.clone(); - let executed_proposals = self.executed_proposals.clone(); - - let join_handle = tokio::spawn( - async move { - let result = block_builder - .build_block() - .await - .map(ProposalOutput::from) - .map_err(|e| ProposalError::BlockBuilderError(Arc::new(e))); - - // The proposal is done, clear the active proposal. - // Keep the proposal result only if it is the same as the active proposal. - // The active proposal might have changed if this proposal was aborted. - let mut active_proposal = active_proposal.lock().await; - if *active_proposal == Some(proposal_id) { - active_proposal.take(); - executed_proposals.lock().await.insert(proposal_id, result); - } - } - .in_current_span(), - ); - - self.active_proposal_task = Some(ProposalTask { abort_signal_sender, join_handle }); - Ok(()) - } - - async fn take_proposal_result( - &mut self, - proposal_id: ProposalId, - ) -> Option> { - self.executed_proposals.lock().await.remove(&proposal_id) - } - - async fn get_active_proposal(&self) -> Option { - *self.active_proposal.lock().await - } - - async fn get_completed_proposals( - &self, - ) -> Arc>>> { - self.executed_proposals.clone() - } - - // Awaits the active proposal. - // Returns true if there was an active proposal, and false otherwise. - async fn await_active_proposal(&mut self) -> bool { - if let Some(proposal_task) = self.active_proposal_task.take() { - proposal_task.join_handle.await.ok(); - return true; - } - false - } - - // Aborts the proposal with the given ID, if active. - // Should be used in validate flow, if the consensus decides to abort the proposal. - async fn abort_proposal(&mut self, proposal_id: ProposalId) { - if *self.active_proposal.lock().await == Some(proposal_id) { - self.abort_active_proposal().await; - self.executed_proposals.lock().await.insert(proposal_id, Err(ProposalError::Aborted)); - } - } - - async fn reset(&mut self) { - self.abort_active_proposal().await; - self.executed_proposals.lock().await.clear(); - } -} - -impl ProposalManager { - pub fn new() -> Self { - Self { - active_proposal: Arc::new(Mutex::new(None)), - active_proposal_task: None, - executed_proposals: Arc::new(Mutex::new(HashMap::new())), - } - } - - // Sets a new active proposal task. - // Fails if either there is no active height, there is another proposal being generated, or a - // proposal with the same ID already exists. - async fn set_active_proposal( - &mut self, - proposal_id: ProposalId, - ) -> Result<(), GenerateProposalError> { - if self.executed_proposals.lock().await.contains_key(&proposal_id) { - return Err(GenerateProposalError::ProposalAlreadyExists { proposal_id }); - } - - let mut active_proposal = self.active_proposal.lock().await; - if let Some(current_generating_proposal_id) = *active_proposal { - return Err(GenerateProposalError::AlreadyGeneratingProposal { - current_generating_proposal_id, - new_proposal_id: proposal_id, - }); - } - - debug!("Set proposal {} as the one being generated.", proposal_id); - *active_proposal = Some(proposal_id); - Ok(()) - } - - // Ends the current active proposal. - // This call is non-blocking. - async fn abort_active_proposal(&mut self) { - self.active_proposal.lock().await.take(); - if let Some(proposal_task) = self.active_proposal_task.take() { - proposal_task.abort_signal_sender.send(()).ok(); - } - } -} - -impl From for ProposalOutput { - fn from(artifacts: BlockExecutionArtifacts) -> Self { - let commitment_state_diff = artifacts.commitment_state_diff; - let nonces = HashMap::from_iter( - commitment_state_diff - .address_to_nonce - .iter() - .map(|(address, nonce)| (*address, *nonce)), - ); - - // TODO: Get these from the transactions. - let deployed_contracts = IndexMap::new(); - let declared_classes = IndexMap::new(); - let state_diff = ThinStateDiff { - deployed_contracts, - storage_diffs: commitment_state_diff.storage_updates, - declared_classes, - nonces: commitment_state_diff.address_to_nonce, - // TODO: Remove this when the structure of storage diffs changes. - deprecated_declared_classes: Vec::new(), - replaced_classes: IndexMap::new(), - }; - let commitment = - ProposalCommitment { state_diff_commitment: calculate_state_diff_hash(&state_diff) }; - let tx_hashes = HashSet::from_iter(artifacts.execution_infos.keys().copied()); - - Self { state_diff, commitment, tx_hashes, nonces } - } -} diff --git a/crates/starknet_batcher/src/proposal_manager_test.rs b/crates/starknet_batcher/src/proposal_manager_test.rs deleted file mode 100644 index 2f92a8492be..00000000000 --- a/crates/starknet_batcher/src/proposal_manager_test.rs +++ /dev/null @@ -1,160 +0,0 @@ -use assert_matches::assert_matches; -use rstest::{fixture, rstest}; -use starknet_api::executable_transaction::Transaction; -use starknet_batcher_types::batcher_types::ProposalId; - -use crate::block_builder::{BlockBuilderTrait, BlockExecutionArtifacts, MockBlockBuilderTrait}; -use crate::proposal_manager::{ - GenerateProposalError, - ProposalError, - ProposalManager, - ProposalManagerTrait, - ProposalOutput, -}; - -const BLOCK_GENERATION_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(1); - -#[fixture] -fn output_streaming() -> ( - tokio::sync::mpsc::UnboundedSender, - tokio::sync::mpsc::UnboundedReceiver, -) { - tokio::sync::mpsc::unbounded_channel() -} - -#[fixture] -fn proposal_manager() -> ProposalManager { - ProposalManager::new() -} - -fn mock_build_block() -> Box { - let mut mock_block_builder = MockBlockBuilderTrait::new(); - mock_block_builder - .expect_build_block() - .times(1) - .return_once(move || Ok(BlockExecutionArtifacts::create_for_testing())); - Box::new(mock_block_builder) -} - -// This function simulates a long build block operation. This is required for a test that -// tries to run other operations while a block is being built. -fn mock_long_build_block() -> Box { - let mut mock_block_builder = MockBlockBuilderTrait::new(); - mock_block_builder.expect_build_block().times(1).return_once(move || { - std::thread::sleep(BLOCK_GENERATION_TIMEOUT * 10); - Ok(BlockExecutionArtifacts::create_for_testing()) - }); - Box::new(mock_block_builder) -} - -async fn spawn_proposal_non_blocking( - proposal_manager: &mut ProposalManager, - proposal_id: ProposalId, - block_builder: Box, -) -> Result<(), GenerateProposalError> { - let (abort_sender, _rec) = tokio::sync::oneshot::channel(); - proposal_manager.spawn_proposal(proposal_id, block_builder, abort_sender).await -} - -async fn spawn_proposal( - proposal_manager: &mut ProposalManager, - proposal_id: ProposalId, - block_builder: Box, -) { - spawn_proposal_non_blocking(proposal_manager, proposal_id, block_builder).await.unwrap(); - assert!(proposal_manager.await_active_proposal().await); -} - -#[rstest] -#[tokio::test] -async fn spawn_proposal_success(mut proposal_manager: ProposalManager) { - spawn_proposal(&mut proposal_manager, ProposalId(0), mock_build_block()).await; - - proposal_manager.take_proposal_result(ProposalId(0)).await.unwrap().unwrap(); -} - -#[rstest] -#[tokio::test] -async fn consecutive_proposal_generations_success(mut proposal_manager: ProposalManager) { - // Build and validate multiple proposals consecutively (awaiting on them to - // make sure they finished successfully). - spawn_proposal(&mut proposal_manager, ProposalId(0), mock_build_block()).await; - spawn_proposal(&mut proposal_manager, ProposalId(1), mock_build_block()).await; -} - -// This test checks that trying to generate a proposal while another one is being generated will -// fail. First the test will generate a new proposal that takes a very long time, and during -// that time it will send another build proposal request. -#[rstest] -#[tokio::test] -async fn multiple_proposals_generation_fail(mut proposal_manager: ProposalManager) { - // Build a proposal that will take a very long time to finish. - spawn_proposal_non_blocking(&mut proposal_manager, ProposalId(0), mock_long_build_block()) - .await - .unwrap(); - - // Try to generate another proposal while the first one is still being generated. - let mut block_builder = MockBlockBuilderTrait::new(); - block_builder.expect_build_block().never(); - let another_generate_request = - spawn_proposal_non_blocking(&mut proposal_manager, ProposalId(1), Box::new(block_builder)) - .await; - - assert_matches!( - another_generate_request, - Err(GenerateProposalError::AlreadyGeneratingProposal { - current_generating_proposal_id, - new_proposal_id - }) if current_generating_proposal_id == ProposalId(0) && new_proposal_id == ProposalId(1) - ); -} - -#[rstest] -#[tokio::test] -async fn take_proposal_result_no_active_proposal(mut proposal_manager: ProposalManager) { - spawn_proposal(&mut proposal_manager, ProposalId(0), mock_build_block()).await; - - let expected_proposal_output = - ProposalOutput::from(BlockExecutionArtifacts::create_for_testing()); - assert_eq!( - proposal_manager.take_proposal_result(ProposalId(0)).await.unwrap().unwrap(), - expected_proposal_output - ); - assert_matches!(proposal_manager.take_proposal_result(ProposalId(0)).await, None); -} - -#[rstest] -#[tokio::test] -async fn abort_active_proposal(mut proposal_manager: ProposalManager) { - spawn_proposal_non_blocking(&mut proposal_manager, ProposalId(0), mock_long_build_block()) - .await - .unwrap(); - - proposal_manager.abort_proposal(ProposalId(0)).await; - - assert_matches!( - proposal_manager.take_proposal_result(ProposalId(0)).await, - Some(Err(ProposalError::Aborted)) - ); - - // Make sure there is no active proposal. - assert!(!proposal_manager.await_active_proposal().await); -} - -#[rstest] -#[tokio::test] -async fn reset(mut proposal_manager: ProposalManager) { - // Create 2 proposals, one will remain active. - spawn_proposal(&mut proposal_manager, ProposalId(0), mock_build_block()).await; - spawn_proposal_non_blocking(&mut proposal_manager, ProposalId(1), mock_long_build_block()) - .await - .unwrap(); - - proposal_manager.reset().await; - - // Make sure executed proposals are deleted. - assert_matches!(proposal_manager.take_proposal_result(ProposalId(0)).await, None); - - // Make sure there is no active proposal. - assert!(!proposal_manager.await_active_proposal().await); -} diff --git a/crates/starknet_batcher/src/test_utils.rs b/crates/starknet_batcher/src/test_utils.rs index 73f251de039..430a0dec7a6 100644 --- a/crates/starknet_batcher/src/test_utils.rs +++ b/crates/starknet_batcher/src/test_utils.rs @@ -1,14 +1,62 @@ use std::ops::Range; +use async_trait::async_trait; use blockifier::blockifier::transaction_executor::VisitedSegmentsMapping; use blockifier::bouncer::BouncerWeights; use blockifier::state::cached_state::CommitmentStateDiff; use indexmap::IndexMap; use starknet_api::executable_transaction::Transaction; use starknet_api::test_utils::invoke::{executable_invoke_tx, InvokeTxArgs}; -use starknet_api::tx_hash; +use starknet_api::{class_hash, contract_address, nonce, tx_hash}; +use tokio::sync::mpsc::UnboundedSender; -use crate::block_builder::BlockExecutionArtifacts; +use crate::block_builder::{BlockBuilderResult, BlockBuilderTrait, BlockExecutionArtifacts}; +use crate::transaction_provider::{NextTxs, TransactionProvider}; + +// A fake block builder for validate flow, that fetches transactions from the transaction provider +// until it is exhausted. +// This ensures the block builder (and specifically the tx_provider) is not dropped before all +// transactions are processed. Otherwise, the batcher would fail during tests when attempting to +// send transactions to it. +pub(crate) struct FakeValidateBlockBuilder { + pub tx_provider: Box, + pub build_block_result: Option>, +} + +#[async_trait] +impl BlockBuilderTrait for FakeValidateBlockBuilder { + async fn build_block(&mut self) -> BlockBuilderResult { + // build_block should be called only once, so we can safely take the result. + let build_block_result = self.build_block_result.take().unwrap(); + + if build_block_result.is_ok() { + while self.tx_provider.get_txs(1).await.is_ok_and(|v| v != NextTxs::End) { + tokio::task::yield_now().await; + } + } + build_block_result + } +} + +// A fake block builder for propose flow, that sends the given transactions to the output content +// sender. +pub(crate) struct FakeProposeBlockBuilder { + pub output_content_sender: UnboundedSender, + pub output_txs: Vec, + pub build_block_result: Option>, +} + +#[async_trait] +impl BlockBuilderTrait for FakeProposeBlockBuilder { + async fn build_block(&mut self) -> BlockBuilderResult { + for tx in &self.output_txs { + self.output_content_sender.send(tx.clone()).unwrap(); + } + + // build_block should be called only once, so we can safely take the result. + self.build_block_result.take().unwrap() + } +} pub fn test_txs(tx_hash_range: Range) -> Vec { tx_hash_range @@ -23,9 +71,18 @@ pub fn test_txs(tx_hash_range: Range) -> Vec { impl BlockExecutionArtifacts { pub fn create_for_testing() -> Self { + // Use a non-empty commitment_state_diff to make the tests more realistic. Self { execution_infos: IndexMap::default(), - commitment_state_diff: CommitmentStateDiff::default(), + commitment_state_diff: CommitmentStateDiff { + address_to_class_hash: IndexMap::from_iter([( + contract_address!("0x7"), + class_hash!("0x11111111"), + )]), + storage_updates: IndexMap::new(), + class_hash_to_compiled_class_hash: IndexMap::new(), + address_to_nonce: IndexMap::from_iter([(contract_address!("0x7"), nonce!(1_u64))]), + }, visited_segments_mapping: VisitedSegmentsMapping::default(), bouncer_weights: BouncerWeights::empty(), }