Skip to content

Commit

Permalink
Add binning to custom label prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbradley committed Oct 8, 2024
1 parent 7f2041b commit 830f9e8
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 33 deletions.
54 changes: 50 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,29 @@ fish 2.932403668845507e-12
bear 1.0
```

### Predict from a list of classes with binning
```python
from bioclip import CustomLabelsBinningClassifier
classifier = CustomLabelsBinningClassifier(cls_to_bin={
'dog': 'small',
'fish': 'small',
'bear': 'big',
})
predictions = classifier.predict("Ursus-arctos.jpeg")
for prediction in predictions:
print(prediction["classification"], prediction["score"])
```
Output:
```console
big 0.99992835521698
small 7.165559509303421e-05
```

## Command Line Usage
```
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT]
[--rank {kingdom,phylum,class,order,family,genus,species} | --cls CLS | --bins BINS]
[--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...]
Commands:
Expand All @@ -117,9 +137,13 @@ Arguments:
Options:
-h --help
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
--k=K number of top predictions to show [default: 5]
--cls=CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed.
--rank {kingdom,phylum,class,order,family,genus,species}
rank of the classification, default: species (when)
--cls CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the
--rank and --bins arguments are not allowed.
--bins BINS path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls and
--bins arguments are not allowed.
--k K number of top predictions to show, default: 5
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
--output=OUTFILE print output to file OUTFILE [default: stdout]
```
Expand Down Expand Up @@ -195,6 +219,28 @@ Ursus-arctos.jpeg,bird,3.051998476166773e-08
Ursus-arctos.jpeg,bear,0.9999998807907104
```

### Predict from a binning CSV
Create predictions for 3 classes (cat, bird, and bear) with 2 bins (one, two) for image `Ursus-arctos.jpeg`:

Create a CSV file named `bins.csv` with the following contents:
```
cls,bin
cat,one
bird,one
bear,two
```

Run predict command:
```console
bioclip predict --bins bins.csv Ursus-arctos.jpeg
```

Output:
```
Ursus-arctos.jpeg,two,0.9999998807907104
Ursus-arctos.jpeg,one,7.633736487377973e-08
```

### Create embeddings

#### Create embedding for an image
Expand Down
4 changes: 2 additions & 2 deletions src/bioclip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: 2024-present John Bradley <johnbradley2008@gmail.com>
#
# SPDX-License-Identifier: MIT
from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier

__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier"]
__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier", "CustomLabelsBinningClassifier"]
34 changes: 24 additions & 10 deletions src/bioclip/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier
from .predict import BIOCLIP_MODEL_STR
import open_clip as oc
import os
Expand Down Expand Up @@ -32,17 +32,32 @@ def write_results_to_file(df, format, outfile):
raise ValueError(f"Invalid format: {format}")


def parse_bins_csv(bins_path):
if not os.path.exists(bins_path):
raise FileNotFoundError(f"File not found: {bins_path}")
bin_df = pd.read_csv(bins_path, index_col=0)
if len(bin_df.columns) == 0:
raise ValueError("CSV file must have at least two columns.")
return bin_df[bin_df.columns[0]].to_dict()


def predict(image_file: list[str],
format: str,
output: str,
cls_str: str,
rank: Rank,
bins_path: str,
k: int,
**kwargs):
if cls_str:
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs)
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
elif bins_path:
cls_to_bin = parse_bins_csv(bins_path)
classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs)
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
else:
classifier = TreeOfLifeClassifier(**kwargs)
predictions = classifier.predict(image_paths=image_file, rank=rank, k=k)
Expand Down Expand Up @@ -81,11 +96,13 @@ def create_parser():
predict_parser.add_argument('image_file', nargs='+', help='input image file(s)')
predict_parser.add_argument('--format', choices=['table', 'csv'], default='csv', help='format of the output, default: csv')
predict_parser.add_argument('--output', **output_arg)
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
cls_group = predict_parser.add_mutually_exclusive_group(required=False)
cls_group.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
help='rank of the classification, default: species (when)')
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank and --bins arguments are not allowed."
cls_group.add_argument('--cls', help=cls_help)
cls_group.add_argument('--bins', help='path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls and --bins arguments are not allowed.')
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed."
predict_parser.add_argument('--cls', help=cls_help)

