Skip to content

Commit

Permalink
Add documentation for criterion
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed May 27, 2024
1 parent a9939fa commit 982d91f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 83 deletions.
3 changes: 2 additions & 1 deletion docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
* [baal.bayesian](./bayesian.md)
* [baal.active](./dataset_management.md)
* [baal.active.heuristics](./heuristics.md)
* [baal.active.stopping_criteria](./stopping_criteria.md)
* [baal.calibration](./calibration.md)
* [baal.utils](./utils.md)

### :material-file-tree: Compatibility

* [baal.utils.pytorch_lightning] (./compatibility/pytorch-lightning)
* [baal.utils.pytorch_lightning](./compatibility/pytorch-lightning)
* [baal.transformers_trainer_wrapper](./compatibility/huggingface)


Expand Down
35 changes: 35 additions & 0 deletions docs/api/stopping_criteria.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Stopping Criteria

Stopping criterion are used to determine when to stop your active learning experiment.

Their usage are simple, but best put in practice with `ActiveExperiment`.

**Example**
```python
from baal.active.stopping_criteria import LabellingBudgetStoppingCriterion
from baal.active.dataset import ActiveLearningDataset

al_dataset: ActiveLearningDataset = ... # len(al_dataset) == 10
criterion = LabellingBudgetStoppingCriterion(al_dataset, labelling_budget=100)

assert not criterion.should_stop({}, [])

# len(al_dataset) == 60
al_dataset.label_randomly(50)
assert not criterion.should_stop({}, [])

# len(al_dataset) == 110, budget exhausted! We've labelled 100 items.
al_dataset.label_randomly(50)
assert criterion.should_stop({}, [])
```


### API

### baal.active.stopping_criteria

::: baal.active.stopping_criteria.LabellingBudgetStoppingCriterion

::: baal.active.stopping_criteria.LowAverageUncertaintyStoppingCriterion

::: baal.active.stopping_criteria.EarlyStoppingCriterion
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ nav:
- api/calibration.md
- api/dataset_management.md
- api/heuristics.md
- api/stopping_criteria.md
- api/modelwrapper.md
- api/utils.md
- Compatibility:
Expand Down
104 changes: 22 additions & 82 deletions notebooks/production/baal_prod_cls.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train: 5174, Valid: 1725, Num. classes : 8\n"
]
}
],
"source": [
"from glob import glob\n",
"import os\n",
Expand All @@ -52,7 +43,8 @@
"classes = os.listdir('/tmp/natural_images')\n",
"train, test = train_test_split(files, random_state=1337) # Split 75% train, 25% validation\n",
"print(f\"Train: {len(train)}, Valid: {len(test)}, Num. classes : {len(classes)}\")\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -79,7 +71,6 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"from baal.active import FileDataset, ActiveLearningDataset\n",
"from torchvision import transforms\n",
Expand All @@ -101,7 +92,8 @@
"# We use -1 to specify that the data is unlabeled.\n",
"test_dataset = FileDataset(test, [-1] * len(test), test_transform)\n",
"active_learning_ds = ActiveLearningDataset(train_dataset, pool_specifics={'transform': test_transform})\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -129,7 +121,6 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn, optim\n",
Expand All @@ -149,7 +140,8 @@
"# ModelWrapper is an object similar to keras.Model.\n",
"baal_model = ModelWrapper(model, criterion)\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -170,11 +162,11 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"from baal.active.heuristics import BALD\n",
"heuristic = BALD(shuffle_prop=0.1)\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -193,13 +185,13 @@
"is_executing": false
}
},
"outputs": [],
"source": [
"# This function would do the work that a human would do.\n",
"def get_label(img_path):\n",
" return classes.index(img_path.split('/')[-2])\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -223,15 +215,6 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num. labeled: 100/5174\n"
]
}
],
"source": [
"import numpy as np\n",
"# 1. Label all the test set and some samples from the training set.\n",
Expand All @@ -246,7 +229,8 @@
"active_learning_ds.label(train_idxs, labels)\n",
"\n",
"print(f\"Num. labeled: {len(active_learning_ds)}/{len(train_dataset)}\")\n"
]
],
"outputs": []
},
{
"cell_type": "code",
Expand All @@ -256,56 +240,19 @@
"is_executing": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:47:48.133213Z [\u001B[32minfo ] Starting training dataset=100 epoch=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.9/site-packages/torch/utils/data/dataloader.py:478: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 1, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" warnings.warn(_create_warning_msg(\n",
"/opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n",
" return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:48:07.477011Z [\u001B[32minfo ] Training complete train_loss=2.058176279067993\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:48:07.479793Z [\u001B[32minfo ] Starting evaluating dataset=1725\n",
"[103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:48:21.277716Z [\u001B[32minfo ] Evaluation complete test_loss=2.0671451091766357\n",
"Metrics: {'test_loss': 2.0671451091766357, 'train_loss': 2.058176279067993}\n"
]
}
],
"source": [
"# 2. Train the model for a few epoch on the training set.\n",
"baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)\n",
"baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)\n",
"\n",
"print(\"Metrics:\", {k:v.avg for k,v in baal_model.metrics.items()})\n"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:48:21.291851Z [\u001B[32minfo ] Start Predict dataset=5074\n"
]
}
],
"source": [
"# 3. Select the K-top uncertain samples according to the heuristic.\n",
"pool = active_learning_ds.pool\n",
Expand All @@ -316,29 +263,22 @@
"predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)\n",
"# We will label the 10 most uncertain samples.\n",
"top_uncertainty = heuristic(predictions)[:10]\n"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(3, 1429), (4, 2971), (2, 1309), (4, 5), (3, 3761), (4, 2708), (6, 4679), (7, 160), (7, 1638), (6, 73)]\n"
]
}
],
"source": [
"# 4. Label those samples.\n",
"oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)\n",
"labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]\n",
"print(list(zip(labels, oracle_indices)))\n",
"active_learning_ds.label(top_uncertainty, labels)\n",
"\n"
]
],
"outputs": []
},
{
"cell_type": "code",
Expand All @@ -348,7 +288,6 @@
"is_executing": true
}
},
"outputs": [],
"source": [
"# 5. If not done, go back to 2.\n",
"for step in range(5): # 5 Active Learning step!\n",
Expand All @@ -372,7 +311,8 @@
" active_learning_ds.label(top_uncertainty, labels)\n",
" \n",
" "
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -386,14 +326,14 @@
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"torch.save({\n",
" 'active_dataset': active_learning_ds.state_dict(),\n",
" 'model': baal_model.state_dict(),\n",
" 'metrics': {k:v.avg for k,v in baal_model.metrics.items()}\n",
"}, '/tmp/baal_output.pth')\n"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand Down

0 comments on commit 982d91f

Please sign in to comment.