diff --git a/crates/blockifier/src/execution/native/entry_point_execution.rs b/crates/blockifier/src/execution/native/entry_point_execution.rs index a8e1c405d0..9a338d593b 100644 --- a/crates/blockifier/src/execution/native/entry_point_execution.rs +++ b/crates/blockifier/src/execution/native/entry_point_execution.rs @@ -49,6 +49,7 @@ pub fn execute_entry_point_call( Some(builtin_costs), &mut syscall_handler, ); + syscall_handler.finalize(); let call_result = execution_result.map_err(EntryPointExecutionError::NativeUnexpectedError)?; diff --git a/crates/blockifier/src/execution/native/syscall_handler.rs b/crates/blockifier/src/execution/native/syscall_handler.rs index 48ab901ece..0debda771b 100644 --- a/crates/blockifier/src/execution/native/syscall_handler.rs +++ b/crates/blockifier/src/execution/native/syscall_handler.rs @@ -78,19 +78,12 @@ impl<'state> NativeSyscallHandler<'state> { entry_point: CallEntryPoint, remaining_gas: &mut u64, ) -> SyscallResult { - let call_info = entry_point - .execute(self.base.state, self.base.context, remaining_gas) - .map_err(|e| self.handle_error(remaining_gas, e.into()))?; - let retdata = call_info.execution.retdata.clone(); - - if call_info.execution.failed { - let error = SyscallExecutionError::SyscallError { error_data: retdata.0 }; - return Err(self.handle_error(remaining_gas, error)); - } - - self.base.inner_calls.push(call_info); + let raw_retdata = self + .base + .execute_inner_call(entry_point, remaining_gas) + .map_err(|e| self.handle_error(remaining_gas, e))?; - Ok(retdata) + Ok(Retdata(raw_retdata)) } pub fn gas_costs(&self) -> &GasCosts { @@ -228,6 +221,9 @@ impl<'state> NativeSyscallHandler<'state> { }), } } + pub fn finalize(&mut self) { + self.base.finalize(); + } } impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { @@ -471,11 +467,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { let key = StorageKey::try_from(address) .map_err(|e| self.handle_error(remaining_gas, e.into()))?; - self.base.accessed_keys.insert(key); - - let write_result = - self.base.state.set_storage_at(self.base.call.storage_address, key, value); - write_result.map_err(|e| self.handle_error(remaining_gas, e.into()))?; + self.base.storage_write(key, value).map_err(|e| self.handle_error(remaining_gas, e))?; Ok(()) } diff --git a/crates/blockifier/src/execution/syscalls/hint_processor.rs b/crates/blockifier/src/execution/syscalls/hint_processor.rs index 86d4cf3d25..847292ab25 100644 --- a/crates/blockifier/src/execution/syscalls/hint_processor.rs +++ b/crates/blockifier/src/execution/syscalls/hint_processor.rs @@ -1,5 +1,5 @@ use std::any::Any; -use std::collections::{hash_map, HashMap}; +use std::collections::HashMap; use cairo_lang_casm::hints::{Hint, StarknetHint}; use cairo_lang_runner::casm_run::execute_core_hint_base; @@ -65,7 +65,6 @@ use crate::execution::syscalls::{ storage_read, storage_write, StorageReadResponse, - StorageWriteResponse, SyscallRequest, SyscallRequestWrapper, SyscallResponse, @@ -643,34 +642,8 @@ impl<'a> SyscallHintProcessor<'a> { Ok(StorageReadResponse { value }) } - pub fn set_contract_storage_at( - &mut self, - key: StorageKey, - value: Felt, - ) -> SyscallResult { - let contract_address = self.storage_address(); - - match self.base.original_values.entry(key) { - hash_map::Entry::Vacant(entry) => { - entry.insert(self.base.state.get_storage_at(contract_address, key)?); - } - hash_map::Entry::Occupied(_) => {} - } - - self.base.accessed_keys.insert(key); - self.base.state.set_storage_at(contract_address, key, value)?; - - Ok(StorageWriteResponse {}) - } - pub fn finalize(&mut self) { - self.base - .context - .revert_infos - .0 - .last_mut() - .expect("Missing contract revert info.") - .original_values = std::mem::take(&mut self.base.original_values); + self.base.finalize(); } } diff --git a/crates/blockifier/src/execution/syscalls/mod.rs b/crates/blockifier/src/execution/syscalls/mod.rs index 407e7eec38..dcef96c904 100644 --- a/crates/blockifier/src/execution/syscalls/mod.rs +++ b/crates/blockifier/src/execution/syscalls/mod.rs @@ -612,7 +612,8 @@ pub fn storage_write( syscall_handler: &mut SyscallHintProcessor<'_>, _remaining_gas: &mut u64, ) -> SyscallResult { - syscall_handler.set_contract_storage_at(request.address, request.value) + syscall_handler.base.storage_write(request.address, request.value)?; + Ok(StorageWriteResponse {}) } // Keccak syscall. diff --git a/crates/blockifier/src/execution/syscalls/syscall_base.rs b/crates/blockifier/src/execution/syscalls/syscall_base.rs index c4b049da00..bdfe1185ab 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_base.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_base.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::{hash_map, HashMap, HashSet}; use std::convert::From; use starknet_api::core::{ClassHash, ContractAddress}; @@ -99,6 +99,22 @@ impl<'state> SyscallHandlerBase<'state> { Ok(self.state.get_storage_at(block_hash_contract_address, key)?) } + pub fn storage_write(&mut self, key: StorageKey, value: Felt) -> SyscallResult<()> { + let contract_address = self.call.storage_address; + + match self.original_values.entry(key) { + hash_map::Entry::Vacant(entry) => { + entry.insert(self.state.get_storage_at(contract_address, key)?); + } + hash_map::Entry::Occupied(_) => {} + } + + self.accessed_keys.insert(key); + self.state.set_storage_at(contract_address, key, value)?; + + Ok(()) + } + pub fn execute_inner_call( &mut self, call: CallEntryPoint, @@ -138,4 +154,13 @@ impl<'state> SyscallHandlerBase<'state> { Ok(raw_retdata) } + + pub fn finalize(&mut self) { + self.context + .revert_infos + .0 + .last_mut() + .expect("Missing contract revert info.") + .original_values = std::mem::take(&mut self.original_values); + } } diff --git a/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs b/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs index ae77f25cdd..606cc9a8e8 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs @@ -27,10 +27,10 @@ use crate::test_utils::{ BALANCE, }; -// TODO: Add test for native once reverts are supported. -#[test] -fn test_call_contract_that_panics() { - let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1); +#[cfg_attr(feature = "cairo_native", test_case(CairoVersion::Native; "Native"))] +#[test_case(CairoVersion::Cairo1;"VM")] +fn test_call_contract_that_panics(cairo_version: CairoVersion) { + let test_contract = FeatureContract::TestContract(cairo_version); let empty_contract = FeatureContract::Empty(CairoVersion::Cairo1); let chain_info = &ChainInfo::create_for_testing(); let mut state = test_state(chain_info, BALANCE, &[(test_contract, 1), (empty_contract, 0)]); @@ -61,6 +61,11 @@ fn test_call_contract_that_panics() { ); assert!(inner_call.execution.events.is_empty()); assert!(inner_call.execution.l2_to_l1_messages.is_empty()); + + // Check that the tracked resource is SierraGas to make sure that Native is running. + for call in res.iter() { + assert_eq!(call.tracked_resource, TrackedResource::SierraGas); + } } #[cfg_attr( diff --git a/crates/blockifier/src/execution/syscalls/syscall_tests/library_call.rs b/crates/blockifier/src/execution/syscalls/syscall_tests/library_call.rs index 0bb5816777..9a3ede71b5 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_tests/library_call.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_tests/library_call.rs @@ -83,20 +83,16 @@ fn test_library_call_assert_fails(cairo_version: CairoVersion) { ..trivial_external_entry_point_new(test_contract) }; let call_info = entry_point_call.execute_directly(&mut state).unwrap(); - let expected_err_retdata = match test_contract.cairo_version() { - CairoVersion::Cairo0 | CairoVersion::Cairo1 => { - // 'x != y', 'ENTRYPOINT_FAILED'. - vec![felt!("0x7820213d2079"), felt!("0x454e545259504f494e545f4641494c4544")] - } - #[cfg(feature = "cairo_native")] - // 'x != y'. - CairoVersion::Native => vec![felt!("0x7820213d2079")], - }; assert_eq!( call_info.execution, CallExecution { - retdata: Retdata(expected_err_retdata), + retdata: Retdata(vec![ + // 'x != y'. + felt!("0x7820213d2079"), + // 'ENTRYPOINT_FAILED'. + felt!("0x454e545259504f494e545f4641494c4544") + ]), gas_consumed: 150980, failed: true, ..Default::default() diff --git a/crates/blockifier/src/transaction/account_transactions_test.rs b/crates/blockifier/src/transaction/account_transactions_test.rs index 52e040fcde..245d82f61e 100644 --- a/crates/blockifier/src/transaction/account_transactions_test.rs +++ b/crates/blockifier/src/transaction/account_transactions_test.rs @@ -55,6 +55,7 @@ use starknet_types_core::felt::Felt; use crate::check_tx_execution_error_for_invalid_scenario; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; +use crate::execution::contract_class::TrackedResource; use crate::execution::entry_point::EntryPointExecutionContext; use crate::execution::syscalls::SyscallSelector; use crate::fee::fee_utils::{get_fee_by_gas_vector, get_sequencer_balance_keys}; @@ -1761,7 +1762,10 @@ fn test_revert_in_execute( } #[rstest] +#[cfg_attr(feature = "cairo_native", case::native(CairoVersion::Native))] +#[case::vm(CairoVersion::Cairo1)] fn test_call_contract_that_panics( + #[case] cairo_version: CairoVersion, mut block_context: BlockContext, default_all_resource_bounds: ValidResourceBounds, #[values(true, false)] enable_reverts: bool, @@ -1769,7 +1773,8 @@ fn test_call_contract_that_panics( ) { // Override enable reverts. block_context.versioned_constants.enable_reverts = enable_reverts; - let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1); + let test_contract = FeatureContract::TestContract(cairo_version); + // TODO(Yoni): use `class_version` here once the feature contract fully supports Native. let account = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); let chain_info = &block_context.chain_info; let state = &mut test_state(chain_info, BALANCE, &[(test_contract, 1), (account, 1)]); @@ -1809,4 +1814,11 @@ fn test_call_contract_that_panics( // If reverts are enabled, `test_call_contract_revert` should catch it and ignore it. // Otherwise, the transaction should revert. assert_eq!(tx_execution_info.is_reverted(), !enable_reverts); + + if enable_reverts { + // Check that the tracked resource is SierraGas to make sure that Native is running. + for call in tx_execution_info.execute_call_info.unwrap().iter() { + assert_eq!(call.tracked_resource, TrackedResource::SierraGas); + } + } }