diff --git a/Cargo.toml b/Cargo.toml index 9702da73..de46569e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chiquito" -version = "0.1.2023070800" +version = "0.1.2023092400" edition = "2021" license = "MIT OR Apache-2.0" authors = ["Leo Lara "] diff --git a/examples/mimc7.py b/examples/mimc7.py new file mode 100644 index 00000000..569e6b48 --- /dev/null +++ b/examples/mimc7.py @@ -0,0 +1,169 @@ +from __future__ import annotations +from chiquito.dsl import SuperCircuit, Circuit, StepType +from chiquito.cb import eq, table +from chiquito.util import F + +from mimc7_constants import ROUND_KEYS + +ROUNDS = 91 + + +# It's the best practice to wrap all values in F, even though the `assign` functions automatically wrap values in F. +class Mimc7Constants(Circuit): + def setup(self): + self.pragma_num_steps(ROUNDS) + self.lookup_row = self.fixed("constant row") + self.lookup_c = self.fixed("constant value") + self.lookup_table = self.new_table( + table().add(self.lookup_row).add(self.lookup_c) + ) + + def fixed_gen(self): + for i, round_key in enumerate(ROUND_KEYS): + self.assign(i, self.lookup_row, F(i)) + self.assign(i, self.lookup_c, F(round_key)) + + +class Mimc7FirstStep(StepType): + def setup(self): + self.xkc = self.internal("xkc") + self.y = self.internal("y") + + self.constr(eq(self.circuit.x + self.circuit.k + self.circuit.c, self.xkc)) + self.constr( + eq( + self.xkc + * self.xkc + * self.xkc + * self.xkc + * self.xkc + * self.xkc + * self.xkc, + self.y, + ) + ) + + self.transition(eq(self.y, self.circuit.x.next())) + self.transition(eq(self.circuit.k, self.circuit.k.next())) + self.transition(eq(self.circuit.row, 0)) + self.transition(eq(self.circuit.row + 1, self.circuit.row.next())) + + self.add_lookup( + self.circuit.constants_table.apply(self.circuit.row).apply(self.circuit.c) + ) + + def wg(self, x_value, k_value, c_value, row_value): + self.assign(self.circuit.x, F(x_value)) + self.assign(self.circuit.k, F(k_value)) + self.assign(self.circuit.c, F(c_value)) + self.assign(self.circuit.row, F(row_value)) + + xkc_value = F(x_value + k_value + c_value) + self.assign(self.xkc, F(xkc_value)) + self.assign(self.y, F(xkc_value**7)) + + +class Mimc7Step(StepType): + def setup(self): + self.xkc = self.internal("xkc") + self.y = self.internal("y") + + self.constr(eq(self.circuit.x + self.circuit.k + self.circuit.c, self.xkc)) + self.constr( + eq( + self.xkc + * self.xkc + * self.xkc + * self.xkc + * self.xkc + * self.xkc + * self.xkc, + self.y, + ) + ) + + self.transition(eq(self.y, self.circuit.x.next())) + self.transition(eq(self.circuit.k, self.circuit.k.next())) + self.transition(eq(self.circuit.row + 1, self.circuit.row.next())) + + self.add_lookup( + self.circuit.constants_table.apply(self.circuit.row).apply(self.circuit.c) + ) + + def wg(self, x_value, k_value, c_value, row_value): + self.assign(self.circuit.x, F(x_value)) + self.assign(self.circuit.k, F(k_value)) + self.assign(self.circuit.c, F(c_value)) + self.assign(self.circuit.row, F(row_value)) + + xkc_value = F(x_value + k_value + c_value) + self.assign(self.xkc, F(xkc_value)) + self.assign(self.y, F(xkc_value**7)) + + +class Mimc7LastStep(StepType): + def setup(self): + self.out = self.internal("out") + + self.constr(eq(self.circuit.x + self.circuit.k, self.out)) + + def wg(self, x_value, k_value, _, row_value): + self.assign(self.circuit.x, F(x_value)) + self.assign(self.circuit.k, F(k_value)) + self.assign(self.circuit.row, F(row_value)) + self.assign(self.out, F(x_value + k_value)) + + +class Mimc7Circuit(Circuit): + def setup(self): + self.x = self.forward("x") + self.k = self.forward("k") + self.c = self.forward("c") + self.row = self.forward("row") + + self.mimc7_first_step = self.step_type(Mimc7FirstStep(self, "mimc7_first_step")) + self.mimc7_step = self.step_type(Mimc7Step(self, "mimc7_step")) + self.mimc7_last_step = self.step_type(Mimc7LastStep(self, "mimc7_last_step")) + + self.pragma_first_step(self.mimc7_first_step) + self.pragma_last_step(self.mimc7_last_step) + self.pragma_num_steps(ROUNDS + 2 - 1) + + def trace(self, x_in_value, k_value): + c_value = F(ROUND_KEYS[0]) + x_value = F(x_in_value) + row_value = F(0) + + self.add(self.mimc7_first_step, x_value, k_value, c_value, row_value) + + for i in range(1, ROUNDS): + row_value += F(1) + x_value += F(k_value + c_value) + x_value = F(x_value**7) + c_value = F(ROUND_KEYS[i]) + + self.add(self.mimc7_step, x_value, k_value, c_value, row_value) + + row_value += F(1) + x_value += F(k_value + c_value) + x_value = F(x_value**7) + + self.add(self.mimc7_last_step, x_value, k_value, c_value, row_value) + + +class Mimc7SuperCircuit(SuperCircuit): + def setup(self): + self.mimc7_constants = self.sub_circuit(Mimc7Constants(self)) + self.mimc7_circuit = self.sub_circuit( + Mimc7Circuit(self, constants_table=self.mimc7_constants.lookup_table) + ) + + def mapping(self, x_in_value, k_value): + self.map(self.mimc7_circuit, x_in_value, k_value) + + +mimc7 = Mimc7SuperCircuit() +mimc7_super_witness = mimc7.gen_witness(F(1), F(2)) +# for key, value in mimc7_super_witness.items(): +# print(f"{key}: {str(value)}") +mimc7.halo2_mock_prover(mimc7_super_witness) diff --git a/examples/mimc7_constants.py b/examples/mimc7_constants.py new file mode 100644 index 00000000..08d9fe57 --- /dev/null +++ b/examples/mimc7_constants.py @@ -0,0 +1,93 @@ +ROUND_KEYS = [ + 0, + 20888961410941983456478427210666206549300505294776164667214940546594746570981, + 15265126113435022738560151911929040668591755459209400716467504685752745317193, + 8334177627492981984476504167502758309043212251641796197711684499645635709656, + 1374324219480165500871639364801692115397519265181803854177629327624133579404, + 11442588683664344394633565859260176446561886575962616332903193988751292992472, + 2558901189096558760448896669327086721003508630712968559048179091037845349145, + 11189978595292752354820141775598510151189959177917284797737745690127318076389, + 3262966573163560839685415914157855077211340576201936620532175028036746741754, + 17029914891543225301403832095880481731551830725367286980611178737703889171730, + 4614037031668406927330683909387957156531244689520944789503628527855167665518, + 19647356996769918391113967168615123299113119185942498194367262335168397100658, + 5040699236106090655289931820723926657076483236860546282406111821875672148900, + 2632385916954580941368956176626336146806721642583847728103570779270161510514, + 17691411851977575435597871505860208507285462834710151833948561098560743654671, + 11482807709115676646560379017491661435505951727793345550942389701970904563183, + 8360838254132998143349158726141014535383109403565779450210746881879715734773, + 12663821244032248511491386323242575231591777785787269938928497649288048289525, + 3067001377342968891237590775929219083706800062321980129409398033259904188058, + 8536471869378957766675292398190944925664113548202769136103887479787957959589, + 19825444354178182240559170937204690272111734703605805530888940813160705385792, + 16703465144013840124940690347975638755097486902749048533167980887413919317592, + 13061236261277650370863439564453267964462486225679643020432589226741411380501, + 10864774797625152707517901967943775867717907803542223029967000416969007792571, + 10035653564014594269791753415727486340557376923045841607746250017541686319774, + 3446968588058668564420958894889124905706353937375068998436129414772610003289, + 4653317306466493184743870159523234588955994456998076243468148492375236846006, + 8486711143589723036499933521576871883500223198263343024003617825616410932026, + 250710584458582618659378487568129931785810765264752039738223488321597070280, + 2104159799604932521291371026105311735948154964200596636974609406977292675173, + 16313562605837709339799839901240652934758303521543693857533755376563489378839, + 6032365105133504724925793806318578936233045029919447519826248813478479197288, + 14025118133847866722315446277964222215118620050302054655768867040006542798474, + 7400123822125662712777833064081316757896757785777291653271747396958201309118, + 1744432620323851751204287974553233986555641872755053103823939564833813704825, + 8316378125659383262515151597439205374263247719876250938893842106722210729522, + 6739722627047123650704294650168547689199576889424317598327664349670094847386, + 21211457866117465531949733809706514799713333930924902519246949506964470524162, + 13718112532745211817410303291774369209520657938741992779396229864894885156527, + 5264534817993325015357427094323255342713527811596856940387954546330728068658, + 18884137497114307927425084003812022333609937761793387700010402412840002189451, + 5148596049900083984813839872929010525572543381981952060869301611018636120248, + 19799686398774806587970184652860783461860993790013219899147141137827718662674, + 19240878651604412704364448729659032944342952609050243268894572835672205984837, + 10546185249390392695582524554167530669949955276893453512788278945742408153192, + 5507959600969845538113649209272736011390582494851145043668969080335346810411, + 18177751737739153338153217698774510185696788019377850245260475034576050820091, + 19603444733183990109492724100282114612026332366576932662794133334264283907557, + 10548274686824425401349248282213580046351514091431715597441736281987273193140, + 1823201861560942974198127384034483127920205835821334101215923769688644479957, + 11867589662193422187545516240823411225342068709600734253659804646934346124945, + 18718569356736340558616379408444812528964066420519677106145092918482774343613, + 10530777752259630125564678480897857853807637120039176813174150229243735996839, + 20486583726592018813337145844457018474256372770211860618687961310422228379031, + 12690713110714036569415168795200156516217175005650145422920562694422306200486, + 17386427286863519095301372413760745749282643730629659997153085139065756667205, + 2216432659854733047132347621569505613620980842043977268828076165669557467682, + 6309765381643925252238633914530877025934201680691496500372265330505506717193, + 20806323192073945401862788605803131761175139076694468214027227878952047793390, + 4037040458505567977365391535756875199663510397600316887746139396052445718861, + 19948974083684238245321361840704327952464170097132407924861169241740046562673, + 845322671528508199439318170916419179535949348988022948153107378280175750024, + 16222384601744433420585982239113457177459602187868460608565289920306145389382, + 10232118865851112229330353999139005145127746617219324244541194256766741433339, + 6699067738555349409504843460654299019000594109597429103342076743347235369120, + 6220784880752427143725783746407285094967584864656399181815603544365010379208, + 6129250029437675212264306655559561251995722990149771051304736001195288083309, + 10773245783118750721454994239248013870822765715268323522295722350908043393604, + 4490242021765793917495398271905043433053432245571325177153467194570741607167, + 19596995117319480189066041930051006586888908165330319666010398892494684778526, + 837850695495734270707668553360118467905109360511302468085569220634750561083, + 11803922811376367215191737026157445294481406304781326649717082177394185903907, + 10201298324909697255105265958780781450978049256931478989759448189112393506592, + 13564695482314888817576351063608519127702411536552857463682060761575100923924, + 9262808208636973454201420823766139682381973240743541030659775288508921362724, + 173271062536305557219323722062711383294158572562695717740068656098441040230, + 18120430890549410286417591505529104700901943324772175772035648111937818237369, + 20484495168135072493552514219686101965206843697794133766912991150184337935627, + 19155651295705203459475805213866664350848604323501251939850063308319753686505, + 11971299749478202793661982361798418342615500543489781306376058267926437157297, + 18285310723116790056148596536349375622245669010373674803854111592441823052978, + 7069216248902547653615508023941692395371990416048967468982099270925308100727, + 6465151453746412132599596984628739550147379072443683076388208843341824127379, + 16143532858389170960690347742477978826830511669766530042104134302796355145785, + 19362583304414853660976404410208489566967618125972377176980367224623492419647, + 1702213613534733786921602839210290505213503664731919006932367875629005980493, + 10781825404476535814285389902565833897646945212027592373510689209734812292327, + 4212716923652881254737947578600828255798948993302968210248673545442808456151, + 7594017890037021425366623750593200398174488805473151513558919864633711506220, + 18979889247746272055963929241596362599320706910852082477600815822482192194401, + 13602139229813231349386885113156901793661719180900395818909719758150455500533, +] diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 7c19523b..5b85d642 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,12 +1,12 @@ pub mod query; -use std::{collections::HashMap, fmt::Debug, rc::Rc}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, rc::Rc}; use crate::{ frontend::dsl::StepTypeHandler, poly::Expr, util::{uuid, UUID}, - wit_gen::{FixedGenContext, Trace, TraceContext}, + wit_gen::{FixedAssignment, FixedGenContext, Trace, TraceContext}, }; use halo2_proofs::plonk::{Advice, Column as Halo2Column, ColumnType, Fixed}; @@ -28,7 +28,7 @@ pub struct Circuit { pub annotations: HashMap, pub trace: Option>>, - pub fixed_gen: Option>>, + pub fixed_assignments: Option>, pub first_step: Option, pub last_step: Option, @@ -41,10 +41,19 @@ pub struct Circuit { impl Debug for Circuit { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Circuit") + .field("step_types", &self.step_types) .field("forward_signals", &self.forward_signals) + .field("shared_signals", &self.shared_signals) + .field("fixed_signals", &self.fixed_signals) .field("halo2_advice", &self.halo2_advice) - .field("step_types", &self.step_types) + .field("halo2_fixed", &self.halo2_fixed) + .field("exposed", &self.exposed) .field("annotations", &self.annotations) + .field("fixed_assignments", &self.fixed_assignments) + .field("first_step", &self.first_step) + .field("last_step", &self.last_step) + .field("num_steps", &self.num_steps) + .field("q_enable", &self.q_enable) .finish() } } @@ -65,7 +74,7 @@ impl Default for Circuit { annotations: Default::default(), trace: None, - fixed_gen: None, + fixed_assignments: None, first_step: None, last_step: None, @@ -166,21 +175,20 @@ impl Circuit { } } - pub fn set_fixed_gen(&mut self, def: D) - where - D: Fn(&mut FixedGenContext) + 'static, - { - match self.fixed_gen { - None => self.fixed_gen = Some(Rc::new(def)), - Some(_) => panic!("circuit cannot have more than one fixed generator"), - } - } - pub fn get_step_type(&self, uuid: UUID) -> Rc> { let step_rc = self.step_types.get(&uuid).expect("step type not found"); Rc::clone(step_rc) } + + pub fn set_fixed_assignments(&mut self, assignments: FixedAssignment) { + match self.fixed_assignments { + None => { + self.fixed_assignments = Some(assignments); + } + Some(_) => panic!("circuit cannot have more than one fixed generator"), + } + } } pub type FixedGen = dyn Fn(&mut FixedGenContext) + 'static; @@ -451,7 +459,7 @@ impl FixedSignal { } } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub enum ExposeOffset { First, Last, diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 4948b9dc..968d7725 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -1,12 +1,13 @@ use crate::{ ast::{query::Queriable, Circuit, ExposeOffset, StepType, StepTypeUUID}, + field::Field, util::{uuid, UUID}, wit_gen::{FixedGenContext, StepInstance, TraceContext}, }; use halo2_proofs::plonk::{Advice, Column as Halo2Column, Fixed}; -use core::fmt::Debug; +use core::{fmt::Debug, hash::Hash}; use std::marker::PhantomData; use self::{ @@ -133,17 +134,6 @@ impl CircuitContext { self.circuit.set_trace(def); } - /// Sets the fixed generation function for the circuit. The fixed generation function is - /// responsible for assigning fixed values to fixed columns. It is entirely left - /// for the user to implement and is Turing complete. Users typically generate cell values and - /// call the `assign` function to fill the fixed columns. - pub fn fixed_gen(&mut self, def: D) - where - D: Fn(&mut FixedGenContext) + 'static, - { - self.circuit.set_fixed_gen(def); - } - pub fn new_table(&self, table: LookupTableStore) -> LookupTable { let uuid = table.uuid(); self.tables.add(table); @@ -172,6 +162,27 @@ impl CircuitContext { } } +impl CircuitContext { + /// Executes the fixed generation function provided by the user and sets the fixed assignments + /// for the circuit. The fixed generation function is responsible for assigning fixed values to + /// fixed columns. It is entirely left for the user to implement and is Turing complete. Users + /// typically generate cell values and call the `assign` function to fill the fixed columns. + pub fn fixed_gen(&mut self, def: D) + where + D: Fn(&mut FixedGenContext) + 'static, + { + if self.circuit.num_steps == 0 { + panic!("circuit must call pragma_num_steps before calling fixed_gen"); + } + let mut ctx = FixedGenContext::new(self.circuit.num_steps); + (def)(&mut ctx); + + let assignments = ctx.get_assignments(); + + self.circuit.set_fixed_assignments(assignments); + } +} + pub enum StepTypeDefInput { Handler(StepTypeHandler), String(&'static str), diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index 2cc11326..db382df7 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -21,7 +21,7 @@ use super::{lb::LookupTableRegistry, CircuitContext}; pub struct SuperCircuitContext { super_circuit: SuperCircuit, sub_circuit_phase1: Vec>, - tables: LookupTableRegistry, + pub tables: LookupTableRegistry, } impl Default for SuperCircuitContext { @@ -48,8 +48,9 @@ impl SuperCircuitContext { circuit: Circuit::default(), tables: self.tables.clone(), }; - + println!("super circuit table registry 2: {:?}", self.tables); let exports = sub_circuit_def(&mut sub_circuit_context, imports); + println!("super circuit table registry 3: {:?}", self.tables); let sub_circuit = sub_circuit_context.circuit; @@ -61,11 +62,24 @@ impl SuperCircuitContext { (assignment, exports) } + pub fn sub_circuit_with_ast( + &mut self, + config: CompilerConfig, + sub_circuit: Circuit, // directly input ast + ) -> AssignmentGenerator { + let (unit, assignment) = compile_phase1(config, &sub_circuit); + let assignment = assignment.unwrap_or_else(|| AssignmentGenerator::empty(unit.uuid)); + + self.sub_circuit_phase1.push(unit); + + assignment + } + pub fn mapping, MappingArgs) + 'static>(&mut self, def: D) { self.super_circuit.set_mapping(def); } - fn compile(mut self) -> SuperCircuit { + pub fn compile(mut self) -> SuperCircuit { let other = Rc::new(self.sub_circuit_phase1.clone()); // let columns = other // .iter() diff --git a/src/frontend/python/chiquito/cb.py b/src/frontend/python/chiquito/cb.py index f6058b67..6d25d920 100644 --- a/src/frontend/python/chiquito/cb.py +++ b/src/frontend/python/chiquito/cb.py @@ -1,12 +1,12 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum, auto -from typing import List +from typing import List, Dict, Optional -from chiquito.util import F +from chiquito.util import F, uuid from chiquito.expr import Expr, Const, Neg, to_expr, ToExpr from chiquito.query import StepTypeNext -from chiquito.chiquito_ast import ASTStepType +from chiquito.chiquito_ast import ASTStepType, Lookup class Typing(Enum): @@ -197,7 +197,12 @@ def rlc(exprs: List[ToExpr], randomness: Expr) -> Expr: return Expr(Const(F(0))) -# TODO: Implement lookup table after the lookup abstraction PR is merged. +def lookup() -> InPlaceLookupBuilder: + return InPlaceLookupBuilder() + + +def table() -> LookupTable: + return LookupTable() ToConstraint = Constraint | Expr | int | F @@ -222,3 +227,99 @@ def to_constraint(v: ToConstraint) -> Constraint: raise TypeError( f"Type `{type(v)}` is not ToConstraint (one of Constraint, Expr, int, or F)." ) + + +@dataclass +class LookupTable: + uuid: int = 0 + dest: List[Expr] = field(default_factory=list) + finished_flag: bool = False + + def __init__(self: LookupTable): + self.uuid: int = uuid() + self.dest = [] + self.finished_flag = False + + def add(self: LookupTable, expr: ToExpr) -> LookupTable: + assert self.finished_flag == False + self.dest.append(to_expr(expr)) + return self + + def apply(self: LookupTable, constraint: ToConstraint) -> LookupTableBuilder: + assert self.finished_flag == True + # just pass in lookuptable itself rather than finding it from uuid + return LookupTableBuilder(self).apply(constraint) + + def when(self: LookupTable, enable: ToConstraint) -> LookupTableBuilder: + assert self.finished_flag == True + return LookupTableBuilder(self).when(enable) + + def set_finished_flag(self: LookupTable): + assert self.finished_flag == False + self.finished_flag = True + + +@dataclass +class LookupTableBuilder: + lookup_table: LookupTable + src: List[Constraint] = field(default_factory=list) + enable: Optional[Constraint] = None + + def __init__(self: LookupTableBuilder, lookup_table: LookupTable): + self.lookup_table = lookup_table + self.src = [] + self.enable = None + + def apply(self: LookupTableBuilder, constraint: ToConstraint) -> LookupTableBuilder: + self.src.append(to_constraint(constraint)) + return self + + def when(self: LookupTableBuilder, enable: ToConstraint) -> LookupTableBuilder: + if self.enable is not None: + raise ValueError("LookupTableBuilder: when() can only be called once.") + self.enable = to_constraint(enable) + return self + + def build(self: LookupTableBuilder) -> Lookup: + if self.lookup_table is None: + raise ValueError( + f"LookupTableBuilder: cannot call build() if self.lookup_table is None" + ) + if len(self.src) != len(self.lookup_table.dest): + raise ValueError( + "LookupTableBuilder: build() has different number of source columns and destination columns." + ) + + lookup = Lookup() + + if self.enable is not None: + lookup.enable(self.enable.annotation, self.enable.expr) + + for i in range(len(self.src)): + lookup.add( + self.src[i].annotation, self.src[i].expr, self.lookup_table.dest[i] + ) + + return lookup + + +@dataclass +class InPlaceLookupBuilder: + lookup: Lookup = field(default_factory=Lookup) + + def build(self: InPlaceLookupBuilder) -> Lookup: + return self.lookup + + def add( + self: InPlaceLookupBuilder, constraint: ToConstraint, expression: ToExpr + ) -> InPlaceLookupBuilder: + constraint = to_constraint(constraint) + self.lookup.add(constraint.annotation, constraint.expr, to_expr(expression)) + return self + + def enable( + self: InPlaceLookupBuilder, enable: ToConstraint + ) -> InPlaceLookupBuilder: + enable = to_constraint(enable) + self.lookup.enable(enable.annotation, enable.expr) + return self diff --git a/src/frontend/python/chiquito/chiquito_ast.py b/src/frontend/python/chiquito/chiquito_ast.py index 447720b5..fc6e6d0b 100644 --- a/src/frontend/python/chiquito/chiquito_ast.py +++ b/src/frontend/python/chiquito/chiquito_ast.py @@ -1,13 +1,13 @@ from __future__ import annotations -from typing import Callable, List, Dict, Optional, Any, Tuple +from typing import Callable, List, Dict, Optional, Tuple from dataclasses import dataclass, field, asdict # from chiquito import wit_gen, expr, query, util -from chiquito.wit_gen import FixedGenContext, StepInstance +from chiquito.wit_gen import FixedAssignment, TraceWitness from chiquito.expr import Expr -from chiquito.util import uuid -from chiquito.query import Queriable +from chiquito.util import uuid, F +from chiquito.query import Queriable, Fixed # pub struct Circuit { @@ -31,6 +31,12 @@ # } +@dataclass +class ASTSuperCircuit: + sub_circuits: Dict[int, ASTCircuit] = field(default_factory=dict) + super_witness: Dict[int, TraceWitness] = field(default_factory=dict) + + @dataclass class ASTCircuit: step_types: Dict[int, ASTStepType] = field(default_factory=dict) @@ -39,7 +45,7 @@ class ASTCircuit: fixed_signals: List[FixedSignal] = field(default_factory=list) exposed: List[Tuple[Queriable, ExposeOffset]] = field(default_factory=list) annotations: Dict[int, str] = field(default_factory=dict) - fixed_gen: Optional[Callable] = None + fixed_assignments: Optional[FixedAssignment] = None first_step: Optional[int] = None last_step: Optional[int] = None num_steps: int = 0 @@ -92,7 +98,7 @@ def __str__(self: ASTCircuit): f"\tfixed_signals=[{fixed_signals_str}],\n" f"\texposed=[{exposed_str}],\n" f"\tannotations={{{annotations_str}}},\n" - f"\tfixed_gen={self.fixed_gen},\n" + f"\tfixed_assignments={self.fixed_assignments},\n" f"\tfirst_step={self.first_step},\n" f"\tlast_step={self.last_step},\n" f"\tnum_steps={self.num_steps}\n" @@ -111,6 +117,11 @@ def __json__(self: ASTCircuit): for (queriable, offset) in self.exposed ], "annotations": self.annotations, + "fixed_assignments": None + if self.fixed_assignments is None + else { + lhs.uuid(): [lhs, rhs] for (lhs, rhs) in self.fixed_assignments.items() + }, "first_step": self.first_step, "last_step": self.last_step, "num_steps": self.num_steps, @@ -143,11 +154,14 @@ def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str): self.annotations[step_type.id] = name self.step_types[step_type.id] = step_type - def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]): - if self.fixed_gen is not None: - raise Exception("ASTCircuit cannot have more than one fixed generator.") + def add_fixed_assignment(self: ASTCircuit, offset: int, lhs: Queriable, rhs: F): + if not isinstance(lhs, Fixed): + raise ValueError(f"Cannot assign to non-fixed signal.") + if lhs in self.fixed_assignments.keys(): + self.fixed_assignments[lhs][offset] = rhs else: - self.fixed_gen = fixed_gen_def + self.fixed_assignments[lhs] = [F.zero()] * self.num_steps + self.fixed_assignments[lhs][offset] = rhs def get_step_type(self, uuid: int) -> ASTStepType: if uuid in self.step_types.keys(): @@ -175,10 +189,11 @@ class ASTStepType: signals: List[InternalSignal] constraints: List[ASTConstraint] transition_constraints: List[TransitionConstraint] + lookups: List[Lookup] annotations: Dict[int, str] def new(name: str) -> ASTStepType: - return ASTStepType(uuid(), name, [], [], [], {}) + return ASTStepType(uuid(), name, [], [], [], [], {}) def __str__(self): signals_str = ( @@ -202,6 +217,13 @@ def __str__(self): if self.transition_constraints else "" ) + lookups_str = ( + "\n\t\t\t\t" + + ",\n\t\t\t\t".join(str(lookup) for lookup in self.lookups) + + "\n\t\t\t" + if self.lookups + else "" + ) annotations_str = ( "\n\t\t\t\t" + ",\n\t\t\t\t".join(f"{k}: {v}" for k, v in self.annotations.items()) @@ -217,6 +239,7 @@ def __str__(self): f"\t\t\tsignals=[{signals_str}],\n" f"\t\t\tconstraints=[{constraints_str}],\n" f"\t\t\ttransition_constraints=[{transition_constraints_str}],\n" + f"\t\t\tlookups=[{lookups_str}],\n" f"\t\t\tannotations={{{annotations_str}}}\n" f"\t\t)" ) @@ -230,6 +253,7 @@ def __json__(self): "transition_constraints": [ x.__json__() for x in self.transition_constraints ], + "lookups": [x.__json__() for x in self.lookups], "annotations": self.annotations, } @@ -385,3 +409,53 @@ def __str__(self: InternalSignal): def __json__(self: InternalSignal): return asdict(self) + + +@dataclass +class Lookup: + annotation: str = "" + exprs: List[Tuple[ASTConstraint, Expr]] = field(default_factory=list) + enabler: Optional[ + ASTConstraint + ] = None # called enabler because cannot have field and method both called "enable" + + def add( + self: Lookup, + constraint_annotation: str, + constraint_expr: Expr, + expression: Expr, + ): + constraint = ASTConstraint(constraint_annotation, constraint_expr) + self.annotation += f"match({constraint.annotation} => {str(expression)}) " + if self.enabler is None: + self.exprs.append((constraint, expression)) + else: + self.exprs.append( + (multiply_constraints(self.enabler, constraint), expression) + ) + + def enable(self: Lookup, enable_annotation: str, enable_expr: Expr): + enabler = ASTConstraint(enable_annotation, enable_expr) + if self.enabler is None: + for constraint, _ in self.exprs: + constraint = multiply_constraints(enabler, constraint) + self.enabler = enabler + self.annotation = f"if {enable_annotation}, {self.annotation}" + else: + raise ValueError("Lookup: enable() can only be called once.") + + def __str__(self: Lookup): + return f"Lookup({self.annotation})" + + def __json__(self: Lookup): + return { + "annotation": self.annotation, + "exprs": [[x.__json__(), y.__json__()] for (x, y) in self.exprs], + "enable": self.enabler.__json__() if self.enabler is not None else None, + } + + +def multiply_constraints( + enabler: ASTConstraint, constraint: ASTConstraint +) -> ASTConstraint: + return ASTConstraint(constraint.annotation, enabler.expr * constraint.expr) diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index 3f00cd85..52abfc62 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -1,32 +1,118 @@ from __future__ import annotations +from typing import List, Dict from enum import Enum from typing import Callable, Any -# import rust_chiquito # rust bindings from chiquito import rust_chiquito import json -from chiquito import chiquito_ast, wit_gen -from chiquito.chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset +from chiquito.chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset, ASTSuperCircuit from chiquito.query import Internal, Forward, Queriable, Shared, Fixed -from chiquito.wit_gen import FixedGenContext, StepInstance, TraceWitness -from chiquito.cb import Constraint, Typing, ToConstraint, to_constraint +from chiquito.wit_gen import StepInstance, TraceWitness +from chiquito.cb import ( + Constraint, + Typing, + ToConstraint, + to_constraint, + LookupTable, + LookupTableBuilder, + InPlaceLookupBuilder, +) from chiquito.util import CustomEncoder, F +class SuperCircuitMode(Enum): + NoMode = 0 + SETUP = 1 + Mapping = 2 + + +class SuperCircuit: + def __init__(self: SuperCircuit): + self.ast = ASTSuperCircuit() + self.mode = SuperCircuitMode.SETUP + self.setup() + self.mode = SuperCircuitMode.NoMode + + # called under setup() + def sub_circuit(self: SuperCircuit, sub_circuit: Circuit) -> Circuit: + assert self.mode == SuperCircuitMode.SETUP + if sub_circuit.rust_id != 0: + raise ValueError( + "SuperCircuit: sub_circuit() cannot be called twice on the same circuit." + ) + ast_json: str = sub_circuit.get_ast_json() + sub_circuit.rust_id: int = rust_chiquito.ast_to_halo2(ast_json) + self.ast.sub_circuits[sub_circuit.rust_id] = sub_circuit.ast + return sub_circuit + + # called under mapping() + # generates TraceWitness for sub_circuit + def map(self: SuperCircuit, sub_circuit: Circuit, *args: Any) -> TraceWitness: + assert self.mode == SuperCircuitMode.Mapping + witness: TraceWitness = sub_circuit.gen_witness(*args) + if sub_circuit.rust_id == 0: + raise ValueError( + "SuperCircuit: must call sub_circuit() before calling map() on a Circuit." + ) + self.ast.super_witness[sub_circuit.rust_id] = witness + return witness + + # called at the outermost level + # generates TraceWitness mapping + def gen_witness(self: SuperCircuit, *args: Any) -> Dict[int, TraceWitness]: + self.mode = SuperCircuitMode.Mapping + self.mapping(*args) + self.mode = SuperCircuitMode.NoMode + super_witness: Dict[int, TraceWitness] = self.ast.super_witness + del ( + self.ast.super_witness + ) # so that we can generate different witness mapping in the next gen_witness() call + return super_witness + + def halo2_mock_prover(self: SuperCircuit, super_witness: Dict[int, TraceWitness]): + for rust_id, witness in super_witness.items(): + if rust_id not in self.ast.sub_circuits: + raise ValueError( + f"SuperCircuit.halo2_mock_prover(): TraceWitness with rust_id {rust_id} not found in sub_circuits." + ) + witness_json: str = witness.get_witness_json() + super_witness[rust_id] = witness_json + rust_chiquito.super_circuit_halo2_mock_prover( + list(self.ast.sub_circuits.keys()), super_witness + ) + + class CircuitMode(Enum): NoMode = 0 SETUP = 1 Trace = 2 + FixedGen = 3 class Circuit: - def __init__(self: Circuit): + def __init__( + self: Circuit, + super_circuit: SuperCircuit = None, + **kwargs, # **kwargs is intended for arbitrary names for imports + ): self.ast = ASTCircuit() self.witness = TraceWitness() - self.rust_ast_id = 0 + self.rust_id = 0 + self.super_circuit = super_circuit + for key, value in kwargs.items(): + setattr(self, key, value) self.mode = CircuitMode.SETUP self.setup() + if hasattr(self, "fixed_gen") and callable(self.fixed_gen): + self.mode = CircuitMode.FixedGen + if self.ast.num_steps == 0: + raise ValueError( + "Must set num_steps by calling pragma_num_steps() in setup before calling fixed_gen()." + ) + self.ast.fixed_assignments = {} + self.fixed_gen() + self.mode = CircuitMode.NoMode def forward(self: Circuit, name: str) -> Forward: assert self.mode == CircuitMode.SETUP @@ -64,9 +150,6 @@ def step_type_def(self: StepType) -> StepType: assert self.mode == CircuitMode.SETUP self.ast.add_step_type_def() - def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]): - self.ast.set_fixed_gen(fixed_gen_def) - def pragma_first_step(self: Circuit, step_type: StepType) -> None: assert self.mode == CircuitMode.SETUP self.ast.first_step = step_type.step_type.id @@ -83,7 +166,15 @@ def pragma_disable_q_enable(self: Circuit) -> None: assert self.mode == CircuitMode.SETUP self.ast.q_enable = False - def add(self: Circuit, step_type: StepType, *args): + def new_table(self: Circuit, table: LookupTable) -> LookupTable: + assert self.mode == CircuitMode.SETUP + # have a method called set_finished_flag() to encapsulate + # call finished_flag "finished" instead + table.set_finished_flag() + return table + + # called under trace() + def add(self: Circuit, step_type: StepType, *args: Any): assert self.mode == CircuitMode.Trace if len(self.witness.step_instances) >= self.ast.num_steps: raise ValueError(f"Number of step instances exceeds {self.ast.num_steps}") @@ -99,6 +190,15 @@ def padding(self: Circuit, step_type: StepType, *args): while self.needs_padding(): self.add(step_type, *args) + # called under fixed_gen() + def assign(self: Circuit, offset: int, lhs: Queriable, rhs: F): + assert self.mode == CircuitMode.FixedGen + if self.ast.fixed_assignments is None: + raise ValueError( + "FixedAssignment: must have initiated fixed_assignments before calling assign()" + ) + self.ast.add_fixed_assignment(offset, lhs, F(rhs)) + def gen_witness(self: Circuit, *args) -> TraceWitness: self.mode = CircuitMode.Trace self.witness = TraceWitness() @@ -112,11 +212,11 @@ def get_ast_json(self: Circuit) -> str: return json.dumps(self.ast, cls=CustomEncoder, indent=4) def halo2_mock_prover(self: Circuit, witness: TraceWitness): - if self.rust_ast_id == 0: + if self.rust_id == 0: ast_json: str = self.get_ast_json() - self.rust_ast_id: int = rust_chiquito.ast_to_halo2(ast_json) + self.rust_id: int = rust_chiquito.ast_to_halo2(ast_json) witness_json: str = witness.get_witness_json() - rust_chiquito.halo2_mock_prover(witness_json, self.rust_ast_id) + rust_chiquito.halo2_mock_prover(witness_json, self.rust_id) def __str__(self: Circuit) -> str: return self.ast.__str__() @@ -172,6 +272,10 @@ def enforce_constraint_typing(constraint: Constraint): def assign(self: StepType, lhs: Queriable, rhs: F): assert self.mode == StepTypeMode.WG - self.step_instance.assign(lhs, rhs) + self.step_instance.assign(lhs, F(rhs)) + + def add_lookup(self: StepType, lookup_builder: LookupBuilder): + self.step_type.lookups.append(lookup_builder.build()) + - # TODO: Implement add_lookup after lookup abstraction PR is merged. +LookupBuilder = LookupTableBuilder | InPlaceLookupBuilder diff --git a/src/frontend/python/chiquito/util.py b/src/frontend/python/chiquito/util.py index d838aa19..0533fd6f 100644 --- a/src/frontend/python/chiquito/util.py +++ b/src/frontend/python/chiquito/util.py @@ -3,20 +3,22 @@ from uuid import uuid1 import json -F = bn128.FQ - -def json_method(self: F): - # Convert the integer to a byte array - byte_array = self.n.to_bytes(32, "little") - - # Split into four 64-bit integers - ints = [int.from_bytes(byte_array[i * 8 : i * 8 + 8], "little") for i in range(4)] - - return ints - - -F.__json__ = json_method +class F(bn128.FQ): + field_modulus = ( + 21888242871839275222246405745257275088548364400416034343698204186575808495617 + ) + + def __json__(self: F): + R = 2**256 + # Convert the integer to a byte array + montgomery_form = self.n * R % F.field_modulus + byte_array = montgomery_form.to_bytes(32, "little") + # Split into four 64-bit integers + ints = [ + int.from_bytes(byte_array[i * 8 : i * 8 + 8], "little") for i in range(4) + ] + return ints class CustomEncoder(json.JSONEncoder): diff --git a/src/frontend/python/chiquito/wit_gen.py b/src/frontend/python/chiquito/wit_gen.py index b7408756..cec59905 100644 --- a/src/frontend/python/chiquito/wit_gen.py +++ b/src/frontend/python/chiquito/wit_gen.py @@ -1,6 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Dict, List, Callable, Any +from typing import Dict, List import json from chiquito.query import Queriable, Fixed @@ -94,29 +94,4 @@ def evil_witness_test( return TraceWitness(new_step_instances) -FixedAssigment = Dict[Queriable, List[F]] - - -@dataclass -class FixedGenContext: - assignments: FixedAssigment = field(default_factory=dict) - num_steps: int = 0 - - def new(num_steps: int) -> FixedGenContext: - return FixedGenContext({}, num_steps) - - def assign(self: FixedGenContext, offset: int, lhs: Queriable, rhs: F): - if not FixedGenContext.is_fixed_queriable(lhs): - raise ValueError(f"Cannot assign to non-fixed signal.") - if lhs in self.assignments.keys(): - self.assignments[lhs][offset] = rhs - else: - self.assignments[lhs] = [F.zero()] * self.num_steps - self.assignments[lhs][offset] = rhs - - def is_fixed_queriable(q: Queriable) -> bool: - match q.enum: - case Fixed(_, _): - return True - case _: - return False +FixedAssignment = Dict[Queriable, List[F]] diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index 722ae5bc..489d7cc9 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -1,21 +1,24 @@ use pyo3::{ prelude::*, - types::{PyLong, PyString}, + types::{PyDict, PyList, PyLong, PyString}, }; use crate::{ ast::{ query::Queriable, Circuit, Constraint, ExposeOffset, FixedSignal, ForwardSignal, - InternalSignal, SharedSignal, StepType, StepTypeUUID, TransitionConstraint, + InternalSignal, Lookup, SharedSignal, StepType, StepTypeUUID, TransitionConstraint, }, - frontend::dsl::StepTypeHandler, + frontend::dsl::{StepTypeHandler, SuperCircuitContext}, plonkish::{ - backend::halo2::{chiquito2Halo2, ChiquitoHalo2, ChiquitoHalo2Circuit}, + backend::halo2::{ + chiquito2Halo2, chiquitoSuperCircuit2Halo2, ChiquitoHalo2, ChiquitoHalo2Circuit, + ChiquitoHalo2SuperCircuit, + }, compiler::{ cell_manager::SingleRowCellManager, compile, config, step_selector::SimpleStepSelectorBuilder, }, - ir::assignments::AssignmentGenerator, + ir::{assignments::AssignmentGenerator, sc::MappingContext}, }, poly::Expr, util::{uuid, UUID}, @@ -27,13 +30,21 @@ use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; use serde::de::{self, Deserialize, Deserializer, IgnoredAny, MapAccess, Visitor}; use std::{cell::RefCell, collections::HashMap, fmt, rc::Rc}; -type CircuitMapStore = (ChiquitoHalo2, Option>); +type CircuitMapStore = ( + Circuit, + ChiquitoHalo2, + Option>, +); type CircuitMap = RefCell>; thread_local! { pub static CIRCUIT_MAP: CircuitMap = RefCell::new(HashMap::new()); } +/// Parses JSON into `ast::Circuit` and compile. Generates a Rust UUID. Inserts tuple of +/// (`ast::Circuit`, `ChiquitoHalo2`, `AssignmentGenerator`, _) to `CIRCUIT_MAP` with the Rust UUID +/// as the key. Return the Rust UUID to Python. The last field of the tuple, `TraceWitness`, is left +/// as None, for `chiquito_add_witness_to_rust_id` to insert. pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID { let circuit: Circuit = serde_json::from_str(ast_json).expect("Json deserialization to Circuit failed."); @@ -46,7 +57,7 @@ pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID { CIRCUIT_MAP.with(|circuit_map| { circuit_map .borrow_mut() - .insert(uuid, (chiquito_halo2, assignment_generator)); + .insert(uuid, (circuit, chiquito_halo2, assignment_generator)); }); println!("{:?}", uuid); @@ -54,17 +65,83 @@ pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID { uuid } -fn uuid_to_halo2(uuid: UUID) -> CircuitMapStore { +fn add_assignment_generator_to_rust_id( + assignment_generator: AssignmentGenerator, + rust_id: UUID, +) { + CIRCUIT_MAP.with(|circuit_map| { + let mut circuit_map = circuit_map.borrow_mut(); + let circuit_map_store = circuit_map.get_mut(&rust_id).unwrap(); + circuit_map_store.2 = Some(assignment_generator); + }); +} + +/// Compile a `ChiquitoHalo2SuperCircuit` object from a list of `rust_ids`, each corresponding to a +/// sub-circuit. The `ChiquitoHalo2SuperCircuit` object is then passed to `MockProver` for +/// verification. `TraceWitness`, if any, should have been inserted to each rust_id prior to +/// invoking this function. +pub fn chiquito_super_circuit_halo2_mock_prover( + rust_ids: Vec, + super_witness: HashMap, +) { + let mut super_circuit_ctx = SuperCircuitContext::::default(); + + // super_circuit def + let config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); + for rust_id in rust_ids.clone() { + let circuit_map_store = rust_id_to_halo2(rust_id); + let (circuit, _, _) = circuit_map_store; + let assignment = super_circuit_ctx.sub_circuit_with_ast(config.clone(), circuit); + add_assignment_generator_to_rust_id(assignment, rust_id); + } + + let super_circuit = super_circuit_ctx.compile(); + let compiled = chiquitoSuperCircuit2Halo2(&super_circuit); + + let mut mapping_ctx = MappingContext::default(); + for rust_id in rust_ids { + let circuit_map_store = rust_id_to_halo2(rust_id); + let (_, _, assignment_generator) = circuit_map_store; + + if let Some(witness_json) = super_witness.get(&rust_id) { + let witness: TraceWitness = serde_json::from_str(witness_json) + .expect("Json deserialization to TraceWitness failed."); + mapping_ctx.map_with_witness(&assignment_generator.unwrap(), witness); + } + } + + let super_assignments = mapping_ctx.get_super_assignments(); + + let circuit = ChiquitoHalo2SuperCircuit::new(compiled, super_assignments); + + let prover = MockProver::::run(10, &circuit, circuit.instance()).unwrap(); + + let result = prover.verify_par(); + + println!("result = {:#?}", result); + + if let Err(failures) = &result { + for failure in failures.iter() { + println!("{}", failure); + } + } +} + +/// Returns the (`ast::Circuit`, `ChiquitoHalo2`, `AssignmentGenerator`, `TraceWitness`) tuple +/// corresponding to `rust_id`. +fn rust_id_to_halo2(uuid: UUID) -> CircuitMapStore { CIRCUIT_MAP.with(|circuit_map| { let circuit_map = circuit_map.borrow(); circuit_map.get(&uuid).unwrap().clone() }) } -pub fn chiquito_halo2_mock_prover(witness_json: &str, ast_id: UUID) { +/// Runs `MockProver` for a single circuit given JSON of `TraceWitness` and `rust_id` of the +/// circuit. +pub fn chiquito_halo2_mock_prover(witness_json: &str, rust_id: UUID) { let trace_witness: TraceWitness = serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed."); - let (compiled, assignment_generator) = uuid_to_halo2(ast_id); + let (_, compiled, assignment_generator) = rust_id_to_halo2(rust_id); let circuit: ChiquitoHalo2Circuit<_> = ChiquitoHalo2Circuit::new( compiled, assignment_generator.map(|g| g.generate_with_witness(trace_witness)), @@ -102,6 +179,7 @@ impl<'de> Visitor<'de> for CircuitVisitor { let mut fixed_signals = None; let mut exposed = None; let mut annotations = None; + let mut fixed_assignments = None; let mut first_step = None; let mut last_step = None; let mut num_steps = None; @@ -146,6 +224,13 @@ impl<'de> Visitor<'de> for CircuitVisitor { } annotations = Some(map.next_value::>()?); } + "fixed_assignments" => { + if fixed_assignments.is_some() { + return Err(de::Error::duplicate_field("fixed_assignments")); + } + fixed_assignments = + Some(map.next_value::, Vec)>>>()?); + } "first_step" => { if first_step.is_some() { return Err(de::Error::duplicate_field("first_step")); @@ -186,6 +271,7 @@ impl<'de> Visitor<'de> for CircuitVisitor { "fixed_signals", "exposed", "annotations", + "fixed_assignments", "first_step", "last_step", "num_steps", @@ -209,6 +295,9 @@ impl<'de> Visitor<'de> for CircuitVisitor { fixed_signals.ok_or_else(|| de::Error::missing_field("fixed_signals"))?; let exposed = exposed.ok_or_else(|| de::Error::missing_field("exposed"))?; let annotations = annotations.ok_or_else(|| de::Error::missing_field("annotations"))?; + let fixed_assignments = fixed_assignments + .ok_or_else(|| de::Error::missing_field("fixed_assignments"))? + .map(|inner| inner.into_values().collect()); let first_step = first_step.ok_or_else(|| de::Error::missing_field("first_step"))?; let last_step = last_step.ok_or_else(|| de::Error::missing_field("last_step"))?; let num_steps = num_steps.ok_or_else(|| de::Error::missing_field("num_steps"))?; @@ -226,7 +315,7 @@ impl<'de> Visitor<'de> for CircuitVisitor { num_steps, annotations, trace: Some(Rc::new(|_: &mut TraceContext<_>, _: _| {})), - fixed_gen: None, + fixed_assignments, first_step, last_step, q_enable, @@ -252,6 +341,7 @@ impl<'de> Visitor<'de> for StepTypeVisitor { let mut signals = None; let mut constraints = None; let mut transition_constraints = None; + let mut lookups = None; let mut annotations = None; while let Some(key) = map.next_key::()? { @@ -287,6 +377,12 @@ impl<'de> Visitor<'de> for StepTypeVisitor { transition_constraints = Some(map.next_value::>>()?); } + "lookups" => { + if lookups.is_some() { + return Err(de::Error::duplicate_field("lookups")); + } + lookups = Some(map.next_value::>>()?); + } "annotations" => { if annotations.is_some() { return Err(de::Error::duplicate_field("annotations")); @@ -302,6 +398,7 @@ impl<'de> Visitor<'de> for StepTypeVisitor { "signals", "constraints", "transition_constraints", + "lookups", "annotations", ], )) @@ -377,6 +474,62 @@ impl_visitor_constraint_transition!( "struct TransitionConstraint" ); +struct LookupVisitor; + +impl<'de> Visitor<'de> for LookupVisitor { + type Value = Lookup; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Lookup") + } + + fn visit_map(self, mut map: A) -> Result, A::Error> + where + A: MapAccess<'de>, + { + let mut annotation = None; + let mut exprs = None; + let mut enable = None; + while let Some(key) = map.next_key::()? { + match key.as_str() { + "annotation" => { + if annotation.is_some() { + return Err(de::Error::duplicate_field("annotation")); + } + annotation = Some(map.next_value::()?); + } + "exprs" => { + if exprs.is_some() { + return Err(de::Error::duplicate_field("exprs")); + } + exprs = + Some(map.next_value::, Expr>)>>()?); + } + "enable" => { + if enable.is_some() { + return Err(de::Error::duplicate_field("enable")); + } + enable = Some(map.next_value::>>()?); + } + _ => { + return Err(de::Error::unknown_field( + &key, + &["annotation", "exprs", "enable"], + )) + } + } + } + let annotation = annotation.ok_or_else(|| de::Error::missing_field("annotation"))?; + let exprs = exprs.ok_or_else(|| de::Error::missing_field("exprs"))?; + let enable = enable.ok_or_else(|| de::Error::missing_field("enable"))?; + Ok(Self::Value { + annotation, + exprs, + enable, + }) + } +} + struct ExprVisitor; impl<'de> Visitor<'de> for ExprVisitor { @@ -725,6 +878,7 @@ impl_deserialize!(TransitionConstraintVisitor, TransitionConstraint); impl_deserialize!(StepTypeVisitor, StepType); impl_deserialize!(TraceWitnessVisitor, TraceWitness); impl_deserialize!(StepInstanceVisitor, StepInstance); +impl_deserialize!(LookupVisitor, Lookup); impl<'de> Deserialize<'de> for Circuit { fn deserialize(deserializer: D) -> Result, D::Error> @@ -887,16 +1041,70 @@ mod tests { let json = r#" { "step_types": { - "205524326356431126935662643926474033674": { - "id": 205524326356431126935662643926474033674, - "name": "fibo_step", + "258869595755756204079859764249309612554": { + "id": 258869595755756204079859764249309612554, + "name": "fibo_first_step", "signals": [ { - "id": 205524332694684128074575021569884162570, + "id": 258869599717164329791616633222308956682, "annotation": "c" } ], "constraints": [ + { + "annotation": "(a == 1)", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869580702405326369584955980151130634, + "phase": 0, + "annotation": "a" + }, + false + ] + }, + { + "Neg": { + "Const": [ + 1, + 0, + 0, + 0 + ] + } + } + ] + } + }, + { + "annotation": "(b == 1)", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + false + ] + }, + { + "Neg": { + "Const": [ + 1, + 0, + 0, + 0 + ] + } + } + ] + } + }, { "annotation": "((a + b) == c)", "expr": { @@ -904,7 +1112,7 @@ mod tests { { "Forward": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869580702405326369584955980151130634, "phase": 0, "annotation": "a" }, @@ -914,7 +1122,7 @@ mod tests { { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, @@ -924,7 +1132,7 @@ mod tests { { "Neg": { "Internal": { - "id": 205524332694684128074575021569884162570, + "id": 258869599717164329791616633222308956682, "annotation": "c" } } @@ -941,7 +1149,7 @@ mod tests { { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, @@ -952,7 +1160,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869580702405326369584955980151130634, "phase": 0, "annotation": "a" }, @@ -969,7 +1177,7 @@ mod tests { "Sum": [ { "Internal": { - "id": 205524332694684128074575021569884162570, + "id": 258869599717164329791616633222308956682, "annotation": "c" } }, @@ -977,7 +1185,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, @@ -987,18 +1195,48 @@ mod tests { } ] } + }, + { + "annotation": "(n == next(n))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + true + ] + } + } + ] + } } ], + "lookups": [], "annotations": { - "205524332694684128074575021569884162570": "c" + "258869599717164329791616633222308956682": "c" } }, - "205524373893328635494146417612672338442": { - "id": 205524373893328635494146417612672338442, - "name": "fibo_last_step", + "258869628239302834927102989021255174666": { + "id": 258869628239302834927102989021255174666, + "name": "fibo_step", "signals": [ { - "id": 205524377062455136063336753318874188298, + "id": 258869632200710960639812650790420089354, "annotation": "c" } ], @@ -1010,7 +1248,7 @@ mod tests { { "Forward": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869580702405326369584955980151130634, "phase": 0, "annotation": "a" }, @@ -1020,7 +1258,7 @@ mod tests { { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, @@ -1030,7 +1268,7 @@ mod tests { { "Neg": { "Internal": { - "id": 205524377062455136063336753318874188298, + "id": 258869632200710960639812650790420089354, "annotation": "c" } } @@ -1039,22 +1277,180 @@ mod tests { } } ], - "transition_constraints": [], + "transition_constraints": [ + { + "annotation": "(b == next(a))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869580702405326369584955980151130634, + "phase": 0, + "annotation": "a" + }, + true + ] + } + } + ] + } + }, + { + "annotation": "(c == next(b))", + "expr": { + "Sum": [ + { + "Internal": { + "id": 258869632200710960639812650790420089354, + "annotation": "c" + } + }, + { + "Neg": { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + true + ] + } + } + ] + } + }, + { + "annotation": "(n == next(n))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + true + ] + } + } + ] + } + } + ], + "lookups": [], "annotations": { - "205524377062455136063336753318874188298": "c" + "258869632200710960639812650790420089354": "c" } + }, + "258869646461780213207493341245063432714": { + "id": 258869646461780213207493341245063432714, + "name": "padding", + "signals": [], + "constraints": [], + "transition_constraints": [ + { + "annotation": "(b == next(b))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + true + ] + } + } + ] + } + }, + { + "annotation": "(n == next(n))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + true + ] + } + } + ] + } + } + ], + "lookups": [], + "annotations": {} } }, "forward_signals": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869580702405326369584955980151130634, "phase": 0, "annotation": "a" }, { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" + }, + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" } ], "shared_signals": [], @@ -1064,28 +1460,13 @@ mod tests { { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, false ] }, - { - "First": 0 - } - ], - [ - { - "Forward": [ - { - "id": 205524314472206749795829327634996267530, - "phase": 0, - "annotation": "a" - }, - false - ] - }, { "Last": -1 } @@ -1094,29 +1475,32 @@ mod tests { { "Forward": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869589417503202934383108674030275082, "phase": 0, - "annotation": "a" + "annotation": "n" }, false ] }, { - "Step": 1 + "Last": -1 } ] ], "annotations": { - "205524314472206749795829327634996267530": "a", - "205524322395023001221676493137926294026": "b", - "205524326356431126935662643926474033674": "fibo_step", - "205524373893328635494146417612672338442": "fibo_last_step" + "258869580702405326369584955980151130634": "a", + "258869587040658327507391136965088381450": "b", + "258869589417503202934383108674030275082": "n", + "258869595755756204079859764249309612554": "fibo_first_step", + "258869628239302834927102989021255174666": "fibo_step", + "258869646461780213207493341245063432714": "padding" }, - "first_step": 205524326356431126935662643926474033674, - "last_step": 205524373893328635494146417612672338442, - "num_steps": 0, - "q_enable": false, - "id": 205522563529815184552233780032226069002 + "fixed_assignments": null, + "first_step": 258869595755756204079859764249309612554, + "last_step": 258869646461780213207493341245063432714, + "num_steps": 10, + "q_enable": true, + "id": 258867373405797678961444396351437277706 } "#; let circuit: Circuit = serde_json::from_str(json).unwrap(); @@ -1443,18 +1827,52 @@ fn ast_to_halo2(json: &PyString) -> u128 { } #[pyfunction] -fn halo2_mock_prover(witness_json: &PyString, ast_uuid: &PyLong) { +fn halo2_mock_prover(witness_json: &PyString, rust_id: &PyLong) { chiquito_halo2_mock_prover( witness_json.to_str().expect("PyString convertion failed."), - ast_uuid.extract().expect("PyLong convertion failed."), + rust_id.extract().expect("PyLong convertion failed."), ); } +#[pyfunction] +fn super_circuit_halo2_mock_prover(rust_ids: &PyList, super_witness: &PyDict) { + let uuids = rust_ids + .iter() + .map(|rust_id| { + rust_id + .downcast::() + .expect("PyAny downcast failed.") + .extract() + .expect("PyLong convertion failed.") + }) + .collect::>(); + + let super_witness = super_witness + .iter() + .map(|(key, value)| { + ( + key.downcast::() + .expect("PyAny downcast failed.") + .extract() + .expect("PyLong convertion failed."), + value + .downcast::() + .expect("PyAny downcast failed.") + .to_str() + .expect("PyString convertion failed."), + ) + }) + .collect::>(); + + chiquito_super_circuit_halo2_mock_prover(uuids, super_witness) +} + #[pymodule] fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(convert_and_print_ast, m)?)?; m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?; m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?; m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?; + m.add_function(wrap_pyfunction!(super_circuit_halo2_mock_prover, m)?)?; Ok(()) } diff --git a/src/plonkish/compiler/mod.rs b/src/plonkish/compiler/mod.rs index 2efacbf1..5647772f 100644 --- a/src/plonkish/compiler/mod.rs +++ b/src/plonkish/compiler/mod.rs @@ -6,7 +6,7 @@ use crate::{ Circuit, Column, Poly, PolyExpr, PolyLookup, }, poly::Expr, - wit_gen::{FixedAssignment, FixedGenContext, TraceGenerator}, + wit_gen::{FixedAssignment, TraceGenerator}, }; use std::{hash::Hash, rc::Rc}; @@ -228,13 +228,8 @@ fn compile_fixed( ast: &astCircuit, unit: &mut CompilationUnit, ) { - if let Some(fixed_gen) = &ast.fixed_gen { - let mut ctx = FixedGenContext::new(unit.num_steps); - (*fixed_gen)(&mut ctx); - - let assignments = ctx.get_assignments(); - - unit.fixed_assignments = place_fixed_assignments(unit, assignments); + if let Some(fixed_assignments) = &ast.fixed_assignments { + unit.fixed_assignments = place_fixed_assignments(unit, fixed_assignments.clone()); } } diff --git a/src/plonkish/ir/sc.rs b/src/plonkish/ir/sc.rs index ee27f616..f283c162 100644 --- a/src/plonkish/ir/sc.rs +++ b/src/plonkish/ir/sc.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, rc::Rc}; -use crate::{field::Field, util::UUID}; +use crate::{field::Field, util::UUID, wit_gen::TraceWitness}; use super::{ assignments::{AssignmentGenerator, Assignments}, @@ -65,7 +65,16 @@ impl MappingContext { self.assignments.insert(gen.uuid(), gen.generate(args)); } - fn get_super_assignments(self) -> SuperAssignments { + pub fn map_with_witness( + &mut self, + gen: &AssignmentGenerator, + witness: TraceWitness, + ) { + self.assignments + .insert(gen.uuid(), gen.generate_with_witness(witness)); + } + + pub fn get_super_assignments(self) -> SuperAssignments { self.assignments } } diff --git a/src/wit_gen.rs b/src/wit_gen.rs index 7d7ff3e2..9589d850 100644 --- a/src/wit_gen.rs +++ b/src/wit_gen.rs @@ -33,7 +33,7 @@ impl StepInstance { pub type Witness = Vec>; -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct TraceWitness { pub step_instances: Witness, }