Skip to content

Commit

Permalink
remove corelib + add tests + integrate with main
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancoGiachetta committed Aug 27, 2024
1 parent e0af1a1 commit ba9808a
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 95 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ version = "0.1.0"
edition = "2021"

[dependencies]
cairo-lang-sierra = "2.7.1"
cairo-lang-utils = "2.7.1"
cairo-lang-sierra = "=2.7.1"
cairo-lang-utils = "=2.7.1"
clap = { version = "4.5.16", features = ["derive"] }
k256 = "0.13.3"
keccak = "0.1.5"
Expand All @@ -25,8 +25,8 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }

[dev-dependencies]
cairo-lang-compiler = "2.7.0"
cairo-lang-starknet = "2.7.0"
cairo-lang-compiler = "=2.7.0"
cairo-lang-starknet = "=2.7.0"

# On dev optimize dependencies a bit so it's not as slow.
[profile.dev.package."*"]
Expand Down
93 changes: 22 additions & 71 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use self::args::CmdArgs;
use args::EntryPoint;
use cairo_lang_sierra::{
extensions::{
circuit::CircuitTypeConcrete, core::CoreTypeConcrete, starknet::StarkNetTypeConcrete,
},
program::Program,
extensions::{core::CoreTypeConcrete, starknet::StarkNetTypeConcrete},
ProgramParser,
};
use clap::Parser;
Expand All @@ -18,8 +14,6 @@ use tracing::{debug, info, Level};
use tracing_subscriber::{EnvFilter, FmtSubscriber};

mod args;
#[cfg(test)]
mod utils;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = CmdArgs::parse();
Expand All @@ -32,44 +26,23 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
)?;

info!("Loading the Sierra program from disk.");
let source_code = fs::read_to_string(&args.program)?;
let source_code = fs::read_to_string(args.program)?;

info!("Parsing the Sierra program.");
let program = Arc::new(
ProgramParser::new()
.parse(&source_code)
.map_err(|e| e.to_string())?,
);

let mut vm = create_vm(program, args.entry_point, args.args, args.available_gas)?;

let mut trace = ProgramTrace::new();

info!("Running the program.");
while let Some((statement_idx, state)) = vm.step() {
trace.push(StateDump::new(statement_idx, state));
}

match args.output {
Some(path) => serde_json::to_writer(File::create(path)?, &trace)?,
None => serde_json::to_writer(stdout().lock(), &trace)?,
};

Ok(())
}

