Skip to content

Commit

Permalink
feat(cairo_native): add batcher compiler struct
Browse files Browse the repository at this point in the history
  • Loading branch information
avi-starkware committed Nov 28, 2024
1 parent 0459bde commit 5c6e955
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 6 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/blockifier/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ description = "The transaction-executing component in the Starknet sequencer."
workspace = true

[features]
cairo_native = ["dep:cairo-native"]
cairo_native = ["dep:cairo-native", "starknet_sierra_compile/cairo_native"]
jemalloc = ["dep:tikv-jemallocator"]
reexecution = ["transaction_serde"]
testing = ["rand", "rstest", "starknet_api/testing"]
Expand Down Expand Up @@ -50,6 +50,7 @@ serde_json = { workspace = true, features = ["arbitrary_precision"] }
sha2.workspace = true
starknet-types-core.workspace = true
starknet_api.workspace = true
starknet_sierra_compile = { workspace = true, optional = true }
strum.workspace = true
strum_macros.workspace = true
tempfile.workspace = true
Expand Down
2 changes: 2 additions & 0 deletions crates/blockifier/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod cached_state;
#[cfg(feature = "cairo_native")]
pub mod contract_class_manager;
#[cfg(test)]
pub mod error_format_test;
pub mod errors;
Expand Down
225 changes: 225 additions & 0 deletions crates/blockifier/src/state/contract_class_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
use std::fmt::{Display, Formatter};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TrySendError};
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;

use log::{error, info};
use starknet_api::core::ClassHash;
use starknet_api::state::ContractClass as SierraContractClass;
use starknet_sierra_compile::command_line_compiler::CommandLineCompiler;
use starknet_sierra_compile::config::SierraToCasmCompilationConfig;
use starknet_sierra_compile::SierraToNativeCompiler;

use crate::execution::contract_class::{ContractClassV1, RunnableContractClass};
use crate::execution::native::contract_class::NativeContractClassV1;
use crate::state::global_cache::{CachedCairoNative, ContractClassCaches};

const CHANNEL_SIZE: usize = 1000;

/// Represents a request for the compilation handler.
///
/// # `Request` variants:
/// * `Terminate` - signals the compilation handler to terminate.
/// * `Compile` - requests compilation of a sierra contract class to a native compiled class.
///
/// # `Request::Compile` fields:
/// * `class_hash` - used to identify the contract class in the cache.
/// * `sierra_contract_class` - the sierra contract class to be compiled.
/// * `casm_contract_class` - stored in [`NativeContractClassV1`] to allow fallback to cairo_vm
/// execution in case of unxecpected failure during native execution.
pub enum Request {
Terminate,
Compile(ClassHash, Arc<SierraContractClass>, ContractClassV1),
}

impl Display for Request {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Request::Terminate => write!(f, "Terminate"),
Request::Compile(class_hash, _, _) => {
write!(f, "Request(class_hash: {:?})", class_hash)
}
}
}
}

/// Manages the global cache of contract classes and handles sierra-to-native compilation requests.
struct ContractClassManager {
// The global cache of contract classes: casm, sierra, and native.
contract_class_caches: ContractClassCaches,
// The sending half of the compilation request channel.
sender: SyncSender<Request>,
// A flag that signals the termination of the compilation handler.
stop_marker: AtomicBool,
// The join handle to the thread running the compilation handler.
join_handle: Mutex<Option<JoinHandle<()>>>,
}

