From 0675785d60aed6a81744cfc14ac73d0ba63dd217 Mon Sep 17 00:00:00 2001 From: Brian Blaylock Date: Fri, 26 Jul 2024 21:50:09 -0700 Subject: [PATCH 1/4] initial use of Type Hints --- environment-dev.yml | 1 + herbie/__init__.py | 6 +-- herbie/core.py | 104 +++++++++++++++++++++++--------------------- 3 files changed, 58 insertions(+), 53 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index 5dc1482f..189f494b 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -26,6 +26,7 @@ dependencies: - netcdf4 - numpy>=1.25 - pandas>=2.0 + - polars>=1.0 - pygrib>=2.1.4 - pylint - pyproj>=3.6 diff --git a/herbie/__init__.py b/herbie/__init__.py index 0231793b..07277612 100644 --- a/herbie/__init__.py +++ b/herbie/__init__.py @@ -42,7 +42,7 @@ ## TODO: Will the `_version.py` file *always* be present? ## TODO: What if the person doesn't do "pip install" from ._version import __version__, __version_tuple__ -except: +except Exception: __version__ = "unknown" __version_tuple__ = (999, 999, 999) @@ -50,7 +50,7 @@ ######################################################################## # Overload Path object with my custom `expand` method so the user can # set environment variables in the config file (e.g., ${HOME}). -def _expand(self, resolve=False, absolute=False): +def _expand(self, resolve: bool = False, absolute: bool = False) -> Path: """ Fully expand the Path with the given environment variables. @@ -149,7 +149,7 @@ def template(self): try: # Load the Herbie config file config = toml.load(_config_file) -except: +except Exception: try: # Create the Herbie config file _config_path.mkdir(parents=True, exist_ok=True) diff --git a/herbie/core.py b/herbie/core.py index 4f0d6130..b62b47d0 100644 --- a/herbie/core.py +++ b/herbie/core.py @@ -18,6 +18,7 @@ from datetime import datetime, timedelta from io import StringIO from shutil import which +from typing import Union, Optional, Literal import cfgrib import pandas as pd @@ -59,7 +60,7 @@ ) -def wgrib2_idx(grib2filepath): +def wgrib2_idx(grib2filepath: Union[Path, str]) -> str: """ Produce the GRIB2 inventory index with wgrib2. @@ -81,7 +82,7 @@ def wgrib2_idx(grib2filepath): raise RuntimeError("wgrib2 command was not found.") -def create_index_files(path, overwrite=False): +def create_index_files(path: Union[Path, str], overwrite: bool = False): """Create an index file for all GRIB2 files in a directory. Parameters @@ -91,7 +92,7 @@ def create_index_files(path, overwrite=False): overwrite : bool Overwrite index file if it exists. """ - path = Path(path).expand() + path = Path(path) files = [] if path.is_dir(): # List all GRIB2 files in the directory @@ -168,16 +169,16 @@ class Herbie: def __init__( self, - date=None, + date: Optional[Union[datetime, str]] = None, *, - valid_date=None, - model=config["default"].get("model"), - fxx=config["default"].get("fxx"), - product=config["default"].get("product"), - priority=config["default"].get("priority"), - save_dir=config["default"].get("save_dir"), - overwrite=config["default"].get("overwrite", False), - verbose=config["default"].get("verbose", True), + valid_date: Optional[Union[datetime, str]] = None, + model: str = config["default"].get("model"), + fxx: int = config["default"].get("fxx"), + product: str = config["default"].get("product"), + priority: Union[str, list[str]] = config["default"].get("priority"), + save_dir: Union[Path, str] = config["default"].get("save_dir"), + overwrite: bool = config["default"].get("overwrite", False), + verbose: bool = config["default"].get("verbose", True), **kwargs, ): """Specify model output and find GRIB2 file at one of the sources.""" @@ -282,7 +283,7 @@ def __init__( f"β”Š {ANSI.green}{self.date:%Y-%b-%d %H:%M UTC}{ANSI.bright_green} F{self.fxx:02d}{ANSI.reset}", ) - def __repr__(self): + def __repr__(self) -> str: """Representation in Notebook.""" msg = ( f"{ANSI.herbie} {self.model.upper()} model", @@ -292,13 +293,13 @@ def __repr__(self): ) return " ".join(msg) - def __str__(self): + def __str__(self) -> str: """When Herbie class object is printed, print all properties.""" # * Keep this simple so it runs fast. msg = (f"β•‘HERBIEβ•  {self.model.upper()}:{self.product}",) return " ".join(msg) - def __bool__(self): + def __bool__(self) -> bool: """Herbie evaluated True if the GRIB file exists.""" return bool(self.grib) @@ -375,7 +376,7 @@ def _ping_pando(self): print("πŸ€πŸ»β›” Bad handshake with pando? Am I able to move on?") pass - def _check_grib(self, url, min_content_length=10): + def _check_grib(self, url: str, min_content_length: int = 10) -> bool: """ Check that the GRIB2 URL exist and is of useful length. @@ -398,7 +399,7 @@ def _check_grib(self, url, min_content_length=10): else: return False - def _check_idx(self, url, verbose=False): + def _check_idx(self, url: str, verbose: bool = False) -> tuple[bool, Optional[str]]: """Check if an index file exist for the GRIB2 URL.""" # To check inventory files with slightly different URL structure # we will loop through the IDX_SUFFIX. @@ -427,7 +428,7 @@ def _check_idx(self, url, verbose=False): ) return False, None - def find_grib(self): + def find_grib(self) -> tuple[Optional[Union[Path, str]], Optional[str]]: """Find a GRIB file from the archive sources. Returns @@ -465,15 +466,15 @@ def find_grib(self): grib_url = self.SOURCES[source] if source.startswith("local"): - grib_path = Path(grib_url).expand() + grib_path = Path(grib_url) if grib_path.exists(): - return [grib_path, source] + return (grib_path, source) elif self._check_grib(grib_url): - return [grib_url, source] + return (grib_url, source) - return [None, None] + return (None, None) - def find_idx(self): + def find_idx(self) -> tuple[Optional[Union[Path, str]], Optional[str]]: """Find an index file for the GRIB file.""" # If priority list is set, we want to search SOURCES in that # priority order. If priority is None, then search all SOURCES @@ -498,31 +499,31 @@ def find_idx(self): grib_url = self.SOURCES[source] if source.startswith("local"): - local_grib = Path(grib_url).expand() + local_grib = Path(grib_url) local_idx = local_grib.with_suffix(self.IDX_SUFFIX[0]) if local_idx.exists(): - return [local_idx, "local"] + return (local_idx, "local") else: idx_exists, idx_url = self._check_idx(grib_url) if idx_exists: - return [idx_url, source] + return (idx_url, source) - return [None, None] + return (None, None) @property - def get_remoteFileName(self, source=None): + def get_remoteFileName(self, source: Optional[str] = None) -> str: """Predict remote file name (assumes all sources are named the same).""" if source is None: source = list(self.SOURCES)[0] return self.SOURCES[source].split("/")[-1] @property - def get_localFileName(self): + def get_localFileName(self) -> str: """Predict the local file name.""" return self.LOCALFILE - def get_localFilePath(self, search=None, *, searchString=None): + def get_localFilePath(self, search: Optional[str] = None, *, searchString=None): """Get full path to the local file.""" # TODO: Remove this eventually if searchString is not None: @@ -535,10 +536,7 @@ def get_localFilePath(self, search=None, *, searchString=None): # Predict the localFileName from the first model template SOURCE. localFilePath = ( - self.save_dir.expand() - / self.model - / f"{self.date:%Y%m%d}" - / self.get_localFileName + self.save_dir / self.model / f"{self.date:%Y%m%d}" / self.get_localFileName ) # Check if any sources in a model template are "local" @@ -546,9 +544,9 @@ def get_localFilePath(self, search=None, *, searchString=None): if any([i.startswith("local") for i in self.SOURCES.keys()]): localFilePath = next( ( - Path(self.SOURCES[i]).expand() + Path(self.SOURCES[i]) for i in self.SOURCES - if i.startswith("local") and Path(self.SOURCES[i]).expand().exists() + if i.startswith("local") and Path(self.SOURCES[i]).exists() ), localFilePath, ) @@ -595,7 +593,7 @@ def get_localFilePath(self, search=None, *, searchString=None): return localFilePath @functools.cached_property - def index_as_dataframe(self): + def index_as_dataframe(self) -> pd.DataFrame: """Read and cache the full index file.""" if self.grib_source == "local" and wgrib2: # Generate IDX inventory with wgrib2 @@ -762,7 +760,13 @@ def index_as_dataframe(self): return df - def inventory(self, search=None, *, searchString=None, verbose=None): + def inventory( + self, + search: Optional[str] = None, + *, + searchString=None, + verbose: Optional[bool] = None, + ) -> pd.DataFrame: """ Inspect the GRIB2 file contents by reading the index file. @@ -817,15 +821,15 @@ def inventory(self, search=None, *, searchString=None, verbose=None): def download( self, - search=None, + search: Optional[str] = None, *, searchString=None, - source=None, - save_dir=None, - overwrite=None, - verbose=None, - errors="warn", - ): + source: Optional[str] = None, + save_dir: Optional[Union[str, Path]] = None, + overwrite: Optional[bool] = None, + verbose: Optional[bool] = None, + errors: Literal["warn", "raise"] = "warn", + ) -> Path: """ Download file from source. @@ -1048,13 +1052,13 @@ def subset(search, outFile): def xarray( self, - search=None, + search: Optional[str] = None, *, searchString=None, - backend_kwargs={}, - remove_grib=True, + backend_kwargs: dict = {}, + remove_grib: bool = True, **download_kwargs, - ): + ) -> xr.Dataset: """ Open GRIB2 data as xarray DataSet. @@ -1188,7 +1192,7 @@ def xarray( ) return Hxr - def terrain(self, water_masked=True): + def terrain(self, water_masked: bool = True) -> xr.Dataset: """Shortcut method to return model terrain as an xarray.Dataset.""" ds = self.xarray(":(?:HGT|LAND):surface") if water_masked: From 35faff182e0f198a5cef47abd68e64703bb430b6 Mon Sep 17 00:00:00 2001 From: Brian Blaylock Date: Fri, 26 Jul 2024 21:54:02 -0700 Subject: [PATCH 2/4] add pd.Timestamp as type for date input --- herbie/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/herbie/core.py b/herbie/core.py index b62b47d0..47304f3c 100644 --- a/herbie/core.py +++ b/herbie/core.py @@ -169,9 +169,9 @@ class Herbie: def __init__( self, - date: Optional[Union[datetime, str]] = None, + date: Optional[Union[datetime, pd.Timestamp, str]] = None, *, - valid_date: Optional[Union[datetime, str]] = None, + valid_date: Optional[Union[datetime, pd.Timestamp, str]] = None, model: str = config["default"].get("model"), fxx: int = config["default"].get("fxx"), product: str = config["default"].get("product"), From 0e6352328bac0c65ff5e380fd4e7769b84f4b158 Mon Sep 17 00:00:00 2001 From: Brian Blaylock Date: Sat, 27 Jul 2024 08:33:21 -0700 Subject: [PATCH 3/4] added some more Type Hints --- herbie/accessors.py | 46 ++++++++++---------- herbie/core.py | 24 ++++++----- herbie/fast.py | 48 +++++++++++++-------- herbie/toolbox/cartopy_tools.py | 75 ++++++++++++++++++++------------- 4 files changed, 115 insertions(+), 78 deletions(-) diff --git a/herbie/accessors.py b/herbie/accessors.py index 12879b98..f8f0f4eb 100644 --- a/herbie/accessors.py +++ b/herbie/accessors.py @@ -14,6 +14,7 @@ import re import warnings from pathlib import Path +from typing import Literal, Optional, Union import numpy as np import pandas as pd @@ -23,7 +24,6 @@ import herbie - _level_units = dict( adiabaticCondensation="adiabatic condensation", atmosphere="atmosphere", @@ -54,7 +54,7 @@ ) -def add_proj_info(ds): +def add_proj_info(ds: xr.Dataset): """Add projection info to a Dataset.""" match = re.search(r'"source": "(.*?)"', ds.history) FILE = Path(match.group(1)) @@ -97,7 +97,7 @@ def __init__(self, xarray_obj): self._center = None @property - def center(self): + def center(self) -> tuple[float, float]: """Return the geographic center point of this dataset.""" if self._center is None: # we can use a cache on our accessor objects, because accessors @@ -107,19 +107,18 @@ def center(self): self._center = (float(lon.mean()), float(lat.mean())) return self._center - def to_180(self): + def to_180(self) -> xr.Dataset: """Wrap longitude coordinates as range [-180,180].""" ds = self._obj ds["longitude"] = (ds["longitude"] + 180) % 360 - 180 return ds - def to_360(self): + def to_360(self) -> xr.Dataset: """Wrap longitude coordinates as range [0,360].""" ds = self._obj ds["longitude"] = (ds["longitude"] - 360) % 360 return ds - @functools.cached_property def crs(self): """ @@ -196,7 +195,9 @@ def polygon(self): return domain_polygon, domain_polygon_latlon - def with_wind(self, which="both"): + def with_wind( + self, which: Literal["both", "speed", "direction"] = "both" + ) -> xr.Dataset: """Return Dataset with calculated wind speed and/or direction. Consistent with the eccodes GRIB parameter database, variables @@ -228,7 +229,7 @@ def with_wind(self, which="both"): ds["si10"].attrs["standard_name"] = "wind_speed" ds["si10"].attrs["grid_mapping"] = ds.u10.attrs.get("grid_mapping") n_computed += 1 - + if {"u100", "v100"}.issubset(ds): ds["si100"] = np.sqrt(ds.u100**2 + ds.v100**2) ds["si100"].attrs["GRIB_paramId"] = 228249 @@ -237,7 +238,7 @@ def with_wind(self, which="both"): ds["si100"].attrs["standard_name"] = "wind_speed" ds["si100"].attrs["grid_mapping"] = ds.u100.attrs.get("grid_mapping") n_computed += 1 - + if {"u80", "v80"}.issubset(ds): ds["si80"] = np.sqrt(ds.u80**2 + ds.v80**2) ds["si80"].attrs["long_name"] = "80 metre wind speed" @@ -266,7 +267,7 @@ def with_wind(self, which="both"): ds["wdir10"].attrs["standard_name"] = "wind_from_direction" ds["wdir10"].attrs["grid_mapping"] = ds.u10.attrs.get("grid_mapping") n_computed += 1 - + if {"u100", "v100"}.issubset(ds): ds["wdir100"] = ( (270 - np.rad2deg(np.arctan2(ds.v100, ds.u100))) % 360 @@ -276,7 +277,7 @@ def with_wind(self, which="both"): ds["wdir100"].attrs["standard_name"] = "wind_from_direction" ds["wdir100"].attrs["grid_mapping"] = ds.u100.attrs.get("grid_mapping") n_computed += 1 - + if {"u80", "v80"}.issubset(ds): ds["wdir80"] = ( (270 - np.rad2deg(np.arctan2(ds.v80, ds.u80))) % 360 @@ -305,15 +306,15 @@ def with_wind(self, which="both"): def pick_points( self, - points, - method="nearest", + points: pd.DataFrame, + method: Literal["nearest", "weighted"] = "nearest", *, - k=None, - max_distance=500, - use_cached_tree=True, - tree_name=None, - verbose=False, - ): + k: Optional[int] = None, + max_distance: Union[int, float] = 500, + use_cached_tree: Union[bool, Literal["replant"]] = True, + tree_name: Optional[str] = None, + verbose: bool = False, + ) -> xr.Dataset: """Pick nearest neighbor grid values at selected points. Parameters @@ -384,7 +385,7 @@ def pick_points( "`pip install 'herbie-data[extras]'` for the full functionality." ) - def plant_tree(save_pickle=None): + def plant_tree(save_pickle: Optional[Union[Path, str]] = None): """Grow a new BallTree object from seedling.""" timer = pd.Timestamp("now") print("INFO: 🌱 Growing new BallTree...", end="") @@ -719,9 +720,10 @@ def plot(self, ax=None, common_features_kw={}, vars=None, **kwargs): raise NotImplementedError("Plotting functionality is not working right now.") try: - from herbie.toolbox import EasyMap, pc - from herbie import paint import matplotlib.pyplot as plt + + from herbie import paint + from herbie.toolbox import EasyMap, pc except ModuleNotFoundError: raise ModuleNotFoundError( "cartopy is an 'extra' requirement. Please use " diff --git a/herbie/core.py b/herbie/core.py index 47304f3c..fd5375c0 100644 --- a/herbie/core.py +++ b/herbie/core.py @@ -32,6 +32,8 @@ from herbie.help import _search_help from herbie.misc import ANSI +Datetime = Union[datetime, pd.Timestamp, str] + # NOTE: The config dict values are retrieved from __init__ and read # from the file ${HOME}/.config/herbie/config.toml # Path is imported from __init__ because it has my custom methods. @@ -82,7 +84,7 @@ def wgrib2_idx(grib2filepath: Union[Path, str]) -> str: raise RuntimeError("wgrib2 command was not found.") -def create_index_files(path: Union[Path, str], overwrite: bool = False): +def create_index_files(path: Union[Path, str], overwrite: bool = False) -> None: """Create an index file for all GRIB2 files in a directory. Parameters @@ -169,9 +171,9 @@ class Herbie: def __init__( self, - date: Optional[Union[datetime, pd.Timestamp, str]] = None, + date: Optional[Datetime] = None, *, - valid_date: Optional[Union[datetime, pd.Timestamp, str]] = None, + valid_date: Optional[Datetime] = None, model: str = config["default"].get("model"), fxx: int = config["default"].get("fxx"), product: str = config["default"].get("product"), @@ -303,7 +305,7 @@ def __bool__(self) -> bool: """Herbie evaluated True if the GRIB file exists.""" return bool(self.grib) - def help(self): + def help(self) -> None: """Print help message if available.""" if hasattr(self, "HELP"): HELP = self.HELP.strip().replace("\n", "\nβ”‚ ") @@ -320,7 +322,7 @@ def help(self): print("β”‚") print("╰─────────────────────────────────────────") - def tell_me_everything(self): + def tell_me_everything(self) -> None: """Print all the attributes of the Herbie object.""" msg = [] for i in dir(self): @@ -330,11 +332,11 @@ def tell_me_everything(self): msg = "\n".join(msg) print(msg) - def __logo__(self): + def __logo__(self) -> None: """For Fun, show the Herbie Logo.""" print(ANSI.ascii) - def _validate(self): + def _validate(self) -> None: """Validate the Herbie class input arguments.""" # Accept model alias if self.model.lower() == "alaska": @@ -368,7 +370,7 @@ def _validate(self): if self.date < expired: self.priority.remove("nomads") - def _ping_pando(self): + def _ping_pando(self) -> None: """Pinging the Pando server before downloading can prevent a bad handshake.""" try: requests.head("https://pando-rgw01.chpc.utah.edu/") @@ -523,9 +525,11 @@ def get_localFileName(self) -> str: """Predict the local file name.""" return self.LOCALFILE - def get_localFilePath(self, search: Optional[str] = None, *, searchString=None): + def get_localFilePath( + self, search: Optional[str] = None, *, searchString=None + ) -> Path: """Get full path to the local file.""" - # TODO: Remove this eventually + # TODO: Remove this check for searString eventually if searchString is not None: warnings.warn( "The argument `searchString` was renamed `search`. Please update your scripts.", diff --git a/herbie/fast.py b/herbie/fast.py index fa638944..c0aa6cda 100644 --- a/herbie/fast.py +++ b/herbie/fast.py @@ -8,9 +8,10 @@ """ import logging - -# Multithreading :) from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from typing import Union, Optional +from pathlib import Path import pandas as pd import xarray as xr @@ -19,6 +20,7 @@ log = logging.getLogger(__name__) +Datetime = Union[datetime, pd.Timestamp, str] """ πŸ§΅πŸ€ΉπŸ»β€β™‚οΈ Notice! Multithreading and Multiprocessing is use @@ -30,8 +32,8 @@ """ -def _validate_fxx(fxx): - """Fast Herbie requires fxx as a list-like""" +def _validate_fxx(fxx: Union[int, Union[list[int], range]]) -> Union[list[int], range]: + """Fast Herbie requires fxx as a list-like.""" if isinstance(fxx, int): fxx = [fxx] @@ -41,8 +43,8 @@ def _validate_fxx(fxx): return fxx -def _validate_DATES(DATES): - """Fast Herbie requires DATES as a list-like""" +def _validate_DATES(DATES: Union[Datetime, list[Datetime]]) -> list[Datetime]: + """Fast Herbie requires DATES as a list-like.""" if isinstance(DATES, str): DATES = [pd.to_datetime(DATES)] elif not hasattr(DATES, "__len__"): @@ -56,7 +58,7 @@ def _validate_DATES(DATES): return DATES -def Herbie_latest(n=6, freq="1h", **kwargs): +def Herbie_latest(n: int = 6, freq: str = "1h", **kwargs) -> Herbie: """Search for the most recent GRIB2 file (using multithreading). Parameters @@ -85,7 +87,16 @@ def Herbie_latest(n=6, freq="1h", **kwargs): class FastHerbie: - def __init__(self, DATES, fxx=[0], *, max_threads=50, **kwargs): + """Create many Herbie objects quickly.""" + + def __init__( + self, + DATES: Union[Datetime, list[Datetime]], + fxx: Union[int, list[int]] = [0], + *, + max_threads: int = 50, + **kwargs, + ): """Create many Herbie objects with methods to download or read with xarray. Uses multithreading. @@ -156,10 +167,11 @@ def __init__(self, DATES, fxx=[0], *, max_threads=50, **kwargs): f"Could not find {len(self.file_not_exists)}/{len(self.file_exists)} GRIB files." ) - def __len__(self): + def __len__(self) -> int: + """Return the number of Herbie objects.""" return len(self.objects) - def df(self): + def df(self) -> pd.DataFrame: """Organize Herbie objects into a DataFrame. #? Why is this inefficient? Takes several seconds to display because the __str__ does a lot. @@ -172,7 +184,7 @@ def df(self): ds_list, index=self.DATES, columns=[f"F{i:02d}" for i in self.fxx] ) - def inventory(self, search=None): + def inventory(self, search: Optional[str] = None): """Get combined inventory DataFrame. Useful for data discovery and checking your search before @@ -186,8 +198,10 @@ def inventory(self, search=None): dfs.append(df) return pd.concat(dfs, ignore_index=True) - def download(self, search=None, *, max_threads=20, **download_kwargs): - r"""Download many Herbie objects + def download( + self, search: Optional[str] = None, *, max_threads: int = 20, **download_kwargs + ) -> list[Path]: + r"""Download many Herbie objects. Uses multithreading. @@ -231,11 +245,11 @@ def download(self, search=None, *, max_threads=20, **download_kwargs): def xarray( self, - search, + search: Optional[str], *, - max_threads=None, + max_threads: Optional[int] = None, **xarray_kwargs, - ): + ) -> xr.Dataset: """Read many Herbie objects into an xarray Dataset. # TODO: Sometimes the Jupyter Cell always crashes when I run this. @@ -302,7 +316,7 @@ def xarray( concat_dim=["time", "step"], combine_attrs="drop_conflicts", ) - except: + except Exception: # TODO: I'm not sure why some cases doesn't like the combine_attrs argument ds = xr.combine_nested( ds_list, diff --git a/herbie/toolbox/cartopy_tools.py b/herbie/toolbox/cartopy_tools.py index 5bce2dc6..e81f2943 100644 --- a/herbie/toolbox/cartopy_tools.py +++ b/herbie/toolbox/cartopy_tools.py @@ -36,6 +36,7 @@ from metpy.plots import USCOUNTIES from mpl_toolkits.axes_grid1.inset_locator import InsetPosition +from typing import Literal, Optional, Union from herbie import Path try: @@ -299,7 +300,11 @@ def check_cartopy_axes(ax=None, crs=pc, *, fignum=None, verbose=False): raise TypeError("🌎 Sorry. The `ax` you gave me is not a cartopy axes.") -def get_ETOPO1(top="ice", coarsen=None, thin=None): +def get_ETOPO1( + top: Literal["bedrock", "ice"] = "ice", + coarsen: Optional[int] = None, + thin: Optional[int] = None, +): """ Return the ETOPO1 elevation and bathymetry DataArray. @@ -381,12 +386,12 @@ def _reporthook(a, b, c): def inset_global_map( ax, - x=0.95, - y=0.95, - size=0.3, + x: float = 0.95, + y: float = 0.95, + size: float = 0.3, theme=None, - facecolor="#f88d0083", - kind="area", + facecolor: str = "#f88d0083", + kind: Literal["point", "area"] = "area", ): """Add an inset map showing the location of the main map on the globe. @@ -478,7 +483,12 @@ def inset_global_map( return ax_inset -def state_polygon(state=None, country="USA", county=None, verbose=True): +def state_polygon( + state: Optional[str] = None, + country: str = "USA", + county: Optional[str] = None, + verbose: bool = True, +): """ Return a shapely polygon of US state boundaries or country borders. @@ -528,18 +538,18 @@ class EasyMap: def __init__( self, - scale="110m", + scale: Literal["110m", "50m", "10m"] = "110m", ax=None, crs=pc, *, figsize=None, - fignum=None, - dpi=None, - theme=None, - verbose=False, - add_coastlines=True, - facecolor=None, + fignum: int = None, + dpi: int = None, + theme: Optional[Literal["dark", "grey"]] = None, + add_coastlines: bool = True, + facecolor: Optional[str] = None, coastlines_kw={}, + verbose: bool = False, **kwargs, ): """ @@ -804,10 +814,10 @@ def LAKES(self, **kwargs): # Less commonly needed features def TERRAIN( self, - coarsen=30, + coarsen: int = 30, *, - top="ice", - kind="pcolormesh", + top: Literal["ice", "bedrock"] = "ice", + kind: Literal["pcolormesh", "contourf"] = "pcolormesh", extent=None, **kwargs, ): @@ -885,10 +895,10 @@ def TERRAIN( def BATHYMETRY( self, - coarsen=30, + coarsen: int = 30, *, - top="ice", - kind="pcolormesh", + top: Literal["ice", "bedrock"] = "ice", + kind: Literal["pcolormesh", "contourf"] = "pcolormesh", extent=None, **kwargs, ): @@ -1044,10 +1054,10 @@ def ROADS(self, road_types=None, **kwargs): def PLACES( self, - country="United States", - rank=2, - scatter=True, - labels=True, + country: str = "United States", + rank: int = 2, + scatter: bool = True, + labels: bool = True, label_kw={}, scatter_kw={}, ): @@ -1090,7 +1100,14 @@ def PLACES( # ============ # Tiled images - def STAMEN(self, style="terrain-background", zoom=3, alpha=1): + def STAMEN( + self, + style: Literal[ + "terrain-background", "terrain", "toner-background", "toner", "watercolor" + ] = "terrain-background", + zoom: int = 3, + alpha=1, + ): """ Add Stamen map tiles to background. @@ -1139,7 +1156,7 @@ def STAMEN(self, style="terrain-background", zoom=3, alpha=1): return self - def OSM(self, zoom=1, alpha=1): + def OSM(self, zoom: int = 1, alpha=1): """ Add Open Street Map tiles as background image. @@ -1187,9 +1204,9 @@ def DOMAIN( x, y=None, *, - text=None, - method="cutout", - facealpha=0.25, + text: Optional[str] = None, + method: Literal["fill", "cutout", "border"] = "cutout", + facealpha: Union[Literal[0], Literal[1], float] = 0.25, text_kwargs={}, **kwargs, ): From 9cff05af92dac75a1c579f8bed5ef7f9d715ab10 Mon Sep 17 00:00:00 2001 From: Brian Blaylock Date: Wed, 31 Jul 2024 21:53:14 -0700 Subject: [PATCH 4/4] added type hints to cartopy_tools --- herbie/toolbox/cartopy_tools.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/herbie/toolbox/cartopy_tools.py b/herbie/toolbox/cartopy_tools.py index e81f2943..b386ced2 100644 --- a/herbie/toolbox/cartopy_tools.py +++ b/herbie/toolbox/cartopy_tools.py @@ -47,6 +47,8 @@ # ) pass +type ExtentPadding = Union[Literal["auto"], float, dict[str, float]] + pc = ccrs.PlateCarree() pc._threshold = 0.01 # https://github.com/SciTools/cartopy/issues/8 @@ -89,7 +91,12 @@ def to_360(lon): ######################################################################## # Methods attached to axes created by `EasyMap` -def _adjust_extent(self, pad="auto", fraction=0.05, verbose=False): +def _adjust_extent( + self, + pad: ExtentPadding = "auto", + fraction: float = 0.05, + verbose: bool = False, +): """ Adjust the extent of an existing cartopy axes. @@ -148,13 +155,13 @@ def _adjust_extent(self, pad="auto", fraction=0.05, verbose=False): def _center_extent( self, - lon=None, - lat=None, - city=None, - state=None, + lon: Optional[Union[int, float]] = None, + lat: Optional[Union[int, float]] = None, + city: Optional[str] = None, + state: Optional[str] = None, *, - pad="auto", - verbose=False, + pad: ExtentPadding = "auto", + verbose: bool = False, ): """ Change the map extent to be centered on a point and adjust padding. @@ -252,7 +259,9 @@ def _copy_extent(self, src_ax): ######################################################################## # Main Functions -def check_cartopy_axes(ax=None, crs=pc, *, fignum=None, verbose=False): +def check_cartopy_axes( + ax=None, crs=pc, *, fignum: Optional[int] = None, verbose: bool = False +): """ Check if an axes is a cartopy axes, else create a new cartopy axes. @@ -389,7 +398,7 @@ def inset_global_map( x: float = 0.95, y: float = 0.95, size: float = 0.3, - theme=None, + theme: Optional[Literal["dark", "grey"]] = None, facecolor: str = "#f88d0083", kind: Literal["point", "area"] = "area", ):