Skip to content

Commit

Permalink
Improve Dataset, DataLoader + Implement validation for PINN + create …
Browse files Browse the repository at this point in the history
…LabelParameter class (equivalent of LabelTensor for torch.nn.Parameters)
  • Loading branch information
FilippoOlivo committed Nov 13, 2024
1 parent e851c33 commit bea50b4
Show file tree
Hide file tree
Showing 18 changed files with 350 additions and 160 deletions.
4 changes: 2 additions & 2 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph', 'LabelParameter'
]

from .meta import *
from .label_tensor import LabelTensor
from .label_tensor import LabelTensor, LabelParameter
from .solvers.solver import SolverInterface
from .trainer import Trainer
from .plotter import Plotter
Expand Down
2 changes: 1 addition & 1 deletion pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DataConditionInterface(ConditionInterface):

def __init__(self, input_points, conditional_variables=None):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
Expand Down
2 changes: 1 addition & 1 deletion pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DomainEquationCondition(ConditionInterface):
condition_type = ['physics']
def __init__(self, domain, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.domain = domain
Expand Down
2 changes: 1 addition & 1 deletion pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class InputPointsEquationCondition(ConditionInterface):
condition_type = ['physics']
def __init__(self, input_points, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
Expand Down
2 changes: 1 addition & 1 deletion pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class InputOutputPointsCondition(ConditionInterface):
condition_type = ['supervised']
def __init__(self, input_points, output_points):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
Expand Down
53 changes: 49 additions & 4 deletions pina/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""
import torch
import logging

import math
from torch.utils.data import Dataset
from ..label_tensor import LabelTensor
from .pina_subset import PinaSubset


class BaseDataset(Dataset):
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, problem=None, device=torch.device('cpu')):
super().__init__()
self.empty = True
self.problem = problem
self.device = device
self.device = torch.device('cpu')
self.condition_indices = None
for slot in self.__slots__:
setattr(self, slot, [])
Expand All @@ -52,7 +53,7 @@ def __init__(self, problem=None, device=torch.device('cpu')):

def _init_from_problem(self, collector_dict):
"""
TODO
TODO : Add docstring
"""
for name, data in collector_dict.items():
keys = list(data.keys())
Expand Down Expand Up @@ -151,6 +152,50 @@ def apply_shuffle(self, indices):
if slot != 'equation':
attribute = getattr(self, slot)
if isinstance(attribute, (LabelTensor, torch.Tensor)):
setattr(self, 'slot', attribute[[indices]])
setattr(self, 'slot', attribute[[indices]].detach())
if isinstance(attribute, list):
setattr(self, 'slot', [attribute[i] for i in indices])
self.condition_indices = self.condition_indices[indices]

def eval_splitting_lengths(self, lengths):
if sum(lengths) - 1 < 1e-3:
len_dataset = len(self)
lengths = [
int(math.floor(len_dataset * length)) for length in lengths
]
remainder = len(self) - sum(lengths)
for i in range(remainder):
lengths[i % len(lengths)] += 1
elif sum(lengths) - 1 >= 1e-3:
raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1")
return lengths

def dataset_split(self, lengths, seed=None, shuffle=True):
"""
Perform the splitting of the dataset
:param dataset: dataset object we wanted to split
:param lengths: lengths of elements in dataset
:param seed: random seed
:param shuffle: shuffle dataset
:return: split dataset
:rtype: PinaSubset
"""

lengths = self.eval_splitting_lengths(lengths)

if shuffle:
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=generator)
else:
indices = torch.randperm(sum(lengths))
self.apply_shuffle(indices)

offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
return [
PinaSubset(self, slice(offset, offset + length))
for offset, length in zip(offsets, lengths)
]
56 changes: 7 additions & 49 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
This module provide basic data management functionalities
"""

import math
import torch
import logging
from pytorch_lightning import LightningDataModule
from .sample_dataset import SamplePointDataset
from .supervised_dataset import SupervisedDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_dataloader import PinaDataLoader
from .pina_subset import PinaSubset


class PinaDataModule(LightningDataModule):
Expand Down Expand Up @@ -101,9 +98,9 @@ def setup(self, stage=None):
if stage == 'fit' or stage is None:
for dataset in self.datasets:
if len(dataset) > 0:
splits = self.dataset_split(dataset,
self.split_length,
shuffle=self.shuffle)
splits = dataset.dataset_split(
self.split_length,
shuffle=self.shuffle)
for i in range(len(self.split_length)):
self.splits[self.split_names[i]][
dataset.data_type] = splits[i]
Expand All @@ -116,45 +113,6 @@ def setup(self, stage=None):
else:
raise ValueError("stage must be either 'fit' or 'test'")

@staticmethod
def dataset_split(dataset, lengths, seed=None, shuffle=True):
"""
Perform the splitting of the dataset
:param dataset: dataset object we wanted to split
:param lengths: lengths of elements in dataset
:param seed: random seed
:param shuffle: shuffle dataset
:return: split dataset
:rtype: PinaSubset
"""
if sum(lengths) - 1 < 1e-3:
len_dataset = len(dataset)
lengths = [
int(math.floor(len_dataset * length)) for length in lengths
]
remainder = len(dataset) - sum(lengths)
for i in range(remainder):
lengths[i % len(lengths)] += 1
elif sum(lengths) - 1 >= 1e-3:
raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1")

if shuffle:
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(sum(lengths), generator=generator)
else:
indices = torch.randperm(sum(lengths))
dataset.apply_shuffle(indices)

offsets = [
sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths))
]
return [
PinaSubset(dataset, slice(offset, offset + length))
for offset, length in zip(offsets, lengths)
]

def _create_datasets(self):
"""
Create the dataset objects putting data
Expand Down Expand Up @@ -189,25 +147,25 @@ def val_dataloader(self):
Create the validation dataloader
"""
return PinaDataLoader(self.splits['val'], self.batch_size,
self.condition_names)
self.condition_names, device=self.device)

def train_dataloader(self):
"""
Create the training dataloader
"""
return PinaDataLoader(self.splits['train'], self.batch_size,
self.condition_names)
self.condition_names, device=self.device)

def test_dataloader(self):
"""
Create the testing dataloader
"""
return PinaDataLoader(self.splits['test'], self.batch_size,
self.condition_names)
self.condition_names, device=self.device)

def predict_dataloader(self):
"""
Create the prediction dataloader
"""
return PinaDataLoader(self.splits['predict'], self.batch_size,
self.condition_names)
self.condition_names, device=self.device)
33 changes: 26 additions & 7 deletions pina/data/pina_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
import torch
from ..label_tensor import LabelTensor

from .pina_subset import PinaSubset


Expand All @@ -13,8 +12,10 @@ class Batch:
optimization.
"""

def __init__(self, dataset_dict, idx_dict, require_grad=True):
def __init__(self, dataset_dict, idx_dict, require_grad=True,
device=torch.device('cpu')):
self.attributes = []
self.require_grad = require_grad
for k, v in dataset_dict.items():
index = idx_dict[k]
if isinstance(v, PinaSubset):
Expand All @@ -23,10 +24,12 @@ def __init__(self, dataset_dict, idx_dict, require_grad=True):
index = slice(dataset_index.start + index.start,
min(dataset_index.start + index.stop,
dataset_index.stop))
setattr(self, k, PinaSubset(v.dataset, index,
require_grad=require_grad))
data = PinaSubset(v.dataset, index, require_grad=require_grad)
setattr(self, k, data)
setattr(self, k + '_data', self.prepare_data(data))
self.attributes.append(k)
self.require_grad = require_grad
self.device = device

def __len__(self):
"""
Expand All @@ -46,11 +49,10 @@ def __getattr__(self, item):
return self.__getattribute__(item)
raise AttributeError(f"'Batch' object has no attribute '{item}'")

def get_data(self, batch_name=None):
def prepare_data(self, data):
"""
# TODO
Prepare the data for the batch
"""
data = getattr(self, batch_name)
to_return_list = []
if isinstance(data, PinaSubset):
items = data.dataset.__slots__
Expand All @@ -69,6 +71,23 @@ def get_data(self, batch_name=None):
i == condition_idx[k]])
temp.append(i)
to_return_list.append(temp)

return to_return_list

#@profile
def get_data(self, batch_name=None):
"""
# TODO : add docstring
"""
data = getattr(self, batch_name + '_data')
to_return_list = [
[
i.detach().to(self.device).requires_grad_()
if isinstance(i, (torch.Tensor, LabelTensor)) else i
for i in points
]
for points in data
]
return to_return_list

def get_supervised_data(self):
Expand Down
7 changes: 5 additions & 2 deletions pina/data/pina_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class PinaDataLoader:
:vartype condition_names: list[str]
"""

def __init__(self, dataset_dict, batch_size, condition_names) -> None:
def __init__(self, dataset_dict, batch_size, condition_names,
device) -> None:
"""
Initialize local variables
:param dataset_dict: Dictionary of datasets
Expand All @@ -28,6 +29,7 @@ def __init__(self, dataset_dict, batch_size, condition_names) -> None:
self.condition_names = condition_names
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.device = device
self._init_batches(batch_size)

def _init_batches(self, batch_size=None):
Expand Down Expand Up @@ -79,7 +81,8 @@ def _init_batches(self, batch_size=None):
actual_indices[k] = actual_indices[k] + v
total_length += v
self.batches.append(
Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict))
Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict,
device=self.device))

def __iter__(self):
"""
Expand Down
7 changes: 4 additions & 3 deletions pina/data/pina_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@

class PinaSubset:
"""
TODO
TODO : add docstring
"""
__slots__ = ['dataset', 'indices', 'require_grad']

def __init__(self, dataset, indices, require_grad=False):
"""
TODO
TODO : add docstring
"""
self.dataset = dataset
self.indices = indices
self.require_grad = require_grad

def __len__(self):
"""
TODO
TODO : add docstring
"""

if isinstance(self.indices, slice):
return self.indices.stop - self.indices.start
return len(self.indices)
Expand Down
Loading

0 comments on commit bea50b4

Please sign in to comment.