Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added base scoring program #14

Merged
merged 14 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added baselines/BioCLIP_code_submission/clf.pkl
Binary file not shown.
1 change: 1 addition & 0 deletions baselines/BioCLIP_code_submission/metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
description: Provides prediction model to be executed by the ingestion program
43 changes: 43 additions & 0 deletions baselines/BioCLIP_code_submission/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
'''
Sample predictive model.
The ingestion program will call `predict` to get a prediction for each test image and then save the predictions for scoring. The following two methods are required:
- predict: uses the model to perform predictions.
- load: reloads the model.
'''
from open_clip import create_model
from torchvision import transforms
import torch
import pickle
import os

class Model:
def __init__(self):
# model will be called from the load() method
self.clf = None

def load(self):
self.device='cuda'

model = create_model("hf-hub:imageomics/bioclip", output_dict=True, require_pretrained=True)
self.model = model.to(self.device)

with open(os.path.join(os.path.dirname(__file__), "clf.pkl"), "rb") as f:
self.clf = pickle.load(f)

self.preprocess_img = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
]
)


def predict(self, datapoint):

with torch.no_grad():
image = self.preprocess_img(datapoint).to(self.device)
image_feature = self.model(image.unsqueeze(0))['image_features']
image_feature = image_feature.detach().cpu().numpy()
score = self.clf.predict_proba(image_feature)[:, 1][0]

