Skip to content

Commit

Permalink
addressed leo comments and modified wg function
Browse files Browse the repository at this point in the history
  • Loading branch information
qwang98 committed Jul 28, 2023
1 parent 0bebfb7 commit daf1201
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 103 deletions.
15 changes: 1 addition & 14 deletions pychiquito/chiquito_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class ASTCircuit:
fixed_signals: List[FixedSignal] = field(default_factory=list)
exposed: List[Tuple[Queriable, ExposeOffset]] = field(default_factory=list)
annotations: Dict[int, str] = field(default_factory=dict)
trace: Optional[Callable[[ASTStepType, Any], None]] = None
fixed_gen: Optional[Callable] = None
first_step: Optional[int] = None
last_step: Optional[int] = None
Expand Down Expand Up @@ -94,7 +93,6 @@ def __str__(self: ASTCircuit):
f"\tfixed_signals=[{fixed_signals_str}],\n"
f"\texposed=[{exposed_str}],\n"
f"\tannotations={{{annotations_str}}},\n"
f"\ttrace={self.trace},\n"
f"\tfixed_gen={self.fixed_gen},\n"
f"\tfirst_step={self.first_step},\n"
f"\tlast_step={self.last_step},\n"
Expand Down Expand Up @@ -146,14 +144,6 @@ def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str):
self.annotations[step_type.id] = name
self.step_types[step_type.id] = step_type

def set_trace(
self: ASTCircuit, trace_def: Callable[[TraceContext, Any], None]
): # TraceArgs are Any.
if self.trace is not None:
raise Exception("ASTCircuit cannot have more than one trace generator.")
else:
self.trace = trace_def

def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]):
if self.fixed_gen is not None:
raise Exception("ASTCircuit cannot have more than one fixed generator.")
Expand Down Expand Up @@ -187,12 +177,9 @@ class ASTStepType:
constraints: List[ASTConstraint]
transition_constraints: List[TransitionConstraint]
annotations: Dict[int, str]
wg: Optional[
Callable[[StepInstance, Any], None]
] # Args are Any. Not passed to Rust Chiquito.

def new(name: str) -> ASTStepType:
return ASTStepType(uuid(), name, [], [], [], {}, None)
return ASTStepType(uuid(), name, [], [], [], {})

def __str__(self):
signals_str = (
Expand Down
112 changes: 48 additions & 64 deletions pychiquito/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset
from query import Internal, Forward, Queriable, Shared, Fixed
from wit_gen import FixedGenContext, TraceContext, StepInstance
from wit_gen import FixedGenContext, TraceContext, StepInstance, TraceWitness
from cb import Constraint, Typing, ToConstraint, to_constraint
from util import CustomEncoder, F

Expand Down Expand Up @@ -68,12 +68,12 @@ def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]):
def pragma_first_step(self: Circuit, step_type: StepType) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.first_step = step_type.step_type.id
print(f"first step id: {step_type.step_type.id}")
# print(f"first step id: {step_type.step_type.id}")

def pragma_last_step(self: Circuit, step_type: StepType) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.last_step = step_type.step_type.id
print(f"last step id: {step_type.step_type.id}")
# print(f"last step id: {step_type.step_type.id}")

def pragma_num_steps(self: Circuit, num_steps: int) -> None:
assert self.mode == CircuitMode.SETUP
Expand All @@ -83,80 +83,58 @@ def pragma_disable_q_enable(self: Circuit) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.q_enable = False

def gen_witness(self: Circuit, args: Any) -> TraceWitness:
self.mode = CircuitMode.Trace
self.trace_context = TraceContext()
self.ast.set_trace(self.trace)
self.trace(args)
return self.trace_context.witness

def add(self: Circuit, step_type: StepType, args: Any):
print(self)
print(step_type)
print(args)
assert self.mode == CircuitMode.Trace
self.trace_context.add(self, step_type, args)

