diff --git a/crates/blockifier/src/execution/entry_point_execution.rs b/crates/blockifier/src/execution/entry_point_execution.rs index a60b5638ab..070e55be8a 100644 --- a/crates/blockifier/src/execution/entry_point_execution.rs +++ b/crates/blockifier/src/execution/entry_point_execution.rs @@ -77,7 +77,7 @@ pub fn execute_entry_point_call( } = initialize_execution_context(call, &contract_class, state, context)?; let args = prepare_call_arguments( - &syscall_handler.call, + &syscall_handler.base.call, &mut runner, initial_syscall_ptr, &mut syscall_handler.read_only_segments, @@ -93,7 +93,7 @@ pub fn execute_entry_point_call( // Collect the set PC values that were visited during the entry point execution. register_visited_pcs( &mut runner, - syscall_handler.state, + syscall_handler.base.state, class_hash, program_segment_size, bytecode_length, @@ -416,7 +416,7 @@ pub fn finalize_execution( .get_execution_resources() .map_err(VirtualMachineError::RunnerError)? .filter_unused_builtins(); - let versioned_constants = syscall_handler.context.versioned_constants(); + let versioned_constants = syscall_handler.base.context.versioned_constants(); if versioned_constants.segment_arena_cells { vm_resources_without_inner_calls .builtin_instance_counter @@ -435,24 +435,25 @@ pub fn finalize_execution( gas_for_fee: GasAmount(0), }; let charged_resources = &charged_resources_without_inner_calls - + &CallInfo::summarize_charged_resources(syscall_handler.inner_calls.iter()); + + &CallInfo::summarize_charged_resources(syscall_handler.base.inner_calls.iter()); + let syscall_handler_base = syscall_handler.base; Ok(CallInfo { - call: syscall_handler.call, + call: syscall_handler_base.call, execution: CallExecution { retdata: call_result.retdata, - events: syscall_handler.events, - l2_to_l1_messages: syscall_handler.l2_to_l1_messages, + events: syscall_handler_base.events, + l2_to_l1_messages: syscall_handler_base.l2_to_l1_messages, failed: call_result.failed, gas_consumed: call_result.gas_consumed, }, - inner_calls: syscall_handler.inner_calls, + inner_calls: syscall_handler_base.inner_calls, tracked_resource, charged_resources, - storage_read_values: syscall_handler.read_values, - accessed_storage_keys: syscall_handler.accessed_keys, - read_class_hash_values: syscall_handler.read_class_hash_values, - accessed_contract_addresses: syscall_handler.accessed_contract_addresses, + storage_read_values: syscall_handler_base.read_values, + accessed_storage_keys: syscall_handler_base.accessed_keys, + read_class_hash_values: syscall_handler_base.read_class_hash_values, + accessed_contract_addresses: syscall_handler_base.accessed_contract_addresses, }) } @@ -489,13 +490,13 @@ fn get_call_result( error_message: format!("Unexpected remaining gas: {gas}."), })?; - if gas > syscall_handler.call.initial_gas { + if gas > syscall_handler.base.call.initial_gas { return Err(PostExecutionError::MalformedReturnData { error_message: format!("Unexpected remaining gas: {gas}."), }); } - let gas_consumed = syscall_handler.call.initial_gas - gas; + let gas_consumed = syscall_handler.base.call.initial_gas - gas; Ok(CallResult { failed, retdata: read_execution_retdata(runner, retdata_size, retdata_start)?, diff --git a/crates/blockifier/src/execution/native/entry_point_execution.rs b/crates/blockifier/src/execution/native/entry_point_execution.rs index 6fb3efdd13..a8e1c405d0 100644 --- a/crates/blockifier/src/execution/native/entry_point_execution.rs +++ b/crates/blockifier/src/execution/native/entry_point_execution.rs @@ -27,7 +27,7 @@ pub fn execute_entry_point_call( let mut syscall_handler: NativeSyscallHandler<'_> = NativeSyscallHandler::new(call, state, context); - let gas_costs = &syscall_handler.context.versioned_constants().os_constants.gas_costs; + let gas_costs = &syscall_handler.base.context.versioned_constants().os_constants.gas_costs; let builtin_costs = BuiltinCosts { // todo(rodrigo): Unsure of what value `const` means, but 1 is the right value r#const: 1, @@ -41,10 +41,10 @@ pub fn execute_entry_point_call( // Fund the initial budget since the native executor charges it before the run. // TODO(Yoni): revert once the VM is aligned with this. - let gas = syscall_handler.call.initial_gas + gas_costs.entry_point_initial_budget; + let gas = syscall_handler.base.call.initial_gas + gas_costs.entry_point_initial_budget; let execution_result = contract_class.executor.run( entry_point.selector.0, - &syscall_handler.call.calldata.0.clone(), + &syscall_handler.base.call.calldata.0.clone(), Some(gas), Some(builtin_costs), &mut syscall_handler, @@ -65,25 +65,25 @@ fn create_callinfo( ) -> Result { let mut remaining_gas = call_result.remaining_gas; - if remaining_gas > syscall_handler.call.initial_gas { - if remaining_gas - syscall_handler.call.initial_gas - <= syscall_handler.context.gas_costs().entry_point_initial_budget + if remaining_gas > syscall_handler.base.call.initial_gas { + if remaining_gas - syscall_handler.base.call.initial_gas + <= syscall_handler.base.context.gas_costs().entry_point_initial_budget { // Revert the refund. // TODO(Yoni): temporary hack - this is probably a bug. Investigate and fix native. - remaining_gas = syscall_handler.call.initial_gas; + remaining_gas = syscall_handler.base.call.initial_gas; } else { return Err(PostExecutionError::MalformedReturnData { error_message: format!( "Unexpected remaining gas. Used gas is greater than initial gas: {} > {}", - remaining_gas, syscall_handler.call.initial_gas + remaining_gas, syscall_handler.base.call.initial_gas ), } .into()); } } - let gas_consumed = syscall_handler.call.initial_gas - remaining_gas; + let gas_consumed = syscall_handler.base.call.initial_gas - remaining_gas; let charged_resources_without_inner_calls = ChargedResources { vm_resources: ExecutionResources::default(), @@ -91,23 +91,23 @@ fn create_callinfo( gas_for_fee: GasAmount(0), }; let charged_resources = &charged_resources_without_inner_calls - + &CallInfo::summarize_charged_resources(syscall_handler.inner_calls.iter()); + + &CallInfo::summarize_charged_resources(syscall_handler.base.inner_calls.iter()); Ok(CallInfo { - call: syscall_handler.call, + call: syscall_handler.base.call, execution: CallExecution { retdata: Retdata(call_result.return_values), - events: syscall_handler.events, - l2_to_l1_messages: syscall_handler.l2_to_l1_messages, + events: syscall_handler.base.events, + l2_to_l1_messages: syscall_handler.base.l2_to_l1_messages, failed: call_result.failure_flag, gas_consumed, }, charged_resources, - inner_calls: syscall_handler.inner_calls, - storage_read_values: syscall_handler.read_values, - accessed_storage_keys: syscall_handler.accessed_keys, - accessed_contract_addresses: syscall_handler.accessed_contract_addresses, - read_class_hash_values: syscall_handler.read_class_hash_values, + inner_calls: syscall_handler.base.inner_calls, + storage_read_values: syscall_handler.base.read_values, + accessed_storage_keys: syscall_handler.base.accessed_keys, + accessed_contract_addresses: syscall_handler.base.accessed_contract_addresses, + read_class_hash_values: syscall_handler.base.read_class_hash_values, tracked_resource: TrackedResource::SierraGas, }) } diff --git a/crates/blockifier/src/execution/native/syscall_handler.rs b/crates/blockifier/src/execution/native/syscall_handler.rs index 3f1f7384ac..dec9744f12 100644 --- a/crates/blockifier/src/execution/native/syscall_handler.rs +++ b/crates/blockifier/src/execution/native/syscall_handler.rs @@ -1,7 +1,5 @@ -use std::collections::HashSet; use std::convert::From; use std::fmt; -use std::hash::RandomState; use std::sync::Arc; use ark_ec::short_weierstrass::{Affine, Projective, SWCurveConfig}; @@ -32,13 +30,7 @@ use starknet_api::transaction::fields::{Calldata, ContractAddressSalt}; use starknet_api::transaction::{EventContent, EventData, EventKey, L2ToL1Payload}; use starknet_types_core::felt::Felt; -use crate::execution::call_info::{ - CallInfo, - MessageToL1, - OrderedEvent, - OrderedL2ToL1Message, - Retdata, -}; +use crate::execution::call_info::{MessageToL1, OrderedEvent, OrderedL2ToL1Message, Retdata}; use crate::execution::common_hints::ExecutionMode; use crate::execution::contract_class::RunnableContractClass; use crate::execution::entry_point::{ @@ -59,24 +51,10 @@ use crate::execution::syscalls::hint_processor::{ use crate::execution::syscalls::{exceeds_event_size_limit, syscall_base}; use crate::state::state_api::State; use crate::transaction::objects::TransactionInfo; +use crate::versioned_constants::GasCosts; pub struct NativeSyscallHandler<'state> { - // Input for execution. - pub state: &'state mut dyn State, - pub context: &'state mut EntryPointExecutionContext, - pub call: CallEntryPoint, - - // Execution results. - pub events: Vec, - pub l2_to_l1_messages: Vec, - pub inner_calls: Vec, - - // Additional information gathered during execution. - pub read_values: Vec, - pub accessed_keys: HashSet, - pub read_class_hash_values: Vec, - // Accessed addresses by the `get_class_hash_at` syscall. - pub accessed_contract_addresses: HashSet, + pub base: syscall_base::SyscallHandlerBase<'state>, // It is set if an unrecoverable error happens during syscall execution pub unrecoverable_error: Option, @@ -89,16 +67,7 @@ impl<'state> NativeSyscallHandler<'state> { context: &'state mut EntryPointExecutionContext, ) -> NativeSyscallHandler<'state> { NativeSyscallHandler { - state, - call, - context, - events: Vec::new(), - l2_to_l1_messages: Vec::new(), - inner_calls: Vec::new(), - read_values: Vec::new(), - accessed_keys: HashSet::new(), - read_class_hash_values: Vec::new(), - accessed_contract_addresses: HashSet::new(), + base: syscall_base::SyscallHandlerBase::new(call, state, context), unrecoverable_error: None, } } @@ -109,7 +78,7 @@ impl<'state> NativeSyscallHandler<'state> { remaining_gas: &mut u64, ) -> SyscallResult { let call_info = entry_point - .execute(self.state, self.context, remaining_gas) + .execute(self.base.state, self.base.context, remaining_gas) .map_err(|e| self.handle_error(remaining_gas, e.into()))?; let retdata = call_info.execution.retdata.clone(); @@ -118,11 +87,15 @@ impl<'state> NativeSyscallHandler<'state> { return Err(self.handle_error(remaining_gas, error)); } - self.inner_calls.push(call_info); + self.base.inner_calls.push(call_info); Ok(retdata) } + pub fn gas_costs(&self) -> &GasCosts { + self.base.context.gas_costs() + } + /// Handles all gas-related logics and perform additional checks. In native, /// we need to explicitly call this method at the beginning of each syscall. fn pre_execute_syscall( @@ -136,7 +109,7 @@ impl<'state> NativeSyscallHandler<'state> { return Err(vec![]); } // Refund `SYSCALL_BASE_GAS_COST` as it was pre-charged. - let required_gas = syscall_gas_cost - self.context.gas_costs().syscall_base_gas_cost; + let required_gas = syscall_gas_cost - self.gas_costs().syscall_base_gas_cost; if *remaining_gas < required_gas { // Out of gas failure. @@ -181,7 +154,7 @@ impl<'state> NativeSyscallHandler<'state> { } fn get_tx_info_v1(&self) -> TxInfo { - let tx_info = &self.context.tx_context.tx_info; + let tx_info = &self.base.context.tx_context.tx_info; TxInfo { version: tx_info.version().0, account_contract_address: Felt::from(tx_info.sender_address()), @@ -189,7 +162,7 @@ impl<'state> NativeSyscallHandler<'state> { signature: tx_info.signature().0, transaction_hash: tx_info.transaction_hash().0, chain_id: Felt::from_hex( - &self.context.tx_context.block_context.chain_info.chain_id.as_hex(), + &self.base.context.tx_context.block_context.chain_info.chain_id.as_hex(), ) .expect("Failed to convert the chain_id to hex."), nonce: tx_info.nonce().0, @@ -197,9 +170,9 @@ impl<'state> NativeSyscallHandler<'state> { } fn get_block_info(&self) -> BlockInfo { - let block_info = &self.context.tx_context.block_context.block_info; - if self.context.execution_mode == ExecutionMode::Validate { - let versioned_constants = self.context.versioned_constants(); + let block_info = &self.base.context.tx_context.block_context.block_info; + if self.base.context.execution_mode == ExecutionMode::Validate { + let versioned_constants = self.base.context.versioned_constants(); let block_number = block_info.block_number.0; let block_timestamp = block_info.block_timestamp.0; // Round down to the nearest multiple of validate_block_number_rounding. @@ -226,7 +199,7 @@ impl<'state> NativeSyscallHandler<'state> { } fn get_tx_info_v2(&self) -> SyscallResult { - let tx_info = &self.context.tx_context.tx_info; + let tx_info = &self.base.context.tx_context.tx_info; let native_tx_info = TxV2Info { version: tx_info.version().0, account_contract_address: Felt::from(tx_info.sender_address()), @@ -234,7 +207,7 @@ impl<'state> NativeSyscallHandler<'state> { signature: tx_info.signature().0, transaction_hash: tx_info.transaction_hash().0, chain_id: Felt::from_hex( - &self.context.tx_context.block_context.chain_info.chain_id.as_hex(), + &self.base.context.tx_context.block_context.chain_info.chain_id.as_hex(), ) .expect("Failed to convert the chain_id to hex."), nonce: tx_info.nonce().0, @@ -262,26 +235,23 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { block_number: u64, remaining_gas: &mut u64, ) -> SyscallResult { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().get_block_hash_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().get_block_hash_gas_cost)?; - match syscall_base::get_block_hash_base(self.context, block_number, self.state) { + match syscall_base::get_block_hash_base(self.base.context, block_number, self.base.state) { Ok(value) => Ok(value), Err(e) => Err(self.handle_error(remaining_gas, e)), } } fn get_execution_info(&mut self, remaining_gas: &mut u64) -> SyscallResult { - self.pre_execute_syscall( - remaining_gas, - self.context.gas_costs().get_execution_info_gas_cost, - )?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().get_execution_info_gas_cost)?; Ok(ExecutionInfo { block_info: self.get_block_info(), tx_info: self.get_tx_info_v1(), - caller_address: Felt::from(self.call.caller_address), - contract_address: Felt::from(self.call.storage_address), - entry_point_selector: self.call.entry_point_selector.0, + caller_address: Felt::from(self.base.call.caller_address), + contract_address: Felt::from(self.base.call.storage_address), + entry_point_selector: self.base.call.entry_point_selector.0, }) } @@ -290,35 +260,30 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { contract_address: Felt, remaining_gas: &mut u64, ) -> SyscallResult { - self.pre_execute_syscall( - remaining_gas, - self.context.gas_costs().get_class_hash_at_gas_cost, - )?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().get_class_hash_at_gas_cost)?; let request = ContractAddress::try_from(contract_address) .map_err(|err| self.handle_error(remaining_gas, err.into()))?; - self.accessed_contract_addresses.insert(request); + self.base.accessed_contract_addresses.insert(request); let class_hash = self + .base .state .get_class_hash_at(request) .map_err(|err| self.handle_error(remaining_gas, err.into()))?; - self.read_class_hash_values.push(class_hash); + self.base.read_class_hash_values.push(class_hash); Ok(class_hash.0) } fn get_execution_info_v2(&mut self, remaining_gas: &mut u64) -> SyscallResult { - self.pre_execute_syscall( - remaining_gas, - self.context.gas_costs().get_execution_info_gas_cost, - )?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().get_execution_info_gas_cost)?; Ok(ExecutionInfoV2 { block_info: self.get_block_info(), tx_info: self.get_tx_info_v2()?, - caller_address: Felt::from(self.call.caller_address), - contract_address: Felt::from(self.call.storage_address), - entry_point_selector: self.call.entry_point_selector.0, + caller_address: Felt::from(self.base.call.caller_address), + contract_address: Felt::from(self.base.call.storage_address), + entry_point_selector: self.base.call.entry_point_selector.0, }) } @@ -330,9 +295,9 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { deploy_from_zero: bool, remaining_gas: &mut u64, ) -> SyscallResult<(Felt, Vec)> { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().deploy_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().deploy_gas_cost)?; - let deployer_address = self.call.storage_address; + let deployer_address = self.base.call.storage_address; let deployer_address_for_calculation = if deploy_from_zero { ContractAddress::default() } else { deployer_address }; @@ -354,21 +319,27 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { caller_address: deployer_address, }; - let call_info = - execute_deployment(self.state, self.context, ctor_context, calldata, remaining_gas) - .map_err(|err| self.handle_error(remaining_gas, err.into()))?; + let call_info = execute_deployment( + self.base.state, + self.base.context, + ctor_context, + calldata, + remaining_gas, + ) + .map_err(|err| self.handle_error(remaining_gas, err.into()))?; let constructor_retdata = call_info.execution.retdata.0[..].to_vec(); - self.inner_calls.push(call_info); + self.base.inner_calls.push(call_info); Ok((Felt::from(deployed_contract_address), constructor_retdata)) } fn replace_class(&mut self, class_hash: Felt, remaining_gas: &mut u64) -> SyscallResult<()> { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().replace_class_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().replace_class_gas_cost)?; let class_hash = ClassHash(class_hash); let contract_class = self + .base .state .get_compiled_contract_class(class_hash) .map_err(|e| self.handle_error(remaining_gas, e.into()))?; @@ -379,8 +350,9 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { SyscallExecutionError::ForbiddenClassReplacement { class_hash }, )), RunnableContractClass::V1(_) | RunnableContractClass::V1Native(_) => { - self.state - .set_class_hash_at(self.call.storage_address, class_hash) + self.base + .state + .set_class_hash_at(self.base.call.storage_address, class_hash) .map_err(|e| self.handle_error(remaining_gas, e.into()))?; Ok(()) @@ -395,7 +367,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { calldata: &[Felt], remaining_gas: &mut u64, ) -> SyscallResult> { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().library_call_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().library_call_gas_cost)?; let class_hash = ClassHash(class_hash); @@ -408,8 +380,8 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { entry_point_selector: EntryPointSelector(function_selector), calldata: wrapper_calldata, // The call context remains the same in a library call. - storage_address: self.call.storage_address, - caller_address: self.call.caller_address, + storage_address: self.base.call.storage_address, + caller_address: self.base.call.caller_address, call_type: CallType::Delegate, initial_gas: *remaining_gas, }; @@ -424,16 +396,16 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { calldata: &[Felt], remaining_gas: &mut u64, ) -> SyscallResult> { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().call_contract_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().call_contract_gas_cost)?; let contract_address = ContractAddress::try_from(address) .map_err(|error| self.handle_error(remaining_gas, error.into()))?; - if self.context.execution_mode == ExecutionMode::Validate - && self.call.storage_address != contract_address + if self.base.context.execution_mode == ExecutionMode::Validate + && self.base.call.storage_address != contract_address { let err = SyscallExecutionError::InvalidSyscallInExecutionMode { syscall_name: "call_contract".to_string(), - execution_mode: self.context.execution_mode, + execution_mode: self.base.context.execution_mode, }; return Err(self.handle_error(remaining_gas, err)); } @@ -447,7 +419,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { entry_point_selector: EntryPointSelector(entry_point_selector), calldata: wrapper_calldata, storage_address: contract_address, - caller_address: self.call.caller_address, + caller_address: self.base.call.caller_address, call_type: CallType::Call, initial_gas: *remaining_gas, }; @@ -461,7 +433,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { address: Felt, remaining_gas: &mut u64, ) -> SyscallResult { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().storage_read_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().storage_read_gas_cost)?; if address_domain != 0 { let address_domain = Felt::from(address_domain); @@ -472,11 +444,11 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { let key = StorageKey::try_from(address) .map_err(|e| self.handle_error(remaining_gas, e.into()))?; - let read_result = self.state.get_storage_at(self.call.storage_address, key); + let read_result = self.base.state.get_storage_at(self.base.call.storage_address, key); let value = read_result.map_err(|e| self.handle_error(remaining_gas, e.into()))?; - self.accessed_keys.insert(key); - self.read_values.push(value); + self.base.accessed_keys.insert(key); + self.base.read_values.push(value); Ok(value) } @@ -488,7 +460,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { value: Felt, remaining_gas: &mut u64, ) -> SyscallResult<()> { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().storage_write_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().storage_write_gas_cost)?; if address_domain != 0 { let address_domain = Felt::from(address_domain); @@ -498,9 +470,10 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { let key = StorageKey::try_from(address) .map_err(|e| self.handle_error(remaining_gas, e.into()))?; - self.accessed_keys.insert(key); + self.base.accessed_keys.insert(key); - let write_result = self.state.set_storage_at(self.call.storage_address, key, value); + let write_result = + self.base.state.set_storage_at(self.base.call.storage_address, key, value); write_result.map_err(|e| self.handle_error(remaining_gas, e.into()))?; Ok(()) @@ -512,23 +485,23 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { data: &[Felt], remaining_gas: &mut u64, ) -> SyscallResult<()> { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().emit_event_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().emit_event_gas_cost)?; - let order = self.context.n_emitted_events; + let order = self.base.context.n_emitted_events; let event = EventContent { keys: keys.iter().copied().map(EventKey).collect(), data: EventData(data.to_vec()), }; exceeds_event_size_limit( - self.context.versioned_constants(), - self.context.n_emitted_events + 1, + self.base.context.versioned_constants(), + self.base.context.n_emitted_events + 1, &event, ) .map_err(|e| self.handle_error(remaining_gas, e.into()))?; - self.events.push(OrderedEvent { order, event }); - self.context.n_emitted_events += 1; + self.base.events.push(OrderedEvent { order, event }); + self.base.context.n_emitted_events += 1; Ok(()) } @@ -539,26 +512,23 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { payload: &[Felt], remaining_gas: &mut u64, ) -> SyscallResult<()> { - self.pre_execute_syscall( - remaining_gas, - self.context.gas_costs().send_message_to_l1_gas_cost, - )?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().send_message_to_l1_gas_cost)?; - let order = self.context.n_sent_messages_to_l1; + let order = self.base.context.n_sent_messages_to_l1; let to_address = EthAddress::try_from(to_address) .map_err(|e| self.handle_error(remaining_gas, e.into()))?; - self.l2_to_l1_messages.push(OrderedL2ToL1Message { + self.base.l2_to_l1_messages.push(OrderedL2ToL1Message { order, message: MessageToL1 { to_address, payload: L2ToL1Payload(payload.to_vec()) }, }); - self.context.n_sent_messages_to_l1 += 1; + self.base.context.n_sent_messages_to_l1 += 1; Ok(()) } fn keccak(&mut self, input: &[u64], remaining_gas: &mut u64) -> SyscallResult { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().keccak_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().keccak_gas_cost)?; const KECCAK_FULL_RATE_IN_WORDS: usize = 17; @@ -577,7 +547,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { // TODO(Ori, 1/2/2024): Write an indicative expect message explaining why the conversion // works. let n_rounds = u64::try_from(n_rounds).expect("Failed to convert usize to u64."); - let gas_cost = n_rounds * self.context.gas_costs().keccak_round_cost_gas_cost; + let gas_cost = n_rounds * self.gas_costs().keccak_round_cost_gas_cost; if gas_cost > *remaining_gas { return Err(self.handle_error( @@ -609,7 +579,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { y: U256, remaining_gas: &mut u64, ) -> SyscallResult> { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256k1_new_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().secp256k1_new_gas_cost)?; Secp256Point::new(x, y) .map(|op| op.map(|p| p.into())) @@ -622,7 +592,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { p1: Secp256k1Point, remaining_gas: &mut u64, ) -> SyscallResult { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256k1_add_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().secp256k1_add_gas_cost)?; Ok(Secp256Point::add(p0.into(), p1.into()).into()) } @@ -633,7 +603,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { m: U256, remaining_gas: &mut u64, ) -> SyscallResult { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256k1_mul_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().secp256k1_mul_gas_cost)?; Ok(Secp256Point::mul(p.into(), m).into()) } @@ -646,7 +616,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { ) -> SyscallResult> { self.pre_execute_syscall( remaining_gas, - self.context.gas_costs().secp256k1_get_point_from_x_gas_cost, + self.gas_costs().secp256k1_get_point_from_x_gas_cost, )?; Secp256Point::get_point_from_x(x, y_parity) @@ -659,10 +629,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { p: Secp256k1Point, remaining_gas: &mut u64, ) -> SyscallResult<(U256, U256)> { - self.pre_execute_syscall( - remaining_gas, - self.context.gas_costs().secp256k1_get_xy_gas_cost, - )?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().secp256k1_get_xy_gas_cost)?; Ok((p.x, p.y)) } @@ -673,7 +640,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { y: U256, remaining_gas: &mut u64, ) -> SyscallResult> { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256r1_new_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().secp256r1_new_gas_cost)?; Secp256Point::new(x, y) .map(|option| option.map(|p| p.into())) @@ -686,7 +653,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { p1: Secp256r1Point, remaining_gas: &mut u64, ) -> SyscallResult { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256r1_add_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().secp256r1_add_gas_cost)?; Ok(Secp256Point::add(p0.into(), p1.into()).into()) } @@ -696,7 +663,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { m: U256, remaining_gas: &mut u64, ) -> SyscallResult { - self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256r1_mul_gas_cost)?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().secp256r1_mul_gas_cost)?; Ok(Secp256Point::mul(p.into(), m).into()) } @@ -709,7 +676,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { ) -> SyscallResult> { self.pre_execute_syscall( remaining_gas, - self.context.gas_costs().secp256r1_get_point_from_x_gas_cost, + self.gas_costs().secp256r1_get_point_from_x_gas_cost, )?; Secp256Point::get_point_from_x(x, y_parity) @@ -722,10 +689,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { p: Secp256r1Point, remaining_gas: &mut u64, ) -> SyscallResult<(U256, U256)> { - self.pre_execute_syscall( - remaining_gas, - self.context.gas_costs().secp256r1_get_xy_gas_cost, - )?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().secp256r1_get_xy_gas_cost)?; Ok((p.x, p.y)) } @@ -736,10 +700,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> { current_block: &[u32; 16], remaining_gas: &mut u64, ) -> SyscallResult<()> { - self.pre_execute_syscall( - remaining_gas, - self.context.gas_costs().sha256_process_block_gas_cost, - )?; + self.pre_execute_syscall(remaining_gas, self.gas_costs().sha256_process_block_gas_cost)?; let data_as_bytes = sha2::digest::generic_array::GenericArray::from_exact_iter( current_block.iter().flat_map(|x| x.to_be_bytes()), diff --git a/crates/blockifier/src/execution/syscalls/hint_processor.rs b/crates/blockifier/src/execution/syscalls/hint_processor.rs index 624004e242..e595bf606e 100644 --- a/crates/blockifier/src/execution/syscalls/hint_processor.rs +++ b/crates/blockifier/src/execution/syscalls/hint_processor.rs @@ -1,5 +1,5 @@ use std::any::Any; -use std::collections::{hash_map, HashMap, HashSet}; +use std::collections::{hash_map, HashMap}; use cairo_lang_casm::hints::{Hint, StarknetHint}; use cairo_lang_runner::casm_run::execute_core_hint_base; @@ -26,7 +26,7 @@ use starknet_types_core::felt::{Felt, FromStrError}; use thiserror::Error; use crate::abi::sierra_types::SierraTypeError; -use crate::execution::call_info::{CallInfo, OrderedEvent, OrderedL2ToL1Message}; +use crate::execution::call_info::CallInfo; use crate::execution::common_hints::{ExecutionMode, HintExecutionResult}; use crate::execution::entry_point::{CallEntryPoint, EntryPointExecutionContext}; use crate::execution::errors::{ConstructorEntryPointExecutionError, EntryPointExecutionError}; @@ -64,6 +64,7 @@ use crate::execution::syscalls::{ sha_256_process_block, storage_read, storage_write, + syscall_base, StorageReadResponse, StorageWriteResponse, SyscallRequest, @@ -76,6 +77,7 @@ use crate::execution::syscalls::{ use crate::state::errors::StateError; use crate::state::state_api::State; use crate::transaction::objects::{CurrentTransactionInfo, TransactionInfo}; +use crate::versioned_constants::GasCosts; pub type SyscallCounter = HashMap; @@ -212,33 +214,15 @@ pub const INVALID_ARGUMENT: &str = /// Executes Starknet syscalls (stateful protocol hints) during the execution of an entry point /// call. pub struct SyscallHintProcessor<'a> { - // Input for execution. - pub state: &'a mut dyn State, - pub context: &'a mut EntryPointExecutionContext, - pub call: CallEntryPoint, - - // Execution results. - /// Inner calls invoked by the current execution. - pub inner_calls: Vec, - pub events: Vec, - pub l2_to_l1_messages: Vec, + pub base: syscall_base::SyscallHandlerBase<'a>, + + // VM-specific fields. pub syscall_counter: SyscallCounter, // Fields needed for execution and validation. pub read_only_segments: ReadOnlySegments, pub syscall_ptr: Relocatable, - // Additional information gathered during execution. - pub read_values: Vec, - pub accessed_keys: HashSet, - pub read_class_hash_values: Vec, - // Accessed addresses by the `get_class_hash_at` syscall. - pub accessed_contract_addresses: HashSet, - - // The original storage value of the executed contract. - // Should be moved back `context.revert_info` before executing an inner call. - pub original_values: HashMap, - // Secp hint processors. pub secp256k1_hint_processor: SecpHintProcessor, pub secp256r1_hint_processor: SecpHintProcessor, @@ -261,29 +245,11 @@ impl<'a> SyscallHintProcessor<'a> { hints: &'a HashMap, read_only_segments: ReadOnlySegments, ) -> Self { - let original_values = std::mem::take( - &mut context - .revert_infos - .0 - .last_mut() - .expect("Missing contract revert info.") - .original_values, - ); SyscallHintProcessor { - state, - context, - call, - inner_calls: vec![], - events: vec![], - l2_to_l1_messages: vec![], + base: syscall_base::SyscallHandlerBase::new(call, state, context), syscall_counter: SyscallCounter::default(), read_only_segments, syscall_ptr: initial_syscall_ptr, - read_values: vec![], - accessed_keys: HashSet::new(), - read_class_hash_values: vec![], - accessed_contract_addresses: HashSet::new(), - original_values, hints, execution_info_ptr: None, secp256k1_hint_processor: SecpHintProcessor::default(), @@ -293,25 +259,29 @@ impl<'a> SyscallHintProcessor<'a> { } pub fn storage_address(&self) -> ContractAddress { - self.call.storage_address + self.base.call.storage_address } pub fn caller_address(&self) -> ContractAddress { - self.call.caller_address + self.base.call.caller_address } pub fn entry_point_selector(&self) -> EntryPointSelector { - self.call.entry_point_selector + self.base.call.entry_point_selector } pub fn execution_mode(&self) -> ExecutionMode { - self.context.execution_mode + self.base.context.execution_mode } pub fn is_validate_mode(&self) -> bool { self.execution_mode() == ExecutionMode::Validate } + pub fn gas_costs(&self) -> &GasCosts { + self.base.context.gas_costs() + } + /// Infers and executes the next syscall. /// Must comply with the API of a hint function, as defined by the `HintProcessor`. pub fn execute_next_syscall( @@ -334,115 +304,91 @@ impl<'a> SyscallHintProcessor<'a> { } match selector { - SyscallSelector::CallContract => self.execute_syscall( - vm, - call_contract, - self.context.gas_costs().call_contract_gas_cost, - ), + SyscallSelector::CallContract => { + self.execute_syscall(vm, call_contract, self.gas_costs().call_contract_gas_cost) + } SyscallSelector::Deploy => { - self.execute_syscall(vm, deploy, self.context.gas_costs().deploy_gas_cost) + self.execute_syscall(vm, deploy, self.gas_costs().deploy_gas_cost) } SyscallSelector::EmitEvent => { - self.execute_syscall(vm, emit_event, self.context.gas_costs().emit_event_gas_cost) + self.execute_syscall(vm, emit_event, self.gas_costs().emit_event_gas_cost) + } + SyscallSelector::GetBlockHash => { + self.execute_syscall(vm, get_block_hash, self.gas_costs().get_block_hash_gas_cost) } - SyscallSelector::GetBlockHash => self.execute_syscall( - vm, - get_block_hash, - self.context.gas_costs().get_block_hash_gas_cost, - ), SyscallSelector::GetClassHashAt => self.execute_syscall( vm, get_class_hash_at, - self.context.gas_costs().get_class_hash_at_gas_cost, + self.gas_costs().get_class_hash_at_gas_cost, ), SyscallSelector::GetExecutionInfo => self.execute_syscall( vm, get_execution_info, - self.context.gas_costs().get_execution_info_gas_cost, + self.gas_costs().get_execution_info_gas_cost, ), SyscallSelector::Keccak => { - self.execute_syscall(vm, keccak, self.context.gas_costs().keccak_gas_cost) + self.execute_syscall(vm, keccak, self.gas_costs().keccak_gas_cost) } SyscallSelector::Sha256ProcessBlock => self.execute_syscall( vm, sha_256_process_block, - self.context.gas_costs().sha256_process_block_gas_cost, - ), - SyscallSelector::LibraryCall => self.execute_syscall( - vm, - library_call, - self.context.gas_costs().library_call_gas_cost, - ), - SyscallSelector::ReplaceClass => self.execute_syscall( - vm, - replace_class, - self.context.gas_costs().replace_class_gas_cost, - ), - SyscallSelector::Secp256k1Add => self.execute_syscall( - vm, - secp256k1_add, - self.context.gas_costs().secp256k1_add_gas_cost, + self.gas_costs().sha256_process_block_gas_cost, ), + SyscallSelector::LibraryCall => { + self.execute_syscall(vm, library_call, self.gas_costs().library_call_gas_cost) + } + SyscallSelector::ReplaceClass => { + self.execute_syscall(vm, replace_class, self.gas_costs().replace_class_gas_cost) + } + SyscallSelector::Secp256k1Add => { + self.execute_syscall(vm, secp256k1_add, self.gas_costs().secp256k1_add_gas_cost) + } SyscallSelector::Secp256k1GetPointFromX => self.execute_syscall( vm, secp256k1_get_point_from_x, - self.context.gas_costs().secp256k1_get_point_from_x_gas_cost, + self.gas_costs().secp256k1_get_point_from_x_gas_cost, ), SyscallSelector::Secp256k1GetXy => self.execute_syscall( vm, secp256k1_get_xy, - self.context.gas_costs().secp256k1_get_xy_gas_cost, - ), - SyscallSelector::Secp256k1Mul => self.execute_syscall( - vm, - secp256k1_mul, - self.context.gas_costs().secp256k1_mul_gas_cost, - ), - SyscallSelector::Secp256k1New => self.execute_syscall( - vm, - secp256k1_new, - self.context.gas_costs().secp256k1_new_gas_cost, - ), - SyscallSelector::Secp256r1Add => self.execute_syscall( - vm, - secp256r1_add, - self.context.gas_costs().secp256r1_add_gas_cost, + self.gas_costs().secp256k1_get_xy_gas_cost, ), + SyscallSelector::Secp256k1Mul => { + self.execute_syscall(vm, secp256k1_mul, self.gas_costs().secp256k1_mul_gas_cost) + } + SyscallSelector::Secp256k1New => { + self.execute_syscall(vm, secp256k1_new, self.gas_costs().secp256k1_new_gas_cost) + } + SyscallSelector::Secp256r1Add => { + self.execute_syscall(vm, secp256r1_add, self.gas_costs().secp256r1_add_gas_cost) + } SyscallSelector::Secp256r1GetPointFromX => self.execute_syscall( vm, secp256r1_get_point_from_x, - self.context.gas_costs().secp256r1_get_point_from_x_gas_cost, + self.gas_costs().secp256r1_get_point_from_x_gas_cost, ), SyscallSelector::Secp256r1GetXy => self.execute_syscall( vm, secp256r1_get_xy, - self.context.gas_costs().secp256r1_get_xy_gas_cost, - ), - SyscallSelector::Secp256r1Mul => self.execute_syscall( - vm, - secp256r1_mul, - self.context.gas_costs().secp256r1_mul_gas_cost, - ), - SyscallSelector::Secp256r1New => self.execute_syscall( - vm, - secp256r1_new, - self.context.gas_costs().secp256r1_new_gas_cost, + self.gas_costs().secp256r1_get_xy_gas_cost, ), + SyscallSelector::Secp256r1Mul => { + self.execute_syscall(vm, secp256r1_mul, self.gas_costs().secp256r1_mul_gas_cost) + } + SyscallSelector::Secp256r1New => { + self.execute_syscall(vm, secp256r1_new, self.gas_costs().secp256r1_new_gas_cost) + } SyscallSelector::SendMessageToL1 => self.execute_syscall( vm, send_message_to_l1, - self.context.gas_costs().send_message_to_l1_gas_cost, - ), - SyscallSelector::StorageRead => self.execute_syscall( - vm, - storage_read, - self.context.gas_costs().storage_read_gas_cost, - ), - SyscallSelector::StorageWrite => self.execute_syscall( - vm, - storage_write, - self.context.gas_costs().storage_write_gas_cost, + self.gas_costs().send_message_to_l1_gas_cost, ), + SyscallSelector::StorageRead => { + self.execute_syscall(vm, storage_read, self.gas_costs().storage_read_gas_cost) + } + SyscallSelector::StorageWrite => { + self.execute_syscall(vm, storage_write, self.gas_costs().storage_write_gas_cost) + } _ => Err(HintError::UnknownHint( format!("Unsupported syscall selector {selector:?}.").into(), )), @@ -515,7 +461,7 @@ impl<'a> SyscallHintProcessor<'a> { ) -> SyscallResult, { // Refund `SYSCALL_BASE_GAS_COST` as it was pre-charged. - let required_gas = syscall_gas_cost - self.context.gas_costs().syscall_base_gas_cost; + let required_gas = syscall_gas_cost - self.base.context.gas_costs().syscall_base_gas_cost; let SyscallRequestWrapper { gas_counter, request } = SyscallRequestWrapper::::read(vm, &mut self.syscall_ptr)?; @@ -586,10 +532,10 @@ impl<'a> SyscallHintProcessor<'a> { &mut self, vm: &mut VirtualMachine, ) -> SyscallResult { - let block_info = &self.context.tx_context.block_context.block_info; + let block_info = &self.base.context.tx_context.block_context.block_info; let block_timestamp = block_info.block_timestamp.0; let block_number = block_info.block_number.0; - let versioned_constants = self.context.versioned_constants(); + let versioned_constants = self.base.context.versioned_constants(); let block_data: Vec = if self.is_validate_mode() { // Round down to the nearest multiple of validate_block_number_rounding. let validate_block_number_rounding = @@ -626,7 +572,7 @@ impl<'a> SyscallHintProcessor<'a> { } fn allocate_tx_info_segment(&mut self, vm: &mut VirtualMachine) -> SyscallResult { - let tx_info = &self.context.tx_context.clone().tx_info; + let tx_info = &self.base.context.tx_context.clone().tx_info; let (tx_signature_start_ptr, tx_signature_end_ptr) = &self.allocate_data_segment(vm, &tx_info.signature().0)?; @@ -638,7 +584,7 @@ impl<'a> SyscallHintProcessor<'a> { tx_signature_end_ptr.into(), (tx_info).transaction_hash().0.into(), Felt::from_hex( - self.context.tx_context.block_context.chain_info.chain_id.as_hex().as_str(), + self.base.context.tx_context.block_context.chain_info.chain_id.as_hex().as_str(), )? .into(), (tx_info).nonce().0.into(), @@ -691,9 +637,9 @@ impl<'a> SyscallHintProcessor<'a> { &mut self, key: StorageKey, ) -> SyscallResult { - self.accessed_keys.insert(key); - let value = self.state.get_storage_at(self.storage_address(), key)?; - self.read_values.push(value); + self.base.accessed_keys.insert(key); + let value = self.base.state.get_storage_at(self.storage_address(), key)?; + self.base.read_values.push(value); Ok(StorageReadResponse { value }) } @@ -705,44 +651,45 @@ impl<'a> SyscallHintProcessor<'a> { ) -> SyscallResult { let contract_address = self.storage_address(); - match self.original_values.entry(key) { + match self.base.original_values.entry(key) { hash_map::Entry::Vacant(entry) => { - entry.insert(self.state.get_storage_at(contract_address, key)?); + entry.insert(self.base.state.get_storage_at(contract_address, key)?); } hash_map::Entry::Occupied(_) => {} } - self.accessed_keys.insert(key); - self.state.set_storage_at(contract_address, key, value)?; + self.base.accessed_keys.insert(key); + self.base.state.set_storage_at(contract_address, key, value)?; Ok(StorageWriteResponse {}) } pub fn finalize(&mut self) { - self.context + self.base + .context .revert_infos .0 .last_mut() .expect("Missing contract revert info.") - .original_values = std::mem::take(&mut self.original_values); + .original_values = std::mem::take(&mut self.base.original_values); } } impl ResourceTracker for SyscallHintProcessor<'_> { fn consumed(&self) -> bool { - self.context.vm_run_resources.consumed() + self.base.context.vm_run_resources.consumed() } fn consume_step(&mut self) { - self.context.vm_run_resources.consume_step() + self.base.context.vm_run_resources.consume_step() } fn get_n_steps(&self) -> Option { - self.context.vm_run_resources.get_n_steps() + self.base.context.vm_run_resources.get_n_steps() } fn run_resources(&self) -> &RunResources { - self.context.vm_run_resources.run_resources() + self.base.context.vm_run_resources.run_resources() } } @@ -803,18 +750,19 @@ pub fn execute_inner_call( syscall_handler: &mut SyscallHintProcessor<'_>, remaining_gas: &mut u64, ) -> SyscallResult { - let revert_idx = syscall_handler.context.revert_infos.0.len(); + let revert_idx = syscall_handler.base.context.revert_infos.0.len(); - let call_info = call.execute(syscall_handler.state, syscall_handler.context, remaining_gas)?; + let call_info = + call.execute(syscall_handler.base.state, syscall_handler.base.context, remaining_gas)?; let mut raw_retdata = call_info.execution.retdata.0.clone(); let failed = call_info.execution.failed; - syscall_handler.inner_calls.push(call_info); + syscall_handler.base.inner_calls.push(call_info); if failed { - syscall_handler.context.revert(revert_idx, syscall_handler.state)?; + syscall_handler.base.context.revert(revert_idx, syscall_handler.base.state)?; // Delete events and l2_to_l1_messages from the reverted call. - let reverted_call = &mut syscall_handler.inner_calls.last_mut().unwrap(); + let reverted_call = &mut syscall_handler.base.inner_calls.last_mut().unwrap(); let mut stack: Vec<&mut CallInfo> = vec![reverted_call]; while let Some(call_info) = stack.pop() { call_info.execution.events.clear(); diff --git a/crates/blockifier/src/execution/syscalls/mod.rs b/crates/blockifier/src/execution/syscalls/mod.rs index df9cbcfbda..7e9ce0f3e1 100644 --- a/crates/blockifier/src/execution/syscalls/mod.rs +++ b/crates/blockifier/src/execution/syscalls/mod.rs @@ -170,7 +170,7 @@ pub fn call_contract( remaining_gas: &mut u64, ) -> SyscallResult { let storage_address = request.contract_address; - let class_hash = syscall_handler.state.get_class_hash_at(storage_address)?; + let class_hash = syscall_handler.base.state.get_class_hash_at(storage_address)?; let selector = request.function_selector; if syscall_handler.is_validate_mode() && syscall_handler.storage_address() != storage_address { return Err(SyscallExecutionError::InvalidSyscallInExecutionMode { @@ -267,8 +267,8 @@ pub fn deploy( caller_address: deployer_address, }; let call_info = execute_deployment( - syscall_handler.state, - syscall_handler.context, + syscall_handler.base.state, + syscall_handler.base.context, ctor_context, request.constructor_calldata, remaining_gas, @@ -276,7 +276,7 @@ pub fn deploy( let constructor_retdata = create_retdata_segment(vm, syscall_handler, &call_info.execution.retdata.0)?; - syscall_handler.inner_calls.push(call_info); + syscall_handler.base.inner_calls.push(call_info); Ok(DeployResponse { contract_address: deployed_contract_address, constructor_retdata }) } @@ -332,7 +332,7 @@ pub fn emit_event( syscall_handler: &mut SyscallHintProcessor<'_>, _remaining_gas: &mut u64, ) -> SyscallResult { - let execution_context = &mut syscall_handler.context; + let execution_context = &mut syscall_handler.base.context; exceeds_event_size_limit( execution_context.versioned_constants(), execution_context.n_emitted_events + 1, @@ -340,7 +340,7 @@ pub fn emit_event( )?; let ordered_event = OrderedEvent { order: execution_context.n_emitted_events, event: request.content }; - syscall_handler.events.push(ordered_event); + syscall_handler.base.events.push(ordered_event); execution_context.n_emitted_events += 1; Ok(EmitEventResponse {}) @@ -390,9 +390,9 @@ pub fn get_block_hash( _remaining_gas: &mut u64, ) -> SyscallResult { let block_hash = BlockHash(syscall_base::get_block_hash_base( - syscall_handler.context, + syscall_handler.base.context, request.block_number.0, - syscall_handler.state, + syscall_handler.base.state, )?); Ok(GetBlockHashResponse { block_hash }) } @@ -501,12 +501,12 @@ pub fn replace_class( ) -> SyscallResult { // Ensure the class is declared (by reading it), and of type V1. let class_hash = request.class_hash; - let class = syscall_handler.state.get_compiled_contract_class(class_hash)?; + let class = syscall_handler.base.state.get_compiled_contract_class(class_hash)?; if !is_cairo1(&class) { return Err(SyscallExecutionError::ForbiddenClassReplacement { class_hash }); } - syscall_handler.state.set_class_hash_at(syscall_handler.storage_address(), class_hash)?; + syscall_handler.base.state.set_class_hash_at(syscall_handler.storage_address(), class_hash)?; Ok(ReplaceClassResponse {}) } @@ -535,12 +535,12 @@ pub fn send_message_to_l1( syscall_handler: &mut SyscallHintProcessor<'_>, _remaining_gas: &mut u64, ) -> SyscallResult { - let execution_context = &mut syscall_handler.context; + let execution_context = &mut syscall_handler.base.context; let ordered_message_to_l1 = OrderedL2ToL1Message { order: execution_context.n_sent_messages_to_l1, message: request.message, }; - syscall_handler.l2_to_l1_messages.push(ordered_message_to_l1); + syscall_handler.base.l2_to_l1_messages.push(ordered_message_to_l1); execution_context.n_sent_messages_to_l1 += 1; Ok(SendMessageToL1Response {}) @@ -672,7 +672,7 @@ pub fn keccak( // TODO(Ori, 1/2/2024): Write an indicative expect message explaining why the conversion works. let n_rounds_as_u64 = u64::try_from(n_rounds).expect("Failed to convert usize to u64."); - let gas_cost = n_rounds_as_u64 * syscall_handler.context.gas_costs().keccak_round_cost_gas_cost; + let gas_cost = n_rounds_as_u64 * syscall_handler.gas_costs().keccak_round_cost_gas_cost; if gas_cost > *remaining_gas { let out_of_gas_error = Felt::from_hex(OUT_OF_GAS_ERROR).map_err(SyscallExecutionError::from)?; @@ -803,8 +803,8 @@ pub(crate) fn get_class_hash_at( syscall_handler: &mut SyscallHintProcessor<'_>, _remaining_gas: &mut u64, ) -> SyscallResult { - syscall_handler.accessed_contract_addresses.insert(request); - let class_hash = syscall_handler.state.get_class_hash_at(request)?; - syscall_handler.read_class_hash_values.push(class_hash); + syscall_handler.base.accessed_contract_addresses.insert(request); + let class_hash = syscall_handler.base.state.get_class_hash_at(request)?; + syscall_handler.base.read_class_hash_values.push(class_hash); Ok(class_hash) } diff --git a/crates/blockifier/src/execution/syscalls/syscall_base.rs b/crates/blockifier/src/execution/syscalls/syscall_base.rs index 15ce73bbad..b561561b8e 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_base.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_base.rs @@ -1,11 +1,15 @@ -/// This file is for sharing common logic between Native and Casm syscalls implementations. -use starknet_api::core::ContractAddress; +use std::collections::{HashMap, HashSet}; +use std::convert::From; +use std::hash::RandomState; + +use starknet_api::core::{ClassHash, ContractAddress}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use crate::abi::constants; +use crate::execution::call_info::{CallInfo, OrderedEvent, OrderedL2ToL1Message}; use crate::execution::common_hints::ExecutionMode; -use crate::execution::entry_point::EntryPointExecutionContext; +use crate::execution::entry_point::{CallEntryPoint, EntryPointExecutionContext}; use crate::execution::syscalls::hint_processor::{ SyscallExecutionError, BLOCK_NUMBER_OUT_OF_RANGE_ERROR, @@ -14,6 +18,61 @@ use crate::state::state_api::State; pub type SyscallResult = Result; +/// This file is for sharing common logic between Native and VM syscall implementations. + +pub struct SyscallHandlerBase<'state> { + // Input for execution. + pub state: &'state mut dyn State, + pub context: &'state mut EntryPointExecutionContext, + pub call: CallEntryPoint, + + // Execution results. + pub events: Vec, + pub l2_to_l1_messages: Vec, + pub inner_calls: Vec, + + // Additional information gathered during execution. + pub read_values: Vec, + pub accessed_keys: HashSet, + pub read_class_hash_values: Vec, + // Accessed addresses by the `get_class_hash_at` syscall. + pub accessed_contract_addresses: HashSet, + + // The original storage value of the executed contract. + // Should be moved back `context.revert_info` before executing an inner call. + pub original_values: HashMap, +} + +impl<'state> SyscallHandlerBase<'state> { + pub fn new( + call: CallEntryPoint, + state: &'state mut dyn State, + context: &'state mut EntryPointExecutionContext, + ) -> SyscallHandlerBase<'state> { + let original_values = std::mem::take( + &mut context + .revert_infos + .0 + .last_mut() + .expect("Missing contract revert info.") + .original_values, + ); + SyscallHandlerBase { + state, + call, + context, + events: Vec::new(), + l2_to_l1_messages: Vec::new(), + inner_calls: Vec::new(), + read_values: Vec::new(), + accessed_keys: HashSet::new(), + read_class_hash_values: Vec::new(), + accessed_contract_addresses: HashSet::new(), + original_values, + } + } +} + pub fn get_block_hash_base( context: &EntryPointExecutionContext, requested_block_number: u64,