From 154cf2167b62db9c76094c9a9b1a6227b5acfa1a Mon Sep 17 00:00:00 2001 From: Dan Brownstein Date: Thu, 7 Sep 2023 23:36:12 +0300 Subject: [PATCH] fix: assert ascii chainID --- src/core.rs | 20 +++++++++++++++++--- src/serde_utils.rs | 20 +++++++++++++++++++- src/serde_utils_test.rs | 14 ++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/src/core.rs b/src/core.rs index 70fbea46..f74f8082 100644 --- a/src/core.rs +++ b/src/core.rs @@ -11,15 +11,29 @@ use serde::{Deserialize, Serialize}; use starknet_crypto::FieldElement; use crate::hash::{pedersen_hash_array, StarkFelt, StarkHash}; -use crate::serde_utils::{BytesAsHex, PrefixedBytesAsHex}; +use crate::serde_utils::{deserialize_ascii_string, BytesAsHex, PrefixedBytesAsHex}; use crate::transaction::{Calldata, ContractAddressSalt}; use crate::{impl_from_through_intermediate, StarknetApiError}; -/// A chain id. +/// A chain id. Must contain only ASCII characters. #[derive(Clone, Debug, Display, Eq, PartialEq, Hash, Deserialize, Serialize, PartialOrd, Ord)] -pub struct ChainId(pub String); +pub struct ChainId(#[serde(deserialize_with = "deserialize_ascii_string")] String); impl ChainId { + /// Returns a new [`ChainId`]. + pub fn new(chain_id: String) -> Result { + match chain_id.chars().all(|c| c.is_ascii()) { + true => Ok(Self(chain_id)), + false => Err(StarknetApiError::OutOfRange { string: chain_id }), + } + } + + /// Returns the chain id as a string. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Returns the chain id as a hex string. pub fn as_hex(&self) -> String { format!("0x{}", hex::encode(&self.0)) } diff --git a/src/serde_utils.rs b/src/serde_utils.rs index 067fddeb..8d53d887 100644 --- a/src/serde_utils.rs +++ b/src/serde_utils.rs @@ -3,9 +3,10 @@ #[path = "serde_utils_test.rs"] mod serde_utils_test; -use serde::de::{Deserialize, Visitor}; +use serde::de::{Deserialize, Error as DeserializationError, Visitor}; use serde::ser::{Serialize, SerializeTuple}; use serde::Deserializer; +use serde_json::Value; use crate::deprecated_contract_class::ContractClassAbiEntry; @@ -144,3 +145,20 @@ where Err(_) => Ok(None), } } + +pub fn deserialize_ascii_string<'de, D: Deserializer<'de>>( + deserializer: D, +) -> Result { + let chian_id: _ = match Value::deserialize(deserializer)? { + Value::String(string) => match string.chars().all(|c| c.is_ascii()) { + true => string, + false => { + return Err(DeserializationError::custom(format!( + "Chain id ({string}) must contain only ASCII characters." + ))); + } + }, + _ => return Err(DeserializationError::custom("Cannot cast value into String.")), + }; + Ok(chian_id) +} diff --git a/src/serde_utils_test.rs b/src/serde_utils_test.rs index b3d6134b..3e239bc7 100644 --- a/src/serde_utils_test.rs +++ b/src/serde_utils_test.rs @@ -175,3 +175,17 @@ fn deserialize_optional_contract_class_abi_entry_vector_none() { let res: DummyContractClass = serde_json::from_str(json).unwrap(); assert_eq!(res, DummyContractClass { abi: None }); } + +#[test] +fn chain_id_non_ascii() { + let chain_id_str = r#""חלודה""#; + let chain_id = serde_json::from_str::(chain_id_str); + assert_matches!(chain_id, Err(serde_json::Error { .. })); +} + +#[test] +fn valid_chain_id() { + let chain_id_str = r#""chain_ID""#; + let chain_id = serde_json::from_str::(chain_id_str).unwrap(); + assert_eq!(chain_id_str, serde_json::to_string(&chain_id).unwrap()) +}