Skip to content

Commit

Permalink
feat: Support reverts of inner calls
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Sep 9, 2024
1 parent c65f0b9 commit 5c1b557
Show file tree
Hide file tree
Showing 8 changed files with 8,299 additions and 5,183 deletions.
13,293 changes: 8,129 additions & 5,164 deletions crates/blockifier/feature_contracts/cairo1/compiled/test_contract.casm.json

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions crates/blockifier/feature_contracts/cairo1/test_contract.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,36 @@ mod TestContract {
.span()
}


#[external(v0)]
fn test_call_contract_revert(
ref self: ContractState,
contract_address: ContractAddress,
entry_point_selector: felt252,
calldata: Array::<felt252>
) {
match syscalls::call_contract_syscall(
contract_address, entry_point_selector, calldata.span())
{
Result::Ok(_) => panic!("Expected revert"),
Result::Err(errors) => {
let mut error_span = errors.span();
assert(
*error_span.pop_back().unwrap() == 'ENTRYPOINT_FAILED',
'Unexpected error',
);
},
};
assert(self.my_storage_var.read() == 0, 'values should not change.');
}


#[external(v0)]
fn test_revert_helper(ref self: ContractState) {
self.my_storage_var.write(17);
panic!("test_revert_helper");
}

#[external(v0)]
fn test_emit_events(
self: @ContractState, events_number: u64, keys: Array::<felt252>, data: Array::<felt252>
Expand Down
57 changes: 56 additions & 1 deletion crates/blockifier/src/execution/entry_point.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::cell::RefCell;
use std::cmp::min;
use std::collections::HashMap;
use std::sync::Arc;

use cairo_vm::vm::runners::cairo_runner::{ExecutionResources, ResourceTracker, RunResources};
use num_traits::{Inv, Zero};
use serde::Serialize;
use starknet_api::core::{ClassHash, ContractAddress, EntryPointSelector};
use starknet_api::deprecated_contract_class::EntryPointType;
use starknet_api::state::StorageKey;
use starknet_api::transaction::{Calldata, TransactionVersion};
use starknet_types_core::felt::Felt;

Expand All @@ -21,7 +23,7 @@ use crate::execution::errors::{
PreExecutionError,
};
use crate::execution::execution_utils::execute_entry_point_call;
use crate::state::state_api::State;
use crate::state::state_api::{State, StateResult};
use crate::transaction::objects::{HasRelatedFeeType, TransactionInfo};
use crate::transaction::transaction_types::TransactionType;
use crate::utils::{u128_from_usize, usize_from_u128};
Expand All @@ -37,6 +39,50 @@ pub const FAULTY_CLASS_HASH: &str =
pub type EntryPointExecutionResult<T> = Result<T, EntryPointExecutionError>;
pub type ConstructorEntryPointExecutionResult<T> = Result<T, ConstructorEntryPointExecutionError>;

/// Holds the the information required to revert the execution of an entry point.
#[derive(Debug)]
pub struct EntryPointRevertInfo {
// The contract address that the revert info applies to.
pub contract_address: ContractAddress,
/// The original class hash of the contract that was called.
pub original_class_hash: ClassHash,
/// The original storage values.
pub orig_values: HashMap<StorageKey, Felt>,
}
impl EntryPointRevertInfo {
pub fn new(contract_address: ContractAddress, original_class_hash: ClassHash) -> Self {
Self { contract_address, original_class_hash, orig_values: HashMap::new() }
}
}

/// The ExecutionRevertInfo stores a vector of entry point revert infos.
/// We don't merge infos related same contract as doing it on every nesting level would
/// result in O(N^2) complexity.
#[derive(Default, Debug)]
pub struct ExecutionRevertInfo(pub Vec<EntryPointRevertInfo>);

impl ExecutionRevertInfo {
/// Reverts the state back to the way it was when self.0[revert_idx] was created.
pub fn revert(&mut self, revert_idx: usize, state: &mut dyn State) -> StateResult<()> {
for contract_revert_info in self.0.drain(revert_idx..).rev() {
for (key, value) in contract_revert_info.orig_values.iter() {
state.set_storage_at(contract_revert_info.contract_address, *key, *value)?;
}
state.set_class_hash_at(
contract_revert_info.contract_address,
contract_revert_info.original_class_hash,
)?;
}

Ok(())
}
}

pub struct ExecutionResult {
pub call_info: CallInfo,
pub revert_info: ExecutionRevertInfo,
}

/// Represents a the type of the call (used for debugging).
#[cfg_attr(feature = "transaction_serde", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, Serialize)]
Expand Down Expand Up @@ -89,6 +135,7 @@ impl CallEntryPoint {
Some(class_hash) => class_hash,
None => storage_class_hash, // If not given, take the storage contract class hash.
};

// Hack to prevent version 0 attack on argent accounts.
if tx_context.tx_info.version() == TransactionVersion::ZERO
&& class_hash
Expand All @@ -102,6 +149,11 @@ impl CallEntryPoint {
self.class_hash = Some(class_hash);
let contract_class = state.get_compiled_contract_class(class_hash)?;

context
.revert_infos
.0
.push(EntryPointRevertInfo::new(self.storage_address, storage_class_hash));

execute_entry_point_call(self, contract_class, state, resources, context)
}
}
Expand Down Expand Up @@ -130,6 +182,8 @@ pub struct EntryPointExecutionContext {

// The execution mode affects the behavior of the hint processor.
pub execution_mode: ExecutionMode,

pub revert_infos: ExecutionRevertInfo,
}

