Skip to content

Commit

Permalink
Fix changes after merge
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbradley committed Nov 20, 2024
1 parent 5d932a4 commit 3aee346
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/bioclip/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
img_features = self.model.encode_image(x)
img_features = F.normalize(img_features, dim=-1)
return self.create_probabilities(img_features, self.txt_features)
return self.create_probabilities(img_features, self.txt_embeddings)


class CustomLabelsClassifier(BaseClassifier):
def __init__(self, cls_ary: List[str], **kwargs):
super().__init__(**kwargs)
self.tokenizer = create_bioclip_tokenizer(self.model_str)
self.classes = [cls.strip() for cls in cls_ary]
self.txt_features = self._get_txt_features(self.classes)
self.txt_embeddings = self._get_txt_embeddings(self.classes)

@torch.no_grad()
def _get_txt_features(self, classnames):
def _get_txt_embeddings(self, classnames):
all_features = []
for classname in classnames:
txts = [template(classname) for template in OPENA_AI_IMAGENET_TEMPLATE]
Expand All @@ -272,7 +272,7 @@ def _get_txt_features(self, classnames):
def predict(self, images: List[str] | str | List[PIL.Image.Image], k: int = None) -> dict[str, float]:
if isinstance(images, str):
images = [images]
probs = self.create_probabilities_for_images(images, self.txt_features)
probs = self.create_probabilities_for_images(images, self.txt_embeddings)
result = []
for i, image in enumerate(images):
key = self.make_key(image, i)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_apply_filter(self):

def test_forward(self):
classifier = TreeOfLifeClassifier()
img = classifier.open_image(EXAMPLE_CAT_IMAGE)
img = classifier.ensure_rgb_image(EXAMPLE_CAT_IMAGE)
img_features = torch.stack([classifier.preprocess(img)])
result = classifier.forward(x=img_features)
self.assertEqual(result.shape, torch.Size([1, len(classifier.txt_names)]))
Expand Down

0 comments on commit 3aee346

Please sign in to comment.