From 905adc4595cdd6482d1831d315ad051e4b9eaa17 Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Tue, 1 Aug 2023 08:21:49 +0800 Subject: [PATCH] Steve/refactor circuit (#8) * refactored everything * refactored everything and example * removed unused imports * renamed step_type_context vairables * fixed type and field namings * updated generate_witness method * addressed leo comments and modified wg function * updated tracecontext and cleaned up debug methods * cleaned up other files * removed tests --- pychiquito/cb.py | 12 +-- pychiquito/chiquito_ast.py | 95 +++++++------------- pychiquito/dsl.py | 180 +++++++++++++++++++++---------------- pychiquito/expr.py | 9 -- pychiquito/fibonacci.py | 134 +++++++-------------------- pychiquito/query.py | 17 ++-- pychiquito/tests.py | 50 ----------- pychiquito/util.py | 8 ++ pychiquito/wit_gen.py | 47 ++-------- src/lib.rs | 1 - 10 files changed, 195 insertions(+), 358 deletions(-) delete mode 100644 pychiquito/tests.py diff --git a/pychiquito/cb.py b/pychiquito/cb.py index 4ceab24..cd3c808 100644 --- a/pychiquito/cb.py +++ b/pychiquito/cb.py @@ -6,11 +6,7 @@ from util import F from expr import Expr, Const, Neg, to_expr, ToExpr from query import StepTypeNext -from chiquito_ast import StepType - -########## -# dsl/cb # -########## +from chiquito_ast import ASTStepType class Typing(Enum): @@ -165,7 +161,7 @@ def isz(constraint: ToConstraint) -> Constraint: ) -def if_next_step(step_type: StepType, constraint: ToConstraint) -> Constraint: +def if_next_step(step_type: ASTStepType, constraint: ToConstraint) -> Constraint: constraint = to_constraint(constraint) return Constraint( f"if(next step is {step_type.annotation})then({constraint.annotation})", @@ -174,7 +170,7 @@ def if_next_step(step_type: StepType, constraint: ToConstraint) -> Constraint: ) -def next_step_must_be(step_type: StepType) -> Constraint: +def next_step_must_be(step_type: ASTStepType) -> Constraint: return Constraint( f"next step must be {step_type.annotation}", Constraint.cb_not(StepTypeNext(step_type)), @@ -182,7 +178,7 @@ def next_step_must_be(step_type: StepType) -> Constraint: ) -def next_step_must_not_be(step_type: StepType) -> Constraint: +def next_step_must_not_be(step_type: ASTStepType) -> Constraint: return Constraint( f"next step must not be {step_type.annotation}", StepTypeNext(step_type), diff --git a/pychiquito/chiquito_ast.py b/pychiquito/chiquito_ast.py index 5dd746e..0655276 100644 --- a/pychiquito/chiquito_ast.py +++ b/pychiquito/chiquito_ast.py @@ -2,17 +2,14 @@ from typing import Callable, List, Dict, Optional, Any, Tuple from dataclasses import dataclass, field, asdict -from wit_gen import TraceContext, FixedGenContext, StepInstance +from wit_gen import FixedGenContext, StepInstance from expr import Expr from util import uuid from query import Queriable -####### -# ast # -####### # pub struct Circuit { -# pub step_types: HashMap>>, +# pub step_types: HashMap>>, # pub forward_signals: Vec, # pub shared_signals: Vec, @@ -26,21 +23,20 @@ # pub trace: Option>>, # pub fixed_gen: Option>>, -# pub first_step: Option, -# pub last_step: Option, +# pub first_step: Option, +# pub last_step: Option, # pub num_steps: usize, # } @dataclass -class Circuit: - step_types: Dict[int, StepType] = field(default_factory=dict) +class ASTCircuit: + step_types: Dict[int, ASTStepType] = field(default_factory=dict) forward_signals: List[ForwardSignal] = field(default_factory=list) shared_signals: List[SharedSignal] = field(default_factory=list) fixed_signals: List[FixedSignal] = field(default_factory=list) exposed: List[Tuple[Queriable, ExposeOffset]] = field(default_factory=list) annotations: Dict[int, str] = field(default_factory=dict) - trace: Optional[Callable[[TraceContext, Any], None]] = None fixed_gen: Optional[Callable] = None first_step: Optional[int] = None last_step: Optional[int] = None @@ -48,7 +44,7 @@ class Circuit: q_enable: bool = True id: int = uuid() - def __str__(self: Circuit): + def __str__(self: ASTCircuit): step_types_str = ( "\n\t\t" + ",\n\t\t".join(f"{k}: {v}" for k, v in self.step_types.items()) @@ -57,20 +53,17 @@ def __str__(self: Circuit): else "" ) forward_signals_str = ( - "\n\t\t" + ",\n\t\t".join(str(fs) - for fs in self.forward_signals) + "\n\t" + "\n\t\t" + ",\n\t\t".join(str(fs) for fs in self.forward_signals) + "\n\t" if self.forward_signals else "" ) shared_signals_str = ( - "\n\t\t" + ",\n\t\t".join(str(ss) - for ss in self.shared_signals) + "\n\t" + "\n\t\t" + ",\n\t\t".join(str(ss) for ss in self.shared_signals) + "\n\t" if self.shared_signals else "" ) fixed_signals_str = ( - "\n\t\t" + ",\n\t\t".join(str(fs) - for fs in self.fixed_signals) + "\n\t" + "\n\t\t" + ",\n\t\t".join(str(fs) for fs in self.fixed_signals) + "\n\t" if self.fixed_signals else "" ) @@ -90,14 +83,13 @@ def __str__(self: Circuit): ) return ( - f"Circuit(\n" + f"ASTCircuit(\n" f"\tstep_types={{{step_types_str}}},\n" f"\tforward_signals=[{forward_signals_str}],\n" f"\tshared_signals=[{shared_signals_str}],\n" f"\tfixed_signals=[{fixed_signals_str}],\n" f"\texposed=[{exposed_str}],\n" f"\tannotations={{{annotations_str}}},\n" - f"\ttrace={self.trace},\n" f"\tfixed_gen={self.fixed_gen},\n" f"\tfirst_step={self.first_step},\n" f"\tlast_step={self.last_step},\n" @@ -106,7 +98,7 @@ def __str__(self: Circuit): f")" ) - def __json__(self: Circuit): + def __json__(self: ASTCircuit): return { "step_types": {k: v.__json__() for k, v in self.step_types.items()}, "forward_signals": [x.__json__() for x in self.forward_signals], @@ -124,52 +116,42 @@ def __json__(self: Circuit): "id": self.id, } - def add_forward(self: Circuit, name: str, phase: int) -> ForwardSignal: + def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal: signal = ForwardSignal(phase, name) self.forward_signals.append(signal) self.annotations[signal.id] = name return signal - def add_shared(self: Circuit, name: str, phase: int) -> SharedSignal: + def add_shared(self: ASTCircuit, name: str, phase: int) -> SharedSignal: signal = SharedSignal(phase, name) self.shared_signals.append(signal) self.annotations[signal.id] = name return signal - def add_fixed(self: Circuit, name: str) -> FixedSignal: + def add_fixed(self: ASTCircuit, name: str) -> FixedSignal: signal = FixedSignal(name) self.fixed_signals.append(signal) self.annotations[signal.id] = name return signal - def expose(self: Circuit, signal: Queriable, offset: ExposeOffset): + def expose(self: ASTCircuit, signal: Queriable, offset: ExposeOffset): self.exposed.append((signal, offset)) - def add_step_type(self: Circuit, step_type: StepType, name: str): + def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str): self.annotations[step_type.id] = name self.step_types[step_type.id] = step_type - def set_trace( - self: Circuit, trace_def: Callable[[TraceContext, Any], None] - ): # TraceArgs are Any. - if self.trace is not None: - raise Exception( - "Circuit cannot have more than one trace generator.") - else: - self.trace = trace_def - def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]): if self.fixed_gen is not None: - raise Exception( - "Circuit cannot have more than one fixed generator.") + raise Exception("ASTCircuit cannot have more than one fixed generator.") else: self.fixed_gen = fixed_gen_def - def get_step_type(self, uuid: int) -> StepType: + def get_step_type(self, uuid: int) -> ASTStepType: if uuid in self.step_types.keys(): return self.step_types[uuid] else: - raise ValueError("StepType not found.") + raise ValueError("ASTStepType not found.") # pub struct StepType { @@ -185,19 +167,16 @@ def get_step_type(self, uuid: int) -> StepType: @dataclass -class StepType: +class ASTStepType: id: int name: str signals: List[InternalSignal] constraints: List[ASTConstraint] transition_constraints: List[TransitionConstraint] annotations: Dict[int, str] - wg: Optional[ - Callable[[StepInstance, Any], None] - ] # Args are Any. Not passed to Rust Chiquito. - def new(name: str) -> StepType: - return StepType(uuid(), name, [], [], [], {}, None) + def new(name: str) -> ASTStepType: + return ASTStepType(uuid(), name, [], [], [], {}) def __str__(self): signals_str = ( @@ -209,8 +188,7 @@ def __str__(self): ) constraints_str = ( "\n\t\t\t\t" - + ",\n\t\t\t\t".join(str(constraint) - for constraint in self.constraints) + + ",\n\t\t\t\t".join(str(constraint) for constraint in self.constraints) + "\n\t\t\t" if self.constraints else "" @@ -224,15 +202,14 @@ def __str__(self): ) annotations_str = ( "\n\t\t\t\t" - + ",\n\t\t\t\t".join(f"{k}: {v}" for k, - v in self.annotations.items()) + + ",\n\t\t\t\t".join(f"{k}: {v}" for k, v in self.annotations.items()) + "\n\t\t\t" if self.annotations else "" ) return ( - f"StepType(\n" + f"ASTStepType(\n" f"\t\t\tid={self.id},\n" f"\t\t\tname='{self.name}',\n" f"\t\t\tsignals=[{signals_str}],\n" @@ -254,33 +231,29 @@ def __json__(self): "annotations": self.annotations, } - def add_signal(self: StepType, name: str) -> InternalSignal: + def add_signal(self: ASTStepType, name: str) -> InternalSignal: signal = InternalSignal(name) self.signals.append(signal) self.annotations[signal.id] = name return signal - def add_constr(self: StepType, annotation: str, expr: Expr): + def add_constr(self: ASTStepType, annotation: str, expr: Expr): condition = ASTConstraint(annotation, expr) self.constraints.append(condition) - def add_transition(self: StepType, annotation: str, expr: Expr): + def add_transition(self: ASTStepType, annotation: str, expr: Expr): condition = TransitionConstraint(annotation, expr) self.transition_constraints.append(condition) - # Args are Any. - def set_wg(self, wg_def: Callable[[StepInstance, Any], None]): - self.wg = wg_def - - def __eq__(self: StepType, other: StepType) -> bool: - if isinstance(self, StepType) and isinstance(other, StepType): + def __eq__(self: ASTStepType, other: ASTStepType) -> bool: + if isinstance(self, ASTStepType) and isinstance(other, ASTStepType): return self.id == other.id return False - def __req__(other: StepType, self: StepType) -> bool: - return StepType.__eq__(self, other) + def __req__(other: ASTStepType, self: ASTStepType) -> bool: + return ASTStepType.__eq__(self, other) - def __hash__(self: StepType): + def __hash__(self: ASTStepType): return hash(self.id) diff --git a/pychiquito/dsl.py b/pychiquito/dsl.py index a941e4b..5ca991d 100644 --- a/pychiquito/dsl.py +++ b/pychiquito/dsl.py @@ -1,84 +1,111 @@ from __future__ import annotations from enum import Enum from typing import Callable, Any -from dataclasses import dataclass +import rust_chiquito # rust bindings +import json -from chiquito_ast import Circuit, StepType, ExposeOffset, ForwardSignal, SharedSignal +from chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset from query import Internal, Forward, Queriable, Shared, Fixed -from wit_gen import FixedGenContext, TraceContext +from wit_gen import FixedGenContext, StepInstance, TraceWitness from cb import Constraint, Typing, ToConstraint, to_constraint +from util import CustomEncoder, F -####### -# dsl # -####### +class CircuitMode(Enum): + NoMode = 0 + SETUP = 1 + Trace = 2 -class CircuitContext: - def __init__(self): - self.circuit = Circuit() +class Circuit: + def __init__(self: Circuit): + self.ast = ASTCircuit() + self.witness = TraceWitness() + self.rust_ast_id = 0 + self.mode = CircuitMode.SETUP + self.setup() - def forward(self: CircuitContext, name: str) -> Forward: - return Forward(self.circuit.add_forward(name, 0), False) + def forward(self: Circuit, name: str) -> Forward: + assert self.mode == CircuitMode.SETUP + return Forward(self.ast.add_forward(name, 0), False) - def forward_with_phase(self: CircuitContext, name: str, phase: int) -> Forward: - return Forward(self.circuit.add_forward(name, phase), False) + def forward_with_phase(self: Circuit, name: str, phase: int) -> Forward: + assert self.mode == CircuitMode.SETUP + return Forward(self.ast.add_forward(name, phase), False) - def shared(self: CircuitContext, name: str) -> Shared: - return Shared(self.circuit.add_shared(name, 0), 0) + def shared(self: Circuit, name: str) -> Shared: + assert self.mode == CircuitMode.SETUP + return Shared(self.ast.add_shared(name, 0), 0) - def shared_with_phase(self: CircuitContext, name: str, phase: int) -> Shared: - return Shared(self.circuit.add_shared(name, phase), 0) + def shared_with_phase(self: Circuit, name: str, phase: int) -> Shared: + assert self.mode == CircuitMode.SETUP + return Shared(self.ast.add_shared(name, phase), 0) - def fixed(self: CircuitContext, name: str) -> Fixed: - return Fixed(self.circuit.add_fixed(name), 0) + def fixed(self: Circuit, name: str) -> Fixed: + assert self.mode == CircuitMode.SETUP + return Fixed(self.ast.add_fixed(name), 0) - def expose(self: CircuitContext, signal: Queriable, offset: ExposeOffset): + def expose(self: Circuit, signal: Queriable, offset: ExposeOffset): + assert self.mode == CircuitMode.SETUP if isinstance(signal, (Forward, Shared)): - self.circuit.expose(signal, offset) + self.ast.expose(signal, offset) else: raise TypeError(f"Can only expose ForwardSignal or SharedSignal.") - # import_halo2_advice and import_halo2_fixed are ignored. + def step_type(self: Circuit, step_type: StepType) -> StepType: + assert self.mode == CircuitMode.SETUP + self.ast.add_step_type(step_type.step_type, step_type.step_type.name) + return step_type + + def step_type_def(self: StepType) -> StepType: + assert self.mode == CircuitMode.SETUP + self.ast.add_step_type_def() + + def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]): + self.ast.set_fixed_gen(fixed_gen_def) + + def pragma_first_step(self: Circuit, step_type: StepType) -> None: + assert self.mode == CircuitMode.SETUP + self.ast.first_step = step_type.step_type.id - def step_type( - self: CircuitContext, step_type_context: StepTypeContext - ) -> StepTypeContext: - self.circuit.add_step_type( - step_type_context.step_type, step_type_context.step_type.name - ) - return step_type_context + def pragma_last_step(self: Circuit, step_type: StepType) -> None: + assert self.mode == CircuitMode.SETUP + self.ast.last_step = step_type.step_type.id - def step_type_def(self: StepTypeContext) -> StepTypeContext: - self.circuit.add_step_type_def() + def pragma_num_steps(self: Circuit, num_steps: int) -> None: + assert self.mode == CircuitMode.SETUP + self.ast.num_steps = num_steps - def trace( - self: CircuitContext, trace_def: Callable[[TraceContext, Any], None] - ): # TraceArgs are Any. - self.circuit.set_trace(trace_def) + def pragma_disable_q_enable(self: Circuit) -> None: + assert self.mode == CircuitMode.SETUP + self.ast.q_enable = False - def fixed_gen( - self: CircuitContext, fixed_gen_def: Callable[[FixedGenContext], None] - ): - self.circuit.set_fixed_gen(fixed_gen_def) + def add(self: Circuit, step_type: StepType, args: Any): + assert self.mode == CircuitMode.Trace + step_instance: StepInstance = step_type.gen_step_instance(args) + self.witness.step_instances.append(step_instance) - def pragma_first_step( - self: CircuitContext, step_type_context: StepTypeContext - ) -> None: - self.circuit.first_step = step_type_context.step_type.id - print(f"first step id: {step_type_context.step_type.id}") + def gen_witness(self: Circuit, args: Any) -> TraceWitness: + self.mode = CircuitMode.Trace + self.witness = TraceWitness() + self.trace(args) + self.mode = CircuitMode.NoMode + witness = self.witness + del self.witness + return witness - def pragma_last_step( - self: CircuitContext, step_type_context: StepTypeContext - ) -> None: - self.circuit.last_step = step_type_context.step_type.id - print(f"last step id: {step_type_context.step_type.id}") + def get_ast_json(self: Circuit) -> str: + return json.dumps(self.ast, cls=CustomEncoder, indent=4) - def pragma_num_steps(self: CircuitContext, num_steps: int) -> None: - self.circuit.num_steps = num_steps + def ast_to_halo2(self: Circuit): + ast_json: str = self.get_ast_json() + self.rust_ast_id: int = rust_chiquito.ast_to_halo2(ast_json) - def pragma_disable_q_enable(self: CircuitContext) -> None: - self.circuit.q_enable = False + def verify_proof(self: Circuit, witness: TraceWitness): + if self.rust_ast_id == 0: + self.ast_to_halo2() + witness_json: str = witness.get_witness_json() + rust_chiquito.verify_proof(witness_json, self.rust_ast_id) class StepTypeMode(Enum): @@ -87,37 +114,39 @@ class StepTypeMode(Enum): WG = 2 -class StepTypeContext: - - def __init__(self: StepTypeContext, circuit, step_type_name: str, ): - self.step_type = StepType.new(step_type_name) +class StepType: + def __init__(self: StepType, circuit: Circuit, step_type_name: str): + self.step_type = ASTStepType.new(step_type_name) self.circuit = circuit self.mode = StepTypeMode.SETUP self.setup() + + def gen_step_instance(self: StepType, args: Any) -> StepInstance: + self.mode = StepTypeMode.WG + self.step_instance = StepInstance.new(self.step_type.id) + self.wg(args) self.mode = StepTypeMode.NoMode + step_instance = self.step_instance + del self.step_instance + return step_instance - def internal(self: StepTypeContext, name: str) -> Internal: - assert (self.mode == StepTypeMode.SETUP) + def internal(self: StepType, name: str) -> Internal: + assert self.mode == StepTypeMode.SETUP return Internal(self.step_type.add_signal(name)) - def wg( - self: StepTypeContext, wg_def: Callable[[TraceContext, Any], None] - ): # Args are Any. - self.step_type.set_wg(wg_def) - - def constr(self: StepTypeContext, constraint: ToConstraint): - assert (self.mode == StepTypeMode.SETUP) + def constr(self: StepType, constraint: ToConstraint): + assert self.mode == StepTypeMode.SETUP constraint = to_constraint(constraint) - StepTypeContext.enforce_constraint_typing(constraint) + StepType.enforce_constraint_typing(constraint) self.step_type.add_constr(constraint.annotation, constraint.expr) - def transition(self: StepTypeContext, constraint: ToConstraint): - assert (self.mode == StepTypeMode.SETUP) + def transition(self: StepType, constraint: ToConstraint): + assert self.mode == StepTypeMode.SETUP constraint = to_constraint(constraint) - StepTypeContext.enforce_constraint_typing(constraint) + StepType.enforce_constraint_typing(constraint) self.step_type.add_transition(constraint.annotation, constraint.expr) def enforce_constraint_typing(constraint: Constraint): @@ -126,12 +155,9 @@ def enforce_constraint_typing(constraint: Constraint): f"Expected AntiBooly constraint, got {constraint.typing} (constraint: {constraint.annotation})" ) - # TODO: Implement add_lookup after lookup abstraction PR is merged. + def assign(self: StepType, lhs: Queriable, rhs: F): + assert self.mode == StepTypeMode.WG + self.step_instance.assign(lhs, rhs) -def circuit( - name: str, circuit_context_def: Callable[[CircuitContext], None] -) -> Circuit: - ctx = CircuitContext() - circuit_context_def(ctx) - return ctx.circuit + # TODO: Implement add_lookup after lookup abstraction PR is merged. diff --git a/pychiquito/expr.py b/pychiquito/expr.py index df0e51e..32fd988 100644 --- a/pychiquito/expr.py +++ b/pychiquito/expr.py @@ -4,9 +4,6 @@ from util import F -############ -# ast/expr # -############ # pub enum Expr { # Const(F), @@ -141,11 +138,7 @@ def __json__(self): return {"Pow": [self.expr.__json__(), self.pow]} -# Ignored Expr::Halo2Expr. - -# Removed Constraint variant to avoid circular reference. ToExpr = Expr | int | F -# | Constraint def to_expr(v: ToExpr) -> Expr: @@ -158,8 +151,6 @@ def to_expr(v: ToExpr) -> Expr: return Neg(Const(F(-v))) elif isinstance(v, F): return Const(v) - # elif isinstance(v, Constraint): - # return v.expr else: raise TypeError( f"Type {type(v)} is not ToExpr (one of Expr, int, F, or Constraint)." diff --git a/pychiquito/fibonacci.py b/pychiquito/fibonacci.py index a256580..284c667 100644 --- a/pychiquito/fibonacci.py +++ b/pychiquito/fibonacci.py @@ -1,31 +1,19 @@ from __future__ import annotations -from typing import Any, Tuple -from py_ecc import bn128 -import json -import rust_chiquito # rust bindings +from typing import Tuple -from dsl import CircuitContext, StepTypeContext -from chiquito_ast import StepType, First, Last, Step +from dsl import Circuit, StepType from cb import eq from query import Queriable -from wit_gen import TraceContext, StepInstance, TraceGenerator +from util import F -F = bn128.FQ - -class Fibonacci(CircuitContext): - def __init__(self: Fibonacci): - super().__init__() - self.a: Queriable = self.forward( - "a" - ) # `self.a` is required instead of `a`, because steps need to access `circuit.a`. +class Fibonacci(Circuit): + def setup(self: Fibonacci): + self.a: Queriable = self.forward("a") self.b: Queriable = self.forward("b") - self.fibo_step = self.step_type( - FiboStep(self, "fibo_step")) - self.fibo_last_step = self.step_type( - FiboLastStep(self, "fibo_last_step") - ) + self.fibo_step = self.step_type(FiboStep(self, "fibo_step")) + self.fibo_last_step = self.step_type(FiboLastStep(self, "fibo_last_step")) self.pragma_first_step(self.fibo_step) self.pragma_last_step(self.fibo_last_step) @@ -36,99 +24,45 @@ def __init__(self: Fibonacci): # self.expose(self.a, Last()) # self.expose(self.a, Step(1)) - def trace(self: Fibonacci): - def trace_def(ctx: TraceContext, _: Any): # Any instead of TraceArgs - ctx.add(self, self.fibo_step, (1, 1)) - a = 1 - b = 2 - for i in range(1, 10): - ctx.add(self, self.fibo_step, (a, b)) - prev_a = a - a = b - b += prev_a - ctx.add(self, self.fibo_last_step, (a, b)) - - super().trace(trace_def) + def trace(self: Fibonacci, args: Any): + self.add(self.fibo_step, (1, 1)) + a = 1 + b = 2 + for i in range(1, 10): + self.add(self.fibo_step, (a, b)) + prev_a = a + a = b + b += prev_a + self.add(self.fibo_last_step, (a, b)) -class FiboStep(StepTypeContext): +class FiboStep(StepType): def setup(self: FiboStep): - self.c = self.internal( - "c" - ) # `self.c` is required instead of `c`, because wg needs to access `self.c`. + self.c = self.internal("c") self.constr(eq(self.circuit.a + self.circuit.b, self.c)) self.transition(eq(self.circuit.b, self.circuit.a.next())) self.transition(eq(self.c, self.circuit.b.next())) - def wg(self: FiboStep, circuit: Fibonacci): - # Any instead of Args - def wg_def(ctx: StepInstance, values: Tuple[int, int]): - a_value, b_value = values - # print(f"fib step wg: {a_value}, {b_value}, {a_value + b_value}") - ctx.assign(circuit.a, F(a_value)) - ctx.assign(circuit.b, F(b_value)) - ctx.assign(self.c, F(a_value + b_value)) + def wg(self: FiboStep, args: Tuple[int, int]): + a_value, b_value = args + self.assign(self.circuit.a, F(a_value)) + self.assign(self.circuit.b, F(b_value)) + self.assign(self.c, F(a_value + b_value)) - super().wg(wg_def) - -class FiboLastStep(StepTypeContext): +class FiboLastStep(StepType): def setup(self: FiboLastStep): self.c = self.internal("c") self.constr(eq(self.circuit.a + self.circuit.b, self.c)) - def wg(self: FiboLastStep, circuit: Fibonacci): - # Any instead of Args - def wg_def(ctx: StepInstance, values: Tuple[int, int]): - a_value, b_value = values - print( - f"fib last step wg: {a_value}, {b_value}, {a_value + b_value}\n") - ctx.assign(circuit.a, F(a_value)) - ctx.assign(circuit.b, F(b_value)) - ctx.assign(self.c, F(a_value + b_value)) - - super().wg(wg_def) + def wg(self: FiboLastStep, values=Tuple[int, int]): + a_value, b_value = values + self.assign(self.circuit.a, F(a_value)) + self.assign(self.circuit.b, F(b_value)) + self.assign(self.c, F(a_value + b_value)) -# Print Circuit fibo = Fibonacci() -fibo.trace() -print("Print Circuit using custom __str__ method in python:") -print(fibo.circuit) -print("Print Circuit using __json__ method in python:") - - -class CustomEncoder(json.JSONEncoder): - def default(self, obj): - if hasattr(obj, "__json__"): - return obj.__json__() - return super().default(obj) - - -# Print Circuit -print("Print Circuit using custom __str__ method in python:") -print(fibo.circuit) -print("Print Circuit using __json__ method in python:") -circuit_json = json.dumps(fibo.circuit, cls=CustomEncoder, indent=4) -print(circuit_json) - -# Print TraceWitness -trace_generator = TraceGenerator(fibo.circuit.trace) -trace_witness = trace_generator.generate(None) -print("Print TraceWitness using custom __str__ method in python:") -print(trace_witness) -print("Print TraceWitness using __json__ method in python:") -trace_witness_json = json.dumps(trace_witness, cls=CustomEncoder, indent=4) -print(trace_witness_json) - -# Rust bindings for Circuit -print("Call rust bindings, parse json to Chiquito ast, and print using Debug trait:") -rust_chiquito.convert_and_print_ast(circuit_json) -print( - "Call rust bindings, parse json to Chiquito TraceWitness, and print using Debug trait:" -) -rust_chiquito.convert_and_print_trace_witness(trace_witness_json) -print("Parse json to Chiquito Halo2, and obtain UUID:") -ast_uuid: int = rust_chiquito.ast_to_halo2(circuit_json) -print("Verify ciruit with ast uuid and trace witness json:") -rust_chiquito.verify_proof(trace_witness_json, ast_uuid) +fibo_witness = fibo.gen_witness(None) +fibo.ast_to_halo2() +fibo.verify_proof(fibo_witness) diff --git a/pychiquito/query.py b/pychiquito/query.py index 99079b1..a4ff2b1 100644 --- a/pychiquito/query.py +++ b/pychiquito/query.py @@ -3,19 +3,15 @@ from expr import Expr # Commented out to avoid circular reference -# from chiquito_ast import InternalSignal, ForwardSignal, SharedSignal, FixedSignal, StepType +# from chiquito_ast import InternalSignal, ForwardSignal, SharedSignal, FixedSignal, ASTStepType -###################### -# ast/expr/queriable # -###################### - # pub enum Queriable { # Internal(InternalSignal), # Forward(ForwardSignal, bool), # Shared(SharedSignal, i32), # Fixed(FixedSignal, i32), -# StepTypeNext(StepTypeHandler), +# StepTypeNext(ASTStepTypeHandler), # Halo2AdviceQuery(ImportedHalo2Advice, i32), # Halo2FixedQuery(ImportedHalo2Fixed, i32), # #[allow(non_camel_case_types)] @@ -127,19 +123,16 @@ def __json__(self): class StepTypeNext(Queriable): - def __init__(self: StepTypeNext, step_type: StepType): + def __init__(self: StepTypeNext, step_type: ASTStepType): self.step_type = step_type - def uuid(self: StepType) -> int: + def uuid(self: ASTStepType) -> int: return self.id - def __str__(self: StepType) -> str: + def __str__(self: ASTStepType) -> str: return self.name def __json__(self): return { "StepTypeNext": {"id": self.step_type.id, "annotation": self.step_type.name} } - - -# Ignored Queriable::Halo2AdviceQuery and Queriable::Halo2FixedQuery. diff --git a/pychiquito/tests.py b/pychiquito/tests.py deleted file mode 100644 index a05909b..0000000 --- a/pychiquito/tests.py +++ /dev/null @@ -1,50 +0,0 @@ -from chiquito_ast import ( - StepType, - ASTConstraint, - TransitionConstraint, - InternalSignal, - ForwardSignal, - SharedSignal, - FixedSignal, -) -from query import Internal, Forward, Shared, Fixed -from expr import Const, Sum, Mul - -######## -# test # -######## -# print(Internal(InternalSignal("a")).__json__()) -# print(Forward(ForwardSignal(1, "a"), True).__json__()) -# print(Shared(SharedSignal(0, "a"), 2).__json__()) -# print(Fixed(FixedSignal("a"), 2).__json__()) -# print(StepTypeNext(StepType.new("fibo")).__json__()) -# print(ASTConstraint("constraint", Sum([Const(1), Mul([Internal(InternalSignal("a")), Const(3)])])).__json__()) -# print(TransitionConstraint("trans", Sum([Const(1), Mul([Internal(InternalSignal("a")), Const(3)])])).__json__()) -print( - StepType( - 1, - "fibo", - [InternalSignal("a"), InternalSignal("b")], - [ - ASTConstraint( - "constraint", - Sum([Const(1), Mul([Internal(InternalSignal("c")), Const(3)])]), - ), - ASTConstraint( - "constraint", - Sum([Const(1), Mul([Shared(SharedSignal(2, "d"), 1), Const(3)])]), - ), - ], - [ - TransitionConstraint( - "trans", - Sum([Const(1), Mul([Forward(ForwardSignal(1, "e"), True), Const(3)])]), - ), - TransitionConstraint( - "trans", Sum([Const(1), Mul([Fixed(FixedSignal("e"), 2), Const(3)])]) - ), - ], - {5: "a", 6: "b", 7: "c"}, - None, - ).__json__() -) diff --git a/pychiquito/util.py b/pychiquito/util.py index 40b144a..d838aa1 100644 --- a/pychiquito/util.py +++ b/pychiquito/util.py @@ -1,6 +1,7 @@ from __future__ import annotations from py_ecc import bn128 from uuid import uuid1 +import json F = bn128.FQ @@ -18,6 +19,13 @@ def json_method(self: F): F.__json__ = json_method +class CustomEncoder(json.JSONEncoder): + def default(self, obj): + if hasattr(obj, "__json__"): + return obj.__json__() + return super().default(obj) + + # int field is the u128 version of uuid. def uuid() -> int: return uuid1(node=int.from_bytes([10, 10, 10, 10, 10, 10], byteorder="little")).int diff --git a/pychiquito/wit_gen.py b/pychiquito/wit_gen.py index a92eee3..b0ba8f9 100644 --- a/pychiquito/wit_gen.py +++ b/pychiquito/wit_gen.py @@ -1,16 +1,13 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Dict, List, Callable, Any +import json from query import Queriable, Fixed -from util import F +from util import F, CustomEncoder # Commented out to avoid circular reference -# from dsl import CircuitContext, StepTypeContext - -########### -# wit_gen # -########### +# from dsl import Circuit, StepType @dataclass @@ -41,7 +38,7 @@ def __str__(self: StepInstance): f"\t\t)" ) - # For assignments, return "uuid: F" rather than "Queriable: F", because JSON doesn't accept Dict as key. + # For assignments, return "uuid: (Queriable, F)" rather than "Queriable: F", because JSON doesn't accept Dict as key. def __json__(self: StepInstance): return { "step_type_uuid": self.step_type_uuid, @@ -84,38 +81,8 @@ def __json__(self: TraceWitness): "height": self.height, } - -@dataclass -class TraceContext: - witness: TraceWitness = field(default_factory=TraceWitness) - - def add( - self: TraceContext, circuit: CircuitContext, step: StepTypeContext, args: Any - ): # Use StepTypeContext instead of StepTypeWGHandler, because StepTypeContext contains step type id and `wg` method that returns witness generation function. - witness = StepInstance.new(step.step_type.id) - step.wg(circuit) - if step.step_type.wg is None: - raise ValueError( - f"Step type {step.step_type.name} does not have a witness generator." - ) - step.step_type.wg(witness, args) - self.witness.step_instances.append(witness) - - def set_height(self: TraceContext, height: int): - self.witness.height = height - - -Trace = Callable[[TraceContext, Any], None] # TraceArgs are Any. - - -@dataclass -class TraceGenerator: - trace: Trace - - def generate(self: TraceGenerator, args: Any) -> TraceWitness: # Args are Any. - ctx = TraceContext() - self.trace(ctx, args) - return ctx.witness + def get_witness_json(self: TraceWitness) -> str: + return json.dumps(self, cls=CustomEncoder, indent=4) FixedAssigment = Dict[Queriable, List[F]] @@ -140,7 +107,7 @@ def assign(self: FixedGenContext, offset: int, lhs: Queriable, rhs: F): def is_fixed_queriable(q: Queriable) -> bool: match q.enum: - case Fixed(_, _): # Ignored Halo2FixedQuery enum type. + case Fixed(_, _): return True case _: return False diff --git a/src/lib.rs b/src/lib.rs index 8e05260..99edb92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,6 @@ fn convert_and_print_trace_witness(json: &PyString) { #[pyfunction] fn ast_to_halo2(json: &PyString) -> u128 { let uuid = chiquito_ast_to_halo2(json.to_str().expect("PyString convertion failed.")); - println!("{:?}", uuid); uuid }