Skip to content

Commit

Permalink
Add tmp_zarr fixture (#1784)
Browse files Browse the repository at this point in the history
* Add tmp_zarr fixture

And specify warning in test_file_warnings()

* Review feedback
  • Loading branch information
VeckoTheGecko authored Dec 3, 2024
1 parent 1995da0 commit 0d14640
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 72 deletions.
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest


@pytest.fixture()
def tmp_zarrfile(tmp_path, request):
test_name = request.node.name
yield tmp_path / f"{test_name}-output.zarr"
7 changes: 3 additions & 4 deletions tests/test_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def test_analyticalAgrid(mode):
@pytest.mark.parametrize("v", [1, -0.3, 0, -1])
@pytest.mark.parametrize("w", [None, 1, -0.3, 0, -1])
@pytest.mark.parametrize("direction", [1, -1])
def test_uniform_analytical(mode, u, v, w, direction, tmpdir):
def test_uniform_analytical(mode, u, v, w, direction, tmp_zarrfile):
lon = np.arange(0, 15, dtype=np.float32)
lat = np.arange(0, 15, dtype=np.float32)
if w is not None:
Expand All @@ -625,15 +625,14 @@ def test_uniform_analytical(mode, u, v, w, direction, tmpdir):
x0, y0, z0 = 6.1, 6.2, 20
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=x0, lat=y0, depth=z0)

outfile_path = tmpdir.join("uniformanalytical.zarr")
outfile = pset.ParticleFile(name=outfile_path, outputdt=1, chunks=(1, 1))
outfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1, chunks=(1, 1))
pset.execute(AdvectionAnalytical, runtime=4, dt=direction, output_file=outfile)
assert np.abs(pset.lon - x0 - pset.time * u) < 1e-6
assert np.abs(pset.lat - y0 - pset.time * v) < 1e-6
if w is not None:
assert np.abs(pset.depth - z0 - pset.time * w) < 1e-4

ds = xr.open_zarr(outfile_path)
ds = xr.open_zarr(tmp_zarrfile)
times = (direction * ds["time"][:]).values.astype("timedelta64[s]")[0]
timeref = np.arange(1, 5).astype("timedelta64[s]")
assert np.allclose(times, timeref, atol=np.timedelta64(1, "ms"))
Expand Down
7 changes: 3 additions & 4 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,7 @@ def SampleUV2(particle, fieldset, time):
assert abs(pset.lat[0] - 0.5) < 1e-9


def test_fieldset_write(tmpdir):
filepath = tmpdir.join("fieldset_write.zarr")
def test_fieldset_write(tmp_zarrfile):
xdim, ydim = 3, 4
lon = np.linspace(0.0, 10.0, xdim, dtype=np.float32)
lat = np.linspace(0.0, 10.0, ydim, dtype=np.float32)
Expand All @@ -674,12 +673,12 @@ def UpdateU(particle, fieldset, time):
fieldset.U.grid.time[0] = time

pset = ParticleSet(fieldset, pclass=ScipyParticle, lon=5, lat=5)
ofile = pset.ParticleFile(name=filepath, outputdt=2.0)
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=2.0)
pset.execute(UpdateU, dt=1, runtime=10, output_file=ofile)

assert fieldset.U.data[0, 1, 0] == 11

da = xr.open_dataset(str(filepath).replace(".zarr", "_0005U.nc"))
da = xr.open_dataset(str(tmp_zarrfile).replace(".zarr", "_0005U.nc"))
assert np.allclose(fieldset.U.data, da["U"].values, atol=1.0)


Expand Down
7 changes: 3 additions & 4 deletions tests/test_fieldset_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def Recover(particle, fieldset, time):


@pytest.mark.parametrize("mode", ["jit", "scipy"])
def test_fieldset_sampling_updating_order(mode, tmpdir):
def test_fieldset_sampling_updating_order(mode, tmp_zarrfile):
def calc_p(t, y, x):
return 10 * t + x + 0.2 * y

