Skip to content

Commit

Permalink
refactor(blockifier): rename ContractClass to CompiledClass (#2296)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoni-Starkware authored Nov 27, 2024
1 parent e5576a4 commit cc5c84b
Show file tree
Hide file tree
Showing 44 changed files with 255 additions and 308 deletions.
2 changes: 1 addition & 1 deletion crates/blockifier/src/blockifier/transaction_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl<S: StateReader> TransactionExecutor<S> {
.block_state
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_compiled_contract_class(*class_hash)?;
.get_compiled_class(*class_hash)?;
Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?))
})
.collect::<TransactionExecutorResult<_>>()?;
Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/bouncer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ pub fn get_casm_hash_calculation_resources<S: StateReader>(
let mut casm_hash_computation_resources = ExecutionResources::default();

for class_hash in executed_class_hashes {
let class = state_reader.get_compiled_contract_class(*class_hash)?;
let class = state_reader.get_compiled_class(*class_hash)?;
casm_hash_computation_resources += &class.estimate_casm_hash_computation_resources();
}

Expand Down
11 changes: 4 additions & 7 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use starknet_types_core::felt::Felt;

use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::RunnableContractClass;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};
Expand All @@ -34,7 +34,7 @@ pub struct VersionedState<S: StateReader> {
// the compiled contract classes mapping. Each key with value false, sohuld not apprear
// in the compiled contract classes mapping.
declared_contracts: VersionedStorage<ClassHash, bool>,
compiled_contract_classes: VersionedStorage<ClassHash, RunnableContractClass>,
compiled_contract_classes: VersionedStorage<ClassHash, RunnableCompiledClass>,
}

impl<S: StateReader> VersionedState<S> {
Expand Down Expand Up @@ -336,14 +336,11 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {
}
}

