diff --git a/gplately/grids.py b/gplately/grids.py index e83686bc..a2ea4595 100644 --- a/gplately/grids.py +++ b/gplately/grids.py @@ -37,10 +37,12 @@ import math import warnings from multiprocessing import cpu_count +from typing import Union import matplotlib.colors import matplotlib.pyplot as plt import numpy as np +import netCDF4 import pygplates from cartopy.crs import PlateCarree as _PlateCarree from cartopy.mpl.geoaxes import GeoAxes as _GeoAxes @@ -126,7 +128,8 @@ def fill_raster(data, invalid=None): return data[tuple(ind)] -def realign_grid(array, lons, lats): +def _realign_grid(array, lons, lats): + """realigns grid to -180/180 and flips the array if the latitudinal coordinates are decreasing.""" mask_lons = lons > 180 # realign to -180/180 @@ -142,7 +145,29 @@ def realign_grid(array, lons, lats): return array, lons, lats -def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None, resize=None): +def _guess_data_variable_name(cdf: netCDF4.Dataset, x_name: str, y_name: str) -> Union[str, None]: # type: ignore + """best effort to find out the data variable name""" + vars = cdf.variables.keys() + for var in vars: + dimensions = cdf.variables[var].dimensions + if len(dimensions) != 2: # only consider two-dimensional data + continue + else: + if dimensions[0] == y_name and dimensions[1] == x_name: + return var + return None + + +def read_netcdf_grid( + filename, + return_grids=False, + realign=False, + resample=None, + resize=None, + x_dimension_name: str = "", + y_dimension_name: str = "", + data_variable_name: str = "", +): """Read a `netCDF` (.nc) grid from a given `filename` and return its data as a `MaskedArray`. @@ -178,6 +203,20 @@ def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None, If passed as `resample = (resX, resY)`, the given `netCDF` grid is resized to the number of columns (resX) and rows (resY). + x_dimension_name : str, optional, default="" + If the grid file uses comman names, such as "x", "lon", "lons" or "longitude", you need not set this parameter. + Otherwise, you need to tell us what the x dimension name is. + + y_dimension_name : str, optional, default="" + If the grid file uses comman names, such as "y", "lat", "lats" or "latitude", you need not set this parameter. + Otherwise, you need to tell us what the y dimension name is. + + data_variable_name : str, optional, default="" + The program will try its best to determine the data variable name. + However, it would be better if you could tell us what the data variable name is. + Otherwise, the program will guess. The result may/may not be correct. + + Returns ------- grid_z : MaskedArray @@ -198,8 +237,6 @@ def find_label(keys, labels): return label return None - import netCDF4 - # possible permutations of lon/lat/z label_lon = ["lon", "lons", "longitude", "x", "east", "easting", "eastings"] label_lat = ["lat", "lats", "latitude", "y", "north", "northing", "northings"] @@ -227,22 +264,48 @@ def find_label(keys, labels): keys = cdf.variables.keys() # find the names of variables - key_z = find_label(keys, label_z) - key_lon = find_label(keys, label_lon) - key_lat = find_label(keys, label_lat) + if data_variable_name: + key_z = data_variable_name + else: + key_z = find_label(keys, label_z) + if x_dimension_name: + key_lon = x_dimension_name + else: + key_lon = find_label(keys, label_lon) + if y_dimension_name: + key_lat = y_dimension_name + else: + key_lat = find_label(keys, label_lat) if key_lon is None or key_lat is None: - raise ValueError("Cannot find x,y or lon/lat coordinates in netcdf") + raise ValueError( + f"Cannot find x,y or lon/lat coordinates in netcdf. The dimensions in the file are {cdf.dimensions.keys()}" + ) + + if key_z is None: + key_z = _guess_data_variable_name(cdf, key_lon, key_lat) + if key_z is None: - raise ValueError("Cannot find z data in netcdf") + raise ValueError( + f"Cannot find z data in netcdf. The variables in the file are {cdf.variables.keys()}" + ) # extract data from cdf variables + # TODO: the dimensions of data may not be (lat, lon). It is possible(but unlikely?) that the dimensions are(lon, lat). + # just note you may need numpy.swapaxes() here. + if len(cdf[key_z].dimensions) != 2: + raise Exception( + f"The data in the netcdf file is not two-dimensional. This function can only handle two-dimensional data." + + f"The dimensions in the file are {cdf[key_z].dimensions.keys()}" + ) cdf_grid = cdf[key_z][:] cdf_lon = cdf[key_lon][:] cdf_lat = cdf[key_lat][:] # fill missing values - if hasattr(cdf[key_z], 'missing_value') and np.issubdtype(cdf_grid.dtype, np.floating): + if hasattr(cdf[key_z], "missing_value") and np.issubdtype( + cdf_grid.dtype, np.floating + ): fill_value = cdf[key_z].missing_value cdf_grid[np.isclose(cdf_grid, fill_value, rtol=0.1)] = np.nan @@ -250,12 +313,12 @@ def find_label(keys, labels): if np.issubdtype(cdf_grid.dtype, np.integer): unique_grid = np.unique(cdf_grid) if len(unique_grid) == 2: - if (unique_grid == [0,1]).all(): + if (unique_grid == [0, 1]).all(): cdf_grid = cdf_grid.astype(bool) if realign: # realign longitudes to -180/180 dateline - cdf_grid_z, cdf_lon, cdf_lat = realign_grid(cdf_grid, cdf_lon, cdf_lat) + cdf_grid_z, cdf_lon, cdf_lat = _realign_grid(cdf_grid, cdf_lon, cdf_lat) else: cdf_grid_z = cdf_grid @@ -323,9 +386,12 @@ def find_label(keys, labels): return cdf_grid_z, cdf_lon, cdf_lat else: return cdf_grid_z - -def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, fill_value=np.nan): - """ Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`. + + +def write_netcdf_grid( + filename, grid, extent="global", significant_digits=None, fill_value=np.nan +): + """Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`. Notes ----- @@ -350,9 +416,9 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, the data's latitudes, while the columns correspond to the data's longitudes. extent : list, default=[-180,180,-90,90] - Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon - variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]` - is assumed. + Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon + variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]` + is assumed. significant_digits : int Applies lossy data compression up to a specified number of significant digits. @@ -369,24 +435,24 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, import netCDF4 from gplately import __version__ as _version - if extent == 'global': + if extent == "global": extent = [-180, 180, -90, 90] else: assert len(extent) == 4, "specify the [min lon, max lon, min lat, max lat]" - + nrows, ncols = np.shape(grid) lon_grid = np.linspace(extent[0], extent[1], ncols) lat_grid = np.linspace(extent[2], extent[3], nrows) - data_kwds = {'compression': 'zlib', 'complevel': 6} - - with netCDF4.Dataset(filename, 'w', driver=None) as cdf: + data_kwds = {"compression": "zlib", "complevel": 6} + + with netCDF4.Dataset(filename, "w", driver=None) as cdf: cdf.title = "Grid produced by gplately " + str(_version) - cdf.createDimension('lon', lon_grid.size) - cdf.createDimension('lat', lat_grid.size) - cdf_lon = cdf.createVariable('lon', lon_grid.dtype, ('lon',), **data_kwds) - cdf_lat = cdf.createVariable('lat', lat_grid.dtype, ('lat',), **data_kwds) + cdf.createDimension("lon", lon_grid.size) + cdf.createDimension("lat", lat_grid.size) + cdf_lon = cdf.createVariable("lon", lon_grid.dtype, ("lon",), **data_kwds) + cdf_lat = cdf.createVariable("lat", lat_grid.dtype, ("lat",), **data_kwds) cdf_lon[:] = lon_grid cdf_lat[:] = lat_grid @@ -399,9 +465,9 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, cdf_lat.actual_range = [lat_grid[0], lat_grid[-1]] # create container variable for CRS: lon/lat WGS84 datum - crso = cdf.createVariable('crs','i4') - crso.long_name = 'Lon/Lat Coords in WGS84' - crso.grid_mapping_name='latitude_longitude' + crso = cdf.createVariable("crs", "i4") + crso.long_name = "Lon/Lat Coords in WGS84" + crso.grid_mapping_name = "latitude_longitude" crso.longitude_of_prime_meridian = 0.0 crso.semi_major_axis = 6378137.0 crso.inverse_flattening = 298.257223563 @@ -410,16 +476,16 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, # add more keyword arguments for quantizing data if significant_digits: # significant_digits needs to be >= 2 so that NaNs are preserved - data_kwds['significant_digits'] = max(2, int(significant_digits)) - data_kwds['quantize_mode'] = 'GranularBitRound' + data_kwds["significant_digits"] = max(2, int(significant_digits)) + data_kwds["quantize_mode"] = "GranularBitRound" # boolean arrays need to be converted to integers # no such thing as a mask on a boolean array if grid.dtype is np.dtype(bool): - grid = grid.astype('i1') + grid = grid.astype("i1") fill_value = None - cdf_data = cdf.createVariable('z', grid.dtype, ('lat','lon'), **data_kwds) + cdf_data = cdf.createVariable("z", grid.dtype, ("lat", "lon"), **data_kwds) # netCDF4 uses the missing_value attribute as the default _FillValue # without this, _FillValue defaults to 9.969209968386869e+36 @@ -427,20 +493,23 @@ def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, cdf_data.missing_value = fill_value grid_mask = grid != fill_value - cdf_data.actual_range = [np.nanmin(grid[grid_mask]), np.nanmax(grid[grid_mask])] + cdf_data.actual_range = [ + np.nanmin(grid[grid_mask]), + np.nanmax(grid[grid_mask]), + ] else: # ensure min and max z values are properly registered cdf_data.actual_range = [np.nanmin(grid), np.nanmax(grid)] - cdf_data.standard_name = 'z' + cdf_data.standard_name = "z" # cdf_data.add_offset = 0.0 - cdf_data.grid_mapping = 'crs' + cdf_data.grid_mapping = "crs" # cdf_data.set_auto_maskandscale(False) # write data - cdf_data[:,:] = grid + cdf_data[:, :] = grid class RegularGridInterpolator(_RGI): @@ -804,6 +873,9 @@ def reconstruct_grid( fill_value=None, threads=1, anchor_plate_id=0, + x_dimension_name: str = "", + y_dimension_name: str = "", + data_variable_name: str = "", ): """Reconstruct a gridded dataset to a given reconstruction time. @@ -845,6 +917,16 @@ def reconstruct_grid( Number of threads to use for certain computationally heavy routines. anchor_plate_id : int, default 0 ID of the anchored plate. + x_dimension_name : str, optional, default="" + If the grid file uses comman names, such as "x", "lon", "lons" or "longitude", you need not set this parameter. + Otherwise, you need to tell us what the x dimension name is. + y_dimension_name : str, optional, default="" + If the grid file uses comman names, such as "y", "lat", "lats" or "latitude", you need not set this parameter. + Otherwise, you need to tell us what the y dimension name is. + data_variable_name : str, optional, default="" + The program will try its best to determine the data variable name. + However, it would be better if you could tell us what the data variable name is. + Otherwise, the program will guess. The result may/may not be correct. Returns ------- @@ -868,7 +950,14 @@ def reconstruct_grid( (0, 0, 0, 0). """ try: - grid = np.array(read_netcdf_grid(grid)) # load grid data from file + grid = np.array( + read_netcdf_grid( + grid, + x_dimension_name=x_dimension_name, + y_dimension_name=y_dimension_name, + data_variable_name=data_variable_name, + ) + ) # load grid data from file except Exception: grid = np.array(grid) # copy grid data to array if to_time == from_time: @@ -1561,6 +1650,9 @@ def __init__( resize=None, time=0.0, origin=None, + x_dimension_name: str = "", + y_dimension_name: str = "", + data_variable_name: str = "", **kwargs, ): """Constructs all necessary attributes for the raster object. @@ -1600,6 +1692,19 @@ def __init__( When `data` is an array, use this parameter to specify the origin (upper left or lower left) of the data (overriding `extent`). + x_dimension_name : str, optional, default="" + If the grid file uses comman names, such as "x", "lon", "lons" or "longitude", you need not set this parameter. + Otherwise, you need to tell us what the x dimension name is. + + y_dimension_name : str, optional, default="" + If the grid file uses comman names, such as "y", "lat", "lats" or "latitude", you need not set this parameter. + Otherwise, you need to tell us what the y dimension name is. + + data_variable_name : str, optional, default="" + The program will try its best to determine the data variable name. + However, it would be better if you could tell us what the data variable name is. + Otherwise, the program will guess. The result may/may not be correct. + **kwargs Handle deprecated arguments such as `PlateReconstruction_object`, `filename`, and `array`. @@ -1663,6 +1768,9 @@ def __init__( realign=realign, resample=resample, resize=resize, + x_dimension_name=x_dimension_name, + y_dimension_name=y_dimension_name, + data_variable_name=data_variable_name, ) self._lons = lons self._lats = lats @@ -1677,7 +1785,7 @@ def __init__( self._lats = np.linspace(extent[2], extent[3], self.data.shape[0]) if realign: # realign to -180,180 and flip grid - self._data, self._lons, self._lats = realign_grid( + self._data, self._lons, self._lats = _realign_grid( self._data, self._lons, self._lats ) @@ -2032,7 +2140,9 @@ def fill_NaNs(self, inplace=False, return_array=False): def save_to_netcdf4(self, filename, significant_digits=None, fill_value=np.nan): """Saves the grid attributed to the `Raster` object to the given `filename` (including the ".nc" extension) in netCDF4 format.""" - write_netcdf_grid(str(filename), self.data, self.extent, significant_digits, fill_value) + write_netcdf_grid( + str(filename), self.data, self.extent, significant_digits, fill_value + ) def reconstruct( self, diff --git a/tests-dir/debug_utils/github_issue_287.py b/tests-dir/debug_utils/github_issue_287.py new file mode 100644 index 00000000..c9230064 --- /dev/null +++ b/tests-dir/debug_utils/github_issue_287.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# this debugging script was created from NW's Jupyter notebook. Thank NW for providing us the code. + +import gplately +from pathlib import Path +import os +import requests +from plate_model_manager import PlateModelManager, PlateModel +import xarray as xr + +workdir = "debug-folder-github-issue-287" +Path(workdir).mkdir(parents=True, exist_ok=True) + +miocene_file = os.path.join(workdir, "miocene_topo_pollard_antscape_dolan_0.5x0.5.nc") + +if not os.path.isfile(miocene_file): + r = requests.get( + "https://repo.gplates.org/webdav/gplately-test-data/issue-287/miocene_topo_pollard_antscape_dolan_0.5x0.5.nc", + allow_redirects=True, + ) + open(miocene_file, "wb").write(r.content) + +model_name = "merdith2021" +try: + plate_model = PlateModelManager().get_model(model_name, data_dir=workdir) +except: + plate_model = PlateModel(model_name, data_dir=workdir, readonly=True) + +if not plate_model: + raise Exception(f"Unable to get model({model_name})") + +rotation_files = plate_model.get_rotation_model() +topology_files = plate_model.get_topologies() +continent_files = plate_model.get_layer("ContinentalPolygons") + +r_model = gplately.PlateReconstruction( + rotation_model=rotation_files, + topology_features=topology_files, + static_polygons=plate_model.get_layer("StaticPolygons"), +) + +raster = gplately.Raster( + data=miocene_file, time=15, plate_reconstruction=r_model, realign=True +) +reconstructed_raster = raster.reconstruct( + 0, threads=4, partitioning_features=plate_model.get_layer("StaticPolygons") +) +outfile = os.path.join(workdir, "reconstructed_grid.nc") +gplately.grids.write_netcdf_grid( + filename=outfile, + grid=reconstructed_raster._data, + extent=[ + reconstructed_raster._lons.min(), + reconstructed_raster._lons.max(), + reconstructed_raster._lats.min(), + reconstructed_raster._lats.max(), + ], +) +# mio_nc2 = xr.open_dataset(outfile) +# mio_nc2.z.plot()