Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
pass over super witness rather than storing them
Browse files Browse the repository at this point in the history
  • Loading branch information
qwang98 committed Sep 19, 2023
1 parent e891154 commit 9b3186b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 50 deletions.
10 changes: 6 additions & 4 deletions examples/mimc7.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def setup(self):
self.pragma_num_steps(ROUNDS)
self.lookup_row = self.fixed("constant row")
self.lookup_c = self.fixed("constant value")
self.lookup_table = self.new_table(table().add(self.lookup_row).add(self.lookup_c))
self.lookup_table = self.new_table(
table().add(self.lookup_row).add(self.lookup_c)
)

def fixed_gen(self):
for i, round_key in enumerate(ROUND_KEYS):
Expand Down Expand Up @@ -161,7 +163,7 @@ def mapping(self, x_in_value, k_value):


mimc7 = Mimc7SuperCircuit()
mimc7_witnesses = mimc7.gen_witness(F(1), F(2))
# for key, value in mimc7_witnesses.items():
mimc7_super_witness = mimc7.gen_witness(F(1), F(2))
# for key, value in mimc7_super_witness.items():
# print(f"{key}: {str(value)}")
mimc7.halo2_mock_prover(mimc7_witnesses)
mimc7.halo2_mock_prover(mimc7_super_witness)
6 changes: 4 additions & 2 deletions src/frontend/python/chiquito/cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def apply(self: LookupTable, constraint: ToConstraint) -> LookupTableBuilder:
def when(self: LookupTable, enable: ToConstraint) -> LookupTableBuilder:
assert self.finished_flag == True
return LookupTableBuilder(self).when(enable)

def set_finished_flag(self: LookupTable):
assert self.finished_flag == False
self.finished_flag = True
Expand Down Expand Up @@ -296,7 +296,9 @@ def build(self: LookupTableBuilder) -> Lookup:
lookup.enable(self.enable.annotation, self.enable.expr)

for i in range(len(self.src)):
lookup.add(self.src[i].annotation, self.src[i].expr, self.lookup_table.dest[i])
lookup.add(
self.src[i].annotation, self.src[i].expr, self.lookup_table.dest[i]
)

return lookup

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/python/chiquito/chiquito_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@dataclass
class ASTSuperCircuit:
sub_circuits: Dict[int, ASTCircuit] = field(default_factory=dict)
witnesses: Dict[int, TraceWitness] = field(default_factory=dict)
super_witness: Dict[int, TraceWitness] = field(default_factory=dict)


@dataclass
Expand Down
22 changes: 12 additions & 10 deletions src/frontend/python/chiquito/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def map(self: SuperCircuit, sub_circuit: Circuit, *args: Any) -> TraceWitness:
raise ValueError(
"SuperCircuit: must call sub_circuit() before calling map() on a Circuit."
)
self.ast.witnesses[sub_circuit.rust_id] = witness
self.ast.super_witness[sub_circuit.rust_id] = witness
return witness

# called at the outermost level
Expand All @@ -64,22 +64,22 @@ def gen_witness(self: SuperCircuit, *args: Any) -> Dict[int, TraceWitness]:
self.mode = SuperCircuitMode.Mapping
self.mapping(*args)
self.mode = SuperCircuitMode.NoMode
witnesses: Dict[int, TraceWitness] = self.ast.witnesses
super_witness: Dict[int, TraceWitness] = self.ast.super_witness
del (
self.ast.witnesses
self.ast.super_witness
) # so that we can generate different witness mapping in the next gen_witness() call
return witnesses
return super_witness

def halo2_mock_prover(self: SuperCircuit, witnesses: Dict[int, TraceWitness]):
for rust_id, witness in witnesses.items():
witness_json: str = witness.get_witness_json()
def halo2_mock_prover(self: SuperCircuit, super_witness: Dict[int, TraceWitness]):
for rust_id, witness in super_witness.items():
if rust_id not in self.ast.sub_circuits:
raise ValueError(
f"SuperCircuit.halo2_mock_prover(): TraceWitness with rust_id {rust_id} not found in sub_circuits."
)
rust_chiquito.add_witness_to_rust_id(witness_json, rust_id)
witness_json: str = witness.get_witness_json()
super_witness[rust_id] = witness_json
rust_chiquito.super_circuit_halo2_mock_prover(
list(self.ast.sub_circuits.keys())
list(self.ast.sub_circuits.keys()), super_witness
)


Expand All @@ -91,7 +91,9 @@ class CircuitMode(Enum):


class Circuit:
def __init__(self: Circuit, super_circuit: SuperCircuit=None, imports: Any=None):
def __init__(
self: Circuit, super_circuit: SuperCircuit = None, imports: Any = None
):
self.ast = ASTCircuit()
self.witness = TraceWitness()
self.rust_id = 0
Expand Down
66 changes: 33 additions & 33 deletions src/frontend/python/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::{
prelude::*,
types::{PyList, PyLong, PyString},
types::{PyDict, PyList, PyLong, PyString},
};

