Skip to content

Commit

Permalink
Steve/refactor circuit (#8)
Browse files Browse the repository at this point in the history
* refactored everything

* refactored everything and example

* removed unused imports

* renamed step_type_context vairables

* fixed type and field namings

* updated generate_witness method

* addressed leo comments and modified wg function

* updated tracecontext and cleaned up debug methods

* cleaned up other files

* removed tests
  • Loading branch information
qwang98 authored Aug 1, 2023
1 parent 8ee46cd commit 905adc4
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 358 deletions.
12 changes: 4 additions & 8 deletions pychiquito/cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
from util import F
from expr import Expr, Const, Neg, to_expr, ToExpr
from query import StepTypeNext
from chiquito_ast import StepType

##########
# dsl/cb #
##########
from chiquito_ast import ASTStepType


class Typing(Enum):
Expand Down Expand Up @@ -165,7 +161,7 @@ def isz(constraint: ToConstraint) -> Constraint:
)


def if_next_step(step_type: StepType, constraint: ToConstraint) -> Constraint:
def if_next_step(step_type: ASTStepType, constraint: ToConstraint) -> Constraint:
constraint = to_constraint(constraint)
return Constraint(
f"if(next step is {step_type.annotation})then({constraint.annotation})",
Expand All @@ -174,15 +170,15 @@ def if_next_step(step_type: StepType, constraint: ToConstraint) -> Constraint:
)


def next_step_must_be(step_type: StepType) -> Constraint:
def next_step_must_be(step_type: ASTStepType) -> Constraint:
return Constraint(
f"next step must be {step_type.annotation}",
Constraint.cb_not(StepTypeNext(step_type)),
Typing.AntiBooly,
)


def next_step_must_not_be(step_type: StepType) -> Constraint:
def next_step_must_not_be(step_type: ASTStepType) -> Constraint:
return Constraint(
f"next step must not be {step_type.annotation}",
StepTypeNext(step_type),
Expand Down
95 changes: 34 additions & 61 deletions pychiquito/chiquito_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
from typing import Callable, List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, field, asdict

from wit_gen import TraceContext, FixedGenContext, StepInstance
from wit_gen import FixedGenContext, StepInstance
from expr import Expr
from util import uuid
from query import Queriable

#######
# ast #
#######

# pub struct Circuit<F, TraceArgs> {
# pub step_types: HashMap<u32, Rc<StepType<F>>>,
# pub step_types: HashMap<u32, Rc<ASTStepType<F>>>,

# pub forward_signals: Vec<ForwardSignal>,
# pub shared_signals: Vec<SharedSignal>,
Expand All @@ -26,29 +23,28 @@
# pub trace: Option<Rc<Trace<F, TraceArgs>>>,
# pub fixed_gen: Option<Rc<FixedGen<F>>>,

# pub first_step: Option<StepTypeUUID>,
# pub last_step: Option<StepTypeUUID>,
# pub first_step: Option<ASTStepTypeUUID>,
# pub last_step: Option<ASTStepTypeUUID>,
# pub num_steps: usize,
# }


@dataclass
class Circuit:
step_types: Dict[int, StepType] = field(default_factory=dict)
class ASTCircuit:
step_types: Dict[int, ASTStepType] = field(default_factory=dict)
forward_signals: List[ForwardSignal] = field(default_factory=list)
shared_signals: List[SharedSignal] = field(default_factory=list)
fixed_signals: List[FixedSignal] = field(default_factory=list)
exposed: List[Tuple[Queriable, ExposeOffset]] = field(default_factory=list)
annotations: Dict[int, str] = field(default_factory=dict)
trace: Optional[Callable[[TraceContext, Any], None]] = None
fixed_gen: Optional[Callable] = None
first_step: Optional[int] = None
last_step: Optional[int] = None
num_steps: int = 0
q_enable: bool = True
id: int = uuid()

def __str__(self: Circuit):
def __str__(self: ASTCircuit):
step_types_str = (
"\n\t\t"
+ ",\n\t\t".join(f"{k}: {v}" for k, v in self.step_types.items())
Expand All @@ -57,20 +53,17 @@ def __str__(self: Circuit):
else ""
)
forward_signals_str = (
"\n\t\t" + ",\n\t\t".join(str(fs)
for fs in self.forward_signals) + "\n\t"
"\n\t\t" + ",\n\t\t".join(str(fs) for fs in self.forward_signals) + "\n\t"
if self.forward_signals
else ""
)
shared_signals_str = (
"\n\t\t" + ",\n\t\t".join(str(ss)
for ss in self.shared_signals) + "\n\t"
"\n\t\t" + ",\n\t\t".join(str(ss) for ss in self.shared_signals) + "\n\t"
if self.shared_signals
else ""
)
fixed_signals_str = (
"\n\t\t" + ",\n\t\t".join(str(fs)
for fs in self.fixed_signals) + "\n\t"
"\n\t\t" + ",\n\t\t".join(str(fs) for fs in self.fixed_signals) + "\n\t"
if self.fixed_signals
else ""
)
Expand All @@ -90,14 +83,13 @@ def __str__(self: Circuit):
)

return (
f"Circuit(\n"
f"ASTCircuit(\n"
f"\tstep_types={{{step_types_str}}},\n"
f"\tforward_signals=[{forward_signals_str}],\n"
f"\tshared_signals=[{shared_signals_str}],\n"
f"\tfixed_signals=[{fixed_signals_str}],\n"
f"\texposed=[{exposed_str}],\n"
f"\tannotations={{{annotations_str}}},\n"
f"\ttrace={self.trace},\n"
f"\tfixed_gen={self.fixed_gen},\n"
f"\tfirst_step={self.first_step},\n"
f"\tlast_step={self.last_step},\n"
Expand All @@ -106,7 +98,7 @@ def __str__(self: Circuit):
f")"
)

