Skip to content

Commit

Permalink
fix: make sure netcdf exporters can handle list of timesteps (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
RubenImhoff committed Jul 11, 2024
1 parent 3121cbf commit 953f799
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
32 changes: 24 additions & 8 deletions pysteps/io/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,10 @@ def initialize_forecast_exporter_netcdf(
Start date of the forecast.
timestep: int
Time step of the forecast (minutes).
n_timesteps: int
Number of time steps in the forecast this argument is ignored if
incremental is set to 'timestep'.
n_timesteps: int or list of integers
Number of time steps to forecast or a list of time steps for which the
forecasts are computed (relative to the input time step). The elements of
the list are required to be in ascending order.
shape: tuple of int
Two-element tuple defining the shape (height,width) of the forecast
grids.
Expand Down Expand Up @@ -460,8 +461,14 @@ def initialize_forecast_exporter_netcdf(
+ "'timestep' or 'member'"
)

n_timesteps_is_list = isinstance(n_timesteps, list)
if n_timesteps_is_list:
num_timesteps = len(n_timesteps)
else:
num_timesteps = n_timesteps

if incremental == "timestep":
n_timesteps = None
num_timesteps = None
elif incremental == "member":
n_ens_members = None
elif incremental is not None:
Expand Down Expand Up @@ -498,7 +505,7 @@ def initialize_forecast_exporter_netcdf(
h, w = shape

ncf.createDimension("ens_number", size=n_ens_members)
ncf.createDimension("time", size=n_timesteps)
ncf.createDimension("time", size=num_timesteps)
ncf.createDimension("y", size=h)
ncf.createDimension("x", size=w)

Expand Down Expand Up @@ -585,7 +592,10 @@ def initialize_forecast_exporter_netcdf(

var_time = ncf.createVariable("time", int, dimensions=("time",))
if incremental != "timestep":
var_time[:] = [i * timestep * 60 for i in range(1, n_timesteps + 1)]
if n_timesteps_is_list:
var_time[:] = np.array(n_timesteps) * timestep * 60
else:
var_time[:] = [i * timestep * 60 for i in range(1, n_timesteps + 1)]
var_time.long_name = "forecast time"
startdate_str = datetime.strftime(startdate, "%Y-%m-%d %H:%M:%S")
var_time.units = "seconds since %s" % startdate_str
Expand Down Expand Up @@ -635,7 +645,8 @@ def initialize_forecast_exporter_netcdf(
exporter["timestep"] = timestep
exporter["metadata"] = metadata
exporter["incremental"] = incremental
exporter["num_timesteps"] = n_timesteps
exporter["num_timesteps"] = num_timesteps
exporter["timesteps"] = n_timesteps
exporter["num_ens_members"] = n_ens_members
exporter["shape"] = shape

Expand Down Expand Up @@ -853,7 +864,12 @@ def _export_netcdf(field, exporter):
else:
var_f[var_f.shape[0], :, :] = field
var_time = exporter["var_time"]
var_time[len(var_time) - 1] = len(var_time) * exporter["timestep"] * 60
if isinstance(exporter["timesteps"], list):
var_time[len(var_time) - 1] = (
exporter["timesteps"][len(var_time) - 1] * exporter["timestep"] * 60
)
else:
var_time[len(var_time) - 1] = len(var_time) * exporter["timestep"] * 60
else:
var_f[var_f.shape[0], :, :, :] = field
var_ens_num = exporter["var_ens_num"]
Expand Down
22 changes: 14 additions & 8 deletions pysteps/tests/test_exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
"fill_value",
"scale_factor",
"offset",
"n_timesteps",
)

exporter_arg_values = [
(1, None, np.float32, None, None, None),
(1, "timestep", np.float32, 65535, None, None),
(2, None, np.float32, 65535, None, None),
(2, "timestep", np.float32, None, None, None),
(2, "member", np.float64, None, 0.01, 1.0),
(1, None, np.float32, None, None, None, 3),
(1, "timestep", np.float32, 65535, None, None, 3),
(2, None, np.float32, 65535, None, None, 3),
(2, None, np.float32, 65535, None, None, [1, 2, 4]),
(2, "timestep", np.float32, None, None, None, 3),
(2, "timestep", np.float32, None, None, None, [1, 2, 4]),
(2, "member", np.float64, None, 0.01, 1.0, 3),
]


Expand All @@ -54,7 +57,7 @@ def test_get_geotiff_filename():

@pytest.mark.parametrize(exporter_arg_names, exporter_arg_values)
def test_io_export_netcdf_one_member_one_time_step(
n_ens_members, incremental, datatype, fill_value, scale_factor, offset
n_ens_members, incremental, datatype, fill_value, scale_factor, offset, n_timesteps
):
"""
Test the export netcdf.
Expand All @@ -75,7 +78,6 @@ def test_io_export_netcdf_one_member_one_time_step(
file_path = os.path.join(outpath, outfnprefix + ".nc")
startdate = metadata["timestamps"][0]
timestep = metadata["accutime"]
n_timesteps = 3
shape = precip.shape[1:]

exporter = initialize_forecast_exporter_netcdf(
Expand All @@ -100,7 +102,11 @@ def test_io_export_netcdf_one_member_one_time_step(
if incremental == None:
export_forecast_dataset(precip, exporter)
if incremental == "timestep":
for t in range(n_timesteps):
if isinstance(n_timesteps, list):
timesteps = len(n_timesteps)
else:
timesteps = n_timesteps
for t in range(timesteps):
if n_ens_members > 1:
export_forecast_dataset(precip[:, t, :, :], exporter)
else:
Expand Down

0 comments on commit 953f799

Please sign in to comment.