Skip to content

Commit

Permalink
Improve CustomLabelsBinningClassifier tests
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbradley committed Oct 8, 2024
1 parent 413bab6 commit a0b0854
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,20 @@ def test_predict_with_bins(self):
})
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
self.assertEqual(len(prediction_ary), 2)
self.assertEqual(prediction_ary[0]['filename'], EXAMPLE_CAT_IMAGE2)
names = set([pred['classification'] for pred in prediction_ary])
self.assertEqual(names, set(['one', 'two']))

classifier = CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': 'two',
'fish': 'three',
})
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
self.assertEqual(len(prediction_ary), 3)
self.assertEqual(prediction_ary[0]['filename'], EXAMPLE_CAT_IMAGE2)
names = set([pred['classification'] for pred in prediction_ary])
self.assertEqual(names, set(['one', 'two', 'three']))

class TestEmbed(unittest.TestCase):
def test_get_image_features(self):
Expand Down

0 comments on commit a0b0854

Please sign in to comment.