Expand Down Expand Up @@ -923,11 +923,10 @@ def SampleP(particle, fieldset, time):

kernels = [AdvectionRK4, SampleP]

filename = tmpdir.join("interpolation_offset.zarr")
pfile = pset.ParticleFile(filename, outputdt=1)
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)
pset.execute(kernels, endtime=1, dt=1, output_file=pfile)

ds = xr.open_zarr(filename)
ds = xr.open_zarr(tmp_zarrfile)
for t in range(len(ds["obs"])):
for i in range(len(ds["trajectory"])):
assert np.isclose(
Expand Down
88 changes: 36 additions & 52 deletions tests/test_particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ def fieldset():


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_metadata(fieldset, mode, tmpdir):
filepath = tmpdir.join("pfile_metadata.zarr")
def test_metadata(fieldset, mode, tmp_zarrfile):
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=0, lat=0)

pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(filepath))
pset.execute(DoNothing, runtime=1, output_file=pset.ParticleFile(tmp_zarrfile))

ds = xr.open_zarr(filepath)
ds = xr.open_zarr(tmp_zarrfile)
assert ds.attrs["parcels_kernels"].lower() == f"{mode}ParticleDoNothing".lower()


Expand All @@ -57,38 +56,36 @@ def test_pfile_array_write_zarr_memorystore(fieldset, mode):


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_pfile_array_remove_particles(fieldset, mode, tmpdir):
def test_pfile_array_remove_particles(fieldset, mode, tmp_zarrfile):
npart = 10
filepath = tmpdir.join("pfile_array_remove_particles.zarr")
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
pfile = pset.ParticleFile(filepath)
pfile = pset.ParticleFile(tmp_zarrfile)
pfile.write(pset, 0)
pset.remove_indices(3)
for p in pset:
p.time = 1
pfile.write(pset, 1)

ds = xr.open_zarr(filepath)
ds = xr.open_zarr(tmp_zarrfile)
timearr = ds["time"][:]
assert (np.isnat(timearr[3, 1])) and (np.isfinite(timearr[3, 0]))
ds.close()


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_pfile_set_towrite_False(fieldset, mode, tmpdir):
def test_pfile_set_towrite_False(fieldset, mode, tmp_zarrfile):
npart = 10
filepath = tmpdir.join("pfile_set_towrite_False.zarr")
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart))
pset.set_variable_write_status("depth", False)
pset.set_variable_write_status("lat", False)
pfile = pset.ParticleFile(filepath, outputdt=1)
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)

def Update_lon(particle, fieldset, time):
particle_dlon += 0.1 # noqa

pset.execute(Update_lon, runtime=10, output_file=pfile)

ds = xr.open_zarr(filepath)
ds = xr.open_zarr(tmp_zarrfile)
assert "time" in ds
assert "z" not in ds
assert "lat" not in ds
Expand All @@ -101,19 +98,18 @@ def Update_lon(particle, fieldset, time):

@pytest.mark.parametrize("mode", ["scipy", "jit"])
@pytest.mark.parametrize("chunks_obs", [1, None])
def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmpdir):
def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmp_zarrfile):
npart = 10
filepath = tmpdir.join("pfile_array_remove_particles.zarr")
pset = ParticleSet(fieldset, pclass=ptype[mode], lon=np.linspace(0, 1, npart), lat=0.5 * np.ones(npart), time=0)
chunks = (npart, chunks_obs) if chunks_obs else None
pfile = pset.ParticleFile(filepath, chunks=chunks)
pfile = pset.ParticleFile(tmp_zarrfile, chunks=chunks)
pfile.write(pset, 0)
for _ in range(npart):
pset.remove_indices(-1)
pfile.write(pset, 1)
pfile.write(pset, 2)

