Skip to content

Commit

Permalink
287 cannot find variable z data in netcdf (#297)
Browse files Browse the repository at this point in the history
* add three parameters to function read_netcdf_grid()

* add the three new parameter to Raster constructor and reconstruct_grid() function

* add a function to guess data variable name

* add debug script

* rename realign_grid to _realign_grid. this function meant to be used internally
  • Loading branch information
michaelchin authored Oct 23, 2024
1 parent 9d8d4cd commit f795370
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 41 deletions.
192 changes: 151 additions & 41 deletions gplately/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
import math
import warnings
from multiprocessing import cpu_count
from typing import Union

import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import netCDF4
import pygplates
from cartopy.crs import PlateCarree as _PlateCarree
from cartopy.mpl.geoaxes import GeoAxes as _GeoAxes
Expand Down Expand Up @@ -126,7 +128,8 @@ def fill_raster(data, invalid=None):
return data[tuple(ind)]


def realign_grid(array, lons, lats):
def _realign_grid(array, lons, lats):
"""realigns grid to -180/180 and flips the array if the latitudinal coordinates are decreasing."""
mask_lons = lons > 180

# realign to -180/180
Expand All @@ -142,7 +145,29 @@ def realign_grid(array, lons, lats):
return array, lons, lats


def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None, resize=None):
def _guess_data_variable_name(cdf: netCDF4.Dataset, x_name: str, y_name: str) -> Union[str, None]: # type: ignore
"""best effort to find out the data variable name"""
vars = cdf.variables.keys()
for var in vars:
dimensions = cdf.variables[var].dimensions
if len(dimensions) != 2: # only consider two-dimensional data
continue
else:
if dimensions[0] == y_name and dimensions[1] == x_name:
return var
return None


def read_netcdf_grid(
filename,
return_grids=False,
realign=False,
resample=None,
resize=None,
x_dimension_name: str = "",
y_dimension_name: str = "",
data_variable_name: str = "",
):
"""Read a `netCDF` (.nc) grid from a given `filename` and return its data as a
`MaskedArray`.
Expand Down Expand Up @@ -178,6 +203,20 @@ def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None,
If passed as `resample = (resX, resY)`, the given `netCDF` grid is resized
to the number of columns (resX) and rows (resY).
x_dimension_name : str, optional, default=""
If the grid file uses comman names, such as "x", "lon", "lons" or "longitude", you need not set this parameter.
Otherwise, you need to tell us what the x dimension name is.
y_dimension_name : str, optional, default=""
If the grid file uses comman names, such as "y", "lat", "lats" or "latitude", you need not set this parameter.
Otherwise, you need to tell us what the y dimension name is.
data_variable_name : str, optional, default=""
The program will try its best to determine the data variable name.
However, it would be better if you could tell us what the data variable name is.
Otherwise, the program will guess. The result may/may not be correct.
Returns
-------
grid_z : MaskedArray
Expand All @@ -198,8 +237,6 @@ def find_label(keys, labels):
return label
return None

import netCDF4

# possible permutations of lon/lat/z
label_lon = ["lon", "lons", "longitude", "x", "east", "easting", "eastings"]
label_lat = ["lat", "lats", "latitude", "y", "north", "northing", "northings"]
Expand Down Expand Up @@ -227,35 +264,61 @@ def find_label(keys, labels):
keys = cdf.variables.keys()

# find the names of variables
key_z = find_label(keys, label_z)
key_lon = find_label(keys, label_lon)
key_lat = find_label(keys, label_lat)
if data_variable_name:
key_z = data_variable_name
else:
key_z = find_label(keys, label_z)
if x_dimension_name:
key_lon = x_dimension_name
else:
key_lon = find_label(keys, label_lon)
if y_dimension_name:
key_lat = y_dimension_name
else:
key_lat = find_label(keys, label_lat)

if key_lon is None or key_lat is None:
raise ValueError("Cannot find x,y or lon/lat coordinates in netcdf")
raise ValueError(
f"Cannot find x,y or lon/lat coordinates in netcdf. The dimensions in the file are {cdf.dimensions.keys()}"
)

if key_z is None:
key_z = _guess_data_variable_name(cdf, key_lon, key_lat)

if key_z is None:
raise ValueError("Cannot find z data in netcdf")
raise ValueError(
f"Cannot find z data in netcdf. The variables in the file are {cdf.variables.keys()}"
)

# extract data from cdf variables
# TODO: the dimensions of data may not be (lat, lon). It is possible(but unlikely?) that the dimensions are(lon, lat).
# just note you may need numpy.swapaxes() here.
if len(cdf[key_z].dimensions) != 2:
raise Exception(
f"The data in the netcdf file is not two-dimensional. This function can only handle two-dimensional data."
+ f"The dimensions in the file are {cdf[key_z].dimensions.keys()}"
)
cdf_grid = cdf[key_z][:]
cdf_lon = cdf[key_lon][:]
cdf_lat = cdf[key_lat][:]

# fill missing values
if hasattr(cdf[key_z], 'missing_value') and np.issubdtype(cdf_grid.dtype, np.floating):
if hasattr(cdf[key_z], "missing_value") and np.issubdtype(
cdf_grid.dtype, np.floating
):
fill_value = cdf[key_z].missing_value
cdf_grid[np.isclose(cdf_grid, fill_value, rtol=0.1)] = np.nan

# convert to boolean array
if np.issubdtype(cdf_grid.dtype, np.integer):
unique_grid = np.unique(cdf_grid)
if len(unique_grid) == 2:
if (unique_grid == [0,1]).all():
if (unique_grid == [0, 1]).all():
cdf_grid = cdf_grid.astype(bool)

if realign:
# realign longitudes to -180/180 dateline
cdf_grid_z, cdf_lon, cdf_lat = realign_grid(cdf_grid, cdf_lon, cdf_lat)
cdf_grid_z, cdf_lon, cdf_lat = _realign_grid(cdf_grid, cdf_lon, cdf_lat)
else:
cdf_grid_z = cdf_grid

Expand Down Expand Up @@ -323,9 +386,12 @@ def find_label(keys, labels):
return cdf_grid_z, cdf_lon, cdf_lat
else:
return cdf_grid_z

def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, fill_value=np.nan):
""" Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`.


def write_netcdf_grid(
filename, grid, extent="global", significant_digits=None, fill_value=np.nan
):
"""Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`.
Notes
-----
Expand All @@ -350,9 +416,9 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None,
the data's latitudes, while the columns correspond to the data's longitudes.
extent : list, default=[-180,180,-90,90]
Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon
variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]`
is assumed.
Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon
variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]`
is assumed.
significant_digits : int
Applies lossy data compression up to a specified number of significant digits.
Expand All @@ -369,24 +435,24 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None,
import netCDF4
from gplately import __version__ as _version

