diff --git a/src/bioclip/predict.py b/src/bioclip/predict.py index 283b777..22520df 100644 --- a/src/bioclip/predict.py +++ b/src/bioclip/predict.py @@ -5,6 +5,7 @@ import open_clip as oc import torch.nn.functional as F import numpy as np +import pandas as pd import collections import heapq import PIL.Image @@ -272,6 +273,8 @@ class CustomLabelsBinningClassifier(CustomLabelsClassifier): def __init__(self, cls_to_bin: dict, **kwargs): super().__init__(cls_ary=cls_to_bin.keys(), **kwargs) self.cls_to_bin = cls_to_bin + if any([pd.isna(x) or not x for x in cls_to_bin.values()]): + raise ValueError("Empty, null, or nan are not allowed for bin values.") def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]: result = [] diff --git a/tests/test_predict.py b/tests/test_predict.py index 41e5700..8ed7914 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -4,6 +4,7 @@ from bioclip.predict import CustomLabelsBinningClassifier import os import torch +import pandas as pd DIRNAME = os.path.dirname(os.path.realpath(__file__)) @@ -111,6 +112,32 @@ def test_predict_with_bins(self): names = set([pred['classification'] for pred in prediction_ary]) self.assertEqual(names, set(['one', 'two', 'three'])) + def test_predict_with_bins_bad_values(self): + with self.assertRaises(ValueError) as raised_exceptions: + CustomLabelsBinningClassifier(cls_to_bin={ + 'cat': 'one', + 'mouse': '', + 'fish': 'two', + }) + self.assertEqual(str(raised_exceptions.exception), + "Empty, null, or nan are not allowed for bin values.") + with self.assertRaises(ValueError) as raised_exceptions: + CustomLabelsBinningClassifier(cls_to_bin={ + 'cat': 'one', + 'mouse': None, + 'fish': 'two', + }) + self.assertEqual(str(raised_exceptions.exception), + "Empty, null, or nan are not allowed for bin values.") + with self.assertRaises(ValueError) as raised_exceptions: + CustomLabelsBinningClassifier(cls_to_bin={ + 'cat': 'one', + 'mouse': pd.NA, + 'fish': 'two', + }) + self.assertEqual(str(raised_exceptions.exception), + "Empty, null, or nan are not allowed for bin values.") + class TestEmbed(unittest.TestCase): def test_get_image_features(self): classifier = TreeOfLifeClassifier(device='cpu')