ds = xr.open_zarr(filepath)
ds = xr.open_zarr(tmp_zarrfile)
assert np.allclose(ds["time"][:, 0], np.timedelta64(0, "s"), atol=np.timedelta64(1, "ms"))
if chunks_obs is not None:
assert ds["time"][:].shape == chunks
Expand All @@ -124,26 +120,22 @@ def test_pfile_array_remove_all_particles(fieldset, mode, chunks_obs, tmpdir):


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_variable_write_double(fieldset, mode, tmpdir):
filepath = tmpdir.join("pfile_variable_write_double.zarr")

def test_variable_write_double(fieldset, mode, tmp_zarrfile):
def Update_lon(particle, fieldset, time):
particle_dlon += 0.1 # noqa

pset = ParticleSet(fieldset, pclass=ptype[mode], lon=[0], lat=[0], lonlatdepth_dtype=np.float64)
ofile = pset.ParticleFile(name=filepath, outputdt=0.00001)
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.00001)
pset.execute(pset.Kernel(Update_lon), endtime=0.001, dt=0.00001, output_file=ofile)

ds = xr.open_zarr(filepath)
ds = xr.open_zarr(tmp_zarrfile)
lons = ds["lon"][:]
assert isinstance(lons.values[0, 0], np.float64)
ds.close()


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_write_dtypes_pfile(fieldset, mode, tmpdir):
filepath = tmpdir.join("pfile_dtypes.zarr")

def test_write_dtypes_pfile(fieldset, mode, tmp_zarrfile):
dtypes = [np.float32, np.float64, np.int32, np.uint32, np.int64, np.uint64]
if mode == "scipy":
dtypes.extend([np.bool_, np.int8, np.uint8, np.int16, np.uint16])
Expand All @@ -152,21 +144,19 @@ def test_write_dtypes_pfile(fieldset, mode, tmpdir):
MyParticle = ptype[mode].add_variables(extra_vars)

pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0)
pfile = pset.ParticleFile(name=filepath, outputdt=1)
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1)
pfile.write(pset, 0)

ds = xr.open_zarr(
filepath, mask_and_scale=False
tmp_zarrfile, mask_and_scale=False
) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float
for d in dtypes:
assert ds[f"v_{d.__name__}"].dtype == d


@pytest.mark.parametrize("mode", ["scipy", "jit"])
@pytest.mark.parametrize("npart", [1, 2, 5])
def test_variable_written_once(fieldset, mode, tmpdir, npart):
filepath = tmpdir.join("pfile_once_written_variables.zarr")

def test_variable_written_once(fieldset, mode, tmp_zarrfile, npart):
def Update_v(particle, fieldset, time):
particle.v_once += 1.0
particle.age += particle.dt
Expand All @@ -181,11 +171,11 @@ def Update_v(particle, fieldset, time):
lat = np.linspace(1, 0, npart)
time = np.arange(0, npart / 10.0, 0.1, dtype=np.float64)
pset = ParticleSet(fieldset, pclass=MyParticle, lon=lon, lat=lat, time=time, v_once=time)
ofile = pset.ParticleFile(name=filepath, outputdt=0.1)
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.1)
pset.execute(pset.Kernel(Update_v), endtime=1, dt=0.1, output_file=ofile)

assert np.allclose(pset.v_once - time - pset.age * 10, 1, atol=1e-5)
ds = xr.open_zarr(filepath)
ds = xr.open_zarr(tmp_zarrfile)
vfile = np.ma.filled(ds["v_once"][:], np.nan)
assert vfile.shape == (npart,)
ds.close()
Expand All @@ -196,7 +186,7 @@ def Update_v(particle, fieldset, time):
@pytest.mark.parametrize("repeatdt", range(1, 3))
@pytest.mark.parametrize("dt", [-1, 1])
@pytest.mark.parametrize("maxvar", [2, 4, 10])
def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, repeatdt, tmpdir, dt, maxvar):
def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, repeatdt, tmp_zarrfile, dt, maxvar):
runtime = 10
fieldset.maxvar = maxvar
pset = None
Expand All @@ -211,8 +201,7 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, rep
pset = ParticleSet(
fieldset, lon=np.zeros(runtime), lat=np.zeros(runtime), pclass=MyParticle, time=list(range(runtime))
)
outfilepath = tmpdir.join("pfile_repeated_release.zarr")
pfile = pset.ParticleFile(outfilepath, outputdt=abs(dt), chunks=(1, 1))
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=abs(dt), chunks=(1, 1))

