diff --git a/Cargo.toml b/Cargo.toml index 9702da73..93a7915f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chiquito" -version = "0.1.2023070800" +version = "0.1.2023091000" edition = "2021" license = "MIT OR Apache-2.0" authors = ["Leo Lara "] diff --git a/examples/mimc7.py b/examples/mimc7.py new file mode 100644 index 00000000..cdfb7ed5 --- /dev/null +++ b/examples/mimc7.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from chiquito.dsl import Circuit, StepType +from chiquito.cb import eq, table +from chiquito.util import F + +from chiquito.rust_chiquito import convert_and_print_ast + +from mimc7_constants import ROUND_KEYS + +ROUNDS = 91 + + +class Mimc7Constants(Circuit): + def setup(self): + self.pragma_num_steps(ROUNDS) + self.lookup_row = self.fixed("constant row") + self.lookup_c = self.fixed("constant value") + self.new_table(table().add(self.lookup_row).add(self.lookup_c)) + + def fixed_gen(self): + for i, round_key in enumerate(ROUND_KEYS): + self.assign(i, self.lookup_row, F(i)) + self.assign(i, self.lookup_c, F(round_key)) + + +mimc7_constants = Mimc7Constants() +mimc7_constants_json = mimc7_constants.get_ast_json() +convert_and_print_ast(mimc7_constants_json) diff --git a/examples/mimc7_constants.py b/examples/mimc7_constants.py new file mode 100644 index 00000000..08d9fe57 --- /dev/null +++ b/examples/mimc7_constants.py @@ -0,0 +1,93 @@ +ROUND_KEYS = [ + 0, + 20888961410941983456478427210666206549300505294776164667214940546594746570981, + 15265126113435022738560151911929040668591755459209400716467504685752745317193, + 8334177627492981984476504167502758309043212251641796197711684499645635709656, + 1374324219480165500871639364801692115397519265181803854177629327624133579404, + 11442588683664344394633565859260176446561886575962616332903193988751292992472, + 2558901189096558760448896669327086721003508630712968559048179091037845349145, + 11189978595292752354820141775598510151189959177917284797737745690127318076389, + 3262966573163560839685415914157855077211340576201936620532175028036746741754, + 17029914891543225301403832095880481731551830725367286980611178737703889171730, + 4614037031668406927330683909387957156531244689520944789503628527855167665518, + 19647356996769918391113967168615123299113119185942498194367262335168397100658, + 5040699236106090655289931820723926657076483236860546282406111821875672148900, + 2632385916954580941368956176626336146806721642583847728103570779270161510514, + 17691411851977575435597871505860208507285462834710151833948561098560743654671, + 11482807709115676646560379017491661435505951727793345550942389701970904563183, + 8360838254132998143349158726141014535383109403565779450210746881879715734773, + 12663821244032248511491386323242575231591777785787269938928497649288048289525, + 3067001377342968891237590775929219083706800062321980129409398033259904188058, + 8536471869378957766675292398190944925664113548202769136103887479787957959589, + 19825444354178182240559170937204690272111734703605805530888940813160705385792, + 16703465144013840124940690347975638755097486902749048533167980887413919317592, + 13061236261277650370863439564453267964462486225679643020432589226741411380501, + 10864774797625152707517901967943775867717907803542223029967000416969007792571, + 10035653564014594269791753415727486340557376923045841607746250017541686319774, + 3446968588058668564420958894889124905706353937375068998436129414772610003289, + 4653317306466493184743870159523234588955994456998076243468148492375236846006, + 8486711143589723036499933521576871883500223198263343024003617825616410932026, + 250710584458582618659378487568129931785810765264752039738223488321597070280, + 2104159799604932521291371026105311735948154964200596636974609406977292675173, + 16313562605837709339799839901240652934758303521543693857533755376563489378839, + 6032365105133504724925793806318578936233045029919447519826248813478479197288, + 14025118133847866722315446277964222215118620050302054655768867040006542798474, + 7400123822125662712777833064081316757896757785777291653271747396958201309118, + 1744432620323851751204287974553233986555641872755053103823939564833813704825, + 8316378125659383262515151597439205374263247719876250938893842106722210729522, + 6739722627047123650704294650168547689199576889424317598327664349670094847386, + 21211457866117465531949733809706514799713333930924902519246949506964470524162, + 13718112532745211817410303291774369209520657938741992779396229864894885156527, + 5264534817993325015357427094323255342713527811596856940387954546330728068658, + 18884137497114307927425084003812022333609937761793387700010402412840002189451, + 5148596049900083984813839872929010525572543381981952060869301611018636120248, + 19799686398774806587970184652860783461860993790013219899147141137827718662674, + 19240878651604412704364448729659032944342952609050243268894572835672205984837, + 10546185249390392695582524554167530669949955276893453512788278945742408153192, + 5507959600969845538113649209272736011390582494851145043668969080335346810411, + 18177751737739153338153217698774510185696788019377850245260475034576050820091, + 19603444733183990109492724100282114612026332366576932662794133334264283907557, + 10548274686824425401349248282213580046351514091431715597441736281987273193140, + 1823201861560942974198127384034483127920205835821334101215923769688644479957, + 11867589662193422187545516240823411225342068709600734253659804646934346124945, + 18718569356736340558616379408444812528964066420519677106145092918482774343613, + 10530777752259630125564678480897857853807637120039176813174150229243735996839, + 20486583726592018813337145844457018474256372770211860618687961310422228379031, + 12690713110714036569415168795200156516217175005650145422920562694422306200486, + 17386427286863519095301372413760745749282643730629659997153085139065756667205, + 2216432659854733047132347621569505613620980842043977268828076165669557467682, + 6309765381643925252238633914530877025934201680691496500372265330505506717193, + 20806323192073945401862788605803131761175139076694468214027227878952047793390, + 4037040458505567977365391535756875199663510397600316887746139396052445718861, + 19948974083684238245321361840704327952464170097132407924861169241740046562673, + 845322671528508199439318170916419179535949348988022948153107378280175750024, + 16222384601744433420585982239113457177459602187868460608565289920306145389382, + 10232118865851112229330353999139005145127746617219324244541194256766741433339, + 6699067738555349409504843460654299019000594109597429103342076743347235369120, + 6220784880752427143725783746407285094967584864656399181815603544365010379208, + 6129250029437675212264306655559561251995722990149771051304736001195288083309, + 10773245783118750721454994239248013870822765715268323522295722350908043393604, + 4490242021765793917495398271905043433053432245571325177153467194570741607167, + 19596995117319480189066041930051006586888908165330319666010398892494684778526, + 837850695495734270707668553360118467905109360511302468085569220634750561083, + 11803922811376367215191737026157445294481406304781326649717082177394185903907, + 10201298324909697255105265958780781450978049256931478989759448189112393506592, + 13564695482314888817576351063608519127702411536552857463682060761575100923924, + 9262808208636973454201420823766139682381973240743541030659775288508921362724, + 173271062536305557219323722062711383294158572562695717740068656098441040230, + 18120430890549410286417591505529104700901943324772175772035648111937818237369, + 20484495168135072493552514219686101965206843697794133766912991150184337935627, + 19155651295705203459475805213866664350848604323501251939850063308319753686505, + 11971299749478202793661982361798418342615500543489781306376058267926437157297, + 18285310723116790056148596536349375622245669010373674803854111592441823052978, + 7069216248902547653615508023941692395371990416048967468982099270925308100727, + 6465151453746412132599596984628739550147379072443683076388208843341824127379, + 16143532858389170960690347742477978826830511669766530042104134302796355145785, + 19362583304414853660976404410208489566967618125972377176980367224623492419647, + 1702213613534733786921602839210290505213503664731919006932367875629005980493, + 10781825404476535814285389902565833897646945212027592373510689209734812292327, + 4212716923652881254737947578600828255798948993302968210248673545442808456151, + 7594017890037021425366623750593200398174488805473151513558919864633711506220, + 18979889247746272055963929241596362599320706910852082477600815822482192194401, + 13602139229813231349386885113156901793661719180900395818909719758150455500533, +] diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 7c19523b..5b85d642 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,12 +1,12 @@ pub mod query; -use std::{collections::HashMap, fmt::Debug, rc::Rc}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, rc::Rc}; use crate::{ frontend::dsl::StepTypeHandler, poly::Expr, util::{uuid, UUID}, - wit_gen::{FixedGenContext, Trace, TraceContext}, + wit_gen::{FixedAssignment, FixedGenContext, Trace, TraceContext}, }; use halo2_proofs::plonk::{Advice, Column as Halo2Column, ColumnType, Fixed}; @@ -28,7 +28,7 @@ pub struct Circuit { pub annotations: HashMap, pub trace: Option>>, - pub fixed_gen: Option>>, + pub fixed_assignments: Option>, pub first_step: Option, pub last_step: Option, @@ -41,10 +41,19 @@ pub struct Circuit { impl Debug for Circuit { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Circuit") + .field("step_types", &self.step_types) .field("forward_signals", &self.forward_signals) + .field("shared_signals", &self.shared_signals) + .field("fixed_signals", &self.fixed_signals) .field("halo2_advice", &self.halo2_advice) - .field("step_types", &self.step_types) + .field("halo2_fixed", &self.halo2_fixed) + .field("exposed", &self.exposed) .field("annotations", &self.annotations) + .field("fixed_assignments", &self.fixed_assignments) + .field("first_step", &self.first_step) + .field("last_step", &self.last_step) + .field("num_steps", &self.num_steps) + .field("q_enable", &self.q_enable) .finish() } } @@ -65,7 +74,7 @@ impl Default for Circuit { annotations: Default::default(), trace: None, - fixed_gen: None, + fixed_assignments: None, first_step: None, last_step: None, @@ -166,21 +175,20 @@ impl Circuit { } } - pub fn set_fixed_gen(&mut self, def: D) - where - D: Fn(&mut FixedGenContext) + 'static, - { - match self.fixed_gen { - None => self.fixed_gen = Some(Rc::new(def)), - Some(_) => panic!("circuit cannot have more than one fixed generator"), - } - } - pub fn get_step_type(&self, uuid: UUID) -> Rc> { let step_rc = self.step_types.get(&uuid).expect("step type not found"); Rc::clone(step_rc) } + + pub fn set_fixed_assignments(&mut self, assignments: FixedAssignment) { + match self.fixed_assignments { + None => { + self.fixed_assignments = Some(assignments); + } + Some(_) => panic!("circuit cannot have more than one fixed generator"), + } + } } pub type FixedGen = dyn Fn(&mut FixedGenContext) + 'static; @@ -451,7 +459,7 @@ impl FixedSignal { } } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub enum ExposeOffset { First, Last, diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 4948b9dc..5845244a 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -1,12 +1,13 @@ use crate::{ ast::{query::Queriable, Circuit, ExposeOffset, StepType, StepTypeUUID}, + field::Field, util::{uuid, UUID}, wit_gen::{FixedGenContext, StepInstance, TraceContext}, }; use halo2_proofs::plonk::{Advice, Column as Halo2Column, Fixed}; -use core::fmt::Debug; +use core::{fmt::Debug, hash::Hash}; use std::marker::PhantomData; use self::{ @@ -133,17 +134,6 @@ impl CircuitContext { self.circuit.set_trace(def); } - /// Sets the fixed generation function for the circuit. The fixed generation function is - /// responsible for assigning fixed values to fixed columns. It is entirely left - /// for the user to implement and is Turing complete. Users typically generate cell values and - /// call the `assign` function to fill the fixed columns. - pub fn fixed_gen(&mut self, def: D) - where - D: Fn(&mut FixedGenContext) + 'static, - { - self.circuit.set_fixed_gen(def); - } - pub fn new_table(&self, table: LookupTableStore) -> LookupTable { let uuid = table.uuid(); self.tables.add(table); @@ -172,6 +162,27 @@ impl CircuitContext { } } +impl CircuitContext { + /// Sets the fixed generation function for the circuit. The fixed generation function is + /// responsible for assigning fixed values to fixed columns. It is entirely left + /// for the user to implement and is Turing complete. Users typically generate cell values and + /// call the `assign` function to fill the fixed columns. + pub fn fixed_gen(&mut self, def: D) + where + D: Fn(&mut FixedGenContext) + 'static, + { + if self.circuit.num_steps == 0 { + panic!("circuit must call pragma_num_steps before calling fixed_gen"); + } + let mut ctx = FixedGenContext::new(self.circuit.num_steps); + (def)(&mut ctx); + + let assignments = ctx.get_assignments(); + + self.circuit.set_fixed_assignments(assignments); + } +} + pub enum StepTypeDefInput { Handler(StepTypeHandler), String(&'static str), diff --git a/src/frontend/python/chiquito/cb.py b/src/frontend/python/chiquito/cb.py index f6058b67..e0dd77df 100644 --- a/src/frontend/python/chiquito/cb.py +++ b/src/frontend/python/chiquito/cb.py @@ -1,12 +1,12 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum, auto -from typing import List +from typing import List, Dict, Optional -from chiquito.util import F +from chiquito.util import F, uuid from chiquito.expr import Expr, Const, Neg, to_expr, ToExpr from chiquito.query import StepTypeNext -from chiquito.chiquito_ast import ASTStepType +from chiquito.chiquito_ast import ASTStepType, Lookup class Typing(Enum): @@ -197,7 +197,12 @@ def rlc(exprs: List[ToExpr], randomness: Expr) -> Expr: return Expr(Const(F(0))) -# TODO: Implement lookup table after the lookup abstraction PR is merged. +def lookup() -> InPlaceLookupBuilder: + return InPlaceLookupBuilder() + + +def table() -> LookupTable: + return LookupTable() ToConstraint = Constraint | Expr | int | F @@ -222,3 +227,87 @@ def to_constraint(v: ToConstraint) -> Constraint: raise TypeError( f"Type `{type(v)}` is not ToConstraint (one of Constraint, Expr, int, or F)." ) + + +@dataclass +class LookupTable: + uuid: int = 0 + dest: List[Expr] = field(default_factory=list) + read_only: bool = False + + def __init__(self: LookupTable): + self.uuid: int = uuid() + self.dest = [] + self.read_only = False + + def add(self: LookupTable, expr: ToExpr) -> LookupTable: + assert self.read_only == False + self.dest.append(to_expr(expr)) + return self + + def apply(self: LookupTable, constraint: ToConstraint) -> LookupTableBuilder: + assert self.read_only == True + return LookupTableBuilder(self.uuid).apply(constraint) + + def when(self: LookupTable, enable: ToConstraint) -> LookupTableBuilder: + assert self.read_only == True + return LookupTableBuilder(self.uuid).when(enable) + + +@dataclass +class LookupTableBuilder: + uuid: int + src: List[Constraint] = field(default_factory=list) + enable: Optional[Constraint] = None + + def __init__(self: LookupTableBuilder, uuid: int): + self.uuid: int = uuid + + def apply(self: LookupTableBuilder, constraint: ToConstraint) -> LookupTableBuilder: + self.src.append(to_constraint(constraint)) + return self + + def when(self: LookupTableBuilder, enable: ToConstraint) -> LookupTableBuilder: + if self.enable is not None: + raise ValueError("LookupTableBuilder: when() can only be called once.") + self.enable = to_constraint(enable) + return self + + def build(self: LookupTableBuilder, super_circuit: StepType) -> Lookup: + table = step_type.circuit.tables.get(self.id) + if self.src.len() != table.dest.len(): + raise ValueError( + "LookupTableBuilder: build() has different number of source columns and destination columns." + ) + + lookup = Lookup() + + if self.enable is not None: + lookup.enable(self.enable.annotation, self.enable.expr) + + for i in range(self.src.len()): + lookup.add(self.src[i].annotation, self.src[i].expr, table.dest[i]) + + return lookup + + +@dataclass +class InPlaceLookupBuilder: + lookup: Lookup = field(default_factory=Lookup) + + def build(self: InPlaceLookupBuilder, _: StepType) -> Lookup: + return self.lookup + + def add( + self: InPlaceLookupBuilder, constraint: ToConstraint, expression: ToExpr + ) -> InPlaceLookupBuilder: + constraint = to_constraint(constraint) + self.lookup.add(constraint.annotation, constraint.expr, to_expr(expression)) + return self + + def enable( + self: InPlaceLookupBuilder, enable: ToConstraint + ) -> InPlaceLookupBuilder: + enable = to_constraint(enable) + self.lookup.enable(enable.annotation, enable.expr) + return self diff --git a/src/frontend/python/chiquito/chiquito_ast.py b/src/frontend/python/chiquito/chiquito_ast.py index 447720b5..f4818b80 100644 --- a/src/frontend/python/chiquito/chiquito_ast.py +++ b/src/frontend/python/chiquito/chiquito_ast.py @@ -1,13 +1,10 @@ from __future__ import annotations -from typing import Callable, List, Dict, Optional, Any, Tuple +from typing import List, Dict, Optional, Tuple from dataclasses import dataclass, field, asdict -# from chiquito import wit_gen, expr, query, util - -from chiquito.wit_gen import FixedGenContext, StepInstance from chiquito.expr import Expr -from chiquito.util import uuid -from chiquito.query import Queriable +from chiquito.util import uuid, F +from chiquito.query import Queriable, Fixed # pub struct Circuit { @@ -30,6 +27,8 @@ # pub num_steps: usize, # } +FixedAssignment = Dict[Queriable, List[F]] + @dataclass class ASTCircuit: @@ -39,7 +38,7 @@ class ASTCircuit: fixed_signals: List[FixedSignal] = field(default_factory=list) exposed: List[Tuple[Queriable, ExposeOffset]] = field(default_factory=list) annotations: Dict[int, str] = field(default_factory=dict) - fixed_gen: Optional[Callable] = None + fixed_assignments: Optional[FixedAssignment] = None first_step: Optional[int] = None last_step: Optional[int] = None num_steps: int = 0 @@ -92,7 +91,7 @@ def __str__(self: ASTCircuit): f"\tfixed_signals=[{fixed_signals_str}],\n" f"\texposed=[{exposed_str}],\n" f"\tannotations={{{annotations_str}}},\n" - f"\tfixed_gen={self.fixed_gen},\n" + f"\tfixed_assignments={self.fixed_assignments},\n" f"\tfirst_step={self.first_step},\n" f"\tlast_step={self.last_step},\n" f"\tnum_steps={self.num_steps}\n" @@ -111,6 +110,11 @@ def __json__(self: ASTCircuit): for (queriable, offset) in self.exposed ], "annotations": self.annotations, + "fixed_assignments": None + if self.fixed_assignments is None + else { + lhs.uuid(): [lhs, rhs] for (lhs, rhs) in self.fixed_assignments.items() + }, "first_step": self.first_step, "last_step": self.last_step, "num_steps": self.num_steps, @@ -143,11 +147,14 @@ def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str): self.annotations[step_type.id] = name self.step_types[step_type.id] = step_type - def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]): - if self.fixed_gen is not None: - raise Exception("ASTCircuit cannot have more than one fixed generator.") + def add_fixed_assignment(self: ASTCircuit, offset: int, lhs: Queriable, rhs: F): + if not isinstance(lhs, Fixed): + raise ValueError(f"Cannot assign to non-fixed signal.") + if lhs in self.fixed_assignments.keys(): + self.fixed_assignments[lhs][offset] = rhs else: - self.fixed_gen = fixed_gen_def + self.fixed_assignments[lhs] = [F.zero()] * self.num_steps + self.fixed_assignments[lhs][offset] = rhs def get_step_type(self, uuid: int) -> ASTStepType: if uuid in self.step_types.keys(): @@ -175,10 +182,11 @@ class ASTStepType: signals: List[InternalSignal] constraints: List[ASTConstraint] transition_constraints: List[TransitionConstraint] + lookups: List[Lookup] annotations: Dict[int, str] def new(name: str) -> ASTStepType: - return ASTStepType(uuid(), name, [], [], [], {}) + return ASTStepType(uuid(), name, [], [], [], [], {}) def __str__(self): signals_str = ( @@ -202,6 +210,13 @@ def __str__(self): if self.transition_constraints else "" ) + lookups_str = ( + "\n\t\t\t\t" + + ",\n\t\t\t\t".join(str(lookup) for lookup in self.lookups) + + "\n\t\t\t" + if self.lookups + else "" + ) annotations_str = ( "\n\t\t\t\t" + ",\n\t\t\t\t".join(f"{k}: {v}" for k, v in self.annotations.items()) @@ -217,6 +232,7 @@ def __str__(self): f"\t\t\tsignals=[{signals_str}],\n" f"\t\t\tconstraints=[{constraints_str}],\n" f"\t\t\ttransition_constraints=[{transition_constraints_str}],\n" + f"\t\t\tlookups=[{lookups_str}],\n" f"\t\t\tannotations={{{annotations_str}}}\n" f"\t\t)" ) @@ -230,6 +246,7 @@ def __json__(self): "transition_constraints": [ x.__json__() for x in self.transition_constraints ], + "lookups": [x.__json__() for x in self.lookups], "annotations": self.annotations, } @@ -385,3 +402,50 @@ def __str__(self: InternalSignal): def __json__(self: InternalSignal): return asdict(self) + + +@dataclass +class Lookup: + annotation: str = "" + exprs: List[Tuple[ASTConstraint, Expr]] = field(default_factory=list) + enable: Optional[ASTConstraint] = None + + def add( + self: Lookup, + constraint_annotation: str, + constraint_expr: Expr, + expression: Expr, + ): + constraint = ASTConstraint(constraint_annotation, constraint_expr) + self.annotation += f"match({constraint.annotation} => {str(expression)}) " + if self.enable is None: + self.exprs.append((constraint, expression)) + else: + self.exprs.append( + (self.multiply_constraints(self.enable, constraint), expression) + ) + + def enable(self: Lookup, enable_annotation: str, enable_expr: Expr): + enable = ASTConstraint(enable_annotation, enable_expr) + if self.enable is None: + for constraint, _ in self.exprs: + constraint = self.multiply_constraints(enable, constraint) + self.enable = enable + self.annotation = f"if {enable_annotation}, {self.annotation}" + else: + raise ValueError("Lookup: enable() can only be called once.") + + def multiply_constraints( + enable: ASTConstraint, constraint: ASTConstraint + ) -> ASTConstraint: + return ASTConstraint(constraint.annotation, enable.expr * constraint.expr) + + def __str__(self: Lookup): + return f"Lookup({self.annotation})" + + def __json__(self: Lookup): + return { + "annotation": self.annotation, + "exprs": [[x.__json__(), y.__json__()] for (x, y) in self.exprs], + "enable": self.enable.__json__() if self.enable is not None else None, + } diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index 3f00cd85..9232a615 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -1,16 +1,21 @@ from __future__ import annotations from enum import Enum from typing import Callable, Any - -# import rust_chiquito # rust bindings -from chiquito import rust_chiquito +from chiquito import rust_chiquito # rust bindings import json -from chiquito import chiquito_ast, wit_gen from chiquito.chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset from chiquito.query import Internal, Forward, Queriable, Shared, Fixed -from chiquito.wit_gen import FixedGenContext, StepInstance, TraceWitness -from chiquito.cb import Constraint, Typing, ToConstraint, to_constraint +from chiquito.wit_gen import StepInstance, TraceWitness +from chiquito.cb import ( + Constraint, + Typing, + ToConstraint, + to_constraint, + LookupTable, + LookupTableBuilder, + InPlaceLookupBuilder, +) from chiquito.util import CustomEncoder, F @@ -18,15 +23,26 @@ class CircuitMode(Enum): NoMode = 0 SETUP = 1 Trace = 2 + FixedGen = 3 class Circuit: def __init__(self: Circuit): self.ast = ASTCircuit() + self.tables: Dict[int, LookupTable] = {} self.witness = TraceWitness() self.rust_ast_id = 0 self.mode = CircuitMode.SETUP self.setup() + if hasattr(self, "fixed_gen") and callable(self.fixed_gen): + self.mode = CircuitMode.FixedGen + if self.ast.num_steps == 0: + raise ValueError( + "Must set num_steps by calling pragma_num_steps() in setup before calling fixed_gen()." + ) + self.ast.fixed_assignments = {} + self.fixed_gen() + self.mode = CircuitMode.NoMode def forward(self: Circuit, name: str) -> Forward: assert self.mode == CircuitMode.SETUP @@ -64,9 +80,6 @@ def step_type_def(self: StepType) -> StepType: assert self.mode == CircuitMode.SETUP self.ast.add_step_type_def() - def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]): - self.ast.set_fixed_gen(fixed_gen_def) - def pragma_first_step(self: Circuit, step_type: StepType) -> None: assert self.mode == CircuitMode.SETUP self.ast.first_step = step_type.step_type.id @@ -83,7 +96,14 @@ def pragma_disable_q_enable(self: Circuit) -> None: assert self.mode == CircuitMode.SETUP self.ast.q_enable = False - def add(self: Circuit, step_type: StepType, *args): + def new_table(self: Circuit, table: LookupTable) -> LookupTable: + assert self.mode == CircuitMode.SETUP + table.read_only = True + self.tables[table.uuid] = table + return table + + # called under trace() + def add(self: Circuit, step_type: StepType, args: Any): assert self.mode == CircuitMode.Trace if len(self.witness.step_instances) >= self.ast.num_steps: raise ValueError(f"Number of step instances exceeds {self.ast.num_steps}") @@ -99,6 +119,15 @@ def padding(self: Circuit, step_type: StepType, *args): while self.needs_padding(): self.add(step_type, *args) + # called under fixed_gen() + def assign(self: Circuit, offset: int, lhs: Queriable, rhs: F): + assert self.mode == CircuitMode.FixedGen + if self.ast.fixed_assignments is None: + raise ValueError( + "FixedAssignment: must have initiated fixed_assignments before calling assign()" + ) + self.ast.add_fixed_assignment(offset, lhs, rhs) + def gen_witness(self: Circuit, *args) -> TraceWitness: self.mode = CircuitMode.Trace self.witness = TraceWitness() @@ -174,4 +203,8 @@ def assign(self: StepType, lhs: Queriable, rhs: F): self.step_instance.assign(lhs, rhs) - # TODO: Implement add_lookup after lookup abstraction PR is merged. + def add_lookup(self: StepType, lookup_builder: LookupBuilder): + self.step_type.lookups.append(lookup_builder.build(self)) + + +LookupBuilder = LookupTableBuilder | InPlaceLookupBuilder diff --git a/src/frontend/python/chiquito/wit_gen.py b/src/frontend/python/chiquito/wit_gen.py index b7408756..cec59905 100644 --- a/src/frontend/python/chiquito/wit_gen.py +++ b/src/frontend/python/chiquito/wit_gen.py @@ -1,6 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Dict, List, Callable, Any +from typing import Dict, List import json from chiquito.query import Queriable, Fixed @@ -94,29 +94,4 @@ def evil_witness_test( return TraceWitness(new_step_instances) -FixedAssigment = Dict[Queriable, List[F]] - - -@dataclass -class FixedGenContext: - assignments: FixedAssigment = field(default_factory=dict) - num_steps: int = 0 - - def new(num_steps: int) -> FixedGenContext: - return FixedGenContext({}, num_steps) - - def assign(self: FixedGenContext, offset: int, lhs: Queriable, rhs: F): - if not FixedGenContext.is_fixed_queriable(lhs): - raise ValueError(f"Cannot assign to non-fixed signal.") - if lhs in self.assignments.keys(): - self.assignments[lhs][offset] = rhs - else: - self.assignments[lhs] = [F.zero()] * self.num_steps - self.assignments[lhs][offset] = rhs - - def is_fixed_queriable(q: Queriable) -> bool: - match q.enum: - case Fixed(_, _): - return True - case _: - return False +FixedAssignment = Dict[Queriable, List[F]] diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index 722ae5bc..55c9510c 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -6,7 +6,7 @@ use pyo3::{ use crate::{ ast::{ query::Queriable, Circuit, Constraint, ExposeOffset, FixedSignal, ForwardSignal, - InternalSignal, SharedSignal, StepType, StepTypeUUID, TransitionConstraint, + InternalSignal, Lookup, SharedSignal, StepType, StepTypeUUID, TransitionConstraint, }, frontend::dsl::StepTypeHandler, plonkish::{ @@ -102,6 +102,7 @@ impl<'de> Visitor<'de> for CircuitVisitor { let mut fixed_signals = None; let mut exposed = None; let mut annotations = None; + let mut fixed_assignments = None; let mut first_step = None; let mut last_step = None; let mut num_steps = None; @@ -146,6 +147,13 @@ impl<'de> Visitor<'de> for CircuitVisitor { } annotations = Some(map.next_value::>()?); } + "fixed_assignments" => { + if fixed_assignments.is_some() { + return Err(de::Error::duplicate_field("fixed_assignments")); + } + fixed_assignments = + Some(map.next_value::, Vec)>>>()?); + } "first_step" => { if first_step.is_some() { return Err(de::Error::duplicate_field("first_step")); @@ -186,6 +194,7 @@ impl<'de> Visitor<'de> for CircuitVisitor { "fixed_signals", "exposed", "annotations", + "fixed_assignments", "first_step", "last_step", "num_steps", @@ -209,6 +218,9 @@ impl<'de> Visitor<'de> for CircuitVisitor { fixed_signals.ok_or_else(|| de::Error::missing_field("fixed_signals"))?; let exposed = exposed.ok_or_else(|| de::Error::missing_field("exposed"))?; let annotations = annotations.ok_or_else(|| de::Error::missing_field("annotations"))?; + let fixed_assignments = fixed_assignments + .ok_or_else(|| de::Error::missing_field("fixed_assignments"))? + .map(|inner| inner.into_values().collect()); let first_step = first_step.ok_or_else(|| de::Error::missing_field("first_step"))?; let last_step = last_step.ok_or_else(|| de::Error::missing_field("last_step"))?; let num_steps = num_steps.ok_or_else(|| de::Error::missing_field("num_steps"))?; @@ -226,7 +238,7 @@ impl<'de> Visitor<'de> for CircuitVisitor { num_steps, annotations, trace: Some(Rc::new(|_: &mut TraceContext<_>, _: _| {})), - fixed_gen: None, + fixed_assignments, first_step, last_step, q_enable, @@ -252,6 +264,7 @@ impl<'de> Visitor<'de> for StepTypeVisitor { let mut signals = None; let mut constraints = None; let mut transition_constraints = None; + let mut lookups = None; let mut annotations = None; while let Some(key) = map.next_key::()? { @@ -287,6 +300,12 @@ impl<'de> Visitor<'de> for StepTypeVisitor { transition_constraints = Some(map.next_value::>>()?); } + "lookups" => { + if lookups.is_some() { + return Err(de::Error::duplicate_field("lookups")); + } + lookups = Some(map.next_value::>>()?); + } "annotations" => { if annotations.is_some() { return Err(de::Error::duplicate_field("annotations")); @@ -302,6 +321,7 @@ impl<'de> Visitor<'de> for StepTypeVisitor { "signals", "constraints", "transition_constraints", + "lookups", "annotations", ], )) @@ -377,6 +397,62 @@ impl_visitor_constraint_transition!( "struct TransitionConstraint" ); +struct LookupVisitor; + +impl<'de> Visitor<'de> for LookupVisitor { + type Value = Lookup; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Lookup") + } + + fn visit_map(self, mut map: A) -> Result, A::Error> + where + A: MapAccess<'de>, + { + let mut annotation = None; + let mut exprs = None; + let mut enable = None; + while let Some(key) = map.next_key::()? { + match key.as_str() { + "annotation" => { + if annotation.is_some() { + return Err(de::Error::duplicate_field("annotation")); + } + annotation = Some(map.next_value::()?); + } + "exprs" => { + if exprs.is_some() { + return Err(de::Error::duplicate_field("exprs")); + } + exprs = + Some(map.next_value::, Expr>)>>()?); + } + "enable" => { + if enable.is_some() { + return Err(de::Error::duplicate_field("enable")); + } + enable = Some(map.next_value::>>()?); + } + _ => { + return Err(de::Error::unknown_field( + &key, + &["annotation", "exprs", "enable"], + )) + } + } + } + let annotation = annotation.ok_or_else(|| de::Error::missing_field("annotation"))?; + let exprs = exprs.ok_or_else(|| de::Error::missing_field("exprs"))?; + let enable = enable.ok_or_else(|| de::Error::missing_field("enable"))?; + Ok(Self::Value { + annotation, + exprs, + enable, + }) + } +} + struct ExprVisitor; impl<'de> Visitor<'de> for ExprVisitor { @@ -725,6 +801,7 @@ impl_deserialize!(TransitionConstraintVisitor, TransitionConstraint); impl_deserialize!(StepTypeVisitor, StepType); impl_deserialize!(TraceWitnessVisitor, TraceWitness); impl_deserialize!(StepInstanceVisitor, StepInstance); +impl_deserialize!(LookupVisitor, Lookup); impl<'de> Deserialize<'de> for Circuit { fn deserialize(deserializer: D) -> Result, D::Error> @@ -887,16 +964,70 @@ mod tests { let json = r#" { "step_types": { - "205524326356431126935662643926474033674": { - "id": 205524326356431126935662643926474033674, - "name": "fibo_step", + "258869595755756204079859764249309612554": { + "id": 258869595755756204079859764249309612554, + "name": "fibo_first_step", "signals": [ { - "id": 205524332694684128074575021569884162570, + "id": 258869599717164329791616633222308956682, "annotation": "c" } ], "constraints": [ + { + "annotation": "(a == 1)", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869580702405326369584955980151130634, + "phase": 0, + "annotation": "a" + }, + false + ] + }, + { + "Neg": { + "Const": [ + 1, + 0, + 0, + 0 + ] + } + } + ] + } + }, + { + "annotation": "(b == 1)", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + false + ] + }, + { + "Neg": { + "Const": [ + 1, + 0, + 0, + 0 + ] + } + } + ] + } + }, { "annotation": "((a + b) == c)", "expr": { @@ -904,7 +1035,7 @@ mod tests { { "Forward": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869580702405326369584955980151130634, "phase": 0, "annotation": "a" }, @@ -914,7 +1045,7 @@ mod tests { { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, @@ -924,7 +1055,7 @@ mod tests { { "Neg": { "Internal": { - "id": 205524332694684128074575021569884162570, + "id": 258869599717164329791616633222308956682, "annotation": "c" } } @@ -941,7 +1072,7 @@ mod tests { { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, @@ -952,7 +1083,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869580702405326369584955980151130634, "phase": 0, "annotation": "a" }, @@ -969,7 +1100,7 @@ mod tests { "Sum": [ { "Internal": { - "id": 205524332694684128074575021569884162570, + "id": 258869599717164329791616633222308956682, "annotation": "c" } }, @@ -977,7 +1108,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, @@ -987,18 +1118,48 @@ mod tests { } ] } + }, + { + "annotation": "(n == next(n))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + true + ] + } + } + ] + } } ], + "lookups": [], "annotations": { - "205524332694684128074575021569884162570": "c" + "258869599717164329791616633222308956682": "c" } }, - "205524373893328635494146417612672338442": { - "id": 205524373893328635494146417612672338442, - "name": "fibo_last_step", + "258869628239302834927102989021255174666": { + "id": 258869628239302834927102989021255174666, + "name": "fibo_step", "signals": [ { - "id": 205524377062455136063336753318874188298, + "id": 258869632200710960639812650790420089354, "annotation": "c" } ], @@ -1010,7 +1171,7 @@ mod tests { { "Forward": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869580702405326369584955980151130634, "phase": 0, "annotation": "a" }, @@ -1020,7 +1181,7 @@ mod tests { { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, @@ -1030,7 +1191,7 @@ mod tests { { "Neg": { "Internal": { - "id": 205524377062455136063336753318874188298, + "id": 258869632200710960639812650790420089354, "annotation": "c" } } @@ -1039,22 +1200,180 @@ mod tests { } } ], - "transition_constraints": [], + "transition_constraints": [ + { + "annotation": "(b == next(a))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869580702405326369584955980151130634, + "phase": 0, + "annotation": "a" + }, + true + ] + } + } + ] + } + }, + { + "annotation": "(c == next(b))", + "expr": { + "Sum": [ + { + "Internal": { + "id": 258869632200710960639812650790420089354, + "annotation": "c" + } + }, + { + "Neg": { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + true + ] + } + } + ] + } + }, + { + "annotation": "(n == next(n))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + true + ] + } + } + ] + } + } + ], + "lookups": [], "annotations": { - "205524377062455136063336753318874188298": "c" + "258869632200710960639812650790420089354": "c" } + }, + "258869646461780213207493341245063432714": { + "id": 258869646461780213207493341245063432714, + "name": "padding", + "signals": [], + "constraints": [], + "transition_constraints": [ + { + "annotation": "(b == next(b))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869587040658327507391136965088381450, + "phase": 0, + "annotation": "b" + }, + true + ] + } + } + ] + } + }, + { + "annotation": "(n == next(n))", + "expr": { + "Sum": [ + { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + false + ] + }, + { + "Neg": { + "Forward": [ + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" + }, + true + ] + } + } + ] + } + } + ], + "lookups": [], + "annotations": {} } }, "forward_signals": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869580702405326369584955980151130634, "phase": 0, "annotation": "a" }, { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" + }, + { + "id": 258869589417503202934383108674030275082, + "phase": 0, + "annotation": "n" } ], "shared_signals": [], @@ -1064,28 +1383,13 @@ mod tests { { "Forward": [ { - "id": 205524322395023001221676493137926294026, + "id": 258869587040658327507391136965088381450, "phase": 0, "annotation": "b" }, false ] }, - { - "First": 0 - } - ], - [ - { - "Forward": [ - { - "id": 205524314472206749795829327634996267530, - "phase": 0, - "annotation": "a" - }, - false - ] - }, { "Last": -1 } @@ -1094,29 +1398,32 @@ mod tests { { "Forward": [ { - "id": 205524314472206749795829327634996267530, + "id": 258869589417503202934383108674030275082, "phase": 0, - "annotation": "a" + "annotation": "n" }, false ] }, { - "Step": 1 + "Last": -1 } ] ], "annotations": { - "205524314472206749795829327634996267530": "a", - "205524322395023001221676493137926294026": "b", - "205524326356431126935662643926474033674": "fibo_step", - "205524373893328635494146417612672338442": "fibo_last_step" + "258869580702405326369584955980151130634": "a", + "258869587040658327507391136965088381450": "b", + "258869589417503202934383108674030275082": "n", + "258869595755756204079859764249309612554": "fibo_first_step", + "258869628239302834927102989021255174666": "fibo_step", + "258869646461780213207493341245063432714": "padding" }, - "first_step": 205524326356431126935662643926474033674, - "last_step": 205524373893328635494146417612672338442, - "num_steps": 0, - "q_enable": false, - "id": 205522563529815184552233780032226069002 + "fixed_assignments": null, + "first_step": 258869595755756204079859764249309612554, + "last_step": 258869646461780213207493341245063432714, + "num_steps": 10, + "q_enable": true, + "id": 258867373405797678961444396351437277706 } "#; let circuit: Circuit = serde_json::from_str(json).unwrap(); diff --git a/src/plonkish/compiler/mod.rs b/src/plonkish/compiler/mod.rs index 2efacbf1..5647772f 100644 --- a/src/plonkish/compiler/mod.rs +++ b/src/plonkish/compiler/mod.rs @@ -6,7 +6,7 @@ use crate::{ Circuit, Column, Poly, PolyExpr, PolyLookup, }, poly::Expr, - wit_gen::{FixedAssignment, FixedGenContext, TraceGenerator}, + wit_gen::{FixedAssignment, TraceGenerator}, }; use std::{hash::Hash, rc::Rc}; @@ -228,13 +228,8 @@ fn compile_fixed( ast: &astCircuit, unit: &mut CompilationUnit, ) { - if let Some(fixed_gen) = &ast.fixed_gen { - let mut ctx = FixedGenContext::new(unit.num_steps); - (*fixed_gen)(&mut ctx); - - let assignments = ctx.get_assignments(); - - unit.fixed_assignments = place_fixed_assignments(unit, assignments); + if let Some(fixed_assignments) = &ast.fixed_assignments { + unit.fixed_assignments = place_fixed_assignments(unit, fixed_assignments.clone()); } }