diff --git a/src/concrete/ml/pandas/_processing.py b/src/concrete/ml/pandas/_processing.py index f77a97952..5133d9abe 100644 --- a/src/concrete/ml/pandas/_processing.py +++ b/src/concrete/ml/pandas/_processing.py @@ -2,7 +2,7 @@ import copy from collections import defaultdict -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy import pandas @@ -10,45 +10,62 @@ from concrete.ml.pandas._development import get_min_max_allowed from concrete.ml.quantization.quantizers import STABILITY_CONST +SCHEMA_FLOAT_KEYS = ["min", "max"] -def compute_scale_zero_point(column: pandas.Series, q_min: int, q_max: int) -> Tuple[float, float]: + +def is_str_or_none(column: pandas.Series) -> bool: + """Determine if the data-frames only contains string and None values or not. + + Args: + column (pandas.Series): The data-frame to consider. + + Returns: + bool: If the data-frames only contains string and None values or not. + """ + return column.apply(lambda x: isinstance(x, str) or not pandas.notna(x)).all() + + +def compute_scale_zero_point( + f_min: float, f_max: float, q_min: int, q_max: int +) -> Tuple[float, float]: """Compute the scale and zero point to use for quantizing / de-quantizing the given column. Note that the scale and zero point are computed so that values are quantized uniformly from - range [column.min(), column.max()] (float) to range [q_min, q_max] (int). + range [f_min, f_max] (float) to range [q_min, q_max] (int). Args: - column (pandas.Series): The column to consider. - q_min (int): The minimum quantized value to consider. - q_max (int): The maximum quantized value to consider. + f_min (float): The minimum float value observed. + f_max (float): The maximum float value observed. + q_min (int): The minimum quantized value to target. + q_max (int): The maximum quantized value to target. Returns: Tuple[float, float]: The scale and zero-point. """ - values_min, values_max = column.min(), column.max() # If there si a single float value in the column, the scale and zero-point need to be handled # differently - if values_max - values_min < STABILITY_CONST: + if f_max - f_min < STABILITY_CONST: # If this single float value is 0, make sure it is not quantized to 0 - if numpy.abs(values_max) < STABILITY_CONST: + if numpy.abs(f_max) < STABILITY_CONST: scale = 1.0 zero_point = -q_min # Else, quantize it to 1 else: - scale = 1 / values_max + scale = 1 / f_max zero_point = 0 else: - scale = (q_max - q_min) / (values_max - values_min) + scale = (q_max - q_min) / (f_max - f_min) # Zero-point must be rounded once NaN values are not represented by 0 anymore # The issue is that we currently need to avoid quantized values to reach 0, but having a # round here + in the 'quant' method can make this happen. # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342 - zero_point = values_min * scale - q_min + # Disable mypy until it is fixed + zero_point = f_min * scale - q_min # type: ignore[assignment] return scale, zero_point @@ -86,9 +103,49 @@ def dequant( return x.astype(dtype) -# Provide a way for users to pass string mappings -# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342 -def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataFrame, Dict]: +def check_schema_format(pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None) -> None: + """Check that the given schema has a proper expected format. + + Args: + pandas_dataframe (pandas.DataFrame): The data-frame associated to the given schema. + schema (Optional[Dict]): The schema to check, which can be None. Default to None. + + Raises: + ValueError: If the given schema is not a dict. + ValueError: If the given schema contains column names that do not appear in the data-frame. + ValueError: If one of the columns' mapping is not a dict. + """ + if schema is None: + return + + if not isinstance(schema, dict): + raise ValueError( + "When set, parameter 'schema' must be a dictionary that associates some of the " + f"data-frame's column names to their value mappings. Got {type(schema)=}" + ) + + column_names = list(pandas_dataframe.columns) + + for column_name, column_mapping in schema.items(): + if column_name not in column_names: + # TODO: Is this check actually relevant ? Can't the schema provide more columns than the + # one found in the data-frame ? + raise ValueError( + f"Column name '{column_name}' found in the given schema cannot be found in the " + f"input data-frame. Expected one of {column_names}" + ) + + if not isinstance(column_mapping, dict): + raise ValueError( + f"Mapping for column '{column_name}' is not a dictionary. Got " + f"{type(column_mapping)=}" + ) + + +# pylint: disable=too-many-branches, too-many-statements +def pre_process_dtypes( + pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None +) -> Tuple[pandas.DataFrame, Dict]: """Pre-process the Pandas data-frame and check that input dtypes and ranges are supported. Currently, three input dtypes are supported : integers (within a specific range), floating @@ -97,6 +154,7 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF Args: pandas_dataframe (pandas.DataFrame): The Pandas data-frame to pre-process. + schema (Optional[Dict]): The input schema to consider. Default to None. Raises: ValueError: If the values of a column with an integer dtype are out of bounds. @@ -118,6 +176,7 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF # Avoid sending column names to server, instead use hashes # FIXME : https://github.com/zama-ai/concrete-ml-internal/issues/4342 + # pylint: disable=too-many-nested-blocks for column_name in pandas_dataframe.columns: column = pandas_dataframe[column_name] column_dtype = column.dtype @@ -127,9 +186,13 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF # If the column contains integers, make sure they are not out of bounds if numpy.issubdtype(column_dtype, numpy.integer): - out_of_bounds = (column < q_min).any() or (column > q_max).any() + if schema is not None and column_name in schema: + raise ValueError( + f"Column '{column_name}' contains integer values and therefore does not " + "require any mappings. Please remove it" + ) - if out_of_bounds: + if column.min() < q_min or column.max() > q_max: raise ValueError( f"Column '{column_name}' (dtype={column_dtype}) contains values that are out " f"of bounds. Expected values to be in interval [min={q_min}, max={q_max}], but " @@ -138,7 +201,33 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF # If the column contains floats, quantize the values elif numpy.issubdtype(column_dtype, numpy.floating): - scale, zero_point = compute_scale_zero_point(column, q_min, q_max) + if schema is not None and column_name in schema: + float_min_max = schema[column_name] + + if not all( + float_mapping_key in SCHEMA_FLOAT_KEYS + for float_mapping_key in float_min_max.keys() + ): + raise ValueError( + f"Column '{column_name}' contains float values but the associated mapping " + f"does not contain proper keys. Expected {sorted(SCHEMA_FLOAT_KEYS)}, but " + f"got {sorted(float_min_max.keys())}" + ) + + f_min, f_max = float_min_max["min"], float_min_max["max"] + + if column.min() < f_min or column.max() > f_max: + raise ValueError( + f"Column '{column_name}' (dtype={column_dtype}) contains values that are " + f"out of bounds. Expected values to be in interval [min={f_min}, " + f"max={f_max}], as determined by the given schema, but found " + f"[min={column.min()}, max={column.max()}]." + ) + + else: + f_min, f_max = column.min(), column.max() + + scale, zero_point = compute_scale_zero_point(f_min, f_max, q_min, q_max) q_column = quant(column, scale, zero_point) @@ -150,14 +239,58 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF # If the column contains objects, make sure it is only made of strings or NaN values elif column_dtype == "object": - is_str = column.apply(lambda x: isinstance(x, str) or not pandas.notna(x)).all() - - if is_str: - - # Build a mapping between the unique strings values and integers - str_to_int = { - str_value: i + 1 for i, str_value in enumerate(column.dropna().unique()) - } + if is_str_or_none(column): + if schema is not None and column_name in schema: + str_to_int = schema[column_name] + + column_values = set(column.values) - set([None, numpy.NaN]) + string_mapping_keys = set(str_to_int.keys()) + + # Allow custom mapping for NaN values once they are not represented by 0 anymore + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342 + if numpy.NaN in string_mapping_keys: + raise ValueError( + f"String mapping for column '{column_name}' contains numpy.NaN as a " + "key, which is currently forbidden" + ) + + forgotten_string_values = column_values - string_mapping_keys + + if forgotten_string_values: + raise ValueError( + f"String mapping keys for column '{column_name}' are not considering " + "all values from the data-frame. Missing values: " + f"{sorted(forgotten_string_values)}" + ) + + for string_mapping_key, string_mapping_value in str_to_int.items(): + if not isinstance(string_mapping_value, int): + raise ValueError( + f"String mapping values for column '{column_name}' must be " + f"integers. Got {type(string_mapping_value)} for key " + f"{string_mapping_key}" + ) + + if string_mapping_value < q_min or string_mapping_value > q_max: + raise ValueError( + f"String mapping values for column '{column_name}' are out of " + f"bounds. Expected values to be in interval [min={q_min}, " + f"max={q_max}] but got {string_mapping_value} for key " + f"{string_mapping_key}" + ) + + if len(str_to_int.values()) != len(set(str_to_int.values())): + raise ValueError( + f"String mapping values for column '{column_name}' must be unique. Got " + f"{str_to_int.values()}" + ) + + else: + + # Build a mapping between the unique strings values and integers + str_to_int = { + str_value: i + 1 for i, str_value in enumerate(column.dropna().unique()) + } # Make sure the number of unique values do not goes over the maximum integer value # allowed in an encrypted data-frame @@ -189,14 +322,20 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF "supported." ) + # TODO: Should all non-integers columns be considered by the schema if not None ? Currently, + # mappings are computed automatically if schema is not set + return pandas_dataframe, dtype_mappings -def pre_process_from_pandas(pandas_dataframe: pandas.DataFrame) -> Tuple[numpy.ndarray, Dict]: +def pre_process_from_pandas( + pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None +) -> Tuple[numpy.ndarray, Dict]: """Pre-process the Pandas data-frame. Args: pandas_dataframe (pandas.DataFrame): The Pandas data-frame to pre-process. + schema (Optional[Dict]): The input schema to consider. Default to None. Raises: ValueError: If the data-frame's index has not been reset (meaning the index is not a @@ -217,7 +356,7 @@ def pre_process_from_pandas(pandas_dataframe: pandas.DataFrame) -> Tuple[numpy.n ) # Check that values are supported and build the mappings - q_pandas_dataframe, dtype_mappings = pre_process_dtypes(pandas_dataframe) + q_pandas_dataframe, dtype_mappings = pre_process_dtypes(pandas_dataframe, schema=schema) # Replace NaN values with 0 # Remove this once NaN values are not represented by 0 anymore diff --git a/src/concrete/ml/pandas/client_engine.py b/src/concrete/ml/pandas/client_engine.py index 98fd5cbbe..4b2e3297b 100644 --- a/src/concrete/ml/pandas/client_engine.py +++ b/src/concrete/ml/pandas/client_engine.py @@ -1,13 +1,17 @@ """Define the framework used for managing keys (encrypt, decrypt) for encrypted data-frames.""" from pathlib import Path -from typing import Optional, Union +from typing import Dict, Optional, Union import pandas from concrete import fhe from concrete.ml.pandas._development import CLIENT_PATH, get_encrypt_config -from concrete.ml.pandas._processing import post_process_to_pandas, pre_process_from_pandas +from concrete.ml.pandas._processing import ( + check_schema_format, + post_process_to_pandas, + pre_process_from_pandas, +) from concrete.ml.pandas._utils import decrypt_elementwise, encrypt_elementwise, encrypt_value from concrete.ml.pandas.dataframe import EncryptedDataFrame @@ -37,16 +41,22 @@ def keygen(self, keys_path: Optional[Union[Path, str]] = None): else: self.client.keygen(True) - def encrypt_from_pandas(self, pandas_dataframe: pandas.DataFrame) -> EncryptedDataFrame: + def encrypt_from_pandas( + self, pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None + ) -> EncryptedDataFrame: """Encrypt a Pandas data-frame using the loaded client. Args: pandas_dataframe (DataFrame): The Pandas data-frame to encrypt. + schema (Optional[Dict]): The input schema to consider. Default to None. Returns: EncryptedDataFrame: The encrypted data-frame. """ - pandas_array, dtype_mappings = pre_process_from_pandas(pandas_dataframe) + + check_schema_format(pandas_dataframe, schema) + + pandas_array, dtype_mappings = pre_process_from_pandas(pandas_dataframe, schema=schema) # Inputs need to be encrypted element-wise in order to be able to use a composable circuit # Once multi-operator is supported, better handle encryption configuration parameters diff --git a/tests/pandas/test_pandas.py b/tests/pandas/test_pandas.py index 8a5084135..1948f2a89 100644 --- a/tests/pandas/test_pandas.py +++ b/tests/pandas/test_pandas.py @@ -1,5 +1,6 @@ """Tests the encrypted data-frame API abd its coherence with Pandas""" +import copy import re import shutil import tempfile @@ -298,7 +299,7 @@ def test_save_load(): def check_invalid_merge_parameters(): - """Check that unsupported or invalid parameters for merge raise the correct errors.""" + """Check that unsupported or invalid parameters for merge raise correct errors.""" encrypted_df_left, encrypted_df_right = get_two_encrypted_dataframes() unsupported_pandas_parameters_and_values = [ @@ -345,7 +346,7 @@ def check_no_multi_columns_merge(): def check_column_coherence(): - """Check that merging data-frames with unsupported scheme raise the correct errors.""" + """Check that merging data-frames with unsupported scheme raises correct errors.""" index_name = "index" # Test when a selected column has a different dtype than the other one @@ -394,7 +395,7 @@ def check_column_coherence(): def check_unsupported_input_values(): - """Check that initializing a data-frame with unsupported inputs raise the correct errors.""" + """Check that initializing a data-frame with unsupported inputs raises correct errors.""" client = ClientEngine() # Test with integer values that are out of bound @@ -451,7 +452,7 @@ def check_unsupported_input_values(): def check_post_processing_coherence(): - """Check post-processing a data-frame with unsupported scheme raise the correct errors.""" + """Check post-processing a data-frame with unsupported scheme raises correct errors.""" index_name = "index" client = ClientEngine() @@ -480,6 +481,8 @@ def test_error_raises(): check_column_coherence() check_unsupported_input_values() check_post_processing_coherence() + check_invalid_schema_format() + check_invalid_schema_values() def deserialize_client_file(client_path: Union[Path, str]) -> ClientSpecs: @@ -571,3 +574,211 @@ def test_print_and_repr(): assert pandas_dataframe_are_equal( expected_schema, schema, equal_nan=True ), "Expected and retrieved schemas do not match." + + +def get_input_schema(pandas_dataframe, selected_schema=None): + """Get a data-frame's expected input schema.""" + schema = {} + for column_name in pandas_dataframe.columns: + column = pandas_dataframe[column_name] + if numpy.issubdtype(column.dtype, numpy.floating): + schema[column_name] = { + "min": column.min(), + "max": column.max(), + } + + elif column.dtype == "object": + unique_values = column.unique() + + # Only take strings into account and thus avoid NaN values + schema[column_name] = { + val: i for i, val in enumerate(unique_values) if isinstance(val, str) + } + + # Update the common column's mapping + if selected_schema is not None: + schema.update(selected_schema) + + return schema + + +def test_schema_input(): + """Test that users can properly provide schemas when encrypting data-frames.""" + selected_column = "index" + pandas_kwargs = {"how": "left", "on": selected_column} + + with tempfile.TemporaryDirectory() as temp_dir: + keys_path = Path(temp_dir) / "keys" + + client_1 = ClientEngine(keys_path=keys_path) + client_2 = ClientEngine(keys_path=keys_path) + + indexes_left = ["one", "two", "three", "four"] + indexes_right = ["two", "three"] + + schema_index = {selected_column: {"one": 1, "two": 2, "three": 3, "four": 4}} + + pandas_df_left = generate_pandas_dataframe( + feat_name="left", index_name=selected_column, indexes=indexes_left, index_position=2 + ) + pandas_df_right = generate_pandas_dataframe( + feat_name="right", index_name=selected_column, indexes=indexes_right, index_position=1 + ) + + schema_left = get_input_schema(pandas_df_left, selected_schema=schema_index) + schema_right = get_input_schema(pandas_df_right, selected_schema=schema_index) + + encrypted_df_left = client_1.encrypt_from_pandas(pandas_df_left, schema=schema_left) + encrypted_df_right = client_2.encrypt_from_pandas(pandas_df_right, schema=schema_right) + + pandas_joined_df = pandas_df_left.merge(pandas_df_right, **pandas_kwargs) + encrypted_df_joined = encrypted_df_left.merge(encrypted_df_right, **pandas_kwargs) + + clear_df_joined_1 = client_1.decrypt_to_pandas(encrypted_df_joined) + clear_df_joined_2 = client_2.decrypt_to_pandas(encrypted_df_joined) + + assert pandas_dataframe_are_equal( + clear_df_joined_1, clear_df_joined_2, equal_nan=True + ), "Joined encrypted data-frames decrypted by different clients are not equal." + + # Improve the test to avoid risk of flaky + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342 + assert pandas_dataframe_are_equal( + clear_df_joined_1, pandas_joined_df, float_atol=1, equal_nan=True + ), "Joined encrypted data-frame does not match Pandas' joined data-frame." + + +def check_invalid_schema_format(): + """Check that encrypting data-frames with an unsupported schema format raises correct errors.""" + selected_column = "index" + + with tempfile.TemporaryDirectory() as temp_dir: + keys_path = Path(temp_dir) / "keys" + + client = ClientEngine(keys_path=keys_path) + + pandas_df = generate_pandas_dataframe(index_name=selected_column) + + with pytest.raises( + ValueError, + match="When set, parameter 'schema' must be a dictionary.*", + ): + client.encrypt_from_pandas(pandas_df, schema=[]) + + schema_wrong_column = {"wrong_column": None} + + with pytest.raises( + ValueError, + match="Column name '.*' found in the given schema cannot be found.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_wrong_column) + + schema_wrong_mapping_type = {selected_column: [None]} + + with pytest.raises( + ValueError, + match="Mapping for column '.*' is not a dictionary. .*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_wrong_mapping_type) + + +def check_invalid_schema_values(): + """Check that encrypting data-frames with an unsupported schema values raises correct errors.""" + selected_column = "index" + feat_name = "feat" + float_min = -10.0 + float_max = 10.0 + + with tempfile.TemporaryDirectory() as temp_dir: + keys_path = Path(temp_dir) / "keys" + + client = ClientEngine(keys_path=keys_path) + + pandas_df = generate_pandas_dataframe( + feat_name=feat_name, index_name=selected_column, float_min=float_min, float_max=float_max + ) + + schema_int_column = {f"{feat_name}_int_1": {None: None}} + + with pytest.raises( + ValueError, + match="Column '.*' contains integer values and therefore does not.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_int_column) + + schema_float_column = {f"{feat_name}_float_1": {"wrong_mapping": 1.0}} + + with pytest.raises( + ValueError, + match="Column '.*' contains float values but the associated mapping.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_float_column) + + schema_float_oob = {f"{feat_name}_float_1": {"min": float_min // 2, "max": float_max // 2}} + + with pytest.raises( + ValueError, + match=r"Column '.*' \(dtype=float64\) contains values that are out of bounds.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_float_oob) + + string_column = f"{feat_name}_str_1" + + schema_string_nan = {string_column: {numpy.NaN: 1}} + + with pytest.raises( + ValueError, + match="String mapping for column '.*' contains numpy.NaN as a key.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_string_nan) + + schema_string_missing_values = {string_column: {"apple": 1}} + + with pytest.raises( + ValueError, + match="String mapping keys for column '.*' are not considering all values.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_string_missing_values) + + # Retrieve the string column's unique values and create a mapping, except for numpy.NaN values + string_values = pandas_df[string_column].unique() + string_values = [ + string_value for string_value in string_values if isinstance(string_value, str) + ] + string_mapping = {val: i for i, val in enumerate(string_values)} + + string_mapping_non_int = copy.copy(string_mapping) + + # Disable mypy as this type assignment is expected for the error to be raised + string_mapping_non_int[string_values[0]] = "orange" # type: ignore[assignment] + + schema_string_non_int = {string_column: string_mapping_non_int} + + with pytest.raises( + ValueError, + match="String mapping values for column '.*' must be integers.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_string_non_int) + + string_mapping_oob = copy.copy(string_mapping) + string_mapping_oob[string_values[0]] = -1 + + schema_string_oob = {string_column: string_mapping_oob} + + with pytest.raises( + ValueError, + match="String mapping values for column '.*' are out of bounds.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_string_oob) + + string_mapping_non_unique = copy.copy(string_mapping) + string_mapping_non_unique[string_values[0]] = 1 + string_mapping_non_unique[string_values[1]] = 1 + + schema_string_non_unique = {string_column: string_mapping_non_unique} + + with pytest.raises( + ValueError, + match="String mapping values for column '.*' must be unique.*", + ): + client.encrypt_from_pandas(pandas_df, schema=schema_string_non_unique)