Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Add factorial circuit (#149)
Browse files Browse the repository at this point in the history
Adds an example of a circuit implementation that computes the factorial
any given number, to a certain maximum.

Using Python frontend to verify the approach, will continue to port the
example to Rust.
  • Loading branch information
sraver authored Oct 17, 2023
1 parent c527f5e commit 7d7c7e6
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 0 deletions.
121 changes: 121 additions & 0 deletions examples/factorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from chiquito.dsl import Circuit, StepType
from chiquito.cb import eq
from chiquito.util import F
from chiquito.chiquito_ast import Last

MAX_FACTORIAL = 10

"""
| step_type | i | x |
----------------------------------
| first_step | 0 | 1 |
| operation_step | 1 | 1 |
| operation_step | 2 | 2 |
| operation_step | 3 | 6 |
| result_step | 4 | 24 |
| result_step | 4 | 24 |
| result_step | 4 | 24 |
...
"""


class FirstStep(StepType):
def setup(self):
# constrain `i` to zero
self.constr(eq(self.circuit.i, 0))
# constrain `x` to one
self.constr(eq(self.circuit.x, 1))
# constrain the next `x` to be equal to the current `x`
self.transition(eq(self.circuit.x, self.circuit.x.next()))

def wg(self):
self.assign(self.circuit.i, F(0))
self.assign(self.circuit.x, F(1))


class OperationStep(StepType):
def setup(self):
# constrain i.prev() + 1 == i
self.transition(eq(self.circuit.i.rot(-1) + 1, self.circuit.i))
# constrain i + 1 == i.next()
self.transition(eq(self.circuit.i + 1, self.circuit.i.next()))
# constrain the next `x` to be the product of the current `x` and the next `i`
self.transition(eq(self.circuit.x * (self.circuit.i + 1), self.circuit.x.next()))

def wg(self, i_value, x_value):
self.assign(self.circuit.i, F(i_value))
self.assign(self.circuit.x, F(x_value))


class ResultStep(StepType):
def setup(self):
# constrain `x` to not change
self.transition(eq(self.circuit.x, self.circuit.x.next()))
# constrain `i` to not change
self.transition(eq(self.circuit.i, self.circuit.i.next()))

def wg(self, i_value, x_value):
self.assign(self.circuit.i, F(i_value))
self.assign(self.circuit.x, F(x_value))


class Factorial(Circuit):
def setup(self):
# `i` holds the current iteration number
self.i = self.shared("i")
# `x` holds the current total result
self.x = self.forward("x")

self.first_step = self.step_type(FirstStep(self, "first_step"))
self.operation_step = self.step_type(OperationStep(self, "operation_step"))
self.result_step = self.step_type(ResultStep(self, "result_step"))

self.pragma_num_steps(MAX_FACTORIAL + 1)
self.pragma_first_step(self.first_step)
self.pragma_last_step(self.result_step)

self.expose(self.x, Last())
self.expose(self.i, Last())

def trace(self, n):
self.add(self.first_step)
current_result = 1

for i in range(1, n + 1):
current_result *= i
if i == n:
# we found the result
self.add(self.result_step, i, current_result)
else:
# more operations need to be done
self.add(self.operation_step, i, current_result)

while self.needs_padding():
# if padding is needed, we propagate final values
self.add(self.result_step, n, current_result)


class Examples:
def test_zero(self):
factorial = Factorial()
factorial_witness = factorial.gen_witness(0)
last_assignments = list(factorial_witness.step_instances[10].assignments.values())
assert last_assignments[0] == 0 # i
assert last_assignments[1] == 1 # x
factorial.halo2_mock_prover(factorial_witness)

def test_basic(self):
factorial = Factorial()
factorial_witness = factorial.gen_witness(7)
last_assignments = list(factorial_witness.step_instances[10].assignments.values())
assert last_assignments[0] == 7 # i
assert last_assignments[1] == 5040 # x
factorial.halo2_mock_prover(factorial_witness)


if __name__ == "__main__":
x = Examples()
for method in [
method for method in dir(x) if callable(getattr(x, method)) if not method.startswith('_')
]:
getattr(x, method)()
180 changes: 180 additions & 0 deletions examples/factorial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
use std::hash::Hash;

use chiquito::{
field::Field,
frontend::dsl::circuit, // main function for constructing an AST circuit
plonkish::backend::halo2::{chiquito2Halo2, ChiquitoHalo2Circuit}, /* compiles to
* Chiquito Halo2
* backend,
* which can be
* integrated into
* Halo2
* circuit */
plonkish::compiler::{
cell_manager::SingleRowCellManager, // input for constructing the compiler
compile, // input for constructing the compiler
config,
step_selector::SimpleStepSelectorBuilder,
},
plonkish::ir::{assignments::AssignmentGenerator, Circuit}, // compiled circuit type
poly::ToField,
};
use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr};

const MAX_FACTORIAL: usize = 10;

