Skip to content

Commit

Permalink
refactor(katana): fix feeder gateway types (dojoengine#2760)
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy authored and augustin-v committed Dec 4, 2024
1 parent e7c98c5 commit b69273e
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 68 deletions.
58 changes: 45 additions & 13 deletions crates/katana/feeder-gateway/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use katana_primitives::block::{BlockIdOrTag, BlockTag};
use katana_primitives::class::CasmContractClass;
use katana_primitives::Felt;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, StatusCode};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use url::Url;

use crate::types::{ContractClass, StateUpdate, StateUpdateWithBlock};
use crate::types::{Block, ContractClass, StateUpdate, StateUpdateWithBlock};

/// HTTP request header for the feeder gateway API key. This allow bypassing the rate limiting.
const X_THROTTLING_BYPASS: &str = "X-Throttling-Bypass";

#[derive(Debug, thiserror::Error)]
pub enum Error {
Expand All @@ -16,15 +20,22 @@ pub enum Error {
#[error(transparent)]
Sequencer(SequencerError),

#[error("Request rate limited")]
#[error("failed to parse header value '{value}'")]
InvalidHeaderValue { value: String },

#[error("request rate limited")]
RateLimited,
}

/// Client for interacting with the Starknet's feeder gateway.
#[derive(Debug, Clone)]
pub struct SequencerGateway {
/// The feeder gateway base URL.
base_url: Url,
client: Client,
/// The HTTP client used to send the requests.
http_client: Client,
/// The API key used to bypass the rate limiting of the feeder gateway.
api_key: Option<String>,
}

impl SequencerGateway {
Expand All @@ -44,8 +55,19 @@ impl SequencerGateway {

/// Creates a new gateway client at the given base URL.
pub fn new(base_url: Url) -> Self {
let api_key = None;
let client = Client::new();
Self { client, base_url }
Self { http_client: client, base_url, api_key }
}

/// Sets the API key.
pub fn with_api_key(mut self, api_key: String) -> Self {
self.api_key = Some(api_key);
self
}

pub async fn get_block(&self, block_id: BlockIdOrTag) -> Result<Block, Error> {
self.feeder_gateway("get_block").with_block_id(block_id).send().await
}

pub async fn get_state_update(&self, block_id: BlockIdOrTag) -> Result<StateUpdate, Error> {
Expand Down Expand Up @@ -90,7 +112,7 @@ impl SequencerGateway {
fn feeder_gateway(&self, method: &str) -> RequestBuilder<'_> {
let mut url = self.base_url.clone();
url.path_segments_mut().expect("invalid base url").extend(["feeder_gateway", method]);
RequestBuilder { client: &self.client, url }
RequestBuilder { gateway_client: self, url }
}
}

Expand All @@ -103,7 +125,7 @@ enum Response<T> {

#[derive(Debug, Clone)]
struct RequestBuilder<'a> {
client: &'a Client,
gateway_client: &'a SequencerGateway,
url: Url,
}

Expand All @@ -124,7 +146,17 @@ impl<'a> RequestBuilder<'a> {
}

async fn send<T: DeserializeOwned>(self) -> Result<T, Error> {
let response = self.client.get(self.url).send().await?;
let mut headers = HeaderMap::new();

if let Some(key) = self.gateway_client.api_key.as_ref() {
let value = HeaderValue::from_str(key)
.map_err(|_| Error::InvalidHeaderValue { value: key.to_string() })?;
headers.insert(X_THROTTLING_BYPASS, value);
}

let response =
self.gateway_client.http_client.get(self.url).headers(headers).send().await?;

if response.status() == StatusCode::TOO_MANY_REQUESTS {
Err(Error::RateLimited)
} else {
Expand Down Expand Up @@ -186,9 +218,9 @@ mod tests {

#[test]
fn request_block_id() {
let client = Client::new();
let base_url = Url::parse("https://example.com/").unwrap();
let req = RequestBuilder { client: &client, url: base_url };
let client = SequencerGateway::new(base_url);
let req = client.feeder_gateway("test");

// Test pending block
let pending_url = req.clone().with_block_id(BlockIdOrTag::Tag(BlockTag::Pending)).url;
Expand All @@ -210,9 +242,9 @@ mod tests {

#[test]
fn multiple_query_params() {
let client = Client::new();
let base_url = Url::parse("https://example.com/").unwrap();
let req = RequestBuilder { client: &client, url: base_url };
let client = SequencerGateway::new(base_url);
let req = client.feeder_gateway("test");

let url = req
.add_query_param("param1", "value1")
Expand All @@ -229,9 +261,9 @@ mod tests {
#[test]
#[ignore]
fn request_block_id_overwrite() {
let client = Client::new();
let base_url = Url::parse("https://example.com/").unwrap();
let req = RequestBuilder { client: &client, url: base_url };
let client = SequencerGateway::new(base_url);
let req = client.feeder_gateway("test");

let url = req
.clone()
Expand Down
1 change: 1 addition & 0 deletions crates/katana/feeder-gateway/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use starknet::core::types::ResourcePrice;
use starknet::providers::sequencer::models::BlockStatus;

mod receipt;
mod serde_utils;
mod transaction;

pub use receipt::*;
Expand Down
106 changes: 106 additions & 0 deletions crates/katana/feeder-gateway/src/types/serde_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use serde::de::Visitor;
use serde::Deserialize;

pub fn deserialize_u64<'de, D>(deserializer: D) -> Result<u64, D::Error>
where
D: serde::Deserializer<'de>,
{
struct U64HexVisitor;

impl<'de> Visitor<'de> for U64HexVisitor {
type Value = u64;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "0x-prefix hex string or decimal number")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if let Some(hex) = v.strip_prefix("0x") {
u64::from_str_radix(hex, 16).map_err(serde::de::Error::custom)
} else {
v.parse::<u64>().map_err(serde::de::Error::custom)
}
}
}

deserializer.deserialize_any(U64HexVisitor)
}

pub fn deserialize_u128<'de, D>(deserializer: D) -> Result<u128, D::Error>
where
D: serde::Deserializer<'de>,
{
struct U128HexVisitor;

impl<'de> Visitor<'de> for U128HexVisitor {
type Value = u128;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "0x-prefix hex string or decimal number")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if let Some(hex) = v.strip_prefix("0x") {
u128::from_str_radix(hex, 16).map_err(serde::de::Error::custom)
} else {
v.parse::<u128>().map_err(serde::de::Error::custom)
}
}
}

deserializer.deserialize_any(U128HexVisitor)
}

pub fn deserialize_optional_u64<'de, D>(deserializer: D) -> Result<Option<u64>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNum {
String(String),
Number(u64),
}

match Option::<StringOrNum>::deserialize(deserializer)? {
None => Ok(None),
Some(StringOrNum::Number(n)) => Ok(Some(n)),
Some(StringOrNum::String(s)) => {
if let Some(hex) = s.strip_prefix("0x") {
u64::from_str_radix(hex, 16).map(Some).map_err(serde::de::Error::custom)
} else {
s.parse().map(Some).map_err(serde::de::Error::custom)
}
}
}
}

pub fn deserialize_optional_u128<'de, D>(deserializer: D) -> Result<Option<u128>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNum {
String(String),
Number(u128),
}

match Option::<StringOrNum>::deserialize(deserializer)? {
None => Ok(None),
Some(StringOrNum::Number(n)) => Ok(Some(n)),
Some(StringOrNum::String(s)) => {
if let Some(hex) = s.strip_prefix("0x") {
u128::from_str_radix(hex, 16).map(Some).map_err(serde::de::Error::custom)
} else {
s.parse().map(Some).map_err(serde::de::Error::custom)
}
}
}
}
Loading

0 comments on commit b69273e

Please sign in to comment.