diff --git a/examples/FineTuneSVM.ipynb b/examples/FineTuneSVM.ipynb index 923d8e1..56d4d2b 100644 --- a/examples/FineTuneSVM.ipynb +++ b/examples/FineTuneSVM.ipynb @@ -6,7 +6,9 @@ "metadata": {}, "source": [ "# Fine-tune with a SVM\n", - "Combines pybioclip image embeddings with a SVM to create predictions for [Somnath01/Birds_Species](https://huggingface.co/datasets/Somnath01/Birds_Species) dataset." + "Creates a model that combines pybioclip image embeddings with a SVM using images from the [Somnath01/Birds_Species](https://huggingface.co/datasets/Somnath01/Birds_Species) dataset. This dataset contains 1000 train images, 403 test images, and 50 validation images. This notebook only uses the train and test images. This dataset was chosen for convenience. No analysis of the suitability of this dataset has been done.\n", + "\n", + "When running this notebook in COLAB change the _runtime type_ to a GPU type to speed up processing. The `fsspec` dependency is not directly used by this notebook, this depencency was added to fix an error running pip on COLAB." ] }, { @@ -16,7 +18,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q numpy datasets pybioclip scikit-learn matplotlib" + "!pip install -q numpy datasets pybioclip scikit-learn matplotlib fsspec" ] }, { @@ -62,28 +64,60 @@ "LABEL_NAME = 'label'\n", "\n", "# Image embedding settings\n", - "BATCH_SIZE = 30\n", - "DEVICE = \"cpu\"" + "BATCH_SIZE = 30" ] }, { "cell_type": "markdown", - "id": "f2a86696-1a9d-4ab4-92de-cdd7bf194496", + "id": "eb7ab928-0889-4650-8eb0-9f0b831ae473", "metadata": {}, "source": [ - "## Load dataset" + "## Determine GPU or CPU" ] }, { "cell_type": "code", "execution_count": 4, + "id": "800d7bd8-52c3-4ee0-bac7-e128cbd7b186", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: mps\n" + ] + } + ], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "print(f\"Using device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2a86696-1a9d-4ab4-92de-cdd7bf194496", + "metadata": {}, + "source": [ + "## Load dataset\n", + "This step takes around 7 minutes to download the images the first time it is run." + ] + }, + { + "cell_type": "code", + "execution_count": 5, "id": "991b4d11-e952-4284-8fb3-64c070607b1f", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "148ed940a7394416bcc732c95893dbeb", + "model_id": "a28423a411f44507afc2950e21457197", "version_major": 2, "version_minor": 0 }, @@ -97,7 +131,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1e66b1a8376c4226974ddaafe258413b", + "model_id": "c17b3a374d584a1898389a91507a344e", "version_major": 2, "version_minor": 0 }, @@ -111,7 +145,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0ab3dcb1f8954526b1c9912e7348bad8", + "model_id": "437ec33244d241508dc9383707b9838e", "version_major": 2, "version_minor": 0 }, @@ -141,7 +175,7 @@ "})" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -153,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "22f7ee8a-bd9d-47f7-bf07-0fce37574bd6", "metadata": {}, "outputs": [ @@ -163,7 +197,7 @@ "text": [ "Labels: ABBOTTS BABBLER,ABBOTTS BOOBY,ABYSSINIAN GROUND HORNBILL,AFRICAN CROWNED CRANE,AFRICAN EMERALD CUCKOO,AFRICAN FIREFINCH,AFRICAN OYSTER CATCHER,AFRICAN PIED HORNBILL,AFRICAN PYGMY GOOSE,ALBATROSS \n", "\n", - "Example Image: \n" + "Example Image: \n" ] }, { @@ -174,7 +208,7 @@ "" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -195,17 +229,17 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "86f687f0-ce79-4abf-8217-08901a491ebd", "metadata": {}, "outputs": [], "source": [ - "classifier = BaseClassifier(device=DEVICE)" + "classifier = BaseClassifier(device=device)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "cd2d5fbb-ce84-4641-ace5-59872a435431", "metadata": {}, "outputs": [], @@ -218,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "6d158873-fb85-4e55-81c7-0a217cc09e3b", "metadata": {}, "outputs": [], @@ -244,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "74408592-92a2-4372-beb7-c455c230a3dc", "metadata": {}, "outputs": [], @@ -264,12 +298,13 @@ "metadata": {}, "source": [ "## Setup a SVM model\n", - "The `init_svc()` function is from [biobench newt](https://github.com/samuelstevens/biobench/blob/637432bfda2b567d966d49bf8c4b37b339d4dc2a/biobench/newt/__init__.py#L247-L262).\n" + "The `init_svc()` function is copied from [biobench newt](https://github.com/samuelstevens/biobench/blob/637432bfda2b567d966d49bf8c4b37b339d4dc2a/biobench/newt/__init__.py#L247-L262) \n", + "created by [@samuelstevens](https://github.com/samuelstevens)." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "7663e6f4-4ea9-4236-960b-750d2827b21c", "metadata": {}, "outputs": [], @@ -295,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "bac122f5-9c95-4670-a027-5976a085b89a", "metadata": {}, "outputs": [], @@ -312,19 +347,22 @@ "id": "f5d51cfd-fc4a-4e7b-b92f-4388e5aa34ec", "metadata": {}, "source": [ - "## Train the SVM model" + "## Train the SVM model\n", + "Trains the SVM using the train dataset. This step takes ~ 10 minutes when running on CPU and ~1 minute otherwise." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "12965290-314a-48e3-bc55-a0f47824b96d", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c3a767f2850a4d4c935d58d9137c8056", + "model_id": "acf3e3cc9e9f4cd0ac74a08e0e90c415", "version_major": 2, "version_minor": 0 }, @@ -761,37 +799,37 @@ " background-color: var(--sklearn-color-fitted-level-3);\n", "}\n", "
Pipeline(steps=[('functiontransformer',\n",
-       "                 FunctionTransformer(func=<function create_image_features at 0x33fd4aef0>)),\n",
+       "                 FunctionTransformer(func=<function create_image_features at 0x339733760>)),\n",
        "                ('randomizedsearchcv',\n",
        "                 RandomizedSearchCV(estimator=Pipeline(steps=[('standardscaler',\n",
        "                                                               StandardScaler()),\n",
        "                                                              ('svc', SVC())]),\n",
        "                                    n_iter=100, n_jobs=16,\n",
-       "                                    param_distributions={'svc__C': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x33fd3faf0>,\n",
-       "                                                         'svc__gamma': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x33fd3f4f0>,\n",
+       "                                    param_distributions={'svc__C': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x3397271c0>,\n",
+       "                                                         'svc__gamma': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x339727550>,\n",
        "                                                         'svc__kernel': ['rbf',\n",
        "                                                                         'linear',\n",
        "                                                                         'sigmoid',\n",
        "                                                                         'poly']},\n",
        "                                    random_state=42, verbose=2))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.