Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add trace generation and constraints for maddu and ins #159

Merged
merged 5 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions emulator/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,8 @@ impl InstrumentedState {
let mut rs = self.state.registers[((insn >> 21) & 0x1f) as usize];
let mut rd_reg = rt_reg;
let fun = insn & 0x3f;
if opcode == 0 || opcode == 0x1c || (opcode == 0x1F && fun == 0x20) {
// R-type (stores rd)
if opcode == 0 || opcode == 0x1c || (opcode == 0x1F && (fun == 0x20 || fun == 4)) {
// R-type (stores rd), partial Special3 insts: ins, seb, seh, wsbh
rt = self.state.registers[rt_reg as usize];
rd_reg = (insn >> 11) & 0x1f;
} else if opcode < 0x20 {
Expand Down
1 change: 1 addition & 0 deletions prover/src/arithmetic/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::arithmetic::utils::{read_value, read_value_i64_limbs, u32_to_array};
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

/// Generates a shift operation (either SLL(V) or SRL(V)).
///
/// The inputs are stored in the form `(shift, input, 1 << shift)`.
/// NB: if `shift >= 32`, then the third register holds 0.
/// We leverage the functions in mul.rs and div.rs to carry out
Expand Down
1 change: 1 addition & 0 deletions prover/src/arithmetic/sra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::arithmetic::utils::{read_value, u32_to_array};
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

/// Generates a shift operation (SRA(V).
///
/// The inputs are stored in the form `(shift, input, 1 >> shift)`.
/// NB: if `shift >= 32`, then the third register holds 0.
/// We leverage the functions in div.rs to carry out
Expand Down
2 changes: 2 additions & 0 deletions prover/src/cpu/columns/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub struct OpsColumnsView<T: Copy> {
pub m_op_store: T,
pub nop: T,
pub ext: T,
pub ins: T,
pub maddu: T,
pub rdhwr: T,
pub signext8: T,
pub signext16: T,
Expand Down
332 changes: 332 additions & 0 deletions prover/src/cpu/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,172 @@ pub fn eval_ext_circuit_extract<F: RichField + Extendable<D>, const D: usize>(
}
}

pub fn eval_packed_insert<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let filter = lv.op.ins;

// Check rt Reg
// addr(channels[1]) == rt
// addr(channels[2]) == rt
{
let rt_reg_read = lv.mem_channels[1].addr_virtual;
let rt_reg_write = lv.mem_channels[2].addr_virtual;
let rt_src = limb_from_bits_le(lv.rt_bits);
yield_constr.constraint(filter * (rt_reg_read - rt_src));

yield_constr.constraint(filter * (rt_reg_write - rt_src));
}

// Check rs Reg
// addr(channels[0]) == rs
{
let rs_reg = lv.mem_channels[0].addr_virtual;
let rs_dst = limb_from_bits_le(lv.rs_bits);
yield_constr.constraint(filter * (rs_reg - rs_dst));
}

// Check ins result
// is_lsb[i] = 1 if i = lsb
// is_lsb[i] = 0 if i != lsb
// is_lsb[i] * (lsb - i) == 0
// auxs = 1 << lsd
// is_lsb[i] * (auxs - (i << 1)) == 0
// size = msb -lsb
// is_msb[i] = 1 if i = size
// is_msb[i] = 0 if i != size
// is_msb[i] * (size - i) == 0
// auxm = rt & !(mask << lsb)
// auxl = rs[0 : size+1]
// is_msb[i] * (auxl - rs[0:i+1]) == 0
// result == auxm + auxl * auxs
{
let msb = limb_from_bits_le(lv.rd_bits);
let rs_bits = lv.general.misc().rs_bits;
let lsb = limb_from_bits_le(lv.shamt_bits);

let auxm = lv.general.misc().auxm;
let auxl = lv.general.misc().auxl;
let auxs = lv.general.misc().auxs;
let rd_result = lv.mem_channels[2].value;

yield_constr.constraint(filter * (rd_result - auxm - auxl * auxs));

for i in 0..32 {
let is_msb = lv.general.misc().is_msb[i];
let is_lsb = lv.general.misc().is_lsb[i];
let cur_index = P::Scalar::from_canonical_usize(i);
let cur_mul = P::Scalar::from_canonical_usize(1 << i);

yield_constr.constraint(filter * is_lsb * (lsb - cur_index));
yield_constr.constraint(filter * is_lsb * (auxs - cur_mul));

yield_constr.constraint(filter * is_msb * (msb - lsb - cur_index));

let mut insert_bits = [P::ZEROS; 32];
insert_bits[0..i + 1].copy_from_slice(&rs_bits[0..i + 1]);
let insert_val = limb_from_bits_le(insert_bits.to_vec());
yield_constr.constraint(filter * is_msb * (auxl - insert_val));
}
}
}

pub fn eval_ext_circuit_insert<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter = lv.op.ins;

// Check rt Reg
// addr(channels[1]) == rt
// addr(channels[2]) == rt
{
let rt_reg_read = lv.mem_channels[1].addr_virtual;
let rt_reg_write = lv.mem_channels[2].addr_virtual;
let rt_src = limb_from_bits_le_recursive(builder, lv.rt_bits);

let constr = builder.sub_extension(rt_reg_read, rt_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(rt_reg_write, rt_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check rs Reg
// addr(channels[0]) == rs
{
let rs_reg = lv.mem_channels[0].addr_virtual;
let rs_src = limb_from_bits_le_recursive(builder, lv.rs_bits);
let constr = builder.sub_extension(rs_reg, rs_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check ins result
// is_lsb[i] = 1 if i = lsb
// is_lsb[i] = 0 if i != lsb
// is_lsb[i] * (lsb - i) == 0
// auxs = 1 << lsd
// is_lsb[i] * (auxs - (i << 1)) == 0
// size = msb -lsb
// is_msb[i] = 1 if i = size
// is_msb[i] = 0 if i != size
// is_msb[i] * (size - i) == 0
// auxm = rt & !(mask << lsb)
// auxl = rs[0 : size+1]
// is_msb[i] * (auxl - rs[0:i+1]) == 0
// result == auxm + auxl * auxs
{
let msb = limb_from_bits_le_recursive(builder, lv.rd_bits);
let rs_bits = lv.general.misc().rs_bits;
let lsb = limb_from_bits_le_recursive(builder, lv.shamt_bits);
let auxm = lv.general.misc().auxm;
let auxl = lv.general.misc().auxl;
let auxs = lv.general.misc().auxs;
let rd_result = lv.mem_channels[2].value;

let constr = builder.mul_extension(auxl, auxs);
let constr = builder.sub_extension(rd_result, constr);
let constr = builder.sub_extension(constr, auxm);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

for i in 0..32 {
let is_msb = lv.general.misc().is_msb[i];
let is_lsb = lv.general.misc().is_lsb[i];
let cur_index = builder.constant_extension(F::Extension::from_canonical_usize(i));
let cur_mul = builder.constant_extension(F::Extension::from_canonical_usize(1 << i));

let constr_msb = builder.mul_extension(filter, is_msb);
let constr_lsb = builder.mul_extension(filter, is_lsb);

let constr = builder.sub_extension(lsb, cur_index);
let constr = builder.mul_extension(constr, constr_lsb);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(auxs, cur_mul);
let constr = builder.mul_extension(constr, constr_lsb);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(msb, lsb);
let constr = builder.sub_extension(constr, cur_index);
let constr = builder.mul_extension(constr, constr_msb);
yield_constr.constraint(builder, constr);

let mut insert_bits = [builder.zero_extension(); 32];
insert_bits[0..i + 1].copy_from_slice(&rs_bits[0..i + 1]);
let insert_val = limb_from_bits_le_recursive(builder, insert_bits);
let constr = builder.sub_extension(auxl, insert_val);
let constr = builder.mul_extension(constr, constr_msb);
yield_constr.constraint(builder, constr);
}
}
}

pub fn eval_packed_ror<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
Expand Down Expand Up @@ -490,6 +656,168 @@ pub fn eval_ext_circuit_ror<F: RichField + Extendable<D>, const D: usize>(
}
}

pub fn eval_packed_maddu<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let filter = lv.op.maddu;

// Check rs Reg
// addr(channels[0]) == rs
{
let rs_reg = lv.mem_channels[0].addr_virtual;
let rs_src = limb_from_bits_le(lv.rs_bits);
yield_constr.constraint(filter * (rs_reg - rs_src));
}

// Check rt Reg
// addr(channels[1]) == rt
{
let rt_reg = lv.mem_channels[1].addr_virtual;
let rt_dst = limb_from_bits_le(lv.rt_bits);
yield_constr.constraint(filter * (rt_reg - rt_dst));
}

// Check hi Reg
// addr(channels[2]) == 33
// addr(channels[4]) == 33
{
let hi_reg_read = lv.mem_channels[2].addr_virtual;
let hi_reg_write = lv.mem_channels[4].addr_virtual;
let hi_src = P::Scalar::from_canonical_usize(33);
yield_constr.constraint(filter * (hi_reg_read - hi_src));
yield_constr.constraint(filter * (hi_reg_write - hi_src));
}

// Check lo Reg
// addr(channels[3]) == 32
// addr(channels[5]) == 32
{
let lo_reg_read = lv.mem_channels[3].addr_virtual;
let lo_reg_write = lv.mem_channels[5].addr_virtual;
let lo_src = P::Scalar::from_canonical_usize(32);
yield_constr.constraint(filter * (lo_reg_read - lo_src));
yield_constr.constraint(filter * (lo_reg_write - lo_src));
}

// Check maddu result
// carry = overflow << 32
// scale = 1 << 32
// carry * (carry - scale) == 0
// result + (overflow << 32) == (hi,lo) + rs * rt
{
let rs = lv.mem_channels[0].value;
let rt = lv.mem_channels[1].value;
let hi = lv.mem_channels[2].value;
let lo = lv.mem_channels[3].value;
let hi_result: P = lv.mem_channels[4].value;
let lo_result = lv.mem_channels[5].value;
let carry = lv.general.misc().auxm;
let scale = P::Scalar::from_canonical_usize(1 << 32);
let result = hi_result * scale + lo_result;
let mul = rs * rt;
let addend = hi * scale + lo;
let overflow = carry * scale;

yield_constr.constraint(filter * carry * (carry - scale));
yield_constr.constraint(filter * (mul + addend - overflow - result));
}
}

pub fn eval_ext_circuit_maddu<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter = lv.op.maddu;

// Check rs Reg
// addr(channels[0]) == rs
{
let rs_reg = lv.mem_channels[0].addr_virtual;
let rs_src = limb_from_bits_le_recursive(builder, lv.rs_bits);
let constr = builder.sub_extension(rs_reg, rs_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check rt Reg
// addr(channels[1]) == rt
{
let rt_reg = lv.mem_channels[1].addr_virtual;
let rt_src = limb_from_bits_le_recursive(builder, lv.rt_bits);
let constr = builder.sub_extension(rt_reg, rt_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check hi Reg
// addr(channels[2]) == 33
// addr(channels[4]) == 33
{
let hi_reg_read = lv.mem_channels[2].addr_virtual;
let hi_reg_write = lv.mem_channels[4].addr_virtual;
let hi_src = builder.constant_extension(F::Extension::from_canonical_usize(33));
let constr = builder.sub_extension(hi_reg_read, hi_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(hi_reg_write, hi_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check lo Reg
// addr(channels[3]) == 32
// addr(channels[5]) == 32
{
let lo_reg_read = lv.mem_channels[3].addr_virtual;
let lo_reg_write = lv.mem_channels[5].addr_virtual;
let lo_src = builder.constant_extension(F::Extension::from_canonical_usize(32));
let constr = builder.sub_extension(lo_reg_read, lo_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(lo_reg_write, lo_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check maddu result
// carry = overflow << 32
// scale = 1 << 32
// carry * (carry - scale) == 0
// result + (overflow << 32) == (hi,lo) + rs * rt
{
let rs = lv.mem_channels[0].value;
let rt = lv.mem_channels[1].value;
let hi = lv.mem_channels[2].value;
let lo = lv.mem_channels[3].value;
let hi_result = lv.mem_channels[4].value;
let lo_result = lv.mem_channels[5].value;
let carry = lv.general.misc().auxm;
let scale = builder.constant_extension(F::Extension::from_canonical_usize(1 << 32));
let result = builder.mul_extension(hi_result, scale);
let result = builder.add_extension(result, lo_result);
let mul = builder.mul_extension(rs, rt);
let addend = builder.mul_extension(hi, scale);
let addend = builder.add_extension(addend, lo);

let overflow = builder.mul_extension(carry, scale);

let constr = builder.sub_extension(carry, scale);
let constr = builder.mul_extension(constr, carry);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let constr = builder.add_extension(mul, addend);
let constr = builder.sub_extension(constr, overflow);
let constr = builder.sub_extension(constr, result);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}
}

pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
Expand All @@ -499,6 +827,8 @@ pub fn eval_packed<P: PackedField>(
eval_packed_teq(lv, yield_constr);
eval_packed_extract(lv, yield_constr);
eval_packed_ror(lv, yield_constr);
eval_packed_insert(lv, yield_constr);
eval_packed_maddu(lv, yield_constr);
}

pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
Expand All @@ -511,4 +841,6 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
eval_ext_circuit_teq(builder, lv, yield_constr);
eval_ext_circuit_extract(builder, lv, yield_constr);
eval_ext_circuit_ror(builder, lv, yield_constr);
eval_ext_circuit_insert(builder, lv, yield_constr);
eval_ext_circuit_maddu(builder, lv, yield_constr);
}
Loading
Loading