Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catching roc_auc_score error 🚫 #316

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions grace/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
roc_auc_score,
average_precision_score,
)

from grace.styling import LOGGER
from grace.base import GraphAttrs, Annotation, Prediction
from grace.models.datasets import dataset_from_graph
from grace.evaluation.metrics_classifier import safe_roc_auc_score
from grace.visualisation.plotting import (
plot_confusion_matrix_tiles,
plot_areas_under_curves,
Expand Down Expand Up @@ -359,11 +359,11 @@ def calculate_numerical_results_on_entire_batch(
inference_batch_metrics["Batch F1-score (edges)"] = prf1_edges[2]

# AUC scores:
inference_batch_metrics["Batch AUROC (nodes)"] = roc_auc_score(
inference_batch_metrics["Batch AUROC (nodes)"] = safe_roc_auc_score(
y_true=predictions_data["n_true"],
y_score=predictions_data["n_prob"],
)
inference_batch_metrics["Batch AUROC (edges)"] = roc_auc_score(
inference_batch_metrics["Batch AUROC (edges)"] = safe_roc_auc_score(
y_true=predictions_data["e_true"],
y_score=predictions_data["e_prob"],
)
Expand Down
12 changes: 12 additions & 0 deletions grace/evaluation/metrics_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
roc_auc_score,
)


Expand Down Expand Up @@ -113,6 +114,17 @@ def confusion_matrix_metric(
return fig_node, fig_edge


def safe_roc_auc_score(
y_true: torch.Tensor,
y_score: torch.Tensor,
):
unique_classes = len(set(y_true))
if unique_classes > 1:
return roc_auc_score(y_true=y_true, y_score=y_score)
else:
return 0.0


METRICS = {
"accuracy": accuracy_metric,
"f1_score": f1_score_metric,
Expand Down
12 changes: 6 additions & 6 deletions grace/visualisation/plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from grace.base import GraphAttrs, Annotation
from grace.styling import COLORMAPS
from grace.base import GraphAttrs, Annotation
from grace.evaluation.metrics_classifier import safe_roc_auc_score

import matplotlib.pyplot as plt
import networkx as nx
Expand All @@ -10,11 +11,10 @@

from skimage.util import montage
from sklearn.metrics import (
ConfusionMatrixDisplay,
roc_auc_score,
RocCurveDisplay,
average_precision_score,
ConfusionMatrixDisplay,
PrecisionRecallDisplay,
RocCurveDisplay,
)


Expand Down Expand Up @@ -179,7 +179,7 @@ def plot_areas_under_curves(
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=figsize)

# Area under ROC:
roc_score_nodes = roc_auc_score(y_true=node_true, y_score=node_pred)
roc_score_nodes = safe_roc_auc_score(y_true=node_true, y_score=node_pred)
RocCurveDisplay.from_predictions(
y_true=node_true,
y_pred=node_pred,
Expand All @@ -189,7 +189,7 @@ def plot_areas_under_curves(
ax=axes[0],
)

roc_score_edges = roc_auc_score(y_true=edge_true, y_score=edge_pred)
roc_score_edges = safe_roc_auc_score(y_true=edge_true, y_score=edge_pred)
RocCurveDisplay.from_predictions(
y_true=edge_true,
y_pred=edge_pred,
Expand Down
27 changes: 25 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,34 @@ def run_grace_training(
run_grace(config_file=config_fn)


@pytest.mark.parametrize("store_graph_attributes_permanently", [False, True])
@pytest.mark.parametrize(
"store_graph_attributes_permanently",
[
False,
],
)
@pytest.mark.xfail(
reason="sample graph contains no node features & edge properties"
)
def test_run_grace(
def test_run_grace_without_required_graph_attributes(
mrc_image_and_annotations_dir,
simple_extractor,
store_graph_attributes_permanently,
):
run_grace_training(
mrc_image_and_annotations_dir,
simple_extractor,
store_graph_attributes_permanently,
)


@pytest.mark.parametrize(
"store_graph_attributes_permanently",
[
True,
],
)
def test_run_grace_if_graph_attribute_computation_allowed(
mrc_image_and_annotations_dir,
simple_extractor,
store_graph_attributes_permanently,
Expand Down