def __json__(self: Circuit):
def __json__(self: ASTCircuit):
return {
"step_types": {k: v.__json__() for k, v in self.step_types.items()},
"forward_signals": [x.__json__() for x in self.forward_signals],
Expand All @@ -124,52 +116,42 @@ def __json__(self: Circuit):
"id": self.id,
}

def add_forward(self: Circuit, name: str, phase: int) -> ForwardSignal:
def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal:
signal = ForwardSignal(phase, name)
self.forward_signals.append(signal)
self.annotations[signal.id] = name
return signal

def add_shared(self: Circuit, name: str, phase: int) -> SharedSignal:
def add_shared(self: ASTCircuit, name: str, phase: int) -> SharedSignal:
signal = SharedSignal(phase, name)
self.shared_signals.append(signal)
self.annotations[signal.id] = name
return signal

def add_fixed(self: Circuit, name: str) -> FixedSignal:
def add_fixed(self: ASTCircuit, name: str) -> FixedSignal:
signal = FixedSignal(name)
self.fixed_signals.append(signal)
self.annotations[signal.id] = name
return signal

def expose(self: Circuit, signal: Queriable, offset: ExposeOffset):
def expose(self: ASTCircuit, signal: Queriable, offset: ExposeOffset):
self.exposed.append((signal, offset))

def add_step_type(self: Circuit, step_type: StepType, name: str):
def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str):
self.annotations[step_type.id] = name
self.step_types[step_type.id] = step_type

def set_trace(
self: Circuit, trace_def: Callable[[TraceContext, Any], None]
): # TraceArgs are Any.
if self.trace is not None:
raise Exception(
"Circuit cannot have more than one trace generator.")
else:
self.trace = trace_def

def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]):
if self.fixed_gen is not None:
raise Exception(
"Circuit cannot have more than one fixed generator.")
raise Exception("ASTCircuit cannot have more than one fixed generator.")
else:
self.fixed_gen = fixed_gen_def

def get_step_type(self, uuid: int) -> StepType:
def get_step_type(self, uuid: int) -> ASTStepType:
if uuid in self.step_types.keys():
return self.step_types[uuid]
else:
raise ValueError("StepType not found.")
raise ValueError("ASTStepType not found.")


# pub struct StepType<F> {
Expand All @@ -185,19 +167,16 @@ def get_step_type(self, uuid: int) -> StepType:


@dataclass
class StepType:
class ASTStepType:
id: int
name: str
signals: List[InternalSignal]
constraints: List[ASTConstraint]
transition_constraints: List[TransitionConstraint]
annotations: Dict[int, str]
wg: Optional[
Callable[[StepInstance, Any], None]
] # Args are Any. Not passed to Rust Chiquito.

def new(name: str) -> StepType:
return StepType(uuid(), name, [], [], [], {}, None)
def new(name: str) -> ASTStepType:
return ASTStepType(uuid(), name, [], [], [], {})

def __str__(self):
signals_str = (
Expand All @@ -209,8 +188,7 @@ def __str__(self):
)
constraints_str = (
"\n\t\t\t\t"
+ ",\n\t\t\t\t".join(str(constraint)
for constraint in self.constraints)
+ ",\n\t\t\t\t".join(str(constraint) for constraint in self.constraints)
+ "\n\t\t\t"
if self.constraints
else ""
Expand All @@ -224,15 +202,14 @@ def __str__(self):
)
annotations_str = (
"\n\t\t\t\t"
+ ",\n\t\t\t\t".join(f"{k}: {v}" for k,
v in self.annotations.items())
+ ",\n\t\t\t\t".join(f"{k}: {v}" for k, v in self.annotations.items())
+ "\n\t\t\t"
if self.annotations
else ""
)

return (
f"StepType(\n"
f"ASTStepType(\n"
f"\t\t\tid={self.id},\n"
f"\t\t\tname='{self.name}',\n"
f"\t\t\tsignals=[{signals_str}],\n"
Expand All @@ -254,33 +231,29 @@ def __json__(self):
"annotations": self.annotations,
}

def add_signal(self: StepType, name: str) -> InternalSignal:
def add_signal(self: ASTStepType, name: str) -> InternalSignal:
signal = InternalSignal(name)
self.signals.append(signal)
self.annotations[signal.id] = name
return signal

def add_constr(self: StepType, annotation: str, expr: Expr):
def add_constr(self: ASTStepType, annotation: str, expr: Expr):
condition = ASTConstraint(annotation, expr)
self.constraints.append(condition)

def add_transition(self: StepType, annotation: str, expr: Expr):
def add_transition(self: ASTStepType, annotation: str, expr: Expr):
condition = TransitionConstraint(annotation, expr)
self.transition_constraints.append(condition)

# Args are Any.
def set_wg(self, wg_def: Callable[[StepInstance, Any], None]):
self.wg = wg_def

def __eq__(self: StepType, other: StepType) -> bool:
if isinstance(self, StepType) and isinstance(other, StepType):
def __eq__(self: ASTStepType, other: ASTStepType) -> bool:
if isinstance(self, ASTStepType) and isinstance(other, ASTStepType):
return self.id == other.id
return False

def __req__(other: StepType, self: StepType) -> bool:
return StepType.__eq__(self, other)
def __req__(other: ASTStepType, self: ASTStepType) -> bool:
return ASTStepType.__eq__(self, other)

def __hash__(self: StepType):
def __hash__(self: ASTStepType):
return hash(self.id)


Expand Down
Loading

0 comments on commit 905adc4

Please sign in to comment.