Skip to content

Commit

Permalink
feat: add infra proc macro for matching responses
Browse files Browse the repository at this point in the history
commit-id:65ba4744
  • Loading branch information
Itay-Tsabary-Starkware committed Jul 29, 2024
1 parent c60f856 commit ef5836a
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 19 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"crates/committer",
"crates/committer_cli",
"crates/gateway",
"crates/infra_proc_macros",
"crates/mempool",
"crates/mempool_infra",
"crates/mempool_node",
Expand Down Expand Up @@ -125,12 +126,14 @@ paste = "1.0.15"
phf = { version = "0.11", features = ["macros"] }
pretty_assertions = "1.4.0"
primitive-types = "0.12.1"
proc-macro2 = "1.0"
prometheus-parse = "0.2.4"
prost = "0.12.1"
prost-build = "0.12.1"
prost-types = "0.12.1"
pyo3 = "0.19.1"
pyo3-log = "0.8.1"
quote = "1.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
rand_distr = "0.4.3"
Expand All @@ -155,6 +158,7 @@ static_assertions = "1.1.0"
statistical = "1.0.0"
strum = "0.25.0"
strum_macros = "0.25.2"
syn = "1.0"
tempfile = "3.7.0"
test-case = "3.2.1"
test-log = "0.2.14"
Expand All @@ -173,6 +177,7 @@ validator = "0.12"
void = "1.0.2"
zstd = "0.13.1"


[workspace.lints.rust]
future-incompatible = "deny"
nonstandard-style = "deny"
Expand Down
17 changes: 17 additions & 0 deletions crates/infra_proc_macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "starknet_infra_proc_macros"
version.workspace = true
edition.workspace = true
repository.workspace = true
license.workspace = true

[dependencies]
proc-macro2.workspace = true
quote.workspace = true
syn.workspace = true

[lib]
proc-macro = true

[lints]
workspace = true
66 changes: 66 additions & 0 deletions crates/infra_proc_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};
use syn::{parse_macro_input, Ident, Token};

struct MacroInput {
response_enum: Ident,
invocation_name: Ident,
component_client_error: Ident,
component_error: Ident,
}

impl Parse for MacroInput {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let response_enum = input.parse()?;
input.parse::<Token![,]>()?;
let invocation_name = input.parse()?;
input.parse::<Token![,]>()?;
let component_client_error = input.parse()?;
input.parse::<Token![,]>()?;
let component_error = input.parse()?;
Ok(MacroInput { response_enum, invocation_name, component_client_error, component_error })
}
}

/// A macro for generating code that handles the received communication response.
/// Takes the following arguments:
/// * response_enum -- the response enum type
/// * invocation_name -- the request/response enum variant that was invoked
/// * component_client_error -- the component's client error type
/// * component_error -- the component's error type
///
/// For example, the following input:
/// ```
/// handle_response_variants!(MempoolResponse, GetTransactions, MempoolClientError, MempoolError)
/// ```
///
/// Results in:
/// ```
/// match response {
/// MempoolResponse::GetTransactions(Ok(response)) => Ok(response),
/// MempoolResponse::GetTransactions(Err(response)) => {
/// Err(MempoolClientError::MempoolError(response))
/// }
/// unexpected_response => Err(MempoolClientError::ClientError(
/// ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
/// )),
/// }
/// ```
#[proc_macro]
pub fn handle_response_variants(input: TokenStream) -> TokenStream {
let MacroInput { response_enum, invocation_name, component_client_error, component_error } =
parse_macro_input!(input as MacroInput);

let expanded = quote! {
match response {
#response_enum::#invocation_name(Ok(response)) => Ok(response),
#response_enum::#invocation_name(Err(response)) => {
Err(#component_client_error::#component_error(response))
}
unexpected_response => Err(#component_client_error::ClientError(ClientError::UnexpectedResponse(format!("{unexpected_response:?}")))),
}
};

TokenStream::from(expanded)
}
3 changes: 2 additions & 1 deletion crates/mempool_types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ workspace = true

[dependencies]
async-trait.workspace = true
starknet_api = { path = "../starknet_api", version = "0.13.0-rc.0"}
mockall.workspace = true
serde = { workspace = true, feat = ["derive"] }
starknet_api = { path = "../starknet_api", version = "0.13.0-rc.0" }
starknet_infra_proc_macros = { path = "../infra_proc_macros" }
starknet_mempool_infra = { path = "../mempool_infra" }
thiserror.workspace = true
26 changes: 8 additions & 18 deletions crates/mempool_types/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use async_trait::async_trait;
use mockall::predicate::*;
use mockall::*;
use serde::{Deserialize, Serialize};
use starknet_infra_proc_macros::handle_response_variants;
use starknet_mempool_infra::component_client::{
ClientError,
LocalComponentClient,
Expand Down Expand Up @@ -88,28 +89,17 @@ impl MempoolClient for RemoteMempoolClientImpl {
async fn add_tx(&self, mempool_input: MempoolInput) -> MempoolClientResult<()> {
let request = MempoolRequest::AddTransaction(mempool_input);
let response = self.send(request).await?;
match response {
MempoolResponse::AddTransaction(Ok(response)) => Ok(response),
MempoolResponse::AddTransaction(Err(response)) => {
Err(MempoolClientError::MempoolError(response))
}
unexpected_response => Err(MempoolClientError::ClientError(
ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
)),
}
handle_response_variants!(MempoolResponse, AddTransaction, MempoolClientError, MempoolError)
}

async fn get_txs(&self, n_txs: usize) -> MempoolClientResult<Vec<ThinTransaction>> {
let request = MempoolRequest::GetTransactions(n_txs);
let response = self.send(request).await?;
match response {
MempoolResponse::GetTransactions(Ok(response)) => Ok(response),
MempoolResponse::GetTransactions(Err(response)) => {
Err(MempoolClientError::MempoolError(response))
}
unexpected_response => Err(MempoolClientError::ClientError(
ClientError::UnexpectedResponse(format!("{unexpected_response:?}")),
)),
}
handle_response_variants!(
MempoolResponse,
GetTransactions,
MempoolClientError,
MempoolError
)
}
}

0 comments on commit ef5836a

Please sign in to comment.