Skip to content

Commit

Permalink
feat(cairo_native): use contract class manager in papyrus reader
Browse files Browse the repository at this point in the history
  • Loading branch information
avi-starkware committed Dec 17, 2024
1 parent 49520e2 commit a624362
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 56 deletions.
1 change: 0 additions & 1 deletion crates/blockifier/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod cached_state;
#[cfg(feature = "cairo_native")]
pub mod contract_class_manager;
#[cfg(test)]
pub mod error_format_test;
Expand Down
58 changes: 40 additions & 18 deletions crates/blockifier/src/state/contract_class_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TrySendError};
use std::sync::Arc;

use log::{error, info};
use starknet_api::contract_class::SierraVersion;
use starknet_api::core::ClassHash;
use starknet_api::state::SierraContractClass;
use starknet_sierra_compile::command_line_compiler::CommandLineCompiler;
Expand All @@ -10,10 +11,15 @@ use starknet_sierra_compile::utils::into_contract_class_for_compilation;
use starknet_sierra_compile::SierraToNativeCompiler;

use crate::blockifier::config::ContractClassManagerConfig;
use crate::execution::contract_class::{CompiledClassV1, RunnableCompiledClass};
use crate::execution::contract_class::{
CompiledClassV1,
RunnableCompiledClass,
VersionedRunnableCompiledClass,
};
use crate::execution::native::contract_class::NativeCompiledClassV1;
use crate::state::global_cache::{CachedCairoNative, ContractCaches};

#[cfg(feature = "cairo_native")]
const CHANNEL_SIZE: usize = 1000;

/// Represents a request to compile a sierra contract class to a native compiled class.
Expand All @@ -26,7 +32,7 @@ const CHANNEL_SIZE: usize = 1000;
type CompilationRequest = (ClassHash, Arc<SierraContractClass>, CompiledClassV1);

