From 5568cca305cfdc6e0ae12a079bbc675fe8662d5f Mon Sep 17 00:00:00 2001 From: Kevin Schwarzwald Date: Mon, 3 Jun 2024 18:33:55 -0400 Subject: [PATCH] Implement global defaults class Implements `OPTIONS`, `set_options()`, and `get_options()`, which will allow global setting of `silent` and `impl`, in addition to setting them in "with" blocks. Sets stage to eventually get rid of "silent" and "impl" as function options, to be replaced with with blocks and `set_options()`. --- .gitignore | 2 + readthedocs.yml | 5 --- xagg/__init__.py | 3 +- xagg/classes.py | 28 +++++++----- xagg/core.py | 54 ++++++++++++----------- xagg/options.py | 110 +++++++++++++++++++++++++++++++++++++++++++++++ xagg/wrappers.py | 13 ++++-- 7 files changed, 172 insertions(+), 43 deletions(-) create mode 100644 xagg/options.py diff --git a/.gitignore b/.gitignore index f293e6d..a1db86b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ __pycache__/ # xagg /wm/ directories created during docs processing wm/ docs/notebooks/wm/ +wm_export_test/ + # C extensions *.so diff --git a/readthedocs.yml b/readthedocs.yml index 36e8bb3..5d03196 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -16,11 +16,6 @@ build: conda: environment: docs/docenvironment.yml -#python: -# version: 3.12 -# install: -# - method: setuptools -# path: package sphinx: fail_on_warning: False configuration: docs/source/conf.py \ No newline at end of file diff --git a/xagg/__init__.py b/xagg/__init__.py index 006ad31..eb6a1ea 100644 --- a/xagg/__init__.py +++ b/xagg/__init__.py @@ -4,4 +4,5 @@ # two functions) from .wrappers import pixel_overlaps from .auxfuncs import (normalize,fix_ds,get_bnds,subset_find) -from .core import (aggregate,read_wm) \ No newline at end of file +from .core import (aggregate,read_wm) +from .options import get_options, set_options \ No newline at end of file diff --git a/xagg/classes.py b/xagg/classes.py index 351fa36..920a377 100644 --- a/xagg/classes.py +++ b/xagg/classes.py @@ -2,6 +2,7 @@ import warnings import os import re +from .options import get_options try: import cartopy @@ -91,56 +92,63 @@ def to_dataframe(self,loc_dim='poly_idx'): return df_out # Export functions - def to_netcdf(self,fn,loc_dim='poly_idx',silent=False): + def to_netcdf(self,fn,loc_dim='poly_idx',silent=None): """ Save as netcdf Parameters ----------------- - fn : str + fn : :py:class:`str` The target filename - loc_dim : str, by default `'poly_idx'` + loc_dim : :py:class:`str`, by default `'poly_idx'` What to name the polygon dimension - silent : bool, by default False + silent : :py:class:`bool`, by default False If `True`, silences standard out """ + if silent is None: + silent = get_options()['silent'] + output_data(self, output_format = 'netcdf', output_fn = fn, loc_dim = loc_dim, silent = silent) - def to_csv(self,fn,silent=False): + def to_csv(self,fn,silent=None): """ Save as csv Parameters ----------------- - fn : str + fn : :py:class:`str` The target filename - silent : bool, by default False + silent : :py:class:`bool`, by default False If `True`, silences standard out """ + if silent is None: + silent = get_options()['silent'] output_data(self, output_format = 'csv', output_fn = fn, silent=silent) - def to_shp(self,fn,silent=False): + def to_shp(self,fn,silent=None): """ Save as shapefile - fn : str + fn : :py:class:`str` The target filename - silent : bool, by default False + silent : :py:class:`bool`, by default False If `True`, silences standard out """ + if silent is None: + silent = get_options()['silent'] output_data(self, output_format = 'shp', output_fn = fn, diff --git a/xagg/core.py b/xagg/core.py index f7a1bd7..bd9a036 100644 --- a/xagg/core.py +++ b/xagg/core.py @@ -15,6 +15,7 @@ from . auxfuncs import (find_rel_area,normalize,fix_ds,get_bnds,subset_find,list_or_first) from . classes import (weightmap,aggregated) +from . options import get_options class NoOverlapError(Exception): """ Exception for when there's no overlap between pixels and polygons """ @@ -92,7 +93,7 @@ def read_wm(path): return wm -def process_weights(ds,weights=None,target='ds',silent=False): +def process_weights(ds,weights=None,target='ds',silent=None): """ Process weights - including regridding If ``target == 'ds'``, regrid `weights` to `ds`. If ``target == 'weights'``, @@ -112,7 +113,7 @@ def process_weights(ds,weights=None,target='ds',silent=False): default) or vice-versa (not yet supported, returns ``NotImplementedError``) - silent : :py:class:`bool`, default = `False` + silent : :py:class:`bool`, default = `False` (set by :py:meth:`xa.set_options`) if True, then no status updates are printed to std out Returns @@ -129,6 +130,10 @@ def process_weights(ds,weights=None,target='ds',silent=False): - ``ds_grid``: a dictionary with the grid ``{"lat":ds.lat,"lon",ds.lon}`` - ``weights_grid``: a dictionary with the grid ``{"lat":weights.lat,"lon":weights.lon}`` """ + + if silent is None: + silent = get_options()['silent'] + if weights is None: # (for robustness against running this without an extra if statement @@ -227,7 +232,7 @@ def create_raster_polygons(ds, mask=None,subset_bbox=None, weights=None,weights_target='ds', wrap_around_thresh=5, - silent=False): + silent=None): """ Create polygons for each pixel in a raster Note: @@ -273,6 +278,9 @@ def create_raster_polygons(ds, the input `ds`) """ + + if silent is None: + silent = get_options()['silent'] # Standardize inputs (including lat/lon order) ds = fix_ds(ds) @@ -366,7 +374,7 @@ def create_raster_polygons(ds, return pix_agg -def get_pixel_overlaps(gdf_in,pix_agg,impl='for_loop'): +def get_pixel_overlaps(gdf_in,pix_agg,impl=None): """ Get, for each polygon, the pixels that overlap and their area of overlap Finds, for each polygon in `gdf_in`, which pixels intersect it, and by how much. @@ -394,13 +402,11 @@ def get_pixel_overlaps(gdf_in,pix_agg,impl='for_loop'): ``[da.lat,da.lon]`` of the grid used to create the pixel polygons - impl : :py:class:`str` + impl : :py:class:`str` (set by :py:meth:`xa.set_options`) whether the output will be used for the dot-product aggregation calculation (needs a slightly different format), either of: - ``'for_loop'`` (default behavior) - - ``'dot_product'`` (to set up for ``impl='dot_product'`` in - ``xagg.core.aggregate``) - + - ``'dot_product'`` (to set up for ``impl='dot_product'`` in ``xagg.core.aggregate``) Returns --------------- @@ -411,24 +417,19 @@ def get_pixel_overlaps(gdf_in,pix_agg,impl='for_loop'): a dataframe containing all the fields of ``gdf_in`` (except geometry) and the additional columns: - - ``coords``: the lat/lon coordiates of all pixels that overlap - the polygon of that row - - ``pix_idxs``: the linear indices of those pixels within the - ``gdf_pixels`` grid - - ``rel_area``: the relative area of each of the overlaps between - the pixels and the polygon (summing to 1 - e.g. - if the polygon is exactly the size and location of - two pixels, their rel_areas would be 0.5 each) + - ``coords``: the lat/lon coordiates of all pixels that overlap the polygon of that row + - ``pix_idxs``: the linear indices of those pixels within the ``gdf_pixels`` grid + - ``rel_area``: the relative area of each of the overlaps between the pixels and the polygon (summing to 1 - e.g. if the polygon is exactly the size and location of two pixels, their rel_areas would be 0.5 each) - ``'source_grid'`` - a dictionary with keys 'lat' and 'lon' giving the - original lat/lon grid whose overlaps with the polygons - was calculated + a dictionary with keys 'lat' and 'lon' giving the original lat/lon grid whose overlaps with the polygons was calculated - ``'geometry'`` just the polygons from ``gdf_in`` - """ + + if impl is None: + impl = get_options()['impl'] # Add an index for each polygon as a column to make indexing easier #if 'poly_idx' not in gdf_in.columns: @@ -530,7 +531,7 @@ def get_pixel_overlaps(gdf_in,pix_agg,impl='for_loop'): return wm_out -def aggregate(ds,wm,impl='for_loop',silent=False): +def aggregate(ds,wm,impl=None,silent=None): """ Aggregate raster variable(s) to polygon(s) Aggregates (N-D) raster variables in `ds` to the polygons @@ -572,7 +573,7 @@ def aggregate(ds,wm,impl='for_loop',silent=False): were calculated (and on which the linear indices are based) - impl : :class:str (def: ``'for_loop'``) + impl : :class:str (def: ``'for_loop'``) (set by :py:meth:`xa.set_options`) which aggregation calculation method to use, either of: - ``'for_loop'`` @@ -583,7 +584,7 @@ def aggregate(ds,wm,impl='for_loop',silent=False): requires much more memory (due to broadcasting of variables) but may be faster in certain circumstances - silent : :py:class:`bool`, default = `False` + silent : :py:class:`bool`, default = `False` (set by :py:meth:`xa.set_options`) if True, then no status updates are printed to std out Returns @@ -592,6 +593,11 @@ def aggregate(ds,wm,impl='for_loop',silent=False): an :class:`xagg.classes.aggregated` object with the aggregated variables """ + if impl is None: + impl = get_options()['impl'] + if silent is None: + silent = get_options()['silent'] + # Make sure pixel_overlaps was correctly run if using dot product if (impl=='dot_product') and (wm.overlap_da is None): raise ValueError("no 'overlap_da' was found in the `wm` input - since you're using the dot product implementation, "+ @@ -613,7 +619,7 @@ def aggregate(ds,wm,impl='for_loop',silent=False): ds = ds.stack(loc=('lat','lon')) # Adjust grid of [ds] if necessary to match - ds = subset_find(ds,wm.source_grid) + ds = subset_find(ds,wm.source_grid,silent=silent) # Set weights; or replace with ones if no additional weight information #if wm.weights != 'nowghts': diff --git a/xagg/options.py b/xagg/options.py new file mode 100644 index 0000000..3456e5a --- /dev/null +++ b/xagg/options.py @@ -0,0 +1,110 @@ +# Partially adapted from xarray's xr.core.options +# NB: In a future version, setting `impl` and `silent` +# in individual function calls should be deprecated +# in favor of setting global defaults or using +# with blocks. +from typing import TypedDict + +# Create class specifying the needed +# type of each option +class T_Options(TypedDict): + silent : bool + impl : str + +# Set default options. Defining it in +# the module makes it global to the module +# (as opposed to within a function, where +# it should only be local to the function) +OPTIONS: T_Options = { + 'silent' : False, + 'impl' : 'for_loop' +} + + +# Options for the backend implementation +_IMPL_OPTIONS = frozenset(['for_loop','dot_product']) + +# Each item of this dictionary is a test for whether a +# modification for the corresponding option was correctly +# set. I.e., "silent" can only be True or False, so the +# 'silent' dict option here tests for whether it's a bool. +_VALIDATORS = { + 'silent': lambda value: isinstance(value, bool), + 'impl': _IMPL_OPTIONS.__contains__, +} + +# Define options class +class set_options: + """ Set options for xagg. + + Parameters + ---------- + silent : bool, by default ``False`` + If True, then status updates are suppressed + + impl : str, by default ``"for_loop"`` + Sets backend algorithm, can be + + * ``for_loop``: slower, but lower memory use + * ``dot_product``: faster, but higher memory use + + """ + + def __init__(self, **kwargs): + # Keep track of changed options, to be able to change + # them back if used in a `with` block (see __exit__ below) + self.old = {} + + for k, v in kwargs.items(): + # Check to make sure the option you're looking to change + # is an option changeable by set_options() + if k not in OPTIONS: + raise ValueError( + f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}" + ) + + # Check to make sure the new value of the option is + # acceptable + if k in _VALIDATORS and not _VALIDATORS[k](v): + if k == "impl": + expected = f"Expected one of {_IMPL_OPTIONS!r}" + else: + expected = "" + raise ValueError( + f"option {k!r} given an invalid value: {v!r}. " + expected + ) + # Note original value of changed options, to reset + # defaults in the __exit__ block below + self.old[k] = OPTIONS[k] + # Update OPTIONS + self._apply_update(kwargs) + + def _apply_update(self, options_dict): + """ Update OPTIONS """ + OPTIONS.update(options_dict) + + # This allows use with "with set_options(...):" + # See e.g., https://stackoverflow.com/questions/1984325/explaining-pythons-enter-and-exit + def __enter__(self): + return + # Resets to the original OPTIONS at the end + # of a `with` block + def __exit__(self, type, value, traceback): + self._apply_update(self.old) + + +def get_options(): + """ + Get module-wide options for xagg. + + See Also + ---------- + :py:meth:`set_options` + + """ + # Returning a copy, to make sure no unintended changes to + # the output of get_options() trickle down to anything + # else that uses the variable + return OPTIONS.copy() + + \ No newline at end of file diff --git a/xagg/wrappers.py b/xagg/wrappers.py index 805593e..494ce17 100644 --- a/xagg/wrappers.py +++ b/xagg/wrappers.py @@ -3,13 +3,14 @@ import copy from . core import (create_raster_polygons,get_pixel_overlaps) +from . options import get_options def pixel_overlaps(ds,gdf_in, weights=None,weights_target='ds', subset_bbox = True, - impl='for_loop', - silent=False): + impl=None, + silent=None): """ Wrapper function for determining overlaps between grid and polygon For a geodataframe `gdf_in`, takes an `xarray` structure `ds` (Dataset or @@ -50,11 +51,12 @@ def pixel_overlaps(ds,gdf_in, a `NotImplementedError`) impl : :class:`str`, optional, default = ``'for_loop'`` + (set by :py:meth:`xa.set_options`) whether to use the ``'for_loop'`` or ``'dot_product'`` methods for aggregating; the former uses less memory, the latter may be faster in certain circumstances - silent : :py:class:`bool`, default = `False` + silent : :py:class:`bool`, default = `False` (set by :py:meth:`xa.set_options`) if True, then no status updates are printed to std out Returns @@ -64,6 +66,11 @@ def pixel_overlaps(ds,gdf_in, input into :func:`xagg.core.aggregate`. """ + if impl is None: + impl = get_options()['impl'] + if silent is None: + silent = get_options()['silent'] + # Create deep copy of gdf to ensure the input gdf doesn't # get modified gdf_in = copy.deepcopy(gdf_in)