From 254491b3750f625ada31619e70f348d39a99c201 Mon Sep 17 00:00:00 2001 From: Itay Tsabary Date: Mon, 29 Jul 2024 13:27:12 +0300 Subject: [PATCH] feat: add infra proc macro for matching responses commit-id:65ba4744 --- Cargo.lock | 9 ++++ Cargo.toml | 4 ++ crates/infra_proc_macros/Cargo.toml | 16 ++++++ crates/infra_proc_macros/src/lib.rs | 66 +++++++++++++++++++++++ crates/mempool_types/Cargo.toml | 3 +- crates/mempool_types/src/communication.rs | 51 ++++++------------ 6 files changed, 112 insertions(+), 37 deletions(-) create mode 100644 crates/infra_proc_macros/Cargo.toml create mode 100644 crates/infra_proc_macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 586494fda0..f2432c5e36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8968,6 +8968,14 @@ dependencies = [ "validator", ] +[[package]] +name = "starknet_infra_proc_macros" +version = "0.0.0" +dependencies = [ + "quote", + "syn 1.0.109", +] + [[package]] name = "starknet_mempool" version = "0.0.0" @@ -9065,6 +9073,7 @@ dependencies = [ "mockall", "serde", "starknet_api", + "starknet_infra_proc_macros", "starknet_mempool_infra", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index 0a545b6278..04e33fe1b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/committer", "crates/committer_cli", "crates/gateway", + "crates/infra_proc_macros", "crates/mempool", "crates/mempool_infra", "crates/mempool_node", @@ -131,6 +132,7 @@ prost-build = "0.12.1" prost-types = "0.12.1" pyo3 = "0.19.1" pyo3-log = "0.8.1" +quote = "1.0" rand = "0.8.5" rand_chacha = "0.3.1" rand_distr = "0.4.3" @@ -155,6 +157,7 @@ static_assertions = "1.1.0" statistical = "1.0.0" strum = "0.25.0" strum_macros = "0.25.2" +syn = "1.0" tempfile = "3.7.0" test-case = "3.2.1" test-log = "0.2.14" @@ -173,6 +176,7 @@ validator = "0.12" void = "1.0.2" zstd = "0.13.1" + [workspace.lints.rust] future-incompatible = "deny" nonstandard-style = "deny" diff --git a/crates/infra_proc_macros/Cargo.toml b/crates/infra_proc_macros/Cargo.toml new file mode 100644 index 0000000000..46a01f8783 --- /dev/null +++ b/crates/infra_proc_macros/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "starknet_infra_proc_macros" +version.workspace = true +edition.workspace = true +repository.workspace = true +license.workspace = true + +[dependencies] +quote.workspace = true +syn.workspace = true + +[lib] +proc-macro = true + +[lints] +workspace = true diff --git a/crates/infra_proc_macros/src/lib.rs b/crates/infra_proc_macros/src/lib.rs new file mode 100644 index 0000000000..64be564e08 --- /dev/null +++ b/crates/infra_proc_macros/src/lib.rs @@ -0,0 +1,66 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::parse::{Parse, ParseStream, Result}; +use syn::{parse_macro_input, Ident, Token}; + +struct MacroInput { + response_enum: Ident, + invocation_name: Ident, + component_client_error: Ident, + component_error: Ident, +} + +impl Parse for MacroInput { + fn parse(input: ParseStream<'_>) -> Result { + let response_enum = input.parse()?; + input.parse::()?; + let invocation_name = input.parse()?; + input.parse::()?; + let component_client_error = input.parse()?; + input.parse::()?; + let component_error = input.parse()?; + Ok(MacroInput { response_enum, invocation_name, component_client_error, component_error }) + } +} + +/// A macro for generating code that handles the received communication response. +/// Takes the following arguments: +/// * response_enum -- the response enum type +/// * invocation_name -- the request/response enum variant that was invoked +/// * component_client_error -- the component's client error type +/// * component_error -- the component's error type +/// +/// For example, the following input: +/// """ +/// handle_response_variants!(MempoolResponse, GetTransactions, MempoolClientError, MempoolError) +/// """ +/// +/// Results in: +/// """ +/// match response { +/// MempoolResponse::GetTransactions(Ok(response)) => Ok(response), +/// MempoolResponse::GetTransactions(Err(response)) => { +/// Err(MempoolClientError::MempoolError(response)) +/// } +/// unexpected_response => Err(MempoolClientError::ClientError( +/// ClientError::UnexpectedResponse(format!("{unexpected_response:?}")), +/// )), +/// } +/// """ +#[proc_macro] +pub fn handle_response_variants(input: TokenStream) -> TokenStream { + let MacroInput { response_enum, invocation_name, component_client_error, component_error } = + parse_macro_input!(input as MacroInput); + + let expanded = quote! { + match response { + #response_enum::#invocation_name(Ok(response)) => Ok(response), + #response_enum::#invocation_name(Err(response)) => { + Err(#component_client_error::#component_error(response)) + } + unexpected_response => Err(#component_client_error::ClientError(ClientError::UnexpectedResponse(format!("{unexpected_response:?}")))), + } + }; + + TokenStream::from(expanded) +} diff --git a/crates/mempool_types/Cargo.toml b/crates/mempool_types/Cargo.toml index e82e6dba98..74d2f3f845 100644 --- a/crates/mempool_types/Cargo.toml +++ b/crates/mempool_types/Cargo.toml @@ -10,8 +10,9 @@ workspace = true [dependencies] async-trait.workspace = true -starknet_api = { path = "../starknet_api", version = "0.13.0-rc.0"} mockall.workspace = true serde = { workspace = true, feat = ["derive"] } +starknet_api = { path = "../starknet_api", version = "0.13.0-rc.0" } +starknet_infra_proc_macros = { path = "../infra_proc_macros" } starknet_mempool_infra = { path = "../mempool_infra" } thiserror.workspace = true diff --git a/crates/mempool_types/src/communication.rs b/crates/mempool_types/src/communication.rs index 0bb8c5aaa2..de9bc04d7d 100644 --- a/crates/mempool_types/src/communication.rs +++ b/crates/mempool_types/src/communication.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use mockall::predicate::*; use mockall::*; use serde::{Deserialize, Serialize}; +use starknet_infra_proc_macros::handle_response_variants; use starknet_mempool_infra::component_client::{ ClientError, LocalComponentClient, @@ -57,29 +58,18 @@ impl MempoolClient for MempoolClientImpl { async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()> { let request = MempoolRequest::AddTransaction(mempool_input); let response = self.send(request).await; - match response { - MempoolResponse::AddTransaction(Ok(response)) => Ok(response), - MempoolResponse::AddTransaction(Err(response)) => { - Err(MempoolClientError::MempoolError(response)) - } - unexpected_response => Err(MempoolClientError::ClientError( - ClientError::UnexpectedResponse(format!("{unexpected_response:?}")), - )), - } + handle_response_variants!(MempoolResponse, AddTransaction, MempoolClientError, MempoolError) } async fn get_txs(&self, n_txs: usize) -> MempoolClientResult> { let request = MempoolRequest::GetTransactions(n_txs); let response = self.send(request).await; - match response { - MempoolResponse::GetTransactions(Ok(response)) => Ok(response), - MempoolResponse::GetTransactions(Err(response)) => { - Err(MempoolClientError::MempoolError(response)) - } - unexpected_response => Err(MempoolClientError::ClientError( - ClientError::UnexpectedResponse(format!("{unexpected_response:?}")), - )), - } + handle_response_variants!( + MempoolResponse, + GetTransactions, + MempoolClientError, + MempoolError + ) } } @@ -88,28 +78,17 @@ impl MempoolClient for RemoteMempoolClientImpl { async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()> { let request = MempoolRequest::AddTransaction(mempool_input); let response = self.send(request).await?; - match response { - MempoolResponse::AddTransaction(Ok(response)) => Ok(response), - MempoolResponse::AddTransaction(Err(response)) => { - Err(MempoolClientError::MempoolError(response)) - } - unexpected_response => Err(MempoolClientError::ClientError( - ClientError::UnexpectedResponse(format!("{unexpected_response:?}")), - )), - } + handle_response_variants!(MempoolResponse, AddTransaction, MempoolClientError, MempoolError) } async fn get_txs(&self, n_txs: usize) -> MempoolClientResult> { let request = MempoolRequest::GetTransactions(n_txs); let response = self.send(request).await?; - match response { - MempoolResponse::GetTransactions(Ok(response)) => Ok(response), - MempoolResponse::GetTransactions(Err(response)) => { - Err(MempoolClientError::MempoolError(response)) - } - unexpected_response => Err(MempoolClientError::ClientError( - ClientError::UnexpectedResponse(format!("{unexpected_response:?}")), - )), - } + handle_response_variants!( + MempoolResponse, + GetTransactions, + MempoolClientError, + MempoolError + ) } }