Skip to content

Commit

Permalink
feat(cairo_native): use contract class manager in papyrus reader
Browse files Browse the repository at this point in the history
  • Loading branch information
avi-starkware committed Dec 18, 2024
1 parent c9d984c commit 8e0c1dc
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 77 deletions.
1 change: 0 additions & 1 deletion crates/blockifier/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod cached_state;
#[cfg(feature = "cairo_native")]
pub mod contract_class_manager;
#[cfg(test)]
pub mod error_format_test;
Expand Down
90 changes: 67 additions & 23 deletions crates/blockifier/src/state/contract_class_manager.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
#[cfg(feature = "cairo_native")]
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TrySendError};
#[cfg(feature = "cairo_native")]
use std::sync::Arc;

#[cfg(any(feature = "testing", test))]
use cached::Cached;
#[cfg(feature = "cairo_native")]
use log::{error, info};
#[cfg(feature = "cairo_native")]
use starknet_api::contract_class::SierraVersion;
use starknet_api::core::ClassHash;
#[cfg(feature = "cairo_native")]
use starknet_api::state::SierraContractClass;
#[cfg(feature = "cairo_native")]
use starknet_sierra_compile::command_line_compiler::CommandLineCompiler;
#[cfg(feature = "cairo_native")]
use starknet_sierra_compile::config::SierraToCasmCompilationConfig;
#[cfg(feature = "cairo_native")]
use starknet_sierra_compile::utils::into_contract_class_for_compilation;
#[cfg(feature = "cairo_native")]
use starknet_sierra_compile::SierraToNativeCompiler;

use crate::blockifier::config::ContractClassManagerConfig;
use crate::execution::contract_class::VersionedRunnableCompiledClass;
#[cfg(feature = "cairo_native")]
use crate::execution::contract_class::{CompiledClassV1, RunnableCompiledClass};
#[cfg(feature = "cairo_native")]
use crate::execution::native::contract_class::NativeCompiledClassV1;
use crate::state::global_cache::{CachedCairoNative, ContractCaches};
#[cfg(feature = "cairo_native")]
use crate::state::global_cache::CachedCairoNative;
use crate::state::global_cache::ContractCaches;

#[cfg(feature = "cairo_native")]
const CHANNEL_SIZE: usize = 1000;

/// Represents a request to compile a sierra contract class to a native compiled class.
Expand All @@ -23,46 +41,54 @@ const CHANNEL_SIZE: usize = 1000;
/// * `sierra_contract_class` - the sierra contract class to be compiled.
/// * `casm_compiled_class` - stored in [`NativeCompiledClassV1`] to allow fallback to cairo_vm
/// execution in case of unexpected failure during native execution.
#[cfg(feature = "cairo_native")]
type CompilationRequest = (ClassHash, Arc<SierraContractClass>, CompiledClassV1);

