Skip to content

Commit

Permalink
feat: add infra proc macro for matching responses
Browse files Browse the repository at this point in the history
commit-id:65ba4744
  • Loading branch information
Itay-Tsabary-Starkware committed Jul 30, 2024
1 parent 0476a55 commit 254491b
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 37 deletions.
9 changes: 9 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"crates/committer",
"crates/committer_cli",
"crates/gateway",
"crates/infra_proc_macros",
"crates/mempool",
"crates/mempool_infra",
"crates/mempool_node",
Expand Down Expand Up @@ -131,6 +132,7 @@ prost-build = "0.12.1"
prost-types = "0.12.1"
pyo3 = "0.19.1"
pyo3-log = "0.8.1"
quote = "1.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
rand_distr = "0.4.3"
Expand All @@ -155,6 +157,7 @@ static_assertions = "1.1.0"
statistical = "1.0.0"
strum = "0.25.0"
strum_macros = "0.25.2"
syn = "1.0"
tempfile = "3.7.0"
test-case = "3.2.1"
test-log = "0.2.14"
Expand All @@ -173,6 +176,7 @@ validator = "0.12"
void = "1.0.2"
zstd = "0.13.1"


[workspace.lints.rust]
future-incompatible = "deny"
nonstandard-style = "deny"
Expand Down
16 changes: 16 additions & 0 deletions crates/infra_proc_macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "starknet_infra_proc_macros"
version.workspace = true
edition.workspace = true
repository.workspace = true
license.workspace = true

[dependencies]
quote.workspace = true
syn.workspace = true

[lib]
proc-macro = true

[lints]
workspace = true
66 changes: 66 additions & 0 deletions crates/infra_proc_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};
use syn::{parse_macro_input, Ident, Token};

struct MacroInput {
response_enum: Ident,
invocation_name: Ident,
component_client_error: Ident,
component_error: Ident,
}

impl Parse for MacroInput {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let response_enum = input.parse()?;
input.parse::<Token![,]>()?;
let invocation_name = input.parse()?;
input.parse::<Token![,]>()?;
let component_client_error = input.parse()?;
input.parse::<Token![,]>()?;
let component_error = input.parse()?;
Ok(MacroInput { response_enum, invocation_name, component_client_error, component_error })
}
}

/// A macro for generating code that handles the received communication response.
/// Takes the following arguments:
/// * response_enum -- the response enum type
/// * invocation_name -- the request/response enum variant that was invoked
/// * component_client_error -- the component's client error type
/// * component_error -- the component's error type
///
/// For example, the following input:
/// """
/// handle_response_variants!(MempoolResponse, GetTransactions, MempoolClientError, MempoolError)
/// """
///
/// Results in:
/// """
/// match response {
/// MempoolResponse::GetTransactions(Ok(response)) => Ok(response),
/// MempoolResponse::GetTransactions(Err(response)) => {
/// Err(MempoolClientError::MempoolError(response))
/// }
/// unexpected_response => Err(MempoolClientError::ClientError(
/// ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
/// )),
/// }
/// """
#[proc_macro]
pub fn handle_response_variants(input: TokenStream) -> TokenStream {
let MacroInput { response_enum, invocation_name, component_client_error, component_error } =
parse_macro_input!(input as MacroInput);

let expanded = quote! {
match response {
#response_enum::#invocation_name(Ok(response)) => Ok(response),
#response_enum::#invocation_name(Err(response)) => {
Err(#component_client_error::#component_error(response))
}
unexpected_response => Err(#component_client_error::ClientError(ClientError::UnexpectedResponse(format!("{unexpected_response:?}")))),
}
};

TokenStream::from(expanded)
}
3 changes: 2 additions & 1 deletion crates/mempool_types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ workspace = true

[dependencies]
async-trait.workspace = true
starknet_api = { path = "../starknet_api", version = "0.13.0-rc.0"}
mockall.workspace = true
serde = { workspace = true, feat = ["derive"] }
starknet_api = { path = "../starknet_api", version = "0.13.0-rc.0" }
starknet_infra_proc_macros = { path = "../infra_proc_macros" }
starknet_mempool_infra = { path = "../mempool_infra" }
thiserror.workspace = true
51 changes: 15 additions & 36 deletions crates/mempool_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use async_trait::async_trait;
use mockall::predicate::*;
use mockall::*;
use serde::{Deserialize, Serialize};
use starknet_infra_proc_macros::handle_response_variants;
use starknet_mempool_infra::component_client::{
ClientError,
LocalComponentClient,
Expand Down Expand Up @@ -57,29 +58,18 @@ impl MempoolClient for MempoolClientImpl {
async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()> {
let request = MempoolRequest::AddTransaction(mempool_input);
let response = self.send(request).await;
match response {
MempoolResponse::AddTransaction(Ok(response)) => Ok(response),
MempoolResponse::AddTransaction(Err(response)) => {
Err(MempoolClientError::MempoolError(response))
}
unexpected_response => Err(MempoolClientError::ClientError(
ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
)),
}
handle_response_variants!(MempoolResponse, AddTransaction, MempoolClientError, MempoolError)
}

async fn get_txs(&self, n_txs: usize) -> MempoolClientResult<Vec<ThinTransaction>> {
let request = MempoolRequest::GetTransactions(n_txs);
let response = self.send(request).await;
match response {
MempoolResponse::GetTransactions(Ok(response)) => Ok(response),
MempoolResponse::GetTransactions(Err(response)) => {
Err(MempoolClientError::MempoolError(response))
}
unexpected_response => Err(MempoolClientError::ClientError(
ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
)),
}
handle_response_variants!(
MempoolResponse,
GetTransactions,
MempoolClientError,
MempoolError
)
}
}

Expand All @@ -88,28 +78,17 @@ impl MempoolClient for RemoteMempoolClientImpl {
async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()> {
let request = MempoolRequest::AddTransaction(mempool_input);
let response = self.send(request).await?;
match response {
MempoolResponse::AddTransaction(Ok(response)) => Ok(response),
MempoolResponse::AddTransaction(Err(response)) => {
Err(MempoolClientError::MempoolError(response))
}
unexpected_response => Err(MempoolClientError::ClientError(
ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
)),
}
handle_response_variants!(MempoolResponse, AddTransaction, MempoolClientError, MempoolError)
}

async fn get_txs(&self, n_txs: usize) -> MempoolClientResult<Vec<ThinTransaction>> {
let request = MempoolRequest::GetTransactions(n_txs);
let response = self.send(request).await?;
match response {
MempoolResponse::GetTransactions(Ok(response)) => Ok(response),
MempoolResponse::GetTransactions(Err(response)) => {
Err(MempoolClientError::MempoolError(response))
}
unexpected_response => Err(MempoolClientError::ClientError(
ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
)),
}
handle_response_variants!(
MempoolResponse,
GetTransactions,
MempoolClientError,
MempoolError
)
}
}

0 comments on commit 254491b

Please sign in to comment.