From 829b68b5b993049f5e177c1cb94d0c753ea3e80b Mon Sep 17 00:00:00 2001 From: Luis Montero Date: Tue, 25 Jun 2024 14:01:10 +0200 Subject: [PATCH] fix: dynamic import of `transformers` in hybrid model We currently have an import to `transformers` in the hybrid model source code but we don't have the library as a requirement, only a development dependency. Importing the library dynamically should fix this. --- src/concrete/ml/torch/hybrid_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/concrete/ml/torch/hybrid_model.py b/src/concrete/ml/torch/hybrid_model.py index e1b5e4e17..4586297b5 100644 --- a/src/concrete/ml/torch/hybrid_model.py +++ b/src/concrete/ml/torch/hybrid_model.py @@ -18,7 +18,6 @@ from brevitas.quant_tensor import QuantTensor from concrete.fhe import Configuration from torch import nn -from transformers import Conv1D from ..common.utils import MAX_BITWIDTH_BACKWARD_COMPATIBLE from ..deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer @@ -76,6 +75,11 @@ def convert_conv1d_to_linear(layer_or_module): nn.Module or nn.Linear: The updated module with Conv1D layers converted to Linear layers, or the Conv1D layer converted to a Linear layer. """ + try: + from transformers import Conv1D # pylint: disable=import-outside-toplevel + except ImportError: # pragma: no cover + return layer_or_module + if isinstance(layer_or_module, Conv1D): # Get the weight size weight_size = layer_or_module.weight.size()