Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: Executor checks for precompile invariant violations #1811

Draft
wants to merge 11 commits into
base: dev
Choose a base branch
from
73 changes: 52 additions & 21 deletions crates/core/executor/src/events/precompiles/ec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,20 @@ pub struct EllipticCurveDecompressEvent {
/// The generic parameter `N` is the number of u32 words in the point representation. For example,
/// for the secp256k1 curve, `N` would be 16 (64 bytes) because the x and y coordinates are 32 bytes
/// each.
///
/// This function also checks if the inputs are valid curve points, setting the invariant violated
/// flag if not.
pub fn create_ec_add_event<E: EllipticCurve>(
rt: &mut SyscallContext,
arg1: u32,
arg2: u32,
) -> EllipticCurveAddEvent {
) -> Option<EllipticCurveAddEvent> {
let start_clk = rt.clk;
let p_ptr = arg1;
if p_ptr % 4 != 0 {
panic!();
}
let q_ptr = arg2;
if q_ptr % 4 != 0 {
panic!();
if p_ptr % 4 > 0 || q_ptr % 4 > 0 {
eprintln!("EC_ADD: ptr alignment violation");
return rt.invariant_violated();
}

let num_words = <E::BaseField as NumWords>::WordsCurvePoint::USIZE;
Expand All @@ -123,14 +124,24 @@ pub fn create_ec_add_event<E: EllipticCurve>(
rt.clk += 1;

let p_affine = AffinePoint::<E>::from_words_le(&p);
if !E::is_valid_point(&p_affine) {
eprintln!("EC_ADD: invalid p point, invariant violation");
return rt.invariant_violated();
}

let q_affine = AffinePoint::<E>::from_words_le(&q);
if !E::is_valid_point(&q_affine) {
eprintln!("EC_ADD: invalid q point, invariant violation");
return rt.invariant_violated();
}

let result_affine = p_affine + q_affine;

let result_words = result_affine.to_words_le();

let p_memory_records = rt.mw_slice(p_ptr, &result_words);

EllipticCurveAddEvent {
Some(EllipticCurveAddEvent {
lookup_id: rt.syscall_lookup_id,
shard: rt.current_shard(),
clk: start_clk,
Expand All @@ -141,7 +152,7 @@ pub fn create_ec_add_event<E: EllipticCurve>(
p_memory_records,
q_memory_records,
local_mem_access: rt.postprocess(),
}
})
}

/// Create an elliptic curve double event.
Expand All @@ -152,48 +163,57 @@ pub fn create_ec_double_event<E: EllipticCurve>(
rt: &mut SyscallContext,
arg1: u32,
_: u32,
) -> EllipticCurveDoubleEvent {
) -> Option<EllipticCurveDoubleEvent> {
let start_clk = rt.clk;
let p_ptr = arg1;
if p_ptr % 4 != 0 {
panic!();
if p_ptr % 4 > 0 {
return rt.invariant_violated();
}

let num_words = <E::BaseField as NumWords>::WordsCurvePoint::USIZE;

let p = rt.slice_unsafe(p_ptr, num_words);

let p_affine = AffinePoint::<E>::from_words_le(&p);
if !E::is_valid_point(&p_affine) {
eprintln!("EC_DOUBLE: invalid point, invariant violation");
return rt.invariant_violated();
}

let result_affine = E::ec_double(&p_affine);

let result_words = result_affine.to_words_le();

let p_memory_records = rt.mw_slice(p_ptr, &result_words);

EllipticCurveDoubleEvent {
Some(EllipticCurveDoubleEvent {
lookup_id: rt.syscall_lookup_id,
shard: rt.current_shard(),
clk: start_clk,
p_ptr,
p,
p_memory_records,
local_mem_access: rt.postprocess(),
}
})
}

/// Create an elliptic curve decompress event.
///
/// It takes a pointer to a memory location, reads the point from memory, decompresses it, and
/// writes the result back to the memory location.
///
/// This function also checks if the input is a valid curve point, setting the invariant `invariant_violated`
/// flag if not.
pub fn create_ec_decompress_event<E: EllipticCurve>(
rt: &mut SyscallContext,
slice_ptr: u32,
sign_bit: u32,
) -> EllipticCurveDecompressEvent {
) -> Option<EllipticCurveDecompressEvent> {
let start_clk = rt.clk;
assert!(slice_ptr % 4 == 0, "slice_ptr must be 4-byte aligned");
assert!(sign_bit <= 1, "is_odd must be 0 or 1");

if slice_ptr % 4 > 0 || sign_bit > 1 {
return rt.invariant_violated();
}

let num_limbs = <E::BaseField as NumLimbs>::Limbs::USIZE;
let num_words_field_element = num_limbs / 4;
Expand All @@ -202,25 +222,36 @@ pub fn create_ec_decompress_event<E: EllipticCurve>(
rt.mr_slice(slice_ptr + (num_limbs as u32), num_words_field_element);

let x_bytes = words_to_bytes_le_vec(&x_vec);
let mut x_bytes_be = x_bytes.clone();
x_bytes_be.reverse();
let x_bytes_be = {
let mut x_bytes_be = x_bytes.clone();
x_bytes_be.reverse();
x_bytes_be
};

// The decompress_fn takes in an X coordinate and a parity,
// This means whatever point we get is guarnteed to be on the curve
// (it computes the corresponding y)
//
// However its falliable if the X coordinate is not in the field
let decompress_fn = match E::CURVE_TYPE {
CurveType::Secp256k1 => secp256k1_decompress::<E>,
CurveType::Secp256r1 => secp256r1_decompress::<E>,
CurveType::Bls12381 => bls12381_decompress::<E>,
_ => panic!("Unsupported curve"),
};

let computed_point: AffinePoint<E> = decompress_fn(&x_bytes_be, sign_bit);
let Some(computed_point) = decompress_fn(&x_bytes_be, sign_bit) else {
eprintln!("EC_DECOMPRESS: decompression failed, invariant violation");
return rt.invariant_violated();
};

let mut decompressed_y_bytes = computed_point.y.to_bytes_le();
decompressed_y_bytes.resize(num_limbs, 0u8);
let y_words = bytes_to_words_le_vec(&decompressed_y_bytes);

let y_memory_records = rt.mw_slice(slice_ptr, &y_words);

EllipticCurveDecompressEvent {
Some(EllipticCurveDecompressEvent {
lookup_id: rt.syscall_lookup_id,
shard: rt.current_shard(),
clk: start_clk,
Expand All @@ -231,5 +262,5 @@ pub fn create_ec_decompress_event<E: EllipticCurve>(
x_memory_records,
y_memory_records,
local_mem_access: rt.postprocess(),
}
})
}
8 changes: 8 additions & 0 deletions crates/core/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ pub enum ExecutionError {
#[error("unimplemented syscall {0}")]
UnsupportedSyscall(u32),

/// An invariant of the syscall has been violated.
#[error("A syscall invariant has been violated. Syscall code: {0}")]
SyscallInvariantViolation(SyscallCode),

/// The execution failed with a breakpoint.
#[error("breakpoint encountered")]
Breakpoint(),
Expand Down Expand Up @@ -970,6 +974,10 @@ impl<'a> Executor<'a> {
// register. If it returns None, we just keep the
// syscall_id in t0.
let res = syscall_impl.execute(&mut precompile_rt, syscall, b, c);
if precompile_rt.invariant_violation {
return Err(ExecutionError::SyscallInvariantViolation(syscall));
}

if let Some(val) = res {
a = val;
} else {
Expand Down
13 changes: 13 additions & 0 deletions crates/core/executor/src/syscalls/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
pub syscall_lookup_id: LookupId,
/// The local memory access events for the syscall.
pub local_memory_access: HashMap<u32, MemoryLocalEvent>,
/// An invariant of the syscall has been violated.
pub invariant_violation: bool,
}

impl<'a, 'b> SyscallContext<'a, 'b> {
Expand All @@ -44,9 +46,20 @@
rt: runtime,
syscall_lookup_id: LookupId::default(),
local_memory_access: HashMap::new(),
invariant_violation: false,
}
}

/// An invariant of the current syscall has been violated.
///
/// This only happens in precompiles.
/// This is a convience function for returning `None` and setting the invariant violation flag.

Check warning on line 56 in crates/core/executor/src/syscalls/context.rs

View workflow job for this annotation

GitHub Actions / Spell Check

"convience" should be "convince" or "convenience".
pub fn invariant_violated<T>(&mut self) -> Option<T> {
self.invariant_violation = true;

None
}

/// Get a mutable reference to the execution record.
pub fn record_mut(&mut self) -> &mut ExecutionRecord {
&mut self.rt.record
Expand Down
3 changes: 2 additions & 1 deletion crates/core/executor/src/syscalls/precompiles/edwards/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ impl<E: EllipticCurve + EdwardsParameters> Syscall for EdwardsAddAssignSyscall<E
arg1: u32,
arg2: u32,
) -> Option<u32> {
let event = create_ec_add_event::<E>(rt, arg1, arg2);
let event = create_ec_add_event::<E>(rt, arg1, arg2)?;

let syscall_event =
rt.rt.syscall_event(event.clk, syscall_code.syscall_id(), arg1, arg2, event.lookup_id);
rt.add_precompile_event(syscall_code, syscall_event, PrecompileEvent::EdAdd(event));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ impl<E: EdwardsParameters> Syscall for EdwardsDecompressSyscall<E> {
) -> Option<u32> {
let start_clk = rt.clk;
let slice_ptr = arg1;
assert!(slice_ptr % 4 == 0, "Pointer must be 4-byte aligned.");
assert!(sign <= 1, "Sign bit must be 0 or 1.");
if slice_ptr % 4 > 0 || sign > 1 {
return rt.invariant_violated();
}

let (y_memory_records_vec, y_vec) =
rt.mr_slice(slice_ptr + (COMPRESSED_POINT_BYTES as u32), WORDS_FIELD_ELEMENT);
Expand All @@ -53,7 +54,10 @@ impl<E: EdwardsParameters> Syscall for EdwardsDecompressSyscall<E> {

// Compute actual decompressed X
let compressed_y = CompressedEdwardsY(compressed_edwards_y);
let decompressed = decompress(&compressed_y);
// This is falliable as Y might not be a valid field element
let Some(decompressed) = decompress(&compressed_y) else {
return rt.invariant_violated();
};

let mut decompressed_x_bytes = decompressed.x.to_bytes_le();
decompressed_x_bytes.resize(32, 0u8);
Expand Down
22 changes: 15 additions & 7 deletions crates/core/executor/src/syscalls/precompiles/fptower/fp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ impl<P: FpOpField> Syscall for FpOpSyscall<P> {
) -> Option<u32> {
let clk = rt.clk;
let x_ptr = arg1;
if x_ptr % 4 != 0 {
panic!();
}
let y_ptr = arg2;
if y_ptr % 4 != 0 {
panic!();
// Need to check alignment
if x_ptr % 4 > 0 || y_ptr % 4 > 0 {
eprintln!("FpOpSyscall: alignment violation");
return rt.invariant_violated();
}

let num_words = <P as NumWords>::WordsFieldElement::USIZE;
Expand All @@ -46,8 +45,17 @@ impl<P: FpOpField> Syscall for FpOpSyscall<P> {
let (y_memory_records, y) = rt.mr_slice(y_ptr, num_words);

let modulus = &BigUint::from_bytes_le(P::MODULUS);
let a = BigUint::from_slice(&x) % modulus;
let b = BigUint::from_slice(&y) % modulus;
let a = BigUint::from_slice(&x);
if &a >= modulus {
eprintln!("FpOpSyscall: a >= modulus, invariant violation");
return rt.invariant_violated();
}

let b = BigUint::from_slice(&y);
if &b >= modulus {
eprintln!("FpOpSyscall: b >= modulus, invariant violation");
return rt.invariant_violated();
}

let result = match self.op {
FieldOperation::Add => (a + b) % modulus,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ impl<P: FpOpField> Syscall for Fp2AddSubSyscall<P> {
) -> Option<u32> {
let clk = rt.clk;
let x_ptr = arg1;
if x_ptr % 4 != 0 {
panic!();
}
let y_ptr = arg2;
if y_ptr % 4 != 0 {
panic!();
// Need to check alignment
if x_ptr % 4 > 0 || y_ptr % 4 > 0 {
return rt.invariant_violated();
}

let num_words = <P as NumWords>::WordsCurvePoint::USIZE;
Expand All @@ -55,6 +53,11 @@ impl<P: FpOpField> Syscall for Fp2AddSubSyscall<P> {
let bc1 = &BigUint::from_slice(bc1);
let modulus = &BigUint::from_bytes_le(P::MODULUS);

if ac0 >= modulus || ac1 >= modulus || bc0 >= modulus || bc1 >= modulus {
eprintln!("Fp2AddSubSyscall: invariant violation, inputs greater than modulus");
return rt.invariant_violated();
}

let (c0, c1) = match self.op {
FieldOperation::Add => ((ac0 + bc0) % modulus, (ac1 + bc1) % modulus),
FieldOperation::Sub => {
Expand Down
13 changes: 8 additions & 5 deletions crates/core/executor/src/syscalls/precompiles/fptower/fp2_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ impl<P: FpOpField> Syscall for Fp2MulSyscall<P> {
) -> Option<u32> {
let clk = rt.clk;
let x_ptr = arg1;
if x_ptr % 4 != 0 {
panic!();
}
let y_ptr = arg2;
if y_ptr % 4 != 0 {
panic!();
// Need to check alignment
if x_ptr % 4 > 0 || y_ptr % 4 > 0 {
return rt.invariant_violated();
}

let num_words = <P as NumWords>::WordsCurvePoint::USIZE;
Expand All @@ -55,6 +53,11 @@ impl<P: FpOpField> Syscall for Fp2MulSyscall<P> {
let bc1 = &BigUint::from_slice(bc1);
let modulus = &BigUint::from_bytes_le(P::MODULUS);

if ac0 >= modulus || ac1 >= modulus || bc0 >= modulus || bc1 >= modulus {
eprintln!("Fp2MulSyscall: invariant violation, inputs greater than modulus");
return rt.invariant_violated();
}

#[allow(clippy::match_bool)]
let c0 = match (ac0 * bc0) % modulus < (ac1 * bc1) % modulus {
true => ((modulus + (ac0 * bc0) % modulus) - (ac1 * bc1) % modulus) % modulus,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ impl Syscall for Keccak256PermuteSyscall {
let start_clk = rt.clk;
let state_ptr = arg1;
if arg2 != 0 {
panic!("Expected arg2 to be 0, got {arg2}");
eprintln!(
"Expected arg2 to be 0, got {arg2}, this violates the Keccak precompile invariant."
);
return rt.invariant_violated();
}

let mut state_read_records = Vec::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ impl Syscall for Sha256CompressSyscall {
) -> Option<u32> {
let w_ptr = arg1;
let h_ptr = arg2;
assert_ne!(w_ptr, h_ptr);
if w_ptr == h_ptr {
eprintln!("w_ptr == h_ptr, violation of the sha256 invariant");
return rt.invariant_violated();
}

let start_clk = rt.clk;
let mut h_read_records = Vec::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ impl Syscall for Sha256ExtendSyscall {
) -> Option<u32> {
let clk_init = rt.clk;
let w_ptr = arg1;
assert!(arg2 == 0, "arg2 must be 0");
if arg2 != 0 {
eprintln!("Warning: sha256_extend syscall arg2 is not zero, this violates the precompile invariants");

return rt.invariant_violated();
}

let w_ptr_init = w_ptr;
let mut w_i_minus_15_reads = Vec::with_capacity(48);
Expand Down
Loading
Loading