Skip to content

Commit

Permalink
refactor(mempool): wrap mempool_input with optional broadcasted messa…
Browse files Browse the repository at this point in the history
…ge metadata (#746)
  • Loading branch information
AlonLStarkWare authored Sep 17, 2024
1 parent eab9d28 commit 8aa893b
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 24 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions crates/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use starknet_api::rpc_transaction::RpcTransaction;
use starknet_api::transaction::TransactionHash;
use starknet_gateway_types::errors::GatewaySpecError;
use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter};
use starknet_mempool_types::communication::SharedMempoolClient;
use starknet_mempool_types::communication::{MempoolWrapperInput, SharedMempoolClient};
use starknet_mempool_types::mempool_types::{Account, AccountState, MempoolInput};
use starknet_sierra_compile::config::SierraToCasmCompilationConfig;
use tracing::{error, info, instrument};
Expand Down Expand Up @@ -127,7 +127,8 @@ async fn internal_add_tx(

let tx_hash = mempool_input.tx.tx_hash();

app_state.mempool_client.add_tx(mempool_input).await.map_err(|e| {
let mempool_wrapper_input = MempoolWrapperInput { mempool_input, message_metadata: None };
app_state.mempool_client.add_tx(mempool_wrapper_input).await.map_err(|e| {
error!("Failed to send tx to mempool: {}", e);
GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() }
})?;
Expand Down
24 changes: 14 additions & 10 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use starknet_api::core::{CompiledClassHash, ContractAddress};
use starknet_api::rpc_transaction::{RpcDeclareTransaction, RpcTransaction};
use starknet_api::transaction::{TransactionHash, ValidResourceBounds};
use starknet_gateway_types::errors::GatewaySpecError;
use starknet_mempool_types::communication::MockMempoolClient;
use starknet_mempool_types::communication::{MempoolWrapperInput, MockMempoolClient};
use starknet_mempool_types::mempool_types::{Account, AccountState, MempoolInput};
use starknet_sierra_compile::config::SierraToCasmCompilationConfig;

Expand Down Expand Up @@ -57,6 +57,7 @@ fn create_tx() -> (RpcTransaction, SenderAddress) {
(tx, sender_address)
}

// TODO: add test with Some broadcasted message metadata
#[tokio::test]
async fn test_add_tx() {
let (tx, sender_address) = create_tx();
Expand All @@ -66,17 +67,20 @@ async fn test_add_tx() {
mock_mempool_client
.expect_add_tx()
.once()
.with(eq(MempoolInput {
.with(eq(MempoolWrapperInput {
// TODO(Arni): Use external_to_executable_tx instead of `create_executable_tx`. Consider
// creating a `convertor for testing` that does not do the compilation.
tx: create_executable_tx(
sender_address,
tx_hash,
*tx.tip(),
*tx.nonce(),
ValidResourceBounds::AllResources(*tx.resource_bounds()),
),
account: Account { sender_address, state: AccountState { nonce: *tx.nonce() } },
mempool_input: MempoolInput {
tx: create_executable_tx(
sender_address,
tx_hash,
*tx.tip(),
*tx.nonce(),
ValidResourceBounds::AllResources(*tx.resource_bounds()),
),
account: Account { sender_address, state: AccountState { nonce: *tx.nonce() } },
},
message_metadata: None,
}))
.return_once(|_| Ok(()));
let state_reader_factory = local_test_state_reader_factory(CairoVersion::Cairo1, false);
Expand Down
11 changes: 6 additions & 5 deletions crates/mempool/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ use starknet_mempool_types::communication::{
MempoolRequest,
MempoolRequestAndResponseSender,
MempoolResponse,
MempoolWrapperInput,
};
use starknet_mempool_types::mempool_types::{MempoolInput, MempoolResult};
use starknet_mempool_types::mempool_types::MempoolResult;
use tokio::sync::mpsc::Receiver;

use crate::mempool::Mempool;
Expand Down Expand Up @@ -48,8 +49,8 @@ impl MempoolCommunicationWrapper {
MempoolCommunicationWrapper { mempool }
}

fn add_tx(&mut self, mempool_input: MempoolInput) -> MempoolResult<()> {
self.mempool.add_tx(mempool_input)
fn add_tx(&mut self, mempool_wrapper_input: MempoolWrapperInput) -> MempoolResult<()> {
self.mempool.add_tx(mempool_wrapper_input.mempool_input)
}

fn get_txs(&mut self, n_txs: usize) -> MempoolResult<Vec<Transaction>> {
Expand All @@ -61,8 +62,8 @@ impl MempoolCommunicationWrapper {
impl ComponentRequestHandler<MempoolRequest, MempoolResponse> for MempoolCommunicationWrapper {
async fn handle_request(&mut self, request: MempoolRequest) -> MempoolResponse {
match request {
MempoolRequest::AddTransaction(mempool_input) => {
MempoolResponse::AddTransaction(self.add_tx(mempool_input))
MempoolRequest::AddTransaction(mempool_wrapper_input) => {
MempoolResponse::AddTransaction(self.add_tx(mempool_wrapper_input))
}
MempoolRequest::GetTransactions(n_txs) => {
MempoolResponse::GetTransactions(self.get_txs(n_txs))
Expand Down
1 change: 1 addition & 0 deletions crates/mempool_types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ workspace = true
[dependencies]
async-trait.workspace = true
mockall.workspace = true
papyrus_network.workspace = true
papyrus_proc_macros.workspace = true
serde = { workspace = true, features = ["derive"] }
starknet_api.workspace = true
Expand Down
19 changes: 13 additions & 6 deletions crates/mempool_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use async_trait::async_trait;
use mockall::predicate::*;
use mockall::*;
use papyrus_network::network_manager::BroadcastedMessageManager;
use papyrus_proc_macros::handle_response_variants;
use serde::{Deserialize, Serialize};
use starknet_api::executable_transaction::Transaction;
Expand All @@ -25,20 +26,26 @@ pub type MempoolRequestAndResponseSender =
ComponentRequestAndResponseSender<MempoolRequest, MempoolResponse>;
pub type SharedMempoolClient = Arc<dyn MempoolClient>;

#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct MempoolWrapperInput {
pub mempool_input: MempoolInput,
pub message_metadata: Option<BroadcastedMessageManager>,
}

/// Serves as the mempool's shared interface. Requires `Send + Sync` to allow transferring and
/// sharing resources (inputs, futures) across threads.
#[automock]
#[async_trait]
pub trait MempoolClient: Send + Sync {
// TODO: Add Option<BroadcastedMessageManager> as an argument for add_transaction
// TODO: Rename tx to transaction
async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()>;
async fn add_tx(&self, mempool_input: MempoolWrapperInput) -> MempoolClientResult<()>;
async fn get_txs(&self, n_txs: usize) -> MempoolClientResult<Vec<Transaction>>;
}

#[derive(Debug, Serialize, Deserialize)]
pub enum MempoolRequest {
AddTransaction(MempoolInput),
AddTransaction(MempoolWrapperInput),
GetTransactions(usize),
}

Expand All @@ -58,8 +65,8 @@ pub enum MempoolClientError {

#[async_trait]
impl MempoolClient for LocalMempoolClientImpl {
async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()> {
let request = MempoolRequest::AddTransaction(mempool_input);
async fn add_tx(&self, mempool_wrapper_input: MempoolWrapperInput) -> MempoolClientResult<()> {
let request = MempoolRequest::AddTransaction(mempool_wrapper_input);
let response = self.send(request).await;
handle_response_variants!(MempoolResponse, AddTransaction, MempoolClientError, MempoolError)
}
Expand All @@ -78,8 +85,8 @@ impl MempoolClient for LocalMempoolClientImpl {

#[async_trait]
impl MempoolClient for RemoteMempoolClientImpl {
async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()> {
let request = MempoolRequest::AddTransaction(mempool_input);
async fn add_tx(&self, mempool_wrapper_input: MempoolWrapperInput) -> MempoolClientResult<()> {
let request = MempoolRequest::AddTransaction(mempool_wrapper_input);
let response = self.send(request).await?;
handle_response_variants!(MempoolResponse, AddTransaction, MempoolClientError, MempoolError)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/papyrus_network/src/network_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ pub type BroadcastTopicSender<T> = With<
>;

// TODO(alonl): remove clone
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BroadcastedMessageManager {
peer_id: PeerId,
}
Expand Down

0 comments on commit 8aa893b

Please sign in to comment.