Skip to content

Commit

Permalink
Reimplementation of data management classes, fix bugs and improve eff…
Browse files Browse the repository at this point in the history
…iciency of LabelTensor
  • Loading branch information
FilippoOlivo committed Nov 20, 2024
1 parent e851c33 commit b18fe04
Show file tree
Hide file tree
Showing 25 changed files with 563 additions and 967 deletions.
10 changes: 5 additions & 5 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
"Trainer", "LabelTensor", "Plotter", "Condition",
"PinaDataModule", 'TorchOptimizer', 'Graph', 'LabelParameter'
]

from .meta import *
from .label_tensor import LabelTensor
from .label_tensor import LabelTensor, LabelParameter
from .solvers.solver import SolverInterface
from .trainer import Trainer
from .plotter import Plotter
from .condition.condition import Condition
from .data import SamplePointDataset

from .data import PinaDataModule
from .data import PinaDataLoader

from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph
9 changes: 5 additions & 4 deletions pina/collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import LabelTensor
from .utils import check_consistency, merge_tensors


Expand Down Expand Up @@ -68,7 +69,7 @@ def store_sample_domains(self, n, mode, variables, sample_locations):
condition = self.problem.conditions[loc]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self._is_conditions_ready[loc]):
if not self._is_conditions_ready[loc]:
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
Expand All @@ -87,7 +88,7 @@ def store_sample_domains(self, n, mode, variables, sample_locations):
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
Expand All @@ -110,5 +111,5 @@ def add_points(self, new_points_dict):
if not self._is_conditions_ready[k]:
raise RuntimeError(
'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)
self.data_collections[k]['input_points'] = LabelTensor.vstack([self.data_collections[k][
'input_points'], v])
2 changes: 1 addition & 1 deletion pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DataConditionInterface(ConditionInterface):

def __init__(self, input_points, conditional_variables=None):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
Expand Down
2 changes: 1 addition & 1 deletion pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DomainEquationCondition(ConditionInterface):
condition_type = ['physics']
def __init__(self, domain, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.domain = domain
Expand Down
2 changes: 1 addition & 1 deletion pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class InputPointsEquationCondition(ConditionInterface):
condition_type = ['physics']
def __init__(self, input_points, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
Expand Down
2 changes: 1 addition & 1 deletion pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class InputOutputPointsCondition(ConditionInterface):
condition_type = ['supervised']
def __init__(self, input_points, output_points):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
Expand Down
15 changes: 6 additions & 9 deletions pina/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
Import data classes
"""
__all__ = [
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
'PinaDataModule',
'PinaDataset'
]

from .pina_dataloader import PinaDataLoader
from .supervised_dataset import SupervisedDataset
from .sample_dataset import SamplePointDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_batch import Batch
from .data_module import PinaDataModule
from .base_dataset import BaseDataset


from .data_management import PinaDataModule
from .data_management import PinaDataset
156 changes: 0 additions & 156 deletions pina/data/base_dataset.py

This file was deleted.

Loading

0 comments on commit b18fe04

Please sign in to comment.