predict_parser.add_argument('--device', **device_arg)
predict_parser.add_argument('--model', **model_arg)
Expand Down Expand Up @@ -115,11 +132,7 @@ def create_parser():
def parse_args(input_args=None):
args = create_parser().parse_args(input_args)
if args.command == 'predict':
if args.cls:
# custom class list mode
if args.rank:
raise ValueError("Cannot use --cls with --rank")
else:
if not args.cls and not args.bins:
# tree of life class list mode
if args.model or args.pretrained:
raise ValueError("Custom model or checkpoints currently not supported for Tree-of-Life prediction")
Expand Down Expand Up @@ -155,6 +168,7 @@ def main():
output=args.output,
cls_str=cls_str,
rank=args.rank,
bins_path=args.bins,
k=args.k,
device=args.device,
model_str=args.model,
Expand All @@ -167,7 +181,7 @@ def main():
for model_str in oc.list_models():
print(f"\t{model_str}")
else:
raise ValueError("Invalid command")
create_parser().print_help()


if __name__ == '__main__':
Expand Down
40 changes: 33 additions & 7 deletions src/bioclip/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,39 @@ def predict(self, image_paths: List[str] | str, k: int = None) -> dict[str, floa
img_probs = probs[image_path]
if not k or k > len(self.classes):
k = len(self.classes)
topk = img_probs.topk(k)
for i, prob in zip(topk.indices, topk.values):
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: self.classes[i],
PRED_SCORE_KEY: prob.item()
})
result.extend(self.group_probs(image_path, img_probs, k))
return result

def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
result = []
topk = img_probs.topk(k)
for i, prob in zip(topk.indices, topk.values):
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: self.classes[i],
PRED_SCORE_KEY: prob.item()
})
return result


class CustomLabelsBinningClassifier(CustomLabelsClassifier):
def __init__(self, cls_to_bin: dict, **kwargs):
super().__init__(cls_ary=cls_to_bin.keys(), **kwargs)
self.cls_to_bin = cls_to_bin

def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
result = []
output = collections.defaultdict(float)
for i in range(len(self.classes)):
name = self.cls_to_bin[self.classes[i]]
output[name] += img_probs[i]
topk_names = heapq.nlargest(k, output, key=output.get)
for name in topk_names:
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: name,
PRED_SCORE_KEY: output[name].item()
})
return result


Expand Down
67 changes: 58 additions & 9 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from unittest.mock import mock_open, patch
import argparse
from bioclip.__main__ import parse_args, Rank, create_classes_str, main
import pandas as pd
from bioclip.__main__ import parse_args, Rank, create_classes_str, main, parse_bins_csv


class TestParser(unittest.TestCase):
Expand All @@ -15,6 +16,7 @@ def test_parse_args(self):
self.assertEqual(args.rank, Rank.SPECIES)
self.assertEqual(args.k, 5)
self.assertEqual(args.cls, None)
self.assertEqual(args.bins, None)
self.assertEqual(args.device, 'cpu')

args = parse_args(['predict', 'image.jpg', 'image2.png'])
Expand All @@ -41,12 +43,29 @@ def test_parse_args(self):
self.assertEqual(args.rank, None) # default ignored for the --cls variation
self.assertEqual(args.k, None)
self.assertEqual(args.cls, 'class1,class2')
self.assertEqual(args.bins, None)
self.assertEqual(args.device, 'cuda')

# test binning version of predict
args = parse_args(['predict', 'image.jpg', '--format', 'table', '--output', 'output.csv', '--bins', 'bins.csv', '--device', 'cuda'])
self.assertEqual(args.command, 'predict')
self.assertEqual(args.image_file, ['image.jpg'])
self.assertEqual(args.format, 'table')
self.assertEqual(args.output, 'output.csv')
self.assertEqual(args.rank, None) # default ignored for the --cls variation
self.assertEqual(args.k, None)
self.assertEqual(args.cls, None)
self.assertEqual(args.bins, 'bins.csv')
self.assertEqual(args.device, 'cuda')

