diff --git a/pina/__init__.py b/pina/__init__.py index 3bc28ae6..b415dbab 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,10 +1,10 @@ __all__ = [ "Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset", - "PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph' + "PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph', 'LabelParameter' ] from .meta import * -from .label_tensor import LabelTensor +from .label_tensor import LabelTensor, LabelParameter from .solvers.solver import SolverInterface from .trainer import Trainer from .plotter import Plotter diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 05c543eb..2adf1c82 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -19,7 +19,7 @@ class DataConditionInterface(ConditionInterface): def __init__(self, input_points, conditional_variables=None): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 53e07621..e0e7f916 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -16,7 +16,7 @@ class DomainEquationCondition(ConditionInterface): condition_type = ['physics'] def __init__(self, domain, equation): """ - TODO + TODO : add docstring """ super().__init__() self.domain = domain diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index 2a7f4647..2c376a16 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -17,7 +17,7 @@ class InputPointsEquationCondition(ConditionInterface): condition_type = ['physics'] def __init__(self, input_points, equation): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py index e9c34bea..de92926d 100644 --- a/pina/condition/input_output_condition.py +++ b/pina/condition/input_output_condition.py @@ -16,7 +16,7 @@ class InputOutputPointsCondition(ConditionInterface): condition_type = ['supervised'] def __init__(self, input_points, output_points): """ - TODO + TODO : add docstring """ super().__init__() self.input_points = input_points diff --git a/pina/data/base_dataset.py b/pina/data/base_dataset.py index d05784f8..b70b0c25 100644 --- a/pina/data/base_dataset.py +++ b/pina/data/base_dataset.py @@ -3,9 +3,10 @@ """ import torch import logging - +import math from torch.utils.data import Dataset from ..label_tensor import LabelTensor +from .pina_subset import PinaSubset class BaseDataset(Dataset): @@ -40,7 +41,7 @@ def __init__(self, problem=None, device=torch.device('cpu')): super().__init__() self.empty = True self.problem = problem - self.device = device + self.device = torch.device('cpu') self.condition_indices = None for slot in self.__slots__: setattr(self, slot, []) @@ -52,7 +53,7 @@ def __init__(self, problem=None, device=torch.device('cpu')): def _init_from_problem(self, collector_dict): """ - TODO + TODO : Add docstring """ for name, data in collector_dict.items(): keys = list(data.keys()) @@ -151,6 +152,50 @@ def apply_shuffle(self, indices): if slot != 'equation': attribute = getattr(self, slot) if isinstance(attribute, (LabelTensor, torch.Tensor)): - setattr(self, 'slot', attribute[[indices]]) + setattr(self, 'slot', attribute[[indices]].detach()) if isinstance(attribute, list): setattr(self, 'slot', [attribute[i] for i in indices]) + self.condition_indices = self.condition_indices[indices] + + def eval_splitting_lengths(self, lengths): + if sum(lengths) - 1 < 1e-3: + len_dataset = len(self) + lengths = [ + int(math.floor(len_dataset * length)) for length in lengths + ] + remainder = len(self) - sum(lengths) + for i in range(remainder): + lengths[i % len(lengths)] += 1 + elif sum(lengths) - 1 >= 1e-3: + raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1") + return lengths + + def dataset_split(self, lengths, seed=None, shuffle=True): + """ + Perform the splitting of the dataset + :param dataset: dataset object we wanted to split + :param lengths: lengths of elements in dataset + :param seed: random seed + :param shuffle: shuffle dataset + :return: split dataset + :rtype: PinaSubset + """ + + lengths = self.eval_splitting_lengths(lengths) + + if shuffle: + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + indices = torch.randperm(sum(lengths), generator=generator) + else: + indices = torch.randperm(sum(lengths)) + self.apply_shuffle(indices) + + offsets = [ + sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) + ] + return [ + PinaSubset(self, slice(offset, offset + length)) + for offset, length in zip(offsets, lengths) + ] diff --git a/pina/data/data_module.py b/pina/data/data_module.py index b09fb54a..eb625d85 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -2,15 +2,12 @@ This module provide basic data management functionalities """ -import math -import torch import logging from pytorch_lightning import LightningDataModule from .sample_dataset import SamplePointDataset from .supervised_dataset import SupervisedDataset from .unsupervised_dataset import UnsupervisedDataset from .pina_dataloader import PinaDataLoader -from .pina_subset import PinaSubset class PinaDataModule(LightningDataModule): @@ -101,9 +98,9 @@ def setup(self, stage=None): if stage == 'fit' or stage is None: for dataset in self.datasets: if len(dataset) > 0: - splits = self.dataset_split(dataset, - self.split_length, - shuffle=self.shuffle) + splits = dataset.dataset_split( + self.split_length, + shuffle=self.shuffle) for i in range(len(self.split_length)): self.splits[self.split_names[i]][ dataset.data_type] = splits[i] @@ -116,45 +113,6 @@ def setup(self, stage=None): else: raise ValueError("stage must be either 'fit' or 'test'") - @staticmethod - def dataset_split(dataset, lengths, seed=None, shuffle=True): - """ - Perform the splitting of the dataset - :param dataset: dataset object we wanted to split - :param lengths: lengths of elements in dataset - :param seed: random seed - :param shuffle: shuffle dataset - :return: split dataset - :rtype: PinaSubset - """ - if sum(lengths) - 1 < 1e-3: - len_dataset = len(dataset) - lengths = [ - int(math.floor(len_dataset * length)) for length in lengths - ] - remainder = len(dataset) - sum(lengths) - for i in range(remainder): - lengths[i % len(lengths)] += 1 - elif sum(lengths) - 1 >= 1e-3: - raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1") - - if shuffle: - if seed is not None: - generator = torch.Generator() - generator.manual_seed(seed) - indices = torch.randperm(sum(lengths), generator=generator) - else: - indices = torch.randperm(sum(lengths)) - dataset.apply_shuffle(indices) - - offsets = [ - sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) - ] - return [ - PinaSubset(dataset, slice(offset, offset + length)) - for offset, length in zip(offsets, lengths) - ] - def _create_datasets(self): """ Create the dataset objects putting data @@ -189,25 +147,25 @@ def val_dataloader(self): Create the validation dataloader """ return PinaDataLoader(self.splits['val'], self.batch_size, - self.condition_names) + self.condition_names, device=self.device) def train_dataloader(self): """ Create the training dataloader """ return PinaDataLoader(self.splits['train'], self.batch_size, - self.condition_names) + self.condition_names, device=self.device) def test_dataloader(self): """ Create the testing dataloader """ return PinaDataLoader(self.splits['test'], self.batch_size, - self.condition_names) + self.condition_names, device=self.device) def predict_dataloader(self): """ Create the prediction dataloader """ return PinaDataLoader(self.splits['predict'], self.batch_size, - self.condition_names) + self.condition_names, device=self.device) diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index e43e1108..0aac85b7 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -3,7 +3,6 @@ """ import torch from ..label_tensor import LabelTensor - from .pina_subset import PinaSubset @@ -13,8 +12,10 @@ class Batch: optimization. """ - def __init__(self, dataset_dict, idx_dict, require_grad=True): + def __init__(self, dataset_dict, idx_dict, require_grad=True, + device=torch.device('cpu')): self.attributes = [] + self.require_grad = require_grad for k, v in dataset_dict.items(): index = idx_dict[k] if isinstance(v, PinaSubset): @@ -23,10 +24,12 @@ def __init__(self, dataset_dict, idx_dict, require_grad=True): index = slice(dataset_index.start + index.start, min(dataset_index.start + index.stop, dataset_index.stop)) - setattr(self, k, PinaSubset(v.dataset, index, - require_grad=require_grad)) + data = PinaSubset(v.dataset, index, require_grad=require_grad) + setattr(self, k, data) + setattr(self, k + '_data', self.prepare_data(data)) self.attributes.append(k) self.require_grad = require_grad + self.device = device def __len__(self): """ @@ -46,11 +49,10 @@ def __getattr__(self, item): return self.__getattribute__(item) raise AttributeError(f"'Batch' object has no attribute '{item}'") - def get_data(self, batch_name=None): + def prepare_data(self, data): """ - # TODO + Prepare the data for the batch """ - data = getattr(self, batch_name) to_return_list = [] if isinstance(data, PinaSubset): items = data.dataset.__slots__ @@ -69,6 +71,23 @@ def get_data(self, batch_name=None): i == condition_idx[k]]) temp.append(i) to_return_list.append(temp) + + return to_return_list + + #@profile + def get_data(self, batch_name=None): + """ + # TODO : add docstring + """ + data = getattr(self, batch_name + '_data') + to_return_list = [ + [ + i.detach().to(self.device).requires_grad_() + if isinstance(i, (torch.Tensor, LabelTensor)) else i + for i in points + ] + for points in data + ] return to_return_list def get_supervised_data(self): diff --git a/pina/data/pina_dataloader.py b/pina/data/pina_dataloader.py index a28ca6c6..04bcee07 100644 --- a/pina/data/pina_dataloader.py +++ b/pina/data/pina_dataloader.py @@ -15,7 +15,8 @@ class PinaDataLoader: :vartype condition_names: list[str] """ - def __init__(self, dataset_dict, batch_size, condition_names) -> None: + def __init__(self, dataset_dict, batch_size, condition_names, + device) -> None: """ Initialize local variables :param dataset_dict: Dictionary of datasets @@ -28,6 +29,7 @@ def __init__(self, dataset_dict, batch_size, condition_names) -> None: self.condition_names = condition_names self.dataset_dict = dataset_dict self.batch_size = batch_size + self.device = device self._init_batches(batch_size) def _init_batches(self, batch_size=None): @@ -79,7 +81,8 @@ def _init_batches(self, batch_size=None): actual_indices[k] = actual_indices[k] + v total_length += v self.batches.append( - Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict)) + Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict, + device=self.device)) def __iter__(self): """ diff --git a/pina/data/pina_subset.py b/pina/data/pina_subset.py index b5b74a68..bcd9c1e0 100644 --- a/pina/data/pina_subset.py +++ b/pina/data/pina_subset.py @@ -7,13 +7,13 @@ class PinaSubset: """ - TODO + TODO : add docstring """ __slots__ = ['dataset', 'indices', 'require_grad'] def __init__(self, dataset, indices, require_grad=False): """ - TODO + TODO : add docstring """ self.dataset = dataset self.indices = indices @@ -21,8 +21,9 @@ def __init__(self, dataset, indices, require_grad=False): def __len__(self): """ - TODO + TODO : add docstring """ + if isinstance(self.indices, slice): return self.indices.stop - self.indices.start return len(self.indices) diff --git a/pina/data/sample_dataset.py b/pina/data/sample_dataset.py index bc3bca33..e8c4e3f1 100644 --- a/pina/data/sample_dataset.py +++ b/pina/data/sample_dataset.py @@ -1,9 +1,12 @@ """ Sample dataset module """ +import warnings from copy import deepcopy +import torch from .base_dataset import BaseDataset from ..condition import InputPointsEquationCondition +from .pina_subset import PinaSubset class SamplePointDataset(BaseDataset): @@ -20,6 +23,7 @@ def add_points(self, data_dict, condition_idx, batching_dim=0): super().add_points(data_dict, condition_idx) def _init_from_problem(self, collector_dict): + self.size = 0 for name, data in collector_dict.items(): keys = list(data.keys()) if set(self.__slots__) == set(keys): @@ -33,3 +37,44 @@ def _init_from_problem(self, collector_dict): ] self.conditions_idx.append(idx) self.initialize() + + def dataset_split(self, lengths, seed=None, shuffle=False): + """ + Split the dataset in different parts + :param lengths: list of the lengths of the splits + :type lengths: list + :param seed: seed for the random generator + :type seed: int + :param shuffle: shuffle the dataset before splitting + :type shuffle: bool + :return: list of the splits + :rtype: list + """ + lengths = self.eval_splitting_lengths(lengths) + if shuffle is True: + warnings.warn('Shuffling not applied to SampleDataset') + + non_train_lengths = sum(lengths[1:]) + + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + indices_splits = torch.randperm(sum(lengths), generator=generator) + else: + indices_splits = torch.randperm(sum(lengths)) + indices_splits = torch.sort(indices_splits[:non_train_lengths])[0] + indices = torch.arange(len(self)) + mask = torch.isin(indices, indices_splits) + indices = indices[~mask] + indices = torch.cat([indices, indices_splits]) + + self.input_points = self.input_points[[indices]].detach() + self.condition_indices = self.condition_indices[indices] + + offsets = [ + sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) + ] + return [ + PinaSubset(self, slice(offset, offset + length)) + for offset, length in zip(offsets, lengths) + ] diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 58dc8b71..c46e369c 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -4,9 +4,9 @@ import torch from torch import Tensor -full_labels = True -MATH_MODULES = {torch.sin, torch.cos, torch.exp, torch.tan, torch.log, - torch.sqrt} +full_labels = False +MATH_FUNCTIONS = {torch.sin, torch.cos} +GRAD_FUNCTIONS = {torch.autograd.grad} def issubset(a, b): @@ -26,6 +26,7 @@ class LabelTensor(torch.Tensor): @staticmethod def __new__(cls, x, labels, *args, **kwargs): full = kwargs.pop("full", full_labels) + if isinstance(x, LabelTensor): x.full = full return x @@ -54,56 +55,78 @@ def __init__(self, x, labels, **kwargs): def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if func in MATH_MODULES: + if func in MATH_FUNCTIONS: str_labels = func.__name__ - labels = copy(args[0].stored_labels) + lt = super().__torch_function__(func, types, args=args, kwargs=kwargs) - lt_shape = lt.shape - - if len(lt_shape) - 1 in labels.keys(): - labels.update({ - len(lt_shape) - 1: { - 'dof': [f'{str_labels}({i})' for i in - labels[len(lt_shape) - 1]['dof']], - 'name': len(lt_shape) - 1 - } - }) - lt._labels = labels - return lt + if hasattr(args[0], '_labels'): + labels = {k: copy(v) for k, v in args[0].stored_labels.items()} + lt._labels = labels + lt.dim_names = args[0].dim_names + + lt_shape = lt.shape + + if len(lt_shape) - 1 in labels.keys(): + labels.update({ + len(lt_shape) - 1: { + 'dof': [f'{str_labels}({i})' for i in + labels[len(lt_shape) - 1]['dof']], + 'name': len(lt_shape) - 1 + } + }) + lt._labels = labels + + return lt + if func in GRAD_FUNCTIONS: + # TODO: Implement the gradient of the LabelTensor + pass return super().__torch_function__(func, types, args=args, kwargs=kwargs) def __mul__(self, other): + lt = super().__mul__(other) + if not hasattr(self, '_labels'): + return lt if isinstance(other, (int, float)): if hasattr(self, '_labels'): lt._labels = self._labels + lt.dim_names = self.dim_names + if isinstance(other, LabelTensor): lt_shape = lt.shape - labels = copy(self.stored_labels) - other_labels = other.stored_labels + check = False - for (k, v), (ko, vo) in zip(sorted(labels.items()), - sorted(other_labels.items())): - if k != ko: - raise ValueError('Labels must be the same') - if k != len(lt_shape) - 1: - if vo != v: + if self.ndim in (0, 1): + labels = copy(other.stored_labels) + else: + labels = copy(self.stored_labels) + other_labels = copy(other.stored_labels) + for (k, v), (ko, vo) in zip(sorted(labels.items()), + sorted(other_labels.items())): + if k != ko: raise ValueError('Labels must be the same') - else: - check = True - if check: - labels.update({ - len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in - zip(self.stored_labels[ - len(lt_shape) - 1]['dof'], - other.stored_labels[ - len(lt_shape) - 1]['dof'])], - 'name': self.stored_labels[ - len(lt_shape) - 1]['name']} - }) + if k != len(lt_shape) - 1: + if vo != v: + raise ValueError('Labels must be the same') + else: + check = True + if check: + labels.update({ + len(lt_shape) - 1: {'dof': [f'{i}{j}' for i, j in + zip(self.stored_labels[ + len(lt_shape) - 1][ + 'dof'], + other.stored_labels[ + len(lt_shape) - 1][ + 'dof'])], + 'name': self.stored_labels[ + len(lt_shape) - 1]['name']} + }) lt._labels = labels + lt.dim_names = self.dim_names + return lt @classmethod @@ -408,7 +431,7 @@ def summation(tensors): raise RuntimeError('Tensors must have the same shape and labels') last_dim_labels = [] - data = torch.zeros(tensors[0].tensor.shape) + data = torch.zeros(tensors[0].tensor.shape).to(tensors[0].device) for tensor in tensors: data += tensor.tensor last_dim_labels.append(tensor.labels) @@ -462,6 +485,7 @@ def __getitem__(self, index): :param index: :return: """ + if isinstance(index, str) or (isinstance(index, (tuple, list)) and all( isinstance(a, str) for a in index)): @@ -509,16 +533,19 @@ def _update_single_label(old_labels, to_update_labels, index, dim): old_dof = old_labels[dim]['dof'] if isinstance(index, torch.Tensor) and index.ndim == 0: index = int(index) + ''' if (not isinstance( index, (int, slice)) and len(index) == len(old_dof) and isinstance(old_dof, range)): return - + ''' if isinstance(index, torch.Tensor): if isinstance(old_dof, range): to_update_labels.update({ dim: { - 'dof': index.tolist(), + 'dof': index.tolist() if not ( + torch.diff(index) == 1).all() else + range(old_dof[index[0]], old_dof[index[-1]] + 1), 'name': old_labels[dim]['name'] } }) @@ -567,3 +594,21 @@ def permute(self, *dims): for k in stored_labels.keys() } return LabelTensor.__internal_init__(tensor, labels, self.dim_names) + + def detach(self): + lt = super().detach() + lt._labels = self.stored_labels + lt.dim_names = self.dim_names + return lt + + +class LabelParameter(torch.nn.Parameter, LabelTensor): + """A class that combines torch.nn.Parameter with LabelTensor behavior.""" + + def __new__(cls, x, labels=None, requires_grad=True): + instance = torch.nn.Parameter.__new__(cls, data=x, + requires_grad=requires_grad) + return instance + + def __init__(self, x, labels=None, requires_grad=True): + LabelTensor.__init__(self, x, labels) diff --git a/pina/model/network.py b/pina/model/network.py index 6fde8039..5d8abfd2 100644 --- a/pina/model/network.py +++ b/pina/model/network.py @@ -67,16 +67,14 @@ def forward(self, x): # in case `input_variables = []` all points are used if self._input_variables: x = x.extract(self._input_variables) - # extract features and append for feature in self._extra_features: x = x.append(feature(x)) # perform forward pass + converting to LabelTensor - output = self._model(x).as_subclass(LabelTensor) - # set the labels for LabelTensor - output.labels = self._output_variables + out = self._model(x.as_subclass(torch.Tensor)) + output = LabelTensor(out, self._output_variables) return output @@ -97,15 +95,9 @@ def forward_map(self, x): This function does not extract the input variables, all the variables are used for both tensors. Output variables are correctly applied. """ - # convert LabelTensor s to torch.Tensor s - x = list(map(lambda x: x.as_subclass(torch.Tensor), x)) # perform forward pass (using torch.Tensor) + converting to LabelTensor - output = self._model(x).as_subclass(LabelTensor) - - # set the labels for LabelTensor - output.labels = self._output_variables - + output = LabelTensor(self._model(x.tensor), self._output_variables) return output @property diff --git a/pina/operators.py b/pina/operators.py index 0b306dfb..ef389a64 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -63,11 +63,9 @@ def grad_scalar_output(output_, input_, d): retain_graph=True, allow_unused=True, )[0] - - gradients.labels = input_.labels - gradients = gradients.extract(d) + gradients.labels = input_.stored_labels + gradients = gradients[..., [input_.labels.index(i) for i in d]] gradients.labels = [f"d{output_fieldname}d{i}" for i in d] - return gradients if not isinstance(input_, LabelTensor): @@ -190,7 +188,9 @@ def laplacian(output_, input_, components=None, d=None, method="std"): to_append_tensors = [] for i, label in enumerate(grad_output.labels): gg = grad(grad_output, input_, d=d, components=[label]) - to_append_tensors.append(gg.extract([gg.labels[i]])) + gg = gg.extract([gg.labels[i]]) + + to_append_tensors.append(gg) labels = [f"dd{components[0]}"] result = LabelTensor.summation(tensors=to_append_tensors) result.labels = labels diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 2bca1823..0739bcc5 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -99,8 +99,11 @@ def __init__( self._model = self._pina_models[0] self._optimizer = self._pina_optimizers[0] self._scheduler = self._pina_schedulers[0] + self.validation_condition_losses = { + k: {'loss': [], + 'count': []} for k in self.problem.conditions.keys()} - def training_step(self, batch, _): + def training_step(self, batch): """ The Physics Informed Solver Training Step. This function takes care of the physics informed training step, and it must not be override @@ -113,11 +116,17 @@ def training_step(self, batch, _): :return: The sum of the loss functions. :rtype: LabelTensor """ + condition_losses = [] - batches = batch.get_supervised_data() + try: + batches = batch.get_supervised_data() + except AttributeError: + batches = [] + for points in batches: input_pts, output_pts, condition_id = points - condition_name = self._dataloader.condition_names[condition_id] + condition_name = self.trainer.data_module.condition_names[ + condition_id] self.__logged_metric = condition_name loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) condition_losses.append(loss_.as_subclass(torch.Tensor)) @@ -125,7 +134,8 @@ def training_step(self, batch, _): batches = batch.get_physics_data() for points in batches: input_pts, condition_id = points - condition_name = self._dataloader.condition_names[condition_id] + condition_name = self.trainer.data_module.condition_names[ + condition_id] condition = self.problem.conditions[condition_name] self.__logged_metric = condition_name loss_ = self.loss_phys(input_pts, condition.equation) @@ -138,21 +148,103 @@ def training_step(self, batch, _): return loss + def validation_step(self, batch): + """ + TODO: add docstring + """ + + try: + batches = batch.get_supervised_data() + except AttributeError: + batches = [] + for points in batches: + input_pts, output_pts, condition_id = points + condition_name = self.trainer.data_module.condition_names[ + condition_id] + self.__logged_metric = condition_name + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + self.validation_condition_losses[condition_name]['loss'].append( + loss_) + self.validation_condition_losses[condition_name]['count'].append( + len(input_pts)) + + batches = batch.get_physics_data() + for points in batches: + input_pts, condition_id = points + condition_name = self.trainer.data_module.condition_names[ + condition_id] + condition = self.problem.conditions[condition_name] + self.__logged_metric = condition_name + + with torch.set_grad_enabled(True): + #input_pts = input_pts.requires_grad_(True) + loss_ = self.loss_phys(input_pts, condition.equation) + # add condition losses for each epoch + self.validation_condition_losses[condition_name]['loss'].append( + loss_) + self.validation_condition_losses[condition_name][ + 'count'].append( + len(input_pts)) + + # clamp unknown parameters in InverseProblem (if needed) + self._clamp_params() + + def on_validation_epoch_end(self): + """ + Solver validation epoch end. + """ + + total_loss = [] + total_count = [] + for k, v in self.validation_condition_losses.items(): + if len(v['loss']) != 0: + local_counter = torch.tensor(v['count']).to(self.device) + n_elements = torch.sum(local_counter) + loss = torch.sum( + torch.stack(v['loss']) * local_counter) / n_elements + loss = loss.as_subclass(torch.Tensor) + total_loss.append(loss) + total_count.append(n_elements) + self.log( + k + "_loss", + loss, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + batch_size=self.trainer.data_module.batch_size, + ) + total_count = (torch.tensor(total_count, dtype=torch.float32). + to(self.device)) + mean_loss = (torch.sum(torch.stack(total_loss) * total_count) / + torch.sum(total_count)) + self.log( + "val_loss", + mean_loss, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + batch_size=self.trainer.data_module.batch_size, + ) + for key in self.validation_condition_losses.keys(): + self.validation_condition_losses[key]['loss'] = [] + self.validation_condition_losses[key]['count'] = [] + def loss_data(self, input_pts, output_pts): """ The data loss for the PINN solver. It computes the loss between the network output against the true solution. This function should not be override if not intentionally. - :param LabelTensor input_tensor: The input to the neural networks. - :param LabelTensor output_tensor: The true solution to compare the + :param LabelTensor input_pts: The input to the neural networks. + :param LabelTensor output_pts: The true solution to compare the network solution. :return: The residual loss averaged on the input coordinates :rtype: torch.Tensor """ return self._loss(self.forward(input_pts), output_pts) - @abstractmethod def loss_phys(self, samples, equation): """ @@ -202,6 +294,9 @@ def store_log(self, loss_value): :param str name: The name of the loss. :param torch.Tensor loss_value: The value of the loss. """ + batch_size = self.trainer.data_module.batch_size \ + if self.trainer.data_module.batch_size is not None else 999 + self.log( self.__logged_metric + "_loss", loss_value, @@ -209,7 +304,7 @@ def store_log(self, loss_value): logger=True, on_epoch=True, on_step=False, - batch_size=self._dataloader.batch_size, + batch_size=batch_size, ) self.__logged_res_losses.append(loss_value) diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index fe9c897e..64eedfff 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -93,7 +93,7 @@ def forward(self, *args, **kwargs): pass @abstractmethod - def training_step(self, batch, batch_idx): + def training_step(self, batch): pass @abstractmethod diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index 049518f1..36bb19f1 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -120,41 +120,27 @@ def training_step(self, batch): :return: The sum of the loss functions. :rtype: LabelTensor """ - condition_loss = [] batches = batch.get_supervised_data() + condition_loss = [] for points in batches: input_pts, output_pts, _ = points loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) condition_loss.append(loss_.as_subclass(torch.Tensor)) loss = sum(condition_loss) - self.log("mean_loss", float(loss), prog_bar=True, logger=True, - on_epoch=True, - on_step=False, batch_size=self.trainer.data_module.batch_size) + self.log("mean_loss", loss, prog_bar=True, logger=True) return loss def validation_step(self, batch): """ Solver validation step. """ - - batch = batch.supervised - condition_idx = batch.condition_indices - for i in range(condition_idx.min(), condition_idx.max() + 1): - condition_name = self.trainer.data_module.condition_names[i] - condition = self.problem.conditions[condition_name] - pts = batch.input_points - out = batch.output_points - if condition_name not in self.problem.conditions: - raise RuntimeError("Something wrong happened.") - - # for data driven mode - if not hasattr(condition, "output_points"): - raise NotImplementedError( - f"{type(self).__name__} works only in data-driven mode.") - - output_pts = out[condition_idx == i] - input_pts = pts[condition_idx == i] - + batches = batch.get_supervised_data() + condition_loss = [] + for points in batches: + input_pts, output_pts, cond_id = points + loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) + condition_loss.append(loss_.as_subclass(torch.Tensor)) + condition_name = self.trainer.data_module.condition_names[cond_id] loss_ = self.loss_data(input_pts=input_pts, output_pts=output_pts) self.validation_condition_losses[condition_name]['loss'].append( loss_) diff --git a/pina/trainer.py b/pina/trainer.py index f5ea5513..af3dfaa4 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -86,7 +86,8 @@ def _create_loader(self): val_size=self.val_size, predict_size=self.predict_size, batch_size=self.batch_size, ) - self.data_module.setup() + if self.batch_size is None: + self.data_module.setup() def train(self, **kwargs): """