diff --git a/emulator/src/state.rs b/emulator/src/state.rs index 300b50a1..8a26b0fd 100644 --- a/emulator/src/state.rs +++ b/emulator/src/state.rs @@ -863,8 +863,8 @@ impl InstrumentedState { let mut rs = self.state.registers[((insn >> 21) & 0x1f) as usize]; let mut rd_reg = rt_reg; let fun = insn & 0x3f; - if opcode == 0 || opcode == 0x1c || (opcode == 0x1F && fun == 0x20) { - // R-type (stores rd) + if opcode == 0 || opcode == 0x1c || (opcode == 0x1F && (fun == 0x20 || fun == 4)) { + // R-type (stores rd), partial Special3 insts: ins, seb, seh, wsbh rt = self.state.registers[rt_reg as usize]; rd_reg = (insn >> 11) & 0x1f; } else if opcode < 0x20 { diff --git a/prover/src/arithmetic/shift.rs b/prover/src/arithmetic/shift.rs index ff69dcf6..9cf1b41d 100644 --- a/prover/src/arithmetic/shift.rs +++ b/prover/src/arithmetic/shift.rs @@ -34,6 +34,7 @@ use crate::arithmetic::utils::{read_value, read_value_i64_limbs, u32_to_array}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; /// Generates a shift operation (either SLL(V) or SRL(V)). +/// /// The inputs are stored in the form `(shift, input, 1 << shift)`. /// NB: if `shift >= 32`, then the third register holds 0. /// We leverage the functions in mul.rs and div.rs to carry out diff --git a/prover/src/arithmetic/sra.rs b/prover/src/arithmetic/sra.rs index 08c225a4..bfacbe32 100644 --- a/prover/src/arithmetic/sra.rs +++ b/prover/src/arithmetic/sra.rs @@ -23,6 +23,7 @@ use crate::arithmetic::utils::{read_value, u32_to_array}; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; /// Generates a shift operation (SRA(V). +/// /// The inputs are stored in the form `(shift, input, 1 >> shift)`. /// NB: if `shift >= 32`, then the third register holds 0. /// We leverage the functions in div.rs to carry out diff --git a/prover/src/cpu/columns/ops.rs b/prover/src/cpu/columns/ops.rs index cbb5340d..be76bcde 100644 --- a/prover/src/cpu/columns/ops.rs +++ b/prover/src/cpu/columns/ops.rs @@ -31,6 +31,8 @@ pub struct OpsColumnsView { pub m_op_store: T, pub nop: T, pub ext: T, + pub ins: T, + pub maddu: T, pub rdhwr: T, pub signext8: T, pub signext16: T, diff --git a/prover/src/cpu/misc.rs b/prover/src/cpu/misc.rs index 042d280f..2d55cd88 100644 --- a/prover/src/cpu/misc.rs +++ b/prover/src/cpu/misc.rs @@ -394,6 +394,172 @@ pub fn eval_ext_circuit_extract, const D: usize>( } } +pub fn eval_packed_insert( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + let filter = lv.op.ins; + + // Check rt Reg + // addr(channels[1]) == rt + // addr(channels[2]) == rt + { + let rt_reg_read = lv.mem_channels[1].addr_virtual; + let rt_reg_write = lv.mem_channels[2].addr_virtual; + let rt_src = limb_from_bits_le(lv.rt_bits); + yield_constr.constraint(filter * (rt_reg_read - rt_src)); + + yield_constr.constraint(filter * (rt_reg_write - rt_src)); + } + + // Check rs Reg + // addr(channels[0]) == rs + { + let rs_reg = lv.mem_channels[0].addr_virtual; + let rs_dst = limb_from_bits_le(lv.rs_bits); + yield_constr.constraint(filter * (rs_reg - rs_dst)); + } + + // Check ins result + // is_lsb[i] = 1 if i = lsb + // is_lsb[i] = 0 if i != lsb + // is_lsb[i] * (lsb - i) == 0 + // auxs = 1 << lsd + // is_lsb[i] * (auxs - (i << 1)) == 0 + // size = msb -lsb + // is_msb[i] = 1 if i = size + // is_msb[i] = 0 if i != size + // is_msb[i] * (size - i) == 0 + // auxm = rt & !(mask << lsb) + // auxl = rs[0 : size+1] + // is_msb[i] * (auxl - rs[0:i+1]) == 0 + // result == auxm + auxl * auxs + { + let msb = limb_from_bits_le(lv.rd_bits); + let rs_bits = lv.general.misc().rs_bits; + let lsb = limb_from_bits_le(lv.shamt_bits); + + let auxm = lv.general.misc().auxm; + let auxl = lv.general.misc().auxl; + let auxs = lv.general.misc().auxs; + let rd_result = lv.mem_channels[2].value; + + yield_constr.constraint(filter * (rd_result - auxm - auxl * auxs)); + + for i in 0..32 { + let is_msb = lv.general.misc().is_msb[i]; + let is_lsb = lv.general.misc().is_lsb[i]; + let cur_index = P::Scalar::from_canonical_usize(i); + let cur_mul = P::Scalar::from_canonical_usize(1 << i); + + yield_constr.constraint(filter * is_lsb * (lsb - cur_index)); + yield_constr.constraint(filter * is_lsb * (auxs - cur_mul)); + + yield_constr.constraint(filter * is_msb * (msb - lsb - cur_index)); + + let mut insert_bits = [P::ZEROS; 32]; + insert_bits[0..i + 1].copy_from_slice(&rs_bits[0..i + 1]); + let insert_val = limb_from_bits_le(insert_bits.to_vec()); + yield_constr.constraint(filter * is_msb * (auxl - insert_val)); + } + } +} + +pub fn eval_ext_circuit_insert, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let filter = lv.op.ins; + + // Check rt Reg + // addr(channels[1]) == rt + // addr(channels[2]) == rt + { + let rt_reg_read = lv.mem_channels[1].addr_virtual; + let rt_reg_write = lv.mem_channels[2].addr_virtual; + let rt_src = limb_from_bits_le_recursive(builder, lv.rt_bits); + + let constr = builder.sub_extension(rt_reg_read, rt_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + + let constr = builder.sub_extension(rt_reg_write, rt_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + } + + // Check rs Reg + // addr(channels[0]) == rs + { + let rs_reg = lv.mem_channels[0].addr_virtual; + let rs_src = limb_from_bits_le_recursive(builder, lv.rs_bits); + let constr = builder.sub_extension(rs_reg, rs_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + } + + // Check ins result + // is_lsb[i] = 1 if i = lsb + // is_lsb[i] = 0 if i != lsb + // is_lsb[i] * (lsb - i) == 0 + // auxs = 1 << lsd + // is_lsb[i] * (auxs - (i << 1)) == 0 + // size = msb -lsb + // is_msb[i] = 1 if i = size + // is_msb[i] = 0 if i != size + // is_msb[i] * (size - i) == 0 + // auxm = rt & !(mask << lsb) + // auxl = rs[0 : size+1] + // is_msb[i] * (auxl - rs[0:i+1]) == 0 + // result == auxm + auxl * auxs + { + let msb = limb_from_bits_le_recursive(builder, lv.rd_bits); + let rs_bits = lv.general.misc().rs_bits; + let lsb = limb_from_bits_le_recursive(builder, lv.shamt_bits); + let auxm = lv.general.misc().auxm; + let auxl = lv.general.misc().auxl; + let auxs = lv.general.misc().auxs; + let rd_result = lv.mem_channels[2].value; + + let constr = builder.mul_extension(auxl, auxs); + let constr = builder.sub_extension(rd_result, constr); + let constr = builder.sub_extension(constr, auxm); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + + for i in 0..32 { + let is_msb = lv.general.misc().is_msb[i]; + let is_lsb = lv.general.misc().is_lsb[i]; + let cur_index = builder.constant_extension(F::Extension::from_canonical_usize(i)); + let cur_mul = builder.constant_extension(F::Extension::from_canonical_usize(1 << i)); + + let constr_msb = builder.mul_extension(filter, is_msb); + let constr_lsb = builder.mul_extension(filter, is_lsb); + + let constr = builder.sub_extension(lsb, cur_index); + let constr = builder.mul_extension(constr, constr_lsb); + yield_constr.constraint(builder, constr); + + let constr = builder.sub_extension(auxs, cur_mul); + let constr = builder.mul_extension(constr, constr_lsb); + yield_constr.constraint(builder, constr); + + let constr = builder.sub_extension(msb, lsb); + let constr = builder.sub_extension(constr, cur_index); + let constr = builder.mul_extension(constr, constr_msb); + yield_constr.constraint(builder, constr); + + let mut insert_bits = [builder.zero_extension(); 32]; + insert_bits[0..i + 1].copy_from_slice(&rs_bits[0..i + 1]); + let insert_val = limb_from_bits_le_recursive(builder, insert_bits); + let constr = builder.sub_extension(auxl, insert_val); + let constr = builder.mul_extension(constr, constr_msb); + yield_constr.constraint(builder, constr); + } + } +} + pub fn eval_packed_ror( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, @@ -490,6 +656,168 @@ pub fn eval_ext_circuit_ror, const D: usize>( } } +pub fn eval_packed_maddu( + lv: &CpuColumnsView

, + yield_constr: &mut ConstraintConsumer

, +) { + let filter = lv.op.maddu; + + // Check rs Reg + // addr(channels[0]) == rs + { + let rs_reg = lv.mem_channels[0].addr_virtual; + let rs_src = limb_from_bits_le(lv.rs_bits); + yield_constr.constraint(filter * (rs_reg - rs_src)); + } + + // Check rt Reg + // addr(channels[1]) == rt + { + let rt_reg = lv.mem_channels[1].addr_virtual; + let rt_dst = limb_from_bits_le(lv.rt_bits); + yield_constr.constraint(filter * (rt_reg - rt_dst)); + } + + // Check hi Reg + // addr(channels[2]) == 33 + // addr(channels[4]) == 33 + { + let hi_reg_read = lv.mem_channels[2].addr_virtual; + let hi_reg_write = lv.mem_channels[4].addr_virtual; + let hi_src = P::Scalar::from_canonical_usize(33); + yield_constr.constraint(filter * (hi_reg_read - hi_src)); + yield_constr.constraint(filter * (hi_reg_write - hi_src)); + } + + // Check lo Reg + // addr(channels[3]) == 32 + // addr(channels[5]) == 32 + { + let lo_reg_read = lv.mem_channels[3].addr_virtual; + let lo_reg_write = lv.mem_channels[5].addr_virtual; + let lo_src = P::Scalar::from_canonical_usize(32); + yield_constr.constraint(filter * (lo_reg_read - lo_src)); + yield_constr.constraint(filter * (lo_reg_write - lo_src)); + } + + // Check maddu result + // carry = overflow << 32 + // scale = 1 << 32 + // carry * (carry - scale) == 0 + // result + (overflow << 32) == (hi,lo) + rs * rt + { + let rs = lv.mem_channels[0].value; + let rt = lv.mem_channels[1].value; + let hi = lv.mem_channels[2].value; + let lo = lv.mem_channels[3].value; + let hi_result: P = lv.mem_channels[4].value; + let lo_result = lv.mem_channels[5].value; + let carry = lv.general.misc().auxm; + let scale = P::Scalar::from_canonical_usize(1 << 32); + let result = hi_result * scale + lo_result; + let mul = rs * rt; + let addend = hi * scale + lo; + let overflow = carry * scale; + + yield_constr.constraint(filter * carry * (carry - scale)); + yield_constr.constraint(filter * (mul + addend - overflow - result)); + } +} + +pub fn eval_ext_circuit_maddu, const D: usize>( + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + lv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let filter = lv.op.maddu; + + // Check rs Reg + // addr(channels[0]) == rs + { + let rs_reg = lv.mem_channels[0].addr_virtual; + let rs_src = limb_from_bits_le_recursive(builder, lv.rs_bits); + let constr = builder.sub_extension(rs_reg, rs_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + } + + // Check rt Reg + // addr(channels[1]) == rt + { + let rt_reg = lv.mem_channels[1].addr_virtual; + let rt_src = limb_from_bits_le_recursive(builder, lv.rt_bits); + let constr = builder.sub_extension(rt_reg, rt_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + } + + // Check hi Reg + // addr(channels[2]) == 33 + // addr(channels[4]) == 33 + { + let hi_reg_read = lv.mem_channels[2].addr_virtual; + let hi_reg_write = lv.mem_channels[4].addr_virtual; + let hi_src = builder.constant_extension(F::Extension::from_canonical_usize(33)); + let constr = builder.sub_extension(hi_reg_read, hi_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + + let constr = builder.sub_extension(hi_reg_write, hi_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + } + + // Check lo Reg + // addr(channels[3]) == 32 + // addr(channels[5]) == 32 + { + let lo_reg_read = lv.mem_channels[3].addr_virtual; + let lo_reg_write = lv.mem_channels[5].addr_virtual; + let lo_src = builder.constant_extension(F::Extension::from_canonical_usize(32)); + let constr = builder.sub_extension(lo_reg_read, lo_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + + let constr = builder.sub_extension(lo_reg_write, lo_src); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + } + + // Check maddu result + // carry = overflow << 32 + // scale = 1 << 32 + // carry * (carry - scale) == 0 + // result + (overflow << 32) == (hi,lo) + rs * rt + { + let rs = lv.mem_channels[0].value; + let rt = lv.mem_channels[1].value; + let hi = lv.mem_channels[2].value; + let lo = lv.mem_channels[3].value; + let hi_result = lv.mem_channels[4].value; + let lo_result = lv.mem_channels[5].value; + let carry = lv.general.misc().auxm; + let scale = builder.constant_extension(F::Extension::from_canonical_usize(1 << 32)); + let result = builder.mul_extension(hi_result, scale); + let result = builder.add_extension(result, lo_result); + let mul = builder.mul_extension(rs, rt); + let addend = builder.mul_extension(hi, scale); + let addend = builder.add_extension(addend, lo); + + let overflow = builder.mul_extension(carry, scale); + + let constr = builder.sub_extension(carry, scale); + let constr = builder.mul_extension(constr, carry); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + + let constr = builder.add_extension(mul, addend); + let constr = builder.sub_extension(constr, overflow); + let constr = builder.sub_extension(constr, result); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + } +} + pub fn eval_packed( lv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, @@ -499,6 +827,8 @@ pub fn eval_packed( eval_packed_teq(lv, yield_constr); eval_packed_extract(lv, yield_constr); eval_packed_ror(lv, yield_constr); + eval_packed_insert(lv, yield_constr); + eval_packed_maddu(lv, yield_constr); } pub fn eval_ext_circuit, const D: usize>( @@ -511,4 +841,6 @@ pub fn eval_ext_circuit, const D: usize>( eval_ext_circuit_teq(builder, lv, yield_constr); eval_ext_circuit_extract(builder, lv, yield_constr); eval_ext_circuit_ror(builder, lv, yield_constr); + eval_ext_circuit_insert(builder, lv, yield_constr); + eval_ext_circuit_maddu(builder, lv, yield_constr); } diff --git a/prover/src/fixed_recursive_verifier.rs b/prover/src/fixed_recursive_verifier.rs index 471f5d93..276a44a5 100644 --- a/prover/src/fixed_recursive_verifier.rs +++ b/prover/src/fixed_recursive_verifier.rs @@ -47,9 +47,11 @@ use crate::verifier::verify_proof; /// The recursion threshold. We end a chain of recursive proofs once we reach this size. const THRESHOLD_DEGREE_BITS: usize = 13; -/// Contains all recursive circuits used in the system. For each STARK and each initial -/// `degree_bits`, this contains a chain of recursive circuits for shrinking that STARK from -/// `degree_bits` to a constant `THRESHOLD_DEGREE_BITS`. It also contains a special root circuit +/// Contains all recursive circuits used in the system. +/// +/// For each STARK and each initial `degree_bits`, this contains a chain of +/// recursive circuits for shrinking that STARK from `degree_bits` to a constant +/// `THRESHOLD_DEGREE_BITS`. It also contains a special root circuit /// for combining each STARK's shrunk wrapper proof into a single proof. #[derive(Eq, PartialEq, Debug)] pub struct AllRecursiveCircuits diff --git a/prover/src/keccak_sponge/mod.rs b/prover/src/keccak_sponge/mod.rs index 92b7f0c1..c7a5f12e 100644 --- a/prover/src/keccak_sponge/mod.rs +++ b/prover/src/keccak_sponge/mod.rs @@ -1,4 +1,5 @@ //! The Keccak sponge STARK is used to hash a variable amount of data which is read from memory. +//! //! It connects to the memory STARK to read input data, and to the Keccak-f STARK to evaluate the //! permutation at each absorption step. diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index b699ca73..37047aa0 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -127,6 +127,8 @@ pub(crate) enum Operation { MstoreGeneral(MemOp, u8, u8, u32), Nop, Ext(u8, u8, u8, u8), + Ins(u8, u8, u8, u8), + Maddu(u8, u8), Ror(u8, u8, u8), Rdhwr(u8, u8), Signext(u8, u8, u8), @@ -1392,6 +1394,75 @@ pub(crate) fn generate_extract( Ok(()) } +pub(crate) fn generate_insert( + rt: u8, + rs: u8, + msb: u8, + lsb: u8, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + assert!(msb < 32); + assert!(lsb <= msb); + let (in0, log_in0) = reg_read_with_log(rs, 0, state, &mut row)?; + let (in1, log_in1) = reg_read_with_log(rt, 1, state, &mut row)?; + let mask = (1 << (msb - lsb + 1)) - 1; + let mask_field = mask << lsb; + + let rs_bits_le = (0..32) + .map(|i| { + let bit: usize = (in0 >> i) & 0x01; + F::from_canonical_u32(bit as u32) + }) + .collect_vec(); + row.general.misc_mut().rs_bits = rs_bits_le.try_into().unwrap(); + + row.general.misc_mut().is_msb = [F::ZERO; 32]; + row.general.misc_mut().is_msb[(msb - lsb) as usize] = F::ONE; + row.general.misc_mut().is_lsb = [F::ZERO; 32]; + row.general.misc_mut().is_lsb[lsb as usize] = F::ONE; + row.general.misc_mut().auxs = F::from_canonical_u32(1 << lsb); + + row.general.misc_mut().auxm = F::from_canonical_u32((in1 & !mask_field) as u32); + row.general.misc_mut().auxl = F::from_canonical_u32((in0 & mask) as u32); + row.general.misc_mut().auxs = F::from_canonical_u32((1 << lsb) as u32); + + let result = (in1 & !mask_field) | ((in0 << lsb) & mask_field); + let log_out0 = reg_write_with_log(rt, 2, result, state, &mut row)?; + + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_out0); + state.traces.push_cpu(row); + + Ok(()) +} + +pub(crate) fn generate_maddu( + rt: u8, + rs: u8, + state: &mut GenerationState, + mut row: CpuColumnsView, +) -> Result<(), ProgramError> { + let (in0, log_in0) = reg_read_with_log(rs, 0, state, &mut row)?; + let (in1, log_in1) = reg_read_with_log(rt, 1, state, &mut row)?; + let (in2, log_in2) = reg_read_with_log(33, 2, state, &mut row)?; + let (in3, log_in3) = reg_read_with_log(32, 3, state, &mut row)?; + let mul = in0 * in1; + let addend = (in2 << 32) + in3; + let (result, overflow) = (mul as u64).overflowing_add(addend as u64); + let log_out0 = reg_write_with_log(33, 4, (result >> 32) as usize, state, &mut row)?; + let log_out1 = reg_write_with_log(32, 5, (result as u32) as usize, state, &mut row)?; + row.general.misc_mut().auxm = F::from_canonical_usize((overflow as usize) << 32); + state.traces.push_memory(log_in0); + state.traces.push_memory(log_in1); + state.traces.push_memory(log_in2); + state.traces.push_memory(log_in3); + state.traces.push_memory(log_out0); + state.traces.push_memory(log_out1); + state.traces.push_cpu(row); + Ok(()) +} pub(crate) fn generate_rdhwr( rt: u8, rd: u8, diff --git a/prover/src/witness/transition.rs b/prover/src/witness/transition.rs index 99fb116b..b1710ab4 100644 --- a/prover/src/witness/transition.rs +++ b/prover/src/witness/transition.rs @@ -277,7 +277,9 @@ fn decode(registers: RegistersState, insn: u32) -> Result Ok(Operation::BinaryLogicImm(logic::Op::Xor, rs, rt, offset)), // XORI: rt = rs + zext(imm) (0b000000, 0b001100, _) => Ok(Operation::Syscall), // Syscall (0b110011, _, _) => Ok(Operation::Nop), // Pref + (0b011100, 0b000001, _) => Ok(Operation::Maddu(rt, rs)), // rdhwr (0b011111, 0b000000, _) => Ok(Operation::Ext(rt, rs, rd, sa)), // ext + (0b011111, 0b000100, _) => Ok(Operation::Ins(rt, rs, rd, sa)), // ins (0b011111, 0b111011, _) => Ok(Operation::Rdhwr(rt, rd)), // rdhwr (0b011111, 0b100000, _) => { if sa == 0b011000 { @@ -334,6 +336,8 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::MstoreGeneral(..) => &mut flags.m_op_store, Operation::Nop => &mut flags.nop, Operation::Ext(_, _, _, _) => &mut flags.ext, + Operation::Ins(_, _, _, _) => &mut flags.ins, + Operation::Maddu(_, _) => &mut flags.maddu, Operation::Ror(_, _, _) => &mut flags.ror, Operation::Rdhwr(_, _) => &mut flags.rdhwr, Operation::Signext(_, _, 8u8) => &mut flags.signext8, @@ -447,6 +451,8 @@ fn perform_op( Operation::SetContext => generate_set_context(state, row)?, Operation::Nop => generate_nop(state, row)?, Operation::Ext(rt, rs, msbd, lsb) => generate_extract(rt, rs, msbd, lsb, state, row)?, + Operation::Ins(rt, rs, msb, lsb) => generate_insert(rt, rs, msb, lsb, state, row)?, + Operation::Maddu(rt, rs) => generate_maddu(rt, rs, state, row)?, Operation::Ror(rd, rt, sa) => generate_ror(rd, rt, sa, state, row)?, Operation::Rdhwr(rt, rd) => generate_rdhwr(rt, rd, state, row)?, Operation::Signext(rd, rt, bits) => generate_signext(rd, rt, bits, state, row)?,