Skip to content

Commit

Permalink
chore: fix coverage
Browse files Browse the repository at this point in the history
+ fix wrong Conv1D transformer import
  • Loading branch information
jfrery committed Nov 25, 2024
1 parent 7d69721 commit 046c412
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 127 deletions.
20 changes: 10 additions & 10 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def convert_conv1d_to_linear(layer_or_module):
or the Conv1D layer converted to a Linear layer.
"""
try:
from transformers import Conv1D # pylint: disable=import-outside-toplevel
from transformers.modeling_utils import Conv1D # pylint: disable=import-outside-toplevel
except ImportError: # pragma: no cover
return layer_or_module

Expand Down Expand Up @@ -412,13 +412,14 @@ def _replace_modules(self):
if is_pure_linear_layer:
module = self.private_modules[module_name]
# Use weight shape instead of in/out_features
if hasattr(module, "weight"):
input_dim = module.weight.shape[
1
] # Input dimension is second dimension for Linear layers
output_dim = module.weight.shape[0] # Output dimension is first dimension
else:
input_dim = output_dim = 0
input_dim, output_dim = (
(
module.weight.shape[1],
module.weight.shape[0],
)
if hasattr(module, "weight")
else (0, 0)
)

is_pure_linear_layer = (
is_pure_linear_layer and input_dim >= 512 and output_dim >= 512
Expand Down Expand Up @@ -474,8 +475,7 @@ def forward(self, x: torch.Tensor, fhe: str = "disable") -> torch.Tensor:

if fhe_mode in (HybridFHEMode.EXECUTE, HybridFHEMode.REMOTE, HybridFHEMode.DISABLE):
# Initialize executor only if not already done
if self.executor is None:
self.executor = GLWELinearLayerExecutor()
self.executor = self.executor or GLWELinearLayerExecutor()

# Generate keys only if needed and not already done
if fhe_mode != HybridFHEMode.DISABLE and self.executor.private_key is None:
Expand Down
204 changes: 87 additions & 117 deletions tests/torch/test_hybrid_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,166 +37,138 @@ def test_tuple_serialization(tup):
assert tup == underscore_str_to_tuple(tuple_to_underscore_str(tup))


# pylint: disable=too-many-locals, too-many-branches, too-many-statements
def setup_test_environment():
"""Save the original state of critical modules for restoration."""
original_modules = {}
for module_name in [
"transformers",
"concrete_ml_extensions",
"concrete.ml.quantization.linear_op_glwe_backend",
"concrete.ml.torch.hybrid_model",
]:
original_modules[module_name] = sys.modules.get(module_name)
return original_modules


# pylint: disable=too-many-arguments, too-many-locals, too-many-statements, too-many-branches
def run_hybrid_llm_test(
model: torch.nn.Module,
inputs: torch.Tensor,
module_names: Union[str, List],
expected_accuracy,
module_names: Union[str, List[str]],
expected_accuracy: float,
has_pbs: bool,
has_pbs_reshape: bool,
monkeypatch,
transformers_installed,
glwe_backend_installed,
transformers_installed: bool,
glwe_backend_installed: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Run the test for any model with its private module names."""

# Multi-parameter strategy is used in order to speed-up the FHE executions
# Configure the model
configuration = Configuration(
single_precision=False,
compress_input_ciphertexts=True,
)

logits_simulate = None

with monkeypatch.context() as m:
if not transformers_installed:
m.setitem(sys.modules, "transformers", None)
if has_pbs_reshape:
has_pbs = True

# Patching for GLWE backend
if not glwe_backend_installed:
m.setitem(sys.modules, "concrete_ml_extensions", None)

# Reload the affected modules to ensure the changes take effect
importlib.reload(concrete.ml.quantization.linear_op_glwe_backend)
importlib.reload(concrete.ml.torch.hybrid_model)

