diff --git a/docs/advanced_examples/KNearestNeighbors.ipynb b/docs/advanced_examples/KNearestNeighbors.ipynb index 081a2496c..fa6d88837 100644 --- a/docs/advanced_examples/KNearestNeighbors.ipynb +++ b/docs/advanced_examples/KNearestNeighbors.ipynb @@ -12,7 +12,7 @@ "\n", "In classification, KNN aims to identify the nearest points by measuring their similarity, often through distance metrics. The new labels are then assigned through majority voting, considering the most frequent labels among the neighboring points.\n", "\n", - "In Fully Homomorphic Encryption (FHE), classification with KNN poses significant computational challenges due to the distance calculations and the sorting algorithms, which is currently a non-stable algorithm.\n", + "In Fully Homomorphic Encryption (FHE), classification with KNN poses significant computational challenges due to the distance calculations and the sorting algorithms, which is currently a non-stable algorithm (i.e., does not consider the order of the elements).\n", "\n", "It is therefore recommended to use it on small datasets (up to dozens of examples) with strong quantization (n_bits <= 4).\n", "\n", @@ -25,7 +25,7 @@ "source": [ "### Import libraries\n", "\n", - "First, import the required packages, the classical KNN regressor and its Concrete ML counterpart." + "First, import the required packages." ] }, { @@ -38,6 +38,7 @@ "\n", "import pandas as pd\n", "from sklearn.datasets import make_classification\n", + "from sklearn.metrics import accuracy_score\n", "from sklearn.model_selection import train_test_split\n", "\n", "from concrete.ml.sklearn import KNeighborsClassifier as ConcreteKNeighborsClassifier" @@ -57,7 +58,7 @@ "outputs": [], "source": [ "X, y = make_classification(\n", - " n_samples=20, n_features=20, n_informative=3, n_redundant=0, n_classes=2, n_clusters_per_class=1\n", + " n_samples=20, n_features=5, n_informative=3, n_redundant=0, n_classes=2, n_clusters_per_class=1\n", ")\n", "# Split the data-set into a train and testing sets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)" @@ -67,7 +68,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Model instantiation" + "# Model instantiation\n", + "\n", + "The novel aspect introduced by Concret ML models is the hyperparameters:\n", + "- `n_bits`: which represents the precision for quantizing input data. This quantization step is essential after the training phase, since FHE exclusively operates over integers.\n", + "- `rounding_threshold_bits`: TODO" ] }, { @@ -76,16 +81,13 @@ "metadata": {}, "outputs": [], "source": [ - "# The novel aspect introduced by Concret-ML models is the hyperparameter `n_bits`, which represents\n", - "# the precision for quantizing input data\n", - "# This quantization step is essential after the training phase, since FHE exclusively operates\n", - "# over integers\n", - "\n", "n_neighbors = 3\n", "\n", - "concrete_knn = ConcreteKNeighborsClassifier(n_bits=3, n_neighbors=n_neighbors)\n", + "concrete_knn = ConcreteKNeighborsClassifier(\n", + " n_bits=2, n_neighbors=n_neighbors, rounding_threshold_bits=4\n", + ")\n", "\n", - "# Fit both the Concrete ML and its equivalent float estimators on clear data\n", + "# Fit both the Concrete ML and its equivalent float estimator on clear data\n", "concrete_knn, sklearn_model = concrete_knn.fit_benchmark(X_train, y_train)" ] }, @@ -93,7 +95,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Compile the model" + "# Compile the model\n", + "\n", + "\n", + "The compilation step aims to:\n", + "- convert the quantized model to its FHE equivalent\n", + "- create an executable operation graph\n", + "- check the operation graph's compatibility with FHE\n", + "- compute the maximum bit-width needed for model execution\n", + "- determine cryptographic parameters necessary for generating secret keys and evaluation keys" ] }, { @@ -105,18 +115,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Saved in 47\n", - "Compilation time: 2.56 seconds\n" + "Compilation time: 4.22 seconds\n" ] } ], "source": [ - "# The compilation step aims to:\n", - "# - convert the quantized model to its FHE equivalent\n", - "# - create an executable operation graph\n", - "# - check the operation graph's compatibility with FHE\n", - "# - compute the maximum bit-width needed for model execution\n", - "# - determine cryptographic parameters necessary for generating secret keys and evaluation keys\n", "time_begin = time.time()\n", "circuit = concrete_knn.compile(X)\n", "print(f\"Compilation time: {time.time() - time_begin:.2f} seconds\")" @@ -131,19 +134,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "Generating a key for an 8-bit circuit\n" + "Generating a key for an 6-bits circuit\n" ] } ], "source": [ - "print(f\"Generating a key for an {circuit.graph.maximum_integer_bit_width()}-bit circuit\")" + "print(f\"Generating a key for an {circuit.graph.maximum_integer_bit_width()}-bits circuit\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Key generation" + "# Key generation\n", + "\n", + "The circuit generated by the compiler is used to generate a set of keys:\n", + " - a _Secret key_ , held exclusively by the user and used for both encryption and decryption process\n", + "\n", + "- an _Evaluation Key_, publicly accessible without compromising the security of the scheme, and used to evaluate the circuit on encrypted data" ] }, { @@ -155,14 +163,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "Key generation time: 829.81 seconds\n" + "Key generation time: 40.04 seconds\n" ] } ], "source": [ "# Note that this step may be time-consuming for circuits exceeding 8-bits\n", "time_begin = time.time()\n", - "circuit.client.keygen(force=False)\n", + "circuit.client.keygen()\n", "print(f\"Key generation time: {time.time() - time_begin:.2f} seconds\")" ] }, @@ -172,11 +180,11 @@ "source": [ "# Inference with Concrete ML:\n", "\n", - "a. __clear__: inference on unencrypted quantized data, without any FHE execution \n", + "a. __clear__: inference on non-encrypted quantized data, without any FHE execution \n", "\n", - "b. __Execution in FHE__: inference on encrypted data, using actual FHE execution\n", + "b. __Simulation__: inference on non-encrypted quantized data, while simulating all FHE operations, failure probabilities and crypto-parameters. This mode of inference is recommended in the deployment phase. For further information, please consult [this link](https://docs.zama.ai/concrete-ml/advanced-topics/compilation#fhe-simulation)\n", "\n", - "c. __Simulation__: inference on unencrypted quantized data, without secure FHE execution, while simulating the p_error failure probability. For further information, please consult: [TODO]()" + "c. __Execution in FHE__: inference on encrypted data, using actual FHE execution" ] }, { @@ -187,36 +195,22 @@ "source": [ "# scikit-learn inference\n", "predict_sklearn = sklearn_model.predict(X_test)\n", - "score_sklearn = (predict_sklearn == y_test).mean()" + "score_sklearn = accuracy_score(y_test, predict_sklearn)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Time inference: 40.05 seconds per sample\n" - ] - } - ], + "outputs": [], "source": [ "# a- Clear inference\n", "pred_cml_clear = concrete_knn.predict(X_test, fhe=\"disable\")\n", - "score_cml_clear = (pred_cml_clear == y_test).mean()\n", + "score_cml_clear = accuracy_score(y_test, pred_cml_clear)\n", "\n", - "# b- FHE inference\n", - "time_begin = time.time()\n", - "pred_cml_fhe = concrete_knn.predict(X_test[0, None], fhe=\"execute\")\n", - "print(f\"Time inference: {time.time() - time_begin:.2f} seconds per sample\")\n", - "score_cml_fhe = (pred_cml_fhe == y_test[0]).mean()\n", - "\n", - "# c- FHE simulation inference\n", + "# b- FHE simulation inference\n", "pred_cml_simulate = concrete_knn.predict(X_test, fhe=\"simulate\")\n", - "score_cml_simulate = (pred_cml_simulate == y_test).mean()" + "score_cml_simulate = accuracy_score(y_test, pred_cml_simulate)" ] }, { @@ -230,155 +224,164 @@ "text": [ "sckit-learn score: 70.00%\n", "Concrete ML (clear) score: 80.00%\n", - "Concrete ML FHE (simulation) score: 80.00%\n", - "Concrete ML FHE score: 100.00%\n" + "Concrete ML (FHE simulation) score: 80.00%\n" ] } ], "source": [ "print(f\"sckit-learn score: {score_sklearn:.2%}\")\n", "print(f\"Concrete ML (clear) score: {score_cml_clear:.2%}\")\n", - "print(f\"Concrete ML FHE (simulation) score: {score_cml_simulate:.2%}\")\n", - "print(f\"Concrete ML FHE score: {score_cml_fhe:.2%}\")" + "print(f\"Concrete ML (FHE simulation) score: {score_cml_simulate:.2%}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Concrete KNN vs. scikit-learn KNN" + "### Concrete KNN vs. scikit-learn KNN\n", + "\n", + "Let's compare the top-k labels returned by Concrete and scikit-learn's KNN in the table below, highlighting mismatched predictions." ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "distance, topk_args = sklearn_model.kneighbors(X_test)\n", "\n", "topk_sk = y_train[topk_args]\n", - "topk_cml = concrete_knn.topk" + "topk_cml = concrete_knn._topk_labels" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", - "
\n", - " | Distance | \n", - "Top3 (scikit-learn) | \n", - "Majority vote (scikit-learn) | \n", - "Top3 (Concrete ML) | \n", - "Majority vote (Concrete ML) | \n", + "Distance | \n", + "Top3 (scikit-learn) | \n", + "Majority vote (scikit-learn) | \n", + "Top3 (Concrete ML) | \n", + "Majority vote (Concrete ML) | \n", "|
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", - "4.620949 | \n", - "[0, 0, 0] | \n", - "0 | \n", - "[0, 0, 0] | \n", - "0 | \n", + "0 | \n", + "2.041796 | \n", + "[1, 0, 0] | \n", + "0 | \n", + "[1, 0, 0] | \n", + "0 | \n", "
1 | \n", - "5.215420 | \n", - "[1, 0, 1] | \n", - "1 | \n", - "[1, 0, 1] | \n", - "1 | \n", + "1 | \n", + "2.514646 | \n", + "[0, 0, 0] | \n", + "0 | \n", + "[0, 0, 0] | \n", + "0 | \n", "
2 | \n", - "3.655300 | \n", - "[0, 0, 0] | \n", - "0 | \n", - "[0, 0, 0] | \n", - "0 | \n", + "2 | \n", + "2.037168 | \n", + "[1, 1, 0] | \n", + "1 | \n", + "[1, 1, 1] | \n", + "1 | \n", "
3 | \n", - "5.601465 | \n", - "[1, 0, 0] | \n", - "0 | \n", - "[1, 0, 1] | \n", - "1 | \n", + "3 | \n", + "1.800107 | \n", + "[1, 1, 1] | \n", + "1 | \n", + "[1, 1, 1] | \n", + "1 | \n", "
4 | \n", - "4.655596 | \n", - "[1, 1, 0] | \n", - "1 | \n", - "[1, 1, 0] | \n", - "1 | \n", + "4 | \n", + "1.380009 | \n", + "[1, 0, 1] | \n", + "1 | \n", + "[0, 0, 1] | \n", + "0 | \n", "
5 | \n", - "3.393518 | \n", - "[0, 1, 0] | \n", - "0 | \n", - "[0, 1, 0] | \n", - "0 | \n", + "5 | \n", + "1.078951 | \n", + "[0, 0, 1] | \n", + "0 | \n", + "[0, 0, 1] | \n", + "0 | \n", "
6 | \n", - "5.437388 | \n", - "[1, 1, 1] | \n", - "1 | \n", - "[1, 1, 0] | \n", - "1 | \n", + "6 | \n", + "1.093890 | \n", + "[1, 1, 0] | \n", + "1 | \n", + "[1, 1, 1] | \n", + "1 | \n", "
7 | \n", - "4.737523 | \n", - "[1, 1, 0] | \n", - "1 | \n", - "[0, 1, 1] | \n", - "1 | \n", + "7 | \n", + "2.419766 | \n", + "[1, 1, 1] | \n", + "1 | \n", + "[1, 1, 1] | \n", + "1 | \n", "
8 | \n", - "5.163767 | \n", - "[1, 1, 0] | \n", - "1 | \n", - "[1, 1, 1] | \n", - "1 | \n", + "8 | \n", + "2.809947 | \n", + "[0, 0, 0] | \n", + "0 | \n", + "[0, 0, 0] | \n", + "0 | \n", "
9 | \n", - "4.256865 | \n", - "[1, 1, 0] | \n", - "1 | \n", - "[1, 1, 0] | \n", - "1 | \n", + "9 | \n", + "1.553386 | \n", + "[1, 1, 1] | \n", + "1 | \n", + "[1, 1, 0] | \n", + "1 | \n", "