From 55a6aae8a001b8569d29f9e3f5d06dcc283b1777 Mon Sep 17 00:00:00 2001 From: Arni Hod Date: Mon, 7 Oct 2024 21:30:32 +0300 Subject: [PATCH] refactor: streamline class info object to match contract class --- .../src/execution/contract_class.rs | 101 +++++++++--------- .../blockifier/src/transaction/test_utils.rs | 7 +- crates/gateway/src/compilation.rs | 6 +- .../native_blockifier/src/py_transaction.rs | 10 +- crates/papyrus_execution/src/lib.rs | 16 +-- crates/starknet_api/src/contract_class.rs | 28 ++++- .../src/executable_transaction.rs | 3 +- 7 files changed, 96 insertions(+), 75 deletions(-) diff --git a/crates/blockifier/src/execution/contract_class.rs b/crates/blockifier/src/execution/contract_class.rs index eab04d9fe7..774ddb163f 100644 --- a/crates/blockifier/src/execution/contract_class.rs +++ b/crates/blockifier/src/execution/contract_class.rs @@ -542,41 +542,79 @@ fn convert_entry_points_v1(external: &[CasmContractEntryPoint]) -> Vec for ClassInfo { type Error = ProgramError; fn try_from(class_info: starknet_api::contract_class::ClassInfo) -> Result { - let starknet_api::contract_class::ClassInfo { - contract_class, - sierra_program_length, - abi_length, - } = class_info; - - Ok(Self { contract_class: contract_class.try_into()?, sierra_program_length, abi_length }) + match class_info { + starknet_api::contract_class::ClassInfo::V0 { contract_class, abi_length } => { + Ok(Self::V0 { contract_class: contract_class.try_into()?, abi_length }) + } + starknet_api::contract_class::ClassInfo::V1 { + contract_class, + sierra_program_length, + abi_length, + } => Ok(Self::V1 { + contract_class: contract_class.try_into()?, + sierra_program_length, + abi_length, + }), + } } } impl ClassInfo { pub fn bytecode_length(&self) -> usize { - self.contract_class.bytecode_length() + match self { + ClassInfo::V0 { contract_class, .. } => contract_class.bytecode_length(), + ClassInfo::V1 { contract_class, .. } => contract_class.bytecode_length(), + ClassInfo::NativeV1 { contract_class: _contract_class, .. } => { + unimplemented!("implement bytecode_length for native contracts.") + } + } } pub fn contract_class(&self) -> ContractClass { - self.contract_class.clone() + match self { + ClassInfo::V0 { contract_class, .. } => ContractClass::V0(contract_class.clone()), + ClassInfo::V1 { contract_class, .. } => ContractClass::V1(contract_class.clone()), + ClassInfo::NativeV1 { contract_class, .. } => { + ContractClass::V1Native(contract_class.clone()) + } + } } pub fn sierra_program_length(&self) -> usize { - self.sierra_program_length + match self { + ClassInfo::V0 { .. } => 0, + ClassInfo::V1 { sierra_program_length, .. } => *sierra_program_length, + ClassInfo::NativeV1 { sierra_program_length, .. } => *sierra_program_length, + } } pub fn abi_length(&self) -> usize { - self.abi_length + match self { + ClassInfo::V0 { abi_length, .. } => *abi_length, + ClassInfo::V1 { abi_length, .. } => *abi_length, + ClassInfo::NativeV1 { abi_length, .. } => *abi_length, + } } pub fn code_size(&self) -> usize { @@ -585,39 +623,6 @@ impl ClassInfo { * eth_gas_constants::WORD_WIDTH + self.abi_length() } - - pub fn new_v1_native( - contract_class: NativeContractClassV1, - sierra_program_length: usize, - abi_length: usize, - ) -> Self { - Self { - contract_class: ContractClass::V1Native(contract_class), - sierra_program_length, - abi_length, - } - } - - pub fn new_v1( - contract_class: ContractClassV1, - sierra_program_length: usize, - abi_length: usize, - ) -> Self { - assert!(sierra_program_length > 0, "Sierra program length must be > 0 for Cairo1"); - Self { - contract_class: ContractClass::V1(contract_class), - sierra_program_length, - abi_length, - } - } - - pub fn new_v0(contract_class: ContractClassV0, abi_length: usize) -> Self { - Self { - contract_class: ContractClass::V0(contract_class), - sierra_program_length: 0, - abi_length, - } - } } // Cairo-native utilities. diff --git a/crates/blockifier/src/transaction/test_utils.rs b/crates/blockifier/src/transaction/test_utils.rs index a1594625c3..d7ae3e3e49 100644 --- a/crates/blockifier/src/transaction/test_utils.rs +++ b/crates/blockifier/src/transaction/test_utils.rs @@ -376,14 +376,13 @@ fn create_all_resource_bounds( pub fn calculate_class_info_for_testing(contract_class: ContractClass) -> ClassInfo { let abi_length = 100; match contract_class { - ContractClass::V0(contract_class) => ClassInfo::new_v0(contract_class, abi_length), + ContractClass::V0(contract_class) => ClassInfo::V0 { contract_class, abi_length }, ContractClass::V1(contract_class) => { - let sierra_program_length = 100; - ClassInfo::new_v1(contract_class, sierra_program_length, abi_length) + ClassInfo::V1 { contract_class, sierra_program_length: 100, abi_length } } ContractClass::V1Native(contract_class) => { let sierra_program_length = 100; - ClassInfo::new_v1_native(contract_class, sierra_program_length, abi_length) + ClassInfo::NativeV1 { contract_class, sierra_program_length, abi_length } } } } diff --git a/crates/gateway/src/compilation.rs b/crates/gateway/src/compilation.rs index 647704c8cd..c4403be7ed 100644 --- a/crates/gateway/src/compilation.rs +++ b/crates/gateway/src/compilation.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass; -use starknet_api::contract_class::{ClassInfo, ContractClass}; +use starknet_api::contract_class::ClassInfo; use starknet_api::rpc_transaction::RpcDeclareTransaction; use starknet_gateway_types::errors::GatewaySpecError; use starknet_sierra_compile::cairo_lang_compiler::CairoLangSierraToCasmCompiler; @@ -47,8 +47,8 @@ impl GatewayCompiler { let casm_contract_class = self.compile(cairo_lang_contract_class)?; - Ok(ClassInfo { - contract_class: ContractClass::V1(casm_contract_class), + Ok(ClassInfo::V1 { + contract_class: casm_contract_class, sierra_program_length: rpc_contract_class.sierra_program.len(), abi_length: rpc_contract_class.abi.len(), }) diff --git a/crates/native_blockifier/src/py_transaction.rs b/crates/native_blockifier/src/py_transaction.rs index 4389bb5ae6..26a7024271 100644 --- a/crates/native_blockifier/src/py_transaction.rs +++ b/crates/native_blockifier/src/py_transaction.rs @@ -170,18 +170,18 @@ impl PyClassInfo { ContractClassV0::try_from_json_string(&py_class_info.raw_contract_class)?; assert_eq!(py_class_info.sierra_program_length, 0); - ClassInfo::new_v0(contract_class, py_class_info.abi_length) + ClassInfo::V0 { contract_class, abi_length: py_class_info.abi_length } } starknet_api::transaction::DeclareTransaction::V2(_) | starknet_api::transaction::DeclareTransaction::V3(_) => { let contract_class = ContractClassV1::try_from_json_string(&py_class_info.raw_contract_class)?; - ClassInfo::new_v1( + ClassInfo::V1 { contract_class, - py_class_info.sierra_program_length, - py_class_info.abi_length, - ) + sierra_program_length: py_class_info.sierra_program_length, + abi_length: py_class_info.abi_length, + } } }; Ok(class_info) diff --git a/crates/papyrus_execution/src/lib.rs b/crates/papyrus_execution/src/lib.rs index 3fab20956e..01df3f8982 100644 --- a/crates/papyrus_execution/src/lib.rs +++ b/crates/papyrus_execution/src/lib.rs @@ -736,7 +736,7 @@ fn to_blockifier_tx( abi_length, only_query, ) => { - let class_v0 = deprecated_class.try_into().map_err( + let contract_class = deprecated_class.try_into().map_err( |e: cairo_vm::types::errors::program_errors::ProgramError| { ExecutionError::TransactionExecutionError { transaction_index, @@ -744,7 +744,7 @@ fn to_blockifier_tx( } }, )?; - let class_info = ClassInfo::new_v0(class_v0, abi_length); + let class_info = ClassInfo::V0 { contract_class, abi_length }; BlockifierTransaction::from_api( Transaction::Declare(DeclareTransaction::V0(declare_tx)), tx_hash, @@ -761,8 +761,8 @@ fn to_blockifier_tx( abi_length, only_query, ) => { - let class_v0 = deprecated_class.try_into().map_err(BlockifierError::new)?; - let class_info = ClassInfo::new_v0(class_v0, abi_length); + let contract_class = deprecated_class.try_into().map_err(BlockifierError::new)?; + let class_info = ClassInfo::V0 { contract_class, abi_length }; BlockifierTransaction::from_api( Transaction::Declare(DeclareTransaction::V1(declare_tx)), tx_hash, @@ -780,8 +780,8 @@ fn to_blockifier_tx( abi_length, only_query, ) => { - let class_v1 = compiled_class.try_into().map_err(BlockifierError::new)?; - let class_info = ClassInfo::new_v1(class_v1, sierra_program_length, abi_length); + let contract_class = compiled_class.try_into().map_err(BlockifierError::new)?; + let class_info = ClassInfo::V1 { contract_class, sierra_program_length, abi_length }; BlockifierTransaction::from_api( Transaction::Declare(DeclareTransaction::V2(declare_tx)), tx_hash, @@ -799,8 +799,8 @@ fn to_blockifier_tx( abi_length, only_query, ) => { - let class_v1 = compiled_class.try_into().map_err(BlockifierError::new)?; - let class_info = ClassInfo::new_v1(class_v1, sierra_program_length, abi_length); + let contract_class = compiled_class.try_into().map_err(BlockifierError::new)?; + let class_info = ClassInfo::V1 { contract_class, sierra_program_length, abi_length }; BlockifierTransaction::from_api( Transaction::Declare(DeclareTransaction::V3(declare_tx)), tx_hash, diff --git a/crates/starknet_api/src/contract_class.rs b/crates/starknet_api/src/contract_class.rs index 381d73d122..93b55c7411 100644 --- a/crates/starknet_api/src/contract_class.rs +++ b/crates/starknet_api/src/contract_class.rs @@ -24,9 +24,27 @@ impl ContractClass { /// All relevant information about a declared contract class, including the compiled contract class /// and other parameters derived from the original declare transaction required for billing. #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] -pub struct ClassInfo { - // TODO(Noa): Consider using Arc. - pub contract_class: ContractClass, - pub sierra_program_length: usize, - pub abi_length: usize, +pub enum ClassInfo { + V0 { + // TODO(Noa): Consider using Arc. + contract_class: DeprecatedContractClass, + abi_length: usize, + }, + V1 { + // TODO(Noa): Consider using Arc. + contract_class: CasmContractClass, + sierra_program_length: usize, + abi_length: usize, + }, +} + +impl ClassInfo { + pub fn compiled_class_hash(&self) -> CompiledClassHash { + match self { + ClassInfo::V0 { .. } => panic!("Cairo 0 doesn't have compiled class hash."), + ClassInfo::V1 { contract_class, .. } => { + CompiledClassHash(contract_class.compiled_class_hash()) + } + } + } } diff --git a/crates/starknet_api/src/executable_transaction.rs b/crates/starknet_api/src/executable_transaction.rs index d88e8d0cfe..840e8fa06f 100644 --- a/crates/starknet_api/src/executable_transaction.rs +++ b/crates/starknet_api/src/executable_transaction.rs @@ -163,8 +163,7 @@ impl DeclareTransaction { | crate::transaction::DeclareTransaction::V0(_) => return true, }; - let contract_class = &self.class_info.contract_class; - let compiled_class_hash = contract_class.compiled_class_hash(); + let compiled_class_hash = self.class_info.compiled_class_hash(); compiled_class_hash == supplied_compiled_class_hash }