hybrid_model = HybridFHEModel(model, module_names)
is_compiled = False
try:
hybrid_model.compile_model(
inputs,
p_error=10e-40, # compare precisely simulate and disable
n_bits=9,
rounding_threshold_bits=8,
configuration=configuration,
)
is_compiled = True
except RuntimeError as error:
# When reshaping adds PBSs we sometimes encounter NoParametersFound
# when compiling. In this case we skip the rest since we can't simulate
# without compilation.
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4183
assert "NoParametersFound" in error.args[0]
pytest.skip(error.args[0])

# Check we can run the simulate locally
if has_pbs or not glwe_backend_installed:
logits_simulate = hybrid_model(inputs, fhe="simulate").logits
# Mock sys.modules to simulate missing modules
if not transformers_installed:
monkeypatch.setitem(sys.modules, "transformers", None)
if not glwe_backend_installed:
monkeypatch.setitem(sys.modules, "concrete_ml_extensions", None)

# Reload affected modules after mocking
importlib.reload(concrete.ml.quantization.linear_op_glwe_backend)
importlib.reload(concrete.ml.torch.hybrid_model)

# Initialize and compile the hybrid model
hybrid_model = HybridFHEModel(model, module_names)

try:
hybrid_model.compile_model(
inputs,
p_error=10e-40,
n_bits=9,
rounding_threshold_bits=8,
configuration=configuration,
)
except RuntimeError as error:
# Skip test if NoParametersFound error occurs
if "NoParametersFound" in str(error):
pytest.skip(str(error))
else:
with pytest.raises(AssertionError, match=".*fhe=simulate is not supported.*"):
hybrid_model(inputs, fhe="simulate")
raise

if has_pbs:
# Check for non-zero programmable bootstrapping
for module in hybrid_model.private_q_modules.values():
assert module.fhe_circuit.statistics["programmable_bootstrap_count"] > 0, (
"Programmable bootstrap count should be greater than 0, "
f"but found {module.fhe_circuit.statistics['programmable_bootstrap_count']}"
)
# Run the model in different modes
logits_simulate = None
if has_pbs or not glwe_backend_installed:
logits_simulate = hybrid_model(inputs, fhe="simulate").logits
else:
# Check for zero programmable bootstrapping
for module in hybrid_model.private_q_modules.values():
# The RemoteModule does not have a circuit if it was optimized
# (in the case of pure linear remote modules)
assert (
not module.fhe_circuit
or module.fhe_circuit.statistics["programmable_bootstrap_count"] == 0
), (
"Programmable bootstrap count should be 0, "
f"but found {module.fhe_circuit.statistics['programmable_bootstrap_count']}"
)
with pytest.raises(AssertionError, match=".*fhe=simulate is not supported.*"):
hybrid_model(inputs, fhe="simulate")

logits_disable = hybrid_model(inputs, fhe="disable").logits
logits_original = hybrid_model(inputs, fhe="torch").logits

# Compare the topk accuracy of the FHE simulate circuit vs. the original.
k = 5

# Check that the topk next tokens are similar for the different FHE modes
# and the original model.
# Check programmable bootstrap counts if not glwe backend
if not glwe_backend_installed:
for module in hybrid_model.private_q_modules.values():
pbs_count = module.fhe_circuit.statistics.get("programmable_bootstrap_count", 0)
if has_pbs:
assert pbs_count > 0, "Expected programmable bootstrap count > 0"
else:
assert pbs_count == 0, "Expected programmable bootstrap count == 0"

# Get the topk indices for logits_disable and logits_simulate
# Compare top-k accuracy
k = 5
topk_disable = logits_disable.topk(k, dim=-1).indices
topk_original = logits_original.topk(k, dim=-1).indices

# Compute accuracy of disable and simulate by checking
# how many labels correspond with the topk_original
accuracy_disable = (topk_disable == topk_original).float().mean().item()
# Ensure logits_disable and logits_original return the same output for the logits
# Assert that both accuracy values are above the expected threshold
assert (
accuracy_disable >= expected_accuracy
), f"Disable accuracy {accuracy_disable:.4f} is below the expected {expected_accuracy:.4f}"
), f"Disable accuracy {accuracy_disable:.4f} is below expected {expected_accuracy:.4f}"

