Skip to content

Commit

Permalink
chore: update before break
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Nov 27, 2023
1 parent d1265c2 commit f2a080e
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 190 deletions.
31 changes: 24 additions & 7 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

import tempfile
from pathlib import Path
from typing import Callable, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import numpy
import onnx
import onnxoptimizer
import torch
from onnx import checker, helper

from .onnx_utils import IMPLEMENTED_ONNX_OPS, execute_onnx_with_numpy, get_op_type
from .onnx_utils import (
IMPLEMENTED_ONNX_OPS,
compute_lsb_to_remove_for_trees,
execute_onnx_with_numpy,
get_op_type,
)

OPSET_VERSION_FOR_ONNX_EXPORT = 14

Expand Down Expand Up @@ -158,15 +163,21 @@ def get_equivalent_numpy_forward_from_torch(

def get_equivalent_numpy_forward_from_onnx(
onnx_model: onnx.ModelProto,
check_model: bool = True,
x: numpy.ndarray = None,
check_model: Optional[bool] = True,
use_rounding: Optional[bool] = False,
) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided ONNX model.
Args:
onnx_model (onnx.ModelProto): the ONNX model for which to get the equivalent numpy
forward.
x (numpy.ndarray): Quantized input used to compute the LSBs to remove if 'use_rounding'.
Defaults to None if rounding is not used.
check_model (bool): set to True to run the onnx checker on the model.
Defaults to True.
use_rounding (bool): Use rounding feature or not.
Defaults to False.
Raises:
ValueError: Raised if there is an unsupported ONNX operator required to convert the torch
Expand All @@ -176,9 +187,11 @@ def get_equivalent_numpy_forward_from_onnx(
Callable[..., Tuple[numpy.ndarray, ...]]: The function that will execute
the equivalent numpy function.
"""

lsbs_to_remove: list = [0, 0]

if check_model:
checker.check_model(onnx_model)
checker.check_model(onnx_model)

# Optimize ONNX graph
# List of all currently supported onnx optimizer passes
Expand All @@ -200,7 +213,6 @@ def get_equivalent_numpy_forward_from_onnx(

# Check supported operators
required_onnx_operators = set(get_op_type(node) for node in equivalent_onnx_model.graph.node)
print(f"{required_onnx_operators=}")
unsupported_operators = required_onnx_operators - IMPLEMENTED_ONNX_OPS
if len(unsupported_operators) > 0:
raise ValueError(
Expand All @@ -209,7 +221,12 @@ def get_equivalent_numpy_forward_from_onnx(
f"Available ONNX operators: {', '.join(sorted(IMPLEMENTED_ONNX_OPS))}"
)

# Return lambda of numpy equivalent of onnx execution
# Compute the LSB to be remove manually while waiting for the truncate feature
# FIXME: https://github.com/zama-ai/concrete-ml/issues/397
if use_rounding:
assert x is not None, "A quantized and representative inputset is needed to use rounding"
lsbs_to_remove = compute_lsb_to_remove_for_trees(equivalent_onnx_model, x)

return (
lambda *args: execute_onnx_with_numpy(equivalent_onnx_model.graph, *args)
lambda *args: execute_onnx_with_numpy(equivalent_onnx_model.graph, lsbs_to_remove, *args)
), equivalent_onnx_model
26 changes: 25 additions & 1 deletion src/concrete/ml/onnx/onnx_impl_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Utility functions for onnx operator implementations."""

from typing import Tuple, Union
from typing import Callable, Tuple, Union

import numpy
from concrete.fhe import conv as cnp_conv
from concrete.fhe import ones as cnp_ones
from concrete.fhe import round_bit_pattern

from ..common.debugging import assert_true

Expand Down Expand Up @@ -223,3 +224,26 @@ def onnx_avgpool_compute_norm_const(
norm_const = numpy.prod(kernel_shape)

return norm_const


def rounded_comparison(
x: numpy.ndarray, y: numpy.ndarray, lsbs_to_remove: int, operation: Callable
) -> Tuple[bool]:
"""Comparison operation using AutoRounder.
Args:
x (numpy.ndarray): Input tensor
y (numpy.ndarray): Input tensor
lsbs_to_remove (int): The number of the least significant bits to remove
operation (Callable): Comparison operation
Returns:
Tuple[bool]: If x and y satisfy the comparison operator.
"""

assert isinstance(lsbs_to_remove, int)

half = 1 << (lsbs_to_remove - 1)
rounded_subtraction = round_bit_pattern((x - y) - half, lsbs_to_remove=lsbs_to_remove)

return (operation(rounded_subtraction),)
106 changes: 104 additions & 2 deletions src/concrete/ml/onnx/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@
# Original file:
# https://github.com/google/jax/blob/f6d329b2d9b5f83c6a59e5739aa1ca8d4d1ffa1c/examples/onnx2xla.py


from typing import Any, Callable, Dict, Tuple
import math
from typing import Any, Callable, Dict, List, Tuple

import numpy
import onnx
Expand Down Expand Up @@ -413,6 +413,9 @@
# All numpy operators used for tree-based models
ONNX_OPS_TO_NUMPY_IMPL_BOOL = {**ONNX_OPS_TO_NUMPY_IMPL, **ONNX_COMPARISON_OPS_TO_NUMPY_IMPL_BOOL}

# All numpy operators used for tree-based models that support auto rounding
SUPPORTED_ROUNDED_OPERATIONS = ["Less", "LessOrEqual", "Greater", "GreaterOrEqual", "Equal"]


IMPLEMENTED_ONNX_OPS = set(ONNX_OPS_TO_NUMPY_IMPL.keys())

Expand Down Expand Up @@ -443,12 +446,14 @@ def get_op_type(node):

def execute_onnx_with_numpy(
graph: onnx.GraphProto,
lsbs_to_remove: List,
*inputs: numpy.ndarray,
) -> Tuple[numpy.ndarray, ...]:
"""Execute the provided ONNX graph on the given inputs.
Args:
graph (onnx.GraphProto): The ONNX graph to execute.
lsbs_to_remove (List): The number of least significant bit to be removed in each stage.
*inputs: The inputs of the graph.
Returns:
Expand All @@ -461,9 +466,16 @@ def execute_onnx_with_numpy(
for initializer in graph.initializer
},
)

for node in graph.node:
curr_inputs = (node_results[input_name] for input_name in node.input)
attributes = {attribute.name: get_attribute(attribute) for attribute in node.attribute}

if node.op_type in SUPPORTED_ROUNDED_OPERATIONS:
attributes["lsbs_to_remove"] = (
lsbs_to_remove[0] if node.op_type != "Equal" else lsbs_to_remove[1]
)

outputs = ONNX_OPS_TO_NUMPY_IMPL_BOOL[node.op_type](*curr_inputs, **attributes)

node_results.update(zip(node.output, outputs))
Expand Down Expand Up @@ -495,3 +507,93 @@ def remove_initializer_from_input(model: onnx.ModelProto): # pragma: no cover
inputs.remove(name_to_input[initializer.name])

return model


# Remove this function once the truncate feature is released
# FIXME: https://github.com/zama-ai/concrete-ml/issues/397
def compute_lsb_to_remove_for_trees(
onnx_model: onnx.ModelProto, quant_x: numpy.ndarray
) -> List[int]:
"""Compute the LSB to remove for trees in the first and second stage.
Args:
onnx_model (onnx.ModelProto): The model to clean
quant_x (numpy.ndarray): The quantized inputs
Returns:
List: the number of LSB to remove for stage 1 and stage 2
"""

def get_bitwidth(array: numpy.ndarray) -> int:
"""Compute the bitwidth required to represent the largest value in `array`.
Args:
array (umpy.ndarray): The array for which the bitwidth needs to be checked.
Returns:
int: The required bits to represent the array.
"""

max_val = numpy.max(numpy.abs(array))
bitwidth = math.ceil(math.log2(max_val + 1)) + 1
return bitwidth

def update_lsbs_if_overflow_detected(array: numpy.ndarray, initial_bitwidth: int) -> int:
"""Update the number of LSBs to remove based on overflow detection.
Args:
array (umpy.ndarray): The array for which the bitwidth needs to be checked.
initial_bitwidth (int): The target bitwidth that should not be exceeded.
Returns:
int: The updated LSB to remove.
"""

lsbs_to_remove = initial_bitwidth

if lsbs_to_remove > 0:
half = 1 << (lsbs_to_remove - 1)
if get_bitwidth(array - half) <= initial_bitwidth:
lsbs_to_remove -= 1

return lsbs_to_remove

quant_params = {
onnx_init.name: numpy_helper.to_array(onnx_init)
for onnx_init in onnx_model.graph.initializer
if "weight" in onnx_init.name or "bias" in onnx_init.name
}

key_mat_1 = [key for key in quant_params.keys() if "_1" in key and "weight" in key][0]
key_bias_1 = [key for key in quant_params.keys() if "_1" in key and "bias" in key][0]

key_mat_2 = [key for key in quant_params.keys() if "_2" in key and "weight" in key][0]
key_bias_2 = [key for key in quant_params.keys() if "_2" in key and "bias" in key][0]

# shape: (noeuds, features) or (trees * noeuds, features)
mat_1 = quant_params[key_mat_1]
# shape: (noeuds, 1) or (trees * noeuds, 1)
bias_1 = quant_params[key_bias_1]

# shape: (trees, leaves, noeuds)
mat_2 = quant_params[key_mat_2]
# shape: (leaves, 1) or (trees * leaves, 1)
bias_2 = quant_params[key_bias_2]

feature = mat_1.shape[1]
noeuds = mat_2.shape[2]
leaves = mat_2.shape[1]

mat_1 = mat_1.reshape(-1, noeuds, feature)
bias_1 = bias_1.reshape(-1, 1, noeuds)
bias_2 = bias_2.reshape(-1, 1, leaves)

stage_1 = (quant_x @ mat_1.transpose(0, 2, 1)) + bias_1
matrix_q = numpy.random.randint(1, 2, size=(stage_1.shape))

stage_2 = ((matrix_q @ mat_2.transpose(0, 2, 1)) + bias_2).sum(axis=0)

lsbs_to_remove_1 = update_lsbs_if_overflow_detected(stage_1, get_bitwidth(stage_1))
lsbs_to_remove_2 = update_lsbs_if_overflow_detected(stage_2, get_bitwidth(stage_2))

return [lsbs_to_remove_1, lsbs_to_remove_2]
Loading

0 comments on commit f2a080e

Please sign in to comment.