/// Manages the global cache of contract classes and handles sierra-to-native compilation requests.
struct ContractClassManager {
#[derive(Clone)]
pub struct ContractClassManager {
#[cfg(feature = "cairo_native")]
config: ContractClassManagerConfig,
/// The global cache of contract classes: casm, sierra, and native.
contract_caches: Arc<ContractCaches>,
contract_caches: ContractCaches,
/// The sending half of the compilation request channel. Set to `None` if native compilation is
/// disabled.
#[cfg(feature = "cairo_native")]
sender: Option<SyncSender<CompilationRequest>>,
}

#[allow(dead_code)]
impl ContractClassManager {
/// Creates a new contract class manager and spawns a thread that listens for compilation
/// requests and processes them (a.k.a. the compilation worker).
/// Returns the contract class manager.
pub fn start(config: ContractClassManagerConfig) -> ContractClassManager {
// TODO(Avi, 15/12/2024): Add the size of the channel to the config.
let contract_caches = Arc::new(ContractCaches::new(config.contract_cache_size));
if !config.run_cairo_native {
// Native compilation is disabled - no need to start the compilation worker.
return ContractClassManager { config, contract_caches, sender: None };
}
let (sender, receiver) = sync_channel(CHANNEL_SIZE);
let compiler_config = SierraToCasmCompilationConfig::default();
let compiler = CommandLineCompiler::new(compiler_config);
let contract_caches = ContractCaches::new(config.contract_cache_size);
#[cfg(not(feature = "cairo_native"))]
return ContractClassManager { contract_caches };
#[cfg(feature = "cairo_native")]
{
if !config.run_cairo_native {
// Native compilation is disabled - no need to start the compilation worker.
return ContractClassManager { config, contract_caches, sender: None };
}
let (sender, receiver) = sync_channel(CHANNEL_SIZE);

std::thread::spawn({
let contract_caches = Arc::clone(&contract_caches);
let compiler = Arc::new(compiler);
std::thread::spawn({
let contract_caches = contract_caches.clone();
let compiler_config = SierraToCasmCompilationConfig::default();
let compiler = CommandLineCompiler::new(compiler_config);

move || run_compilation_worker(contract_caches, receiver, compiler)
});
move || run_compilation_worker(contract_caches, receiver, compiler)
});

ContractClassManager { config, contract_caches, sender: Some(sender) }
ContractClassManager { config, contract_caches, sender: Some(sender) }
}
}

/// Sends a compilation request to the compilation worker. Does not block the sender. Logs an
/// error if the channel is full.
#[cfg(feature = "cairo_native")]
pub fn send_compilation_request(&self, request: CompilationRequest) {
assert!(!self.config.run_cairo_native, "Native compilation is disabled.");
let sender = self.sender.as_ref().expect("Compilation channel not available.");
Expand All @@ -83,41 +109,59 @@ impl ContractClassManager {
}

/// Returns the native compiled class for the given class hash, if it exists in cache.
#[cfg(feature = "cairo_native")]
pub fn get_native(&self, class_hash: &ClassHash) -> Option<CachedCairoNative> {
self.contract_caches.get_native(class_hash)
}

/// Returns the Sierra contract class for the given class hash, if it exists in cache.
#[cfg(feature = "cairo_native")]
pub fn get_sierra(&self, class_hash: &ClassHash) -> Option<Arc<SierraContractClass>> {
self.contract_caches.get_sierra(class_hash)
}

/// Returns the casm compiled class for the given class hash, if it exists in cache.
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<RunnableCompiledClass> {
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<VersionedRunnableCompiledClass> {
self.contract_caches.get_casm(class_hash)
}

/// Sets the casm compiled class for the given class hash in the cache.
pub fn set_casm(&self, class_hash: ClassHash, compiled_class: RunnableCompiledClass) {
pub fn set_casm(&self, class_hash: ClassHash, compiled_class: VersionedRunnableCompiledClass) {
self.contract_caches.set_casm(class_hash, compiled_class);
}

/// Clear the contract caches.
pub fn clear(&mut self) {
self.contract_caches.clear();
}

/// Caches the sierra and casm contract classes of a compilation request.
#[cfg(feature = "cairo_native")]
fn cache_request_contracts(&self, request: &CompilationRequest) {
let (class_hash, sierra, casm) = request.clone();
let sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program).unwrap();
self.contract_caches.set_sierra(class_hash, sierra);
let cached_casm = RunnableCompiledClass::from(casm);
let cached_casm = VersionedRunnableCompiledClass::Cairo1((
RunnableCompiledClass::from(casm),
sierra_version,
));
self.contract_caches.set_casm(class_hash, cached_casm);
}

#[cfg(any(feature = "testing", test))]
pub fn get_casm_cache_size(&self) -> usize {
self.contract_caches.casm_cache.lock().cache_size()
}
}

/// Handles compilation requests from the channel, holding the receiver end of the channel.
/// If no request is available, non-busy-waits until a request is available.
/// When the sender is dropped, the worker processes all pending requests and terminates.
#[cfg(feature = "cairo_native")]
fn run_compilation_worker(
contract_caches: Arc<ContractCaches>,
contract_caches: ContractCaches,
receiver: Receiver<CompilationRequest>,
compiler: Arc<dyn SierraToNativeCompiler>,
compiler: impl SierraToNativeCompiler,
) {
info!("Compilation worker started.");
for (class_hash, sierra, casm) in receiver.iter() {
Expand Down
22 changes: 15 additions & 7 deletions crates/blockifier/src/state/global_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use starknet_api::core::ClassHash;
#[cfg(feature = "cairo_native")]
use starknet_api::state::SierraContractClass;

#[cfg(feature = "cairo_native")]
use crate::execution::contract_class::RunnableCompiledClass;
use crate::execution::contract_class::VersionedRunnableCompiledClass;
#[cfg(feature = "cairo_native")]
use crate::execution::native::contract_class::NativeCompiledClassV1;

Expand Down Expand Up @@ -51,50 +50,59 @@ impl<T: Clone> GlobalContractCache<T> {
}
}

#[cfg(feature = "cairo_native")]
#[derive(Clone)]
pub struct ContractCaches {
pub casm_cache: GlobalContractCache<RunnableCompiledClass>,
pub casm_cache: GlobalContractCache<VersionedRunnableCompiledClass>,
#[cfg(feature = "cairo_native")]
pub native_cache: GlobalContractCache<CachedCairoNative>,
#[cfg(feature = "cairo_native")]
pub sierra_cache: GlobalContractCache<Arc<SierraContractClass>>,
}

#[cfg(feature = "cairo_native")]
impl ContractCaches {
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<RunnableCompiledClass> {
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<VersionedRunnableCompiledClass> {
self.casm_cache.get(class_hash)
}

pub fn set_casm(&self, class_hash: ClassHash, compiled_class: RunnableCompiledClass) {
pub fn set_casm(&self, class_hash: ClassHash, compiled_class: VersionedRunnableCompiledClass) {
self.casm_cache.set(class_hash, compiled_class);
}

#[cfg(feature = "cairo_native")]
pub fn get_native(&self, class_hash: &ClassHash) -> Option<CachedCairoNative> {
self.native_cache.get(class_hash)
}

#[cfg(feature = "cairo_native")]
pub fn set_native(&self, class_hash: ClassHash, contract_executor: CachedCairoNative) {
self.native_cache.set(class_hash, contract_executor);
}

#[cfg(feature = "cairo_native")]
pub fn get_sierra(&self, class_hash: &ClassHash) -> Option<Arc<SierraContractClass>> {
self.sierra_cache.get(class_hash)
}

#[cfg(feature = "cairo_native")]
pub fn set_sierra(&self, class_hash: ClassHash, contract_class: Arc<SierraContractClass>) {
self.sierra_cache.set(class_hash, contract_class);
}

pub fn new(cache_size: usize) -> Self {
Self {
casm_cache: GlobalContractCache::new(cache_size),
#[cfg(feature = "cairo_native")]
native_cache: GlobalContractCache::new(cache_size),
#[cfg(feature = "cairo_native")]
sierra_cache: GlobalContractCache::new(cache_size),
}
}

pub fn clear(&mut self) {
self.casm_cache.clear();
#[cfg(feature = "cairo_native")]
self.native_cache.clear();
#[cfg(feature = "cairo_native")]
self.sierra_cache.clear();
}
}
34 changes: 15 additions & 19 deletions crates/native_blockifier/src/py_block_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use blockifier::blockifier::transaction_executor::{TransactionExecutor, Transact
use blockifier::bouncer::BouncerConfig;
use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses};
use blockifier::execution::call_info::CallInfo;
use blockifier::execution::contract_class::VersionedRunnableCompiledClass;
use blockifier::fee::receipt::TransactionReceipt;
use blockifier::state::global_cache::GlobalContractCache;
use blockifier::state::contract_class_manager::ContractClassManager;
use blockifier::transaction::objects::{ExecutionResourcesTraits, TransactionExecutionInfo};
use blockifier::transaction::transaction_execution::Transaction;
use blockifier::utils::usize_from_u64;
Expand Down Expand Up @@ -137,8 +136,7 @@ pub struct PyBlockExecutor {
pub tx_executor: Option<TransactionExecutor<PapyrusReader>>,
/// `Send` trait is required for `pyclass` compatibility as Python objects must be threadsafe.
pub storage: Box<dyn Storage + Send>,
pub contract_class_manager_config: ContractClassManagerConfig,
pub global_contract_cache: GlobalContractCache<VersionedRunnableCompiledClass>,
pub contract_class_manager: ContractClassManager,
}

#[pymethods]
Expand Down Expand Up @@ -169,9 +167,8 @@ impl PyBlockExecutor {
versioned_constants,
tx_executor: None,
storage: Box::new(storage),
contract_class_manager_config: contract_class_manager_config.into(),
global_contract_cache: GlobalContractCache::new(
contract_class_manager_config.contract_cache_size,
contract_class_manager: ContractClassManager::start(
contract_class_manager_config.into(),
),
}
}
Expand Down Expand Up @@ -365,8 +362,8 @@ impl PyBlockExecutor {
/// (this is true for every partial existence of information at tables).
#[pyo3(signature = (block_number))]
pub fn revert_block(&mut self, block_number: u64) -> NativeBlockifierResult<()> {
// Clear global class cache, to peroperly revert classes declared in the reverted block.
self.global_contract_cache.clear();
// Clear global class cache, to properly revert classes declared in the reverted block.
self.contract_class_manager.clear();
self.storage.revert_block(block_number)
}

Expand Down Expand Up @@ -407,9 +404,8 @@ impl PyBlockExecutor {
chain_info: os_config.into_chain_info(),
versioned_constants,
tx_executor: None,
contract_class_manager_config: contract_class_manager_config.into(),
global_contract_cache: GlobalContractCache::new(
contract_class_manager_config.contract_cache_size,
contract_class_manager: ContractClassManager::start(
contract_class_manager_config.into(),
),
}
}
Expand All @@ -426,25 +422,25 @@ impl PyBlockExecutor {
PapyrusReader::new(
self.storage.reader().clone(),
next_block_number,
self.global_contract_cache.clone(),
self.contract_class_manager.clone(),
)
}

pub fn create_for_testing_with_storage(storage: impl Storage + Send + 'static) -> Self {
use blockifier::state::global_cache::GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST;
let contract_class_manager_config = ContractClassManagerConfig {
run_cairo_native: false,
wait_on_native_compilation: false,
contract_cache_size: GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST,
};
Self {
bouncer_config: BouncerConfig::max(),
tx_executor_config: TransactionExecutorConfig::create_for_testing(true),
storage: Box::new(storage),
chain_info: ChainInfo::default(),
versioned_constants: VersionedConstants::latest_constants().clone(),
tx_executor: None,
contract_class_manager_config: ContractClassManagerConfig {
run_cairo_native: false,
wait_on_native_compilation: false,
contract_cache_size: GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST,
},
global_contract_cache: GlobalContractCache::new(GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST),
contract_class_manager: ContractClassManager::start(contract_class_manager_config),
}
}

Expand Down
5 changes: 2 additions & 3 deletions crates/native_blockifier/src/py_block_executor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::collections::HashMap;
use blockifier::blockifier::transaction_executor::BLOCK_STATE_ACCESS_ERR;
use blockifier::execution::contract_class::{CompiledClassV1, RunnableCompiledClass};
use blockifier::state::state_api::StateReader;
use cached::Cached;
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
use pretty_assertions::assert_eq;
use starknet_api::class_hash;
Expand Down Expand Up @@ -69,7 +68,7 @@ fn global_contract_cache_update() {
)
.unwrap();

assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 0);
assert_eq!(block_executor.contract_class_manager.get_casm_cache_size(), 0);

let queried_contract_class = block_executor
.tx_executor()
Expand All @@ -80,7 +79,7 @@ fn global_contract_cache_update() {
.unwrap();

assert_eq!(queried_contract_class, contract_class);
assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 1);
assert_eq!(block_executor.contract_class_manager.get_casm_cache_size(), 1);
}

#[test]
Expand Down
Loading

0 comments on commit 8e0c1dc

Please sign in to comment.