Skip to content

Commit

Permalink
Merge pull request #152 from DHI/minor/filter_multiple_periods
Browse files Browse the repository at this point in the history
Implement possibility to time filter on several periods.
  • Loading branch information
ryan-kipawa authored Dec 19, 2024
2 parents 922b7e1 + 20aaa59 commit 7d24ff6
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 19 deletions.
51 changes: 39 additions & 12 deletions mikeio1d/filter/time_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,51 @@ def apply(self, filter: Filter, result_data: ResultData | None = None):
if not self.use_filter():
return

time = self._time
time_intervals = self._determine_time_intervals(self._time)

for time_interval in time_intervals:
start, end = time_interval
period = self.create_period(start, end)
filter.Periods.Add(period)

def _determine_time_intervals(
self, time_intervals: None | slice | tuple | list
) -> list[tuple[pd.Timestamp, pd.Timestamp]]:
# In case of slice convert time_intervals to a list containing that slice.
# Needed to be able to evaluate contains_time_intervals properly.
time_intervals = [time_intervals] if isinstance(time_intervals, slice) else time_intervals

contains_time_intervals = (
all(isinstance(t, slice) for t in time_intervals)
or all(isinstance(t, tuple) for t in time_intervals)
or all(isinstance(t, list) for t in time_intervals)
)

if contains_time_intervals:
return [self._determine_start_and_end_time(t) for t in time_intervals]
else:
return [self._determine_start_and_end_time(time_intervals)]

def _determine_start_and_end_time(
self, time_interval: None | slice | tuple[any, any] | list[any, any]
) -> tuple[pd.Timestamp, pd.Timestamp]:
start, end = None, None
if isinstance(time, slice):
start = time.start
end = time.stop
elif isinstance(time, tuple) or isinstance(time, list):
start, end = time

if isinstance(time_interval, slice):
start = time_interval.start
end = time_interval.stop
elif isinstance(time_interval, tuple) or isinstance(time_interval, list):
start, end = time_interval
else:
raise ValueError("time parameter must be a slice, tuple or list")

if start is not None:
start = pd.to_datetime(start)
start = self._convert_to_datetime(start)
end = self._convert_to_datetime(end)

if end is not None:
end = pd.to_datetime(end)
return (start, end)

period = self.create_period(start, end)
filter.Periods.Add(period)
def _convert_to_datetime(self, time: str | datetime) -> pd.Timestamp:
return pd.to_datetime(time) if time is not None else None

def create_period(self, start: None | datetime, end: None | datetime) -> Period:
"""Create a DHI.Mike1D.ResultDataAccess.Period object."""
Expand Down
2 changes: 1 addition & 1 deletion mikeio1d/result_network/result_reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def interpolate_reach_critical_level(self, chainage: float) -> float:
@property
def reaches(self) -> List[IRes1DReach]:
"""List of IRes1DReach corresponding to this result location."""
warnings.warn("The 'reaches' property is deprecated. Use 'm1d_reaches' instead.")
warnings.warn("The 'reaches' property is deprecated. Use 'res1d_reaches' instead.")
return self.res1d_reaches

# endregion
Expand Down
7 changes: 2 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@


class Helpers:
"""
Class containing helper methods for performing tests.
"""
"""Class containing helper methods for performing tests."""

@staticmethod
def assert_shared_columns_equal(df_ref, df):
"""
Compares columns in df to the ones in df_ref.
"""Compare columns in df to the ones in df_ref.
Note that df_ref typically has more columns than df.
Comparison is performed only in columns of df.
Expand Down
32 changes: 32 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,35 @@ def test_mikeio1d_all_column_modes_basic(extension, column_mode):
assert len(df) > 0
df = res.read(column_mode=column_mode)
assert len(df) > 0


@pytest.mark.slow
@pytest.mark.parametrize(
"time",
[
slice("1994-08-07 16:35:00.000", "1994-08-07 16:37:07.560000"),
slice(None, "1994-08-07 16:37:07.560000"),
slice("1994-08-07 18:32:07.967000", None),
[
slice("1994-08-07 16:35:00.000", "1994-08-07 16:37:07.560000"),
slice("1994-08-07 16:37:07.560000", None),
],
[
slice("1994-08-07 16:48:00.000", "1994-08-07 16:58:12.888000"),
slice("1994-08-07 16:35:00.000", "1994-08-07 16:38:00.000000"),
slice("1994-08-07 16:38:00.000", "1994-08-07 16:48:00.000000"),
],
],
)
def test_mikeio1d_network_res1d_using_time_filters(time, helpers):
for name in testdata_name():
path = getattr(testdata, name)
if not path.endswith("network.res1d"):
continue

df = Res1D(path, time=time).read(column_mode=ColumnMode.STRING)
df_expected = testdata.get_expected_dataframe(name)
if len(df.index) != len(df_expected):
df_expected = df_expected.loc[df.index]
assert_frame_equal(df_expected, df)

37 changes: 36 additions & 1 deletion tests/test_res1d_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ def test_res1d_filter_using_flow_split(flow_split_file_path, helpers):
@pytest.mark.parametrize(
"time, expected_len, expected_start, expected_end",
[
(None, 110, "1994-08-07 16:35:00.000", "1994-08-07 18:35:00.000"),
(
None,
110,
"1994-08-07 16:35:00.000",
"1994-08-07 18:35:00.000",
),
(
slice("1994-08-07 16:35:00.000", "1994-08-07 16:37:07.560000"),
3,
Expand Down Expand Up @@ -282,6 +287,36 @@ def test_res1d_filter_using_flow_split(flow_split_file_path, helpers):
"1994-08-07 18:32:07.967000",
"1994-08-07 18:35:00.000",
),
(
[
slice("1994-08-07 16:35:00.000", "1994-08-07 16:38:00.000000"),
slice("1994-08-07 16:38:00.000", "1994-08-07 16:48:00.000000"),
slice("1994-08-07 16:48:00.000", "1994-08-07 16:58:12.888000"),
],
21,
"1994-08-07 16:35:00.000",
"1994-08-07 16:58:12.888000",
),
(
[
slice("1994-08-07 16:35:00.000", "1994-08-07 16:37:07.560000"),
slice("1994-08-07 16:45:00.000", "1994-08-07 16:48:00.000000"),
slice("1994-08-07 16:55:00.000", "1994-08-07 16:58:12.888000"),
],
9,
"1994-08-07 16:35:00.000",
"1994-08-07 16:58:12.888000",
),
(
[
("1994-08-07 16:35:00.000", "1994-08-07 16:37:07.560000"),
("1994-08-07 16:45:00.000", "1994-08-07 16:48:00.000000"),
("1994-08-07 16:55:00.000", "1994-08-07 16:58:12.888000"),
],
9,
"1994-08-07 16:35:00.000",
"1994-08-07 16:58:12.888000",
),
],
)
def test_res1d_filter_time(test_file_path, time, expected_len, expected_start, expected_end):
Expand Down

0 comments on commit 7d24ff6

Please sign in to comment.