Skip to content

Commit

Permalink
Fix creation of mask in grid_create functions
Browse files Browse the repository at this point in the history
The previous logic (involving the use of `.any()`) was causing an error
in recent versions of NumPy. On inspection, the use of `.any()` looked
wrong, and not what was intended. In addition, the use of `.data` in
other functions caused errors; I think these functions that used `.data`
were never being called, so these errors weren't showing up. This commit
fixes the logic to what I think was intended.

This fix caused a failure in `test_field_regrid_zeroregion`. Bob Oehmke
and I looked at the test together and felt that the test was incorrect
and should actually be asserting that values are -100 where the mask is
0 (rather than having 0 values there), so I have made this change as
well. This was passing previously because the mask was never 0 due to
the above bug.

Resolves #329
  • Loading branch information
billsacks committed Dec 11, 2024
1 parent 523ca69 commit 12b248a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
5 changes: 3 additions & 2 deletions src/addon/esmpy/src/esmpy/test/test_api/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,12 @@ def test_field_regrid_zeroregion(self):
dst_mask_values=np.atleast_1d(np.array([0])))
dstfield = rh(srcfield, dstfield, zero_region=Region.SELECT)

# validate that the masked values were not zeroed out
# validate that the masked values were not zeroed out (zero_region=Region.SELECT
# should leave the masked values at their original values)
for i in range(dstfield.data.shape[x]):
for j in range(dstfield.data.shape[y]):
if dstfield.grid.mask[StaggerLoc.CENTER][i, j] == 0:
assert(dstfield[i, j] == 0)
assert(dstfield.data[i, j] == -100)

@pytest.mark.skipif(_ESMF_PIO==False, reason="PIO required in ESMF build")
@pytest.mark.skipif(_ESMF_NETCDF==False, reason="NetCDF required in ESMF build")
Expand Down
20 changes: 10 additions & 10 deletions src/addon/esmpy/src/esmpy/util/grid_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def grid_create_from_coordinates(xcoords, ycoords, xcorners=False, ycorners=Fals
if domask:
mask = grid.add_item(esmpy.GridItem.MASK)
mask[:] = 1
mask[np.where((1.75 <= gridXCenter.any() < 2.25) &
(1.75 <= gridYCenter.any() < 2.25))] = 0
mask[np.where((1.75 <= gridXCenter) & (gridXCenter < 2.25) &
(1.75 <= gridYCenter) & (gridYCenter < 2.25))] = 0

# add arbitrary areas values
if doarea:
Expand Down Expand Up @@ -180,8 +180,8 @@ def grid_create_from_coordinates_periodic(longitudes, latitudes, lon_corners=Fal
if domask:
mask = grid.add_item(esmpy.GridItem.MASK)
mask[:] = 1
mask[np.where((1.75 <= gridXCenter.any() < 2.25) &
(1.75 <= gridYCenter.any() < 2.25))] = 0
mask[np.where((1.75 <= gridXCenter) & (gridXCenter < 2.25) &
(1.75 <= gridYCenter) & (gridYCenter < 2.25))] = 0

return grid

Expand Down Expand Up @@ -286,9 +286,9 @@ def grid_create_from_coordinates_3d(xcoords, ycoords, zcoords, xcorners=False, y
if domask:
mask = grid.add_item(esmpy.GridItem.MASK)
mask[:] = 1
mask[np.where((1.75 < gridXCenter.data < 2.25) &
(1.75 < gridYCenter.data < 2.25) &
(1.75 < gridZCenter.data < 2.25))] = 0
mask[np.where((1.75 <= gridXCenter) & (gridXCenter < 2.25) &
(1.75 <= gridYCenter) & (gridYCenter < 2.25) &
(1.75 <= gridZCenter) & (gridZCenter < 2.25))] = 0

# add arbitrary areas values
if doarea:
Expand Down Expand Up @@ -394,9 +394,9 @@ def grid_create_from_coordinates_periodic_3d(longitudes, latitudes, heights,
if domask:
mask = grid.add_item(esmpy.GridItem.MASK)
mask[:] = 1
mask[np.where((1.75 <= gridXCenter.data < 2.25) &
(1.75 <= gridYCenter.data < 2.25) &
(1.75 <= gridZCenter.data < 2.25))] = 0
mask[np.where((1.75 <= gridXCenter) & (gridXCenter < 2.25) &
(1.75 <= gridYCenter) & (gridYCenter < 2.25) &
(1.75 <= gridZCenter) & (gridZCenter < 2.25))] = 0

return grid

Expand Down

0 comments on commit 12b248a

Please sign in to comment.