From d5b5bf893aa5cc417e9b81476e38b4c401d3210b Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Fri, 8 Nov 2024 19:57:48 +0000 Subject: [PATCH] patch recip --- src/circuit/ops/hybrid.rs | 4 ++-- src/circuit/ops/layouts.rs | 35 +++++++++++++++++++++++++++++------ src/circuit/ops/lookup.rs | 4 +--- src/graph/utilities.rs | 2 +- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index 75f0478b3..e8c1e456f 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -137,8 +137,8 @@ impl Op for Hybrid HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs), HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs), - HybridOp::Max => format!("MAX"), - HybridOp::Min => format!("MIN"), + HybridOp::Max => "MAX".to_string(), + HybridOp::Min => "MIN".to_string(), HybridOp::Recip { input_scale, output_scale, diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 2461243c9..c9d9c8e94 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -71,6 +71,7 @@ fn optimum_convex_function, region: &mut RegionCtx, x: &ValTensor, + ignore_mask: &Option>, f: impl Fn(&BaseConfig, &mut RegionCtx, &ValTensor) -> Result, CircuitError>, ) -> Result<(), CircuitError> { let two = create_constant_tensor(F::from(2), 1); @@ -90,7 +91,11 @@ fn optimum_convex_function( Ok(distance) }; - optimum_convex_function(config, region, &claimed_output, err_func)?; + optimum_convex_function(config, region, &claimed_output, &None, err_func)?; Ok(claimed_output) } @@ -205,9 +210,16 @@ pub(crate) fn recip( let zero_inverse = create_constant_tensor(integer_rep_to_felt(zero_inverse_val), 1); let equal_zero_mask = equals_zero(config, region, &[input.clone()])?; - + let not_equal_zero_mask = not(config, region, &[equal_zero_mask.clone()])?; let equal_inverse_mask = equals(config, region, &[claimed_output.clone(), zero_inverse])?; + let masked_unit_scale = pairwise( + config, + region, + &[unit_scale.clone(), not_equal_zero_mask.clone()], + BaseOp::Mult, + )?; + // assert the two masks are equal enforce_equality( config, @@ -220,11 +232,22 @@ pub(crate) fn recip( x: &ValTensor| -> Result, CircuitError> { let product = pairwise(config, region, &[x.clone(), input.clone()], BaseOp::Mult)?; - let distance = l1_distance(config, region, &[product.clone(), unit_scale.clone()])?; + + let distance = l1_distance( + config, + region, + &[product.clone(), masked_unit_scale.clone()], + )?; Ok(distance) }; - optimum_convex_function(config, region, &claimed_output, err_func)?; + optimum_convex_function( + config, + region, + &claimed_output, + &Some(equal_zero_mask), + err_func, + )?; Ok(claimed_output) } @@ -306,7 +329,7 @@ pub fn sqrt( Ok(distance) }; - optimum_convex_function(config, region, &claimed_output, err_func)?; + optimum_convex_function(config, region, &claimed_output, &None, err_func)?; Ok(claimed_output) } diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs index b77882a55..af5d61c99 100644 --- a/src/circuit/ops/lookup.rs +++ b/src/circuit/ops/lookup.rs @@ -198,9 +198,7 @@ impl Op for Lookup /// Returns the scale of the output of the operation. fn out_scale(&self, inputs_scale: Vec) -> Result { - let scale = match self { - _ => inputs_scale[0], - }; + let scale = inputs_scale[0]; Ok(scale) } diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 6d505bf30..c9a1bb449 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -764,7 +764,7 @@ pub fn new_op_from_onnx( .collect::>(); if inputs.len() == 2 { - if const_inputs.len() > 0 { + if !const_inputs.is_empty() { let const_idx = const_inputs[0]; let boxed_op = inputs[const_idx].opkind(); let unit = if let Some(c) = extract_const_raw_values(boxed_op) {