Skip to content

Commit

Permalink
patch recip
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Nov 8, 2024
1 parent 7e3c3ff commit d5b5bf8
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),

HybridOp::Max => format!("MAX"),
HybridOp::Min => format!("MIN"),
HybridOp::Max => "MAX".to_string(),
HybridOp::Min => "MIN".to_string(),
HybridOp::Recip {
input_scale,
output_scale,
Expand Down
35 changes: 29 additions & 6 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
x: &ValTensor<F>,
ignore_mask: &Option<ValTensor<F>>,
f: impl Fn(&BaseConfig<F>, &mut RegionCtx<F>, &ValTensor<F>) -> Result<ValTensor<F>, CircuitError>,
) -> Result<(), CircuitError> {
let two = create_constant_tensor(F::from(2), 1);
Expand All @@ -90,7 +91,11 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
let f_x_is_opt_rhs = less(config, region, &[f_x.clone(), f_x_plus_2])?;
let f_x_is_opt_lhs = less(config, region, &[f_x.clone(), f_x_minus_2])?;

let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;
let mut is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;

if let Some(ignore_mask) = ignore_mask {
is_opt = or(config, region, &[is_opt.clone(), ignore_mask.clone()])?;
}

let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_opt.len());
comparison_unit.reshape(is_opt.dims())?;
Expand Down Expand Up @@ -155,7 +160,7 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Ok(distance)
};

optimum_convex_function(config, region, &claimed_output, err_func)?;
optimum_convex_function(config, region, &claimed_output, &None, err_func)?;

Ok(claimed_output)
}
Expand Down Expand Up @@ -205,9 +210,16 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let zero_inverse = create_constant_tensor(integer_rep_to_felt(zero_inverse_val), 1);

let equal_zero_mask = equals_zero(config, region, &[input.clone()])?;

let not_equal_zero_mask = not(config, region, &[equal_zero_mask.clone()])?;
let equal_inverse_mask = equals(config, region, &[claimed_output.clone(), zero_inverse])?;

let masked_unit_scale = pairwise(
config,
region,
&[unit_scale.clone(), not_equal_zero_mask.clone()],
BaseOp::Mult,
)?;

// assert the two masks are equal
enforce_equality(
config,
Expand All @@ -220,11 +232,22 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
x: &ValTensor<F>|
-> Result<ValTensor<F>, CircuitError> {
let product = pairwise(config, region, &[x.clone(), input.clone()], BaseOp::Mult)?;
let distance = l1_distance(config, region, &[product.clone(), unit_scale.clone()])?;

let distance = l1_distance(
config,
region,
&[product.clone(), masked_unit_scale.clone()],
)?;
Ok(distance)
};

optimum_convex_function(config, region, &claimed_output, err_func)?;
optimum_convex_function(
config,
region,
&claimed_output,
&Some(equal_zero_mask),
err_func,
)?;

Ok(claimed_output)
}
Expand Down Expand Up @@ -306,7 +329,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Ok(distance)
};

optimum_convex_function(config, region, &claimed_output, err_func)?;
optimum_convex_function(config, region, &claimed_output, &None, err_func)?;

Ok(claimed_output)
}
Expand Down
4 changes: 1 addition & 3 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup

/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
_ => inputs_scale[0],
};
let scale = inputs_scale[0];
Ok(scale)
}

Expand Down
2 changes: 1 addition & 1 deletion src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ pub fn new_op_from_onnx(
.collect::<Vec<_>>();

if inputs.len() == 2 {
if const_inputs.len() > 0 {
if !const_inputs.is_empty() {
let const_idx = const_inputs[0];
let boxed_op = inputs[const_idx].opkind();
let unit = if let Some(c) = extract_const_raw_values(boxed_op) {
Expand Down

0 comments on commit d5b5bf8

Please sign in to comment.