if extent == 'global':
if extent == "global":
extent = [-180, 180, -90, 90]
else:
assert len(extent) == 4, "specify the [min lon, max lon, min lat, max lat]"

nrows, ncols = np.shape(grid)

lon_grid = np.linspace(extent[0], extent[1], ncols)
lat_grid = np.linspace(extent[2], extent[3], nrows)

data_kwds = {'compression': 'zlib', 'complevel': 6}
with netCDF4.Dataset(filename, 'w', driver=None) as cdf:
data_kwds = {"compression": "zlib", "complevel": 6}

with netCDF4.Dataset(filename, "w", driver=None) as cdf:
cdf.title = "Grid produced by gplately " + str(_version)
cdf.createDimension('lon', lon_grid.size)
cdf.createDimension('lat', lat_grid.size)
cdf_lon = cdf.createVariable('lon', lon_grid.dtype, ('lon',), **data_kwds)
cdf_lat = cdf.createVariable('lat', lat_grid.dtype, ('lat',), **data_kwds)
cdf.createDimension("lon", lon_grid.size)
cdf.createDimension("lat", lat_grid.size)
cdf_lon = cdf.createVariable("lon", lon_grid.dtype, ("lon",), **data_kwds)
cdf_lat = cdf.createVariable("lat", lat_grid.dtype, ("lat",), **data_kwds)
cdf_lon[:] = lon_grid
cdf_lat[:] = lat_grid

Expand All @@ -399,9 +465,9 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None,
cdf_lat.actual_range = [lat_grid[0], lat_grid[-1]]

# create container variable for CRS: lon/lat WGS84 datum
crso = cdf.createVariable('crs','i4')
crso.long_name = 'Lon/Lat Coords in WGS84'
crso.grid_mapping_name='latitude_longitude'
crso = cdf.createVariable("crs", "i4")
crso.long_name = "Lon/Lat Coords in WGS84"
crso.grid_mapping_name = "latitude_longitude"
crso.longitude_of_prime_meridian = 0.0
crso.semi_major_axis = 6378137.0
crso.inverse_flattening = 298.257223563
Expand All @@ -410,37 +476,40 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None,
# add more keyword arguments for quantizing data
if significant_digits:
# significant_digits needs to be >= 2 so that NaNs are preserved
data_kwds['significant_digits'] = max(2, int(significant_digits))
data_kwds['quantize_mode'] = 'GranularBitRound'
data_kwds["significant_digits"] = max(2, int(significant_digits))
data_kwds["quantize_mode"] = "GranularBitRound"

# boolean arrays need to be converted to integers
# no such thing as a mask on a boolean array
if grid.dtype is np.dtype(bool):
grid = grid.astype('i1')
grid = grid.astype("i1")
fill_value = None

cdf_data = cdf.createVariable('z', grid.dtype, ('lat','lon'), **data_kwds)
cdf_data = cdf.createVariable("z", grid.dtype, ("lat", "lon"), **data_kwds)

# netCDF4 uses the missing_value attribute as the default _FillValue
# without this, _FillValue defaults to 9.969209968386869e+36
if fill_value is not None:
cdf_data.missing_value = fill_value
grid_mask = grid != fill_value

