From 4186506aa0cfe52b21355f2a1e8172fd74aeb34c Mon Sep 17 00:00:00 2001 From: Guillaume Maze Date: Wed, 6 Nov 2024 10:26:12 +0100 Subject: [PATCH] Improve perf and docstrings --- argopy/data_fetchers/erddap_data.py | 30 ++-- argopy/data_fetchers/gdac_data.py | 172 ++++++++++++------- argopy/data_fetchers/gdac_data_processors.py | 40 +++-- argopy/extensions/params_data_mode.py | 5 + argopy/fetchers.py | 2 +- argopy/stores/filesystems.py | 10 +- argopy/utils/casting.py | 5 +- argopy/utils/transform.py | 21 ++- argopy/xarray.py | 4 +- 9 files changed, 184 insertions(+), 105 deletions(-) diff --git a/argopy/data_fetchers/erddap_data.py b/argopy/data_fetchers/erddap_data.py index 191c3efd..cfce36c5 100644 --- a/argopy/data_fetchers/erddap_data.py +++ b/argopy/data_fetchers/erddap_data.py @@ -96,7 +96,7 @@ def __init__( # noqa: C901 Parameters ---------- - ds: str (optional) + ds: str, default = OPTIONS['ds'] Dataset to load: 'phy' or 'ref' or 'bgc-s' cache: bool (optional) Cache data or not (default: False) @@ -128,6 +128,13 @@ def __init__( # noqa: C901 List of BGC essential variables that can't be NaN. If set to 'all', this is an easy way to reduce the size of the :class:`xr.DataSet`` to points where all variables have been measured. Otherwise, provide a simple list of variables. + + Other parameters + ---------------- + server: str, default = OPTIONS['erddap'] + URL to erddap server + mode: str, default = OPTIONS['mode'] + """ timeout = OPTIONS["api_timeout"] if api_timeout == 0 else api_timeout self.definition = "Ifremer erddap Argo data fetcher" @@ -707,11 +714,9 @@ def to_xarray( # noqa: C901 ------- :class:`xarray.Dataset` """ - URI = self.uri # Call it once - # Should we compute (from the index) and add DATA_MODE for BGC variables: - add_dm = self.dataset_id in ["bgc", "bgc-s"] if add_dm is None else bool(add_dm) + # Pre-processor options: preprocess_opts = { "add_dm": False, "URI": URI, @@ -725,7 +730,7 @@ def to_xarray( # noqa: C901 "indexfs": self.indexfs, } - # Download data + # Download and pre-process data: results = [] if not self.parallelize: if len(URI) == 1: @@ -840,21 +845,6 @@ def to_xarray( # noqa: C901 [filtered.append(self.filter_measured(r)) for r in results] results = filtered - # empty = [] - # for v in self._bgc_vlist_measured: - # if v in results and np.count_nonzero(results[v]) != len(results["N_POINTS"]): - # empty.append(v) - # if len(empty) > 0: - # msg = ( - # "After processing, your BGC request returned final data with NaNs (%s). " - # "This may be due to the 'measured' argument ('%s') that imposes a no-NaN constraint " - # "impossible to fulfill for the access point defined (%s)]. " - # "\nUsing the 'measured' argument, you can try to minimize the list of variables to " - # "return without NaNs, or set it to 'None' to return all samples." - # % (",".join(to_list(v)), ",".join(self._bgc_measured), self.cname()) - # ) - # raise ValueError(msg) - if concat and results is not None: results["N_POINTS"] = np.arange(0, len(results["N_POINTS"])) diff --git a/argopy/data_fetchers/gdac_data.py b/argopy/data_fetchers/gdac_data.py index ccf67be8..862620a0 100644 --- a/argopy/data_fetchers/gdac_data.py +++ b/argopy/data_fetchers/gdac_data.py @@ -12,12 +12,13 @@ import warnings import getpass import logging +from typing import Literal from ..utils.format import argo_split_path from ..utils.decorators import deprecated from ..options import OPTIONS, check_gdac_path, PARALLEL_SETUP from ..errors import DataNotFound -from ..stores import ArgoIndex +from ..stores import ArgoIndex, has_distributed, distributed from .proto import ArgoDataFetcherProto from .gdac_data_processors import pre_process_multiprof, filter_points @@ -56,14 +57,13 @@ def init(self, *args, **kwargs): ### def __init__( self, - gdac: str = "", ds: str = "", cache: bool = False, cachedir: str = "", - dimension: str = "point", - errors: str = "raise", parallel: bool = False, progress: bool = False, + dimension: Literal['point', 'profile'] = "point", + errors: str = "raise", api_timeout: int = 0, **kwargs ): @@ -71,9 +71,7 @@ def __init__( Parameters ---------- - gdac: str (optional) - Path to the local or remote directory where the 'dac' folder is located - ds: str (optional) + ds: str, default = OPTIONS['ds'] Dataset to load: 'phy' or 'bgc' cache: bool (optional) Cache data or not (default: False) @@ -97,12 +95,19 @@ def __init__( Show a progress bar or not when fetching data. api_timeout: int (optional) Server request time out in seconds. Set to OPTIONS['api_timeout'] by default. + + Other parameters + ---------------- + gdac: str, default = OPTIONS['gdac'] + Path to the local or remote directory where the 'dac' folder is located """ self.timeout = OPTIONS["api_timeout"] if api_timeout == 0 else api_timeout self.dataset_id = OPTIONS["ds"] if ds == "" else ds self.user_mode = kwargs["mode"] if "mode" in kwargs else OPTIONS["mode"] - self.server = OPTIONS["gdac"] if gdac == "" else gdac + self.server = kwargs["gdac"] if "gdac" in kwargs else OPTIONS["gdac"] + self.errors = errors + self.dimension = dimension # Validate server, raise GdacPathError if not valid. check_gdac_path(self.server, errors="raise") @@ -111,7 +116,7 @@ def __init__( if self.dataset_id in ["bgc-s", "bgc-b"]: index_file = self.dataset_id - # Validation of self.server is done by the ArgoIndex: + # Validation of self.server is done by the ArgoIndex instance: self.indexfs = ArgoIndex( host=self.server, index_file=index_file, @@ -146,10 +151,14 @@ def __repr__(self): else: summary.append("📕 Index: %s (not loaded)" % self.indexfs.index_file) if hasattr(self.indexfs, "search"): - match = "matches" if self.N_FILES > 1 else "match" + match = "matches" if self.indexfs.N_MATCH > 1 else "match" summary.append( "📸 Index searched: True (%i %s, %0.4f%%)" - % (self.N_FILES, match, self.N_FILES * 100 / self.N_RECORDS) + % ( + self.indexfs.N_MATCH, + match, + self.indexfs.N_MATCH * 100 / self.N_RECORDS, + ) ) else: summary.append("📷 Index searched: False") @@ -307,7 +316,13 @@ def _preprocess_multiprof(self, ds): def pre_process(self, ds, *args, **kwargs): return pre_process_multiprof(ds, *args, **kwargs) - def to_xarray(self, errors: str = "ignore"): + def to_xarray( + self, + errors: str = "ignore", + concat: bool = True, + concat_method: Literal["drop", "fill"] = "fill", + dimension: Literal['point', 'profile'] = "", + ): """Load Argo data and return a :class:`xarray.Dataset` Parameters @@ -323,8 +338,11 @@ def to_xarray(self, errors: str = "ignore"): ------- :class:`xarray.Dataset` """ + URI = self.uri # Call it once + dimension = self.dimension if dimension == "" else dimension + if ( - len(self.uri) > 50 + len(URI) > 50 and not self.parallelize and self.parallel_method == "sequential" ): @@ -332,9 +350,10 @@ def to_xarray(self, errors: str = "ignore"): "Found more than 50 files to load, this may take a while to process sequentially ! " "Consider using another data source (eg: 'erddap') or the 'parallel=True' option to improve processing time." ) - elif len(self.uri) == 0: + elif len(URI) == 0: raise DataNotFound("No data found for: %s" % self.indexfs.cname) + # Pre-processor options: if hasattr(self, "BOX"): access_point = "BOX" access_point_opts = {"BOX": self.BOX} @@ -344,59 +363,88 @@ def to_xarray(self, errors: str = "ignore"): elif hasattr(self, "WMO"): access_point = "WMO" access_point_opts = {"WMO": self.WMO} + preprocess_opts = { + "access_point": access_point, + "access_point_opts": access_point_opts, + "pre_filter_points": self._post_filter_points, + "dimension": dimension, + } # Download and pre-process data: - ds = self.fs.open_mfdataset( - self.uri, - method=self.parallel_method, - concat_dim="N_POINTS", - concat=True, - preprocess=pre_process_multiprof, - preprocess_opts={ - "access_point": access_point, - "access_point_opts": access_point_opts, - "pre_filter_points": self._post_filter_points, - }, - progress=self.progress, - errors=errors, - open_dataset_opts={ + opts = { + "progress": self.progress, + "errors": errors, + "concat": concat, + "concat_dim": "N_POINTS", + "preprocess": pre_process_multiprof, + "preprocess_opts": preprocess_opts, + } + if self.parallel_method in ["thread"]: + opts["method"] = "thread" + opts["open_dataset_opts"] = { "xr_opts": {"decode_cf": 1, "use_cftime": 0, "mask_and_scale": 1} - }, - ) + } - # Meta-data processing: - ds["N_POINTS"] = np.arange( - 0, len(ds["N_POINTS"]) - ) # Re-index to avoid duplicate values - ds = ds.set_coords("N_POINTS") - ds = ds.sortby("TIME") - - # Remove netcdf file attributes and replace them with simplified argopy ones: - if "Fetched_from" not in ds.attrs: - raw_attrs = ds.attrs - ds.attrs = {} - ds.attrs.update({"raw_attrs": raw_attrs}) - if self.dataset_id == "phy": - ds.attrs["DATA_ID"] = "ARGO" - if self.dataset_id in ["bgc", "bgc-s"]: - ds.attrs["DATA_ID"] = "ARGO-BGC" - ds.attrs["DOI"] = "http://doi.org/10.17882/42182" - ds.attrs["Fetched_from"] = self.server - try: - ds.attrs["Fetched_by"] = getpass.getuser() - except: # noqa: E722 - ds.attrs["Fetched_by"] = "anonymous" - ds.attrs["Fetched_date"] = pd.to_datetime("now", utc=True).strftime( - "%Y/%m/%d" - ) + elif (self.parallel_method in ["process"]) | ( + has_distributed + and isinstance(self.parallel_method, distributed.client.Client) + ): + opts["method"] = self.parallel_method + opts["open_dataset_opts"] = { + "errors": "ignore", + "download_url_opts": {"errors": "ignore"}, + } + opts["progress"] = False + + results = self.fs.open_mfdataset(URI, **opts) + + if concat and results is not None: + if self.progress: + print("Final post-processing of the merged dataset ...") + # results = pre_process_multiprof(results, **preprocess_opts) + results = results.argo.cast_types(overwrite=False) + + # Meta-data processing for a single merged dataset: + results = results.assign_coords({'N_POINTS': np.arange(0, len(results['N_POINTS']))}) + results = results.sortby("TIME") + + # Remove netcdf file attributes and replace them with simplified argopy ones: + if "Fetched_from" not in results.attrs: + raw_attrs = results.attrs + + results.attrs = {} + if "Processing_history" in raw_attrs: + results.attrs.update({"Processing_history": raw_attrs["Processing_history"]}) + raw_attrs.pop("Processing_history") + results.argo.add_history("URI merged with '%s'" % concat_method) + + results.attrs.update({"raw_attrs": raw_attrs}) + if self.dataset_id == "phy": + results.attrs["DATA_ID"] = "ARGO" + if self.dataset_id in ["bgc", "bgc-s"]: + results.attrs["DATA_ID"] = "ARGO-BGC" + results.attrs["DOI"] = "http://doi.org/10.17882/42182" + results.attrs["Fetched_from"] = self.server + try: + results.attrs["Fetched_by"] = getpass.getuser() + except: # noqa: E722 + results.attrs["Fetched_by"] = "anonymous" + results.attrs["Fetched_date"] = pd.to_datetime( + "now", utc=True + ).strftime("%Y/%m/%d") + + results.attrs["Fetched_constraints"] = self.cname() + if len(self.uri) == 1: + results.attrs["Fetched_uri"] = self.uri[0] + else: + results.attrs["Fetched_uri"] = ";".join(self.uri) - ds.attrs["Fetched_constraints"] = self.cname() - if len(self.uri) == 1: - ds.attrs["Fetched_uri"] = self.uri[0] + if concat: + results.attrs = dict(sorted(results.attrs.items())) else: - ds.attrs["Fetched_uri"] = ";".join(self.uri) - - return ds + for ds in results: + ds.attrs = dict(sorted(ds.attrs.items())) + return results @deprecated( "Not serializable, please use 'gdac_data_processors.filter_points'", @@ -563,4 +611,4 @@ def uri(self): self._list_of_argo_files = URIs self.N_FILES = len(self._list_of_argo_files) - return self._list_of_argo_files \ No newline at end of file + return self._list_of_argo_files diff --git a/argopy/data_fetchers/gdac_data_processors.py b/argopy/data_fetchers/gdac_data_processors.py index 763e9b96..c4d76bcc 100644 --- a/argopy/data_fetchers/gdac_data_processors.py +++ b/argopy/data_fetchers/gdac_data_processors.py @@ -1,5 +1,6 @@ import numpy as np import xarray as xr +from typing import Literal def pre_process_multiprof( @@ -7,6 +8,7 @@ def pre_process_multiprof( access_point: str, access_point_opts: {}, pre_filter_points: bool = False, + dimension: Literal['point', 'profile'] = 'point', # dataset_id: str = "phy", # cname: str = '?', ) -> xr.Dataset: @@ -24,10 +26,10 @@ def pre_process_multiprof( if ds is None: return None - # Remove raw netcdf file attributes and replace them with argopy ones: - raw_attrs = ds.attrs - ds.attrs = {} - ds.attrs.update({"raw_attrs": raw_attrs}) + # # Remove raw netcdf file attributes and replace them with argopy ones: + # raw_attrs = ds.attrs + # ds.attrs = {} + # ds.attrs.update({"raw_attrs": raw_attrs}) # Rename JULD and JULD_QC to TIME and TIME_QC ds = ds.rename( @@ -38,6 +40,13 @@ def pre_process_multiprof( "standard_name": "time", } + # Ensure N_PROF is a coordinate + # ds = ds.assign_coords(N_PROF=np.arange(0, len(ds["N_PROF"]))) + ds = ds.reset_coords() + coords = ("LATITUDE", "LONGITUDE", "TIME", "N_PROF") + ds = ds.assign_coords({'N_PROF': np.arange(0, len(ds['N_PROF']))}) + ds = ds.set_coords(coords) + # Cast data types: ds = ds.argo.cast_types() @@ -52,9 +61,10 @@ def pre_process_multiprof( if len(list(ds[v].dims)) == 0: ds = ds.drop_vars(v) - ds = ( - ds.argo.profile2point() - ) # Default output is a collection of points, along N_POINTS + if dimension == 'point': + ds = ( + ds.argo.profile2point() + ) # Default output is a collection of points, along N_POINTS ds = ds[np.sort(ds.data_vars)] @@ -70,6 +80,11 @@ def filter_points(ds: xr.Dataset, access_point: str = None, **kwargs) -> xr.Data This may be necessary if for download performance improvement we had to work with multi instead of mono profile files: we loaded and merged multi-profile files, and then we need to make sure to retain only profiles requested. """ + dim = "N_PROF" if "N_PROF" in ds.dims else "N_POINTS" + ds = ds.assign_coords({dim: np.arange(0, len(ds[dim]))}) + if 'N_LEVELS' in ds.dims: + ds = ds.assign_coords({'N_LEVELS': np.arange(0, len(ds['N_LEVELS']))}) + if access_point == "BOX": BOX = kwargs["BOX"] # - box = [lon_min, lon_max, lat_min, lat_max, pres_min, pres_max] @@ -89,15 +104,14 @@ def filter_points(ds: xr.Dataset, access_point: str = None, **kwargs) -> xr.Data if access_point == "CYC": this_mask = xr.DataArray( - np.zeros_like(ds["N_POINTS"]), - dims=["N_POINTS"], - coords={"N_POINTS": ds["N_POINTS"]}, + np.zeros_like(ds[dim]), + dims=[dim], + coords={dim: ds[dim]}, ) for cyc in kwargs["CYC"]: this_mask += ds["CYCLE_NUMBER"] == cyc this_mask = this_mask >= 1 # any ds = ds.where(this_mask, drop=True) - ds["N_POINTS"] = np.arange(0, len(ds["N_POINTS"])) - - return ds + ds = ds.assign_coords({dim: np.arange(0, len(ds[dim]))}) + return ds \ No newline at end of file diff --git a/argopy/extensions/params_data_mode.py b/argopy/extensions/params_data_mode.py index 9892b495..46000bff 100644 --- a/argopy/extensions/params_data_mode.py +++ b/argopy/extensions/params_data_mode.py @@ -306,6 +306,11 @@ def filter( # Determine the list of variables to filter: params = to_list(params) + + if len(params) == 0: + this.argo.add_history("Found no variables to select according to DATA_MODE") + return this + if params[0] == "all": if "DATA_MODE" in this.data_vars: params = ["PRES", "TEMP"] diff --git a/argopy/fetchers.py b/argopy/fetchers.py index b0e9306c..a86e70b7 100755 --- a/argopy/fetchers.py +++ b/argopy/fetchers.py @@ -336,7 +336,7 @@ def index(self): not isinstance(self._index, pd.core.frame.DataFrame) or self._request != self.__repr__() ): - self.load() + self._index = self.to_index() return self._index @property diff --git a/argopy/stores/filesystems.py b/argopy/stores/filesystems.py index 2cc59cc6..c79a493a 100644 --- a/argopy/stores/filesystems.py +++ b/argopy/stores/filesystems.py @@ -35,7 +35,7 @@ import tempfile import logging from packaging import version -from typing import Union, Any, List +from typing import Union, Any, List, Literal from collections.abc import Callable from urllib.parse import urlparse, parse_qs from functools import lru_cache @@ -1133,10 +1133,11 @@ def open_mfdataset( progress: Union[bool, str] = False, concat: bool = True, concat_dim: str = "row", + concat_method: Literal["drop", "fill"] = "drop", preprocess: Callable = None, preprocess_opts: dict = {}, open_dataset_opts: dict = {}, - errors: str = "ignore", + errors: Literal['ignore', 'raise', 'silent'] = "ignore", compute_details: bool = False, *args, **kwargs, @@ -1380,7 +1381,10 @@ def open_mfdataset( if len(results) > 0: if concat: # ds = xr.concat(results, dim=concat_dim, data_vars='all', coords='all', compat='override') - results = drop_variables_not_in_all_datasets(results) + if concat_method == 'drop': + results = drop_variables_not_in_all_datasets(results) + elif concat_method == 'fill': + results = fill_variables_not_in_all_datasets(results) ds = xr.concat( results, dim=concat_dim, diff --git a/argopy/utils/casting.py b/argopy/utils/casting.py index 166c60aa..25ac5c8f 100644 --- a/argopy/utils/casting.py +++ b/argopy/utils/casting.py @@ -176,7 +176,10 @@ def cast_this_da(da, v): return da for v in ds.variables: - if overwrite or ("casted" in ds[v].attrs and ds[v].attrs["casted"] == 0): + if (overwrite or + ("casted" in ds[v].attrs and ds[v].attrs["casted"] == 0) or + (not overwrite and "casted" in ds[v].attrs and ds[v].attrs["casted"] == 1 and ds[v].dtype == 'O') + ): try: ds[v] = cast_this_da(ds[v], v) except Exception: diff --git a/argopy/utils/transform.py b/argopy/utils/transform.py index 1224f523..87d55f31 100644 --- a/argopy/utils/transform.py +++ b/argopy/utils/transform.py @@ -27,6 +27,10 @@ def drop_variables_not_in_all_datasets( Returns ------- List[xarray.Dataset] + + See Also + -------- + :func:`argopy.utils.fill_variables_not_in_all_datasets` """ # List all possible data variables: @@ -45,6 +49,7 @@ def drop_variables_not_in_all_datasets( # List of dataset with missing variables: # ir_missing = np.sum(ishere, axis=0) < len(vlist) + # List of variables missing in some dataset: iv_missing = np.sum(ishere, axis=1) < len(ds_collection) if len(iv_missing) > 0: @@ -82,6 +87,10 @@ def fill_variables_not_in_all_datasets( Returns ------- List[xarray.Dataset] + + See Also + -------- + :func:`argopy.utils.drop_variables_not_in_all_datasets` """ def first_variable_with_concat_dim(this_ds, concat_dim="rows"): @@ -111,14 +120,14 @@ def fillvalue(da): for res in ds_collection: [vlist.append(v) for v in list(res.variables) if concat_dim in res[v].dims] vlist = np.unique(vlist) - # log.debug('variables', vlist) + log.debug("variables: %s" % vlist) # List all possible coordinates: clist = [] for res in ds_collection: [clist.append(c) for c in list(res.coords) if concat_dim in res[c].dims] clist = np.unique(clist) - # log.debu('coordinates', clist) + log.debug("coordinates: %s" % clist) # Get the first occurrence of each variable, to be used as a template for attributes and dtype meta = {} @@ -130,7 +139,7 @@ def fillvalue(da): "dtype": ds[v].dtype, "fill_value": fillvalue(ds[v]), } - # [log.debug(meta[m]) for m in meta.keys()] + [log.debug(meta[m]) for m in meta.keys()] # Add missing variables to dataset datasets = [ds.copy() for ds in ds_collection] @@ -354,6 +363,12 @@ def split_data_mode(ds: xr.Dataset) -> xr.Dataset: """ if "STATION_PARAMETERS" in ds and "PARAMETER_DATA_MODE" in ds: + # Ensure N_PROF is a coordinate + # otherwise, the ``ds[name] = da`` line below will fail when a PARAMETER is not + # available in all profiles, hence da['N_PROF'] != ds['N_PROF'] + if "N_PROF" in ds.dims and "N_PROF" not in ds.coords: + ds = ds.assign_coords(N_PROF=np.arange(0, len(ds["N_PROF"]))) + u64 = lambda s: "%s%s" % (s, " " * (64 - len(s))) # noqa: E731 params = [p.strip() for p in np.unique(ds["STATION_PARAMETERS"])] diff --git a/argopy/xarray.py b/argopy/xarray.py index a64f0b7c..900d8e5c 100644 --- a/argopy/xarray.py +++ b/argopy/xarray.py @@ -615,7 +615,7 @@ def profile2point(self) -> xr.Dataset: ds = ds.where(~np.isnan(ds["PRES"]), drop=1) ds = ds.sortby("TIME") if "TIME" in ds else ds.sortby("JULD") ds["N_POINTS"] = np.arange(0, len(ds["N_POINTS"])) - ds = cast_Argo_variable_type(ds) + ds = cast_Argo_variable_type(ds, overwrite=False) ds = ds[np.sort(ds.data_vars)] ds.encoding = self.encoding # Preserve low-level encoding information ds.argo.add_history("Transformed with 'profile2point'") @@ -691,7 +691,7 @@ def filter_qc( # noqa: C901 ) if len(QC_fields) == 0: - this.argo._add_history( + this.argo.add_history( "Variables selected according to QC (but found no QC variables)" ) return this