Skip to content

Commit

Permalink
chore: allow users to pass schema in encrypted data-frames (#676)
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft authored Jun 4, 2024
1 parent f937c9f commit ccd6641
Show file tree
Hide file tree
Showing 3 changed files with 396 additions and 36 deletions.
195 changes: 167 additions & 28 deletions src/concrete/ml/pandas/_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,70 @@

import copy
from collections import defaultdict
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import numpy
import pandas

from concrete.ml.pandas._development import get_min_max_allowed
from concrete.ml.quantization.quantizers import STABILITY_CONST

SCHEMA_FLOAT_KEYS = ["min", "max"]

def compute_scale_zero_point(column: pandas.Series, q_min: int, q_max: int) -> Tuple[float, float]:

def is_str_or_none(column: pandas.Series) -> bool:
"""Determine if the data-frames only contains string and None values or not.
Args:
column (pandas.Series): The data-frame to consider.
Returns:
bool: If the data-frames only contains string and None values or not.
"""
return column.apply(lambda x: isinstance(x, str) or not pandas.notna(x)).all()


def compute_scale_zero_point(
f_min: float, f_max: float, q_min: int, q_max: int
) -> Tuple[float, float]:
"""Compute the scale and zero point to use for quantizing / de-quantizing the given column.
Note that the scale and zero point are computed so that values are quantized uniformly from
range [column.min(), column.max()] (float) to range [q_min, q_max] (int).
range [f_min, f_max] (float) to range [q_min, q_max] (int).
Args:
column (pandas.Series): The column to consider.
q_min (int): The minimum quantized value to consider.
q_max (int): The maximum quantized value to consider.
f_min (float): The minimum float value observed.
f_max (float): The maximum float value observed.
q_min (int): The minimum quantized value to target.
q_max (int): The maximum quantized value to target.
Returns:
Tuple[float, float]: The scale and zero-point.
"""
values_min, values_max = column.min(), column.max()

# If there si a single float value in the column, the scale and zero-point need to be handled
# differently
if values_max - values_min < STABILITY_CONST:
if f_max - f_min < STABILITY_CONST:

# If this single float value is 0, make sure it is not quantized to 0
if numpy.abs(values_max) < STABILITY_CONST:
if numpy.abs(f_max) < STABILITY_CONST:
scale = 1.0
zero_point = -q_min

# Else, quantize it to 1
else:
scale = 1 / values_max
scale = 1 / f_max
zero_point = 0

else:
scale = (q_max - q_min) / (values_max - values_min)
scale = (q_max - q_min) / (f_max - f_min)

# Zero-point must be rounded once NaN values are not represented by 0 anymore
# The issue is that we currently need to avoid quantized values to reach 0, but having a
# round here + in the 'quant' method can make this happen.
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342
zero_point = values_min * scale - q_min
# Disable mypy until it is fixed
zero_point = f_min * scale - q_min # type: ignore[assignment]

return scale, zero_point

Expand Down Expand Up @@ -86,9 +103,49 @@ def dequant(
return x.astype(dtype)


# Provide a way for users to pass string mappings
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342
def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataFrame, Dict]:
def check_schema_format(pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None) -> None:
"""Check that the given schema has a proper expected format.
Args:
pandas_dataframe (pandas.DataFrame): The data-frame associated to the given schema.
schema (Optional[Dict]): The schema to check, which can be None. Default to None.
Raises:
ValueError: If the given schema is not a dict.
ValueError: If the given schema contains column names that do not appear in the data-frame.
ValueError: If one of the columns' mapping is not a dict.
"""
if schema is None:
return

if not isinstance(schema, dict):
raise ValueError(
"When set, parameter 'schema' must be a dictionary that associates some of the "
f"data-frame's column names to their value mappings. Got {type(schema)=}"
)

column_names = list(pandas_dataframe.columns)

for column_name, column_mapping in schema.items():
if column_name not in column_names:
# TODO: Is this check actually relevant ? Can't the schema provide more columns than the
# one found in the data-frame ?
raise ValueError(
f"Column name '{column_name}' found in the given schema cannot be found in the "
f"input data-frame. Expected one of {column_names}"
)

if not isinstance(column_mapping, dict):
raise ValueError(
f"Mapping for column '{column_name}' is not a dictionary. Got "
f"{type(column_mapping)=}"
)


# pylint: disable=too-many-branches, too-many-statements
def pre_process_dtypes(
pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None
) -> Tuple[pandas.DataFrame, Dict]:
"""Pre-process the Pandas data-frame and check that input dtypes and ranges are supported.
Currently, three input dtypes are supported : integers (within a specific range), floating
Expand All @@ -97,6 +154,7 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF
Args:
pandas_dataframe (pandas.DataFrame): The Pandas data-frame to pre-process.
schema (Optional[Dict]): The input schema to consider. Default to None.
Raises:
ValueError: If the values of a column with an integer dtype are out of bounds.
Expand All @@ -118,6 +176,7 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

# Avoid sending column names to server, instead use hashes
# FIXME : https://github.com/zama-ai/concrete-ml-internal/issues/4342
# pylint: disable=too-many-nested-blocks
for column_name in pandas_dataframe.columns:
column = pandas_dataframe[column_name]
column_dtype = column.dtype
Expand All @@ -127,9 +186,13 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

# If the column contains integers, make sure they are not out of bounds
if numpy.issubdtype(column_dtype, numpy.integer):
out_of_bounds = (column < q_min).any() or (column > q_max).any()
if schema is not None and column_name in schema:
raise ValueError(
f"Column '{column_name}' contains integer values and therefore does not "
"require any mappings. Please remove it"
)

if out_of_bounds:
if column.min() < q_min or column.max() > q_max:
raise ValueError(
f"Column '{column_name}' (dtype={column_dtype}) contains values that are out "
f"of bounds. Expected values to be in interval [min={q_min}, max={q_max}], but "
Expand All @@ -138,7 +201,33 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

# If the column contains floats, quantize the values
elif numpy.issubdtype(column_dtype, numpy.floating):
scale, zero_point = compute_scale_zero_point(column, q_min, q_max)
if schema is not None and column_name in schema:
float_min_max = schema[column_name]

if not all(
float_mapping_key in SCHEMA_FLOAT_KEYS
for float_mapping_key in float_min_max.keys()
):
raise ValueError(
f"Column '{column_name}' contains float values but the associated mapping "
f"does not contain proper keys. Expected {sorted(SCHEMA_FLOAT_KEYS)}, but "
f"got {sorted(float_min_max.keys())}"
)

f_min, f_max = float_min_max["min"], float_min_max["max"]

if column.min() < f_min or column.max() > f_max:
raise ValueError(
f"Column '{column_name}' (dtype={column_dtype}) contains values that are "
f"out of bounds. Expected values to be in interval [min={f_min}, "
f"max={f_max}], as determined by the given schema, but found "
f"[min={column.min()}, max={column.max()}]."
)

else:
f_min, f_max = column.min(), column.max()

scale, zero_point = compute_scale_zero_point(f_min, f_max, q_min, q_max)

q_column = quant(column, scale, zero_point)

Expand All @@ -150,14 +239,58 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

# If the column contains objects, make sure it is only made of strings or NaN values
elif column_dtype == "object":
is_str = column.apply(lambda x: isinstance(x, str) or not pandas.notna(x)).all()

if is_str:

# Build a mapping between the unique strings values and integers
str_to_int = {
str_value: i + 1 for i, str_value in enumerate(column.dropna().unique())
}
if is_str_or_none(column):
if schema is not None and column_name in schema:
str_to_int = schema[column_name]

column_values = set(column.values) - set([None, numpy.NaN])
string_mapping_keys = set(str_to_int.keys())

# Allow custom mapping for NaN values once they are not represented by 0 anymore
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342
if numpy.NaN in string_mapping_keys:
raise ValueError(
f"String mapping for column '{column_name}' contains numpy.NaN as a "
"key, which is currently forbidden"
)

forgotten_string_values = column_values - string_mapping_keys

if forgotten_string_values:
raise ValueError(
f"String mapping keys for column '{column_name}' are not considering "
"all values from the data-frame. Missing values: "
f"{sorted(forgotten_string_values)}"
)

for string_mapping_key, string_mapping_value in str_to_int.items():
if not isinstance(string_mapping_value, int):
raise ValueError(
f"String mapping values for column '{column_name}' must be "
f"integers. Got {type(string_mapping_value)} for key "
f"{string_mapping_key}"
)

if string_mapping_value < q_min or string_mapping_value > q_max:
raise ValueError(
f"String mapping values for column '{column_name}' are out of "
f"bounds. Expected values to be in interval [min={q_min}, "
f"max={q_max}] but got {string_mapping_value} for key "
f"{string_mapping_key}"
)

if len(str_to_int.values()) != len(set(str_to_int.values())):
raise ValueError(
f"String mapping values for column '{column_name}' must be unique. Got "
f"{str_to_int.values()}"
)

else:

# Build a mapping between the unique strings values and integers
str_to_int = {
str_value: i + 1 for i, str_value in enumerate(column.dropna().unique())
}

# Make sure the number of unique values do not goes over the maximum integer value
# allowed in an encrypted data-frame
Expand Down Expand Up @@ -189,14 +322,20 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF
"supported."
)

# TODO: Should all non-integers columns be considered by the schema if not None ? Currently,
# mappings are computed automatically if schema is not set

return pandas_dataframe, dtype_mappings


def pre_process_from_pandas(pandas_dataframe: pandas.DataFrame) -> Tuple[numpy.ndarray, Dict]:
def pre_process_from_pandas(
pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None
) -> Tuple[numpy.ndarray, Dict]:
"""Pre-process the Pandas data-frame.
Args:
pandas_dataframe (pandas.DataFrame): The Pandas data-frame to pre-process.
schema (Optional[Dict]): The input schema to consider. Default to None.
Raises:
ValueError: If the data-frame's index has not been reset (meaning the index is not a
Expand All @@ -217,7 +356,7 @@ def pre_process_from_pandas(pandas_dataframe: pandas.DataFrame) -> Tuple[numpy.n
)

# Check that values are supported and build the mappings
q_pandas_dataframe, dtype_mappings = pre_process_dtypes(pandas_dataframe)
q_pandas_dataframe, dtype_mappings = pre_process_dtypes(pandas_dataframe, schema=schema)

# Replace NaN values with 0
# Remove this once NaN values are not represented by 0 anymore
Expand Down
18 changes: 14 additions & 4 deletions src/concrete/ml/pandas/client_engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Define the framework used for managing keys (encrypt, decrypt) for encrypted data-frames."""

from pathlib import Path
from typing import Optional, Union
from typing import Dict, Optional, Union

import pandas

from concrete import fhe
from concrete.ml.pandas._development import CLIENT_PATH, get_encrypt_config
from concrete.ml.pandas._processing import post_process_to_pandas, pre_process_from_pandas
from concrete.ml.pandas._processing import (
check_schema_format,
post_process_to_pandas,
pre_process_from_pandas,
)
from concrete.ml.pandas._utils import decrypt_elementwise, encrypt_elementwise, encrypt_value
from concrete.ml.pandas.dataframe import EncryptedDataFrame

Expand Down Expand Up @@ -37,16 +41,22 @@ def keygen(self, keys_path: Optional[Union[Path, str]] = None):
else:
self.client.keygen(True)

def encrypt_from_pandas(self, pandas_dataframe: pandas.DataFrame) -> EncryptedDataFrame:
def encrypt_from_pandas(
self, pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None
) -> EncryptedDataFrame:
"""Encrypt a Pandas data-frame using the loaded client.
Args:
pandas_dataframe (DataFrame): The Pandas data-frame to encrypt.
schema (Optional[Dict]): The input schema to consider. Default to None.
Returns:
EncryptedDataFrame: The encrypted data-frame.
"""
pandas_array, dtype_mappings = pre_process_from_pandas(pandas_dataframe)

check_schema_format(pandas_dataframe, schema)

pandas_array, dtype_mappings = pre_process_from_pandas(pandas_dataframe, schema=schema)

# Inputs need to be encrypted element-wise in order to be able to use a composable circuit
# Once multi-operator is supported, better handle encryption configuration parameters
Expand Down
Loading

0 comments on commit ccd6641

Please sign in to comment.