cdf_data.actual_range = [np.nanmin(grid[grid_mask]), np.nanmax(grid[grid_mask])]
cdf_data.actual_range = [
np.nanmin(grid[grid_mask]),
np.nanmax(grid[grid_mask]),
]

else:
# ensure min and max z values are properly registered
cdf_data.actual_range = [np.nanmin(grid), np.nanmax(grid)]

cdf_data.standard_name = 'z'
cdf_data.standard_name = "z"

# cdf_data.add_offset = 0.0
cdf_data.grid_mapping = 'crs'
cdf_data.grid_mapping = "crs"
# cdf_data.set_auto_maskandscale(False)

# write data
cdf_data[:,:] = grid
cdf_data[:, :] = grid


class RegularGridInterpolator(_RGI):
Expand Down Expand Up @@ -804,6 +873,9 @@ def reconstruct_grid(
fill_value=None,
threads=1,
anchor_plate_id=0,
x_dimension_name: str = "",
y_dimension_name: str = "",
data_variable_name: str = "",
):
"""Reconstruct a gridded dataset to a given reconstruction time.
Expand Down Expand Up @@ -845,6 +917,16 @@ def reconstruct_grid(
Number of threads to use for certain computationally heavy routines.
anchor_plate_id : int, default 0
ID of the anchored plate.
x_dimension_name : str, optional, default=""
If the grid file uses comman names, such as "x", "lon", "lons" or "longitude", you need not set this parameter.
Otherwise, you need to tell us what the x dimension name is.
y_dimension_name : str, optional, default=""
If the grid file uses comman names, such as "y", "lat", "lats" or "latitude", you need not set this parameter.
Otherwise, you need to tell us what the y dimension name is.
data_variable_name : str, optional, default=""
The program will try its best to determine the data variable name.
However, it would be better if you could tell us what the data variable name is.
Otherwise, the program will guess. The result may/may not be correct.
Returns
-------
Expand All @@ -868,7 +950,14 @@ def reconstruct_grid(
(0, 0, 0, 0).
"""
try:
grid = np.array(read_netcdf_grid(grid)) # load grid data from file
grid = np.array(
read_netcdf_grid(
grid,
x_dimension_name=x_dimension_name,
y_dimension_name=y_dimension_name,
data_variable_name=data_variable_name,
)
) # load grid data from file
except Exception:
grid = np.array(grid) # copy grid data to array
if to_time == from_time:
Expand Down Expand Up @@ -1561,6 +1650,9 @@ def __init__(
resize=None,
time=0.0,
origin=None,
x_dimension_name: str = "",
y_dimension_name: str = "",
data_variable_name: str = "",
**kwargs,
):
"""Constructs all necessary attributes for the raster object.
Expand Down Expand Up @@ -1600,6 +1692,19 @@ def __init__(
When `data` is an array, use this parameter to specify the origin
(upper left or lower left) of the data (overriding `extent`).
x_dimension_name : str, optional, default=""
If the grid file uses comman names, such as "x", "lon", "lons" or "longitude", you need not set this parameter.
Otherwise, you need to tell us what the x dimension name is.
y_dimension_name : str, optional, default=""
If the grid file uses comman names, such as "y", "lat", "lats" or "latitude", you need not set this parameter.
Otherwise, you need to tell us what the y dimension name is.
data_variable_name : str, optional, default=""
The program will try its best to determine the data variable name.
However, it would be better if you could tell us what the data variable name is.
Otherwise, the program will guess. The result may/may not be correct.
**kwargs
Handle deprecated arguments such as `PlateReconstruction_object`,
`filename`, and `array`.
Expand Down Expand Up @@ -1663,6 +1768,9 @@ def __init__(
realign=realign,
resample=resample,
resize=resize,
x_dimension_name=x_dimension_name,
y_dimension_name=y_dimension_name,
data_variable_name=data_variable_name,
)
self._lons = lons
self._lats = lats
Expand All @@ -1677,7 +1785,7 @@ def __init__(
self._lats = np.linspace(extent[2], extent[3], self.data.shape[0])
if realign:
# realign to -180,180 and flip grid
self._data, self._lons, self._lats = realign_grid(
self._data, self._lons, self._lats = _realign_grid(
self._data, self._lons, self._lats
)

Expand Down Expand Up @@ -2032,7 +2140,9 @@ def fill_NaNs(self, inplace=False, return_array=False):
def save_to_netcdf4(self, filename, significant_digits=None, fill_value=np.nan):
"""Saves the grid attributed to the `Raster` object to the given `filename` (including
the ".nc" extension) in netCDF4 format."""
write_netcdf_grid(str(filename), self.data, self.extent, significant_digits, fill_value)
write_netcdf_grid(
str(filename), self.data, self.extent, significant_digits, fill_value
)

def reconstruct(
self,
Expand Down
Loading

0 comments on commit f795370

Please sign in to comment.