return score
4 changes: 4 additions & 0 deletions baselines/BioCLIP_code_submission/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
open-clip-torch==2.24.0
torch==2.3.0
torchvision==0.18.0
scikit-learn==1.3.2
Binary file added baselines/DINO_SGD_code_submission/clf.pkl
Binary file not shown.
54 changes: 47 additions & 7 deletions baselines/DINO_SGD_code_submission/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,60 @@
- load: reloads the model.
'''

import os
import pickle

import torch
from torchvision import transforms
from transformers import AutoModel

import numpy as np
import PIL

from sklearn.linear_model import SGDClassifier

class Model:
def __init__(self):
self.dino_name = 'facebook/dinov2-base'
self.pil_transform_fn = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.num_train_samples=0
self.num_feat=1
self.num_labels=1
self.is_trained=False

def fit(self, X):
# We may just require a dataloader, or a csv file
pass

def load(self):
# DINO backbone
self.model = AutoModel.from_pretrained(self.dino_name)
self.model.eval()

# Classifier
non_hybrid_weight = 1
hybrid_weight = 1
class_weights = {0: non_hybrid_weight, 1: hybrid_weight}

# Load Classifier weights
with open(os.path.join(os.path.dirname(__file__), "clf.pkl"), "rb") as f:
self.clf = pickle.load(f)


def _get_features(self, x: torch.Tensor) -> torch.Tensor:
feats = self.model(x)
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/dinov2/modeling_dinov2.py#L707
cls_token = feats[:, 0]
patch_tokens = feats[:, 1:]
feats = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
return feats

def _get_clf_prediction(self, features: torch.Tensor) -> float:
np_features = features.detach().cpu().numpy() # Convert to numpy for classifer compatibility
return self.clf.predict_proba(np_features)[0, 1] # Since a batch of 1, just extract float

def predict(self, X):
# Again we may just send in a dataloader
return np.zeros_like(X)
def predict(self, x: PIL.Image) -> float:
x_tensor = self.pil_transform_fn(x).unsqueeze(0)
features = self._get_features(x_tensor)
prediction = self._get_clf_prediction(features)
return prediction
egrace479 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 4 additions & 0 deletions baselines/DINO_SGD_code_submission/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch==2.3.0
torchvision==0.18.0
scikit-learn==1.4.2
transformers==4.40.0
85 changes: 57 additions & 28 deletions ingestion_program/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,37 +84,39 @@
# Use default location for the input and output data:
# If no arguments to run.py are provided, this is where the data will be found
# and the results written to. Change the root_dir to your local directory.
root_dir = "../"
default_input_dir = root_dir + "input_data"
default_output_dir = root_dir + "sample_result_submission"
default_program_dir = root_dir + "ingestion_program"
default_submission_dir = root_dir + "baselines/DINO_SGD_code_submission"

# =============================================================================
# =========================== END USER OPTIONS ================================
# =============================================================================

import os
from sys import argv, path
from sys import argv, path, executable
import subprocess
from PIL import Image
from tqdm import tqdm

def write_results(outfile, scores):
"""Should write the score to our output directory
TODO:
"""
pass

# def write_results(path, data_iter):
# """Should write the score to our output directory
# path: path of output file
# data_iter: a iterator of data
# """
# with open(path, 'w') as f:
# for data in data_iter[:-1]:
# f.write(data[0] + " " + data[1] + '\n')
# data = data_iter[-1]
# f.write(data[0] + " " + data[1] + '\n')

# print("Write all the results to: " + path)

if __name__ == "__main__":
#### INPUT/OUTPUT: Get input and output directory names
if len(argv)==1: # Use the default input and output directories if no arguments are provided
input_dir = default_input_dir
output_dir = default_output_dir
program_dir= default_program_dir
submission_dir= default_submission_dir
else:
input_dir = os.path.abspath(argv[1])
output_dir = os.path.abspath(argv[2])
program_dir = os.path.abspath(argv[3])
submission_dir = os.path.abspath(argv[4])

input_dir = os.path.abspath(argv[1])
output_dir = os.path.abspath(argv[2])
program_dir = os.path.abspath(argv[3])
submission_dir = os.path.abspath(argv[4])

if verbose:
print("Using input_dir: " + input_dir)
print("Using output_dir: " + output_dir)
Expand All @@ -123,13 +125,40 @@ def write_results(outfile, scores):

path.append(program_dir) # In order to access libraries from our own code
path.append(submission_dir) # In order to access libraries of the user

if os.path.isfile(os.path.join(submission_dir, "requirements.txt")):
subprocess.check_call([executable, "-m", "pip", "install", "-r", os.path.join(submission_dir, "requirements.txt")])

from model import Model


submit_model = Model()
submit_model.load()



img_list = os.listdir(input_dir)
num_of_datapoint = len(img_list)

with open(os.path.join(output_dir, "predictions.txt"), 'w') as f:

# scorelist = []
for idx, filename in tqdm(enumerate(img_list), total=num_of_datapoint):
image_path = os.path.join(input_dir, filename)

try:
datapoint = Image.open(image_path)
except Exception as e:
print(f"{image_path}: {e}")
continue

score = submit_model.predict(datapoint)
#? whether need to sanity check on the variable returned from submitted model

# scorelist.append(str(round(score, 4)))
if idx == num_of_datapoint - 1:
f.write(filename + " " + str(round(score, 4)))
else:
f.write(filename + " " + str(round(score, 4)) + '\n')

path_to_csv_for_training = None # TODO
path_to_csv_for_testing = None # TODO
M = Model()
M.fit(path_to_csv_for_training)
scores = M.predict(path_to_csv_for_testing)

write_results(os.path.join(output_dir, "scores.txt"), scores)
# write_results(os.path.join(output_dir, "predictions.txt"), zip(ref['CAMID'].values.tolist(), scorelist))
3 changes: 3 additions & 0 deletions reference_data/predictions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CAM123 0.03
CAM124 0.23
CAM125 0.89
5 changes: 5 additions & 0 deletions reference_data/score.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
accuracy: 0.65814696485623
f1_score: 0.036036036036036036
hybrid_precision: 0.16
hybrid_recall: 0.02030456852791878
roc_auc: 0.48567675978843494
3 changes: 3 additions & 0 deletions reference_data/solutions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CAM123 0
CAM124 1
CAM125 1
Loading