diff --git a/Cargo.lock b/Cargo.lock index f80d0d87c22..d0a5f6b28d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8969,6 +8969,15 @@ dependencies = [ "validator", ] +[[package]] +name = "starknet_infra_proc_macros" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "starknet_mempool" version = "0.0.0" @@ -9066,6 +9075,7 @@ dependencies = [ "mockall", "serde", "starknet_api", + "starknet_infra_proc_macros", "starknet_mempool_infra", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index 0a545b6278e..5f27e57a91b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/committer", "crates/committer_cli", "crates/gateway", + "crates/infra_proc_macros", "crates/mempool", "crates/mempool_infra", "crates/mempool_node", @@ -125,12 +126,14 @@ paste = "1.0.15" phf = { version = "0.11", features = ["macros"] } pretty_assertions = "1.4.0" primitive-types = "0.12.1" +proc-macro2 = "1.0" prometheus-parse = "0.2.4" prost = "0.12.1" prost-build = "0.12.1" prost-types = "0.12.1" pyo3 = "0.19.1" pyo3-log = "0.8.1" +quote = "1.0" rand = "0.8.5" rand_chacha = "0.3.1" rand_distr = "0.4.3" @@ -155,6 +158,7 @@ static_assertions = "1.1.0" statistical = "1.0.0" strum = "0.25.0" strum_macros = "0.25.2" +syn = "1.0" tempfile = "3.7.0" test-case = "3.2.1" test-log = "0.2.14" @@ -173,6 +177,7 @@ validator = "0.12" void = "1.0.2" zstd = "0.13.1" + [workspace.lints.rust] future-incompatible = "deny" nonstandard-style = "deny" diff --git a/crates/infra_proc_macros/Cargo.toml b/crates/infra_proc_macros/Cargo.toml new file mode 100644 index 00000000000..15e949205ab --- /dev/null +++ b/crates/infra_proc_macros/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "starknet_infra_proc_macros" +version.workspace = true +edition.workspace = true +repository.workspace = true +license.workspace = true + +[dependencies] +proc-macro2.workspace = true +quote.workspace = true +syn.workspace = true + +[lib] +proc-macro = true + +[lints] +workspace = true diff --git a/crates/infra_proc_macros/src/lib.rs b/crates/infra_proc_macros/src/lib.rs new file mode 100644 index 00000000000..030855c9fa2 --- /dev/null +++ b/crates/infra_proc_macros/src/lib.rs @@ -0,0 +1,66 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::parse::{Parse, ParseStream, Result}; +use syn::{parse_macro_input, Ident, Token}; + +struct MacroInput { + response_enum: Ident, + invocation_name: Ident, + component_client_error: Ident, + component_error: Ident, +} + +impl Parse for MacroInput { + fn parse(input: ParseStream<'_>) -> Result { + let response_enum = input.parse()?; + input.parse::()?; + let invocation_name = input.parse()?; + input.parse::()?; + let component_client_error = input.parse()?; + input.parse::()?; + let component_error = input.parse()?; + Ok(MacroInput { response_enum, invocation_name, component_client_error, component_error }) + } +} + +/// A macro for generating code that handles the received communication response. +/// Takes the following arguments: +/// * response_enum -- the response enum type +/// * invocation_name -- the request/response enum variant that was invoked +/// * component_client_error -- the component's client error type +/// * component_error -- the component's error type +/// +/// For example, the following input: +/// ``` +/// handle_response_variants!(MempoolResponse, GetTransactions, MempoolClientError, MempoolError) +/// ``` +/// +/// Results in: +/// ``` +/// match response { +/// MempoolResponse::GetTransactions(Ok(response)) => Ok(response), +/// MempoolResponse::GetTransactions(Err(response)) => { +/// Err(MempoolClientError::MempoolError(response)) +/// } +/// unexpected_response => Err(MempoolClientError::ClientError( +/// ClientError::UnexpectedResponse(format!("{unexpected_response:?}")), +/// )), +/// } +/// ``` +#[proc_macro] +pub fn handle_response_variants(input: TokenStream) -> TokenStream { + let MacroInput { response_enum, invocation_name, component_client_error, component_error } = + parse_macro_input!(input as MacroInput); + + let expanded = quote! { + match response { + #response_enum::#invocation_name(Ok(response)) => Ok(response), + #response_enum::#invocation_name(Err(response)) => { + Err(#component_client_error::#component_error(response)) + } + unexpected_response => Err(#component_client_error::ClientError(ClientError::UnexpectedResponse(format!("{unexpected_response:?}")))), + } + }; + + TokenStream::from(expanded) +} diff --git a/crates/mempool_types/Cargo.toml b/crates/mempool_types/Cargo.toml index e82e6dba984..74d2f3f8456 100644 --- a/crates/mempool_types/Cargo.toml +++ b/crates/mempool_types/Cargo.toml @@ -10,8 +10,9 @@ workspace = true [dependencies] async-trait.workspace = true -starknet_api = { path = "../starknet_api", version = "0.13.0-rc.0"} mockall.workspace = true serde = { workspace = true, feat = ["derive"] } +starknet_api = { path = "../starknet_api", version = "0.13.0-rc.0" } +starknet_infra_proc_macros = { path = "../infra_proc_macros" } starknet_mempool_infra = { path = "../mempool_infra" } thiserror.workspace = true diff --git a/crates/mempool_types/src/communication.rs b/crates/mempool_types/src/communication.rs index 0bb8c5aaa20..e93f37e0255 100644 --- a/crates/mempool_types/src/communication.rs +++ b/crates/mempool_types/src/communication.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use mockall::predicate::*; use mockall::*; use serde::{Deserialize, Serialize}; +use starknet_infra_proc_macros::handle_response_variants; use starknet_mempool_infra::component_client::{ ClientError, LocalComponentClient, @@ -88,28 +89,17 @@ impl MempoolClient for RemoteMempoolClientImpl { async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()> { let request = MempoolRequest::AddTransaction(mempool_input); let response = self.send(request).await?; - match response { - MempoolResponse::AddTransaction(Ok(response)) => Ok(response), - MempoolResponse::AddTransaction(Err(response)) => { - Err(MempoolClientError::MempoolError(response)) - } - unexpected_response => Err(MempoolClientError::ClientError( - ClientError::UnexpectedResponse(format!("{unexpected_response:?}")), - )), - } + handle_response_variants!(MempoolResponse, AddTransaction, MempoolClientError, MempoolError) } async fn get_txs(&self, n_txs: usize) -> MempoolClientResult> { let request = MempoolRequest::GetTransactions(n_txs); let response = self.send(request).await?; - match response { - MempoolResponse::GetTransactions(Ok(response)) => Ok(response), - MempoolResponse::GetTransactions(Err(response)) => { - Err(MempoolClientError::MempoolError(response)) - } - unexpected_response => Err(MempoolClientError::ClientError( - ClientError::UnexpectedResponse(format!("{unexpected_response:?}")), - )), - } + handle_response_variants!( + MempoolResponse, + GetTransactions, + MempoolClientError, + MempoolError + ) } }