use crate::{
Expand Down Expand Up @@ -34,7 +34,7 @@ type CircuitMapStore = (
Circuit<Fr, ()>,
ChiquitoHalo2<Fr>,
Option<AssignmentGenerator<Fr, ()>>,
Option<TraceWitness<Fr>>,
// Option<TraceWitness<Fr>>,
);
type CircuitMap = RefCell<HashMap<UUID, CircuitMapStore>>;

Expand All @@ -58,28 +58,14 @@ pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID {
CIRCUIT_MAP.with(|circuit_map| {
circuit_map
.borrow_mut()
.insert(uuid, (circuit, chiquito_halo2, assignment_generator, None));
.insert(uuid, (circuit, chiquito_halo2, assignment_generator));
});

println!("{:?}", uuid);

uuid
}

/// Parses JSON into `TraceWitness` and insert it into `CIRCUIT_MAP` with `rust_id` as the key.
pub fn chiquito_add_witness_to_rust_id(witness_json: &str, rust_id: UUID) {
let witness: TraceWitness<Fr> =
serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed.");

CIRCUIT_MAP.with(|circuit_map| {
let mut circuit_map = circuit_map.borrow_mut();
let circuit_map_store = circuit_map.get_mut(&rust_id).unwrap();
circuit_map_store.3 = Some(witness);
});

println!("Added TraceWitness to rust_id: {:?}", rust_id);
}

fn add_assignment_generator_to_rust_id(
assignment_generator: AssignmentGenerator<Fr, ()>,
rust_id: UUID,
Expand All @@ -97,14 +83,17 @@ fn add_assignment_generator_to_rust_id(
/// sub-circuit. The `ChiquitoHalo2SuperCircuit` object is then passed to `MockProver` for
/// verification. `TraceWitness`, if any, should have been inserted to each rust_id prior to
/// invoking this function.
pub fn chiquito_super_circuit_halo2_mock_prover(rust_ids: Vec<UUID>) {
pub fn chiquito_super_circuit_halo2_mock_prover(
rust_ids: Vec<UUID>,
super_witness: HashMap<UUID, &str>,
) {
let mut super_circuit_ctx = SuperCircuitContext::<Fr, ()>::default();

// super_circuit def
let config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {});
for rust_id in rust_ids.clone() {
let circuit_map_store = rust_id_to_halo2(rust_id);
let (circuit, _, _, _) = circuit_map_store;
let (circuit, _, _) = circuit_map_store;
let assignment = super_circuit_ctx.sub_circuit_with_ast(config.clone(), circuit);
add_assignment_generator_to_rust_id(assignment, rust_id);
}
Expand All @@ -115,8 +104,11 @@ pub fn chiquito_super_circuit_halo2_mock_prover(rust_ids: Vec<UUID>) {
let mut mapping_ctx = MappingContext::default();
for rust_id in rust_ids {
let circuit_map_store = rust_id_to_halo2(rust_id);
let (_, _, assignment_generator, witness) = circuit_map_store;
if let Some(witness) = witness {
let (_, _, assignment_generator) = circuit_map_store;

if let Some(witness_json) = super_witness.get(&rust_id) {
let witness: TraceWitness<Fr> = serde_json::from_str(witness_json)
.expect("Json deserialization to TraceWitness failed.");
mapping_ctx.map_with_witness(&assignment_generator.unwrap(), witness);
}
}
Expand Down Expand Up @@ -152,7 +144,7 @@ fn rust_id_to_halo2(uuid: UUID) -> CircuitMapStore {
pub fn chiquito_halo2_mock_prover(witness_json: &str, rust_id: UUID) {
let trace_witness: TraceWitness<Fr> =
serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed.");
let (_, compiled, assignment_generator, _) = rust_id_to_halo2(rust_id);
let (_, compiled, assignment_generator) = rust_id_to_halo2(rust_id);
let circuit: ChiquitoHalo2Circuit<_> = ChiquitoHalo2Circuit::new(
compiled,
assignment_generator.map(|g| g.generate_with_witness(trace_witness)),
Expand Down Expand Up @@ -1846,15 +1838,7 @@ fn halo2_mock_prover(witness_json: &PyString, rust_id: &PyLong) {
}

#[pyfunction]
fn add_witness_to_rust_id(witness_json: &PyString, rust_id: &PyLong) {
chiquito_add_witness_to_rust_id(
witness_json.to_str().expect("PyString convertion failed."),
rust_id.extract().expect("PyLong convertion failed."),
);
}

#[pyfunction]
fn super_circuit_halo2_mock_prover(rust_ids: &PyList) {
fn super_circuit_halo2_mock_prover(rust_ids: &PyList, super_witness: &PyDict) {
let uuids = rust_ids
.iter()
.map(|rust_id| {
Expand All @@ -1866,7 +1850,24 @@ fn super_circuit_halo2_mock_prover(rust_ids: &PyList) {
})
.collect::<Vec<UUID>>();

chiquito_super_circuit_halo2_mock_prover(uuids)
let super_witness = super_witness
.iter()
.map(|(key, value)| {
(
key.downcast::<PyLong>()
.expect("PyAny downcast failed.")
.extract()
.expect("PyLong convertion failed."),
value
.downcast::<PyString>()
.expect("PyAny downcast failed.")
.to_str()
.expect("PyString convertion failed."),
)
})
.collect::<HashMap<u128, &str>>();

chiquito_super_circuit_halo2_mock_prover(uuids, super_witness)
}

#[pymodule]
Expand All @@ -1875,7 +1876,6 @@ fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?;
m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?;
m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?;
m.add_function(wrap_pyfunction!(add_witness_to_rust_id, m)?)?;
m.add_function(wrap_pyfunction!(super_circuit_halo2_mock_prover, m)?)?;
Ok(())
}

0 comments on commit 9b3186b

Please sign in to comment.