Skip to content

Commit

Permalink
chore: update format
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Apr 16, 2024
1 parent c1ea4c6 commit 93f1498
Show file tree
Hide file tree
Showing 97 changed files with 135 additions and 40 deletions.
2 changes: 1 addition & 1 deletion benchmarks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main():
# Listing
if args.long_list or args.short_list:
already_done_models = {}
for (dataset_i, model_class_i, config_i) in all_classification_tasks:
for dataset_i, model_class_i, config_i in all_classification_tasks:
config_n = json.dumps(config_i).replace("'", '"')
model_name_i = model_class_i.__name__

Expand Down
5 changes: 3 additions & 2 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def run_and_report_classification_metrics(
(f1_score, "f1", "F1Score"),
]

for (metric, metric_id, metric_label) in metric_info:
for metric, metric_id, metric_label in metric_info:
run_and_report_metric(
y_gt,
y_pred,
Expand All @@ -288,7 +288,7 @@ def run_and_report_regression_metrics(y_gt, y_pred, metric_id_prefix, metric_lab
"""Run several metrics and report results to progress tracker with computed name and id"""

metric_info = [(r2_score, "r2_score", "R2Score"), (mean_squared_error, "MSE", "MSE")]
for (metric, metric_id, metric_label) in metric_info:
for metric, metric_id, metric_label in metric_info:
run_and_report_metric(
y_gt,
y_pred,
Expand Down Expand Up @@ -768,6 +768,7 @@ def benchmark_name_generator(
# - The functions support all models
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/1866


# pylint: disable-next=too-many-branches, redefined-outer-name
def benchmark_name_to_config(
benchmark_name: str, joiner: str = "_"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/deep_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def main():
if args.long_list or args.short_list:
# Print the short or long lists if asked and stop
printed_models = set()
for (dataset, cnn_class, config) in all_tasks:
for dataset, cnn_class, config in all_tasks:
configs = json.dumps(config).replace("'", '"')
cnn_name = cnn_class.__name__

Expand Down
10 changes: 8 additions & 2 deletions benchmarks/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,20 @@ def get_preprocessor() -> ColumnTransformer:

def get_train_test_data(data: pandas.DataFrame) -> Tuple[pandas.DataFrame, pandas.DataFrame]:
"""Split the data into a train and test set."""
train_data, test_data, = train_test_split(
(
train_data,
test_data,
) = train_test_split(
data,
test_size=0.2,
random_state=0,
)

# The test set is reduced for faster FHE runs.
_, test_data, = train_test_split(
(
_,
test_data,
) = train_test_split(
test_data,
test_size=500,
random_state=0,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main():
# Listing
if args.long_list or args.short_list:
already_done_models = {}
for (dataset_i, model_class_i, config_i) in all_tasks:
for dataset_i, model_class_i, config_i in all_tasks:
config_n = json.dumps(config_i).replace("'", '"')
model_name_i = model_class_i.__name__

Expand Down
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""PyTest configuration file."""

import hashlib
import json
import random
Expand Down
1 change: 1 addition & 0 deletions docker/release_resources/sanity_check.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sanity checks, to be sure that our package is usable"""

import argparse
import random
import shutil
Expand Down
8 changes: 5 additions & 3 deletions docs/advanced_examples/KNearestNeighbors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,11 @@
"def highlight_diff(row):\n",
" \"\"\"Custom style function to highlight mismatched predictions.\"\"\"\n",
" return [\n",
" \"background-color: yellow\"\n",
" if row[\"Majority vote (Concrete ML)\"] != row[\"Majority vote (scikit-learn)\"]\n",
" else \"\"\n",
" (\n",
" \"background-color: yellow\"\n",
" if row[\"Majority vote (Concrete ML)\"] != row[\"Majority vote (scikit-learn)\"]\n",
" else \"\"\n",
" )\n",
" ] * len(row)\n",
"\n",
"\n",
Expand Down
1 change: 1 addition & 0 deletions script/actions_utils/escape_quotes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Script to escape double quotes within brackets or curly braces."""

import argparse

parser = argparse.ArgumentParser(description="Escape double quotes in a string")
Expand Down
1 change: 1 addition & 0 deletions script/actions_utils/generate_scripts_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Script to generate the list of commands to run all benchmarks"""

import argparse
import datetime
import json
Expand Down
1 change: 1 addition & 0 deletions script/actions_utils/json_length.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Script to evaluate the length of a json file"""

import argparse
import json
from pathlib import Path
Expand Down
1 change: 1 addition & 0 deletions script/actions_utils/monitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module to generate figure of evolution of Concrete ML-CI time on main for last 4 weeks."""

import argparse
import datetime
import json
Expand Down
1 change: 1 addition & 0 deletions script/actions_utils/pytest_failed_test_report.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Pytest JSON report on failed tests utils."""

import argparse
import json
from pathlib import Path
Expand Down
1 change: 1 addition & 0 deletions script/actions_utils/refresh_notebooks_list.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Update the list of available notebooks for the refresh_one_notebook GitHib action."""

import argparse
from pathlib import Path

Expand Down
1 change: 1 addition & 0 deletions script/actions_utils/run_commands.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Script to run commands from a json file"""

import argparse
import json
import subprocess
Expand Down
1 change: 1 addition & 0 deletions script/doc_utils/gen_supported_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Update list of supported functions in the doc."""

import argparse
from pathlib import Path

Expand Down
1 change: 1 addition & 0 deletions script/make_utils/actionlint_check_with_whitelists.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Check an actionlint log against some whitelists """

import sys
from typing import Set

Expand Down
1 change: 1 addition & 0 deletions script/make_utils/check_headers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Check that headers linked do indeed exist in target markdown files"""

import os
from pathlib import Path

Expand Down
1 change: 1 addition & 0 deletions script/make_utils/check_issues.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Check linked github issues states"""

import json
import re
import subprocess
Expand Down
1 change: 1 addition & 0 deletions script/nbmake_utils/notebook_finalize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Finalize Jupyter notebooks."""

import argparse
import json
from pathlib import Path
Expand Down
1 change: 1 addition & 0 deletions src/concrete/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Top level import."""

# Do not modify, this is to have a compatible namespace package
# https://packaging.python.org/en/latest/guides/packaging-namespace-packages/
# #pkg-resources-style-namespace-packages
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Module for shared data structures and code."""

from . import check_inputs, debugging, utils
1 change: 1 addition & 0 deletions src/concrete/ml/common/debugging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Module for debugging."""

from .custom_assert import assert_false, assert_not_reached, assert_true
1 change: 1 addition & 0 deletions src/concrete/ml/common/debugging/custom_assert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Provide some variants of assert."""

from typing import Type


Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/common/serialization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Serialization module."""

import os

from torch.nn.modules import activation
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/common/serialization/decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Custom decoder for serialization."""

import inspect
import json
from typing import Any, Dict, Type
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/common/serialization/dumpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Dump functions for serialization."""

import json
from typing import Any, TextIO

Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/common/serialization/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Custom encoder for serialization."""

import inspect
import json
from json.encoder import _make_iterencode # type: ignore[attr-defined]
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/common/serialization/loaders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Load functions for serialization."""

import json
from typing import IO, Any, Union

Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/deployment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Module for deployment of the FHE model."""

from .fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
1 change: 1 addition & 0 deletions src/concrete/ml/deployment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Check if connection possible
- Wait for connection to be available (with timeout)
"""

import subprocess
import time
from pathlib import Path
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/onnx/onnx_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utils to interpret an ONNX model with numpy."""

# Utils to interpret an ONNX model with numpy.


Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pandas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Public API for encrypted data-frames."""

from pathlib import Path
from typing import Hashable, Optional, Sequence, Tuple, Union

Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pandas/_development.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define development methods for generating client/server files."""

import itertools
from functools import partial
from pathlib import Path
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pandas/_operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implement Pandas operators in FHE using encrypted data-frames."""

from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, Union

import numpy
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pandas/_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define pre-processing and post-processing steps for encrypted data-frames."""

import copy
from collections import defaultdict
from typing import Dict, List, Tuple
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pandas/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define utility functions for encrypted data-frames."""

import functools
from typing import List, Optional, Tuple, Union

Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pandas/client_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define the framework used for managing keys (encrypt, decrypt) for encrypted data-frames."""

from pathlib import Path
from typing import Optional, Union

Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pandas/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define the encrypted data-frame framework."""

import json
from pathlib import Path
from typing import Dict, Hashable, List, Optional, Sequence, Tuple, Union
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pytest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Module which is used to contain common functions for pytest."""

from . import torch_models, utils
4 changes: 1 addition & 3 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,9 +867,7 @@ def __init__(self, input_output, activation_function, n_bits=2, disable_bit_chec
n_bits_weights = n_bits

# Generate the pattern 0, 1, ..., 2^N-1, 0, 1, .. 2^N-1, 0, 1..
all_weights = numpy.mod(
numpy.arange(numpy.prod(self.fc1.weight.shape)), 2**n_bits_weights
)
all_weights = numpy.mod(numpy.arange(numpy.prod(self.fc1.weight.shape)), 2**n_bits_weights)

# Shuffle the pattern and reshape to weight shape
numpy.random.shuffle(all_weights)
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/pytest/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common functions or lists for test files, which can't be put in fixtures."""

import copy
import io
from functools import partial
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Modules for quantization."""

from .base_quantized_op import QuantizedOp
from .post_training import (
PostTrainingAffineQuantization,
Expand Down
12 changes: 6 additions & 6 deletions src/concrete/ml/quantization/base_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,14 @@ def dump_dict(self) -> Dict:
metadata["_input_idx_to_params_name"] = self._input_idx_to_params_name
metadata["_params_that_are_onnx_inputs"] = self._params_that_are_onnx_inputs
metadata["_params_that_are_onnx_var_inputs"] = self._params_that_are_onnx_var_inputs
metadata[
"_params_that_are_required_onnx_inputs"
] = self._params_that_are_required_onnx_inputs
metadata["_params_that_are_required_onnx_inputs"] = (
self._params_that_are_required_onnx_inputs
)
metadata["_has_attr"] = self._has_attr
metadata["_inputs_not_quantized"] = self._inputs_not_quantized
metadata[
"quantize_inputs_with_model_outputs_precision"
] = self.quantize_inputs_with_model_outputs_precision
metadata["quantize_inputs_with_model_outputs_precision"] = (
self.quantize_inputs_with_model_outputs_precision
)
metadata["produces_graph_output"] = self.produces_graph_output
metadata["produces_raw_output"] = self.produces_raw_output
metadata["error_tracker"] = self.error_tracker
Expand Down
1 change: 1 addition & 0 deletions src/concrete/ml/quantization/qat_quantizers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Custom Quantization Aware Training Brevitas quantizers."""

from brevitas.quant.scaled_int import (
IntQuant,
MaxStatsScaling,
Expand Down
9 changes: 5 additions & 4 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""QuantizedModule API."""

import copy
import re
from functools import partial
Expand Down Expand Up @@ -145,7 +146,7 @@ def set_reduce_sum_copy(self):
to copy the inputs with a PBS to avoid it.
"""
assert self.quant_layers_dict is not None
for (_, quantized_op) in self.quant_layers_dict.values():
for _, quantized_op in self.quant_layers_dict.values():
if isinstance(quantized_op, QuantizedReduceSum):
quantized_op.copy_inputs = True

Expand Down Expand Up @@ -369,10 +370,10 @@ def forward(
debug_value_tracker: Dict[
str, Dict[Union[int, str], Optional[ONNXOpInputOutputType]]
] = {}
for (_, layer) in self.quant_layers_dict.values():
for _, layer in self.quant_layers_dict.values():
layer.debug_value_tracker = debug_value_tracker
q_y_pred = self.quantized_forward(*q_x, fhe="disable")
for (_, layer) in self.quant_layers_dict.values():
for _, layer in self.quant_layers_dict.values():
layer.debug_value_tracker = None
# De-quantize the output predicted values
y_pred = self.dequantize_output(*to_tuple(q_y_pred))
Expand Down Expand Up @@ -767,7 +768,7 @@ def bitwidth_and_range_report(
return None

op_names_to_report: Dict[str, Dict[str, Union[Tuple[int, ...], int]]] = {}
for (_, op_inst) in self.quant_layers_dict.values():
for _, op_inst in self.quant_layers_dict.values():
# Get the value range of this tag and all its subtags
# The potential tags for this op start with the op instance name
# and are, sometimes, followed by a subtag starting with a period:
Expand Down
5 changes: 3 additions & 2 deletions src/concrete/ml/quantization/quantized_module_passes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Optimization passes for QuantizedModules."""

from collections import defaultdict
from typing import DefaultDict, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -101,7 +102,7 @@ def compute_op_predecessors(self) -> PredecessorsType:
# Initialize the list of predecessors with tensors that are graph inputs
predecessors: PredecessorsType = defaultdict(list)

for (node_inputs, node_op) in self._qmodule.quant_layers_dict.values():
for node_inputs, node_op in self._qmodule.quant_layers_dict.values():
# The first input node contains the encrypted data
enc_input_node = node_inputs[0]

Expand Down Expand Up @@ -168,7 +169,7 @@ def detect_patterns(self, predecessors: PredecessorsType) -> PatternDict:
valid_paths: PatternDict = {}

# pylint: disable-next=too-many-nested-blocks
for (_, node_op) in self._qmodule.quant_layers_dict.values():
for _, node_op in self._qmodule.quant_layers_dict.values():
# Only work with supported nodes that have a single
# encrypted input (not supporting enc x enc matmul)
if (
Expand Down
Loading

0 comments on commit 93f1498

Please sign in to comment.