Skip to content

Commit

Permalink
chore: improvement hybrid deployment (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored Jun 27, 2024
1 parent 829b68b commit 94feaaf
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _export_model_to_json(self, is_training: bool = False) -> Path:

return json_path

def save(self, mode: DeploymentMode = DeploymentMode.INFERENCE, via_mlir: bool = False):
def save(self, mode: DeploymentMode = DeploymentMode.INFERENCE, via_mlir: bool = True):
"""Export all needed artifacts for the client and server.
Arguments:
Expand Down
2 changes: 1 addition & 1 deletion src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _save_fhe_circuit(self, path: Path, via_mlir=False):
)
model_dev.save(via_mlir=via_mlir)

def save_and_clear_private_info(self, path: Path, via_mlir=False):
def save_and_clear_private_info(self, path: Path, via_mlir=True):
"""Save the PyTorch model to the provided path and also saves the corresponding FHE circuit.
Args:
Expand Down
6 changes: 4 additions & 2 deletions use_case_examples/hybrid_model/compile_hybrid_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def compile_model(
models_dir.mkdir(exist_ok=True)
model_dir = models_dir / model_name
print(f"Saving to {model_dir}")
via_mlir = bool(int(os.environ.get("VIA_MLIR", 0)))
via_mlir = bool(int(os.environ.get("VIA_MLIR", 1)))
hybrid_model.save_and_clear_private_info(model_dir, via_mlir=via_mlir)


Expand Down Expand Up @@ -101,7 +101,9 @@ def module_names_parser(string: str) -> List[str]:
max_context_size = 20
num_samples = 50

dataset = load_dataset("wikipedia", "20220301.en")
dataset = load_dataset(
"wikimedia/wikipedia", "20231101.en", revision="b04c8d1ceb2f5cd4588862100d08de323dccfbaa"
)
print(model)
models_dir = Path(__file__).parent / os.environ.get("MODELS_DIR_NAME", "compiled_models")
models_dir.mkdir(exist_ok=True)
Expand Down
1 change: 1 addition & 0 deletions use_case_examples/hybrid_model/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
concrete-ml
accelerate
datasets
transformers
apache_beam==2.49.0
mwparserfromhell==0.6.4
loguru==0.7.0

0 comments on commit 94feaaf

Please sign in to comment.