diff --git a/pychiquito/chiquito_ast.py b/pychiquito/chiquito_ast.py index f33a41d..0b5c103 100644 --- a/pychiquito/chiquito_ast.py +++ b/pychiquito/chiquito_ast.py @@ -2,7 +2,7 @@ from typing import Callable, List, Dict, Optional, Any, Tuple from dataclasses import dataclass, field, asdict -from wit_gen import TraceContext, FixedGenContext, StepInstance +from wit_gen import FixedGenContext, StepInstance from expr import Expr from util import uuid from query import Queriable diff --git a/pychiquito/dsl.py b/pychiquito/dsl.py index bb5b0e5..5ca991d 100644 --- a/pychiquito/dsl.py +++ b/pychiquito/dsl.py @@ -1,13 +1,12 @@ from __future__ import annotations from enum import Enum from typing import Callable, Any -from dataclasses import dataclass import rust_chiquito # rust bindings import json from chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset from query import Internal, Forward, Queriable, Shared, Fixed -from wit_gen import FixedGenContext, TraceContext, StepInstance, TraceWitness +from wit_gen import FixedGenContext, StepInstance, TraceWitness from cb import Constraint, Typing, ToConstraint, to_constraint from util import CustomEncoder, F @@ -21,7 +20,7 @@ class CircuitMode(Enum): class Circuit: def __init__(self: Circuit): self.ast = ASTCircuit() - self.trace_context = TraceContext() + self.witness = TraceWitness() self.rust_ast_id = 0 self.mode = CircuitMode.SETUP self.setup() @@ -68,12 +67,10 @@ def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]): def pragma_first_step(self: Circuit, step_type: StepType) -> None: assert self.mode == CircuitMode.SETUP self.ast.first_step = step_type.step_type.id - # print(f"first step id: {step_type.step_type.id}") def pragma_last_step(self: Circuit, step_type: StepType) -> None: assert self.mode == CircuitMode.SETUP self.ast.last_step = step_type.step_type.id - # print(f"last step id: {step_type.step_type.id}") def pragma_num_steps(self: Circuit, num_steps: int) -> None: assert self.mode == CircuitMode.SETUP @@ -85,27 +82,21 @@ def pragma_disable_q_enable(self: Circuit) -> None: def add(self: Circuit, step_type: StepType, args: Any): assert self.mode == CircuitMode.Trace - self.trace_context.add(self, step_type, args) + step_instance: StepInstance = step_type.gen_step_instance(args) + self.witness.step_instances.append(step_instance) def gen_witness(self: Circuit, args: Any) -> TraceWitness: self.mode = CircuitMode.Trace - self.trace_context = TraceContext() + self.witness = TraceWitness() self.trace(args) self.mode = CircuitMode.NoMode - witness = self.trace_context.witness - del self.trace_context + witness = self.witness + del self.witness return witness def get_ast_json(self: Circuit) -> str: return json.dumps(self.ast, cls=CustomEncoder, indent=4) - def convert_and_print_ast(self: Circuit): - ast_json: str = self.get_ast_json() - print( - "Call rust bindings, parse json to Chiquito ASTCircuit, and print using Debug trait:" - ) - rust_chiquito.convert_and_print_ast(ast_json) - def ast_to_halo2(self: Circuit): ast_json: str = self.get_ast_json() self.rust_ast_id: int = rust_chiquito.ast_to_halo2(ast_json) @@ -117,24 +108,6 @@ def verify_proof(self: Circuit, witness: TraceWitness): rust_chiquito.verify_proof(witness_json, self.rust_ast_id) -# Debug method -def convert_and_print_witness(witness: TraceWitness): - witness_json: str = witness.get_witness_json() - rust_chiquito.convert_and_print_trace_witness(witness_json) - - -# Debug method -def print_ast(ast: ASTCircuit): - print("Print ASTCircuit using custom __str__ method in python:") - print(ast) - - -# Debug method -def print_witness(witness: TraceWitness): - print("Print TraceWitness using custom __str__ method in python:") - print(witness) - - class StepTypeMode(Enum): NoMode = 0 SETUP = 1 diff --git a/pychiquito/fibonacci.py b/pychiquito/fibonacci.py index f87a7da..284c667 100644 --- a/pychiquito/fibonacci.py +++ b/pychiquito/fibonacci.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Tuple -from dsl import Circuit, StepType, print_ast, print_witness, convert_and_print_witness +from dsl import Circuit, StepType from cb import eq from query import Queriable from util import F @@ -64,11 +64,5 @@ def wg(self: FiboLastStep, values=Tuple[int, int]): fibo = Fibonacci() fibo_witness = fibo.gen_witness(None) -fibo.convert_and_print_ast() fibo.ast_to_halo2() fibo.verify_proof(fibo_witness) - -# Debug methods -# print_ast(fibo.ast) -# print_witness(fibo_witness) -# convert_and_print_witness(fibo_witness) diff --git a/pychiquito/wit_gen.py b/pychiquito/wit_gen.py index 346c9ea..bdad541 100644 --- a/pychiquito/wit_gen.py +++ b/pychiquito/wit_gen.py @@ -89,31 +89,6 @@ def get_witness_json(self: TraceWitness) -> str: return json.dumps(self, cls=CustomEncoder, indent=4) -@dataclass -class TraceContext: - witness: TraceWitness = field(default_factory=TraceWitness) - - def add(self: TraceContext, circuit: Circuit, step: StepType, args: Any): - step_instance: StepInstance = step.gen_step_instance(args) - self.witness.step_instances.append(step_instance) - - def set_height(self: TraceContext, height: int): - self.witness.height = height - - -Trace = Callable[[TraceContext, Any], None] # TraceArgs are Any. - - -@dataclass -class TraceGenerator: - trace: Trace - - def generate(self: TraceGenerator, args: Any) -> TraceWitness: # Args are Any. - ctx = TraceContext() - self.trace(ctx, args) - return ctx.witness - - FixedAssigment = Dict[Queriable, List[F]]