Skip to content

Commit

Permalink
Add tests for bins predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbradley committed Oct 3, 2024
1 parent e08a0dc commit 4ad42f7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
18 changes: 11 additions & 7 deletions src/bioclip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def write_results_to_file(df, format, outfile):
raise ValueError(f"Invalid format: {format}")


def parse_bins_csv(bins_path):
if not os.path.exists(bins_path):
raise FileNotFoundError(f"File not found: {bins_path}")
bin_df = pd.read_csv(bins_path, index_col=0)
if 'bin' not in bin_df.columns:
raise ValueError("CSV file must have a column named 'bin'")
return bin_df.bin.to_dict()


def predict(image_file: list[str],
format: str,
output: str,
Expand All @@ -45,8 +54,7 @@ def predict(image_file: list[str],
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
elif bins_path:
bin_df = pd.read_csv(bins_path, index_col=0)
cls_to_bin = bin_df.bin.to_dict()
cls_to_bin = parse_bins_csv(bins_path)
classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs)
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
Expand Down Expand Up @@ -124,11 +132,7 @@ def create_parser():
def parse_args(input_args=None):
args = create_parser().parse_args(input_args)
if args.command == 'predict':
if args.cls:
# custom class list mode
if args.rank:
raise ValueError("Cannot use --cls with --rank")
else:
if not args.cls and not args.bins:
# tree of life class list mode
if args.model or args.pretrained:
raise ValueError("Custom model or checkpoints currently not supported for Tree-of-Life prediction")
Expand Down
51 changes: 50 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from unittest.mock import mock_open, patch
import argparse
from bioclip.__main__ import parse_args, Rank, create_classes_str, main
import pandas as pd
from bioclip.__main__ import parse_args, Rank, create_classes_str, main, parse_bins_csv


class TestParser(unittest.TestCase):
Expand All @@ -15,6 +16,7 @@ def test_parse_args(self):
self.assertEqual(args.rank, Rank.SPECIES)
self.assertEqual(args.k, 5)
self.assertEqual(args.cls, None)
self.assertEqual(args.bins, None)
self.assertEqual(args.device, 'cpu')

args = parse_args(['predict', 'image.jpg', 'image2.png'])
Expand All @@ -41,12 +43,29 @@ def test_parse_args(self):
self.assertEqual(args.rank, None) # default ignored for the --cls variation
self.assertEqual(args.k, None)
self.assertEqual(args.cls, 'class1,class2')
self.assertEqual(args.bins, None)
self.assertEqual(args.device, 'cuda')

# test binning version of predict
args = parse_args(['predict', 'image.jpg', '--format', 'table', '--output', 'output.csv', '--bins', 'bins.csv', '--device', 'cuda'])
self.assertEqual(args.command, 'predict')
self.assertEqual(args.image_file, ['image.jpg'])
self.assertEqual(args.format, 'table')
self.assertEqual(args.output, 'output.csv')
self.assertEqual(args.rank, None) # default ignored for the --cls variation
self.assertEqual(args.k, None)
self.assertEqual(args.cls, None)
self.assertEqual(args.bins, 'bins.csv')
self.assertEqual(args.device, 'cuda')

# test error when using --cls with --rank
with self.assertRaises(SystemExit):
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--rank', 'genus'])

# test error when using --cls with --bins
with self.assertRaises(SystemExit):
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--bins', 'somefile.csv', 'genus'])

# not an error when using --cls with --k
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
self.assertEqual(args.k, 10)
Expand Down Expand Up @@ -106,3 +125,33 @@ def test_predict_class_file(self, mock_os, mock_parse_args, mock_predict):
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
@patch('bioclip.__main__.os')
def test_predict_bins(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = True
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=None, k=5, cls=None,
device='cpu', model=None, pretrained=None,
bins='some.csv')
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=None,
bins_path='some.csv', k=5, device='cpu', model_str=None, pretrained_str=None)
@patch('bioclip.__main__.os.path')
def test_parse_bins_csv_file_missing(self, mock_path):
mock_path.exists.return_value = False
with self.assertRaises(FileNotFoundError) as raised_exception:
parse_bins_csv("somefile.csv")
self.assertEqual(str(raised_exception.exception), 'File not found: somefile.csv')

@patch('bioclip.__main__.pd')
@patch('bioclip.__main__.os.path')
def test_parse_bins_csv(self, mock_path, mock_pd):
mock_path.exists.return_value = True
data = {'bin': ['a', 'b']}
mock_pd.read_csv.return_value = pd.DataFrame(data=data, index=['dog', 'cat'])
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
cls_to_bin = parse_bins_csv("somefile.csv")
self.assertEqual(cls_to_bin, {'cat': 'b', 'dog': 'a'})
11 changes: 10 additions & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from bioclip.predict import TreeOfLifeClassifier, Rank
from bioclip.predict import CustomLabelsClassifier
from bioclip.predict import CustomLabelsBinningClassifier
import os
import torch

Expand Down Expand Up @@ -81,13 +82,21 @@ def test_custom_labels_classifier_ary_multiple(self):
{'file_name': EXAMPLE_CAT_IMAGE2, 'classification': 'dog', 'score': unittest.mock.ANY},
])


def test_predict_with_rgba_image(self):
# Ensure that the classifier can handle RGBA images
classifier = TreeOfLifeClassifier()
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2], rank=Rank.SPECIES)
self.assertEqual(len(prediction_ary), 5)

def test_predict_with_bins(self):
classifier = CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': 'two',
'fish': 'two',
})
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
self.assertEqual(len(prediction_ary), 2)


class TestEmbed(unittest.TestCase):
def test_get_image_features(self):
Expand Down

0 comments on commit 4ad42f7

Please sign in to comment.