diff --git a/crates/blockifier/src/test_utils/contracts.rs b/crates/blockifier/src/test_utils/contracts.rs index d41d2a7f69..1f6d7bb603 100644 --- a/crates/blockifier/src/test_utils/contracts.rs +++ b/crates/blockifier/src/test_utils/contracts.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; +use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass; use itertools::Itertools; use starknet_api::abi::abi_utils::selector_from_name; use starknet_api::abi::constants::CONSTRUCTOR_ENTRY_POINT_NAME; @@ -10,6 +11,7 @@ use starknet_api::deprecated_contract_class::{ ContractClass as DeprecatedContractClass, EntryPointOffset, }; +use starknet_api::state::SierraContractClass; use starknet_api::{class_hash, contract_address, felt}; use starknet_types_core::felt::Felt; use strum::IntoEnumIterator; @@ -205,6 +207,21 @@ impl FeatureContract { self.get_class().try_into().unwrap() } + pub fn get_raw_sierra(&self) -> String { + if self.cairo_version() == CairoVersion::Cairo0 { + panic!("The sierra contract is only available for Cairo1."); + } + + get_raw_contract_class(&self.get_sierra_path()) + } + + pub fn get_sierra(&self) -> SierraContractClass { + let raw_sierra = self.get_raw_sierra(); + let cairo_contract_class: CairoLangContractClass = + serde_json::from_str(&raw_sierra).unwrap(); + SierraContractClass::from(cairo_contract_class) + } + pub fn get_raw_class(&self) -> String { get_raw_contract_class(&self.get_compiled_path()) } diff --git a/crates/starknet_api/src/rpc_transaction.rs b/crates/starknet_api/src/rpc_transaction.rs index 49295a230d..ea79d2652c 100644 --- a/crates/starknet_api/src/rpc_transaction.rs +++ b/crates/starknet_api/src/rpc_transaction.rs @@ -4,6 +4,7 @@ mod rpc_transaction_test; use std::collections::HashMap; +use cairo_lang_starknet_classes::contract_class::ContractEntryPoints as CairoLangContractEntryPoints; use serde::{Deserialize, Serialize}; use crate::contract_class::EntryPointType; @@ -282,6 +283,18 @@ pub struct EntryPointByType { pub l1handler: Vec, } +// TODO(AVIV): Consider removing this conversion and using the one in the +// CairoLangContractEntryPoints +impl From for EntryPointByType { + fn from(value: CairoLangContractEntryPoints) -> Self { + Self { + constructor: value.constructor.into_iter().map(EntryPoint::from).collect(), + external: value.external.into_iter().map(EntryPoint::from).collect(), + l1handler: value.l1_handler.into_iter().map(EntryPoint::from).collect(), + } + } +} + impl EntryPointByType { pub fn from_hash_map(entry_points_by_type: HashMap>) -> Self { macro_rules! get_entrypoint_by_type { diff --git a/crates/starknet_api/src/state.rs b/crates/starknet_api/src/state.rs index ada375f13f..46ac61a0b9 100644 --- a/crates/starknet_api/src/state.rs +++ b/crates/starknet_api/src/state.rs @@ -4,6 +4,10 @@ mod state_test; use std::fmt::Debug; +use cairo_lang_starknet_classes::contract_class::{ + ContractClass as CairoLangContractClass, + ContractEntryPoint as CairoLangContractEntryPoint, +}; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; use starknet_types_core::felt::Felt; @@ -228,6 +232,26 @@ impl Default for SierraContractClass { } } +impl From for SierraContractClass { + fn from(cairo_lang_contract_class: CairoLangContractClass) -> Self { + Self { + sierra_program: cairo_lang_contract_class + .sierra_program + .into_iter() + .map(|big_uint_as_hex| Felt::from(big_uint_as_hex.value)) + .collect(), + contract_class_version: cairo_lang_contract_class.contract_class_version, + entry_points_by_type: cairo_lang_contract_class.entry_points_by_type.into(), + abi: { + match cairo_lang_contract_class.abi { + Some(abi) => abi.json(), + None => Default::default(), + } + }, + } + } +} + /// An entry point of a [ContractClass](`crate::state::ContractClass`). #[derive(Debug, Default, Clone, Eq, PartialEq, Hash, Deserialize, Serialize, PartialOrd, Ord)] pub struct EntryPoint { @@ -235,6 +259,15 @@ pub struct EntryPoint { pub selector: EntryPointSelector, } +impl From for EntryPoint { + fn from(entry_point: CairoLangContractEntryPoint) -> Self { + Self { + function_idx: FunctionIndex(entry_point.function_idx), + selector: EntryPointSelector(entry_point.selector.into()), + } + } +} + #[derive( Debug, Copy, Clone, Default, Eq, PartialEq, Hash, Deserialize, Serialize, PartialOrd, Ord, )]