From 94feaaf92e83ba69715f6d19e6822415fe3bfbff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jordan=20Fr=C3=A9ry?= Date: Thu, 27 Jun 2024 16:34:47 +0200 Subject: [PATCH] chore: improvement hybrid deployment (#766) --- src/concrete/ml/deployment/fhe_client_server.py | 2 +- src/concrete/ml/torch/hybrid_model.py | 2 +- use_case_examples/hybrid_model/compile_hybrid_llm.py | 6 ++++-- use_case_examples/hybrid_model/requirements.txt | 1 + 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/concrete/ml/deployment/fhe_client_server.py b/src/concrete/ml/deployment/fhe_client_server.py index 7018524e2..6d28541c3 100644 --- a/src/concrete/ml/deployment/fhe_client_server.py +++ b/src/concrete/ml/deployment/fhe_client_server.py @@ -239,7 +239,7 @@ def _export_model_to_json(self, is_training: bool = False) -> Path: return json_path - def save(self, mode: DeploymentMode = DeploymentMode.INFERENCE, via_mlir: bool = False): + def save(self, mode: DeploymentMode = DeploymentMode.INFERENCE, via_mlir: bool = True): """Export all needed artifacts for the client and server. Arguments: diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index 4586297b5..3b54bca42 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -539,7 +539,7 @@ def _save_fhe_circuit(self, path: Path, via_mlir=False): ) model_dev.save(via_mlir=via_mlir) - def save_and_clear_private_info(self, path: Path, via_mlir=False): + def save_and_clear_private_info(self, path: Path, via_mlir=True): """Save the PyTorch model to the provided path and also saves the corresponding FHE circuit. Args: diff --git a/use_case_examples/hybrid_model/compile_hybrid_llm.py b/use_case_examples/hybrid_model/compile_hybrid_llm.py index d1e3cb6eb..100b0a33b 100644 --- a/use_case_examples/hybrid_model/compile_hybrid_llm.py +++ b/use_case_examples/hybrid_model/compile_hybrid_llm.py @@ -34,7 +34,7 @@ def compile_model( models_dir.mkdir(exist_ok=True) model_dir = models_dir / model_name print(f"Saving to {model_dir}") - via_mlir = bool(int(os.environ.get("VIA_MLIR", 0))) + via_mlir = bool(int(os.environ.get("VIA_MLIR", 1))) hybrid_model.save_and_clear_private_info(model_dir, via_mlir=via_mlir) @@ -101,7 +101,9 @@ def module_names_parser(string: str) -> List[str]: max_context_size = 20 num_samples = 50 - dataset = load_dataset("wikipedia", "20220301.en") + dataset = load_dataset( + "wikimedia/wikipedia", "20231101.en", revision="b04c8d1ceb2f5cd4588862100d08de323dccfbaa" + ) print(model) models_dir = Path(__file__).parent / os.environ.get("MODELS_DIR_NAME", "compiled_models") models_dir.mkdir(exist_ok=True) diff --git a/use_case_examples/hybrid_model/requirements.txt b/use_case_examples/hybrid_model/requirements.txt index 7522bdbb9..c20d74837 100644 --- a/use_case_examples/hybrid_model/requirements.txt +++ b/use_case_examples/hybrid_model/requirements.txt @@ -1,6 +1,7 @@ concrete-ml accelerate datasets +transformers apache_beam==2.49.0 mwparserfromhell==0.6.4 loguru==0.7.0