Skip to content

Commit

Permalink
Wrapper to match grid_labels (#161)
Browse files Browse the repository at this point in the history
* First tests working

* Added whats new

* Add xesmf to ci envs

* Dont use pip for xesmf. Bad idea

* added docs

* some docs typos

* some minor coverage fix

* fix doc headings

* fix doc headings again

* Delete Untitled.ipynb

* Update whats-new.rst

* Clean up
  • Loading branch information
jbusecke authored Jul 9, 2021
1 parent a9ca67d commit 446935d
Show file tree
Hide file tree
Showing 8 changed files with 2,674 additions and 57 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ target/
.ipynb_checkpoints

**/dask-worker-space/
mydask.png

.vscode
.mypy_cache
Expand Down
3 changes: 2 additions & 1 deletion ci/environment-upstream-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ dependencies:
- cftime
- dask
- pip
- cartopy
- cartopy #installing this without conda is a nightmare, so ill leave it here
- xesmf # same here
- pip:
- codecov
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- xgcm
- cftime
- regionmask
- xesmf
- xarrayutils
- pytest-cov
- pytest-xdist
Expand Down
200 changes: 200 additions & 0 deletions cmip6_preprocessing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from cmip6_preprocessing.utils import _key_from_attrs, cmip6_dataset_id


try:
import xesmf
except ImportError:
xesmf = None


# define the attrs that are needed to get an 'exact' match
exact_attrs = [
"source_id",
Expand Down Expand Up @@ -217,6 +223,200 @@ def concat_members(
)


### Matching wrapper specific to combining grid labels via interpolation with xesmf
def requires_xesmf(func):
@functools.wraps(func)
def wrapper_requires_xesmf(*args, **kwargs):

# Check if xesmf is installed
if xesmf is None:
raise ValueError(
"This function needs the optional package xesmf. Please install with `conda install -c conda-forge xesmf`."
)
return func(*args, **kwargs)

return wrapper_requires_xesmf


def _pick_grid_label(ds_list, preferred_grid_label):
"""From a list of datasets this picks the dataset with `ds.attrs['grid_label']==preferred_label`
if available"""
# TODO: Implement an option to pick the highest/lowest resolution if grid label dows not match.
matches = [ds for ds in ds_list if ds.attrs["grid_label"] == preferred_grid_label]
if len(matches) > 0:
return matches[0]
else:
return ds_list[0]


def _drop_duplicate_grid_labels(ddict, preferred_grid_label):
"""Eliminate multiple grid labels for the same dataset, by dropping multiples
with a preference for `preferred_grid_label`"""
match_attrs = [ma for ma in exact_attrs if ma not in ["version", "grid_label"]] + [
"variable_id"
]
return combine_datasets(
ddict,
_pick_grid_label,
combine_func_args=(preferred_grid_label,),
match_attrs=match_attrs,
)


def _clean_regridder(ds_source, ds_target, method, **xesmf_kwargs):
def _clean(ds):
# remove all unnecessary stuff for the regridding
ds = ds.isel(time=0, lev=0, rho=0, missing_dims="ignore")
for coord in [
co for co in ds.coords if co not in ["lon", "lat"]
]: # , "lat_bounds""lon_bounds",
ds = ds.drop_vars(coord)

# Ugly Hack to elminate 'seam' when regridding from gr to native grids
# There is something in xesmf that causes problems with broadcasted regular lon/lat values
if "gr" in ds.attrs["grid_label"]:
# actually revert the convention and return the 1d coordinates. ?Should I change that behavior in general?
ds = ds.assign(lon=ds.lon.isel(y=0))
ds = ds.assign(lat=ds.lat.isel(x=0))

# for now just eliminate the attrs here
# I can solve this more elegantly when I parse proper cf-attributes
for coord in ds.coords:
ds[coord].attrs = {}
# ds = ds.rename({'lon_bounds':'lon_b', 'lat_bounds':'lat_b'})

# TODO: Make this work out of the box with lon/lat bounds and method='conservative'
# TODO: Maybe erase the need for this completely with cf-xarray
return ds

ds_source = _clean(ds_source)
ds_target = _clean(ds_target)
return xesmf.Regridder(ds_source, ds_target, method, **xesmf_kwargs)


def _regrid_to_target(ds_source, ds_target, regridder):
ds_regridded = regridder(ds_source, keep_attrs=True)

# remove all coordinates that involve x and y (these will be merged from the native grid dataset)
ds_regridded = ds_regridded.reset_coords(drop=True)

# modify attributes
ds_regridded.attrs["grid_label"] = ds_target.attrs["grid_label"]
# identify the variables that are regridded
for var in ds_regridded.data_vars:
ds_regridded[var].attrs[
"cmip6_preprocessing_regrid_method"
] = ds_regridded.attrs["regrid_method"]

return ds_regridded


@requires_xesmf
def interpolate_grid_label(
ds_dict,
target_grid_label="gn",
method="bilinear",
xesmf_kwargs={},
merge_kwargs={},
verbose=False,
):
"""Combines different grid labels via interpolation with xesmf
Parameters
----------
ds_dict : dict
dictonary of input datasets
target_grid_label : str, optional
preferred grid_label value. If at least one dataset has this grid_label, otherse are interpolated to it.
Dataset with this grid label are not modified, by default "gn"
method : str, optional
interpolation method for xesmf, by default "bilinear"
xesmf_kwargs : dict, optional
optional arguments for building xesmf regridder, by default {}
merge_kwargs : dict, optional
optional arguments for the merging of interpolated datasets, by default {}
verbose : bool, optional
print output while creating regridder, by default False
Returns
-------
dict
dictionary of combined datasets (usually will combine across different variable ids)
"""
match_attrs = [
ma for ma in exact_attrs if ma not in ["grid_label", "version"]
] # does this need to be more flexible?

xesmf_kwargs.setdefault("ignore_degenerate", True)
xesmf_kwargs.setdefault("periodic", True)

merge_kwargs.setdefault("combine_attrs", "drop_conflicts")

# first drop the datasets that might have both the target and another grid_label present
ds_dict = _drop_duplicate_grid_labels(ds_dict, target_grid_label)

def combine_func(ds_list, **kwargs):
grid_labels = np.unique([dss.attrs["grid_label"] for dss in ds_list])
target_grid = [
dss for dss in ds_list if target_grid_label == dss.attrs["grid_label"]
]
if len(target_grid) < 1:
raise ValueError(
f"Could not find any variable with the target_grid_label{target_grid_label}. Found these instead: {grid_labels}"
)
else:
# just take the first one with a matching grid?
target_grid = target_grid[0]

# Construct a regridder for each other grid_label
regridder_dict = {}
if verbose:
print(
f'Constructing regridders for source_id {target_grid.attrs["source_id"]} ...'
)
for gl in grid_labels:
if gl != target_grid_label:
if verbose:
print(gl)
source_grid = [dss for dss in ds_list if dss.attrs["grid_label"] == gl][
0
] # again just take the first one available
regridder_dict[gl] = _clean_regridder(
source_grid, target_grid, method, **xesmf_kwargs
)
if verbose:
print("FINISHED")

# Now regrid all datasets in the list (dont do anything if already on the target_grid)
ds_list_new = []
for ds_raw in ds_list:
if ds_raw.attrs["grid_label"] != target_grid_label:
if verbose:
print(f"regridding {cmip6_dataset_id(ds_raw)}")
ds_regridded = _regrid_to_target(
ds_raw, target_grid, regridder_dict[ds_raw.attrs["grid_label"]]
)
ds_list_new.append(ds_regridded)
else:
ds_list_new.append(ds_raw)

# check that horizontal dimensions are compatible.
xy_dimensions = [(len(ds.x), len(ds.y)) for ds in ds_list_new]
if not all([xy == xy_dimensions[0] for xy in xy_dimensions]):
raise ValueError(
f"Regridded datasets do not have the same dimensions. Found({xy_dimensions}). This will cause broadcasting problems during merge."
)
if verbose:
print("Merging regridded files")
return xr.merge(ds_list_new, **merge_kwargs)

return combine_datasets(
ds_dict,
combine_func,
match_attrs=match_attrs,
)


### Matching wrapper specific to metric datasets


Expand Down
Loading

0 comments on commit 446935d

Please sign in to comment.