Skip to content

Commit

Permalink
refactor(starknet_batcher): add tests and refactor batcher_test
Browse files Browse the repository at this point in the history
  • Loading branch information
dafnamatsry committed Dec 10, 2024
1 parent cfc0dcc commit 2556745
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 122 deletions.
235 changes: 116 additions & 119 deletions crates/starknet_batcher/src/batcher_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::Arc;
use assert_matches::assert_matches;
use blockifier::abi::constants;
use blockifier::test_utils::struct_impls::BlockInfoExt;
use chrono::Utc;
use mockall::predicate::eq;
use rstest::rstest;
use starknet_api::block::{BlockInfo, BlockNumber};
Expand Down Expand Up @@ -35,11 +34,9 @@ use crate::block_builder::{
BlockExecutionArtifacts,
FailOnErrorCause,
MockBlockBuilderFactoryTrait,
MockBlockBuilderTrait,
};
use crate::config::BatcherConfig;
use crate::test_utils::test_txs;
use crate::transaction_provider::NextTxs;
use crate::test_utils::{test_txs, FakeProposeBlockBuilder, FakeValidateBlockBuilder};
use crate::utils::ProposalOutput;

const INITIAL_HEIGHT: BlockNumber = BlockNumber(3);
Expand All @@ -49,16 +46,26 @@ const PROPOSAL_ID: ProposalId = ProposalId(0);
const BUILD_BLOCK_FAIL_ON_ERROR: BlockBuilderError =
BlockBuilderError::FailOnError(FailOnErrorCause::BlockFull);

fn initial_block_info() -> BlockInfo {
BlockInfo { block_number: INITIAL_HEIGHT, ..BlockInfo::create_for_testing() }
}

fn proposal_commitment() -> ProposalCommitment {
ProposalOutput::from(BlockExecutionArtifacts::create_for_testing()).commitment
}