fn generate<F: Field + From<u64> + Hash>() -> (Circuit<F>, Option<AssignmentGenerator<F, u32>>) {
//
// table for the circuit:
// | step_type | i | x |
// ----------------------------------
// | first_step | 0 | 1 |
// | operation_step | 1 | 1 |
// | operation_step | 2 | 2 |
// | operation_step | 3 | 6 |
// | result_step | 4 | 24 |
// | result_step | 4 | 24 |
// | result_step | 4 | 24 |
// ...

use chiquito::frontend::dsl::cb::*; // functions for constraint building

let factorial_circuit = circuit::<F, u32, _>("factorial", |ctx| {
let i = ctx.shared("i");
let x = ctx.forward("x");

// first step will make sure the circuit is initialized correctly
let first_step = ctx.step_type_def("first_step", |ctx| {
// define the setup
ctx.setup(move |ctx| {
// constrain `i` to zero
ctx.constr(eq(i, 0));
// constrain `x` to one
ctx.constr(eq(x, 1));
// constrain the next `x` to be equal to the current `x`
ctx.transition(eq(x, x.next()));
});
// witness assignment
ctx.wg(move |ctx, ()| {
ctx.assign(i, 0.into());
ctx.assign(x, 1.into());
})
});

// operation step will make sure every state transition is correct
let operation_step = ctx.step_type_def("operation_step", |ctx| {
// define the setup
ctx.setup(move |ctx| {
// constrain i.prev() + 1 == i
ctx.transition(eq(i.rot(-1) + 1, i));
// constrain i + 1 == i.next()
ctx.transition(eq(i + 1, i.next()));
// constrain the next `x` to be the product of the current `x` and the next `i`
ctx.transition(eq(x * i.next(), x.next()));
});
// witness assignment
ctx.wg(move |ctx, (i_value, x_value): (u32, u32)| {
ctx.assign(i, i_value.field());
ctx.assign(x, x_value.field());
})
});

// result step will hold and propagate the value
let result_step = ctx.step_type_def("result_step", |ctx| {
// define the setup
ctx.setup(move |ctx| {
// constrain `i` to not change
ctx.transition(eq(i, i.next()));
// constrain `x` to not change
ctx.transition(eq(x, x.next()));
});
// witness assignment
ctx.wg(move |ctx, (i_value, x_value): (u32, u32)| {
ctx.assign(i, i_value.field());
ctx.assign(x, x_value.field());
})
});

ctx.pragma_first_step(&first_step);
ctx.pragma_last_step(&result_step);
ctx.pragma_num_steps(MAX_FACTORIAL + 1);

ctx.trace(move |ctx, n| {
ctx.add(&first_step, ());

let mut current_result = 1;

for i in 1..n + 1 {
current_result *= i;
if i == n {
// we found the result
ctx.add(&result_step, (i, current_result));
} else {
// more operations need to be done
ctx.add(&operation_step, (i, current_result));
}
}

// if padding is needed, propagate final values
ctx.padding(&result_step, || (n, current_result));
})
});

compile(
config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}),
&factorial_circuit,
)
}

// After compiling Chiquito AST to an IR, it is further parsed by a Chiquito Halo2 backend and
// integrated into a Halo2 circuit, which is done by the boilerplate code below.

// standard main function for a Halo2 circuit
fn main() {
let (chiquito, wit_gen) = generate::<Fr>();
let compiled = chiquito2Halo2(chiquito);
let circuit = ChiquitoHalo2Circuit::new(compiled, wit_gen.map(|g| g.generate(0)));

let prover = MockProver::<Fr>::run(10, &circuit, circuit.instance()).unwrap();

let result = prover.verify_par();

println!("result = {:#?}", result);

if let Err(failures) = &result {
for failure in failures.iter() {
println!("{}", failure);
}
}

// plaf boilerplate
use chiquito::plonkish::backend::plaf::chiquito2Plaf;
use polyexen::plaf::{backends::halo2::PlafH2Circuit};

// get Chiquito ir
let (circuit, wit_gen) = generate::<Fr>();
// get Plaf
let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, 8, false);
let wit = plaf_wit_gen.generate(wit_gen.map(|v| v.generate(7)));

// debug only: print witness
// println!("{}", polyexen::plaf::WitnessDisplayCSV(&wit));

// get Plaf halo2 circuit from Plaf's halo2 backend
// this is just a proof of concept, because Plaf only has backend for halo2
// this is unnecessary because Chiquito has a halo2 backend already
let plaf_circuit = PlafH2Circuit { plaf, wit };

// same as halo2 boilerplate above
let prover_plaf = MockProver::<Fr>::run(8, &plaf_circuit, Vec::new()).unwrap();

let result_plaf = prover_plaf.verify_par();

println!("result = {:#?}", result_plaf);

if let Err(failures) = &result_plaf {
for failure in failures.iter() {
println!("{}", failure);
}
}
}

0 comments on commit 7d7c7e6

Please sign in to comment.