fn get_compiled_contract_class(
&self,
class_hash: ClassHash,
) -> StateResult<RunnableContractClass> {
fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
let mut state = self.state();
match state.compiled_contract_classes.read(self.tx_index, class_hash) {
Some(value) => Ok(value),
None => match state.initial_state.get_compiled_contract_class(class_hash) {
None => match state.initial_state.get_compiled_class(class_hash) {
Ok(initial_value) => {
state.declared_contracts.set_initial_value(class_hash, true);
state
Expand Down
20 changes: 6 additions & 14 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,9 @@ fn test_versioned_state_proxy() {
versioned_state_proxys[5].get_compiled_class_hash(class_hash).unwrap(),
compiled_class_hash
);
assert_eq!(
versioned_state_proxys[7].get_compiled_contract_class(class_hash).unwrap(),
contract_class
);
assert_eq!(versioned_state_proxys[7].get_compiled_class(class_hash).unwrap(), contract_class);
assert_matches!(
versioned_state_proxys[7].get_compiled_contract_class(another_class_hash).unwrap_err(),
versioned_state_proxys[7].get_compiled_class(another_class_hash).unwrap_err(),
StateError::UndeclaredClassHash(class_hash) if
another_class_hash == class_hash
);
Expand Down Expand Up @@ -197,7 +194,7 @@ fn test_versioned_state_proxy() {
compiled_class_hash_v18
);
assert_eq!(
versioned_state_proxys[15].get_compiled_contract_class(class_hash).unwrap(),
versioned_state_proxys[15].get_compiled_class(class_hash).unwrap(),
contract_class_v11
);
}
Expand Down Expand Up @@ -321,7 +318,7 @@ fn test_validate_reads(

assert!(transactional_state.cache.borrow().initial_reads.declared_contracts.is_empty());
assert_matches!(
transactional_state.get_compiled_contract_class(class_hash),
transactional_state.get_compiled_class(class_hash),
Err(StateError::UndeclaredClassHash(err_class_hash)) if
err_class_hash == class_hash
);
Expand Down Expand Up @@ -440,10 +437,7 @@ fn test_apply_writes(
&HashMap::default(),
);
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0);
assert!(
transactional_states[1].get_compiled_contract_class(class_hash).unwrap()
== contract_class_0
);
assert!(transactional_states[1].get_compiled_class(class_hash).unwrap() == contract_class_0);
}

#[rstest]
Expand Down Expand Up @@ -659,7 +653,5 @@ fn test_versioned_proxy_state_flow(
.commit_chunk_and_recover_block_state(4, HashMap::new());

assert!(modified_block_state.get_class_hash_at(contract_address).unwrap() == class_hash_3);
assert!(
modified_block_state.get_compiled_contract_class(class_hash).unwrap() == contract_class_2
);
assert!(modified_block_state.get_compiled_class(class_hash).unwrap() == contract_class_2);
}
53 changes: 27 additions & 26 deletions crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use crate::execution::entry_point::CallEntryPoint;
use crate::execution::errors::PreExecutionError;
use crate::execution::execution_utils::{poseidon_hash_many_cost, sn_api_to_cairo_vm_program};
#[cfg(feature = "cairo_native")]
use crate::execution::native::contract_class::NativeContractClassV1;
use crate::execution::native::contract_class::NativeCompiledClassV1;
use crate::transaction::errors::TransactionExecutionError;
use crate::versioned_constants::CompilerVersion;

Expand All @@ -58,16 +58,17 @@ pub enum TrackedResource {
SierraGas, // AKA Sierra mode.
}

/// Represents a runnable Starknet contract class (meaning, the program is runnable by the VM).
/// Represents a runnable Starknet compiled class.
/// Meaning, the program is runnable by the VM (or natively).
#[derive(Clone, Debug, Eq, PartialEq, derive_more::From)]
pub enum RunnableContractClass {
V0(ContractClassV0),
V1(ContractClassV1),
pub enum RunnableCompiledClass {
V0(CompiledClassV0),
V1(CompiledClassV1),
#[cfg(feature = "cairo_native")]
V1Native(NativeContractClassV1),
V1Native(NativeCompiledClassV1),
}

impl TryFrom<ContractClass> for RunnableContractClass {
impl TryFrom<ContractClass> for RunnableCompiledClass {
type Error = ProgramError;

fn try_from(raw_contract_class: ContractClass) -> Result<Self, Self::Error> {
Expand All @@ -80,7 +81,7 @@ impl TryFrom<ContractClass> for RunnableContractClass {
}
}

impl RunnableContractClass {
impl RunnableCompiledClass {
pub fn constructor_selector(&self) -> Option<EntryPointSelector> {
match self {
Self::V0(class) => class.constructor_selector(),
Expand Down Expand Up @@ -156,16 +157,16 @@ impl RunnableContractClass {
// Note: when deserializing from a SN API class JSON string, the ABI field is ignored
// by serde, since it is not required for execution.
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq)]
pub struct ContractClassV0(pub Arc<ContractClassV0Inner>);
impl Deref for ContractClassV0 {
type Target = ContractClassV0Inner;
pub struct CompiledClassV0(pub Arc<CompiledClassV0Inner>);
impl Deref for CompiledClassV0 {
type Target = CompiledClassV0Inner;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl ContractClassV0 {
impl CompiledClassV0 {
fn constructor_selector(&self) -> Option<EntryPointSelector> {
Some(self.entry_points_by_type[&EntryPointType::Constructor].first()?.selector)
}
Expand Down Expand Up @@ -201,24 +202,24 @@ impl ContractClassV0 {
TrackedResource::CairoSteps
}

pub fn try_from_json_string(raw_contract_class: &str) -> Result<ContractClassV0, ProgramError> {
let contract_class: ContractClassV0Inner = serde_json::from_str(raw_contract_class)?;
Ok(ContractClassV0(Arc::new(contract_class)))
pub fn try_from_json_string(raw_contract_class: &str) -> Result<CompiledClassV0, ProgramError> {
let contract_class: CompiledClassV0Inner = serde_json::from_str(raw_contract_class)?;
Ok(CompiledClassV0(Arc::new(contract_class)))
}
}

#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq)]
pub struct ContractClassV0Inner {
pub struct CompiledClassV0Inner {
#[serde(deserialize_with = "deserialize_program")]
pub program: Program,
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPointV0>>,
}

impl TryFrom<DeprecatedContractClass> for ContractClassV0 {
impl TryFrom<DeprecatedContractClass> for CompiledClassV0 {
type Error = ProgramError;

fn try_from(class: DeprecatedContractClass) -> Result<Self, Self::Error> {
Ok(Self(Arc::new(ContractClassV0Inner {
Ok(Self(Arc::new(CompiledClassV0Inner {
program: sn_api_to_cairo_vm_program(class.program)?,
entry_points_by_type: class.entry_points_by_type,
})))
Expand All @@ -227,20 +228,20 @@ impl TryFrom<DeprecatedContractClass> for ContractClassV0 {

// V1.

/// Represents a runnable Cario (Cairo 1) Starknet contract class (meaning, the program is runnable
/// Represents a runnable Cario (Cairo 1) Starknet compiled class (meaning, the program is runnable
/// by the VM). We wrap the actual class in an Arc to avoid cloning the program when cloning the
/// class.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ContractClassV1(pub Arc<ContractClassV1Inner>);
impl Deref for ContractClassV1 {
pub struct CompiledClassV1(pub Arc<ContractClassV1Inner>);
impl Deref for CompiledClassV1 {
type Target = ContractClassV1Inner;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl ContractClassV1 {
impl CompiledClassV1 {
pub fn constructor_selector(&self) -> Option<EntryPointSelector> {
self.0.entry_points_by_type.constructor.first().map(|ep| ep.selector)
}
Expand Down Expand Up @@ -286,9 +287,9 @@ impl ContractClassV1 {
get_visited_segments(&self.bytecode_segment_lengths, &mut reversed_visited_pcs, &mut 0)
}

pub fn try_from_json_string(raw_contract_class: &str) -> Result<ContractClassV1, ProgramError> {
pub fn try_from_json_string(raw_contract_class: &str) -> Result<CompiledClassV1, ProgramError> {
let casm_contract_class: CasmContractClass = serde_json::from_str(raw_contract_class)?;
let contract_class = ContractClassV1::try_from(casm_contract_class)?;
let contract_class = CompiledClassV1::try_from(casm_contract_class)?;

Ok(contract_class)
}
Expand Down Expand Up @@ -413,7 +414,7 @@ impl HasSelector for EntryPointV1 {
}
}

impl TryFrom<CasmContractClass> for ContractClassV1 {
impl TryFrom<CasmContractClass> for CompiledClassV1 {
type Error = ProgramError;

fn try_from(class: CasmContractClass) -> Result<Self, Self::Error> {
Expand Down Expand Up @@ -466,7 +467,7 @@ impl TryFrom<CasmContractClass> for ContractClassV1 {
Version::parse(&class.compiler_version)
.unwrap_or_else(|_| panic!("Invalid version: '{}'", class.compiler_version)),
);
Ok(ContractClassV1(Arc::new(ContractClassV1Inner {
Ok(CompiledClassV1(Arc::new(ContractClassV1Inner {
program,
entry_points_by_type,
hints: string_to_hint,
Expand Down
4 changes: 2 additions & 2 deletions crates/blockifier/src/execution/contract_class_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use assert_matches::assert_matches;
use cairo_lang_starknet_classes::NestedIntList;
use rstest::rstest;

use crate::execution::contract_class::{ContractClassV1, ContractClassV1Inner};
use crate::execution::contract_class::{CompiledClassV1, ContractClassV1Inner};
use crate::transaction::errors::TransactionExecutionError;

#[rstest]
fn test_get_visited_segments() {
let test_contract = ContractClassV1(Arc::new(ContractClassV1Inner {
let test_contract = CompiledClassV1(Arc::new(ContractClassV1Inner {
program: Default::default(),
entry_points_by_type: Default::default(),
hints: Default::default(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use starknet_api::hash::StarkHash;

use super::execution_utils::SEGMENT_ARENA_BUILTIN_SIZE;
use crate::execution::call_info::{CallExecution, CallInfo, ChargedResources};
use crate::execution::contract_class::{ContractClassV0, TrackedResource};
use crate::execution::contract_class::{CompiledClassV0, TrackedResource};
use crate::execution::deprecated_syscalls::hint_processor::DeprecatedSyscallHintProcessor;
use crate::execution::entry_point::{
CallEntryPoint,
Expand Down Expand Up @@ -44,12 +44,12 @@ pub const CAIRO0_BUILTINS_NAMES: [BuiltinName; 6] = [
/// Executes a specific call to a contract entry point and returns its output.
pub fn execute_entry_point_call(
call: CallEntryPoint,
contract_class: ContractClassV0,
compiled_class: CompiledClassV0,
state: &mut dyn State,
context: &mut EntryPointExecutionContext,
) -> EntryPointExecutionResult<CallInfo> {
let VmExecutionContext { mut runner, mut syscall_handler, initial_syscall_ptr, entry_point_pc } =
initialize_execution_context(&call, contract_class, state, context)?;
initialize_execution_context(&call, compiled_class, state, context)?;

let (implicit_args, args) = prepare_call_arguments(
&call,
Expand All @@ -67,13 +67,13 @@ pub fn execute_entry_point_call(

pub fn initialize_execution_context<'a>(
call: &CallEntryPoint,
contract_class: ContractClassV0,
compiled_class: CompiledClassV0,
state: &'a mut dyn State,
context: &'a mut EntryPointExecutionContext,
) -> Result<VmExecutionContext<'a>, PreExecutionError> {
// Verify use of cairo0 builtins only.
let program_builtins: HashSet<&BuiltinName> =
HashSet::from_iter(contract_class.program.iter_builtins());
HashSet::from_iter(compiled_class.program.iter_builtins());
let unsupported_builtins =
&program_builtins - &HashSet::from_iter(CAIRO0_BUILTINS_NAMES.iter());
if !unsupported_builtins.is_empty() {
Expand All @@ -83,14 +83,14 @@ pub fn initialize_execution_context<'a>(
}

// Resolve initial PC from EP indicator.
let entry_point_pc = resolve_entry_point_pc(call, &contract_class)?;
let entry_point_pc = resolve_entry_point_pc(call, &compiled_class)?;
// Instantiate Cairo runner.
let proof_mode = false;
let trace_enabled = false;
let allow_missing_builtins = false;
let program_base = None;
let mut runner =
CairoRunner::new(&contract_class.program, LayoutName::starknet, proof_mode, trace_enabled)?;
CairoRunner::new(&compiled_class.program, LayoutName::starknet, proof_mode, trace_enabled)?;

runner.initialize_builtins(allow_missing_builtins)?;
runner.initialize_segments(program_base);
Expand All @@ -110,15 +110,15 @@ pub fn initialize_execution_context<'a>(

pub fn resolve_entry_point_pc(
call: &CallEntryPoint,
contract_class: &ContractClassV0,
compiled_class: &CompiledClassV0,
) -> Result<usize, PreExecutionError> {
if call.entry_point_type == EntryPointType::Constructor
&& call.entry_point_selector != selector_from_name(CONSTRUCTOR_ENTRY_POINT_NAME)
{
return Err(PreExecutionError::InvalidConstructorEntryPointName);
}

let entry_points_of_same_type = &contract_class.entry_points_by_type[&call.entry_point_type];
let entry_points_of_same_type = &compiled_class.entry_points_by_type[&call.entry_point_type];
let filtered_entry_points: Vec<_> = entry_points_of_same_type
.iter()
.filter(|ep| ep.selector == call.entry_point_selector)
Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/execution/deprecated_syscalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ pub fn replace_class(
syscall_handler: &mut DeprecatedSyscallHintProcessor<'_>,
) -> DeprecatedSyscallResult<ReplaceClassResponse> {
// Ensure the class is declared (by reading it).
syscall_handler.state.get_compiled_contract_class(request.class_hash)?;
syscall_handler.state.get_compiled_class(request.class_hash)?;
syscall_handler.state.set_class_hash_at(syscall_handler.storage_address, request.class_hash)?;

Ok(ReplaceClassResponse {})
Expand Down
13 changes: 6 additions & 7 deletions crates/blockifier/src/execution/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl CallEntryPoint {
}
// Add class hash to the call, that will appear in the output (call info).
self.class_hash = Some(class_hash);
let contract_class = state.get_compiled_contract_class(class_hash)?;
let compiled_class = state.get_compiled_class(class_hash)?;

context.revert_infos.0.push(EntryPointRevertInfo::new(
self.storage_address,
Expand All @@ -157,7 +157,7 @@ impl CallEntryPoint {
));

// This is the last operation of this function.
execute_entry_point_call_wrapper(self, contract_class, state, context, remaining_gas)
execute_entry_point_call_wrapper(self, compiled_class, state, context, remaining_gas)
}

/// Similar to `execute`, but returns an error if the outer call is reverted.
Expand Down Expand Up @@ -406,11 +406,10 @@ pub fn execute_constructor_entry_point(
remaining_gas: &mut u64,
) -> ConstructorEntryPointExecutionResult<CallInfo> {
// Ensure the class is declared (by reading it).
let contract_class =
state.get_compiled_contract_class(ctor_context.class_hash).map_err(|error| {
ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None)
})?;
let Some(constructor_selector) = contract_class.constructor_selector() else {
let compiled_class = state.get_compiled_class(ctor_context.class_hash).map_err(|error| {
ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None)
})?;
let Some(constructor_selector) = compiled_class.constructor_selector() else {
// Contract has no constructor.
return handle_empty_constructor(&ctor_context, calldata, *remaining_gas)
.map_err(|error| ConstructorEntryPointExecutionError::new(error, &ctor_context, None));
Expand Down
Loading

0 comments on commit cc5c84b

Please sign in to comment.