Skip to content

Commit

Permalink
feat: add trace generation and constraints for maddu and ins (#159)
Browse files Browse the repository at this point in the history
* feat: add trace generation and constraints for maddu and ins

* fix constraints for maddu

* add comments for constraints

* fix: continue fix cliipy
  • Loading branch information
weilzkm authored Aug 28, 2024
1 parent 2faba08 commit 4a96161
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 5 deletions.
4 changes: 2 additions & 2 deletions emulator/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,8 @@ impl InstrumentedState {
let mut rs = self.state.registers[((insn >> 21) & 0x1f) as usize];
let mut rd_reg = rt_reg;
let fun = insn & 0x3f;
if opcode == 0 || opcode == 0x1c || (opcode == 0x1F && fun == 0x20) {
// R-type (stores rd)
if opcode == 0 || opcode == 0x1c || (opcode == 0x1F && (fun == 0x20 || fun == 4)) {
// R-type (stores rd), partial Special3 insts: ins, seb, seh, wsbh
rt = self.state.registers[rt_reg as usize];
rd_reg = (insn >> 11) & 0x1f;
} else if opcode < 0x20 {
Expand Down
1 change: 1 addition & 0 deletions prover/src/arithmetic/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::arithmetic::utils::{read_value, read_value_i64_limbs, u32_to_array};
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

/// Generates a shift operation (either SLL(V) or SRL(V)).
///
/// The inputs are stored in the form `(shift, input, 1 << shift)`.
/// NB: if `shift >= 32`, then the third register holds 0.
/// We leverage the functions in mul.rs and div.rs to carry out
Expand Down
1 change: 1 addition & 0 deletions prover/src/arithmetic/sra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::arithmetic::utils::{read_value, u32_to_array};
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

/// Generates a shift operation (SRA(V).
///
/// The inputs are stored in the form `(shift, input, 1 >> shift)`.
/// NB: if `shift >= 32`, then the third register holds 0.
/// We leverage the functions in div.rs to carry out
Expand Down
2 changes: 2 additions & 0 deletions prover/src/cpu/columns/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub struct OpsColumnsView<T: Copy> {
pub m_op_store: T,
pub nop: T,
pub ext: T,
pub ins: T,
pub maddu: T,
pub rdhwr: T,
pub signext8: T,
pub signext16: T,
Expand Down
332 changes: 332 additions & 0 deletions prover/src/cpu/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,172 @@ pub fn eval_ext_circuit_extract<F: RichField + Extendable<D>, const D: usize>(
}
}

pub fn eval_packed_insert<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let filter = lv.op.ins;

// Check rt Reg
// addr(channels[1]) == rt
// addr(channels[2]) == rt
{
let rt_reg_read = lv.mem_channels[1].addr_virtual;
let rt_reg_write = lv.mem_channels[2].addr_virtual;
let rt_src = limb_from_bits_le(lv.rt_bits);
yield_constr.constraint(filter * (rt_reg_read - rt_src));

yield_constr.constraint(filter * (rt_reg_write - rt_src));
}

// Check rs Reg
// addr(channels[0]) == rs
{
let rs_reg = lv.mem_channels[0].addr_virtual;
let rs_dst = limb_from_bits_le(lv.rs_bits);
yield_constr.constraint(filter * (rs_reg - rs_dst));
}

// Check ins result
// is_lsb[i] = 1 if i = lsb
// is_lsb[i] = 0 if i != lsb
// is_lsb[i] * (lsb - i) == 0
// auxs = 1 << lsd
// is_lsb[i] * (auxs - (i << 1)) == 0
// size = msb -lsb
// is_msb[i] = 1 if i = size
// is_msb[i] = 0 if i != size
// is_msb[i] * (size - i) == 0
// auxm = rt & !(mask << lsb)
// auxl = rs[0 : size+1]
// is_msb[i] * (auxl - rs[0:i+1]) == 0
// result == auxm + auxl * auxs
{
let msb = limb_from_bits_le(lv.rd_bits);
let rs_bits = lv.general.misc().rs_bits;
let lsb = limb_from_bits_le(lv.shamt_bits);

let auxm = lv.general.misc().auxm;
let auxl = lv.general.misc().auxl;
let auxs = lv.general.misc().auxs;
let rd_result = lv.mem_channels[2].value;

yield_constr.constraint(filter * (rd_result - auxm - auxl * auxs));

for i in 0..32 {
let is_msb = lv.general.misc().is_msb[i];
let is_lsb = lv.general.misc().is_lsb[i];
let cur_index = P::Scalar::from_canonical_usize(i);
let cur_mul = P::Scalar::from_canonical_usize(1 << i);

yield_constr.constraint(filter * is_lsb * (lsb - cur_index));
yield_constr.constraint(filter * is_lsb * (auxs - cur_mul));

yield_constr.constraint(filter * is_msb * (msb - lsb - cur_index));

let mut insert_bits = [P::ZEROS; 32];
insert_bits[0..i + 1].copy_from_slice(&rs_bits[0..i + 1]);
let insert_val = limb_from_bits_le(insert_bits.to_vec());
yield_constr.constraint(filter * is_msb * (auxl - insert_val));
}
}
}

pub fn eval_ext_circuit_insert<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter = lv.op.ins;

// Check rt Reg
// addr(channels[1]) == rt
// addr(channels[2]) == rt
{
let rt_reg_read = lv.mem_channels[1].addr_virtual;
let rt_reg_write = lv.mem_channels[2].addr_virtual;
let rt_src = limb_from_bits_le_recursive(builder, lv.rt_bits);

let constr = builder.sub_extension(rt_reg_read, rt_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(rt_reg_write, rt_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check rs Reg
// addr(channels[0]) == rs
{
let rs_reg = lv.mem_channels[0].addr_virtual;
let rs_src = limb_from_bits_le_recursive(builder, lv.rs_bits);
let constr = builder.sub_extension(rs_reg, rs_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check ins result
// is_lsb[i] = 1 if i = lsb
// is_lsb[i] = 0 if i != lsb
// is_lsb[i] * (lsb - i) == 0
// auxs = 1 << lsd
// is_lsb[i] * (auxs - (i << 1)) == 0
// size = msb -lsb
// is_msb[i] = 1 if i = size
// is_msb[i] = 0 if i != size
// is_msb[i] * (size - i) == 0
// auxm = rt & !(mask << lsb)
// auxl = rs[0 : size+1]
// is_msb[i] * (auxl - rs[0:i+1]) == 0
// result == auxm + auxl * auxs
{
let msb = limb_from_bits_le_recursive(builder, lv.rd_bits);
let rs_bits = lv.general.misc().rs_bits;
let lsb = limb_from_bits_le_recursive(builder, lv.shamt_bits);
let auxm = lv.general.misc().auxm;
let auxl = lv.general.misc().auxl;
let auxs = lv.general.misc().auxs;
let rd_result = lv.mem_channels[2].value;

let constr = builder.mul_extension(auxl, auxs);
let constr = builder.sub_extension(rd_result, constr);
let constr = builder.sub_extension(constr, auxm);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

for i in 0..32 {
let is_msb = lv.general.misc().is_msb[i];
let is_lsb = lv.general.misc().is_lsb[i];
let cur_index = builder.constant_extension(F::Extension::from_canonical_usize(i));
let cur_mul = builder.constant_extension(F::Extension::from_canonical_usize(1 << i));

let constr_msb = builder.mul_extension(filter, is_msb);
let constr_lsb = builder.mul_extension(filter, is_lsb);

let constr = builder.sub_extension(lsb, cur_index);
let constr = builder.mul_extension(constr, constr_lsb);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(auxs, cur_mul);
let constr = builder.mul_extension(constr, constr_lsb);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(msb, lsb);
let constr = builder.sub_extension(constr, cur_index);
let constr = builder.mul_extension(constr, constr_msb);
yield_constr.constraint(builder, constr);

let mut insert_bits = [builder.zero_extension(); 32];
insert_bits[0..i + 1].copy_from_slice(&rs_bits[0..i + 1]);
let insert_val = limb_from_bits_le_recursive(builder, insert_bits);
let constr = builder.sub_extension(auxl, insert_val);
let constr = builder.mul_extension(constr, constr_msb);
yield_constr.constraint(builder, constr);
}
}
}

pub fn eval_packed_ror<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
Expand Down Expand Up @@ -490,6 +656,168 @@ pub fn eval_ext_circuit_ror<F: RichField + Extendable<D>, const D: usize>(
}
}

pub fn eval_packed_maddu<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let filter = lv.op.maddu;

// Check rs Reg
// addr(channels[0]) == rs
{
let rs_reg = lv.mem_channels[0].addr_virtual;
let rs_src = limb_from_bits_le(lv.rs_bits);
yield_constr.constraint(filter * (rs_reg - rs_src));
}

// Check rt Reg
// addr(channels[1]) == rt
{
let rt_reg = lv.mem_channels[1].addr_virtual;
let rt_dst = limb_from_bits_le(lv.rt_bits);
yield_constr.constraint(filter * (rt_reg - rt_dst));
}

// Check hi Reg
// addr(channels[2]) == 33
// addr(channels[4]) == 33
{
let hi_reg_read = lv.mem_channels[2].addr_virtual;
let hi_reg_write = lv.mem_channels[4].addr_virtual;
let hi_src = P::Scalar::from_canonical_usize(33);
yield_constr.constraint(filter * (hi_reg_read - hi_src));
yield_constr.constraint(filter * (hi_reg_write - hi_src));
}

// Check lo Reg
// addr(channels[3]) == 32
// addr(channels[5]) == 32
{
let lo_reg_read = lv.mem_channels[3].addr_virtual;
let lo_reg_write = lv.mem_channels[5].addr_virtual;
let lo_src = P::Scalar::from_canonical_usize(32);
yield_constr.constraint(filter * (lo_reg_read - lo_src));
yield_constr.constraint(filter * (lo_reg_write - lo_src));
}

// Check maddu result
// carry = overflow << 32
// scale = 1 << 32
// carry * (carry - scale) == 0
// result + (overflow << 32) == (hi,lo) + rs * rt
{
let rs = lv.mem_channels[0].value;
let rt = lv.mem_channels[1].value;
let hi = lv.mem_channels[2].value;
let lo = lv.mem_channels[3].value;
let hi_result: P = lv.mem_channels[4].value;
let lo_result = lv.mem_channels[5].value;
let carry = lv.general.misc().auxm;
let scale = P::Scalar::from_canonical_usize(1 << 32);
let result = hi_result * scale + lo_result;
let mul = rs * rt;
let addend = hi * scale + lo;
let overflow = carry * scale;

yield_constr.constraint(filter * carry * (carry - scale));
yield_constr.constraint(filter * (mul + addend - overflow - result));
}
}

pub fn eval_ext_circuit_maddu<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter = lv.op.maddu;

// Check rs Reg
// addr(channels[0]) == rs
{
let rs_reg = lv.mem_channels[0].addr_virtual;
let rs_src = limb_from_bits_le_recursive(builder, lv.rs_bits);
let constr = builder.sub_extension(rs_reg, rs_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check rt Reg
// addr(channels[1]) == rt
{
let rt_reg = lv.mem_channels[1].addr_virtual;
let rt_src = limb_from_bits_le_recursive(builder, lv.rt_bits);
let constr = builder.sub_extension(rt_reg, rt_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check hi Reg
// addr(channels[2]) == 33
// addr(channels[4]) == 33
{
let hi_reg_read = lv.mem_channels[2].addr_virtual;
let hi_reg_write = lv.mem_channels[4].addr_virtual;
let hi_src = builder.constant_extension(F::Extension::from_canonical_usize(33));
let constr = builder.sub_extension(hi_reg_read, hi_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(hi_reg_write, hi_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check lo Reg
// addr(channels[3]) == 32
// addr(channels[5]) == 32
{
let lo_reg_read = lv.mem_channels[3].addr_virtual;
let lo_reg_write = lv.mem_channels[5].addr_virtual;
let lo_src = builder.constant_extension(F::Extension::from_canonical_usize(32));
let constr = builder.sub_extension(lo_reg_read, lo_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(lo_reg_write, lo_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check maddu result
// carry = overflow << 32
// scale = 1 << 32
// carry * (carry - scale) == 0
// result + (overflow << 32) == (hi,lo) + rs * rt
{
let rs = lv.mem_channels[0].value;
let rt = lv.mem_channels[1].value;
let hi = lv.mem_channels[2].value;
let lo = lv.mem_channels[3].value;
let hi_result = lv.mem_channels[4].value;
let lo_result = lv.mem_channels[5].value;
let carry = lv.general.misc().auxm;
let scale = builder.constant_extension(F::Extension::from_canonical_usize(1 << 32));
let result = builder.mul_extension(hi_result, scale);
let result = builder.add_extension(result, lo_result);
let mul = builder.mul_extension(rs, rt);
let addend = builder.mul_extension(hi, scale);
let addend = builder.add_extension(addend, lo);

let overflow = builder.mul_extension(carry, scale);

let constr = builder.sub_extension(carry, scale);
let constr = builder.mul_extension(constr, carry);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let constr = builder.add_extension(mul, addend);
let constr = builder.sub_extension(constr, overflow);
let constr = builder.sub_extension(constr, result);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}
}

pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
Expand All @@ -499,6 +827,8 @@ pub fn eval_packed<P: PackedField>(
eval_packed_teq(lv, yield_constr);
eval_packed_extract(lv, yield_constr);
eval_packed_ror(lv, yield_constr);
eval_packed_insert(lv, yield_constr);
eval_packed_maddu(lv, yield_constr);
}

pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
Expand All @@ -511,4 +841,6 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
eval_ext_circuit_teq(builder, lv, yield_constr);
eval_ext_circuit_extract(builder, lv, yield_constr);
eval_ext_circuit_ror(builder, lv, yield_constr);
eval_ext_circuit_insert(builder, lv, yield_constr);
eval_ext_circuit_maddu(builder, lv, yield_constr);
}
Loading

0 comments on commit 4a96161

Please sign in to comment.