diff --git a/README.md b/README.md index b35fa40..f3daba5 100644 --- a/README.md +++ b/README.md @@ -102,9 +102,29 @@ fish 2.932403668845507e-12 bear 1.0 ``` +### Predict from a list of classes with binning +```python +from bioclip import CustomLabelsBinningClassifier +classifier = CustomLabelsBinningClassifier(cls_to_bin={ + 'dog': 'small', + 'fish': 'small', + 'bear': 'big', +}) +predictions = classifier.predict("Ursus-arctos.jpeg") +for prediction in predictions: + print(prediction["classification"], prediction["score"]) +``` +Output: +```console +big 0.99992835521698 +small 7.165559509303421e-05 +``` + ## Command Line Usage ``` -bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...] +bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] + [--rank {kingdom,phylum,class,order,family,genus,species} | --cls CLS | --bins BINS] + [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...] bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...] Commands: @@ -117,9 +137,13 @@ Arguments: Options: -h --help --format=FORMAT format of the output (table or csv) for predict mode [default: csv] - --rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species] - --k=K number of top predictions to show [default: 5] - --cls=CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed. + --rank {kingdom,phylum,class,order,family,genus,species} + rank of the classification, default: species (when) + --cls CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the + --rank and --bins arguments are not allowed. + --bins BINS path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls and + --bins arguments are not allowed. + --k K number of top predictions to show, default: 5 --device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu] --output=OUTFILE print output to file OUTFILE [default: stdout] ``` @@ -195,6 +219,28 @@ Ursus-arctos.jpeg,bird,3.051998476166773e-08 Ursus-arctos.jpeg,bear,0.9999998807907104 ``` +### Predict from a binning CSV +Create predictions for 3 classes (cat, bird, and bear) with 2 bins (one, two) for image `Ursus-arctos.jpeg`: + +Create a CSV file named `bins.csv` with the following contents: +``` +cls,bin +cat,one +bird,one +bear,two +``` + +Run predict command: +```console +bioclip predict --bins bins.csv Ursus-arctos.jpeg +``` + +Output: +``` +Ursus-arctos.jpeg,two,0.9999998807907104 +Ursus-arctos.jpeg,one,7.633736487377973e-08 +``` + ### Create embeddings #### Create embedding for an image diff --git a/src/bioclip/__init__.py b/src/bioclip/__init__.py index 2ddf567..2c50324 100644 --- a/src/bioclip/__init__.py +++ b/src/bioclip/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2024-present John Bradley # # SPDX-License-Identifier: MIT -from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier +from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier -__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier"] +__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier", "CustomLabelsBinningClassifier"] diff --git a/src/bioclip/__main__.py b/src/bioclip/__main__.py index 525f65a..de189bd 100644 --- a/src/bioclip/__main__.py +++ b/src/bioclip/__main__.py @@ -1,4 +1,4 @@ -from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier +from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier from .predict import BIOCLIP_MODEL_STR import open_clip as oc import os @@ -32,17 +32,32 @@ def write_results_to_file(df, format, outfile): raise ValueError(f"Invalid format: {format}") +def parse_bins_csv(bins_path): + if not os.path.exists(bins_path): + raise FileNotFoundError(f"File not found: {bins_path}") + bin_df = pd.read_csv(bins_path, index_col=0) + if len(bin_df.columns) == 0: + raise ValueError("CSV file must have at least two columns.") + return bin_df[bin_df.columns[0]].to_dict() + + def predict(image_file: list[str], format: str, output: str, cls_str: str, rank: Rank, + bins_path: str, k: int, **kwargs): if cls_str: classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs) predictions = classifier.predict(image_paths=image_file, k=k) write_results(predictions, format, output) + elif bins_path: + cls_to_bin = parse_bins_csv(bins_path) + classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs) + predictions = classifier.predict(image_paths=image_file, k=k) + write_results(predictions, format, output) else: classifier = TreeOfLifeClassifier(**kwargs) predictions = classifier.predict(image_paths=image_file, rank=rank, k=k) @@ -81,11 +96,13 @@ def create_parser(): predict_parser.add_argument('image_file', nargs='+', help='input image file(s)') predict_parser.add_argument('--format', choices=['table', 'csv'], default='csv', help='format of the output, default: csv') predict_parser.add_argument('--output', **output_arg) - predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'], + cls_group = predict_parser.add_mutually_exclusive_group(required=False) + cls_group.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'], help='rank of the classification, default: species (when)') + cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank and --bins arguments are not allowed." + cls_group.add_argument('--cls', help=cls_help) + cls_group.add_argument('--bins', help='path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls and --bins arguments are not allowed.') predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5') - cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed." - predict_parser.add_argument('--cls', help=cls_help) predict_parser.add_argument('--device', **device_arg) predict_parser.add_argument('--model', **model_arg) @@ -115,11 +132,7 @@ def create_parser(): def parse_args(input_args=None): args = create_parser().parse_args(input_args) if args.command == 'predict': - if args.cls: - # custom class list mode - if args.rank: - raise ValueError("Cannot use --cls with --rank") - else: + if not args.cls and not args.bins: # tree of life class list mode if args.model or args.pretrained: raise ValueError("Custom model or checkpoints currently not supported for Tree-of-Life prediction") @@ -155,6 +168,7 @@ def main(): output=args.output, cls_str=cls_str, rank=args.rank, + bins_path=args.bins, k=args.k, device=args.device, model_str=args.model, @@ -167,7 +181,7 @@ def main(): for model_str in oc.list_models(): print(f"\t{model_str}") else: - raise ValueError("Invalid command") + create_parser().print_help() if __name__ == '__main__': diff --git a/src/bioclip/predict.py b/src/bioclip/predict.py index 615f859..283b777 100644 --- a/src/bioclip/predict.py +++ b/src/bioclip/predict.py @@ -253,13 +253,39 @@ def predict(self, image_paths: List[str] | str, k: int = None) -> dict[str, floa img_probs = probs[image_path] if not k or k > len(self.classes): k = len(self.classes) - topk = img_probs.topk(k) - for i, prob in zip(topk.indices, topk.values): - result.append({ - PRED_FILENAME_KEY: image_path, - PRED_CLASSICATION_KEY: self.classes[i], - PRED_SCORE_KEY: prob.item() - }) + result.extend(self.group_probs(image_path, img_probs, k)) + return result + + def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]: + result = [] + topk = img_probs.topk(k) + for i, prob in zip(topk.indices, topk.values): + result.append({ + PRED_FILENAME_KEY: image_path, + PRED_CLASSICATION_KEY: self.classes[i], + PRED_SCORE_KEY: prob.item() + }) + return result + + +class CustomLabelsBinningClassifier(CustomLabelsClassifier): + def __init__(self, cls_to_bin: dict, **kwargs): + super().__init__(cls_ary=cls_to_bin.keys(), **kwargs) + self.cls_to_bin = cls_to_bin + + def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]: + result = [] + output = collections.defaultdict(float) + for i in range(len(self.classes)): + name = self.cls_to_bin[self.classes[i]] + output[name] += img_probs[i] + topk_names = heapq.nlargest(k, output, key=output.get) + for name in topk_names: + result.append({ + PRED_FILENAME_KEY: image_path, + PRED_CLASSICATION_KEY: name, + PRED_SCORE_KEY: output[name].item() + }) return result diff --git a/tests/test_main.py b/tests/test_main.py index a3e7865..e55e321 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,7 +1,8 @@ import unittest from unittest.mock import mock_open, patch import argparse -from bioclip.__main__ import parse_args, Rank, create_classes_str, main +import pandas as pd +from bioclip.__main__ import parse_args, Rank, create_classes_str, main, parse_bins_csv class TestParser(unittest.TestCase): @@ -15,6 +16,7 @@ def test_parse_args(self): self.assertEqual(args.rank, Rank.SPECIES) self.assertEqual(args.k, 5) self.assertEqual(args.cls, None) + self.assertEqual(args.bins, None) self.assertEqual(args.device, 'cpu') args = parse_args(['predict', 'image.jpg', 'image2.png']) @@ -41,12 +43,29 @@ def test_parse_args(self): self.assertEqual(args.rank, None) # default ignored for the --cls variation self.assertEqual(args.k, None) self.assertEqual(args.cls, 'class1,class2') + self.assertEqual(args.bins, None) + self.assertEqual(args.device, 'cuda') + + # test binning version of predict + args = parse_args(['predict', 'image.jpg', '--format', 'table', '--output', 'output.csv', '--bins', 'bins.csv', '--device', 'cuda']) + self.assertEqual(args.command, 'predict') + self.assertEqual(args.image_file, ['image.jpg']) + self.assertEqual(args.format, 'table') + self.assertEqual(args.output, 'output.csv') + self.assertEqual(args.rank, None) # default ignored for the --cls variation + self.assertEqual(args.k, None) + self.assertEqual(args.cls, None) + self.assertEqual(args.bins, 'bins.csv') self.assertEqual(args.device, 'cuda') # test error when using --cls with --rank - with self.assertRaises(ValueError): + with self.assertRaises(SystemExit): parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--rank', 'genus']) + # test error when using --cls with --bins + with self.assertRaises(SystemExit): + parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--bins', 'somefile.csv', 'genus']) + # not an error when using --cls with --k args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10']) self.assertEqual(args.k, 10) @@ -77,10 +96,10 @@ def test_create_classes_str(self): def test_predict_no_class(self, mock_parse_args, mock_predict): mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv', output='stdout', rank=Rank.SPECIES, k=5, cls=None, device='cpu', - model=None, pretrained=None) + model=None, pretrained=None, bins=None) main() - mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES, k=5, - device='cpu', model_str=None, pretrained_str=None) + mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES, + bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None) @patch('bioclip.__main__.predict') @patch('bioclip.__main__.parse_args') @@ -89,10 +108,10 @@ def test_predict_class_list(self, mock_os, mock_parse_args, mock_predict): mock_os.path.exists.return_value = False mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv', output='stdout', rank=Rank.SPECIES, k=5, cls='dog,fish,bird', - device='cpu', model=None, pretrained=None) + device='cpu', model=None, pretrained=None, bins=None) main() mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES, - k=5, device='cpu', model_str=None, pretrained_str=None) + bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None) @patch('bioclip.__main__.predict') @patch('bioclip.__main__.parse_args') @@ -101,8 +120,38 @@ def test_predict_class_file(self, mock_os, mock_parse_args, mock_predict): mock_os.path.exists.return_value = True mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv', output='stdout', rank=Rank.SPECIES, k=5, cls='somefile.txt', - device='cpu', model=None, pretrained=None) + device='cpu', model=None, pretrained=None, bins=None) with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file: main() mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES, - k=5, device='cpu', model_str=None, pretrained_str=None) + bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None) + + @patch('bioclip.__main__.predict') + @patch('bioclip.__main__.parse_args') + @patch('bioclip.__main__.os') + def test_predict_bins(self, mock_os, mock_parse_args, mock_predict): + mock_os.path.exists.return_value = True + mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv', + output='stdout', rank=None, k=5, cls=None, + device='cpu', model=None, pretrained=None, + bins='some.csv') + with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file: + main() + mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=None, + bins_path='some.csv', k=5, device='cpu', model_str=None, pretrained_str=None) + @patch('bioclip.__main__.os.path') + def test_parse_bins_csv_file_missing(self, mock_path): + mock_path.exists.return_value = False + with self.assertRaises(FileNotFoundError) as raised_exception: + parse_bins_csv("somefile.csv") + self.assertEqual(str(raised_exception.exception), 'File not found: somefile.csv') + + @patch('bioclip.__main__.pd') + @patch('bioclip.__main__.os.path') + def test_parse_bins_csv(self, mock_path, mock_pd): + mock_path.exists.return_value = True + data = {'bin': ['a', 'b']} + mock_pd.read_csv.return_value = pd.DataFrame(data=data, index=['dog', 'cat']) + with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file: + cls_to_bin = parse_bins_csv("somefile.csv") + self.assertEqual(cls_to_bin, {'cat': 'b', 'dog': 'a'}) diff --git a/tests/test_predict.py b/tests/test_predict.py index 2a2b20e..41e5700 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,6 +1,7 @@ import unittest from bioclip.predict import TreeOfLifeClassifier, Rank from bioclip.predict import CustomLabelsClassifier +from bioclip.predict import CustomLabelsBinningClassifier import os import torch @@ -81,13 +82,34 @@ def test_custom_labels_classifier_ary_multiple(self): {'file_name': EXAMPLE_CAT_IMAGE2, 'classification': 'dog', 'score': unittest.mock.ANY}, ]) - def test_predict_with_rgba_image(self): # Ensure that the classifier can handle RGBA images classifier = TreeOfLifeClassifier() prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2], rank=Rank.SPECIES) self.assertEqual(len(prediction_ary), 5) + def test_predict_with_bins(self): + classifier = CustomLabelsBinningClassifier(cls_to_bin={ + 'cat': 'one', + 'mouse': 'two', + 'fish': 'two', + }) + prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2]) + self.assertEqual(len(prediction_ary), 2) + self.assertEqual(prediction_ary[0]['file_name'], EXAMPLE_CAT_IMAGE2) + names = set([pred['classification'] for pred in prediction_ary]) + self.assertEqual(names, set(['one', 'two'])) + + classifier = CustomLabelsBinningClassifier(cls_to_bin={ + 'cat': 'one', + 'mouse': 'two', + 'fish': 'three', + }) + prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2]) + self.assertEqual(len(prediction_ary), 3) + self.assertEqual(prediction_ary[0]['file_name'], EXAMPLE_CAT_IMAGE2) + names = set([pred['classification'] for pred in prediction_ary]) + self.assertEqual(names, set(['one', 'two', 'three'])) class TestEmbed(unittest.TestCase): def test_get_image_features(self):