Skip to content

Commit

Permalink
organize metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Ramith Hettiarachchi committed Mar 11, 2023
1 parent 2beeb35 commit 75ea653
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions modules/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import matplotlib.pyplot as plt


def get_bacteria_dataloaders(img_size, train_batch_size ,torch_seed=10, label_type = "class", balanced_mode = False, expand_channels = False, data_dir= '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria_np'):
def get_bacteria_dataloaders(img_size, train_batch_size ,torch_seed=10, label_type = "class", balanced_mode = False, expand_channels = False, data_dir= '/n/holyscratch01/wadduwage_lab/ramith/bacteria_processed'):
'''
Function to return train, validation QPM dataloaders
Args:
Expand Down Expand Up @@ -56,7 +56,7 @@ def get_bacteria_dataloaders(img_size, train_batch_size ,torch_seed=10, label_ty
dataset_sizes = {'train': len(train_loader)*train_batch_size, 'val': len(val_loader)*32, 'test': len(test_loader)*128}
return train_loader, val_loader, test_loader, dataset_sizes

def get_bacteria_eval_dataloaders(img_size, test_batch_size ,torch_seed=10, label_type = "class", expand_channels = False, data_dir= '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria_np', isolate_class = False):
def get_bacteria_eval_dataloaders(img_size, test_batch_size ,torch_seed=10, label_type = "class", expand_channels = False, data_dir= '/n/holyscratch01/wadduwage_lab/ramith/bacteria_processed', isolate_class = False):
'''
Function to return train, validation QPM dataloaders
Args:
Expand Down
38 changes: 19 additions & 19 deletions modules/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,21 @@ def test_model_in_groups(model, data, criterion, n_classes = 0, device = 'cpu',



if(n_classes == 2): ## Calculate *binary* classification metrics
test_accuracy = Accuracy(task="binary", average = None, num_classes = n_classes, compute_on_step=False).to(device)

test_f1 = F1Score(task="binary", compute_on_step=False).to(device)
test_precision = Precision(task="binary", compute_on_step=False).to(device)
test_recall = Recall(task="binary", compute_on_step=False).to(device)
test_specificity = Specificity(task="binary", compute_on_step=False).to(device)
# test_auroc = AUROC(task="binary").to(device)
else:
test_accuracy = Accuracy(task="multiclass", average = None, num_classes = n_classes, compute_on_step=False).to(device)

test_f1 = F1Score(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device)
test_precision = Precision(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device)
test_recall = Recall(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device)
test_specificity = Specificity(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device)
# if(n_classes == 2): ## Calculate *binary* classification metrics
# test_accuracy = Accuracy(task="binary", average = None, num_classes = n_classes, compute_on_step=False).to(device)

# test_f1 = F1Score(task="binary", compute_on_step=False).to(device)
# test_precision = Precision(task="binary", compute_on_step=False).to(device)
# test_recall = Recall(task="binary", compute_on_step=False).to(device)
# test_specificity = Specificity(task="binary", compute_on_step=False).to(device)
# # test_auroc = AUROC(task="binary").to(device)
# else:
test_accuracy = Accuracy(task="multiclass", average = None, num_classes = n_classes, compute_on_step=False).to(device)

test_f1 = F1Score(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device)
test_precision = Precision(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device)
test_recall = Recall(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device)
test_specificity = Specificity(task="multiclass", num_classes = n_classes, compute_on_step=False, average = None).to(device)
# test_auroc = MulticlassAUROC(num_classes = n_classes, average="macro").to(device)

test_preds = torch.empty([0, ])
Expand Down Expand Up @@ -250,10 +250,10 @@ def test_model_in_groups(model, data, criterion, n_classes = 0, device = 'cpu',
# print("Sklearn ROC AUC (ovo)", roc_auc_score(test_labels_.to(dtype = torch.int32), pred_probs, multi_class= 'ovo'))

t_acc = test_accuracy.compute().tolist()
t_f1 = float(test_f1.compute()) if n_classes == 2 else test_f1.compute().tolist()
t_precision = float(test_precision.compute()) if n_classes == 2 else test_precision.compute().tolist()
t_recall = float(test_recall.compute()) if n_classes == 2 else test_recall.compute().tolist()
t_specificity = float(test_specificity.compute()) if n_classes == 2 else test_specificity.compute().tolist()
t_f1 = float(test_f1.compute()) if n_classes == 1 else test_f1.compute().tolist()
t_precision = float(test_precision.compute()) if n_classes == 1 else test_precision.compute().tolist()
t_recall = float(test_recall.compute()) if n_classes == 1 else test_recall.compute().tolist()
t_specificity = float(test_specificity.compute()) if n_classes == 1 else test_specificity.compute().tolist()

print("test accuracy",t_acc)
print("test f1",t_f1)
Expand Down

0 comments on commit 75ea653

Please sign in to comment.