/// Manages the global cache of contract classes and handles sierra-to-native compilation requests.
struct ContractClassManager {
pub struct ContractClassManager {
config: ContractClassManagerConfig,
/// The global cache of contract classes: casm, sierra, and native.
contract_caches: Arc<ContractCaches>,
Expand All @@ -35,7 +41,6 @@ struct ContractClassManager {
sender: Option<SyncSender<CompilationRequest>>,
}

#[allow(dead_code)]
impl ContractClassManager {
/// Creates a new contract class manager and spawns a thread that listens for compilation
/// requests and processes them (a.k.a. the compilation worker).
Expand All @@ -47,18 +52,25 @@ impl ContractClassManager {
// Native compilation is disabled - no need to start the compilation worker.
return ContractClassManager { config, contract_caches, sender: None };
}
let (sender, receiver) = sync_channel(CHANNEL_SIZE);
let compiler_config = SierraToCasmCompilationConfig::default();
let compiler = CommandLineCompiler::new(compiler_config);

std::thread::spawn({
let contract_caches = Arc::clone(&contract_caches);
let compiler = Arc::new(compiler);

move || run_compilation_worker(contract_caches, receiver, compiler)
});

ContractClassManager { config, contract_caches, sender: Some(sender) }
#[cfg(not(feature = "cairo_native"))]
unimplemented!(
"Native compilation cannot be enabled when the cairo_native feature is turned off."
);
#[cfg(feature = "cairo_native")]
{
let (sender, receiver) = sync_channel(CHANNEL_SIZE);
let compiler_config = SierraToCasmCompilationConfig::default();
let compiler = CommandLineCompiler::new(compiler_config);

std::thread::spawn({
let contract_caches = Arc::clone(&contract_caches);
let compiler = Arc::new(compiler);

move || run_compilation_worker(contract_caches, receiver, compiler)
});

ContractClassManager { config, contract_caches, sender: Some(sender) }
}
}

/// Sends a compilation request to the compilation worker. Does not block the sender. Logs an
Expand All @@ -83,6 +95,7 @@ impl ContractClassManager {
}

/// Returns the native compiled class for the given class hash, if it exists in cache.
#[cfg(feature = "cairo_native")]
pub fn get_native(&self, class_hash: &ClassHash) -> Option<CachedCairoNative> {
self.contract_caches.get_native(class_hash)
}
Expand All @@ -93,20 +106,29 @@ impl ContractClassManager {
}

/// Returns the casm compiled class for the given class hash, if it exists in cache.
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<RunnableCompiledClass> {
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<VersionedRunnableCompiledClass> {
self.contract_caches.get_casm(class_hash)
}

/// Sets the casm compiled class for the given class hash in the cache.
pub fn set_casm(&self, class_hash: ClassHash, compiled_class: RunnableCompiledClass) {
pub fn set_casm(&self, class_hash: ClassHash, compiled_class: VersionedRunnableCompiledClass) {
self.contract_caches.set_casm(class_hash, compiled_class);
}

/// Clear the contract caches.
pub fn clear(&mut self) {
self.contract_caches.clear();
}

/// Caches the sierra and casm contract classes of a compilation request.
fn cache_request_contracts(&self, request: &CompilationRequest) {
let (class_hash, sierra, casm) = request.clone();
let sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program).unwrap();
self.contract_caches.set_sierra(class_hash, sierra);
let cached_casm = RunnableCompiledClass::from(casm);
let cached_casm = VersionedRunnableCompiledClass::Cairo1((
RunnableCompiledClass::from(casm),
sierra_version,
));
self.contract_caches.set_casm(class_hash, cached_casm);
}
}
Expand Down
13 changes: 9 additions & 4 deletions crates/blockifier/src/state/global_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use starknet_api::state::SierraContractClass;

#[cfg(feature = "cairo_native")]
use crate::execution::contract_class::RunnableCompiledClass;
use crate::execution::contract_class::VersionedRunnableCompiledClass;
#[cfg(feature = "cairo_native")]
use crate::execution::native::contract_class::NativeCompiledClassV1;

Expand Down Expand Up @@ -51,27 +52,29 @@ impl<T: Clone> GlobalContractCache<T> {
}
}

#[cfg(feature = "cairo_native")]
pub struct ContractCaches {
pub casm_cache: GlobalContractCache<RunnableCompiledClass>,
pub casm_cache: GlobalContractCache<VersionedRunnableCompiledClass>,
#[cfg(feature = "cairo_native")]
pub native_cache: GlobalContractCache<CachedCairoNative>,
pub sierra_cache: GlobalContractCache<Arc<SierraContractClass>>,
}

#[cfg(feature = "cairo_native")]
impl ContractCaches {
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<RunnableCompiledClass> {
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<VersionedRunnableCompiledClass> {
self.casm_cache.get(class_hash)
}

pub fn set_casm(&self, class_hash: ClassHash, compiled_class: RunnableCompiledClass) {
pub fn set_casm(&self, class_hash: ClassHash, compiled_class: VersionedRunnableCompiledClass) {
self.casm_cache.set(class_hash, compiled_class);
}

#[cfg(feature = "cairo_native")]
pub fn get_native(&self, class_hash: &ClassHash) -> Option<CachedCairoNative> {
self.native_cache.get(class_hash)
}

#[cfg(feature = "cairo_native")]
pub fn set_native(&self, class_hash: ClassHash, contract_executor: CachedCairoNative) {
self.native_cache.set(class_hash, contract_executor);
}
Expand All @@ -87,13 +90,15 @@ impl ContractCaches {
pub fn new(cache_size: usize) -> Self {
Self {
casm_cache: GlobalContractCache::new(cache_size),
#[cfg(feature = "cairo_native")]
native_cache: GlobalContractCache::new(cache_size),
sierra_cache: GlobalContractCache::new(cache_size),
}
}

pub fn clear(&mut self) {
self.casm_cache.clear();
#[cfg(feature = "cairo_native")]
self.native_cache.clear();
self.sierra_cache.clear();
}
Expand Down
42 changes: 21 additions & 21 deletions crates/native_blockifier/src/py_block_executor.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#![allow(non_local_definitions)]

use std::collections::HashMap;
use std::sync::Arc;

use blockifier::abi::constants as abi_constants;
use blockifier::blockifier::config::{ContractClassManagerConfig, TransactionExecutorConfig};
use blockifier::blockifier::transaction_executor::{TransactionExecutor, TransactionExecutorError};
use blockifier::bouncer::BouncerConfig;
use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses};
use blockifier::execution::call_info::CallInfo;
use blockifier::execution::contract_class::VersionedRunnableCompiledClass;
use blockifier::fee::receipt::TransactionReceipt;
use blockifier::state::global_cache::GlobalContractCache;
use blockifier::state::contract_class_manager::ContractClassManager;
use blockifier::transaction::objects::{ExecutionResourcesTraits, TransactionExecutionInfo};
use blockifier::transaction::transaction_execution::Transaction;
use blockifier::utils::usize_from_u64;
Expand Down Expand Up @@ -137,8 +137,7 @@ pub struct PyBlockExecutor {
pub tx_executor: Option<TransactionExecutor<PapyrusReader>>,
/// `Send` trait is required for `pyclass` compatibility as Python objects must be threadsafe.
pub storage: Box<dyn Storage + Send>,
pub contract_class_manager_config: ContractClassManagerConfig,
pub global_contract_cache: GlobalContractCache<VersionedRunnableCompiledClass>,
pub contract_class_manager: Arc<ContractClassManager>,
}

#[pymethods]
Expand Down Expand Up @@ -169,10 +168,9 @@ impl PyBlockExecutor {
versioned_constants,
tx_executor: None,
storage: Box::new(storage),
contract_class_manager_config: contract_class_manager_config.into(),
global_contract_cache: GlobalContractCache::new(
contract_class_manager_config.contract_cache_size,
),
contract_class_manager: Arc::new(ContractClassManager::start(
contract_class_manager_config.into(),
)),
}
}

Expand Down Expand Up @@ -365,8 +363,9 @@ impl PyBlockExecutor {
/// (this is true for every partial existence of information at tables).
#[pyo3(signature = (block_number))]
pub fn revert_block(&mut self, block_number: u64) -> NativeBlockifierResult<()> {
// Clear global class cache, to peroperly revert classes declared in the reverted block.
self.global_contract_cache.clear();
// Clear global class cache, to properly revert classes declared in the reverted block.
// TODO(Avi, 01/01/2025): Consider what exactly to clear in native compilation context.
self.contract_class_manager.clear();
self.storage.revert_block(block_number)
}

Expand Down Expand Up @@ -407,10 +406,9 @@ impl PyBlockExecutor {
chain_info: os_config.into_chain_info(),
versioned_constants,
tx_executor: None,
contract_class_manager_config: contract_class_manager_config.into(),
global_contract_cache: GlobalContractCache::new(
contract_class_manager_config.contract_cache_size,
),
contract_class_manager: Arc::new(ContractClassManager::start(
contract_class_manager_config.into(),
)),
}
}
}
Expand All @@ -426,25 +424,27 @@ impl PyBlockExecutor {
PapyrusReader::new(
self.storage.reader().clone(),
next_block_number,
self.global_contract_cache.clone(),
self.contract_class_manager.clone(),
)
}

pub fn create_for_testing_with_storage(storage: impl Storage + Send + 'static) -> Self {
use blockifier::state::global_cache::GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST;
let contract_class_manager_config = ContractClassManagerConfig {
run_cairo_native: false,
wait_on_native_compilation: false,
contract_cache_size: GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST,
};
Self {
bouncer_config: BouncerConfig::max(),
tx_executor_config: TransactionExecutorConfig::create_for_testing(true),
storage: Box::new(storage),
chain_info: ChainInfo::default(),
versioned_constants: VersionedConstants::latest_constants().clone(),
tx_executor: None,
contract_class_manager_config: ContractClassManagerConfig {
run_cairo_native: false,
wait_on_native_compilation: false,
contract_cache_size: GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST,
},
global_contract_cache: GlobalContractCache::new(GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST),
contract_class_manager: Arc::new(ContractClassManager::start(
contract_class_manager_config,
)),
}
}

Expand Down
16 changes: 9 additions & 7 deletions crates/papyrus_state_reader/src/papyrus_state.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::sync::Arc;

use blockifier::execution::contract_class::{
CompiledClassV0,
CompiledClassV1,
RunnableCompiledClass,
VersionedRunnableCompiledClass,
};
use blockifier::state::contract_class_manager::ContractClassManager;
use blockifier::state::errors::{couple_casm_and_sierra, StateError};
use blockifier::state::global_cache::GlobalContractCache;
use blockifier::state::state_api::{StateReader, StateResult};
use papyrus_storage::compiled_class::CasmStorageReader;
use papyrus_storage::db::RO;
Expand All @@ -25,16 +27,16 @@ type RawPapyrusReader<'env> = papyrus_storage::StorageTxn<'env, RO>;
pub struct PapyrusReader {
storage_reader: StorageReader,
latest_block: BlockNumber,
global_class_hash_to_class: GlobalContractCache<VersionedRunnableCompiledClass>,
contract_class_manager: Arc<ContractClassManager>,
}

impl PapyrusReader {
pub fn new(
storage_reader: StorageReader,
latest_block: BlockNumber,
global_class_hash_to_class: GlobalContractCache<VersionedRunnableCompiledClass>,
contract_class_manager: Arc<ContractClassManager>,
) -> Self {
Self { storage_reader, latest_block, global_class_hash_to_class }
Self { storage_reader, latest_block, contract_class_manager }
}

fn reader(&self) -> StateResult<RawPapyrusReader<'_>> {
Expand Down Expand Up @@ -132,15 +134,15 @@ impl StateReader for PapyrusReader {

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
// Assumption: the global cache is cleared upon reverted blocks.
let versioned_contract_class = self.global_class_hash_to_class.get(&class_hash);
let versioned_contract_class = self.contract_class_manager.get_casm(&class_hash);

match versioned_contract_class {
Some(contract_class) => Ok(RunnableCompiledClass::from(contract_class)),
None => {
let versioned_contract_class_from_db = self.get_compiled_class_inner(class_hash)?;
// The class was declared in a previous (finalized) state; update the global cache.
self.global_class_hash_to_class
.set(class_hash, versioned_contract_class_from_db.clone());
self.contract_class_manager
.set_casm(class_hash, versioned_contract_class_from_db.clone());
Ok(RunnableCompiledClass::from(versioned_contract_class_from_db))
}
}
Expand Down
11 changes: 10 additions & 1 deletion crates/starknet_batcher/src/batcher.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use blockifier::blockifier::config::ContractClassManagerConfig;
use blockifier::state::contract_class_manager::ContractClassManager;
use blockifier::state::global_cache::GlobalContractCache;
#[cfg(test)]
use mockall::automock;
Expand Down Expand Up @@ -433,10 +435,17 @@ pub fn create_batcher(config: BatcherConfig, mempool_client: SharedMempoolClient
let (storage_reader, storage_writer) = papyrus_storage::open_storage(config.storage.clone())
.expect("Failed to open batcher's storage");

let contract_class_manager_config = ContractClassManagerConfig {
run_cairo_native: false,
wait_on_native_compilation: false,
contract_cache_size: config.global_contract_cache_size,
};
let block_builder_factory = Box::new(BlockBuilderFactory {
block_builder_config: config.block_builder_config.clone(),
storage_reader: storage_reader.clone(),
global_class_hash_to_class: GlobalContractCache::new(config.global_contract_cache_size),
contract_class_manager: Arc::new(ContractClassManager::start(
contract_class_manager_config,
)),
});
let storage_reader = Arc::new(storage_reader);
let storage_writer = Box::new(storage_writer);
Expand Down
Loading

0 comments on commit a624362

Please sign in to comment.