From 7b031814a759851a0e9a13aa0920181b5fd28597 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 11 Nov 2024 09:58:54 +0100 Subject: [PATCH] Add validation_step and improve data management --- pina/data/data_module.py | 64 ++++++++++++++++++------ pina/solvers/solver.py | 1 - pina/solvers/supervised.py | 100 ++++++++++++++++++++++++++++++++++--- pina/trainer.py | 31 +++++++++--- 4 files changed, 166 insertions(+), 30 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index ea6a802c..d11a70f7 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -6,6 +6,9 @@ import torch import logging from pytorch_lightning import LightningDataModule +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, \ + TRAIN_DATALOADERS + from .sample_dataset import SamplePointDataset from .supervised_dataset import SupervisedDataset from .unsupervised_dataset import UnsupervisedDataset @@ -61,30 +64,31 @@ def __init__(self, if train_size > 0: self.split_names.append('train') self.split_length.append(train_size) - self.loader_functions['train_dataloader'] = lambda \ - x: PinaDataLoader(self.splits['train'], self.batch_size, - self.condition_names) + else: + self.train_dataloader = super().train_dataloader + if test_size > 0: self.split_length.append(test_size) self.split_names.append('test') - self.loader_functions['test_dataloader'] = lambda x: PinaDataLoader( - self.splits['test'], self.batch_size, self.condition_names) + else: + self.test_dataloader = super().test_dataloader + if val_size > 0: self.split_length.append(val_size) self.split_names.append('val') - self.loader_functions['val_dataloader'] = lambda x: PinaDataLoader( - self.splits['val'], self.batch_size, self.condition_names) + else: + self.val_dataloader = super().val_dataloader + if predict_size > 0: self.split_length.append(predict_size) self.split_names.append('predict') - self.loader_functions[ - 'predict_dataloader'] = lambda x: PinaDataLoader( - self.splits['predict'], self.batch_size, self.condition_names) + else: + self.predict_dataloader = super().predict_dataloader + self.splits = {k: {} for k in self.split_names} self.shuffle = shuffle - - for k, v in self.loader_functions.items(): - setattr(self, k, v.__get__(self, PinaDataModule)) + self.has_setup_fit = False + self.has_setup_test = False def prepare_data(self): if self.datasets is None: @@ -106,8 +110,12 @@ def setup(self, stage=None): for i in range(len(self.split_length)): self.splits[self.split_names[i]][ dataset.data_type] = splits[i] + self.has_setup_fit = True elif stage == 'test': - raise NotImplementedError("Testing pipeline not implemented yet") + if self.has_setup_fit is False: + raise NotImplementedError( + "You must call setup with stage='fit' " + "first") else: raise ValueError("stage must be either 'fit' or 'test'") @@ -178,3 +186,31 @@ def _create_datasets(self): dataset.initialize() datasets.append(dataset) self.datasets = datasets + + def val_dataloader(self): + """ + Create the validation dataloader + """ + return PinaDataLoader(self.splits['val'], self.batch_size, + self.condition_names) + + def train_dataloader(self): + """ + Create the training dataloader + """ + return PinaDataLoader(self.splits['train'], self.batch_size, + self.condition_names) + + def test_dataloader(self): + """ + Create the testing dataloader + """ + return PinaDataLoader(self.splits['test'], self.batch_size, + self.condition_names) + + def predict_dataloader(self): + """ + Create the prediction dataloader + """ + return PinaDataLoader(self.splits['predict'], self.batch_size, + self.condition_names) diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index b622546e..fe9c897e 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -83,7 +83,6 @@ def __init__(self, " optimizers.") # extra features handling - self._pina_models = models self._pina_optimizers = optimizers self._pina_schedulers = schedulers diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index a2be1102..ff4153a6 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -1,5 +1,6 @@ """ Module for SupervisedSolver """ import torch +from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn.modules.loss import _Loss from ..optim import TorchOptimizer, TorchScheduler from .solver import SolverInterface @@ -75,11 +76,15 @@ def __init__(self, extra_features=extra_features) # check consistency - check_consistency(loss, (LossInterface, _Loss), subclass=False) + check_consistency(loss, (LossInterface, _Loss), + subclass=False) self._loss = loss self._model = self._pina_models[0] self._optimizer = self._pina_optimizers[0] self._scheduler = self._pina_schedulers[0] + self.validation_condition_losses = { + k: {'loss': [], + 'count': []} for k in self.problem.conditions.keys()} def forward(self, x): """Forward pass implementation for the solver. @@ -105,7 +110,7 @@ def configure_optimizers(self): return ([self._optimizer.optimizer_instance], [self._scheduler.scheduler_instance]) - def training_step(self, batch, batch_idx): + def training_step(self, batch): """Solver training step. :param batch: The batch element in the dataloader. @@ -117,12 +122,14 @@ def training_step(self, batch, batch_idx): """ condition_idx = batch.supervised.condition_indices - loss = torch.tensor(0, dtype=torch.float32) + loss = torch.tensor(0, dtype=torch.float32).to(self.device) + batch = batch.supervised for condition_id in range(condition_idx.min(), condition_idx.max() + 1): - condition_name = self._dataloader.condition_names[condition_id] + condition_name = self.trainer.data_module.condition_names[ + condition_id] condition = self.problem.conditions[condition_name] - pts = batch.supervised.input_points - out = batch.supervised.output_points + pts = batch.input_points + out = batch.output_points if condition_name not in self.problem.conditions: raise RuntimeError("Something wrong happened.") @@ -134,13 +141,90 @@ def training_step(self, batch, batch_idx): output_pts = out[condition_idx == condition_id] input_pts = pts[condition_idx == condition_id] - loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) loss += loss_.as_subclass(torch.Tensor) - self.log("mean_loss", float(loss), prog_bar=True, logger=True) + self.log("mean_loss", float(loss), prog_bar=True, logger=True, + on_epoch=True, + on_step=False, batch_size=self.trainer.data_module.batch_size) return loss + def validation_step(self, batch): + """ + Solver validation step. + """ + + batch = batch.supervised + condition_idx = batch.condition_indices + for i in range(condition_idx.min(), condition_idx.max() + 1): + condition_name = self.trainer.data_module.condition_names[i] + condition = self.problem.conditions[condition_name] + pts = batch.input_points + out = batch.output_points + if condition_name not in self.problem.conditions: + raise RuntimeError("Something wrong happened.") + + # for data driven mode + if not hasattr(condition, "output_points"): + raise NotImplementedError( + f"{type(self).__name__} works only in data-driven mode.") + + output_pts = out[condition_idx == i] + input_pts = pts[condition_idx == i] + + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + self.validation_condition_losses[condition_name]['loss'].append( + loss_) + self.validation_condition_losses[condition_name]['count'].append( + len(input_pts)) + + def on_validation_epoch_end(self): + """ + Solver validation epoch end. + """ + total_loss = [] + total_count = [] + for k, v in self.validation_condition_losses.items(): + local_counter = torch.tensor(v['count']).to(self.device) + n_elements = torch.sum(local_counter) + loss = torch.sum( + torch.stack(v['loss']) * local_counter) / n_elements + loss = loss.as_subclass(torch.Tensor) + total_loss.append(loss) + total_count.append(n_elements) + self.log( + k + "_loss", + loss, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + batch_size=self.trainer.data_module.batch_size, + ) + total_count = (torch.tensor(total_count, dtype=torch.float32). + to(self.device)) + mean_loss = (torch.sum(torch.stack(total_loss) * total_count) / + total_count) + self.log( + "val_loss", + mean_loss, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + batch_size=self.trainer.data_module.batch_size, + ) + for key in self.validation_condition_losses.keys(): + self.validation_condition_losses[key]['loss'] = [] + self.validation_condition_losses[key]['count'] = [] + + def test_step(self, batch, batch_idx) -> STEP_OUTPUT: + """ + Solver test step. + """ + + raise NotImplementedError("Test step not implemented.") + def loss_data(self, input_pts, output_pts): """ The data loss for the Supervised solver. It computes the loss between diff --git a/pina/trainer.py b/pina/trainer.py index 49461166..46d26471 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,5 +1,5 @@ """ Trainer module. """ - +import warnings import torch import pytorch_lightning from .utils import check_consistency @@ -15,6 +15,7 @@ def __init__(self, train_size=.7, test_size=.2, val_size=.1, + predict_size=.0, **kwargs): """ PINA Trainer class for costumizing every aspect of training via flags. @@ -30,8 +31,8 @@ def __init__(self, and can be choosen from the `pytorch-lightning Trainer API `_ """ - - super().__init__(**kwargs) + log_every_n_steps = kwargs.get('log_every_n_steps', 0) + super().__init__(log_every_n_steps=log_every_n_steps, **kwargs) # check inheritance consistency for solver and batch size check_consistency(solver, SolverInterface) @@ -40,9 +41,9 @@ def __init__(self, self.train_size = train_size self.test_size = test_size self.val_size = val_size + self.predict_size = predict_size self.solver = solver self.batch_size = batch_size - self._create_loader() self._move_to_device() self.data_module = None @@ -83,6 +84,7 @@ def _create_loader(self): train_size=self.train_size, test_size=self.test_size, val_size=self.val_size, + predict_size=self.predict_size, batch_size=self.batch_size, ) self.data_module.setup() @@ -91,9 +93,24 @@ def train(self, **kwargs): Train the solver method. """ self._create_loader() - return super().fit(self.solver, - datamodule=self.data_module, - **kwargs) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="You defined a `validation_step` but have no " + "`val_dataloader`", + category=UserWarning + ) + return super().fit(self.solver, + datamodule=self.data_module, + **kwargs) + + def test(self, **kwargs): + """ + Test the solver method. + """ + return super().test(self.solver, + datamodule=self.data_module, + **kwargs) @property def solver(self):