Skip to content

Commit

Permalink
Simplify binning classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbradley committed Oct 3, 2024
1 parent 87c17a6 commit 3e8b732
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/bioclip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def predict(image_file: list[str],
write_results(predictions, format, output)
elif bins_path:
bin_df = pd.read_csv(bins_path, index_col=0)
classifier = CustomLabelsBinningClassifier(bin_df=bin_df, **kwargs)
cls_to_bin = bin_df.bin.to_dict()
classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs)
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
else:
Expand Down
10 changes: 4 additions & 6 deletions src/bioclip/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from huggingface_hub import hf_hub_download
from typing import Union, List
from enum import Enum
import pandas as pd


HF_DATAFILE_REPO = "imageomics/bioclip-demo"
Expand Down Expand Up @@ -270,16 +269,15 @@ def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -


class CustomLabelsBinningClassifier(CustomLabelsClassifier):
def __init__(self, bin_df: pd.DataFrame, **kwargs):
classes = list(bin_df.index.values)
super().__init__(cls_ary=classes, **kwargs)
self.bin_df = bin_df
def __init__(self, cls_to_bin: dict, **kwargs):
super().__init__(cls_ary=cls_to_bin.keys(), **kwargs)
self.cls_to_bin = cls_to_bin

def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
result = []
output = collections.defaultdict(float)
for i in range(len(self.classes)):
name = self.bin_df.loc[self.classes[i]].bin
name = self.cls_to_bin[self.classes[i]]
output[name] += img_probs[i]
topk_names = heapq.nlargest(k, output, key=output.get)
for name in topk_names:
Expand Down

0 comments on commit 3e8b732

Please sign in to comment.