Skip to content

Commit

Permalink
feat: add FHE training deployment (#665)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Jun 3, 2024
1 parent 1b57b9c commit b718629
Show file tree
Hide file tree
Showing 18 changed files with 858 additions and 555 deletions.
511 changes: 395 additions & 116 deletions docs/advanced_examples/LogisticRegressionTraining.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ filterwarnings = [
"ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning",
"ignore:Converting a tensor to a NumPy array might cause the trace to be incorrect.",
"ignore:torch.from_numpy results are registered as constants in the trace.",
"ignore:ONNX Preprocess - Removing mutation from node aten*:UserWarning",
"ignore:Liblinear failed to converge,*:sklearn.exceptions.ConvergenceWarning",
"ignore:lbfgs failed to converge,*:sklearn.exceptions.ConvergenceWarning",
"ignore:Maximum number of iteration reached before convergence.*:sklearn.exceptions.ConvergenceWarning",
]

[tool.semantic_release]
Expand Down
38 changes: 30 additions & 8 deletions script/make_utils/check_pytest_determinism.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,35 @@ then
exit 255
fi

set -e

# Exceptions:
# passed in: since it is related to timings
diff "${OUTPUT_DIRECTORY}/one.txt" "${OUTPUT_DIRECTORY}/two.txt" -I "passed in"
echo "Successful determinism check"
# in X.Xs: since it is related to timings
diff_output=$(diff "${OUTPUT_DIRECTORY}/one.txt" "${OUTPUT_DIRECTORY}/two.txt" -I "in [0-9]*\.[0-9]*s")

# If a diff is present, we need to print the tests that failed to be reproduced for debugging.
if [ -n "$diff_output" ]; then
echo "Differences found:"
echo "$diff_output"

# Extract line numbers of differences
diff_lines=$(echo "$diff_output" | grep -E '^[0-9]+(,[0-9]+)?[acd][0-9]+(,[0-9]+)?' | sed -E 's/([0-9]+).*/\1/')

for line in $diff_lines; do

# Find the first line number before the diff that starts with 'tests/seeding/'
start_line_num=$(awk 'NR<'"$line"' && /^tests\/seeding\// {print NR}' "${OUTPUT_DIRECTORY}/one.txt" | tail -n 1)

# Print lines from start_line_num to the diff line
if [ -n "$start_line_num" ]; then
sed -n "${start_line_num},${line}p" "${OUTPUT_DIRECTORY}/one.txt"
fi
done

exit 255
else
echo "Successful determinism check"
fi

set -e

# Now, check that one can reproduce conditions of a bug in a single file
# and test without having to relaunch the full pytest
Expand All @@ -58,14 +81,13 @@ do
# SC2086 is about double quote to prevent globbing and word splitting, but here, it makes that we have
# an empty arg in pytest, which is considered as "do pytest for all files"
# shellcheck disable=SC2086
poetry run pytest "$x" -xsvv $EXTRA_OPTION --randomly-dont-reset-seed | sed -n -e '/collecting/,$p' | grep -v collecting | grep -v "collected" | grep -v "passed in" | grep -v "PASSED" >> "${OUTPUT_DIRECTORY}/three.txt"
poetry run pytest "$x" -xsvv $EXTRA_OPTION --randomly-dont-reset-seed | sed -n -e '/collecting/,$p' | grep -v collecting | grep -v "collected" | grep -v "PASSED" | grep -v "SKIPPED" | grep -v "in [0-9]*\.[0-9]*s" >> "${OUTPUT_DIRECTORY}/three.txt"

((WHICH+=1))
done

# Clean a bit one.txt
sed -n -e '/collecting/,$p' "${OUTPUT_DIRECTORY}/one.txt" | grep -v collecting | grep -v "collected" | grep -v "passed in" | grep -v "PASSED" | grep -v "Leaving directory" > "${OUTPUT_DIRECTORY}/one.modified.txt"

sed -n -e '/collecting/,$p' "${OUTPUT_DIRECTORY}/one.txt" | grep -v collecting | grep -v "collected" | grep -v "PASSED" | grep -v "SKIPPED" |grep -v "Leaving directory" | grep -v "in [0-9]*\.[0-9]*s" > "${OUTPUT_DIRECTORY}/one.modified.txt"
echo ""
echo "diff:"
echo ""
Expand Down
178 changes: 130 additions & 48 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import json
import sys
import zipfile
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Tuple, Union

import numpy

Expand All @@ -13,6 +14,7 @@
from ..common.debugging.custom_assert import assert_true
from ..common.serialization.dumpers import dump
from ..common.serialization.loaders import load
from ..common.utils import to_tuple
from ..version import __version__ as CML_VERSION

try:
Expand All @@ -25,6 +27,25 @@
from importlib_metadata import version


class DeploymentMode(Enum):
"""Mode for the FHE API."""

INFERENCE = "inference"
TRAINING = "training"

@staticmethod
def is_valid(mode: Union["DeploymentMode", str]) -> bool:
"""Indicate if the given name is a supported mode.
Args:
mode (Union[Mode, str]): The mode to check.
Returns:
bool: Whether the mode is supported or not.
"""
return mode in {member.value for member in DeploymentMode.__members__.values()}


def check_concrete_versions(zip_path: Path):
"""Check that current versions match the ones used in development.
Expand Down Expand Up @@ -105,30 +126,35 @@ def load(self):

def run(
self,
serialized_encrypted_quantized_data: bytes,
serialized_encrypted_quantized_data: Union[bytes, Tuple[bytes, ...]],
serialized_evaluation_keys: bytes,
) -> bytes:
) -> Union[bytes, Tuple[bytes, ...]]:
"""Run the model on the server over encrypted data.
Args:
serialized_encrypted_quantized_data (bytes): the encrypted, quantized
and serialized data
serialized_encrypted_quantized_data (Union[bytes, Tuple[bytes, ...]]): the encrypted,
quantized and serialized data
serialized_evaluation_keys (bytes): the serialized evaluation keys
Returns:
bytes: the result of the model
Union[bytes, Tuple[bytes, ...]]: the result of the model
"""
assert_true(self.server is not None, "Model has not been loaded.")

deserialized_encrypted_quantized_data = fhe.Value.deserialize(
serialized_encrypted_quantized_data
serialized_encrypted_quantized_data = to_tuple(serialized_encrypted_quantized_data)

deserialized_data = tuple(
fhe.Value.deserialize(data) for data in serialized_encrypted_quantized_data
)
deserialized_evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)
result = self.server.run(
deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys
deserialized_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)

result = self.server.run(*deserialized_data, evaluation_keys=deserialized_keys)

return (
tuple(res.serialize() for res in result)
if isinstance(result, tuple)
else result.serialize()
)
serialized_result = result.serialize()
return serialized_result


class FHEModelDev:
Expand All @@ -149,17 +175,22 @@ def __init__(self, path_dir: str, model: Any = None):

Path(self.path_dir).mkdir(parents=True, exist_ok=True)

def _export_model_to_json(self) -> Path:
def _export_model_to_json(self, is_training: bool = False) -> Path:
"""Export the quantizers to a json file.
Args:
is_training (bool): If True, we export the training circuit.
Returns:
Path: the path to the json file
"""
module_to_export = self.model.training_quantized_module if is_training else self.model
serialized_processing = {
"model_type": self.model.__class__,
"model_post_processing_params": self.model.post_processing_params,
"input_quantizers": self.model.input_quantizers,
"output_quantizers": self.model.output_quantizers,
"model_type": module_to_export.__class__,
"model_post_processing_params": module_to_export.post_processing_params,
"input_quantizers": module_to_export.input_quantizers,
"output_quantizers": module_to_export.output_quantizers,
"is_training": is_training,
}

# Export the `is_fitted` attribute for built-in models
Expand All @@ -173,35 +204,58 @@ def _export_model_to_json(self) -> Path:

return json_path

def save(self, via_mlir: bool = False):
def save(self, mode: DeploymentMode = DeploymentMode.INFERENCE, via_mlir: bool = False):
"""Export all needed artifacts for the client and server.
Arguments:
mode (DeploymentMode): the mode to save the FHE circuit,
either "inference" or "training".
via_mlir (bool): serialize with `via_mlir` option from Concrete-Python.
For more details on the topic please refer to Concrete-Python's documentation.
Raises:
Exception: path_dir is not empty
Exception: path_dir is not empty or training module does not exist
ValueError: if mode is not "inference" or "training"
"""

if isinstance(mode, str):
mode_lower = mode.lower()
if not DeploymentMode.is_valid(mode_lower):
raise ValueError("Mode must be either 'inference' or 'training'")
mode = DeploymentMode(mode_lower)

# Get fhe_circuit based on the mode
if mode == DeploymentMode.TRAINING:

# Check that training FHE circuit exists
assert_true(
hasattr(self.model, "training_quantized_module")
and (self.model.training_quantized_module),
"Training FHE circuit does not exist.",
)
self.model.training_quantized_module.check_model_is_compiled()
fhe_circuit = self.model.training_quantized_module.fhe_circuit
else:
self.model.check_model_is_compiled()
fhe_circuit = self.model.fhe_circuit

# Check if the path_dir is empty with pathlib
listdir = list(Path(self.path_dir).glob("**/*"))
if len(listdir) > 0:
raise Exception(
f"path_dir: {self.path_dir} is not empty."
f"path_dir: {self.path_dir} is not empty. "
"Please delete it before saving a new model."
)
self.model.check_model_is_compiled()

# Export the quantizers
json_path = self._export_model_to_json()
json_path = self._export_model_to_json(is_training=(mode == DeploymentMode.TRAINING))

# First save the circuit for the server
# Save the circuit for the server
path_circuit_server = Path(self.path_dir).joinpath("server.zip")
self.model.fhe_circuit.server.save(path_circuit_server, via_mlir=via_mlir)
fhe_circuit.server.save(path_circuit_server, via_mlir=via_mlir)

# Save the circuit for the client
path_circuit_client = Path(self.path_dir).joinpath("client.zip")
self.model.fhe_circuit.client.save(path_circuit_client)
fhe_circuit.client.save(path_circuit_client)

with zipfile.ZipFile(path_circuit_client, "a") as zip_file:
zip_file.write(filename=json_path, arcname="serialized_processing.json")
Expand Down Expand Up @@ -265,6 +319,8 @@ def load(self): # pylint: disable=no-value-for-parameter

# Initialize the model
self.model = serialized_processing["model_type"]()

# Load the quantizers
self.model.input_quantizers = serialized_processing["input_quantizers"]
self.model.output_quantizers = serialized_processing["output_quantizers"]

Expand Down Expand Up @@ -300,72 +356,98 @@ def get_serialized_evaluation_keys(self) -> bytes:

return self.client.evaluation_keys.serialize()

def quantize_encrypt_serialize(self, x: numpy.ndarray) -> bytes:
def quantize_encrypt_serialize(
self, x: Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]
) -> Union[bytes, Tuple[bytes, ...]]:
"""Quantize, encrypt and serialize the values.
Args:
x (numpy.ndarray): the values to quantize, encrypt and serialize
x (Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]): the values to quantize,
encrypt and serialize
Returns:
bytes: the quantized, encrypted and serialized values
Union[bytes, Tuple[bytes, ...]]: the quantized, encrypted and serialized values
"""

x = to_tuple(x)

# Quantize the values
quantized_x = self.model.quantize_input(x)
quantized_x = self.model.quantize_input(*x)

quantized_x = to_tuple(quantized_x)

# Encrypt the values
enc_qx = self.client.encrypt(quantized_x)
enc_qx = self.client.encrypt(*quantized_x)

enc_qx = to_tuple(enc_qx)

# Serialize the encrypted values to be sent to the server
serialized_enc_qx = enc_qx.serialize()
return serialized_enc_qx
serialized_enc_qx = tuple(e.serialize() for e in enc_qx)

def deserialize_decrypt(self, serialized_encrypted_quantized_result: bytes) -> numpy.ndarray:
# Return a single value if the original input was a single value
return serialized_enc_qx[0] if len(serialized_enc_qx) == 1 else serialized_enc_qx

def deserialize_decrypt(
self, serialized_encrypted_quantized_result: Union[bytes, Tuple[bytes, ...]]
) -> Union[Any, Tuple[Any, ...]]:
"""Deserialize and decrypt the values.
Args:
serialized_encrypted_quantized_result (bytes): the serialized, encrypted
and quantized result
serialized_encrypted_quantized_result (Union[bytes, Tuple[bytes, ...]]): the
serialized, encrypted and quantized result
Returns:
numpy.ndarray: the decrypted and deserialized values
Union[Any, Tuple[Any, ...]]: the decrypted and deserialized values
"""

serialized_encrypted_quantized_result = to_tuple(serialized_encrypted_quantized_result)

# Deserialize the encrypted values
deserialized_encrypted_quantized_result = fhe.Value.deserialize(
serialized_encrypted_quantized_result
deserialized_encrypted_quantized_result = tuple(
fhe.Value.deserialize(data) for data in serialized_encrypted_quantized_result
)

# Decrypt the values
deserialized_decrypted_quantized_result = self.client.decrypt(
deserialized_encrypted_quantized_result
*deserialized_encrypted_quantized_result
)
assert isinstance(deserialized_decrypted_quantized_result, numpy.ndarray)

return deserialized_decrypted_quantized_result

def deserialize_decrypt_dequantize(
self, serialized_encrypted_quantized_result: bytes
self, serialized_encrypted_quantized_result: Union[bytes, Tuple[bytes, ...]]
) -> numpy.ndarray:
"""Deserialize, decrypt and de-quantize the values.
Args:
serialized_encrypted_quantized_result (bytes): the serialized, encrypted
and quantized result
serialized_encrypted_quantized_result (Union[bytes, Tuple[bytes, ...]]): the
serialized, encrypted and quantized result
Returns:
numpy.ndarray: the decrypted (de-quantized) values
"""
# Ensure the input is a tuple
serialized_encrypted_quantized_result = to_tuple(serialized_encrypted_quantized_result)

# Decrypt and deserialize the values
deserialized_decrypted_quantized_result = self.deserialize_decrypt(
serialized_encrypted_quantized_result
)

deserialized_decrypted_quantized_result = to_tuple(deserialized_decrypted_quantized_result)

# De-quantize the values
deserialized_decrypted_dequantized_result = self.model.dequantize_output(
deserialized_decrypted_quantized_result
*deserialized_decrypted_quantized_result
)

# Apply post-processing the to de-quantized values
deserialized_decrypted_dequantized_result = self.model.post_processing(
deserialized_decrypted_dequantized_result = to_tuple(
deserialized_decrypted_dequantized_result
)

# Apply post-processing to the de-quantized values
deserialized_decrypted_dequantized_result = self.model.post_processing(
*deserialized_decrypted_dequantized_result
)

return deserialized_decrypted_dequantized_result
Loading

0 comments on commit b718629

Please sign in to comment.