From 72ce408bef996a727cd6547a363cba9b482b39a0 Mon Sep 17 00:00:00 2001 From: Dafna Matsry Date: Thu, 5 Dec 2024 13:19:20 +0200 Subject: [PATCH] refactor(starknet_batcher): delete the proposal manager --- crates/starknet_batcher/src/batcher.rs | 214 ++++++++----- crates/starknet_batcher/src/batcher_test.rs | 317 +++++--------------- crates/starknet_batcher/src/lib.rs | 3 - crates/starknet_batcher/src/test_utils.rs | 13 +- 4 files changed, 237 insertions(+), 310 deletions(-) diff --git a/crates/starknet_batcher/src/batcher.rs b/crates/starknet_batcher/src/batcher.rs index 007457a21a6..bab2b207465 100644 --- a/crates/starknet_batcher/src/batcher.rs +++ b/crates/starknet_batcher/src/batcher.rs @@ -29,16 +29,18 @@ use starknet_batcher_types::errors::BatcherError; use starknet_mempool_types::communication::SharedMempoolClient; use starknet_mempool_types::mempool_types::CommitBlockArgs; use starknet_sequencer_infra::component_definitions::ComponentStarter; -use tracing::{debug, error, info, instrument, trace}; +use tokio::sync::Mutex; +use tracing::{debug, error, info, instrument, trace, Instrument}; use crate::block_builder::{ + BlockBuilderError, BlockBuilderExecutionParams, BlockBuilderFactory, BlockBuilderFactoryTrait, + BlockBuilderTrait, BlockMetadata, }; use crate::config::BatcherConfig; -use crate::proposal_manager::{GenerateProposalError, ProposalManager, ProposalManagerTrait}; use crate::transaction_provider::{ DummyL1ProviderClient, ProposeTransactionProvider, @@ -50,6 +52,7 @@ use crate::utils::{ verify_block_input, ProposalOutput, ProposalResult, + ProposalTask, }; type OutputStreamReceiver = tokio::sync::mpsc::UnboundedReceiver; @@ -61,11 +64,30 @@ pub struct Batcher { pub storage_writer: Box, pub mempool_client: SharedMempoolClient, + // Used to create block builders. + // Using the factory pattern to allow for easier testing. + block_builder_factory: Box, + + // The height that the batcher is currently working on. + // All proposals are considered to be at this height. active_height: Option, - proposal_manager: Box, - block_builder_factory: Box, + // The block proposal that is currently being built, if any. + // At any given time, there can be only one proposal being actively executed (either proposed + // or validated). + active_proposal: Arc>>, + active_proposal_task: Option, + + // Holds all the proposals that completed execution in the current height. + executed_proposals: Arc>>>, + + // The propose blocks transaction streams, used to stream out the proposal transactions. + // Each stream is kept until all the transactions are streamed out, or a new height is started. propose_tx_streams: HashMap, + + // The validate blocks transaction streams, used to stream in the transactions to validate. + // Each stream is kept until SendProposalContent::Finish/Abort is received, or a new height is + // started. validate_tx_streams: HashMap, } @@ -76,16 +98,17 @@ impl Batcher { storage_writer: Box, mempool_client: SharedMempoolClient, block_builder_factory: Box, - proposal_manager: Box, ) -> Self { Self { config: config.clone(), storage_reader, storage_writer, mempool_client, - active_height: None, block_builder_factory, - proposal_manager, + active_height: None, + active_proposal: Arc::new(Mutex::new(None)), + active_proposal_task: None, + executed_proposals: Arc::new(Mutex::new(HashMap::new())), propose_tx_streams: HashMap::new(), validate_tx_streams: HashMap::new(), } @@ -113,7 +136,8 @@ impl Batcher { } // Clear all the proposals from the previous height. - self.proposal_manager.reset().await; + self.abort_active_proposal().await; + self.executed_proposals.lock().await.clear(); self.propose_tx_streams.clear(); self.validate_tx_streams.clear(); @@ -135,6 +159,8 @@ impl Batcher { propose_block_input.retrospective_block_hash, )?; + self.set_active_proposal(propose_block_input.proposal_id).await?; + let tx_provider = ProposeTransactionProvider::new( self.mempool_client.clone(), // TODO: use a real L1 provider client. @@ -161,8 +187,7 @@ impl Batcher { ) .map_err(|_| BatcherError::InternalError)?; - self.proposal_manager - .spawn_proposal(propose_block_input.proposal_id, block_builder, abort_signal_sender) + self.spawn_proposal(propose_block_input.proposal_id, block_builder, abort_signal_sender) .await?; self.propose_tx_streams.insert(propose_block_input.proposal_id, output_tx_receiver); @@ -181,6 +206,8 @@ impl Batcher { validate_block_input.retrospective_block_hash, )?; + self.set_active_proposal(validate_block_input.proposal_id).await?; + // A channel to send the transactions to include in the block being validated. let (input_tx_sender, input_tx_receiver) = tokio::sync::mpsc::channel(self.config.input_stream_content_buffer_size); @@ -207,8 +234,7 @@ impl Batcher { ) .map_err(|_| BatcherError::InternalError)?; - self.proposal_manager - .spawn_proposal(validate_block_input.proposal_id, block_builder, abort_signal_sender) + self.spawn_proposal(validate_block_input.proposal_id, block_builder, abort_signal_sender) .await?; self.validate_tx_streams.insert(validate_block_input.proposal_id, input_tx_sender); @@ -258,6 +284,8 @@ impl Batcher { let proposal_result = self.get_completed_proposal_result(proposal_id).await.expect("Proposal should exist."); match proposal_result { + // TODO(dafna): at this point the proposal result must be an error, since it finsisehd + // earlier than expected. Consider panicking instead of returning an error. Ok(_) => Err(BatcherError::ProposalAlreadyFinished { proposal_id }), Err(err) => Ok(SendProposalContentResponse { response: proposal_status_from(err)? }), } @@ -269,9 +297,11 @@ impl Batcher { ) -> BatcherResult { debug!("Send proposal content done for {}", proposal_id); - self.close_input_transaction_stream(proposal_id)?; + self.validate_tx_streams.remove(&proposal_id); if self.is_active(proposal_id).await { - self.proposal_manager.await_active_proposal().await; + if let Some(proposal_task) = self.active_proposal_task.take() { + proposal_task.join_handle.await.ok(); + } } let proposal_result = @@ -287,18 +317,17 @@ impl Batcher { &mut self, proposal_id: ProposalId, ) -> BatcherResult { - self.proposal_manager.abort_proposal(proposal_id).await; - self.close_input_transaction_stream(proposal_id)?; + if self.is_active(proposal_id).await { + self.abort_active_proposal().await; + self.executed_proposals + .lock() + .await + .insert(proposal_id, Err(Arc::new(BlockBuilderError::Aborted))); + } + self.validate_tx_streams.remove(&proposal_id); Ok(SendProposalContentResponse { response: ProposalStatus::Aborted }) } - fn close_input_transaction_stream(&mut self, proposal_id: ProposalId) -> BatcherResult<()> { - self.validate_tx_streams - .remove(&proposal_id) - .ok_or(BatcherError::ProposalNotFound { proposal_id })?; - Ok(()) - } - #[instrument(skip(self), err)] pub async fn get_height(&mut self) -> BatcherResult { let height = self.storage_reader.height().map_err(|err| { @@ -345,40 +374,97 @@ impl Batcher { #[instrument(skip(self), err)] pub async fn decision_reached(&mut self, input: DecisionReachedInput) -> BatcherResult<()> { + let height = self.active_height.ok_or(BatcherError::NoActiveHeight)?; + let proposal_id = input.proposal_id; - let proposal_output = self - .proposal_manager - .take_proposal_result(proposal_id) - .await - .ok_or(BatcherError::ExecutedProposalNotFound { proposal_id })? - .map_err(|_| BatcherError::InternalError)?; + let proposal_result = self.executed_proposals.lock().await.remove(&proposal_id); let ProposalOutput { state_diff, nonces: address_to_nonce, tx_hashes, .. } = - proposal_output; - // TODO: Keep the height from start_height or get it from the input. - let height = self.storage_reader.height().map_err(|err| { - error!("Failed to get height from storage: {}", err); - BatcherError::InternalError - })?; + proposal_result + .ok_or(BatcherError::ExecutedProposalNotFound { proposal_id })? + .map_err(|_| BatcherError::InternalError)?; + info!( "Committing proposal {} at height {} and notifying mempool of the block.", proposal_id, height ); trace!("Transactions: {:#?}, State diff: {:#?}.", tx_hashes, state_diff); + + // Commit the proposal to the storage and notify the mempool. self.storage_writer.commit_proposal(height, state_diff).map_err(|err| { error!("Failed to commit proposal to storage: {}", err); BatcherError::InternalError })?; - if let Err(mempool_err) = - self.mempool_client.commit_block(CommitBlockArgs { address_to_nonce, tx_hashes }).await - { + let mempool_result = + self.mempool_client.commit_block(CommitBlockArgs { address_to_nonce, tx_hashes }).await; + + if let Err(mempool_err) = mempool_result { error!("Failed to commit block to mempool: {}", mempool_err); // TODO: Should we rollback the state diff and return an error? - } + }; + Ok(()) } async fn is_active(&self, proposal_id: ProposalId) -> bool { - self.proposal_manager.get_active_proposal().await == Some(proposal_id) + *self.active_proposal.lock().await == Some(proposal_id) + } + + // Sets a new active proposal task. + // Fails if there is another proposal being currently generated, or a proposal with the same ID + // already exists. + async fn set_active_proposal(&mut self, proposal_id: ProposalId) -> BatcherResult<()> { + if self.executed_proposals.lock().await.contains_key(&proposal_id) { + return Err(BatcherError::ProposalAlreadyExists { proposal_id }); + } + + let mut active_proposal = self.active_proposal.lock().await; + if let Some(active_proposal_id) = *active_proposal { + return Err(BatcherError::ServerBusy { + active_proposal_id, + new_proposal_id: proposal_id, + }); + } + + debug!("Set proposal {} as the one being generated.", proposal_id); + *active_proposal = Some(proposal_id); + Ok(()) + } + + // Starts a new block proposal generation task for the given proposal_id. + // Uses the given block_builder to generate the proposal. + async fn spawn_proposal( + &mut self, + proposal_id: ProposalId, + mut block_builder: Box, + abort_signal_sender: tokio::sync::oneshot::Sender<()>, + ) -> BatcherResult<()> { + info!("Starting generation of a new proposal with id {}.", proposal_id); + + let active_proposal = self.active_proposal.clone(); + let executed_proposals = self.executed_proposals.clone(); + + let join_handle = tokio::spawn( + async move { + let result = block_builder + .build_block() + .await + .map(ProposalOutput::from) + .map_err(|e| Arc::new(e)); + + // The proposal is done, clear the active proposal. + // Keep the proposal result only if it is the same as the active proposal. + // The active proposal might have changed if this proposal was aborted. + let mut active_proposal = active_proposal.lock().await; + if *active_proposal == Some(proposal_id) { + active_proposal.take(); + executed_proposals.lock().await.insert(proposal_id, result); + } + } + .in_current_span(), + ); + + self.active_proposal_task = Some(ProposalTask { abort_signal_sender, join_handle }); + Ok(()) } // Returns a completed proposal result, either its commitment or an error if the proposal @@ -387,8 +473,7 @@ impl Batcher { &self, proposal_id: ProposalId, ) -> Option> { - let completed_proposals = self.proposal_manager.get_completed_proposals().await; - let guard = completed_proposals.lock().await; + let guard = self.executed_proposals.lock().await; let proposal_result = guard.get(&proposal_id); match proposal_result { @@ -397,6 +482,22 @@ impl Batcher { None => None, } } + + // Ends the current active proposal. + // This call is non-blocking. + async fn abort_active_proposal(&mut self) { + self.active_proposal.lock().await.take(); + if let Some(proposal_task) = self.active_proposal_task.take() { + proposal_task.abort_signal_sender.send(()).ok(); + } + } + + #[cfg(test)] + pub async fn await_active_proposal(&mut self) { + if let Some(proposal_task) = self.active_proposal_task.take() { + proposal_task.join_handle.await.ok(); + } + } } pub fn create_batcher(config: BatcherConfig, mempool_client: SharedMempoolClient) -> Batcher { @@ -410,15 +511,7 @@ pub fn create_batcher(config: BatcherConfig, mempool_client: SharedMempoolClient }); let storage_reader = Arc::new(storage_reader); let storage_writer = Box::new(storage_writer); - let proposal_manager = Box::new(ProposalManager::new()); - Batcher::new( - config, - storage_reader, - storage_writer, - mempool_client, - block_builder_factory, - proposal_manager, - ) + Batcher::new(config, storage_reader, storage_writer, mempool_client, block_builder_factory) } #[cfg_attr(test, automock)] @@ -453,23 +546,4 @@ impl BatcherStorageWriterTrait for papyrus_storage::StorageWriter { } } -impl From for BatcherError { - fn from(err: GenerateProposalError) -> Self { - match err { - GenerateProposalError::AlreadyGeneratingProposal { - current_generating_proposal_id, - new_proposal_id, - } => BatcherError::ServerBusy { - active_proposal_id: current_generating_proposal_id, - new_proposal_id, - }, - GenerateProposalError::BlockBuilderError(..) => BatcherError::InternalError, - GenerateProposalError::NoActiveHeight => BatcherError::NoActiveHeight, - GenerateProposalError::ProposalAlreadyExists { proposal_id } => { - BatcherError::ProposalAlreadyExists { proposal_id } - } - } - } -} - impl ComponentStarter for Batcher {} diff --git a/crates/starknet_batcher/src/batcher_test.rs b/crates/starknet_batcher/src/batcher_test.rs index d60068fc0a9..1e5894d4f1b 100644 --- a/crates/starknet_batcher/src/batcher_test.rs +++ b/crates/starknet_batcher/src/batcher_test.rs @@ -1,23 +1,13 @@ -use std::collections::{HashMap, HashSet}; use std::sync::Arc; use assert_matches::assert_matches; -use async_trait::async_trait; use blockifier::abi::constants; use blockifier::test_utils::struct_impls::BlockInfoExt; use chrono::Utc; -use futures::future::BoxFuture; -use futures::FutureExt; -use mockall::automock; -use mockall::predicate::{always, eq}; +use mockall::predicate::eq; use rstest::rstest; use starknet_api::block::{BlockInfo, BlockNumber}; -use starknet_api::core::{ContractAddress, Nonce, StateDiffCommitment}; use starknet_api::executable_transaction::Transaction; -use starknet_api::hash::PoseidonHash; -use starknet_api::state::ThinStateDiff; -use starknet_api::transaction::TransactionHash; -use starknet_api::{contract_address, felt, nonce, tx_hash}; use starknet_batcher_types::batcher_types::{ DecisionReachedInput, GetHeightResponse, @@ -37,40 +27,35 @@ use starknet_batcher_types::batcher_types::{ use starknet_batcher_types::errors::BatcherError; use starknet_mempool_types::communication::MockMempoolClient; use starknet_mempool_types::mempool_types::CommitBlockArgs; -use tokio::sync::Mutex; use crate::batcher::{Batcher, MockBatcherStorageReaderTrait, MockBatcherStorageWriterTrait}; use crate::block_builder::{ AbortSignalSender, BlockBuilderError, - BlockBuilderTrait, + BlockBuilderResult, + BlockExecutionArtifacts, FailOnErrorCause, MockBlockBuilderFactoryTrait, MockBlockBuilderTrait, }; use crate::config::BatcherConfig; -use crate::proposal_manager::{GenerateProposalError, ProposalManagerTrait}; use crate::test_utils::test_txs; use crate::transaction_provider::NextTxs; -use crate::utils::{ProposalOutput, ProposalResult}; +use crate::utils::ProposalOutput; const INITIAL_HEIGHT: BlockNumber = BlockNumber(3); const STREAMING_CHUNK_SIZE: usize = 3; const BLOCK_GENERATION_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(1); const PROPOSAL_ID: ProposalId = ProposalId(0); +const BUILD_BLOCK_FAIL_ON_ERROR: BlockBuilderError = + BlockBuilderError::FailOnError(FailOnErrorCause::BlockFull); fn initial_block_info() -> BlockInfo { BlockInfo { block_number: INITIAL_HEIGHT, ..BlockInfo::create_for_testing() } } fn proposal_commitment() -> ProposalCommitment { - ProposalCommitment { - state_diff_commitment: StateDiffCommitment(PoseidonHash(felt!(u128::try_from(7).unwrap()))), - } -} - -fn proposal_output() -> ProposalOutput { - ProposalOutput { commitment: proposal_commitment(), ..Default::default() } + ProposalOutput::from(BlockExecutionArtifacts::create_for_testing()).commitment } fn deadline() -> chrono::DateTime { @@ -95,15 +80,10 @@ fn validate_block_input() -> ValidateBlockInput { } } -fn invalid_proposal_result() -> ProposalResult { - Err(Arc::new(BlockBuilderError::FailOnError(FailOnErrorCause::BlockFull))) -} - struct MockDependencies { storage_reader: MockBatcherStorageReaderTrait, storage_writer: MockBatcherStorageWriterTrait, mempool_client: MockMempoolClient, - proposal_manager: MockProposalManagerTraitWrapper, block_builder_factory: MockBlockBuilderFactoryTrait, } @@ -115,7 +95,6 @@ impl Default for MockDependencies { storage_reader, storage_writer: MockBatcherStorageWriterTrait::new(), mempool_client: MockMempoolClient::new(), - proposal_manager: MockProposalManagerTraitWrapper::new(), block_builder_factory: MockBlockBuilderFactoryTrait::new(), } } @@ -128,7 +107,6 @@ fn create_batcher(mock_dependencies: MockDependencies) -> Batcher { Box::new(mock_dependencies.storage_writer), Arc::new(mock_dependencies.mempool_client), Box::new(mock_dependencies.block_builder_factory), - Box::new(mock_dependencies.proposal_manager), ) } @@ -136,19 +114,25 @@ fn abort_signal_sender() -> AbortSignalSender { tokio::sync::oneshot::channel().0 } -fn mock_create_builder_for_validate_block() -> MockBlockBuilderFactoryTrait { +fn mock_create_builder_for_validate_block( + build_block_result: BlockBuilderResult, +) -> MockBlockBuilderFactoryTrait { let mut block_builder_factory = MockBlockBuilderFactoryTrait::new(); block_builder_factory.expect_create_block_builder().times(1).return_once( |_, _, mut tx_provider, _| { - // Spawn a task to keep tx_provider alive until all transactions are read. - // Without this, the provider would be dropped, causing the batcher to fail when sending - // transactions to it during the test. - tokio::spawn(async move { - while tx_provider.get_txs(1).await.is_ok_and(|v| v != NextTxs::End) { - tokio::task::yield_now().await; - } + let mut block_builder = MockBlockBuilderTrait::new(); + block_builder.expect_build_block().times(1).return_once(move || { + // Spawn a task to keep tx_provider alive until all transactions are read. + // Without this, the provider would be dropped, causing the batcher to fail when + // sending transactions to it during the test. + tokio::spawn(async move { + while tx_provider.get_txs(1).await.is_ok_and(|v| v != NextTxs::End) { + tokio::task::yield_now().await; + } + }); + build_block_result }); - Ok((Box::new(MockBlockBuilderTrait::new()), abort_signal_sender())) + Ok((Box::new(block_builder), abort_signal_sender())) }, ); block_builder_factory @@ -156,52 +140,32 @@ fn mock_create_builder_for_validate_block() -> MockBlockBuilderFactoryTrait { fn mock_create_builder_for_propose_block( output_txs: Vec, + build_block_result: BlockBuilderResult, ) -> MockBlockBuilderFactoryTrait { let mut block_builder_factory = MockBlockBuilderFactoryTrait::new(); block_builder_factory.expect_create_block_builder().times(1).return_once( |_, _, _, output_content_sender| { - // Simulate the streaming of the block builder output. - for tx in output_txs { - output_content_sender.as_ref().unwrap().send(tx).unwrap(); - } - Ok((Box::new(MockBlockBuilderTrait::new()), abort_signal_sender())) + let mut block_builder = MockBlockBuilderTrait::new(); + block_builder.expect_build_block().times(1).return_once(move || { + // Simulate the streaming of the block builder output. + for tx in output_txs { + output_content_sender.as_ref().unwrap().send(tx).unwrap(); + } + build_block_result + }); + Ok((Box::new(block_builder), abort_signal_sender())) }, ); block_builder_factory } -fn mock_start_proposal(proposal_manager: &mut MockProposalManagerTraitWrapper) { - proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed()); - proposal_manager - .expect_wrap_spawn_proposal() - .times(1) - .with(eq(PROPOSAL_ID), always(), always()) - .return_once(|_, _, _| { async move { Ok(()) } }.boxed()); -} - -fn mock_completed_proposal( - proposal_manager: &mut MockProposalManagerTraitWrapper, - proposal_result: ProposalResult, -) { - proposal_manager.expect_wrap_get_completed_proposals().times(1).return_once(move || { - async move { Arc::new(Mutex::new(HashMap::from([(PROPOSAL_ID, proposal_result)]))) }.boxed() - }); -} - async fn batcher_with_validated_proposal( - proposal_result: ProposalResult, + build_block_result: BlockBuilderResult, ) -> Batcher { - let block_builder_factory = mock_create_builder_for_validate_block(); - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - mock_start_proposal(&mut proposal_manager); - mock_completed_proposal(&mut proposal_manager, proposal_result); - proposal_manager.expect_wrap_get_active_proposal().returning(|| async move { None }.boxed()); - - let mut batcher = create_batcher(MockDependencies { - proposal_manager, - block_builder_factory, - ..Default::default() - }); + let block_builder_factory = mock_create_builder_for_validate_block(build_block_result); + + let mut batcher = + create_batcher(MockDependencies { block_builder_factory, ..Default::default() }); batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); @@ -210,27 +174,10 @@ async fn batcher_with_validated_proposal( batcher } -fn mock_proposal_manager_validate_flow() -> MockProposalManagerTraitWrapper { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - mock_start_proposal(&mut proposal_manager); - proposal_manager - .expect_wrap_get_active_proposal() - .returning(|| async move { Some(PROPOSAL_ID) }.boxed()); - proposal_manager - .expect_wrap_await_active_proposal() - .times(1) - .returning(|| async move { true }.boxed()); - mock_completed_proposal(&mut proposal_manager, Ok(proposal_output())); - proposal_manager -} - #[rstest] #[tokio::test] async fn start_height_success() { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed()); - - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); assert_eq!(batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await, Ok(())); } @@ -251,20 +198,14 @@ async fn start_height_success() { )] #[tokio::test] async fn start_height_fail(#[case] height: BlockNumber, #[case] expected_error: BatcherError) { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_reset().never(); - - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); assert_eq!(batcher.start_height(StartHeightInput { height }).await, Err(expected_error)); } #[rstest] #[tokio::test] async fn duplicate_start_height() { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed()); - - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); let initial_height = StartHeightInput { height: INITIAL_HEIGHT }; assert_eq!(batcher.start_height(initial_height.clone()).await, Ok(())); @@ -274,8 +215,7 @@ async fn duplicate_start_height() { #[rstest] #[tokio::test] async fn no_active_height() { - let proposal_manager = MockProposalManagerTraitWrapper::new(); - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); // Calling `propose_block` and `validate_block` without starting a height should fail. @@ -289,13 +229,10 @@ async fn no_active_height() { #[rstest] #[tokio::test] async fn validate_block_full_flow() { - let block_builder_factory = mock_create_builder_for_validate_block(); - let proposal_manager = mock_proposal_manager_validate_flow(); - let mut batcher = create_batcher(MockDependencies { - proposal_manager, - block_builder_factory, - ..Default::default() - }); + let block_builder_factory = + mock_create_builder_for_validate_block(Ok(BlockExecutionArtifacts::create_for_testing())); + let mut batcher = + create_batcher(MockDependencies { block_builder_factory, ..Default::default() }); batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); batcher.validate_block(validate_block_input()).await.unwrap(); @@ -323,8 +260,9 @@ async fn validate_block_full_flow() { #[rstest] #[tokio::test] async fn send_content_after_proposal_already_finished() { - let successful_proposal_result = Ok(proposal_output()); - let mut batcher = batcher_with_validated_proposal(successful_proposal_result).await; + let successful_build_block_result = Ok(BlockExecutionArtifacts::create_for_testing()); + let mut batcher = batcher_with_validated_proposal(successful_build_block_result).await; + batcher.await_active_proposal().await; // Send transactions after the proposal has finished. let send_proposal_input_txs = SendProposalContentInput { @@ -358,7 +296,8 @@ async fn send_content_to_unknown_proposal() { #[rstest] #[tokio::test] async fn send_txs_to_an_invalid_proposal() { - let mut batcher = batcher_with_validated_proposal(invalid_proposal_result()).await; + let mut batcher = batcher_with_validated_proposal(Err(BUILD_BLOCK_FAIL_ON_ERROR)).await; + batcher.await_active_proposal().await; let send_proposal_input_txs = SendProposalContentInput { proposal_id: PROPOSAL_ID, @@ -371,7 +310,7 @@ async fn send_txs_to_an_invalid_proposal() { #[rstest] #[tokio::test] async fn send_finish_to_an_invalid_proposal() { - let mut batcher = batcher_with_validated_proposal(invalid_proposal_result()).await; + let mut batcher = batcher_with_validated_proposal(Err(BUILD_BLOCK_FAIL_ON_ERROR)).await; let send_proposal_input_txs = SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Finish }; @@ -386,16 +325,13 @@ async fn propose_block_full_flow() { let expected_streamed_txs = test_txs(0..STREAMING_CHUNK_SIZE * 2 + 1); let txs_to_stream = expected_streamed_txs.clone(); - let block_builder_factory = mock_create_builder_for_propose_block(txs_to_stream); - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - mock_start_proposal(&mut proposal_manager); - mock_completed_proposal(&mut proposal_manager, Ok(proposal_output())); + let block_builder_factory = mock_create_builder_for_propose_block( + txs_to_stream, + Ok(BlockExecutionArtifacts::create_for_testing()), + ); - let mut batcher = create_batcher(MockDependencies { - proposal_manager, - block_builder_factory, - ..Default::default() - }); + let mut batcher = + create_batcher(MockDependencies { block_builder_factory, ..Default::default() }); batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); batcher.propose_block(propose_block_input()).await.unwrap(); @@ -443,16 +379,12 @@ async fn get_height() { #[rstest] #[tokio::test] async fn propose_block_without_retrospective_block_hash() { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed()); - let mut storage_reader = MockBatcherStorageReaderTrait::new(); storage_reader .expect_height() .returning(|| Ok(BlockNumber(constants::STORED_BLOCK_HASH_BUFFER))); - let mut batcher = - create_batcher(MockDependencies { proposal_manager, storage_reader, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies { storage_reader, ..Default::default() }); batcher .start_height(StartHeightInput { height: BlockNumber(constants::STORED_BLOCK_HASH_BUFFER) }) @@ -466,10 +398,7 @@ async fn propose_block_without_retrospective_block_hash() { #[rstest] #[tokio::test] async fn get_content_from_unknown_proposal() { - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_get_completed_proposals().times(0); - - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); + let mut batcher = create_batcher(MockDependencies::default()); let get_proposal_content_input = GetProposalContentInput { proposal_id: PROPOSAL_ID }; let result = batcher.get_proposal_content(get_proposal_content_input).await; @@ -480,40 +409,41 @@ async fn get_content_from_unknown_proposal() { #[tokio::test] async fn decision_reached() { let mut mock_dependencies = MockDependencies::default(); - - mock_dependencies - .proposal_manager - .expect_wrap_take_proposal_result() - .times(1) - .with(eq(PROPOSAL_ID)) - .return_once(move |_| { - async move { - Some(Ok(ProposalOutput { - state_diff: ThinStateDiff::default(), - commitment: ProposalCommitment::default(), - tx_hashes: test_tx_hashes(), - nonces: test_contract_nonces(), - })) - } - .boxed() - }); + let expected_proposal_output = + ProposalOutput::from(BlockExecutionArtifacts::create_for_testing()); mock_dependencies .mempool_client .expect_commit_block() .with(eq(CommitBlockArgs { - address_to_nonce: test_contract_nonces(), - tx_hashes: test_tx_hashes(), + address_to_nonce: expected_proposal_output.nonces, + tx_hashes: expected_proposal_output.tx_hashes, })) .returning(|_| Ok(())); mock_dependencies .storage_writer .expect_commit_proposal() - .with(eq(INITIAL_HEIGHT), eq(ThinStateDiff::default())) + .with(eq(INITIAL_HEIGHT), eq(expected_proposal_output.state_diff)) .returning(|_, _| Ok(())); + mock_dependencies.block_builder_factory = mock_create_builder_for_propose_block( + vec![], + Ok(BlockExecutionArtifacts::create_for_testing()), + ); + let mut batcher = create_batcher(mock_dependencies); + batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); + batcher + .propose_block(ProposeBlockInput { + proposal_id: PROPOSAL_ID, + retrospective_block_hash: None, + deadline: deadline(), + block_info: initial_block_info(), + }) + .await + .unwrap(); + batcher.await_active_proposal().await; batcher.decision_reached(DecisionReachedInput { proposal_id: PROPOSAL_ID }).await.unwrap(); } @@ -523,93 +453,10 @@ async fn decision_reached() { async fn decision_reached_no_executed_proposal() { let expected_error = BatcherError::ExecutedProposalNotFound { proposal_id: PROPOSAL_ID }; - let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager - .expect_wrap_take_proposal_result() - .times(1) - .with(eq(PROPOSAL_ID)) - .return_once(|_| async move { None }.boxed()); + let mut batcher = create_batcher(MockDependencies::default()); + batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); - let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() }); let decision_reached_result = batcher.decision_reached(DecisionReachedInput { proposal_id: PROPOSAL_ID }).await; assert_eq!(decision_reached_result, Err(expected_error)); } - -// A wrapper trait to allow mocking the ProposalManagerTrait in tests. -#[automock] -trait ProposalManagerTraitWrapper: Send + Sync { - fn wrap_spawn_proposal( - &mut self, - proposal_id: ProposalId, - block_builder: Box, - abort_signal_sender: tokio::sync::oneshot::Sender<()>, - ) -> BoxFuture<'_, Result<(), GenerateProposalError>>; - - fn wrap_take_proposal_result( - &mut self, - proposal_id: ProposalId, - ) -> BoxFuture<'_, Option>>; - - fn wrap_get_active_proposal(&self) -> BoxFuture<'_, Option>; - - #[allow(clippy::type_complexity)] - fn wrap_get_completed_proposals( - &self, - ) -> BoxFuture<'_, Arc>>>>; - - fn wrap_await_active_proposal(&mut self) -> BoxFuture<'_, bool>; - - fn wrap_abort_proposal(&mut self, proposal_id: ProposalId) -> BoxFuture<'_, ()>; - - fn wrap_reset(&mut self) -> BoxFuture<'_, ()>; -} - -#[async_trait] -impl ProposalManagerTrait for T { - async fn spawn_proposal( - &mut self, - proposal_id: ProposalId, - block_builder: Box, - abort_signal_sender: tokio::sync::oneshot::Sender<()>, - ) -> Result<(), GenerateProposalError> { - self.wrap_spawn_proposal(proposal_id, block_builder, abort_signal_sender).await - } - - async fn take_proposal_result( - &mut self, - proposal_id: ProposalId, - ) -> Option> { - self.wrap_take_proposal_result(proposal_id).await - } - - async fn get_active_proposal(&self) -> Option { - self.wrap_get_active_proposal().await - } - - async fn get_completed_proposals( - &self, - ) -> Arc>>> { - self.wrap_get_completed_proposals().await - } - - async fn await_active_proposal(&mut self) -> bool { - self.wrap_await_active_proposal().await - } - - async fn abort_proposal(&mut self, proposal_id: ProposalId) { - self.wrap_abort_proposal(proposal_id).await - } - - async fn reset(&mut self) { - self.wrap_reset().await - } -} - -fn test_tx_hashes() -> HashSet { - (0..5u8).map(|i| tx_hash!(i + 12)).collect() -} - -fn test_contract_nonces() -> HashMap { - HashMap::from_iter((0..3u8).map(|i| (contract_address!(i + 33), nonce!(i + 9)))) -} diff --git a/crates/starknet_batcher/src/lib.rs b/crates/starknet_batcher/src/lib.rs index 199f3544695..a17a44409cd 100644 --- a/crates/starknet_batcher/src/lib.rs +++ b/crates/starknet_batcher/src/lib.rs @@ -7,9 +7,6 @@ mod block_builder_test; pub mod communication; pub mod config; pub mod fee_market; -mod proposal_manager; -#[cfg(test)] -mod proposal_manager_test; #[cfg(test)] mod test_utils; mod transaction_executor; diff --git a/crates/starknet_batcher/src/test_utils.rs b/crates/starknet_batcher/src/test_utils.rs index 73f251de039..31a9e67a77b 100644 --- a/crates/starknet_batcher/src/test_utils.rs +++ b/crates/starknet_batcher/src/test_utils.rs @@ -6,7 +6,7 @@ use blockifier::state::cached_state::CommitmentStateDiff; use indexmap::IndexMap; use starknet_api::executable_transaction::Transaction; use starknet_api::test_utils::invoke::{executable_invoke_tx, InvokeTxArgs}; -use starknet_api::tx_hash; +use starknet_api::{class_hash, contract_address, nonce, tx_hash}; use crate::block_builder::BlockExecutionArtifacts; @@ -23,9 +23,18 @@ pub fn test_txs(tx_hash_range: Range) -> Vec { impl BlockExecutionArtifacts { pub fn create_for_testing() -> Self { + // Use a non-empty commitment_state_diff to make the tests more realistic. Self { execution_infos: IndexMap::default(), - commitment_state_diff: CommitmentStateDiff::default(), + commitment_state_diff: CommitmentStateDiff { + address_to_class_hash: IndexMap::from_iter([( + contract_address!("0x7"), + class_hash!("0x11111111"), + )]), + storage_updates: IndexMap::new(), + class_hash_to_compiled_class_hash: IndexMap::new(), + address_to_nonce: IndexMap::from_iter([(contract_address!("0x7"), nonce!(1_u64))]), + }, visited_segments_mapping: VisitedSegmentsMapping::default(), bouncer_weights: BouncerWeights::empty(), }