impl EntryPointExecutionContext {
Expand All @@ -146,6 +200,7 @@ impl EntryPointExecutionContext {
tx_context: tx_context.clone(),
current_recursion_depth: Default::default(),
execution_mode: mode,
revert_infos: ExecutionRevertInfo(vec![]),
}
}

Expand Down
9 changes: 3 additions & 6 deletions crates/blockifier/src/execution/entry_point_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ pub fn execute_entry_point_call(
n_total_args,
program_extra_data_length,
)?;
if call_info.execution.failed {
return Err(EntryPointExecutionError::ExecutionFailed {
error_data: call_info.execution.retdata.0,
});
}

Ok(call_info)
}
Expand Down Expand Up @@ -369,7 +364,7 @@ fn maybe_fill_holes(

pub fn finalize_execution(
mut runner: CairoRunner,
syscall_handler: SyscallHintProcessor<'_>,
mut syscall_handler: SyscallHintProcessor<'_>,
previous_resources: ExecutionResources,
n_total_args: usize,
program_extra_data_length: usize,
Expand Down Expand Up @@ -409,6 +404,8 @@ pub fn finalize_execution(
*syscall_handler.resources += &versioned_constants
.get_additional_os_syscall_resources(&syscall_handler.syscall_counter)?;

syscall_handler.finalize();

let full_call_resources = &*syscall_handler.resources - &previous_resources;
Ok(CallInfo {
call: syscall_handler.call,
Expand Down
51 changes: 44 additions & 7 deletions crates/blockifier/src/execution/syscalls/hint_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ pub const L1_GAS: &str = "0x0000000000000000000000000000000000000000000000000000
pub const L2_GAS: &str = "0x00000000000000000000000000000000000000000000000000004c325f474153";
// "L1_DATA";
pub const L1_DATA: &str = "0x000000000000000000000000000000000000000000000000004c315f44415441";
// "ENTRYPOINT_FAILED";
pub const ENTRYPOINT_FAILED_ERROR: &str =
"0x000000000000000000000000000000454e545259504f494e545f4641494c4544";

/// Executes Starknet syscalls (stateful protocol hints) during the execution of an entry point
/// call.
Expand All @@ -232,6 +235,10 @@ pub struct SyscallHintProcessor<'a> {
pub read_values: Vec<Felt>,
pub accessed_keys: HashSet<StorageKey>,

// The original storage value of the executed contract.
// Should be moved back `context.revert_info` before executing an inner call.
pub orig_values: HashMap<StorageKey, Felt>,

// Secp hint processors.
pub secp256k1_hint_processor: SecpHintProcessor<ark_secp256k1::Config>,
pub secp256r1_hint_processor: SecpHintProcessor<ark_secp256r1::Config>,
Expand All @@ -254,6 +261,14 @@ impl<'a> SyscallHintProcessor<'a> {
hints: &'a HashMap<String, Hint>,
read_only_segments: ReadOnlySegments,
) -> Self {
let orig_values = std::mem::take(
&mut context
.revert_infos
.0
.last_mut()
.expect("Missing contract revert info.")
.orig_values,
);
SyscallHintProcessor {
state,
resources,
Expand All @@ -267,6 +282,7 @@ impl<'a> SyscallHintProcessor<'a> {
syscall_ptr: initial_syscall_ptr,
read_values: vec![],
accessed_keys: HashSet::new(),
orig_values,
hints,
execution_info_ptr: None,
secp256k1_hint_processor: SecpHintProcessor::default(),
Expand Down Expand Up @@ -698,11 +714,24 @@ impl<'a> SyscallHintProcessor<'a> {
key: StorageKey,
value: Felt,
) -> SyscallResult<StorageWriteResponse> {
let contract_address = self.storage_address();
self.orig_values
.entry(key)
.or_insert_with(|| self.state.get_storage_at(contract_address, key).unwrap());
self.accessed_keys.insert(key);
self.state.set_storage_at(self.storage_address(), key, value)?;
self.state.set_storage_at(contract_address, key, value)?;

Ok(StorageWriteResponse {})
}

pub fn finalize(&mut self) {
self.context
.revert_infos
.0
.last_mut()
.expect("Missing contract revert info.")
.orig_values = std::mem::take(&mut self.orig_values);
}
}

/// Retrieves a [Relocatable] from the VM given a [ResOperand].
Expand Down Expand Up @@ -800,21 +829,29 @@ pub fn execute_inner_call(
syscall_handler: &mut SyscallHintProcessor<'_>,
remaining_gas: &mut u64,
) -> SyscallResult<ReadOnlySegment> {
let revert_idx = syscall_handler.context.revert_infos.0.len();

let call_info =
call.execute(syscall_handler.state, syscall_handler.resources, syscall_handler.context)?;
let raw_retdata = &call_info.execution.retdata.0;

if call_info.execution.failed {
// TODO(spapini): Append an error word according to starknet spec if needed.
// Something like "EXECUTION_ERROR".
return Err(SyscallExecutionError::SyscallError { error_data: raw_retdata.clone() });
}
syscall_handler.context.revert_infos.revert(revert_idx, syscall_handler.state)?;
};

let mut raw_retdata = call_info.execution.retdata.0.clone();

let retdata_segment = create_retdata_segment(vm, syscall_handler, raw_retdata)?;
update_remaining_gas(remaining_gas, &call_info);

let failed = call_info.execution.failed;
syscall_handler.inner_calls.push(call_info);

if failed {
raw_retdata
.push(Felt::from_hex(ENTRYPOINT_FAILED_ERROR).map_err(SyscallExecutionError::from)?);
return Err(SyscallExecutionError::SyscallError { error_data: raw_retdata });
}

let retdata_segment = create_retdata_segment(vm, syscall_handler, &raw_retdata)?;
Ok(retdata_segment)
}

Expand Down
6 changes: 4 additions & 2 deletions crates/blockifier/src/execution/syscalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,11 @@ pub fn call_contract(
call_type: CallType::Call,
initial_gas: *remaining_gas,
};

let retdata_segment = execute_inner_call(entry_point, vm, syscall_handler, remaining_gas)
.map_err(|error| {
error.as_call_contract_execution_error(class_hash, storage_address, selector)
.map_err(|error| match error {
SyscallExecutionError::SyscallError { .. } => error,
_ => error.as_call_contract_execution_error(class_hash, storage_address, selector),
})?;

Ok(CallContractResponse { segment: retdata_segment })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@ use crate::test_utils::contracts::FeatureContract;
use crate::test_utils::initial_test_state::test_state;
use crate::test_utils::{create_calldata, trivial_external_entry_point_new, CairoVersion, BALANCE};

#[test]
fn test_call_contract_that_panics() {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
let chain_info = &ChainInfo::create_for_testing();
let mut state = test_state(chain_info, BALANCE, &[(test_contract, 1)]);

let outer_entry_point_selector = selector_from_name("test_call_contract_revert");
let calldata = create_calldata(
FeatureContract::TestContract(CairoVersion::Cairo1).get_instance_address(0),
"test_revert_helper",
&[],
);
let entry_point_call = CallEntryPoint {
entry_point_selector: outer_entry_point_selector,
calldata,
..trivial_external_entry_point_new(test_contract)
};

let res = entry_point_call.execute_directly(&mut state).unwrap();
assert_eq!(
res.execution,
CallExecution {
retdata: retdata![],
gas_consumed: 164420,
failed: false,
..CallExecution::default()
}
);
}

#[test_case(
FeatureContract::TestContract(CairoVersion::Cairo1),
FeatureContract::TestContract(CairoVersion::Cairo1),
Expand Down
6 changes: 3 additions & 3 deletions crates/blockifier/src/transaction/transactions_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1853,11 +1853,11 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) {
},
execution: CallExecution {
retdata: Retdata(vec![value]),
gas_consumed: 6820,
gas_consumed: 6120,
..Default::default()
},
resources: ExecutionResources {
n_steps: 158,
n_steps: 151,
n_memory_holes: 0,
builtin_instance_counter: HashMap::from([(BuiltinName::range_check, 6)]),
},
Expand Down Expand Up @@ -1893,7 +1893,7 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) {
+ 6,
),
]),
n_steps: get_tx_resources(TransactionType::L1Handler).n_steps + 171,
n_steps: get_tx_resources(TransactionType::L1Handler).n_steps + 164,
n_memory_holes: 0,
};

Expand Down

0 comments on commit 5c1b557

Please sign in to comment.