Skip to content

Commit

Permalink
chore: Add knn to classifier comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia authored Sep 27, 2023
1 parent 6e94c09 commit 09500b6
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 22 deletions.
55 changes: 49 additions & 6 deletions docs/advanced_examples/ClassifierComparison.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
89 changes: 73 additions & 16 deletions docs/advanced_examples/utils/classifier_comparison_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -21,30 +21,35 @@
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from concrete.ml.common.utils import get_model_name
from concrete.ml.sklearn import DecisionTreeClassifier

ALWAYS_USE_SIM = False

# pylint: disable=too-many-locals,too-many-statements,too-many-branches
def make_classifier_comparison(title, classifiers, decision_level, verbose=False):
# pylint: disable=too-many-locals,too-many-statements,too-many-branches,invalid-name
def make_classifier_comparison(
title, classifiers, decision_level, verbose=False, show_score=False, save_plot=False
):

h = 0.04 # Step size in the mesh
n_samples = 25 if get_model_name(classifiers[0][0]) == "KNeighborsClassifier" else 200

X, y = make_classification(
n_samples=200,
n_samples=n_samples,
n_features=2,
n_redundant=0,
n_informative=2,
random_state=1,
n_clusters_per_class=1,
)
# pylint: disable-next=no-member
rng = np.random.RandomState(2)
X += 2 * rng.uniform(size=X.shape)
linearly_separable = (X, y)

datasets = [
make_moons(n_samples=200, noise=0.3, random_state=0),
make_circles(n_samples=200, noise=0.2, factor=0.5, random_state=1),
make_moons(n_samples=n_samples, noise=0.2, random_state=0),
make_circles(n_samples=n_samples, noise=0.2, factor=0.5, random_state=1),
linearly_separable,
]

