Skip to content

Commit

Permalink
Refactor ResultReader filters, moving into Filter object owned by Res1D
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-kipawa committed Dec 13, 2024
1 parent 9e5c5b9 commit ac3d2d5
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 189 deletions.
30 changes: 18 additions & 12 deletions mikeio1d/res1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
from .result_reader_writer import ResultReaderCreator
from .result_reader_writer import ResultReaderType
from .result_reader_writer import ResultWriter
from .result_reader_writer.filter import Filter
from .result_reader_writer.filter import TimeFilter
from .result_reader_writer.filter import NameFilter
from .result_reader_writer.filter import StepEveryFilter

from .query import QueryDataCatchment # noqa: F401
from .query import QueryDataNode # noqa: F401
Expand Down Expand Up @@ -147,19 +151,21 @@ def __init__(

# endregion deprecation

self.filter = Filter(
[
NameFilter(reaches, nodes, catchments),
TimeFilter(time),
StepEveryFilter(step_every),
]
)

self.reader = ResultReaderCreator.create(
result_reader_type,
self,
file_path,
lazy_load,
header_load,
reaches,
nodes,
catchments,
col_name_delimiter,
put_chainage_in_col_name,
time=time,
step_every=step_every,
result_reader_type=result_reader_type,
res1d=self,
file_path=file_path,
col_name_delimiter=col_name_delimiter,
put_chainage_in_col_name=put_chainage_in_col_name,
filter=self.filter,
)

self.network = ResultNetwork(self)
Expand Down
9 changes: 9 additions & 0 deletions mikeio1d/result_reader_writer/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Filter module for building Filter objects."""

from .filter import Filter
from .filter import SubFilter
from .name_filter import NameFilter
from .time_filter import TimeFilter
from .step_every_filter import StepEveryFilter

__all__ = ["Filter", "SubFilter", "NameFilter", "TimeFilter", "StepEveryFilter"]
54 changes: 54 additions & 0 deletions mikeio1d/result_reader_writer/filter/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Filter Class."""

from __future__ import annotations

from typing import Protocol
from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
from DHI.Mike1D.ResultDataAccess import ResultData

from DHI.Mike1D.ResultDataAccess import Filter as Mike1DFilter


class Filter:
"""Wrapper class for applying subfilters to a Filter object."""

def __init__(
self,
sub_filters: list[SubFilter],
):
self._filter = Mike1DFilter()
self.sub_filters = sub_filters

def use_filter(self) -> bool:
"""Whether the filter should be applied."""
return any([f.use_filter() for f in self.sub_filters])

def apply(self, result_data: ResultData):
"""Apply filter."""
for sub_filter in self.sub_filters:
sub_filter.apply(self._filter, result_data)
result_data.Parameters.Filter = self._filter

@property
def m1d_filter(self) -> Filter:
""".NET DHI.Mike1D.ResultDataAccess.Filter object."""
return self._filter

@property
def filtered_reaches(self) -> list[str]:
"""List of filtered reach names."""
return self._filter.FilteredReaches


class SubFilter(Protocol):
"""Class for configuring Filter objects."""

def apply(self, filter: Filter, result_data: ResultData | None) -> None:
"""Apply the filter to the provided Filter object."""
pass

def use_filter(self) -> bool:
"""Check if the filter should be used."""
pass
53 changes: 53 additions & 0 deletions mikeio1d/result_reader_writer/filter/name_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Module for the NameFilter class."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
from DHI.Mike1D.ResultDataAccess import ResultData
from DHI.Mike1D.ResultDataAccess import Filter


from .filter import SubFilter

from DHI.Mike1D.ResultDataAccess import DataItemFilterName


class NameFilter(SubFilter):
"""Wrapper class for applying time filters to a Filter object."""

def __init__(
self,
reaches: None | list[str],
nodes: None | list[str],
catchments: None | list[str],
):
self._reaches = reaches if reaches else []
self._nodes = nodes if nodes else []
self._catchments = catchments if catchments else []

def use_filter(self) -> bool:
"""Check if the filter should be used."""
return any((self._reaches, self._nodes, self._catchments))

def apply(self, filter: Filter, result_data: ResultData | None):
"""Apply the filter to the provided Filter object."""
if not self.use_filter():
return

data_item_filter = self.create_data_item_filter(result_data)
filter.AddDataItemFilter(data_item_filter)

def create_data_item_filter(self, result_data: ResultData) -> DataItemFilterName:
"""Create DataItemFilterName object."""
data_item_filter = DataItemFilterName(result_data)

for reach in self._reaches:
data_item_filter.Reaches.Add(reach)
for node in self._nodes:
data_item_filter.Nodes.Add(node)
for catchment in self._catchments:
data_item_filter.Catchments.Add(catchment)

return data_item_filter
25 changes: 25 additions & 0 deletions mikeio1d/result_reader_writer/filter/step_every_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Module for the TimeFilter class."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
from DHI.Mike1D.ResultDataAccess import ResultData
from DHI.Mike1D.ResultDataAccess import Filter


class StepEveryFilter:
"""Wrapper class for applying step every filter to a Filter object."""

def __init__(self, step_every: int | None):
self._step_every = step_every

def use_filter(self) -> bool:
"""Check if the filter should be used."""
return self._step_every is not None

def apply(self, filter: Filter, result_data: ResultData | None = None):
"""Apply the filter to the provided Filter object."""
if self.use_filter():
filter.LoadStep = self._step_every
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
from typing import Union
from DHI.Mike1D.ResultDataAccess import ResultData
from DHI.Mike1D.ResultDataAccess import Filter

from datetime import datetime

Expand All @@ -14,22 +15,27 @@
from System import DateTime
from DHI.Mike1D.ResultDataAccess import Period

from ..dotnet import to_dotnet_datetime
from .filter import SubFilter
from ...dotnet import to_dotnet_datetime


class TimeFilter:
class TimeFilter(SubFilter):
"""Wrapper class for applying time filters to a Filter object."""

def __init__(self, filter):
self._filter = filter
def __init__(self, time: None | slice | tuple | list):
self._time = time

def setup_from_user_params(self, *, time: Union[None, slice, tuple, list]):
"""Set up the filter using a user supplied parameters."""
if time is None:
def use_filter(self) -> bool:
"""Check if the filter should be used."""
return self._time is not None

def apply(self, filter: Filter, result_data: ResultData | None = None):
"""Apply the filter to the provided Filter object."""
if not self.use_filter():
return

time = self._time
start, end = None, None

if isinstance(time, slice):
start = time.start
end = time.stop
Expand All @@ -44,21 +50,17 @@ def setup_from_user_params(self, *, time: Union[None, slice, tuple, list]):
if end is not None:
end = pd.to_datetime(end)

self.add_period(start, end)

def add_period(self, start: Union[None, datetime], end: Union[None, datetime]):
"""Add a period to the filter."""
if start is None and end is None:
raise ValueError("Either start or end must be provided")
period = self.create_period(start, end)
filter.Periods.Add(period)

def create_period(self, start: None | datetime, end: None | datetime) -> Period:
"""Create a DHI.Mike1D.ResultDataAccess.Period object."""
start = to_dotnet_datetime(start) if start else DateTime.MinValue
end = to_dotnet_datetime(end) if end else DateTime.MaxValue

start, end = self._adjust_start_and_end(start, end)

period = Period(start, end)

self._filter.Periods.Add(period)
return Period(start, end)

def _adjust_start_and_end(self, start, end):
"""Adjust start and end times to conservatively ensure they are inclusive."""
Expand Down
72 changes: 15 additions & 57 deletions mikeio1d/result_reader_writer/result_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List
from typing import Optional
from ..res1d import Res1D
from .filter import Filter

import warnings

Expand All @@ -23,14 +24,11 @@
from ..dotnet import pythonnet_implementation as impl
from ..various import NAME_DELIMITER
from ..quantities import TimeSeriesId
from .time_filter import TimeFilter
from .step_every_filter import StepEverFilter
from ..result_network import ResultNetwork

from DHI.Mike1D.ResultDataAccess import ResultData
from DHI.Mike1D.ResultDataAccess import ResultDataQuery
from DHI.Mike1D.ResultDataAccess import ResultDataSearch
from DHI.Mike1D.ResultDataAccess import Filter
from DHI.Mike1D.ResultDataAccess import DataItemFilterName
from DHI.Mike1D.ResultDataAccess import ResultTypes

Expand Down Expand Up @@ -83,31 +81,18 @@ def __init__(
self,
res1d,
file_path=None,
lazy_load=False,
header_load=False,
reaches=None,
nodes=None,
catchments=None,
col_name_delimiter=NAME_DELIMITER,
put_chainage_in_col_name=True,
time=None,
step_every=None,
filter: Filter = None,
):
self.res1d: Res1D = res1d

self.file_path = file_path
self.file_extension = os.path.splitext(file_path)[-1]

self.lazy_load = lazy_load

self._reaches = reaches if reaches else []
self._nodes = nodes if nodes else []
self._catchments = catchments if catchments else []
self._time = time
self._step_every = step_every

self.use_filter = any([reaches, nodes, catchments, time, step_every])
self.lazy_load = False

self.filter = filter
self._loaded = False

self._load_header()
Expand Down Expand Up @@ -141,21 +126,14 @@ def _load_header(self):
if self.lazy_load:
self.data.Connection.BridgeName = "res1dlazy"

if self.use_filter:
if self.filter.use_filter():
self.data.LoadHeader(True, self.diagnostics)
else:
self.data.LoadHeader(self.diagnostics)

def _load_file(self):
if self.use_filter:
self._setup_filter()

for reach in self._reaches:
self._add_reach(reach)
for node in self._nodes:
self._add_node(node)
for catchment in self._catchments:
self._add_catchment(catchment)
if self.filter.use_filter():
self.filter.apply(self.data)

if self.file_extension.lower() in [".resx", ".crf", ".prf", ".xrf"]:
self.data.Load(self.diagnostics)
Expand All @@ -170,30 +148,6 @@ def load_dynamic_data(self):
self._load_file()
self._loaded = True

def _setup_filter(self):
"""Set up the filter for result data object."""
if not self.use_filter:
return

self.data_filter = Filter()
self.time_filter = TimeFilter(self.data_filter)
self.time_filter.setup_from_user_params(time=self._time)
self.step_every_filter = StepEverFilter(self.data_filter)
self.step_every_filter.setup_from_user_params(step_every=self._step_every)
self.data_subfilter = DataItemFilterName(self.data)
self.data_filter.AddDataItemFilter(self.data_subfilter)

self.data.Parameters.Filter = self.data_filter

def _add_reach(self, reach_id):
self.data_subfilter.Reaches.Add(reach_id)

def _add_node(self, node_id):
self.data_subfilter.Nodes.Add(node_id)

def _add_catchment(self, catchment_id):
self.data_subfilter.Catchments.Add(catchment_id)

# endregion File loading

@abstractmethod
Expand Down Expand Up @@ -237,10 +191,14 @@ def read_all(self, column_mode: Optional[str | ColumnMode]) -> pd.DataFrame:

def is_data_set_included(self, data_set):
"""Skip filtered data sets."""
name = self.get_data_set_name(data_set)
if self.use_filter and name not in self._catchments + self._reaches + self._nodes:
return False
return True
if not self.filter.use_filter():
return True

m1d_filter = self.filter.m1d_filter
for data_item in data_set.DataItems:
if m1d_filter.Include(data_item):
return True
return False

@property
def query(self):
Expand Down
Loading

0 comments on commit ac3d2d5

Please sign in to comment.