pub fn create_vm(
program: Arc<Program>,
entry_point: EntryPoint,
args: Vec<String>,
available_gas: Option<u128>,
) -> Result<VirtualMachine, Box<dyn std::error::Error>> {
info!("Preparing the virtual machine.");
let mut vm = VirtualMachine::new(program.clone());

debug!("Pushing the entry point's frame.");
let function = program
.funcs
.iter()
.find(|f| match &entry_point {
.find(|f| match &args.entry_point {
args::EntryPoint::Number(x) => f.id.id == *x,
args::EntryPoint::String(x) => f.id.debug_name.as_deref() == Some(x.as_str()),
})
Expand All @@ -79,7 +52,7 @@ pub fn create_vm(
"Entry point argument types: {:?}",
function.signature.param_types
);
let mut iter = args.into_iter();
let mut iter = args.args.into_iter();
vm.push_frame(
function.id.clone(),
function
Expand All @@ -90,16 +63,12 @@ pub fn create_vm(
let type_info = vm.registry().get_type(type_id).unwrap();
match type_info {
CoreTypeConcrete::Felt252(_) => Value::parse_felt(&iter.next().unwrap()),
CoreTypeConcrete::GasBuiltin(_) => Value::U128(available_gas.unwrap()),
CoreTypeConcrete::GasBuiltin(_) => Value::U128(args.available_gas.unwrap()),
CoreTypeConcrete::RangeCheck(_)
| CoreTypeConcrete::RangeCheck96(_)
| CoreTypeConcrete::Bitwise(_)
| CoreTypeConcrete::Pedersen(_)
| CoreTypeConcrete::Poseidon(_)
| CoreTypeConcrete::SegmentArena(_)
| CoreTypeConcrete::Circuit(
CircuitTypeConcrete::AddMod(_) | CircuitTypeConcrete::MulMod(_),
) => Value::Unit,
| CoreTypeConcrete::SegmentArena(_) => Value::Unit,
CoreTypeConcrete::StarkNet(inner) => match inner {
StarkNetTypeConcrete::System(_) => Value::Unit,
_ => todo!(),
Expand All @@ -110,7 +79,19 @@ pub fn create_vm(
.collect::<Vec<_>>(),
);

Ok(vm)
let mut trace = ProgramTrace::new();

info!("Running the program.");
while let Some((statement_idx, state)) = vm.step() {
trace.push(StateDump::new(statement_idx, state));
}

match args.output {
Some(path) => serde_json::to_writer(File::create(path)?, &trace)?,
None => serde_json::to_writer(stdout().lock(), &trace)?,
};

Ok(())
}

#[cfg(test)]
Expand All @@ -119,12 +100,9 @@ mod test {

use cairo_lang_compiler::CompilerConfig;
use cairo_lang_starknet::compile::compile_path;
use num_bigint::BigInt;
use sierra_emu::{
find_entry_point_by_idx, ContractExecutionResult, ProgramTrace, Value, StateDump, VirtualMachine,
find_entry_point_by_idx, ContractExecutionResult, ProgramTrace, StateDump, VirtualMachine,
};

use crate::utils::run_program_assert_result;

#[test]
fn test_contract() {
Expand Down Expand Up @@ -204,31 +182,4 @@ mod test {
// let trace_str = serde_json::to_string_pretty(&trace).unwrap();
// std::fs::write("contract_trace.json", trace_str).unwrap();
}

#[test]
fn run_full_circuit() {
let path = Path::new("programs/circuits.cairo");

let range96 = BigInt::ZERO..(BigInt::from(1) << 96);
let limb0 = Value::BoundedInt {
range: range96.clone(),
value: 36699840570117848377038274035_u128.into(),
};
let limb1 = Value::BoundedInt {
range: range96.clone(),
value: 72042528776886984408017100026_u128.into(),
};
let limb2 = Value::BoundedInt {
range: range96.clone(),
value: 54251667697617050795983757117_u128.into(),
};
let limb3 = Value::BoundedInt {
range: range96,
value: 7.into(),
};

let expected_output = vec![Value::Struct(vec![limb0, limb1, limb2, limb3])];

run_program_assert_result(path, "circuits::circuits::main", expected_output)
}
}
14 changes: 1 addition & 13 deletions src/vm/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ pub fn eval_eval(
let res = match r.modinv(&modulus) {
Some(inv) => inv,
None => {
// attempted to get the inverse of 0,
break false;
panic!("attempt to divide by 0");
}
};
// if it is a inv_gate the output index is store in lhs
Expand Down Expand Up @@ -354,14 +353,3 @@ pub fn eval_into_u96_guarantee(

EvalAction::NormalBranch(0, smallvec![Value::U128(value.try_into().unwrap())])
}

#[cfg(test)]
mod tests {
use std::path::Path;

use cairo_lang_compiler::CompilerConfig;
use cairo_lang_starknet::compile::compile_path;
use num_bigint::BigInt;

use crate::{ProgramTrace, StateDump, Value, VirtualMachine};
}
61 changes: 59 additions & 2 deletions tests/libfuncs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ use std::{path::Path, sync::Arc};

use cairo_lang_compiler::{compile_cairo_project_at_path, CompilerConfig};
use cairo_lang_sierra::{
extensions::{core::CoreTypeConcrete, starknet::StarkNetTypeConcrete},
extensions::{
circuit::CircuitTypeConcrete, core::CoreTypeConcrete, starknet::StarkNetTypeConcrete,
},
program::{GenFunction, Program, StatementIdx},
};
use num_bigint::BigInt;
use sierra_emu::{ProgramTrace, StateDump, Value, VirtualMachine};

fn run_program(path: &str, func_name: &str, args: &[Value]) -> Vec<Value> {
Expand Down Expand Up @@ -40,11 +43,15 @@ fn run_program(path: &str, func_name: &str, args: &[Value]) -> Vec<Value> {
CoreTypeConcrete::GasBuiltin(_) => Value::U128(initial_gas),
CoreTypeConcrete::StarkNet(StarkNetTypeConcrete::System(_)) => Value::Unit,
CoreTypeConcrete::RangeCheck(_)
| CoreTypeConcrete::RangeCheck96(_)
| CoreTypeConcrete::Pedersen(_)
| CoreTypeConcrete::Poseidon(_)
| CoreTypeConcrete::Bitwise(_)
| CoreTypeConcrete::BuiltinCosts(_)
| CoreTypeConcrete::SegmentArena(_) => Value::Unit,
| CoreTypeConcrete::SegmentArena(_)
| CoreTypeConcrete::Circuit(
CircuitTypeConcrete::AddMod(_) | CircuitTypeConcrete::MulMod(_),
) => Value::Unit,
_ => args.next().unwrap(),
}
})
Expand Down Expand Up @@ -131,3 +138,53 @@ pub fn find_entry_point_by_name<'a>(
.iter()
.find(|x| x.id.debug_name.as_ref().map(|x| x.as_str()) == Some(name))
}

// CIRCUITS

#[test]
fn test_run_full_circuit() {
let range96 = BigInt::ZERO..(BigInt::from(1) << 96);
let limb0 = Value::BoundedInt {
range: range96.clone(),
value: 36699840570117848377038274035_u128.into(),
};
let limb1 = Value::BoundedInt {
range: range96.clone(),
value: 72042528776886984408017100026_u128.into(),
};
let limb2 = Value::BoundedInt {
range: range96.clone(),
value: 54251667697617050795983757117_u128.into(),
};
let limb3 = Value::BoundedInt {
range: range96,
value: 7.into(),
};

let output = run_program(
"tests/tests/circuits.cairo",
"circuits::circuits::main",
&[],
);
let expected_output = Value::Struct(vec![Value::Struct(vec![limb0, limb1, limb2, limb3])]);
let Value::Enum {
self_ty: _,
index: _,
payload,
} = output.last().unwrap()
else {
panic!("No output");
};

assert_eq!(**payload, expected_output);
}

#[test]
#[should_panic(expected = "attempt to divide by 0")]
fn test_circuit_failure() {
run_program(
"tests/tests/circuits_failure.cairo",
"circuits_failure::circuits_failure::main",
&[],
);
}
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@ use core::circuit::{

fn main() {
let in1 = CircuitElement::<CircuitInput<0>> {};
let in2 = CircuitElement::<CircuitInput<1>> {};
let sub = circuit_sub(in1, in2);
let inv = circuit_inverse(sub);
let inv = circuit_inverse(in1);

let modulus = TryInto::<_, CircuitModulus>::try_into([7, 0, 0, 0]).unwrap();
let outputs = (inv,)
.new_inputs()
.next([6, 0, 0, 0])
.next([6, 0, 0, 0])
.next([0, 0, 0, 0])
.done()
.eval(modulus)
.unwrap();
Expand Down

0 comments on commit ba9808a

Please sign in to comment.