Expand All @@ -53,6 +58,7 @@ def make_classifier_comparison(title, classifiers, decision_level, verbose=False
fig, axs = plt.subplots(len(datasets), 2 * len(classifiers) + 1, figsize=(32, 16))
fig.suptitle(title, fontsize=20)
fig.patch.set_facecolor("white")
plt.subplots_adjust(top=0.9)

# Iterate over data-sets
for i, dataset in enumerate(datasets):
Expand All @@ -69,6 +75,7 @@ def make_classifier_comparison(title, classifiers, decision_level, verbose=False
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

# pylint: disable-next=no-member
cm = plt.cm.RdBu
cm_bright = ListedColormap(["#FF0000", "#0000FF"])
ax = axs[i, 0]
Expand Down Expand Up @@ -116,7 +123,19 @@ def make_classifier_comparison(title, classifiers, decision_level, verbose=False
sklearn_y_pred = sklearn_model.predict(X_test)

# Compile the Concrete ML model
circuit = concrete_model.compile(X_train)
if get_model_name(classifier) == "KNeighborsClassifier":
n_compile_samples = 20
else:
n_compile_samples = X_train.shape[0]

if verbose:
print(f"Inputset has: {n_compile_samples} samples used for the compilationg.\n")

time_begin = time.time()
circuit = concrete_model.compile(X_train[:n_compile_samples])

if verbose:
print(f"Compilation time: {(time.time() - time_begin):.4f} seconds\n")

# If the prediction are done in FHE, generate the key
if not ALWAYS_USE_SIM:
Expand All @@ -139,7 +158,7 @@ def make_classifier_comparison(title, classifiers, decision_level, verbose=False

if verbose:
print(
f"Execution time: {(time.time() - time_begin) / len(X_test):.4f} "
f"FHE Execution time: {(time.time() - time_begin) / len(X_test):.4f} "
"seconds per sample\n"
)

Expand Down Expand Up @@ -171,6 +190,14 @@ def make_classifier_comparison(title, classifiers, decision_level, verbose=False
np.c_[xx.ravel(), yy.ravel()],
fhe="simulate",
)
# `predict_proba` not implemented yet for Concrete KNN
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962
elif get_model_name(classifier) == "KNeighborsClassifier":
sklearn_Z = sklearn_model.predict(np.c_[xx.ravel(), yy.ravel()].astype(np.float32))
concrete_Z = concrete_model.predict(
np.c_[xx.ravel(), yy.ravel()],
fhe="simulate",
)
else:
sklearn_Z = sklearn_model.predict_proba(
np.c_[xx.ravel(), yy.ravel()].astype(np.float32)
Expand Down Expand Up @@ -222,14 +249,14 @@ def make_classifier_comparison(title, classifiers, decision_level, verbose=False

if i == 0:
ax.set_title(model_name + f" ({framework})", fontsize=font_size_text)

ax.text(
xx.max() - 0.3,
yy.min() + 0.3,
f"{score*100:0.1f}%",
size=font_size_text,
horizontalalignment="right",
)
if show_score:
ax.text(
xx.max() - 0.3,
yy.min() + 0.3,
f"{score*100:0.1f}%",
size=font_size_text,
horizontalalignment="right",
)

if bitwidth and framework == "Concrete ML":
ax.text(
Expand All @@ -240,5 +267,35 @@ def make_classifier_comparison(title, classifiers, decision_level, verbose=False
horizontalalignment="right",
)

if save_plot:
plt.savefig(f"./{title}.png")

plt.tight_layout()
plt.show()


def display_plot(path, figsize=(20, 20)):
"""
Display an image plot from the given file path.
Parameters:
path (str): The file path to the image (PNG format).
figsize (tuple): A tuple specifying the width and height of the plot. Default is (20, 20).
Raises:
FileNotFoundError: If the specified image file does not exist, this exception is raised.
"""
img_path = Path(path)
# Check if the PNG file exists
if img_path.exists():
img = plt.imread(img_path)
plt.figure(figsize=figsize)
plt.imshow(img)
plt.axis("off")
plt.show()
else:
# Raise a custom exception if the image doesn't exist
raise FileNotFoundError(
"The desired plot is not available."
"To generate the plot, please run script `utils/knn_comparisons.py`."
)
20 changes: 20 additions & 0 deletions docs/advanced_examples/utils/knn_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""This file runs `classifier_comparison_utils.py` script for the KNN classifier, because it doesn't
work on Jupyter for now. This file must be deleted once:
https://github.com/zama-ai/concrete-ml-internal/issues/4018 is solved."""

from functools import partial

from classifier_comparison_utils import make_classifier_comparison

from concrete.ml.sklearn import KNeighborsClassifier

if __name__ == "__main__":
knn_classifiers = [
(partial(KNeighborsClassifier, n_bits=2, n_neighbors=3), "3nn"),
(partial(KNeighborsClassifier, n_bits=4, n_neighbors=5), "5nn"),
]

# pylint: disable-next=unexpected-keyword-arg
make_classifier_comparison(
"KNN_Classifier", knn_classifiers, 0, verbose=True, show_score=False, save_plot=True
)
18 changes: 18 additions & 0 deletions script/make_utils/check_all_images_are_used.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@ set -e
IMAGE_FILES=$(git -C ./docs ls-files "*.png" "*.svg" "*.jpg" "*.jpeg" --cached --others --exclude-standard --full-name)
MD_FILES=$(git -C ./docs ls-files "*.md" --cached --others --exclude-standard --full-name)

# Create an empty list to store image files mentioned in notebooks
MENTIONED_IMG_FILES=()
for img_file in $IMAGE_FILES; do
# Extract the image file name from the full path
img_name=$(basename "$img_file")
# Search if the image is mentioned in any notebooks (".ipynb" files).
mentioned_in_specific_ipynb=$(find "./docs" -type f -name "*.ipynb" -exec grep -q -F "$img_name" {} \; -print)
# Check if the image file extension is valid and if it is mentioned in any notebooks
if [[ $img_name =~ \.(jpeg|svg|jpg|png)$ ]] && [ -n "$mentioned_in_specific_ipynb" ]; then
MENTIONED_IMG_FILES+=("$img_file")
fi
done

# Remove the image file paths mentioned in notebooks from IMAGE_FILES.
for mentioned_img_filename in "${MENTIONED_IMG_FILES[@]}"; do
IMAGE_FILES=$(echo "$IMAGE_FILES" | grep -vF "$mentioned_img_filename")
done

# shellcheck disable=SC2086
poetry run python script/doc_utils/check_all_images_are_used.py --images $IMAGE_FILES --files $MD_FILES

0 comments on commit 09500b6

Please sign in to comment.