def IncrLon(particle, fieldset, time):
particle.sample_var += 1.0
Expand All @@ -222,7 +211,7 @@ def IncrLon(particle, fieldset, time):
for _ in range(runtime):
pset.execute(IncrLon, dt=dt, runtime=1.0, output_file=pfile)

ds = xr.open_zarr(outfilepath)
ds = xr.open_zarr(tmp_zarrfile)
samplevar = ds["sample_var"][:]
if type == "repeatdt":
assert samplevar.shape == (runtime // repeatdt, min(maxvar + 1, runtime))
Expand All @@ -232,51 +221,47 @@ def IncrLon(particle, fieldset, time):
# test whether samplevar[:, k] = k
for k in range(samplevar.shape[1]):
assert np.allclose([p for p in samplevar[:, k] if np.isfinite(p)], k + 1)
filesize = os.path.getsize(str(outfilepath))
filesize = os.path.getsize(str(tmp_zarrfile))
assert filesize < 1024 * 65 # test that chunking leads to filesize less than 65KB
ds.close()


@pytest.mark.parametrize("mode", ["scipy", "jit"])
@pytest.mark.parametrize("repeatdt", [1, 2])
@pytest.mark.parametrize("nump", [1, 10])
def test_pfile_chunks_repeatedrelease(fieldset, mode, repeatdt, nump, tmpdir):
def test_pfile_chunks_repeatedrelease(fieldset, mode, repeatdt, nump, tmp_zarrfile):
runtime = 8
pset = ParticleSet(
fieldset, pclass=ptype[mode], lon=np.zeros((nump, 1)), lat=np.zeros((nump, 1)), repeatdt=repeatdt
)
outfilepath = tmpdir.join("pfile_chunks_repeatedrelease.zarr")
chunks = (20, 10)
pfile = pset.ParticleFile(outfilepath, outputdt=1, chunks=chunks)
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1, chunks=chunks)

def DoNothing(particle, fieldset, time):
pass

pset.execute(DoNothing, dt=1, runtime=runtime, output_file=pfile)
ds = xr.open_zarr(outfilepath)
ds = xr.open_zarr(tmp_zarrfile)
assert ds["time"].shape == (int(nump * runtime / repeatdt), chunks[1])


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_write_timebackward(fieldset, mode, tmpdir):
outfilepath = tmpdir.join("pfile_write_timebackward.zarr")

def test_write_timebackward(fieldset, mode, tmp_zarrfile):
def Update_lon(particle, fieldset, time):
particle_dlon -= 0.1 * particle.dt # noqa

pset = ParticleSet(fieldset, pclass=ptype[mode], lat=np.linspace(0, 1, 3), lon=[0, 0, 0], time=[1, 2, 3])
pfile = pset.ParticleFile(name=outfilepath, outputdt=1.0)
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=1.0)
pset.execute(pset.Kernel(Update_lon), runtime=4, dt=-1.0, output_file=pfile)
ds = xr.open_zarr(outfilepath)
ds = xr.open_zarr(tmp_zarrfile)
trajs = ds["trajectory"][:]
assert trajs.values.dtype == "int64"
assert np.all(np.diff(trajs.values) < 0) # all particles written in order of release
ds.close()


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_write_xiyi(fieldset, mode, tmpdir):
outfilepath = tmpdir.join("pfile_xiyi.zarr")
def test_write_xiyi(fieldset, mode, tmp_zarrfile):
fieldset.U.data[:] = 1 # set a non-zero zonal velocity
fieldset.add_field(Field(name="P", data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2]))
dt = 3600
Expand Down Expand Up @@ -304,10 +289,10 @@ def SampleP(particle, fieldset, time):
_ = fieldset.P[particle] # To trigger sampling of the P field

pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64)
pfile = pset.ParticleFile(name=outfilepath, outputdt=dt)
pfile = pset.ParticleFile(name=tmp_zarrfile, outputdt=dt)
pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10 * dt, dt=dt, output_file=pfile)

ds = xr.open_zarr(outfilepath)
ds = xr.open_zarr(tmp_zarrfile)
pxi0 = ds["pxi0"][:].values.astype(np.int32)
pxi1 = ds["pxi1"][:].values.astype(np.int32)
lons = ds["lon"][:].values
Expand Down Expand Up @@ -335,16 +320,15 @@ def test_set_calendar():


@pytest.mark.parametrize("mode", ["scipy", "jit"])
def test_reset_dt(fieldset, mode, tmpdir):
def test_reset_dt(fieldset, mode, tmp_zarrfile):
# Assert that p.dt gets reset when a write_time is not a multiple of dt
# for p.dt=0.02 to reach outputdt=0.05 and endtime=0.1, the steps should be [0.2, 0.2, 0.1, 0.2, 0.2, 0.1], resulting in 6 kernel executions
filepath = tmpdir.join("pfile_reset_dt.zarr")

def Update_lon(particle, fieldset, time):
particle_dlon += 0.1 # noqa

pset = ParticleSet(fieldset, pclass=ptype[mode], lon=[0], lat=[0], lonlatdepth_dtype=np.float64)
ofile = pset.ParticleFile(name=filepath, outputdt=0.05)
ofile = pset.ParticleFile(name=tmp_zarrfile, outputdt=0.05)
pset.execute(pset.Kernel(Update_lon), endtime=0.12, dt=0.02, output_file=ofile)

assert np.allclose(pset.lon, 0.6)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_particlesets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def test_pset_create_list_with_customvariable(fieldset, mode):

@pytest.mark.parametrize("mode", ["scipy", "jit"])
@pytest.mark.parametrize("restart", [True, False])
def test_pset_create_fromparticlefile(fieldset, mode, restart, tmpdir):
filename = tmpdir.join("pset_fromparticlefile.zarr")
def test_pset_create_fromparticlefile(fieldset, mode, restart, tmp_zarrfile):
lon = np.linspace(0, 1, 10, dtype=np.float32)
lat = np.linspace(1, 0, 10, dtype=np.float32)

Expand All @@ -89,7 +88,7 @@ def test_pset_create_fromparticlefile(fieldset, mode, restart, tmpdir):
TestParticle = TestParticle.add_variable("p3", np.float64, to_write="once")

pset = ParticleSet(fieldset, lon=lon, lat=lat, depth=[4] * len(lon), pclass=TestParticle, p3=np.arange(len(lon)))
pfile = pset.ParticleFile(filename, outputdt=1)
pfile = pset.ParticleFile(tmp_zarrfile, outputdt=1)

def Kernel(particle, fieldset, time):
particle.p = 2.0
Expand All @@ -99,7 +98,7 @@ def Kernel(particle, fieldset, time):
pset.execute(Kernel, runtime=2, dt=1, output_file=pfile)

pset_new = ParticleSet.from_particlefile(
fieldset, pclass=TestParticle, filename=filename, restart=restart, repeatdt=1
fieldset, pclass=TestParticle, filename=tmp_zarrfile, restart=restart, repeatdt=1
)

for var in ["lon", "lat", "depth", "time", "p", "p2", "p3"]:
Expand Down
Loading

0 comments on commit 0d14640

Please sign in to comment.