if logits_simulate is not None:
assert torch.allclose(logits_disable, logits_simulate, atol=1e-7), "Outputs do not match!"
assert torch.allclose(logits_disable, logits_simulate, atol=1e-7)
topk_simulate = logits_simulate.topk(k, dim=-1).indices
accuracy_simulate = (topk_simulate == topk_original).float().mean().item()
assert accuracy_simulate >= expected_accuracy, (
f"Simulate accuracy {accuracy_simulate:.4f} is below "
f"the expected {expected_accuracy:.4f}"
)
assert (
accuracy_simulate >= expected_accuracy
), f"Simulate accuracy {accuracy_simulate:.4f} is below expected {expected_accuracy:.4f}"

# Test model saving and deployment
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)

# Get the temp directory path

if not has_pbs and glwe_backend_installed:

if is_compiled:
# Deployment of GLWE backend hybrid models is not yet supported
with pytest.raises(
NotImplementedError, match="GLWE backend deployment is not yet supported"
):
hybrid_model.save_and_clear_private_info(temp_dir_path)
else:
# Check that we get an error when trying to save a non-compiled model
with pytest.raises(
AttributeError,
match="The quantized module is not compiled. Please run compile*",
):
hybrid_model.save_and_clear_private_info(temp_dir_path)
with pytest.raises(
NotImplementedError, match="GLWE backend deployment is not yet supported"
):
hybrid_model.save_and_clear_private_info(temp_dir_path)
else:
hybrid_model.save_and_clear_private_info(temp_dir_path)
# If transformers is not installed, skip the saving test
if not transformers_installed:
pytest.skip("Skipping save test as transformers module is not available")

hybrid_model.save_and_clear_private_info(temp_dir_path)
hybrid_model.set_fhe_mode("remote")

# At this point, the hybrid model does not have
# the parameters necessaryto run the module_names
module_names = module_names if isinstance(module_names, list) else [module_names]

# Check that files are there
# Verify saved files
assert (temp_dir_path / "model.pth").exists()
for module_name in module_names:
module_dir_path = temp_dir_path / module_name
module_dir_files = set(str(elt.name) for elt in module_dir_path.glob("**/*"))
for file_name in ["client.zip", "server.zip"]:
assert file_name in module_dir_files
module_names_list = module_names if isinstance(module_names, list) else [module_names]
for module_name in module_names_list:
module_dir = temp_dir_path / module_name
files = {file.name for file in module_dir.glob("**/*")}
assert "client.zip" in files and "server.zip" in files


# Dependency 'huggingface-hub' raises a 'FutureWarning' from version 0.23.0 when calling the
# 'from_pretrained' method
@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize(
"list_or_str_private_modules_names, expected_accuracy, has_pbs, has_pbs_reshape",
"list_or_str_private_modules_names, expected_accuracy, has_pbs",
[
("transformer.h.0.mlp", 0.95, True, False),
(["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True, False),
("transformer.h.0.mlp.c_fc", 1.0, False, True),
("transformer.h.0.mlp", 0.95, True),
(["transformer.h.0.mlp", "transformer.h.1.mlp"], 0.40, True),
("transformer.h.0.mlp.c_fc", 1.0, False),
],
)
@pytest.mark.parametrize("transformers_installed", [True, False])
Expand All @@ -205,7 +177,6 @@ def test_gpt2_hybrid_mlp(
list_or_str_private_modules_names,
expected_accuracy,
has_pbs,
has_pbs_reshape,
transformers_installed,
glwe_backend_installed,
monkeypatch,
Expand All @@ -227,10 +198,9 @@ def test_gpt2_hybrid_mlp(
list_or_str_private_modules_names,
expected_accuracy,
has_pbs,
has_pbs_reshape,
monkeypatch,
transformers_installed,
glwe_backend_installed,
monkeypatch,
)


Expand Down

0 comments on commit 046c412

Please sign in to comment.