Skip to content

Commit

Permalink
Consider splits and merges in tdating (#349)
Browse files Browse the repository at this point in the history
* Ignore 0 in cell match

Since 0 in labels means that no cell was identified in that pixel, we need to
ignore 0 when matching. Otherwise we could have a situation where e.g the
cell overlaps 0.55 with 0 and 0.45 with some cell and it would end up unmatched
(`ID_coverage` would be 0 and a new track would be initialized)

* Get required match overlap frac as parameter

* Store information about splits

If the advected cell overlaps > 10% with more than one cell at next timestep, consider
the advected cell as a split cell. Store information of which IDs the cell split to.
Also for cells at current timestep, mark cells that resulted from splits

* Store information about merges

If the cell at current timestep is overlapped > 10% by more than one advected cell,
consider it merged and store IDs of previous cells.
Also mark cells from previous timestep is they will merge at next timestep.

* Add columns to cell dataframes

* Make splits/merges output optional

* Add options to specify fractions required for matching/splitting/merging

* Update tests to account for split/merge output in tdating

* Fix unused variables code check

* Fix match_frac argument in tracking

* Refactor to avoid chained assignment warnings in pandas

* Add short example

---------

Co-authored-by: Daniele Nerini <daniele.nerini@gmail.com>
  • Loading branch information
ritvje and dnerini committed Jul 22, 2024
1 parent d77fe73 commit 3e433ff
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 59 deletions.
11 changes: 11 additions & 0 deletions examples/thunderstorm_detection_and_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@
# Properties of one of the identified cells:
print(cells_id.iloc[0])

###############################################################################
# Optionally, one can also ask to consider splits and merges of thunderstorm cells.
# A cell at time t is considered to split if it will verlap more than 10% with more than
# one cell at time t+1. Conversely, a cell is considered to be a merge, if more
# than one cells fron time t will overlap more than 10% with it.

cells_id, labels = tstorm_detect.detection(
input_image, time=time, output_splits_merges=True
)
print(cells_id.iloc[0])

###############################################################################
# Example of thunderstorm tracking over a timeseries
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
59 changes: 55 additions & 4 deletions pysteps/feature/tstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def detection(
minmax=41,
mindis=10,
output_feat=False,
output_splits_merges=False,
time="000000000",
):
"""
Expand Down Expand Up @@ -93,6 +94,10 @@ def detection(
smaller distance will be merged. The default is 10 km.
output_feat: bool, optional
Set to True to return only the cell coordinates.
output_split_merge: bool, optional
Set to True to return additional columns in the dataframe for describing the
splitting and merging of cells. Note that columns are initialized with None,
and the information needs to be analyzed while tracking.
time: string, optional
Date and time as string. Used to label time in the resulting dataframe.
The default is '000000000'.
Expand Down Expand Up @@ -166,7 +171,15 @@ def detection(

areas, lines = breakup(input_image, np.nanmin(input_image.flatten()), maxima_dis)

cells_id, labels = get_profile(areas, binary, input_image, loc_max, time, minref)
cells_id, labels = get_profile(
areas,
binary,
input_image,
loc_max,
time,
minref,
output_splits_merges=output_splits_merges,
)

if max_num_features is not None:
idx = np.argsort(cells_id.area.to_numpy())[::-1]
Expand Down Expand Up @@ -225,10 +238,12 @@ def longdistance(loc_max, mindis):
return new_max


def get_profile(areas, binary, ref, loc_max, time, minref):
def get_profile(areas, binary, ref, loc_max, time, minref, output_splits_merges=False):
"""
This function returns the identified cells in a dataframe including their x,y
locations, location of their maxima, maximum reflectivity and contours.
Optionally, the dataframe can include columns for storing information regarding
splitting and merging of cells.
"""
cells = areas * binary
cell_labels = cells[loc_max]
Expand All @@ -255,11 +270,47 @@ def get_profile(areas, binary, ref, loc_max, time, minref):
"area": len(x),
}
)
if output_splits_merges:
cells_id[-1].update(
{
"splitted": None,
"split_IDs": None,
"merged": None,
"merged_IDs": None,
"results_from_split": None,
"will_merge": None,
}
)
labels[cells == cell_labels[n]] = this_id

columns = [
"ID",
"time",
"x",
"y",
"cen_x",
"cen_y",
"max_ref",
"cont",
"area",
]
if output_splits_merges:
columns.extend(
[
"splitted",
"split_IDs",
"merged",
"merged_IDs",
"results_from_split",
"will_merge",
]
)
cells_id = pd.DataFrame(
data=cells_id,
index=range(len(cell_labels)),
columns=["ID", "time", "x", "y", "cen_x", "cen_y", "max_ref", "cont", "area"],
columns=columns,
)

if output_splits_merges:
cells_id["split_IDs"] = cells_id["split_IDs"].astype("object")
cells_id["merged_IDs"] = cells_id["merged_IDs"].astype("object")
return cells_id, labels
65 changes: 56 additions & 9 deletions pysteps/tests/test_feature_tstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,29 @@
except ModuleNotFoundError:
pass

arg_names = ("source", "output_feat", "dry_input", "max_num_features")
arg_names = (
"source",
"output_feat",
"dry_input",
"max_num_features",
"output_split_merge",
)

arg_values = [
("mch", False, False, None),
("mch", False, False, 5),
("mch", True, False, None),
("mch", True, False, 5),
("mch", False, True, None),
("mch", False, True, 5),
("mch", False, False, None, False),
("mch", False, False, 5, False),
("mch", True, False, None, False),
("mch", True, False, 5, False),
("mch", False, True, None, False),
("mch", False, True, 5, False),
("mch", False, False, None, True),
]


@pytest.mark.parametrize(arg_names, arg_values)
def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_features):
def test_feature_tstorm_detection(
source, output_feat, dry_input, max_num_features, output_split_merge
):
pytest.importorskip("skimage")
pytest.importorskip("pandas")

Expand All @@ -36,7 +45,11 @@ def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_featur

time = "000"
output = detection(
input, time=time, output_feat=output_feat, max_num_features=max_num_features
input,
time=time,
output_feat=output_feat,
max_num_features=max_num_features,
output_splits_merges=output_split_merge,
)

if output_feat:
Expand All @@ -45,6 +58,40 @@ def test_feature_tstorm_detection(source, output_feat, dry_input, max_num_featur
assert output.shape[1] == 2
if max_num_features is not None:
assert output.shape[0] <= max_num_features
elif output_split_merge:
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], DataFrame)
assert isinstance(output[1], np.ndarray)
if max_num_features is not None:
assert output[0].shape[0] <= max_num_features
assert output[0].shape[1] == 15
assert list(output[0].columns) == [
"ID",
"time",
"x",
"y",
"cen_x",
"cen_y",
"max_ref",
"cont",
"area",
"splitted",
"split_IDs",
"merged",
"merged_IDs",
"results_from_split",
"will_merge",
]
assert (output[0].time == time).all()
assert output[1].ndim == 2
assert output[1].shape == input.shape
if not dry_input:
assert output[0].shape[0] > 0
assert sorted(list(output[0].ID)) == sorted(list(np.unique(output[1]))[1:])
else:
assert output[0].shape[0] == 0
assert output[1].sum() == 0
else:
assert isinstance(output, tuple)
assert len(output) == 2
Expand Down
32 changes: 21 additions & 11 deletions pysteps/tests/test_tracking_tdating.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@
from pysteps.utils import to_reflectivity
from pysteps.tests.helpers import get_precipitation_fields

arg_names = ("source", "dry_input")
arg_names = ("source", "dry_input", "output_splits_merges")

arg_values = [
("mch", False),
("mch", False),
("mch", True),
("mch", False, False),
("mch", False, False),
("mch", True, False),
("mch", False, True),
]

arg_names_multistep = ("source", "len_timesteps")
arg_names_multistep = ("source", "len_timesteps", "output_splits_merges")
arg_values_multistep = [
("mch", 6),
("mch", 6, False),
("mch", 6, True),
]


@pytest.mark.parametrize(arg_names_multistep, arg_values_multistep)
def test_tracking_tdating_dating_multistep(source, len_timesteps):
def test_tracking_tdating_dating_multistep(source, len_timesteps, output_splits_merges):
pytest.importorskip("skimage")

input_fields, metadata = get_precipitation_fields(
Expand All @@ -37,6 +39,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
input_fields[0 : len_timesteps // 2],
timelist[0 : len_timesteps // 2],
mintrack=1,
output_splits_merges=output_splits_merges,
)
# Second half of timesteps
tracks_2, cells, _ = dating(
Expand All @@ -46,6 +49,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):
start=2,
cell_list=cells,
label_list=labels,
output_splits_merges=output_splits_merges,
)

# Since we are adding cells, number of tracks should increase
Expand All @@ -67,7 +71,7 @@ def test_tracking_tdating_dating_multistep(source, len_timesteps):


@pytest.mark.parametrize(arg_names, arg_values)
def test_tracking_tdating_dating(source, dry_input):
def test_tracking_tdating_dating(source, dry_input, output_splits_merges):
pytest.importorskip("skimage")
pandas = pytest.importorskip("pandas")

Expand All @@ -80,7 +84,13 @@ def test_tracking_tdating_dating(source, dry_input):

timelist = metadata["timestamps"]

output = dating(input, timelist, mintrack=1)
cell_column_length = 9
if output_splits_merges:
cell_column_length = 15

output = dating(
input, timelist, mintrack=1, output_splits_merges=output_splits_merges
)

# Check output format
assert isinstance(output, tuple)
Expand All @@ -92,12 +102,12 @@ def test_tracking_tdating_dating(source, dry_input):
assert len(output[2]) == input.shape[0]
assert isinstance(output[1][0], pandas.DataFrame)
assert isinstance(output[2][0], np.ndarray)
assert output[1][0].shape[1] == 9
assert output[1][0].shape[1] == cell_column_length
assert output[2][0].shape == input.shape[1:]
if not dry_input:
assert len(output[0]) > 0
assert isinstance(output[0][0], pandas.DataFrame)
assert output[0][0].shape[1] == 9
assert output[0][0].shape[1] == cell_column_length
else:
assert len(output[0]) == 0
assert output[1][0].shape[0] == 0
Expand Down
Loading

0 comments on commit 3e433ff

Please sign in to comment.