diff --git a/notebooks/ViT_AdapterPlus_FineTuning.ipynb b/notebooks/ViT_AdapterPlus_FineTuning.ipynb new file mode 100644 index 000000000..8dfbcd341 --- /dev/null +++ b/notebooks/ViT_AdapterPlus_FineTuning.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-Tuning ViT Classification models using AdapterPlus\n", + "\n", + "In this vision tutorial, we will show how to fine-tune ViT Image models using the `AdapterPlus` Config, which is a bottleneck adapter using the parameters as defined in the `AdapterPlus` paper. For more information on bottleneck adapters, you can visit our basic [tutorial](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb) and our docs [page](https://docs.adapterhub.ml/methods#bottleneck-adapters).\n", + "\n", + "You can find the link to the `AdapterPlus` paper by Steitz and Roth [here](https://openaccess.thecvf.com/content/CVPR2024/papers/Steitz_Adapters_Strike_Back_CVPR_2024_paper.pdf) and their GitHub page [here](https://github.com/visinf/adapter_plus)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "Before we can get started, we need to ensure the proper packages are installed. Here's a breakdown of what we need:\n", + "\n", + "- `adapters` to load the model and the adapter configuration\n", + "- `accelerate` for training optimization\n", + "- `evaluate` for metric computation and model evaluation\n", + "- `datasets` to import the datasets for training and evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qq -U adapters datasets accelerate torchvision evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset\n", + "\n", + "For this tutorial, we will be using a light-weight image dataset `cifar100`, which contains 60,000 images with 100 classes. We use the `datasets` library to directly import the dataset and split it into its training and evaluation datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "num_classes = 100\n", + "train_dataset = load_dataset(\"uoft-cs/cifar100\", split = \"train\")\n", + "eval_dataset = load_dataset(\"uoft-cs/cifar100\", split = \"test\")\n", + "\n", + "train_dataset.set_format(\"torch\")\n", + "eval_dataset.set_format(\"torch\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this tutorial, we will be using the fine_label, to match the number of classes (100) that were used in the paper." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset[0].keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will now initialize our `ViT` image processor to convert the images into a more friendly format. It will also apply transformations to each image in order to improve the performance during training." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "model_name_or_path = 'google/vit-base-patch16-224-in21k'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import ViTImageProcessor\n", + "processor = ViTImageProcessor.from_pretrained(model_name_or_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll print out the processor here in order to get an idea of what types of transformations it is applying onto the iamge" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data Preprocessing\n", + "\n", + "We will pre-process every image as defined in the `processor` above, and add the `label` key which contains our labels" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess_image(example):\n", + " image = processor(example[\"img\"], return_tensors='pt')\n", + " image[\"label\"] = example[\"fine_label\"]\n", + " return image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = train_dataset.map(preprocess_image)\n", + "eval_dataset = eval_dataset.map(preprocess_image)\n", + "#remove uneccessary columns\n", + "train_dataset = train_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])\n", + "eval_dataset = eval_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Defining a Datacollator\n", + "\n", + "We'll be using a very simple custom datacollator to help us combine multiple data samples into one batch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "from dataclasses import dataclass\n", + "\n", + "@dataclass\n", + "class DataCollator:\n", + " processor : Any\n", + " def __call__(self, inputs):\n", + "\n", + " pixel_values = [input[\"pixel_values\"].squeeze() for input in inputs]\n", + " labels = [input[\"label\"] for input in inputs]\n", + "\n", + " pixel_values = torch.stack(pixel_values)\n", + " labels = torch.stack(labels)\n", + " return {\n", + " 'pixel_values': pixel_values,\n", + " 'labels': labels,\n", + " }\n", + "\n", + "data_collator = DataCollator(processor = processor)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Loading the `ViT` model and the `AdapterPlusConfig`\n", + "\n", + "Here we load the `vit-base-patch16-224-in21k` model similar to the one used in the `AdapterConfig` paper. We will load the model using the `adapters` `AutoAdapterModel` and add the corresponding `AdapterPlusConfig`. To read more about the config, you can check out the docs page [here](https://docs.adapterhub.ml/methods#bottleneck-adapters) under `AdapterPlusConfig`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from adapters import ViTAdapterModel\n", + "from adapters import AdapterPlusConfig\n", + "\n", + "model = ViTAdapterModel.from_pretrained(model_name_or_path)\n", + "config = AdapterPlusConfig(original_ln_after=True)\n", + "\n", + "model.add_adapter(\"adapterplus_config\", config)\n", + "model.add_image_classification_head(\"adapterplus_config\", num_labels=num_classes)\n", + "model.train_adapter(\"adapterplus_config\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "Name Architecture #Param %Param Active Train\n", + "--------------------------------------------------------------------------------\n", + "adapterplus_config bottleneck 165,984 0.192 1 1\n", + "--------------------------------------------------------------------------------\n", + "Full model 86,389,248 100.000 0\n", + "================================================================================\n" + ] + } + ], + "source": [ + "print(model.adapter_summary())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation Metrics\n", + "\n", + "We'll use accuracy as our main metric to evaluate the perforce of the reft model on the `cifar100` dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import evaluate\n", + "accuracy = evaluate.load(\"accuracy\")\n", + "\n", + "def compute_metrics(p):\n", + " return accuracy.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training\n", + "\n", + "Now we are ready to train our model. The same set of hyper-parameters that were used in the original paper will be re-used, except for the number of training epochs that the model will be trained on. You can always adjust the number of epochs yourself or any other hyperparameter in the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from adapters import AdapterTrainer\n", + "from transformers import TrainingArguments\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir='./training_results',\n", + " eval_strategy='epoch',\n", + " learning_rate=10e-3,\n", + " per_device_train_batch_size=64,\n", + " per_device_eval_batch_size=64,\n", + " num_train_epochs=5,\n", + " weight_decay=10e-4,\n", + " report_to = \"none\",\n", + " remove_unused_columns=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = AdapterTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " data_collator=data_collator,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " tokenizer=processor,\n", + " compute_metrics = compute_metrics\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
Epoch | \n", + "Training Loss | \n", + "Validation Loss | \n", + "Accuracy | \n", + "
---|---|---|---|
1 | \n", + "0.523500 | \n", + "0.351386 | \n", + "0.897700 | \n", + "
2 | \n", + "0.264200 | \n", + "0.329842 | \n", + "0.906500 | \n", + "
3 | \n", + "0.184700 | \n", + "0.324673 | \n", + "0.911400 | \n", + "
4 | \n", + "0.135400 | \n", + "0.347025 | \n", + "0.909100 | \n", + "
5 | \n", + "0.114700 | \n", + "0.349541 | \n", + "0.910900 | \n", + "
6 | \n", + "0.074300 | \n", + "0.370867 | \n", + "0.909600 | \n", + "
7 | \n", + "0.058100 | \n", + "0.373732 | \n", + "0.912400 | \n", + "
8 | \n", + "0.039300 | \n", + "0.376220 | \n", + "0.913900 | \n", + "
9 | \n", + "0.028200 | \n", + "0.381238 | \n", + "0.913300 | \n", + "
10 | \n", + "0.021600 | \n", + "0.380825 | \n", + "0.912400 | \n", + "
"
+ ],
+ "text/plain": [
+ "