Skip to content

Commit

Permalink
Prevent invalid bin values
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbradley committed Oct 10, 2024
1 parent 82e4d2b commit 403922f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/bioclip/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import open_clip as oc
import torch.nn.functional as F
import numpy as np
import pandas as pd
import collections
import heapq
import PIL.Image
Expand Down Expand Up @@ -272,6 +273,8 @@ class CustomLabelsBinningClassifier(CustomLabelsClassifier):
def __init__(self, cls_to_bin: dict, **kwargs):
super().__init__(cls_ary=cls_to_bin.keys(), **kwargs)
self.cls_to_bin = cls_to_bin
if any([pd.isna(x) or not x for x in cls_to_bin.values()]):
raise ValueError("Empty, null, or nan are not allowed for bin values.")

def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
result = []
Expand Down
27 changes: 27 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from bioclip.predict import CustomLabelsBinningClassifier
import os
import torch
import pandas as pd


DIRNAME = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -111,6 +112,32 @@ def test_predict_with_bins(self):
names = set([pred['classification'] for pred in prediction_ary])
self.assertEqual(names, set(['one', 'two', 'three']))

def test_predict_with_bins_bad_values(self):
with self.assertRaises(ValueError) as raised_exceptions:
CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': '',
'fish': 'two',
})
self.assertEqual(str(raised_exceptions.exception),
"Empty, null, or nan are not allowed for bin values.")
with self.assertRaises(ValueError) as raised_exceptions:
CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': None,
'fish': 'two',
})
self.assertEqual(str(raised_exceptions.exception),
"Empty, null, or nan are not allowed for bin values.")
with self.assertRaises(ValueError) as raised_exceptions:
CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': pd.NA,
'fish': 'two',
})
self.assertEqual(str(raised_exceptions.exception),
"Empty, null, or nan are not allowed for bin values.")

class TestEmbed(unittest.TestCase):
def test_get_image_features(self):
classifier = TreeOfLifeClassifier(device='cpu')
Expand Down

0 comments on commit 403922f

Please sign in to comment.