# test error when using --cls with --rank
with self.assertRaises(ValueError):
with self.assertRaises(SystemExit):
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--rank', 'genus'])

# test error when using --cls with --bins
with self.assertRaises(SystemExit):
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--bins', 'somefile.csv', 'genus'])

# not an error when using --cls with --k
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
self.assertEqual(args.k, 10)
Expand Down Expand Up @@ -77,10 +96,10 @@ def test_create_classes_str(self):
def test_predict_no_class(self, mock_parse_args, mock_predict):
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls=None, device='cpu',
model=None, pretrained=None)
model=None, pretrained=None, bins=None)
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES, k=5,
device='cpu', model_str=None, pretrained_str=None)
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES,
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
Expand All @@ -89,10 +108,10 @@ def test_predict_class_list(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = False
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls='dog,fish,bird',
device='cpu', model=None, pretrained=None)
device='cpu', model=None, pretrained=None, bins=None)
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
k=5, device='cpu', model_str=None, pretrained_str=None)
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
Expand All @@ -101,8 +120,38 @@ def test_predict_class_file(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = True
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls='somefile.txt',
device='cpu', model=None, pretrained=None)
device='cpu', model=None, pretrained=None, bins=None)
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
k=5, device='cpu', model_str=None, pretrained_str=None)
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
@patch('bioclip.__main__.os')
def test_predict_bins(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = True
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=None, k=5, cls=None,
device='cpu', model=None, pretrained=None,
bins='some.csv')
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=None,
bins_path='some.csv', k=5, device='cpu', model_str=None, pretrained_str=None)
@patch('bioclip.__main__.os.path')
def test_parse_bins_csv_file_missing(self, mock_path):
mock_path.exists.return_value = False
with self.assertRaises(FileNotFoundError) as raised_exception:
parse_bins_csv("somefile.csv")
self.assertEqual(str(raised_exception.exception), 'File not found: somefile.csv')

@patch('bioclip.__main__.pd')
@patch('bioclip.__main__.os.path')
def test_parse_bins_csv(self, mock_path, mock_pd):
mock_path.exists.return_value = True
data = {'bin': ['a', 'b']}
mock_pd.read_csv.return_value = pd.DataFrame(data=data, index=['dog', 'cat'])
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
cls_to_bin = parse_bins_csv("somefile.csv")
self.assertEqual(cls_to_bin, {'cat': 'b', 'dog': 'a'})
24 changes: 23 additions & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from bioclip.predict import TreeOfLifeClassifier, Rank
from bioclip.predict import CustomLabelsClassifier
from bioclip.predict import CustomLabelsBinningClassifier
import os
import torch

Expand Down Expand Up @@ -81,13 +82,34 @@ def test_custom_labels_classifier_ary_multiple(self):
{'file_name': EXAMPLE_CAT_IMAGE2, 'classification': 'dog', 'score': unittest.mock.ANY},
])


def test_predict_with_rgba_image(self):
# Ensure that the classifier can handle RGBA images
classifier = TreeOfLifeClassifier()
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2], rank=Rank.SPECIES)
self.assertEqual(len(prediction_ary), 5)

def test_predict_with_bins(self):
classifier = CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': 'two',
'fish': 'two',
})
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
self.assertEqual(len(prediction_ary), 2)
self.assertEqual(prediction_ary[0]['file_name'], EXAMPLE_CAT_IMAGE2)
names = set([pred['classification'] for pred in prediction_ary])
self.assertEqual(names, set(['one', 'two']))

classifier = CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': 'two',
'fish': 'three',
})
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
self.assertEqual(len(prediction_ary), 3)
self.assertEqual(prediction_ary[0]['file_name'], EXAMPLE_CAT_IMAGE2)
names = set([pred['classification'] for pred in prediction_ary])
self.assertEqual(names, set(['one', 'two', 'three']))

class TestEmbed(unittest.TestCase):
def test_get_image_features(self):
Expand Down

0 comments on commit 830f9e8

Please sign in to comment.