diff --git a/crates/starknet_batcher/src/batcher.rs b/crates/starknet_batcher/src/batcher.rs index afc638b428..48768d403d 100644 --- a/crates/starknet_batcher/src/batcher.rs +++ b/crates/starknet_batcher/src/batcher.rs @@ -17,7 +17,7 @@ use starknet_batcher_types::batcher_types::{ GetProposalContentInput, GetProposalContentResponse, ProposalId, - ProposalStatus as ResponseProposalStatus, + ProposalStatus, SendProposalContent, SendProposalContentInput, SendProposalContentResponse, @@ -35,13 +35,17 @@ use crate::config::BatcherConfig; use crate::proposal_manager::{ GenerateProposalError, GetProposalResultError, + InternalProposalStatus, ProposalManager, ProposalManagerTrait, ProposalOutput, - ProposalStatus, StartHeightError, }; -use crate::transaction_provider::{DummyL1ProviderClient, ProposeTransactionProvider}; +use crate::transaction_provider::{ + DummyL1ProviderClient, + ProposeTransactionProvider, + ValidateTransactionProvider, +}; type OutputStreamReceiver = tokio::sync::mpsc::UnboundedReceiver; type InputStreamSender = tokio::sync::mpsc::Sender; @@ -89,7 +93,7 @@ impl Batcher { let proposal_id = build_proposal_input.proposal_id; let deadline = deadline_as_instant(build_proposal_input.deadline)?; - let (tx_sender, tx_receiver) = tokio::sync::mpsc::unbounded_channel(); + let (output_tx_sender, output_tx_receiver) = tokio::sync::mpsc::unbounded_channel(); let tx_provider = ProposeTransactionProvider::new( self.mempool_client.clone(), // TODO: use a real L1 provider client. @@ -102,14 +106,12 @@ impl Batcher { build_proposal_input.proposal_id, build_proposal_input.retrospective_block_hash, deadline, - tx_sender, + output_tx_sender, tx_provider, ) - .await - .map_err(BatcherError::from)?; + .await?; - let output_tx_stream = tx_receiver; - self.build_proposals.insert(proposal_id, output_tx_stream); + self.build_proposals.insert(proposal_id, output_tx_receiver); Ok(()) } @@ -118,7 +120,28 @@ impl Batcher { &mut self, validate_proposal_input: ValidateProposalInput, ) -> BatcherResult<()> { - todo!(); + let proposal_id = validate_proposal_input.proposal_id; + let deadline = deadline_as_instant(validate_proposal_input.deadline)?; + + let (input_tx_sender, input_tx_receiver) = + tokio::sync::mpsc::channel(self.config.input_stream_content_buffer_size); + let tx_provider = ValidateTransactionProvider { + tx_receiver: input_tx_receiver, + // TODO: use a real L1 provider client. + l1_provider_client: Arc::new(DummyL1ProviderClient), + }; + + self.proposal_manager + .validate_block_proposal( + validate_proposal_input.proposal_id, + validate_proposal_input.retrospective_block_hash, + deadline, + tx_provider, + ) + .await?; + + self.validate_proposals.insert(proposal_id, input_tx_sender); + Ok(()) } // This function assumes that requests are received in order, otherwise the content could @@ -133,7 +156,7 @@ impl Batcher { match send_proposal_content_input.content { SendProposalContent::Txs(txs) => self.send_txs_and_get_status(proposal_id, txs).await, SendProposalContent::Finish => { - self.close_tx_channel_and_get_commitement(proposal_id).await + self.close_tx_channel_and_get_commitment(proposal_id).await } SendProposalContent::Abort => { unimplemented!("Abort not implemented yet."); @@ -147,8 +170,7 @@ impl Batcher { txs: Vec, ) -> BatcherResult { match self.proposal_manager.get_proposal_status(proposal_id).await { - ProposalStatus::Processing => { - // TODO: validate L1 transactions. + InternalProposalStatus::Processing => { let tx_provider_sender = &self .validate_proposals .get(&proposal_id) @@ -159,18 +181,20 @@ impl Batcher { BatcherError::InternalError })?; } - Ok(SendProposalContentResponse { response: ResponseProposalStatus::Processing }) + Ok(SendProposalContentResponse { response: ProposalStatus::Processing }) } // Proposal Got an Error while processing transactions. - ProposalStatus::Failed => Ok(SendProposalContentResponse { - response: ResponseProposalStatus::InvalidProposal, - }), - ProposalStatus::Finished => Err(BatcherError::ProposalAlreadyFinished { proposal_id }), - ProposalStatus::NotFound => Err(BatcherError::ProposalNotFound { proposal_id }), + InternalProposalStatus::Failed => { + Ok(SendProposalContentResponse { response: ProposalStatus::InvalidProposal }) + } + InternalProposalStatus::Finished => { + Err(BatcherError::ProposalAlreadyFinished { proposal_id }) + } + InternalProposalStatus::NotFound => Err(BatcherError::ProposalNotFound { proposal_id }), } } - async fn close_tx_channel_and_get_commitement( + async fn close_tx_channel_and_get_commitment( &mut self, proposal_id: ProposalId, ) -> BatcherResult { @@ -184,9 +208,7 @@ impl Batcher { let proposal_commitment = self.proposal_manager.await_proposal_commitment(proposal_id).await?; - Ok(SendProposalContentResponse { - response: ResponseProposalStatus::Finished(proposal_commitment), - }) + Ok(SendProposalContentResponse { response: ProposalStatus::Finished(proposal_commitment) }) } #[instrument(skip(self), err)] diff --git a/crates/starknet_batcher/src/batcher_test.rs b/crates/starknet_batcher/src/batcher_test.rs index c01f738fe4..c518f4e1d2 100644 --- a/crates/starknet_batcher/src/batcher_test.rs +++ b/crates/starknet_batcher/src/batcher_test.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use assert_matches::assert_matches; use async_trait::async_trait; +use chrono::Utc; use futures::future::BoxFuture; use futures::FutureExt; use mockall::automock; -use mockall::predicate::eq; +use mockall::predicate::{always, eq}; use rstest::{fixture, rstest}; use starknet_api::block::{BlockHashAndNumber, BlockNumber}; use starknet_api::core::{ContractAddress, Nonce, StateDiffCommitment}; @@ -23,7 +24,12 @@ use starknet_batcher_types::batcher_types::{ GetProposalContentResponse, ProposalCommitment, ProposalId, + ProposalStatus, + SendProposalContent, + SendProposalContentInput, + SendProposalContentResponse, StartHeightInput, + ValidateProposalInput, }; use starknet_batcher_types::errors::BatcherError; use starknet_mempool_types::communication::MockMempoolClient; @@ -34,10 +40,10 @@ use crate::config::BatcherConfig; use crate::proposal_manager::{ GenerateProposalError, GetProposalResultError, + InternalProposalStatus, ProposalManagerTrait, ProposalOutput, ProposalResult, - ProposalStatus, StartHeightError, }; use crate::test_utils::test_txs; @@ -45,6 +51,18 @@ use crate::transaction_provider::{ProposeTransactionProvider, ValidateTransactio const INITIAL_HEIGHT: BlockNumber = BlockNumber(3); const STREAMING_CHUNK_SIZE: usize = 3; +const BLOCK_GENERATION_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(1); +const PROPOSAL_ID: ProposalId = ProposalId(0); + +fn proposal_commitment() -> ProposalCommitment { + ProposalCommitment { + state_diff_commitment: StateDiffCommitment(PoseidonHash(felt!(u128::try_from(7).unwrap()))), + } +} + +fn deadline() -> chrono::DateTime { + chrono::Utc::now() + BLOCK_GENERATION_TIMEOUT +} #[fixture] fn storage_reader() -> MockBatcherStorageReaderTrait { @@ -78,28 +96,106 @@ fn batcher(proposal_manager: MockProposalManagerTraitWrapper) -> Batcher { ) } +fn mock_proposal_manager_common_expectations( + proposal_manager: &mut MockProposalManagerTraitWrapper, +) { + proposal_manager + .expect_wrap_start_height() + .times(1) + .with(eq(INITIAL_HEIGHT)) + .return_once(|_| async { Ok(()) }.boxed()); + proposal_manager + .expect_wrap_await_proposal_commitment() + .times(1) + .with(eq(PROPOSAL_ID)) + .return_once(move |_| { async move { Ok(proposal_commitment()) } }.boxed()); +} + +fn mock_proposal_manager_validate_flow() -> MockProposalManagerTraitWrapper { + let mut proposal_manager = MockProposalManagerTraitWrapper::new(); + mock_proposal_manager_common_expectations(&mut proposal_manager); + proposal_manager + .expect_wrap_validate_block_proposal() + .times(1) + .with(eq(PROPOSAL_ID), eq(None), always(), always()) + .return_once(|_, _, _, tx_provider| { + { + async move { + // Spawn a task to keep tx_provider alive until the transactions sender is + // dropped. Without this, the provider would be dropped, + // causing the batcher to fail when sending transactions to + // it during the test. + tokio::spawn(async move { + while !tx_provider.tx_receiver.is_closed() { + tokio::task::yield_now().await; + } + }); + Ok(()) + } + } + .boxed() + }); + proposal_manager + .expect_wrap_get_proposal_status() + .times(1) + .with(eq(PROPOSAL_ID)) + .returning(move |_| async move { InternalProposalStatus::Processing }.boxed()); + proposal_manager +} + +// TODO: add negative tests +#[rstest] +#[tokio::test] +async fn validate_proposal_full_flow() { + let proposal_manager = mock_proposal_manager_validate_flow(); + let mut batcher = batcher(proposal_manager); + + // TODO(Yael 14/11/2024): The test will pass without calling start height (if we delete the mock + // expectation). Leaving this here for future compatibility with the upcoming + // batcher-proposal_manager unification. + batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap(); + + let validate_proposal_input = ValidateProposalInput { + proposal_id: PROPOSAL_ID, + deadline: deadline(), + retrospective_block_hash: None, + }; + batcher.validate_proposal(validate_proposal_input).await.unwrap(); + + let send_proposal_input_txs = SendProposalContentInput { + proposal_id: PROPOSAL_ID, + content: SendProposalContent::Txs(test_txs(0..1)), + }; + let send_txs_result = batcher.send_proposal_content(send_proposal_input_txs).await.unwrap(); + assert_eq!( + send_txs_result, + SendProposalContentResponse { response: ProposalStatus::Processing } + ); + + let send_proposal_input_finish = + SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Finish }; + let send_finish_result = + batcher.send_proposal_content(send_proposal_input_finish).await.unwrap(); + assert_eq!( + send_finish_result, + SendProposalContentResponse { response: ProposalStatus::Finished(proposal_commitment()) } + ); +} + #[rstest] #[tokio::test] -async fn get_stream_content() { - const PROPOSAL_ID: ProposalId = ProposalId(0); +async fn build_proposal_full_flow() { // Expecting 3 chunks of streamed txs. let expected_streamed_txs = test_txs(0..STREAMING_CHUNK_SIZE * 2 + 1); let txs_to_stream = expected_streamed_txs.clone(); - let expected_proposal_commitment = ProposalCommitment { - state_diff_commitment: StateDiffCommitment(PoseidonHash(felt!(u128::try_from(7).unwrap()))), - }; let mut proposal_manager = MockProposalManagerTraitWrapper::new(); - proposal_manager.expect_wrap_start_height().times(1).return_once(|_| async { Ok(()) }.boxed()); + mock_proposal_manager_common_expectations(&mut proposal_manager); proposal_manager.expect_wrap_build_block_proposal().times(1).return_once( move |_proposal_id, _block_hash, _deadline, tx_sender, _tx_provider| { simulate_build_block_proposal(tx_sender, txs_to_stream).boxed() }, ); - proposal_manager - .expect_wrap_executed_proposal_commitment() - .times(1) - .return_once(move |_| async move { Ok(expected_proposal_commitment) }.boxed()); let mut batcher = batcher(proposal_manager); @@ -134,8 +230,8 @@ async fn get_stream_content() { assert_matches!( commitment, GetProposalContentResponse { - content: GetProposalContent::Finished(proposal_commitment) - } if proposal_commitment == expected_proposal_commitment + content: GetProposalContent::Finished(commitment) + } if commitment == proposal_commitment() ); let exhausted = @@ -151,7 +247,6 @@ async fn decision_reached( mut storage_writer: MockBatcherStorageWriterTrait, mut mempool_client: MockMempoolClient, ) { - const PROPOSAL_ID: ProposalId = ProposalId(0); let expected_state_diff = ThinStateDiff::default(); let state_diff_clone = expected_state_diff.clone(); let expected_proposal_commitment = ProposalCommitment::default(); @@ -191,13 +286,12 @@ async fn decision_reached( Arc::new(mempool_client), Box::new(proposal_manager), ); - batcher.decision_reached(DecisionReachedInput { proposal_id: ProposalId(0) }).await.unwrap(); + batcher.decision_reached(DecisionReachedInput { proposal_id: PROPOSAL_ID }).await.unwrap(); } #[rstest] #[tokio::test] async fn decision_reached_no_executed_proposal() { - const PROPOSAL_ID: ProposalId = ProposalId(0); let expected_error = BatcherError::ExecutedProposalNotFound { proposal_id: PROPOSAL_ID }; let mut proposal_manager = MockProposalManagerTraitWrapper::new(); @@ -255,9 +349,12 @@ trait ProposalManagerTraitWrapper: Send + Sync { proposal_id: ProposalId, ) -> BoxFuture<'_, ProposalResult>; - fn wrap_get_proposal_status(&self, proposal_id: ProposalId) -> BoxFuture<'_, ProposalStatus>; + fn wrap_get_proposal_status( + &self, + proposal_id: ProposalId, + ) -> BoxFuture<'_, InternalProposalStatus>; - fn wrap_executed_proposal_commitment( + fn wrap_await_proposal_commitment( &self, proposal_id: ProposalId, ) -> BoxFuture<'_, ProposalResult>; @@ -312,7 +409,7 @@ impl ProposalManagerTrait for T { self.wrap_take_proposal_result(proposal_id).await } - async fn get_proposal_status(&self, proposal_id: ProposalId) -> ProposalStatus { + async fn get_proposal_status(&self, proposal_id: ProposalId) -> InternalProposalStatus { self.wrap_get_proposal_status(proposal_id).await } @@ -320,7 +417,7 @@ impl ProposalManagerTrait for T { &mut self, proposal_id: ProposalId, ) -> ProposalResult { - self.wrap_executed_proposal_commitment(proposal_id).await + self.wrap_await_proposal_commitment(proposal_id).await } async fn abort_proposal(&mut self, proposal_id: ProposalId) { diff --git a/crates/starknet_batcher/src/proposal_manager.rs b/crates/starknet_batcher/src/proposal_manager.rs index d553eeba8f..34472258a8 100644 --- a/crates/starknet_batcher/src/proposal_manager.rs +++ b/crates/starknet_batcher/src/proposal_manager.rs @@ -71,7 +71,7 @@ pub enum GetProposalResultError { Aborted, } -pub enum ProposalStatus { +pub(crate) enum InternalProposalStatus { Processing, Finished, Failed, @@ -91,8 +91,6 @@ pub trait ProposalManagerTrait: Send + Sync { tx_provider: ProposeTransactionProvider, ) -> Result<(), GenerateProposalError>; - // TODO: delete allow dead code once the batcher uses this code. - #[allow(dead_code)] async fn validate_block_proposal( &mut self, proposal_id: ProposalId, @@ -106,7 +104,7 @@ pub trait ProposalManagerTrait: Send + Sync { proposal_id: ProposalId, ) -> ProposalResult; - async fn get_proposal_status(&self, proposal_id: ProposalId) -> ProposalStatus; + async fn get_proposal_status(&self, proposal_id: ProposalId) -> InternalProposalStatus; async fn await_proposal_commitment( &mut self, @@ -265,15 +263,15 @@ impl ProposalManagerTrait for ProposalManager { } // Returns None if the proposal does not exist, otherwise, returns the status of the proposal. - async fn get_proposal_status(&self, proposal_id: ProposalId) -> ProposalStatus { + async fn get_proposal_status(&self, proposal_id: ProposalId) -> InternalProposalStatus { match self.executed_proposals.lock().await.get(&proposal_id) { - Some(Ok(_)) => ProposalStatus::Finished, - Some(Err(_)) => ProposalStatus::Failed, + Some(Ok(_)) => InternalProposalStatus::Finished, + Some(Err(_)) => InternalProposalStatus::Failed, None => { if self.active_proposal.lock().await.as_ref() == Some(&proposal_id) { - ProposalStatus::Processing + InternalProposalStatus::Processing } else { - ProposalStatus::NotFound + InternalProposalStatus::NotFound } } } diff --git a/crates/starknet_batcher_types/src/batcher_types.rs b/crates/starknet_batcher_types/src/batcher_types.rs index 4be67e9a71..fe2b05549f 100644 --- a/crates/starknet_batcher_types/src/batcher_types.rs +++ b/crates/starknet_batcher_types/src/batcher_types.rs @@ -79,12 +79,12 @@ pub enum SendProposalContent { Abort, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub struct SendProposalContentResponse { pub response: ProposalStatus, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub enum ProposalStatus { Processing, // Only sent in response to `Finish`.