Skip to content

Commit

Permalink
Merge branch 'dev' of github.com:ENSTA-U2IS-AI/torch-uncertainty into…
Browse files Browse the repository at this point in the history
… dev
  • Loading branch information
alafage committed Sep 24, 2024
2 parents afd9b00 + 54faee6 commit d9b5a88
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 22 deletions.
13 changes: 0 additions & 13 deletions pyrightconfig.json

This file was deleted.

3 changes: 3 additions & 0 deletions torch_uncertainty/baselines/classification/resnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal

from torch import nn
from torch.optim import Optimizer

from torch_uncertainty.models import mc_dropout
from torch_uncertainty.models.resnet import (
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
normalization_layer: type[nn.Module] = nn.BatchNorm2d,
num_estimators: int = 1,
dropout_rate: float = 0.0,
optim_recipe: dict | Optimizer | None = None,
mixup_params: dict | None = None,
last_layer_dropout: bool = False,
width_multiplier: float = 1.0,
Expand Down Expand Up @@ -229,6 +231,7 @@ def __init__(
model=model,
loss=loss,
is_ensemble=version in ENSEMBLE_METHODS,
optim_recipe=optim_recipe,
format_batch_fn=format_batch_fn,
mixup_params=mixup_params,
eval_ood=eval_ood,
Expand Down
7 changes: 7 additions & 0 deletions torch_uncertainty/baselines/classification/vgg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal

from torch import nn
from torch.optim import Optimizer

from torch_uncertainty.models import mc_dropout
from torch_uncertainty.models.vgg import (
Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(
num_estimators: int = 1,
dropout_rate: float = 0.0,
last_layer_dropout: bool = False,
optim_recipe: dict | Optimizer | None = None,
mixup_params: dict | None = None,
groups: int = 1,
alpha: int | None = None,
Expand All @@ -52,6 +54,10 @@ def __init__(
num_classes (int): Number of classes to predict.
in_channels (int): Number of input channels.
loss (nn.Module): Training loss.
optim_recipe (Any): optimization recipe, corresponds to
what expect the `LightningModule.configure_optimizers()
<https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#configure-optimizers>`_
method.
version (str):
Determines which VGG version to use:
Expand Down Expand Up @@ -164,6 +170,7 @@ def __init__(
loss=loss,
is_ensemble=version in ENSEMBLE_METHODS,
format_batch_fn=format_batch_fn,
optim_recipe=optim_recipe,
mixup_params=mixup_params,
eval_ood=eval_ood,
ood_criterion=ood_criterion,
Expand Down
3 changes: 3 additions & 0 deletions torch_uncertainty/baselines/classification/wideresnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Literal

from torch import nn
from torch.optim import Optimizer

from torch_uncertainty.models import mc_dropout
from torch_uncertainty.models.wideresnet import (
Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
style: str = "imagenet",
num_estimators: int = 1,
dropout_rate: float = 0.0,
optim_recipe: dict | Optimizer | None = None,
mixup_params: dict | None = None,
groups: int = 1,
last_layer_dropout: bool = False,
Expand Down Expand Up @@ -186,6 +188,7 @@ def __init__(
loss=loss,
is_ensemble=version in ENSEMBLE_METHODS,
format_batch_fn=format_batch_fn,
optim_recipe=optim_recipe,
mixup_params=mixup_params,
eval_ood=eval_ood,
eval_grouping_loss=eval_grouping_loss,
Expand Down
1 change: 0 additions & 1 deletion torch_uncertainty/datamodules/segmentation/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(
pin_memory=pin_memory,
persistent_workers=persistent_workers,
)

self.dataset = Cityscapes
self.mode = "fine"
self.crop_size = _pair(crop_size)
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/metrics/classification/risk_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def plot(
ax.set_xlabel("Coverage (%)", fontsize=16)
ax.set_ylabel("Risk - Error Rate (%)", fontsize=16)
ax.set_xlim(0, 100)
ax.set_ylim(0, 100)
ax.set_ylim(0, min(100, np.ceil(error_rates.max() * 100)))
ax.set_aspect("equal", "box")
ax.legend(loc="upper right")
fig.tight_layout()
Expand Down Expand Up @@ -269,7 +269,7 @@ def plot(
ax.set_xlabel("Coverage (%)", fontsize=16)
ax.set_ylabel("Generalized Risk (%)", fontsize=16)
ax.set_xlim(0, 100)
ax.set_ylim(0, 100)
ax.set_ylim(0, min(100, np.ceil(error_rates.max() * 100)))
ax.set_aspect("equal", "box")
ax.legend(loc="upper right")
fig.tight_layout()
Expand Down
12 changes: 10 additions & 2 deletions torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def training_step(
loss = self.loss(logits, target, self.current_epoch)
if self.needs_step_update:
self.model.update_wrapper(self.current_epoch)
self.log("train_loss", loss)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss

def validation_step(
Expand Down Expand Up @@ -501,7 +501,15 @@ def test_step(
self.ood_logit_storage.append(logits.detach().cpu())

def on_validation_epoch_end(self) -> None:
self.log_dict(self.val_cls_metrics.compute(), sync_dist=True)
self.log_dict(
self.val_cls_metrics.compute(), logger=True, sync_dist=True
)
self.log(
"Acc%",
self.val_cls_metrics["cls/Acc"].compute() * 100,
prog_bar=True,
logger=False,
)
self.val_cls_metrics.reset()

if self.eval_grouping_loss:
Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/routines/pixel_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def training_step(

if self.needs_step_update:
self.model.update_wrapper(self.current_epoch)
self.log("train_loss", loss)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss

def validation_step(
Expand Down
2 changes: 1 addition & 1 deletion torch_uncertainty/routines/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def training_step(

if self.needs_step_update:
self.model.update_wrapper(self.current_epoch)
self.log("train_loss", loss)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss

def validation_step(
Expand Down
11 changes: 9 additions & 2 deletions torch_uncertainty/routines/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def training_step(
loss = self.loss(logits[valid_mask], target[valid_mask])
if self.needs_step_update:
self.model.update_wrapper(self.current_epoch)
self.log("train_loss", loss)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss

def validation_step(
Expand Down Expand Up @@ -214,7 +214,14 @@ def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None:
self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, targets))

def on_validation_epoch_end(self) -> None:
self.log_dict(self.val_seg_metrics.compute(), sync_dist=True)
self.log_dict(
self.val_seg_metrics.compute(), logger=True, sync_dist=True
)
self.log(
"mIoU%",
self.val_seg_metrics["seg/mIoU"].compute() * 100,
prog_bar=True,
)
self.log_dict(self.val_sbsmpl_seg_metrics.compute(), sync_dist=True)
self.val_seg_metrics.reset()
self.val_sbsmpl_seg_metrics.reset()
Expand Down

0 comments on commit d9b5a88

Please sign in to comment.