Skip to content

Commit

Permalink
chore: fix coverage
Browse files Browse the repository at this point in the history
+ fix wrong Conv1D transformer import
  • Loading branch information
jfrery committed Nov 25, 2024
1 parent ef2127d commit 7ced6f4
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 142 deletions.
34 changes: 24 additions & 10 deletions src/concrete/ml/quantization/linear_op_glwe_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
from ..common.utils import HybridFHEMode, to_tuple
from .quantized_module import QuantizedModule

try:
import concrete_ml_extensions as fhext

_HAS_GLWE_BACKEND = True
except ImportError: # pragma: no cover
fhext = None
_HAS_GLWE_BACKEND = False
def has_glwe_backend():
"""Check if the GLWE backend is installed.
Returns:
bool: True if the GLWE backend is installed, False otherwise.
"""
try:
__import__("concrete_ml_extensions")
return True
except ImportError:
return False


class GLWELinearLayerExecutor:
Expand All @@ -24,6 +29,13 @@ def __init__(
private_key=None,
compression_key=None,
):
if not has_glwe_backend():
raise RuntimeError("GLWE backend not installed")

import concrete_ml_extensions as fhext

self.fhext = fhext

self.compression_key = compression_key
self.private_key = private_key

Expand All @@ -39,7 +51,9 @@ def __init__(
def keygen(self):
"""Generate private and compression key."""
# pylint: disable-next=no-member
self.private_key, self.compression_key = fhext.create_private_key(self.glwe_crypto_params)
self.private_key, self.compression_key = self.fhext.create_private_key(
self.glwe_crypto_params
)

def forward(
self, x: numpy.ndarray, q_module: QuantizedModule, fhe: HybridFHEMode
Expand Down Expand Up @@ -123,15 +137,15 @@ def forward(

for idx, q_x_sample in enumerate(q_x):

ciphertext = fhext.encrypt_matrix( # pylint: disable=no-member
ciphertext = self.fhext.encrypt_matrix( # pylint: disable=no-member
pkey=self.private_key, crypto_params=self.glwe_crypto_params, data=q_x_sample
)
encrypted_result = fhext.matrix_multiplication( # pylint: disable=no-member
encrypted_result = self.fhext.matrix_multiplication( # pylint: disable=no-member
encrypted_matrix=ciphertext,
data=q_weight.astype(numpy.uint64),
compression_key=self.compression_key,
)
q_result = fhext.decrypt_matrix( # pylint: disable=no-member
q_result = self.fhext.decrypt_matrix( # pylint: disable=no-member
encrypted_result,
self.private_key,
self.glwe_crypto_params,
Expand Down
26 changes: 13 additions & 13 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ..common.utils import MAX_BITWIDTH_BACKWARD_COMPATIBLE, HybridFHEMode
from ..deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
from ..quantization.linear_op_glwe_backend import _HAS_GLWE_BACKEND, GLWELinearLayerExecutor
from ..quantization.linear_op_glwe_backend import GLWELinearLayerExecutor, has_glwe_backend
from .compile import (
QuantizedModule,
build_quantized_module,
Expand Down Expand Up @@ -69,7 +69,7 @@ def convert_conv1d_to_linear(layer_or_module):
or the Conv1D layer converted to a Linear layer.
"""
try:
from transformers import Conv1D # pylint: disable=import-outside-toplevel
from transformers.modeling_utils import Conv1D # pylint: disable=import-outside-toplevel
except ImportError: # pragma: no cover
return layer_or_module

Expand Down Expand Up @@ -412,13 +412,14 @@ def _replace_modules(self):
if is_pure_linear_layer:
module = self.private_modules[module_name]
# Use weight shape instead of in/out_features
if hasattr(module, "weight"):
input_dim = module.weight.shape[
1
] # Input dimension is second dimension for Linear layers
output_dim = module.weight.shape[0] # Output dimension is first dimension
else:
input_dim = output_dim = 0
input_dim, output_dim = (
(
module.weight.shape[1],
module.weight.shape[0],
)
if hasattr(module, "weight")
else (0, 0)
)

is_pure_linear_layer = (
is_pure_linear_layer and input_dim >= 512 and output_dim >= 512
Expand Down Expand Up @@ -465,7 +466,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:
# Validate the FHE mode
fhe_mode = HybridFHEMode(fhe)

if _HAS_GLWE_BACKEND and self._has_only_large_linear_layers:
if has_glwe_backend() and self._has_only_large_linear_layers:
if fhe_mode == HybridFHEMode.SIMULATE:
raise AssertionError(
"When the HybridFHEModel is instantiated with only "
Expand All @@ -474,8 +475,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:

if fhe_mode in (HybridFHEMode.EXECUTE, HybridFHEMode.REMOTE, HybridFHEMode.DISABLE):
# Initialize executor only if not already done
if self.executor is None:
self.executor = GLWELinearLayerExecutor()
self.executor = self.executor or GLWELinearLayerExecutor()

# Generate keys only if needed and not already done
if fhe_mode != HybridFHEMode.DISABLE and self.executor.private_key is None:
Expand Down Expand Up @@ -589,7 +589,7 @@ def compile_model(
# If all layers are linear and the GLWE backend is available
# then simply quantize the model without compiling with
# Concrete Python.
if self._has_only_large_linear_layers and _HAS_GLWE_BACKEND:
if self._has_only_large_linear_layers and has_glwe_backend():
self.executor = GLWELinearLayerExecutor()
self.private_q_modules[name] = build_quantized_module(
self.private_modules[name],
Expand Down
Loading

0 comments on commit 7ced6f4

Please sign in to comment.