diff --git a/mikeio1d/res1d.py b/mikeio1d/res1d.py index d00fc7f9..62a75c32 100644 --- a/mikeio1d/res1d.py +++ b/mikeio1d/res1d.py @@ -45,6 +45,10 @@ from .result_reader_writer import ResultReaderCreator from .result_reader_writer import ResultReaderType from .result_reader_writer import ResultWriter +from .result_reader_writer.filter import Filter +from .result_reader_writer.filter import TimeFilter +from .result_reader_writer.filter import NameFilter +from .result_reader_writer.filter import StepEveryFilter from .query import QueryDataCatchment # noqa: F401 from .query import QueryDataNode # noqa: F401 @@ -147,19 +151,21 @@ def __init__( # endregion deprecation + self.filter = Filter( + [ + NameFilter(reaches, nodes, catchments), + TimeFilter(time), + StepEveryFilter(step_every), + ] + ) + self.reader = ResultReaderCreator.create( - result_reader_type, - self, - file_path, - lazy_load, - header_load, - reaches, - nodes, - catchments, - col_name_delimiter, - put_chainage_in_col_name, - time=time, - step_every=step_every, + result_reader_type=result_reader_type, + res1d=self, + file_path=file_path, + col_name_delimiter=col_name_delimiter, + put_chainage_in_col_name=put_chainage_in_col_name, + filter=self.filter, ) self.network = ResultNetwork(self) diff --git a/mikeio1d/result_reader_writer/filter/__init__.py b/mikeio1d/result_reader_writer/filter/__init__.py new file mode 100644 index 00000000..8db8a2d9 --- /dev/null +++ b/mikeio1d/result_reader_writer/filter/__init__.py @@ -0,0 +1,9 @@ +"""Filter module for building Filter objects.""" + +from .filter import Filter +from .filter import SubFilter +from .name_filter import NameFilter +from .time_filter import TimeFilter +from .step_every_filter import StepEveryFilter + +__all__ = ["Filter", "SubFilter", "NameFilter", "TimeFilter", "StepEveryFilter"] diff --git a/mikeio1d/result_reader_writer/filter/filter.py b/mikeio1d/result_reader_writer/filter/filter.py new file mode 100644 index 00000000..50d915ff --- /dev/null +++ b/mikeio1d/result_reader_writer/filter/filter.py @@ -0,0 +1,54 @@ +"""Filter Class.""" + +from __future__ import annotations + +from typing import Protocol +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from DHI.Mike1D.ResultDataAccess import ResultData + +from DHI.Mike1D.ResultDataAccess import Filter as Mike1DFilter + + +class Filter: + """Wrapper class for applying subfilters to a Filter object.""" + + def __init__( + self, + sub_filters: list[SubFilter], + ): + self._filter = Mike1DFilter() + self.sub_filters = sub_filters + + def use_filter(self) -> bool: + """Whether the filter should be applied.""" + return any([f.use_filter() for f in self.sub_filters]) + + def apply(self, result_data: ResultData): + """Apply filter.""" + for sub_filter in self.sub_filters: + sub_filter.apply(self._filter, result_data) + result_data.Parameters.Filter = self._filter + + @property + def m1d_filter(self) -> Filter: + """.NET DHI.Mike1D.ResultDataAccess.Filter object.""" + return self._filter + + @property + def filtered_reaches(self) -> list[str]: + """List of filtered reach names.""" + return self._filter.FilteredReaches + + +class SubFilter(Protocol): + """Class for configuring Filter objects.""" + + def apply(self, filter: Filter, result_data: ResultData | None) -> None: + """Apply the filter to the provided Filter object.""" + pass + + def use_filter(self) -> bool: + """Check if the filter should be used.""" + pass diff --git a/mikeio1d/result_reader_writer/filter/name_filter.py b/mikeio1d/result_reader_writer/filter/name_filter.py new file mode 100644 index 00000000..28871f56 --- /dev/null +++ b/mikeio1d/result_reader_writer/filter/name_filter.py @@ -0,0 +1,53 @@ +"""Module for the NameFilter class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from DHI.Mike1D.ResultDataAccess import ResultData + from DHI.Mike1D.ResultDataAccess import Filter + + +from .filter import SubFilter + +from DHI.Mike1D.ResultDataAccess import DataItemFilterName + + +class NameFilter(SubFilter): + """Wrapper class for applying time filters to a Filter object.""" + + def __init__( + self, + reaches: None | list[str], + nodes: None | list[str], + catchments: None | list[str], + ): + self._reaches = reaches if reaches else [] + self._nodes = nodes if nodes else [] + self._catchments = catchments if catchments else [] + + def use_filter(self) -> bool: + """Check if the filter should be used.""" + return any((self._reaches, self._nodes, self._catchments)) + + def apply(self, filter: Filter, result_data: ResultData | None): + """Apply the filter to the provided Filter object.""" + if not self.use_filter(): + return + + data_item_filter = self.create_data_item_filter(result_data) + filter.AddDataItemFilter(data_item_filter) + + def create_data_item_filter(self, result_data: ResultData) -> DataItemFilterName: + """Create DataItemFilterName object.""" + data_item_filter = DataItemFilterName(result_data) + + for reach in self._reaches: + data_item_filter.Reaches.Add(reach) + for node in self._nodes: + data_item_filter.Nodes.Add(node) + for catchment in self._catchments: + data_item_filter.Catchments.Add(catchment) + + return data_item_filter diff --git a/mikeio1d/result_reader_writer/filter/step_every_filter.py b/mikeio1d/result_reader_writer/filter/step_every_filter.py new file mode 100644 index 00000000..60cfa40a --- /dev/null +++ b/mikeio1d/result_reader_writer/filter/step_every_filter.py @@ -0,0 +1,25 @@ +"""Module for the TimeFilter class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from DHI.Mike1D.ResultDataAccess import ResultData + from DHI.Mike1D.ResultDataAccess import Filter + + +class StepEveryFilter: + """Wrapper class for applying step every filter to a Filter object.""" + + def __init__(self, step_every: int | None): + self._step_every = step_every + + def use_filter(self) -> bool: + """Check if the filter should be used.""" + return self._step_every is not None + + def apply(self, filter: Filter, result_data: ResultData | None = None): + """Apply the filter to the provided Filter object.""" + if self.use_filter(): + filter.LoadStep = self._step_every diff --git a/mikeio1d/result_reader_writer/time_filter.py b/mikeio1d/result_reader_writer/filter/time_filter.py similarity index 60% rename from mikeio1d/result_reader_writer/time_filter.py rename to mikeio1d/result_reader_writer/filter/time_filter.py index cbf7fd9b..6f9f94db 100644 --- a/mikeio1d/result_reader_writer/time_filter.py +++ b/mikeio1d/result_reader_writer/filter/time_filter.py @@ -5,7 +5,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: # pragma: no cover - from typing import Union + from DHI.Mike1D.ResultDataAccess import ResultData + from DHI.Mike1D.ResultDataAccess import Filter from datetime import datetime @@ -14,22 +15,27 @@ from System import DateTime from DHI.Mike1D.ResultDataAccess import Period -from ..dotnet import to_dotnet_datetime +from .filter import SubFilter +from ...dotnet import to_dotnet_datetime -class TimeFilter: +class TimeFilter(SubFilter): """Wrapper class for applying time filters to a Filter object.""" - def __init__(self, filter): - self._filter = filter + def __init__(self, time: None | slice | tuple | list): + self._time = time - def setup_from_user_params(self, *, time: Union[None, slice, tuple, list]): - """Set up the filter using a user supplied parameters.""" - if time is None: + def use_filter(self) -> bool: + """Check if the filter should be used.""" + return self._time is not None + + def apply(self, filter: Filter, result_data: ResultData | None = None): + """Apply the filter to the provided Filter object.""" + if not self.use_filter(): return + time = self._time start, end = None, None - if isinstance(time, slice): start = time.start end = time.stop @@ -44,21 +50,17 @@ def setup_from_user_params(self, *, time: Union[None, slice, tuple, list]): if end is not None: end = pd.to_datetime(end) - self.add_period(start, end) - - def add_period(self, start: Union[None, datetime], end: Union[None, datetime]): - """Add a period to the filter.""" - if start is None and end is None: - raise ValueError("Either start or end must be provided") + period = self.create_period(start, end) + filter.Periods.Add(period) + def create_period(self, start: None | datetime, end: None | datetime) -> Period: + """Create a DHI.Mike1D.ResultDataAccess.Period object.""" start = to_dotnet_datetime(start) if start else DateTime.MinValue end = to_dotnet_datetime(end) if end else DateTime.MaxValue start, end = self._adjust_start_and_end(start, end) - period = Period(start, end) - - self._filter.Periods.Add(period) + return Period(start, end) def _adjust_start_and_end(self, start, end): """Adjust start and end times to conservatively ensure they are inclusive.""" diff --git a/mikeio1d/result_reader_writer/result_reader.py b/mikeio1d/result_reader_writer/result_reader.py index 7d305bdd..b496b34b 100644 --- a/mikeio1d/result_reader_writer/result_reader.py +++ b/mikeio1d/result_reader_writer/result_reader.py @@ -8,6 +8,7 @@ from typing import List from typing import Optional from ..res1d import Res1D + from .filter import Filter import warnings @@ -23,14 +24,11 @@ from ..dotnet import pythonnet_implementation as impl from ..various import NAME_DELIMITER from ..quantities import TimeSeriesId -from .time_filter import TimeFilter -from .step_every_filter import StepEverFilter from ..result_network import ResultNetwork from DHI.Mike1D.ResultDataAccess import ResultData from DHI.Mike1D.ResultDataAccess import ResultDataQuery from DHI.Mike1D.ResultDataAccess import ResultDataSearch -from DHI.Mike1D.ResultDataAccess import Filter from DHI.Mike1D.ResultDataAccess import DataItemFilterName from DHI.Mike1D.ResultDataAccess import ResultTypes @@ -83,31 +81,18 @@ def __init__( self, res1d, file_path=None, - lazy_load=False, - header_load=False, - reaches=None, - nodes=None, - catchments=None, col_name_delimiter=NAME_DELIMITER, put_chainage_in_col_name=True, - time=None, - step_every=None, + filter: Filter = None, ): self.res1d: Res1D = res1d self.file_path = file_path self.file_extension = os.path.splitext(file_path)[-1] - self.lazy_load = lazy_load - - self._reaches = reaches if reaches else [] - self._nodes = nodes if nodes else [] - self._catchments = catchments if catchments else [] - self._time = time - self._step_every = step_every - - self.use_filter = any([reaches, nodes, catchments, time, step_every]) + self.lazy_load = False + self.filter = filter self._loaded = False self._load_header() @@ -141,21 +126,14 @@ def _load_header(self): if self.lazy_load: self.data.Connection.BridgeName = "res1dlazy" - if self.use_filter: + if self.filter.use_filter(): self.data.LoadHeader(True, self.diagnostics) else: self.data.LoadHeader(self.diagnostics) def _load_file(self): - if self.use_filter: - self._setup_filter() - - for reach in self._reaches: - self._add_reach(reach) - for node in self._nodes: - self._add_node(node) - for catchment in self._catchments: - self._add_catchment(catchment) + if self.filter.use_filter(): + self.filter.apply(self.data) if self.file_extension.lower() in [".resx", ".crf", ".prf", ".xrf"]: self.data.Load(self.diagnostics) @@ -170,30 +148,6 @@ def load_dynamic_data(self): self._load_file() self._loaded = True - def _setup_filter(self): - """Set up the filter for result data object.""" - if not self.use_filter: - return - - self.data_filter = Filter() - self.time_filter = TimeFilter(self.data_filter) - self.time_filter.setup_from_user_params(time=self._time) - self.step_every_filter = StepEverFilter(self.data_filter) - self.step_every_filter.setup_from_user_params(step_every=self._step_every) - self.data_subfilter = DataItemFilterName(self.data) - self.data_filter.AddDataItemFilter(self.data_subfilter) - - self.data.Parameters.Filter = self.data_filter - - def _add_reach(self, reach_id): - self.data_subfilter.Reaches.Add(reach_id) - - def _add_node(self, node_id): - self.data_subfilter.Nodes.Add(node_id) - - def _add_catchment(self, catchment_id): - self.data_subfilter.Catchments.Add(catchment_id) - # endregion File loading @abstractmethod @@ -237,10 +191,14 @@ def read_all(self, column_mode: Optional[str | ColumnMode]) -> pd.DataFrame: def is_data_set_included(self, data_set): """Skip filtered data sets.""" - name = self.get_data_set_name(data_set) - if self.use_filter and name not in self._catchments + self._reaches + self._nodes: - return False - return True + if not self.filter.use_filter(): + return True + + m1d_filter = self.filter.m1d_filter + for data_item in data_set.DataItems: + if m1d_filter.Include(data_item): + return True + return False @property def query(self): diff --git a/mikeio1d/result_reader_writer/result_reader_copier.py b/mikeio1d/result_reader_writer/result_reader_copier.py index 403d960a..90231eb6 100644 --- a/mikeio1d/result_reader_writer/result_reader_copier.py +++ b/mikeio1d/result_reader_writer/result_reader_copier.py @@ -33,29 +33,17 @@ def __init__( self, res1d, file_path=None, - lazy_load=False, - header_load=False, - reaches=None, - nodes=None, - catchments=None, col_name_delimiter=NAME_DELIMITER, put_chainage_in_col_name=True, - time=None, - step_every=None, + filter=None, ): ResultReader.__init__( self, - res1d, - file_path, - lazy_load, - header_load, - reaches, - nodes, - catchments, - col_name_delimiter, - put_chainage_in_col_name, - time=time, - step_every=step_every, + res1d=res1d, + file_path=file_path, + col_name_delimiter=col_name_delimiter, + put_chainage_in_col_name=put_chainage_in_col_name, + filter=filter, ) self.result_data_copier = ResultDataCopier(self.data) diff --git a/mikeio1d/result_reader_writer/result_reader_creator.py b/mikeio1d/result_reader_writer/result_reader_creator.py index 6fcf3bcf..acaed7b0 100644 --- a/mikeio1d/result_reader_writer/result_reader_creator.py +++ b/mikeio1d/result_reader_writer/result_reader_creator.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: # pragma: no cover from typing import Dict from .result_reader import ResultReader + from .filter import Filter from ..various import NAME_DELIMITER @@ -29,15 +30,9 @@ def create( result_reader_type, res1d, file_path=None, - lazy_load=False, - header_load=False, - reaches=None, - nodes=None, - catchments=None, col_name_delimiter=NAME_DELIMITER, put_chainage_in_col_name=True, - time=None, - step_every=None, + filter: Filter = None, ) -> ResultReader: """Create a ResultReader object based on the provided type.""" reasult_readers: Dict[ResultReaderType, ResultReader] = { @@ -48,15 +43,9 @@ def create( reader = reasult_readers.get(result_reader_type, None) return reader( - res1d, - file_path, - lazy_load, - header_load, - reaches, - nodes, - catchments, - col_name_delimiter, - put_chainage_in_col_name, - time=time, - step_every=step_every, + res1d=res1d, + file_path=file_path, + col_name_delimiter=col_name_delimiter, + put_chainage_in_col_name=put_chainage_in_col_name, + filter=filter, ) diff --git a/mikeio1d/result_reader_writer/step_every_filter.py b/mikeio1d/result_reader_writer/step_every_filter.py deleted file mode 100644 index 185e7f02..00000000 --- a/mikeio1d/result_reader_writer/step_every_filter.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Module for the TimeFilter class.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from DHI.Mike1D.ResultDataAccess import Filter - -from ..dotnet import to_dotnet_datetime - - -class StepEverFilter: - """Wrapper class for applying step every filter to a Filter object.""" - - def __init__(self, filter: Filter): - self._filter = filter - - def setup_from_user_params(self, step_every: int | None): - """Set up the filter using a user supplied parameters.""" - if step_every is None: - return - self._filter.LoadStep = step_every diff --git a/tests/test_derived_quantities.py b/tests/test_derived_quantities.py index 1d357f4a..2ceaad14 100644 --- a/tests/test_derived_quantities.py +++ b/tests/test_derived_quantities.py @@ -43,21 +43,15 @@ def test_access_available_derived_quantities(res1d_network): assert isinstance(dq[0], str) -def test_access_available_derived_quantities_nodes( - res1d_network, node_derived_quantity_id -): +def test_access_available_derived_quantities_nodes(res1d_network, node_derived_quantity_id): assert node_derived_quantity_id in res1d_network.derived_quantities -def test_access_available_derived_quantities_reaches( - res1d_network, reach_derived_quantity_id -): +def test_access_available_derived_quantities_reaches(res1d_network, reach_derived_quantity_id): assert reach_derived_quantity_id in res1d_network.derived_quantities -def test_available_derived_quantities_by_locations_nodes( - res1d_network, node_derived_quantity_id -): +def test_available_derived_quantities_by_locations_nodes(res1d_network, node_derived_quantity_id): assert node_derived_quantity_id in res1d_network.nodes.derived_quantities assert node_derived_quantity_id not in res1d_network.reaches.derived_quantities @@ -73,24 +67,17 @@ def test_available_derived_quantities_by_single_location_node( res1d_network, node_derived_quantity_id ): assert node_derived_quantity_id in res1d_network.nodes["1"].derived_quantities - assert ( - node_derived_quantity_id - not in res1d_network.reaches["100l1"].derived_quantities - ) + assert node_derived_quantity_id not in res1d_network.reaches["100l1"].derived_quantities def test_available_Derived_quantities_by_single_location_reach( res1d_network, reach_derived_quantity_id ): - assert ( - reach_derived_quantity_id in res1d_network.reaches["100l1"].derived_quantities - ) + assert reach_derived_quantity_id in res1d_network.reaches["100l1"].derived_quantities assert reach_derived_quantity_id not in res1d_network.nodes["1"].derived_quantities -def test_read_all_with_derived_quantities_nodes( - res1d_network, node_derived_quantity_id -): +def test_read_all_with_derived_quantities_nodes(res1d_network, node_derived_quantity_id): df = res1d_network.nodes.read(include_derived=True) assert isinstance(df, pd.DataFrame) assert len(df) > 0 @@ -98,9 +85,7 @@ def test_read_all_with_derived_quantities_nodes( assert node_derived_quantity_id in quantities -def test_read_all_with_derived_quantities_reaches( - res1d_network, reach_derived_quantity_id -): +def test_read_all_with_derived_quantities_reaches(res1d_network, reach_derived_quantity_id): df = res1d_network.reaches.read(include_derived=True) assert isinstance(df, pd.DataFrame) assert len(df) > 0 @@ -108,9 +93,7 @@ def test_read_all_with_derived_quantities_reaches( assert reach_derived_quantity_id in quantities -def test_read_derived_quantities_locations_nodes( - res1d_network, node_derived_quantity_id -): +def test_read_derived_quantities_locations_nodes(res1d_network, node_derived_quantity_id): derived_result_quantity = getattr(res1d_network.nodes, node_derived_quantity_id) df = derived_result_quantity.read() assert df is not None @@ -119,9 +102,7 @@ def test_read_derived_quantities_locations_nodes( assert (df.columns.get_level_values("quantity") == node_derived_quantity_id).all() -def test_read_derived_quantities_locations_reaches( - res1d_network, reach_derived_quantity_id -): +def test_read_derived_quantities_locations_reaches(res1d_network, reach_derived_quantity_id): derived_result_quantity = getattr(res1d_network.reaches, reach_derived_quantity_id) df = derived_result_quantity.read() assert df is not None @@ -130,12 +111,8 @@ def test_read_derived_quantities_locations_reaches( assert (df.columns.get_level_values("quantity") == reach_derived_quantity_id).all() -def test_read_derived_quantities_single_location_node( - res1d_network, node_derived_quantity_id -): - derived_result_quantity = getattr( - res1d_network.nodes["1"], node_derived_quantity_id - ) +def test_read_derived_quantities_single_location_node(res1d_network, node_derived_quantity_id): + derived_result_quantity = getattr(res1d_network.nodes["1"], node_derived_quantity_id) df = derived_result_quantity.read() assert df is not None assert len(df) > 0 @@ -143,12 +120,8 @@ def test_read_derived_quantities_single_location_node( assert (df.columns.get_level_values("quantity") == node_derived_quantity_id).all() -def test_read_derived_quantities_single_location_reach( - res1d_network, reach_derived_quantity_id -): - derived_result_quantity = getattr( - res1d_network.reaches["100l1"], reach_derived_quantity_id - ) +def test_read_derived_quantities_single_location_reach(res1d_network, reach_derived_quantity_id): + derived_result_quantity = getattr(res1d_network.reaches["100l1"], reach_derived_quantity_id) df = derived_result_quantity.read() assert df is not None assert len(df) > 0 @@ -241,9 +214,7 @@ def test_custom_derived_quantity_example(res1d_network): assert_frame_equal(df_nodes, df_nodes_expected) df_reaches = res1d_network.reaches.WaterLevelPlusOne.read() - df_reaches_expected = ( - res1d_network.reaches.WaterLevel.read(column_mode="compact") + 1 - ) + df_reaches_expected = res1d_network.reaches.WaterLevel.read(column_mode="compact") + 1 df_reaches_expected = set_multiindex_level_values( df_reaches_expected, "quantity", "WaterLevelPlusOne" )