From bc3e41b3f682685d391fe446a08e4fc06376c160 Mon Sep 17 00:00:00 2001 From: Arni Hod Date: Mon, 9 Dec 2024 20:11:08 +0200 Subject: [PATCH] feat(starknet_batcher): implement sync_block --- crates/starknet_batcher/src/batcher.rs | 36 +++++++++++---- crates/starknet_batcher/src/batcher_test.rs | 51 +++++++++++++++++++++ 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/crates/starknet_batcher/src/batcher.rs b/crates/starknet_batcher/src/batcher.rs index 1599fee8262..016be048826 100644 --- a/crates/starknet_batcher/src/batcher.rs +++ b/crates/starknet_batcher/src/batcher.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use blockifier::state::global_cache::GlobalContractCache; @@ -6,8 +6,10 @@ use blockifier::state::global_cache::GlobalContractCache; use mockall::automock; use papyrus_storage::state::{StateStorageReader, StateStorageWriter}; use starknet_api::block::BlockNumber; +use starknet_api::core::{ContractAddress, Nonce}; use starknet_api::executable_transaction::Transaction; use starknet_api::state::ThinStateDiff; +use starknet_api::transaction::TransactionHash; use starknet_batcher_types::batcher_types::{ BatcherResult, DecisionReachedInput, @@ -351,9 +353,20 @@ impl Batcher { Ok(GetProposalContentResponse { content: GetProposalContent::Finished(commitment) }) } - // TODO(Arni): Impl add sync block - pub async fn add_sync_block(&mut self, _sync_block: SyncBlock) -> BatcherResult<()> { - todo!("Implement add sync block"); + #[instrument(skip(self), err)] + pub async fn add_sync_block(&mut self, sync_block: SyncBlock) -> BatcherResult<()> { + if let Some(height) = self.active_height { + info!("Aborting all work on height {} due to state sync.", height); + self.abort_active_height().await; + self.active_height = None; + } + + let SyncBlock { state_diff, transaction_hashes } = sync_block; + let address_to_nonce = state_diff.nonces.iter().map(|(k, v)| (*k, *v)).collect(); + let tx_hashes = transaction_hashes.into_iter().collect(); + + // TODO(Arni): Assert the input `sync_block` corresponds to this `height`. + self.commit_proposal_and_block(state_diff, address_to_nonce, tx_hashes).await } #[instrument(skip(self), err)] @@ -367,12 +380,19 @@ impl Batcher { .map_err(|_| BatcherError::InternalError)?; let ProposalOutput { state_diff, nonces: address_to_nonce, tx_hashes, .. } = proposal_output; + + self.commit_proposal_and_block(state_diff, address_to_nonce, tx_hashes).await + } + + async fn commit_proposal_and_block( + &mut self, + state_diff: ThinStateDiff, + address_to_nonce: HashMap, + tx_hashes: HashSet, + ) -> BatcherResult<()> { // TODO: Keep the height from start_height or get it from the input. let height = self.get_height_from_storage()?; - info!( - "Committing proposal {} at height {} and notifying mempool of the block.", - proposal_id, height - ); + info!("Committing block at height {} and notifying mempool of the block.", height); trace!("Transactions: {:#?}, State diff: {:#?}.", tx_hashes, state_diff); self.storage_writer.commit_proposal(height, state_diff).map_err(|err| { error!("Failed to commit proposal to storage: {}", err); diff --git a/crates/starknet_batcher/src/batcher_test.rs b/crates/starknet_batcher/src/batcher_test.rs index d60068fc0a9..751c3cd2640 100644 --- a/crates/starknet_batcher/src/batcher_test.rs +++ b/crates/starknet_batcher/src/batcher_test.rs @@ -8,6 +8,7 @@ use blockifier::test_utils::struct_impls::BlockInfoExt; use chrono::Utc; use futures::future::BoxFuture; use futures::FutureExt; +use indexmap::indexmap; use mockall::automock; use mockall::predicate::{always, eq}; use rstest::rstest; @@ -37,6 +38,7 @@ use starknet_batcher_types::batcher_types::{ use starknet_batcher_types::errors::BatcherError; use starknet_mempool_types::communication::MockMempoolClient; use starknet_mempool_types::mempool_types::CommitBlockArgs; +use starknet_state_sync_types::state_sync_types::SyncBlock; use tokio::sync::Mutex; use crate::batcher::{Batcher, MockBatcherStorageReaderTrait, MockBatcherStorageWriterTrait}; @@ -476,6 +478,37 @@ async fn get_content_from_unknown_proposal() { assert_eq!(result, Err(BatcherError::ProposalNotFound { proposal_id: PROPOSAL_ID })); } +#[rstest] +#[tokio::test] +async fn add_sync_block() { + let mut mock_dependencies = MockDependencies::default(); + + mock_dependencies + .storage_writer + .expect_commit_proposal() + .times(1) + .with(eq(INITIAL_HEIGHT), eq(test_state_diff())) + .returning(|_, _| Ok(())); + + mock_dependencies + .mempool_client + .expect_commit_block() + .times(1) + .with(eq(CommitBlockArgs { + address_to_nonce: test_contract_nonces(), + tx_hashes: test_tx_hashes(), + })) + .returning(|_| Ok(())); + + let mut batcher = create_batcher(mock_dependencies); + + let sync_block = SyncBlock { + state_diff: test_state_diff(), + transaction_hashes: test_tx_hashes().into_iter().collect(), + }; + batcher.add_sync_block(sync_block).await.unwrap(); +} + #[rstest] #[tokio::test] async fn decision_reached() { @@ -501,6 +534,7 @@ async fn decision_reached() { mock_dependencies .mempool_client .expect_commit_block() + .times(1) .with(eq(CommitBlockArgs { address_to_nonce: test_contract_nonces(), tx_hashes: test_tx_hashes(), @@ -510,6 +544,7 @@ async fn decision_reached() { mock_dependencies .storage_writer .expect_commit_proposal() + .times(1) .with(eq(INITIAL_HEIGHT), eq(ThinStateDiff::default())) .returning(|_, _| Ok(())); @@ -613,3 +648,19 @@ fn test_tx_hashes() -> HashSet { fn test_contract_nonces() -> HashMap { HashMap::from_iter((0..3u8).map(|i| (contract_address!(i + 33), nonce!(i + 9)))) } + +pub fn test_state_diff() -> ThinStateDiff { + ThinStateDiff { + storage_diffs: indexmap! { + 4u64.into() => indexmap! { + 5u64.into() => 6u64.into(), + 7u64.into() => 8u64.into(), + }, + 9u64.into() => indexmap! { + 10u64.into() => 11u64.into(), + }, + }, + nonces: test_contract_nonces().into_iter().collect(), + ..Default::default() + } +}