def print_ast(self: Circuit):
print("Print ASTCircuit using custom __str__ method in python:")
print(self.ast)

def get_ast_json(self: Circuit, print_json=False) -> str:
ast_json: str = json.dumps(self.ast, cls=CustomEncoder, indent=4)
if print_json:
print("Print ASTCircuit using __json__ method in python:")
print(ast_json)
return ast_json

def print_witness(self: Circuit):
print("Print TraceWitness using custom __str__ method in python:")
print(self.trace_context.witness)
def gen_witness(self: Circuit, args: Any) -> TraceWitness:
self.mode = CircuitMode.Trace
self.trace_context = TraceContext()
self.trace(args)
self.mode = CircuitMode.NoMode
witness = self.trace_context.witness
del self.trace_context
return witness

def get_witness_json(self: Circuit, print_json=False) -> str:
witness_json: str = json.dumps(
self.trace_context.witness, cls=CustomEncoder, indent=4
)
if print_json:
print("Print TraceWitness using __json__ method in python:")
print(witness_json)
return witness_json
def get_ast_json(self: Circuit) -> str:
return json.dumps(self.ast, cls=CustomEncoder, indent=4)

def convert_and_print_ast(self: Circuit, print_ast=False):
def convert_and_print_ast(self: Circuit):
ast_json: str = self.get_ast_json()
if print_ast:
print(
"Call rust bindings, parse json to Chiquito ASTCircuit, and print using Debug trait:"
)
print(rust_chiquito.convert_and_print_ast(ast_json))

def convert_and_print_witness(self: Circuit, print_witness=False):
witness_json: str = self.get_witness_json()
if print_witness:
print(
"Call rust bindings, parse json to Chiquito TraceWitness, and print using Debug trait:"
)
print(rust_chiquito.convert_and_print_trace_witness(witness_json))
print(
"Call rust bindings, parse json to Chiquito ASTCircuit, and print using Debug trait:"
)
rust_chiquito.convert_and_print_ast(ast_json)

def ast_to_halo2(self: Circuit, print_ast_id=False):
def ast_to_halo2(self: Circuit):
ast_json: str = self.get_ast_json()
self.rust_ast_id: int = rust_chiquito.ast_to_halo2(ast_json)
if print_ast_id:
print("Parse json to Chiquito Halo2, and obtain UUID:")
print(self.rust_ast_id)

def verify_proof(self: Circuit, print_inputs=False):
def verify_proof(self: Circuit, witness: TraceWitness):
if self.rust_ast_id == 0:
self.rust_ast_id = self.ast_to_halo2()
witness_json: str = self.get_witness_json()
if print_inputs:
print("Rust AST UUID:")
print(self.rust_ast_id)
print("Print TraceWitness using __json__ method in python:")
print(witness_json)
print("Verify ciruit with AST uuid and witness json:")
self.ast_to_halo2()
witness_json: str = witness.get_witness_json()
rust_chiquito.verify_proof(witness_json, self.rust_ast_id)


# Debug method
def convert_and_print_witness(witness: TraceWitness):
witness_json: str = witness.get_witness_json()
rust_chiquito.convert_and_print_trace_witness(witness_json)


# Debug method
def print_ast(ast: ASTCircuit):
print("Print ASTCircuit using custom __str__ method in python:")
print(ast)


# Debug method
def print_witness(witness: TraceWitness):
print("Print TraceWitness using custom __str__ method in python:")
print(witness)


class StepTypeMode(Enum):
NoMode = 0
SETUP = 1
Expand All @@ -167,11 +145,17 @@ class StepType:
def __init__(self: StepType, circuit: Circuit, step_type_name: str):
self.step_type = ASTStepType.new(step_type_name)
self.circuit = circuit
self.step_instance = StepInstance.new(self.step_type.id)
self.mode = StepTypeMode.SETUP
self.setup()

