Skip to content

Commit

Permalink
Make tests dynamically respond to xesmf presence in environment (chec…
Browse files Browse the repository at this point in the history
…king for ImportError in the right places if no xesmf)
  • Loading branch information
ks905383 committed Feb 13, 2024
1 parent 405d9be commit 83a3fc9
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,25 @@ def test_process_weights_basic():
xr.testing.assert_allclose(ds_compare,ds_t)
# (weights_info isn't currently used by anything)

if _has_xesmf:
def test_process_weights_regrid_weights():
# Now, test with a weights array twice the resolution as the
# ds, so it needs to be regridded
ds = xr.Dataset(coords={'lat':(['lat'],np.array([0,1])),
'lon':(['lon'],np.array([0,1])),
})

# Synthetic weights grid, with double the resolution, and shifted
# by a half degree. Should regrid to the same weights grid as above
weights = xr.DataArray(data=np.array([[-0.5,0.5,0.5,1.5],
[0.5,-0.5,1.5,0.5],
[1.5,2.5,2.5,3.5],
[2.5,1.5,3.5,2.5]]),
dims=['lat','lon'],
coords=[np.array([-0.25,0.25,0.75,1.25]),
np.array([-0.25,0.25,0.75,1.25])])
def test_process_weights_regrid_weights():
# Now, test with a weights array twice the resolution as the
# ds, so it needs to be regridded
ds = xr.Dataset(coords={'lat':(['lat'],np.array([0,1])),
'lon':(['lon'],np.array([0,1])),
})

# Synthetic weights grid, with double the resolution, and shifted
# by a half degree. Should regrid to the same weights grid as above
weights = xr.DataArray(data=np.array([[-0.5,0.5,0.5,1.5],
[0.5,-0.5,1.5,0.5],
[1.5,2.5,2.5,3.5],
[2.5,1.5,3.5,2.5]]),
dims=['lat','lon'],
coords=[np.array([-0.25,0.25,0.75,1.25]),
np.array([-0.25,0.25,0.75,1.25])])

if _has_xesmf:
ds_t,weights_info = process_weights(ds,weights=weights)

ds_compare = xr.Dataset({'weights':(('lat','lon'),np.array([[0,1],[2,3]]))},
Expand All @@ -78,6 +79,10 @@ def test_process_weights_regrid_weights():

# Check if weights were correctly added to ds
xr.testing.assert_allclose(ds_compare,ds_t)
else:
# Should raise ImportError in the no-xesmf environment
with pytest.raises(ImportError):
ds_t,weights_info = process_weights(ds,weights=weights)

def test_process_weights_close_weights():
# Make sure weights that are within `np.allclose` but not exactly
Expand Down Expand Up @@ -142,7 +147,7 @@ def test_create_raster_polygons_with_weights():
'lon':(['lon'],np.array([0,1])),
'bnds':(['bnds'],np.array([0,1]))})

# Synethetic weights grid
# Synthetic weights grid that requires regridding
weights = xr.DataArray(data=np.array([[-0.5,0.5,0.5,1.5],
[0.5,-0.5,1.5,0.5],
[1.5,2.5,2.5,3.5],
Expand All @@ -151,19 +156,24 @@ def test_create_raster_polygons_with_weights():
coords=[np.array([-0.25,0.25,0.75,1.25]),
np.array([-0.25,0.25,0.75,1.25])])

pix_agg = create_raster_polygons(ds,weights=weights)
if _has_xesmf:
pix_agg = create_raster_polygons(ds,weights=weights)

compare_series = pd.Series(data=[np.array(v) for v in [0.,1.,2.,3.]],
index=[0,1,2,3],
name='weights')
compare_series = pd.Series(data=[np.array(v) for v in [0.,1.,2.,3.]],
index=[0,1,2,3],
name='weights')


# There's an issue here in pd.testing.assert_series_equal...
#pd.testing.assert_series_equal(pix_agg['gdf_pixels'].weights,
# compare_series)
#np.testing.assert_allclose(compare_series,pix_agg['gdf_pixels'].weights)
#assert np.allclose(compare_series,pix_agg['gdf_pixels'].weights)
assert np.allclose([float(v) for v in compare_series],[float(v) for v in pix_agg['gdf_pixels'].weights])
# There's an issue here in pd.testing.assert_series_equal...
#pd.testing.assert_series_equal(pix_agg['gdf_pixels'].weights,
# compare_series)
#np.testing.assert_allclose(compare_series,pix_agg['gdf_pixels'].weights)
#assert np.allclose(compare_series,pix_agg['gdf_pixels'].weights)
assert np.allclose([float(v) for v in compare_series],[float(v) for v in pix_agg['gdf_pixels'].weights])
else:
# Should raise ImportError in the no-xesmf environment
with pytest.raises(ImportError):
pix_agg = create_raster_polygons(ds,weights=weights)

def test_create_raster_polygons_at180():
# Make sure raster polygons are correctly built at the 180/-180
Expand Down

0 comments on commit 83a3fc9

Please sign in to comment.