fn deadline() -> chrono::DateTime<Utc> {
chrono::Utc::now() + BLOCK_GENERATION_TIMEOUT
fn propose_block_input(proposal_id: ProposalId) -> ProposeBlockInput {
ProposeBlockInput {
proposal_id,
retrospective_block_hash: None,
deadline: chrono::Utc::now() + BLOCK_GENERATION_TIMEOUT,
block_info: BlockInfo { block_number: INITIAL_HEIGHT, ..BlockInfo::create_for_testing() },
}
}

fn validate_block_input(proposal_id: ProposalId) -> ValidateBlockInput {
ValidateBlockInput {
proposal_id,
retrospective_block_hash: None,
deadline: chrono::Utc::now() + BLOCK_GENERATION_TIMEOUT,
block_info: BlockInfo { block_number: INITIAL_HEIGHT, ..BlockInfo::create_for_testing() },
}
}

struct MockDependencies {
Expand Down Expand Up @@ -96,67 +103,48 @@ fn abort_signal_sender() -> AbortSignalSender {
}

fn mock_create_builder_for_validate_block(
block_builder_factory: &mut MockBlockBuilderFactoryTrait,
build_block_result: BlockBuilderResult<BlockExecutionArtifacts>,
) -> MockBlockBuilderFactoryTrait {
let mut block_builder_factory = MockBlockBuilderFactoryTrait::new();
) {
block_builder_factory.expect_create_block_builder().times(1).return_once(
|_, _, mut tx_provider, _| {
let mut block_builder = MockBlockBuilderTrait::new();
block_builder.expect_build_block().times(1).return_once(move || {
// Spawn a task to keep tx_provider alive until all transactions are read.
// Without this, the provider would be dropped, causing the batcher to fail when
// sending transactions to it during the test.
tokio::spawn(async move {
while tx_provider.get_txs(1).await.is_ok_and(|v| v != NextTxs::End) {
tokio::task::yield_now().await;
}
});
build_block_result
});
|_, _, tx_provider, _| {
let block_builder = FakeValidateBlockBuilder {
tx_provider,
build_block_result: Some(build_block_result),
};
Ok((Box::new(block_builder), abort_signal_sender()))
},
);
block_builder_factory
}

fn mock_create_builder_for_propose_block(
block_builder_factory: &mut MockBlockBuilderFactoryTrait,
output_txs: Vec<Transaction>,
build_block_result: BlockBuilderResult<BlockExecutionArtifacts>,
) -> MockBlockBuilderFactoryTrait {
let mut block_builder_factory = MockBlockBuilderFactoryTrait::new();
) {
block_builder_factory.expect_create_block_builder().times(1).return_once(
|_, _, _, output_content_sender| {
let mut block_builder = MockBlockBuilderTrait::new();
block_builder.expect_build_block().times(1).return_once(move || {
// Simulate the streaming of the block builder output.
for tx in output_txs {
output_content_sender.as_ref().unwrap().send(tx).unwrap();
}
build_block_result
});
move |_, _, _, output_content_sender| {
let block_builder = FakeProposeBlockBuilder {
output_content_sender: output_content_sender.unwrap(),
output_txs,
build_block_result: Some(build_block_result),
};
Ok((Box::new(block_builder), abort_signal_sender()))
},
);
block_builder_factory
}

async fn batcher_with_validated_proposal(
async fn batcher_with_active_validate_block(
build_block_result: BlockBuilderResult<BlockExecutionArtifacts>,
) -> Batcher {
let block_builder_factory = mock_create_builder_for_validate_block(build_block_result);
let mut block_builder_factory = MockBlockBuilderFactoryTrait::new();
mock_create_builder_for_validate_block(&mut block_builder_factory, build_block_result);

let mut batcher =
create_batcher(MockDependencies { block_builder_factory, ..Default::default() });

batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();

let validate_block_input = ValidateBlockInput {
proposal_id: PROPOSAL_ID,
deadline: deadline(),
retrospective_block_hash: None,
block_info: initial_block_info(),
};
batcher.validate_block(validate_block_input).await.unwrap();
batcher.validate_block(validate_block_input(PROPOSAL_ID)).await.unwrap();

batcher
}
Expand Down Expand Up @@ -206,70 +194,60 @@ async fn no_active_height() {

// Calling `propose_block` and `validate_block` without starting a height should fail.

let result = batcher
.propose_block(ProposeBlockInput {
proposal_id: ProposalId(0),
retrospective_block_hash: None,
deadline: chrono::Utc::now() + chrono::Duration::seconds(1),
block_info: Default::default(),
})
.await;
let result = batcher.propose_block(propose_block_input(PROPOSAL_ID)).await;
assert_eq!(result, Err(BatcherError::NoActiveHeight));

let result = batcher
.validate_block(ValidateBlockInput {
proposal_id: ProposalId(0),
retrospective_block_hash: None,
deadline: chrono::Utc::now() + chrono::Duration::seconds(1),
block_info: Default::default(),
})
.await;
let result = batcher.validate_block(validate_block_input(PROPOSAL_ID)).await;
assert_eq!(result, Err(BatcherError::NoActiveHeight));
}

#[rstest]
#[tokio::test]
async fn validate_block_full_flow() {
let block_builder_factory =
mock_create_builder_for_validate_block(Ok(BlockExecutionArtifacts::create_for_testing()));
let mut batcher =
create_batcher(MockDependencies { block_builder_factory, ..Default::default() });

batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();

let validate_block_input = ValidateBlockInput {
proposal_id: PROPOSAL_ID,
deadline: deadline(),
retrospective_block_hash: None,
block_info: initial_block_info(),
};
batcher.validate_block(validate_block_input).await.unwrap();
batcher_with_active_validate_block(Ok(BlockExecutionArtifacts::create_for_testing())).await;

let send_proposal_input_txs = SendProposalContentInput {
proposal_id: PROPOSAL_ID,
content: SendProposalContent::Txs(test_txs(0..1)),
};
let send_txs_result = batcher.send_proposal_content(send_proposal_input_txs).await.unwrap();
assert_eq!(
send_txs_result,
batcher.send_proposal_content(send_proposal_input_txs).await.unwrap(),
SendProposalContentResponse { response: ProposalStatus::Processing }
);

let send_proposal_input_finish =
let finish_proposal =
SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Finish };
let send_finish_result =
batcher.send_proposal_content(send_proposal_input_finish).await.unwrap();
assert_eq!(
send_finish_result,
batcher.send_proposal_content(finish_proposal).await.unwrap(),
SendProposalContentResponse { response: ProposalStatus::Finished(proposal_commitment()) }
);
}

#[rstest]
#[tokio::test]
async fn send_proposal_content_abort() {
let mut batcher =
batcher_with_active_validate_block(Ok(BlockExecutionArtifacts::create_for_testing())).await;

let send_abort_proposal =
SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Abort };
assert_eq!(
batcher.send_proposal_content(send_abort_proposal).await.unwrap(),
SendProposalContentResponse { response: ProposalStatus::Aborted }
);
}

#[rstest]
#[tokio::test]
async fn send_content_after_proposal_already_finished() {
let successful_build_block_result = Ok(BlockExecutionArtifacts::create_for_testing());
let mut batcher = batcher_with_validated_proposal(successful_build_block_result).await;
let mut batcher =
batcher_with_active_validate_block(Ok(BlockExecutionArtifacts::create_for_testing())).await;

// Finish the proposal, and wait for it to complete.
let finish_proposal =
SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Finish };
batcher.send_proposal_content(finish_proposal).await.unwrap();
batcher.await_active_proposal().await;

// Send transactions after the proposal has finished.
Expand All @@ -278,7 +256,7 @@ async fn send_content_after_proposal_already_finished() {
content: SendProposalContent::Txs(test_txs(0..1)),
};
let result = batcher.send_proposal_content(send_proposal_input_txs).await;
assert_eq!(result, Err(BatcherError::ProposalAlreadyFinished { proposal_id: PROPOSAL_ID }));
assert_eq!(result, Err(BatcherError::ProposalNotFound { proposal_id: PROPOSAL_ID }));
}

#[rstest]
Expand All @@ -304,7 +282,7 @@ async fn send_content_to_unknown_proposal() {
#[rstest]
#[tokio::test]
async fn send_txs_to_an_invalid_proposal() {
let mut batcher = batcher_with_validated_proposal(Err(BUILD_BLOCK_FAIL_ON_ERROR)).await;
let mut batcher = batcher_with_active_validate_block(Err(BUILD_BLOCK_FAIL_ON_ERROR)).await;
batcher.await_active_proposal().await;

let send_proposal_input_txs = SendProposalContentInput {
Expand All @@ -318,7 +296,7 @@ async fn send_txs_to_an_invalid_proposal() {
#[rstest]
#[tokio::test]
async fn send_finish_to_an_invalid_proposal() {
let mut batcher = batcher_with_validated_proposal(Err(BUILD_BLOCK_FAIL_ON_ERROR)).await;
let mut batcher = batcher_with_active_validate_block(Err(BUILD_BLOCK_FAIL_ON_ERROR)).await;

let send_proposal_input_txs =
SendProposalContentInput { proposal_id: PROPOSAL_ID, content: SendProposalContent::Finish };
Expand All @@ -331,26 +309,19 @@ async fn send_finish_to_an_invalid_proposal() {
async fn propose_block_full_flow() {
// Expecting 3 chunks of streamed txs.
let expected_streamed_txs = test_txs(0..STREAMING_CHUNK_SIZE * 2 + 1);
let txs_to_stream = expected_streamed_txs.clone();

let block_builder_factory = mock_create_builder_for_propose_block(
txs_to_stream,
let mut block_builder_factory = MockBlockBuilderFactoryTrait::new();
mock_create_builder_for_propose_block(
&mut block_builder_factory,
expected_streamed_txs.clone(),
Ok(BlockExecutionArtifacts::create_for_testing()),
);

let mut batcher =
create_batcher(MockDependencies { block_builder_factory, ..Default::default() });

batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();
batcher
.propose_block(ProposeBlockInput {
proposal_id: PROPOSAL_ID,
retrospective_block_hash: None,
deadline: chrono::Utc::now() + chrono::Duration::seconds(1),
block_info: initial_block_info(),
})
.await
.unwrap();
batcher.propose_block(propose_block_input(PROPOSAL_ID)).await.unwrap();

let expected_n_chunks = expected_streamed_txs.len().div_ceil(STREAMING_CHUNK_SIZE);
let mut aggregated_streamed_txs = Vec::new();
Expand Down Expand Up @@ -394,14 +365,7 @@ async fn propose_block_without_retrospective_block_hash() {
.start_height(StartHeightInput { height: BlockNumber(constants::STORED_BLOCK_HASH_BUFFER) })
.await
.unwrap();
let result = batcher
.propose_block(ProposeBlockInput {
proposal_id: PROPOSAL_ID,
retrospective_block_hash: None,
deadline: deadline(),
block_info: Default::default(),
})
.await;
let result = batcher.propose_block(propose_block_input(PROPOSAL_ID)).await;

assert_matches!(result, Err(BatcherError::MissingRetrospectiveBlockHash));
}
Expand All @@ -416,6 +380,44 @@ async fn get_content_from_unknown_proposal() {
assert_eq!(result, Err(BatcherError::ProposalNotFound { proposal_id: PROPOSAL_ID }));
}

#[rstest]
#[tokio::test]
async fn consecutive_proposal_generation_success() {
let mut block_builder_factory = MockBlockBuilderFactoryTrait::new();
mock_create_builder_for_propose_block(
&mut block_builder_factory,
vec![],
Ok(BlockExecutionArtifacts::create_for_testing()),
);
mock_create_builder_for_validate_block(
&mut block_builder_factory,
Ok(BlockExecutionArtifacts::create_for_testing()),
);
let mut batcher =
create_batcher(MockDependencies { block_builder_factory, ..Default::default() });

batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();

// Generate the first proposal, and wait for it to complete.
batcher.propose_block(propose_block_input(ProposalId(0))).await.unwrap();
batcher.await_active_proposal().await;

// Make sure another proposal can be generated after the first one finished.
batcher.validate_block(validate_block_input(ProposalId(1))).await.unwrap();
}

#[rstest]
#[tokio::test]
async fn concurrent_proposals_generation_fail() {
let mut batcher =
batcher_with_active_validate_block(Ok(BlockExecutionArtifacts::create_for_testing())).await;

// Make sure another proposal can't be generated while the first one is still active.
let result = batcher.propose_block(propose_block_input(ProposalId(1))).await;

assert_matches!(result, Err(BatcherError::ServerBusy { .. }));
}

#[rstest]
#[tokio::test]
async fn decision_reached() {
Expand All @@ -438,20 +440,15 @@ async fn decision_reached() {
.with(eq(INITIAL_HEIGHT), eq(expected_proposal_output.state_diff))
.returning(|_, _| Ok(()));

mock_dependencies.block_builder_factory =
mock_create_builder_for_validate_block(Ok(BlockExecutionArtifacts::create_for_testing()));
mock_create_builder_for_propose_block(
&mut mock_dependencies.block_builder_factory,
vec![],
Ok(BlockExecutionArtifacts::create_for_testing()),
);

let mut batcher = create_batcher(mock_dependencies);
batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();
batcher
.propose_block(ProposeBlockInput {
proposal_id: PROPOSAL_ID,
retrospective_block_hash: None,
deadline: deadline(),
block_info: initial_block_info(),
})
.await
.unwrap();
batcher.propose_block(propose_block_input(PROPOSAL_ID)).await.unwrap();
batcher.await_active_proposal().await;

batcher.decision_reached(DecisionReachedInput { proposal_id: PROPOSAL_ID }).await.unwrap();
Expand Down
Loading

0 comments on commit 2556745

Please sign in to comment.