impl ContractClassManager {
/// Creates a new contract class manager and spawns a thread that listens for compilation
/// requests and processes them (a.k.a. the compilation handler).
/// Returns an `Arc` to the created contract class manager.
pub fn start(contract_class_caches: ContractClassCaches) -> Arc<ContractClassManager> {
// TODO(Avi, 15/12/2024): Add the size of the channel to the config.
let (sender, receiver) = sync_channel(CHANNEL_SIZE);
let compiler_config = SierraToCasmCompilationConfig::default();
let compiler = CommandLineCompiler::new(compiler_config);

// Create the contract class manager.
let contract_class_manager = Arc::new(ContractClassManager {
contract_class_caches,
sender,
stop_marker: AtomicBool::new(false),
// Store `None` in a mutex-guarded Option to allow setting the join handle after
// spawning the thread.
join_handle: Mutex::new(None),
});

// Spawn a thread running the compilation handler.
let join_handle = std::thread::spawn({
let contract_class_manager = Arc::clone(&contract_class_manager);
move || contract_class_manager.run_compilation_handler(receiver, compiler)
});

// Store the join handle to allow waiting for the thread to finish.
let mut mutex_guard = contract_class_manager
.join_handle
.lock()
.expect("No other thread should access the join handle.");
*mutex_guard = Some(join_handle);
drop(mutex_guard);

// Return the contract class manager.
contract_class_manager
}

/// Stops the compilation handler.
pub fn stop(&self) {
self.stop_marker.store(true, Ordering::Relaxed);
// For cases where the compilation request channel is empty, send a termination request to
// unblock the compilation handler. This is necessary because the handler is
// non-busy-waiting for requests.
self.send_request(Request::Terminate);
}

/// Sends a request to the compilation handler.
pub fn send_request(&self, request: Request) {
self.cache_request_contracts(&request);

match request {
Request::Terminate => {
// When sending a termination request, block the sender until the request is sent.
self.sender.send(request).expect("Compilation request channel is closed.");
}
Request::Compile(_, _, _) => {
// When sending a compilation request, send the request without blocking the sender.
// TODO(Avi, 15/12/2024): Check for duplicated requests.
self.sender.try_send(request).map_err(|err| match err {
TrySendError::Full(request) => {
error!(
"Compilation request channel is full (size: {}). Compilation request \
{} was not sent.",
CHANNEL_SIZE, request
)
}
TrySendError::Disconnected(_) => {
panic!("Compilation request channel is closed.")
}
});
}
}
}

/// Returns the native compiled class for the given class hash, if it exists in cache.
pub fn get_native(&self, class_hash: &ClassHash) -> Option<CachedCairoNative> {
self.contract_class_caches.get_native(class_hash)
}

/// Returns the Sierra contract class for the given class hash, if it exists in cache.
pub fn get_sierra(&self, class_hash: &ClassHash) -> Option<Arc<SierraContractClass>> {
self.contract_class_caches.get_sierra(class_hash)
}

/// Returns the casm compiled class for the given class hash, if it exists in cache.
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<RunnableContractClass> {
self.contract_class_caches.get_casm(class_hash)
}

/// Handles requests on the compilation request channel.
/// If no request is available, non-busy-waits until a request is available.
fn run_compilation_handler(&self, receiver: Receiver<Request>, compiler: CommandLineCompiler) {
info!("Compilation handler started.");
for request in receiver.iter() {
if self.stopped() {
info!("Compilation handler terminated.");
return;
}
match request {
Request::Terminate => {
info!("Compilation handler terminated.");
return;
}
Request::Compile(class_hash, sierra, casm) => {
if self.contract_class_caches.get_native(&class_hash).is_some() {
// The contract class is already compiled to native - skip the compilation.
continue;
}
// TODO(Avi): Convert `sierra_contract_class` to
// `cairo_lang_starknet_classes::contract_class::ContractClass`
let compilation_result = compiler.compile_to_native(sierra.into());
match compilation_result {
Ok(executor) => {
let native_contract_class = NativeContractClassV1::new(executor, casm);
self.contract_class_caches.set_native(
class_hash,
CachedCairoNative::Compiled(native_contract_class),
);
}
Err(err) => {
error!("Error compiling contract class: {}", err);
self.contract_class_caches
.set_native(class_hash, CachedCairoNative::CompilationFailed);
}
}
}
}
}
}

/// Caches the sierra and casm contract classes of a compilation request.
fn cache_request_contracts(&self, request: &Request) {
match request {
Request::Terminate => {}
Request::Compile(class_hash, sierra, casm) => {
self.contract_class_caches.set_sierra(class_hash.clone(), sierra.clone());
let cached_casm = RunnableContractClass::from(casm.clone());
self.contract_class_caches.set_casm(class_hash.clone(), cached_casm);
}
}
}

/// Returns true if the compilation handler has been stopped.
fn stopped(&self) -> bool {
self.stop_marker.load(Ordering::Relaxed)
}
}

impl Drop for ContractClassManager {
/// Ensures the thread running the compilation handler is terminated when the contract class
/// manager is dropped.
fn drop(&mut self) {
self.stop();
let join_handle = self
.join_handle
.lock()
.expect("The lock should only be accessed when the contract class manager is dropped.")
.take()
.expect(
"The join handle should be set when the thread running the compilation handler is
spawned.",
);

join_handle.join().unwrap();
}
}
10 changes: 5 additions & 5 deletions crates/blockifier/src/state/global_cache.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::sync::{Arc, Mutex, MutexGuard};

use cached::{Cached, SizedCache};
#[cfg(feature = "cairo_native")]
use cairo_native::executor::AotContractExecutor;
use starknet_api::core::ClassHash;
#[cfg(feature = "cairo_native")]
use starknet_api::state::ContractClass as SierraContractClass;

#[cfg(feature = "cairo_native")]
use crate::execution::contract_class::RunnableContractClass;
#[cfg(feature = "cairo_native")]
use crate::execution::native::contract_class::NativeContractClassV1;

type ContractClassLRUCache<T> = SizedCache<ClassHash, T>;
pub type LockedContractClassCache<'a, T> = MutexGuard<'a, ContractClassLRUCache<T>>;
Expand All @@ -21,7 +21,7 @@ pub struct GlobalContractCache<T: Clone>(pub Arc<Mutex<ContractClassLRUCache<T>>
#[cfg(feature = "cairo_native")]
#[derive(Debug, Clone)]
pub enum CachedCairoNative {
Compiled(AotContractExecutor),
Compiled(NativeContractClassV1),
CompilationFailed,
}

Expand Down Expand Up @@ -52,14 +52,14 @@ impl<T: Clone> GlobalContractCache<T> {
}

#[cfg(feature = "cairo_native")]
pub struct GlobalContractCacheManager {
pub struct ContractClassCaches {
pub casm_cache: GlobalContractCache<RunnableContractClass>,
pub native_cache: GlobalContractCache<CachedCairoNative>,
pub sierra_cache: GlobalContractCache<Arc<SierraContractClass>>,
}

#[cfg(feature = "cairo_native")]
impl GlobalContractCacheManager {
impl ContractClassCaches {
pub fn get_casm(&self, class_hash: &ClassHash) -> Option<RunnableContractClass> {
self.casm_cache.get(class_hash)
}
Expand Down

0 comments on commit 5c6e955

Please sign in to comment.