def gen_step_instance(self: StepType, args: Any) -> StepInstance:
self.mode = StepTypeMode.WG
self.step_type.set_wg(self.wg)
self.step_instance = StepInstance.new(self.step_type.id)
self.wg(args)
self.mode = StepTypeMode.NoMode
step_instance = self.step_instance
del self.step_instance
return step_instance

def internal(self: StepType, name: str) -> Internal:
assert self.mode == StepTypeMode.SETUP
Expand Down
26 changes: 13 additions & 13 deletions pychiquito/fibonacci.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Tuple

from dsl import Circuit, StepType
from dsl import Circuit, StepType, print_ast, print_witness, convert_and_print_witness
from cb import eq
from query import Queriable
from util import F
Expand Down Expand Up @@ -38,15 +38,13 @@ def trace(self: Fibonacci, args: Any):

class FiboStep(StepType):
def setup(self: FiboStep):
self.c = self.internal(
"c"
) # `self.c` is required instead of `c`, because wg needs to access `self.c`.
self.c = self.internal("c")
self.constr(eq(self.circuit.a + self.circuit.b, self.c))
self.transition(eq(self.circuit.b, self.circuit.a.next()))
self.transition(eq(self.c, self.circuit.b.next()))

def wg(self: FiboStep, values: Tuple[int, int]):
a_value, b_value = values
def wg(self: FiboStep, args: Tuple[int, int]):
a_value, b_value = args
self.assign(self.circuit.a, F(a_value))
self.assign(self.circuit.b, F(b_value))
self.assign(self.c, F(a_value + b_value))
Expand All @@ -65,10 +63,12 @@ def wg(self: FiboLastStep, values=Tuple[int, int]):


fibo = Fibonacci()
fibo.print_ast()
fibo.gen_witness(None)
fibo.print_witness()
fibo.convert_and_print_ast(print_ast=True)
fibo.convert_and_print_witness(print_witness=True)
fibo.ast_to_halo2(print_ast_id=True)
fibo.verify_proof(print_inputs=True)
fibo_witness = fibo.gen_witness(None)
fibo.convert_and_print_ast()
fibo.ast_to_halo2()
fibo.verify_proof(fibo_witness)

# Debug methods
# print_ast(fibo.ast)
# print_witness(fibo_witness)
# convert_and_print_witness(fibo_witness)
19 changes: 8 additions & 11 deletions pychiquito/wit_gen.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Callable, Any
import json

from query import Queriable, Fixed
from util import F
from util import F, CustomEncoder

# Commented out to avoid circular reference
# from dsl import Circuit, StepType
Expand Down Expand Up @@ -84,21 +85,17 @@ def __json__(self: TraceWitness):
"height": self.height,
}

def get_witness_json(self: TraceWitness) -> str:
return json.dumps(self, cls=CustomEncoder, indent=4)


@dataclass
class TraceContext:
witness: TraceWitness = field(default_factory=TraceWitness)

def add(
self: TraceContext, circuit: Circuit, step: StepType, args: Any
): # Use StepType instead of StepTypeWGHandler, because StepType contains step type id and `wg` method that returns witness generation function.
step.wg(args)
if step.step_type.wg is None:
raise ValueError(
f"Step type {step.step_type.name} does not have a witness generator."
)
step.step_type.wg(args)
self.witness.step_instances.append(step.step_instance)
def add(self: TraceContext, circuit: Circuit, step: StepType, args: Any):
step_instance: StepInstance = step.gen_step_instance(args)
self.witness.step_instances.append(step_instance)

def set_height(self: TraceContext, height: int):
self.witness.height = height
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ fn convert_and_print_trace_witness(json: &PyString) {
#[pyfunction]
fn ast_to_halo2(json: &PyString) -> u128 {
let uuid = chiquito_ast_to_halo2(json.to_str().expect("PyString convertion failed."));
println!("{:?}", uuid);

uuid
}
Expand Down

0 comments on commit daf1201

Please sign in to comment.