Skip to content

Commit

Permalink
refactor(batcher): move height mgmt out to the batcher (#2192)
Browse files Browse the repository at this point in the history
  • Loading branch information
dafnamatsry authored Nov 21, 2024
1 parent 7661556 commit 2ff3249
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 227 deletions.
59 changes: 37 additions & 22 deletions crates/starknet_batcher/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ use crate::proposal_manager::{
ProposalManager,
ProposalManagerTrait,
ProposalOutput,
StartHeightError,
};
use crate::transaction_provider::{
DummyL1ProviderClient,
Expand All @@ -55,6 +54,8 @@ pub struct Batcher {
pub storage_reader: Arc<dyn BatcherStorageReaderTrait>,
pub storage_writer: Box<dyn BatcherStorageWriterTrait>,
pub mempool_client: SharedMempoolClient,

active_height: Option<BlockNumber>,
proposal_manager: Box<dyn ProposalManagerTrait>,
propose_tx_streams: HashMap<ProposalId, OutputStreamReceiver>,
validate_tx_streams: HashMap<ProposalId, InputStreamSender>,
Expand All @@ -73,23 +74,52 @@ impl Batcher {
storage_reader,
storage_writer,
mempool_client,
active_height: None,
proposal_manager,
propose_tx_streams: HashMap::new(),
validate_tx_streams: HashMap::new(),
}
}

#[instrument(skip(self), err)]
pub async fn start_height(&mut self, input: StartHeightInput) -> BatcherResult<()> {
if self.active_height == Some(input.height) {
return Err(BatcherError::HeightInProgress);
}

let storage_height =
self.storage_reader.height().map_err(|_| BatcherError::InternalError)?;
if storage_height < input.height {
return Err(BatcherError::StorageNotSynced {
storage_height,
requested_height: input.height,
});
}
if storage_height > input.height {
return Err(BatcherError::HeightAlreadyPassed {
storage_height,
requested_height: input.height,
});
}

// Clear all the proposals from the previous height.
self.proposal_manager.reset().await;
self.propose_tx_streams.clear();
self.validate_tx_streams.clear();
self.proposal_manager.start_height(input.height).await.map_err(BatcherError::from)

info!("Starting to work on height {}.", input.height);
self.active_height = Some(input.height);

Ok(())
}

#[instrument(skip(self), err)]
pub async fn propose_block(
&mut self,
propose_block_input: ProposeBlockInput,
) -> BatcherResult<()> {
let active_height = self.active_height.ok_or(BatcherError::NoActiveHeight)?;

let proposal_id = propose_block_input.proposal_id;
let deadline = deadline_as_instant(propose_block_input.deadline)?;

Expand All @@ -103,6 +133,7 @@ impl Batcher {

self.proposal_manager
.propose_block(
active_height,
proposal_id,
propose_block_input.retrospective_block_hash,
deadline,
Expand All @@ -120,6 +151,8 @@ impl Batcher {
&mut self,
validate_block_input: ValidateBlockInput,
) -> BatcherResult<()> {
let active_height = self.active_height.ok_or(BatcherError::NoActiveHeight)?;

let proposal_id = validate_block_input.proposal_id;
let deadline = deadline_as_instant(validate_block_input.deadline)?;

Expand All @@ -133,6 +166,7 @@ impl Batcher {

self.proposal_manager
.validate_block(
active_height,
proposal_id,
validate_block_input.retrospective_block_hash,
deadline,
Expand Down Expand Up @@ -299,8 +333,7 @@ pub fn create_batcher(config: BatcherConfig, mempool_client: SharedMempoolClient
});
let storage_reader = Arc::new(storage_reader);
let storage_writer = Box::new(storage_writer);
let proposal_manager =
Box::new(ProposalManager::new(block_builder_factory, storage_reader.clone()));
let proposal_manager = Box::new(ProposalManager::new(block_builder_factory));
Batcher::new(config, storage_reader, storage_writer, mempool_client, proposal_manager)
}

Expand Down Expand Up @@ -336,24 +369,6 @@ impl BatcherStorageWriterTrait for papyrus_storage::StorageWriter {
}
}

impl From<StartHeightError> for BatcherError {
fn from(err: StartHeightError) -> Self {
match err {
StartHeightError::HeightAlreadyPassed { storage_height, requested_height } => {
BatcherError::HeightAlreadyPassed { storage_height, requested_height }
}
StartHeightError::StorageError(err) => {
error!("{}", err);
BatcherError::InternalError
}
StartHeightError::StorageNotSynced { storage_height, requested_height } => {
BatcherError::StorageNotSynced { storage_height, requested_height }
}
StartHeightError::HeightInProgress => BatcherError::HeightInProgress,
}
}
}

impl From<GenerateProposalError> for BatcherError {
fn from(err: GenerateProposalError) -> Self {
match err {
Expand Down
131 changes: 110 additions & 21 deletions crates/starknet_batcher/src/batcher_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ use crate::proposal_manager::{
ProposalManagerTrait,
ProposalOutput,
ProposalResult,
StartHeightError,
};
use crate::test_utils::test_txs;
use crate::transaction_provider::{ProposeTransactionProvider, ValidateTransactionProvider};
Expand Down Expand Up @@ -100,11 +99,7 @@ fn batcher(proposal_manager: MockProposalManagerTraitWrapper) -> Batcher {
fn mock_proposal_manager_common_expectations(
proposal_manager: &mut MockProposalManagerTraitWrapper,
) {
proposal_manager
.expect_wrap_start_height()
.times(1)
.with(eq(INITIAL_HEIGHT))
.return_once(|_| async { Ok(()) }.boxed());
proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed());
proposal_manager
.expect_wrap_await_proposal_commitment()
.times(1)
Expand All @@ -118,8 +113,8 @@ fn mock_proposal_manager_validate_flow() -> MockProposalManagerTraitWrapper {
proposal_manager
.expect_wrap_validate_block()
.times(1)
.with(eq(PROPOSAL_ID), eq(None), always(), always())
.return_once(|_, _, _, tx_provider| {
.with(eq(INITIAL_HEIGHT), eq(PROPOSAL_ID), eq(None), always(), always())
.return_once(|_, _, _, _, tx_provider| {
{
async move {
// Spawn a task to keep tx_provider alive until the transactions sender is
Expand All @@ -144,6 +139,89 @@ fn mock_proposal_manager_validate_flow() -> MockProposalManagerTraitWrapper {
proposal_manager
}

#[rstest]
#[case::height_already_passed(
INITIAL_HEIGHT.prev().unwrap(),
Result::Err(BatcherError::HeightAlreadyPassed {
storage_height: INITIAL_HEIGHT,
requested_height: INITIAL_HEIGHT.prev().unwrap()
}
))]
#[case::happy(
INITIAL_HEIGHT,
Result::Ok(())
)]
#[case::storage_not_synced(
INITIAL_HEIGHT.unchecked_next(),
Result::Err(BatcherError::StorageNotSynced {
storage_height: INITIAL_HEIGHT,
requested_height: INITIAL_HEIGHT.unchecked_next()
}
))]
#[tokio::test]
async fn start_height(
#[case] height: BlockNumber,
#[case] expected_result: Result<(), BatcherError>,
) {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
let reset_times = if expected_result.is_ok() { 1 } else { 0 };
proposal_manager.expect_wrap_reset().times(reset_times).returning(|| async {}.boxed());

let mut batcher = batcher(proposal_manager);
let result = batcher.start_height(StartHeightInput { height }).await;
assert_eq!(format!("{:?}", result), format!("{:?}", expected_result));
}

#[rstest]
#[tokio::test]
async fn duplicate_start_height() {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed());

let mut batcher = batcher(proposal_manager);

assert_matches!(
batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await,
Ok(())
);
assert_matches!(
batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await,
Err(BatcherError::HeightInProgress)
);
}

#[rstest]
#[tokio::test]
async fn propose_block_fails_without_start_height() {
let proposal_manager = MockProposalManagerTraitWrapper::new();
let mut batcher = batcher(proposal_manager);

let result = batcher
.propose_block(ProposeBlockInput {
proposal_id: ProposalId(0),
retrospective_block_hash: None,
deadline: chrono::Utc::now() + chrono::Duration::seconds(1),
})
.await;
assert_matches!(result, Err(BatcherError::NoActiveHeight));
}

#[rstest]
#[tokio::test]
async fn validate_proposal_fails_without_start_height() {
let proposal_manager = MockProposalManagerTraitWrapper::new();
let mut batcher = batcher(proposal_manager);

let err = batcher
.validate_block(ValidateBlockInput {
proposal_id: ProposalId(0),
retrospective_block_hash: None,
deadline: chrono::Utc::now() + chrono::Duration::seconds(1),
})
.await;
assert_matches!(err, Err(BatcherError::NoActiveHeight));
}

#[rstest]
#[tokio::test]
async fn validate_block_full_flow() {
Expand Down Expand Up @@ -254,11 +332,12 @@ async fn send_txs_to_an_invalid_proposal() {
#[tokio::test]
async fn send_finish_to_an_invalid_proposal() {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed());
proposal_manager
.expect_wrap_validate_block()
.times(1)
.with(eq(PROPOSAL_ID), eq(None), always(), always())
.return_once(|_, _, _, _| { async move { Ok(()) } }.boxed());
.with(eq(INITIAL_HEIGHT), eq(PROPOSAL_ID), eq(None), always(), always())
.return_once(|_, _, _, _, _| { async move { Ok(()) } }.boxed());

let proposal_error = GetProposalResultError::BlockBuilderError(Arc::new(
BlockBuilderError::FailOnError(FailOnErrorCause::BlockFull),
Expand All @@ -270,6 +349,7 @@ async fn send_finish_to_an_invalid_proposal() {
.return_once(move |_| { async move { Err(proposal_error) } }.boxed());

let mut batcher = batcher(proposal_manager);
batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();

let validate_block_input = ValidateBlockInput {
proposal_id: PROPOSAL_ID,
Expand All @@ -294,7 +374,7 @@ async fn propose_block_full_flow() {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
mock_proposal_manager_common_expectations(&mut proposal_manager);
proposal_manager.expect_wrap_propose_block().times(1).return_once(
move |_proposal_id, _block_hash, _deadline, tx_sender, _tx_provider| {
move |_height, _proposal_id, _block_hash, _deadline, tx_sender, _tx_provider| {
simulate_build_block_proposal(tx_sender, txs_to_stream).boxed()
},
);
Expand Down Expand Up @@ -437,13 +517,9 @@ async fn simulate_build_block_proposal(
// A wrapper trait to allow mocking the ProposalManagerTrait in tests.
#[automock]
trait ProposalManagerTraitWrapper: Send + Sync {
fn wrap_start_height(
&mut self,
height: BlockNumber,
) -> BoxFuture<'_, Result<(), StartHeightError>>;

fn wrap_propose_block(
&mut self,
height: BlockNumber,
proposal_id: ProposalId,
retrospective_block_hash: Option<BlockHashAndNumber>,
deadline: tokio::time::Instant,
Expand All @@ -453,6 +529,7 @@ trait ProposalManagerTraitWrapper: Send + Sync {

fn wrap_validate_block(
&mut self,
height: BlockNumber,
proposal_id: ProposalId,
retrospective_block_hash: Option<BlockHashAndNumber>,
deadline: tokio::time::Instant,
Expand All @@ -475,23 +552,23 @@ trait ProposalManagerTraitWrapper: Send + Sync {
) -> BoxFuture<'_, ProposalResult<ProposalCommitment>>;

fn wrap_abort_proposal(&mut self, proposal_id: ProposalId) -> BoxFuture<'_, ()>;

fn wrap_reset(&mut self) -> BoxFuture<'_, ()>;
}

#[async_trait]
impl<T: ProposalManagerTraitWrapper> ProposalManagerTrait for T {
async fn start_height(&mut self, height: BlockNumber) -> Result<(), StartHeightError> {
self.wrap_start_height(height).await
}

async fn propose_block(
&mut self,
height: BlockNumber,
proposal_id: ProposalId,
retrospective_block_hash: Option<BlockHashAndNumber>,
deadline: tokio::time::Instant,
output_content_sender: tokio::sync::mpsc::UnboundedSender<Transaction>,
tx_provider: ProposeTransactionProvider,
) -> Result<(), GenerateProposalError> {
self.wrap_propose_block(
height,
proposal_id,
retrospective_block_hash,
deadline,
Expand All @@ -503,12 +580,20 @@ impl<T: ProposalManagerTraitWrapper> ProposalManagerTrait for T {

async fn validate_block(
&mut self,
height: BlockNumber,
proposal_id: ProposalId,
retrospective_block_hash: Option<BlockHashAndNumber>,
deadline: tokio::time::Instant,
tx_provider: ValidateTransactionProvider,
) -> Result<(), GenerateProposalError> {
self.wrap_validate_block(proposal_id, retrospective_block_hash, deadline, tx_provider).await
self.wrap_validate_block(
height,
proposal_id,
retrospective_block_hash,
deadline,
tx_provider,
)
.await
}

async fn take_proposal_result(
Expand All @@ -532,6 +617,10 @@ impl<T: ProposalManagerTraitWrapper> ProposalManagerTrait for T {
async fn abort_proposal(&mut self, proposal_id: ProposalId) {
self.wrap_abort_proposal(proposal_id).await
}

async fn reset(&mut self) {
self.wrap_reset().await
}
}

fn test_tx_hashes(range: std::ops::Range<u128>) -> HashSet<TransactionHash> {
Expand Down
Loading

0 comments on commit 2ff3249

Please sign in to comment.