From c4a60aca217481522be4975b4a5b7e8ff219d861 Mon Sep 17 00:00:00 2001 From: Tzahi Taub Date: Tue, 15 Oct 2024 17:08:26 +0300 Subject: [PATCH] test(blockifier): function to build calldata for recursive call contract calls --- .../syscalls/syscall_tests/call_contract.rs | 90 +++++++++---------- crates/blockifier/src/test_utils.rs | 28 ++++++ crates/blockifier/src/test_utils/syscall.rs | 33 +++++++ 3 files changed, 105 insertions(+), 46 deletions(-) create mode 100644 crates/blockifier/src/test_utils/syscall.rs diff --git a/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs b/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs index 9e3b573c98..c260162a23 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs @@ -19,7 +19,14 @@ use crate::retdata; use crate::state::state_api::StateReader; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::initial_test_state::test_state; -use crate::test_utils::{create_calldata, trivial_external_entry_point_new, CairoVersion, BALANCE}; +use crate::test_utils::syscall::build_recurse_calldata; +use crate::test_utils::{ + create_calldata, + trivial_external_entry_point_new, + CairoVersion, + CompilerBasedVersion, + BALANCE, +}; #[test] fn test_call_contract_that_panics() { @@ -76,7 +83,7 @@ fn test_call_contract( inner_contract.get_instance_address(0), "test_storage_read_write", &[ - felt!(405_u16), // Calldata: address. + felt!(405_u16), // Calldata: storage address. felt!(48_u8), // Calldata: value. ], ); @@ -96,26 +103,29 @@ fn test_call_contract( ); } -/// Cairo0 / Cairo1 calls to Cairo0 / Cairo1. +/// Cairo0 / Old Cairo1 / Cairo1 calls to Cairo0 / Old Cairo1/ Cairo1. #[rstest] -fn test_track_resources( - #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] outer_version: CairoVersion, - #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] inner_version: CairoVersion, +fn test_tracked_resources( + #[values( + CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0), + CompilerBasedVersion::OldCairo1, + CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1) + )] + outer_version: CompilerBasedVersion, + #[values( + CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0), + CompilerBasedVersion::OldCairo1, + CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1) + )] + inner_version: CompilerBasedVersion, ) { - let outer_contract = FeatureContract::TestContract(outer_version); - let inner_contract = FeatureContract::TestContract(inner_version); + let outer_contract = outer_version.get_test_contract(); + let inner_contract = inner_version.get_test_contract(); let chain_info = &ChainInfo::create_for_testing(); let mut state = test_state(chain_info, BALANCE, &[(outer_contract, 1), (inner_contract, 1)]); let outer_entry_point_selector = selector_from_name("test_call_contract"); - let calldata = create_calldata( - inner_contract.get_instance_address(0), - "test_storage_read_write", - &[ - felt!(405_u16), // Calldata: address. - felt!(48_u8), // Calldata: value. - ], - ); + let calldata = build_recurse_calldata(&[inner_version]); let entry_point_call = CallEntryPoint { entry_point_selector: outer_entry_point_selector, calldata, @@ -123,15 +133,14 @@ fn test_track_resources( }; let execution = entry_point_call.execute_directly(&mut state).unwrap(); - let expected_outer_resource = match outer_version { - CairoVersion::Cairo0 => TrackedResource::CairoSteps, - CairoVersion::Cairo1 => TrackedResource::SierraGas, - }; + let expected_outer_resource = outer_version.own_tracked_resource(); assert_eq!(execution.tracked_resource, expected_outer_resource); - let expected_inner_resource = match (outer_version, inner_version) { - (CairoVersion::Cairo1, CairoVersion::Cairo1) => TrackedResource::SierraGas, - _ => TrackedResource::CairoSteps, + let expected_inner_resource = if expected_outer_resource == inner_version.own_tracked_resource() + { + expected_outer_resource + } else { + TrackedResource::CairoSteps }; assert_eq!(execution.inner_calls.first().unwrap().tracked_resource, expected_inner_resource); } @@ -140,37 +149,26 @@ fn test_track_resources( /// 1) Cairo-Steps contract that calls Sierra-Gas (nested) contract. /// 2) Sierra-Gas contract. #[rstest] -fn test_track_resources_nested( +fn test_tracked_resources_nested( #[values( - FeatureContract::TestContract(CairoVersion::Cairo0), - FeatureContract::CairoStepsTestContract + CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0), + CompilerBasedVersion::OldCairo1 )] - cairo_steps_contract: FeatureContract, + cairo_steps_contract_version: CompilerBasedVersion, ) { + let cairo_steps_contract = cairo_steps_contract_version.get_test_contract(); let sierra_gas_contract = FeatureContract::TestContract(CairoVersion::Cairo1); let chain_info = &ChainInfo::create_for_testing(); let mut state = test_state(chain_info, BALANCE, &[(sierra_gas_contract, 1), (cairo_steps_contract, 1)]); - let first_calldata = create_calldata( - cairo_steps_contract.get_instance_address(0), - "test_call_contract", - &[ - sierra_gas_contract.get_instance_address(0).into(), - selector_from_name("test_storage_read_write").0, - felt!(2_u8), // Calldata length - felt!(405_u16), // Calldata: address. - felt!(48_u8), // Calldata: value. - ], - ); - let second_calldata = create_calldata( - sierra_gas_contract.get_instance_address(0), - "test_storage_read_write", - &[ - felt!(406_u16), // Calldata: address. - felt!(49_u8), // Calldata: value. - ], - ); + let first_calldata = build_recurse_calldata(&[ + cairo_steps_contract_version, + CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1), + ]); + + let second_calldata = + build_recurse_calldata(&[CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1)]); let concated_calldata_felts = [first_calldata.0, second_calldata.0] .into_iter() diff --git a/crates/blockifier/src/test_utils.rs b/crates/blockifier/src/test_utils.rs index 111b425dcf..5bebf91d12 100644 --- a/crates/blockifier/src/test_utils.rs +++ b/crates/blockifier/src/test_utils.rs @@ -7,6 +7,7 @@ pub mod initial_test_state; pub mod invoke; pub mod prices; pub mod struct_impls; +pub mod syscall; pub mod transfers_generator; use std::collections::HashMap; use std::fs; @@ -30,6 +31,7 @@ use starknet_types_core::felt::Felt; use crate::abi::abi_utils::{get_fee_token_var_address, selector_from_name}; use crate::execution::call_info::ExecutionSummary; +use crate::execution::contract_class::TrackedResource; use crate::execution::deprecated_syscalls::hint_processor::SyscallCounter; use crate::execution::entry_point::CallEntryPoint; use crate::execution::syscalls::SyscallSelector; @@ -90,6 +92,32 @@ impl CairoVersion { } } +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum CompilerBasedVersion { + CairoVersion(CairoVersion), + OldCairo1, +} + +impl CompilerBasedVersion { + pub fn get_test_contract(&self) -> FeatureContract { + match self { + Self::CairoVersion(version) => FeatureContract::TestContract(*version), + Self::OldCairo1 => FeatureContract::CairoStepsTestContract, + } + } + + /// Returns the tracked resource for a contract execution with the current version, assuming no + /// calls were made to other contracts prior to this execution. + pub fn own_tracked_resource(&self) -> TrackedResource { + match self { + Self::CairoVersion(CairoVersion::Cairo0) | Self::OldCairo1 => { + TrackedResource::CairoSteps + } + Self::CairoVersion(CairoVersion::Cairo1) => TrackedResource::SierraGas, + } + } +} + // Storage keys. pub fn test_erc20_sequencer_balance_key() -> StorageKey { get_fee_token_var_address(contract_address!(TEST_SEQUENCER_ADDRESS)) diff --git a/crates/blockifier/src/test_utils/syscall.rs b/crates/blockifier/src/test_utils/syscall.rs new file mode 100644 index 0000000000..cbf8c16dc9 --- /dev/null +++ b/crates/blockifier/src/test_utils/syscall.rs @@ -0,0 +1,33 @@ +use starknet_api::felt; +use starknet_api::transaction::Calldata; + +use crate::test_utils::{create_calldata, CompilerBasedVersion}; + +/// Returns the calldata for N recursive call contract syscalls, where N is the length of versions. +/// versions determines the cairo version of the called contract in each recursive call. Final call +/// is a simple local contract call (test_storage_read_write). +/// The first element in the returned value is the calldata for a call from a contract of the first +/// element in versions, to the a contract of the second element, etc. +pub fn build_recurse_calldata(versions: &[CompilerBasedVersion]) -> Calldata { + if versions.is_empty() { + return Calldata(vec![].into()); + } + let last_version = versions.last().unwrap(); + let mut calldata = create_calldata( + last_version.get_test_contract().get_instance_address(0), + "test_storage_read_write", + &[ + felt!(123_u16), // Calldata: address. + felt!(45_u8), // Calldata: value. + ], + ); + + for version in versions[..versions.len() - 1].iter().rev() { + calldata = create_calldata( + version.get_test_contract().get_instance_address(0), + "test_call_contract", + &calldata.0, + ); + } + return calldata; +}