diff --git a/CHANGELOG.md b/CHANGELOG.md index 228b9cbc..587853d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog -## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.9.1] +## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [2.10.0] +This version improves behaviour of loading revisions and loading datasets from list_datasets output. + +### Modified + +- sub-collections no longer captured when filtering with filename that starts with wildcard in wildcard mode +- bugfix of spurious error raised when loading dataset with a revision provided +- default_revisions_only parameter in One.list_datasets filters non-default datasets +- permit data frame input to One.load_datasets and load precise relative paths provided (instead of default revisions) +- redundent session_path column has been dropped from the datasets cache table +- bugfix in one.params.setup: suggest previous cache dir if available instead of always the default +- bugfix in one.params.setup: remove all extrenuous parameters (i.e. TOKEN) when running setup in silent mode +- warn user to reauthenticate when password is None in silent mode +- always force authentication when password passed, even when token cached +- bugfix: negative indexing of paginated response objects now functions correctly +- deprecate one.util.ensure_list; moved to iblutil.util.ensure_list + +### Added + +- one.alf.exceptions.ALFWarning category allows users to filter warnings relating to mixed revisions + +## [2.9.1] ### Modified @@ -7,7 +28,7 @@ - HOTFIX: Ensure http data server URL does not end in slash - HOTFIX: Handle public aggregate dataset relative paths - HOTFIX: No longer warns in silent mode when no param conflicts present -- Explicit kwargs in load_* methods to avoid user confusion (e.g. no 'namespace' kwarg for `load_dataset`) +- explicit kwargs in load_* methods to avoid user confusion (e.g. no 'namespace' kwarg for `load_dataset`) ## [2.9.0] This version adds a couple of new ALF functions. diff --git a/docs/one_installation.md b/docs/one_installation.md index b4f8ba81..6296523e 100644 --- a/docs/one_installation.md +++ b/docs/one_installation.md @@ -61,7 +61,7 @@ one = ONE() To change your default database, or re-run the setup for a given database, you can use the following ```python -ONE._setup(base_url='https://test.alyx.internationalbrainlab.org', make_default=True) +ONE.setup(base_url='https://test.alyx.internationalbrainlab.org', make_default=True) ``` ## 4. Update diff --git a/one/__init__.py b/one/__init__.py index fb65cac4..07a5d384 100644 --- a/one/__init__.py +++ b/one/__init__.py @@ -1,2 +1,2 @@ """The Open Neurophysiology Environment (ONE) API.""" -__version__ = '2.9.1' +__version__ = '2.10.0' diff --git a/one/alf/cache.py b/one/alf/cache.py index 4cfde509..8fccaa95 100644 --- a/one/alf/cache.py +++ b/one/alf/cache.py @@ -30,7 +30,7 @@ from one.alf.io import iter_sessions, iter_datasets from one.alf.files import session_path_parts, get_alf_path from one.converters import session_record2path -from one.util import QC_TYPE +from one.util import QC_TYPE, patch_cache __all__ = ['make_parquet_db', 'remove_missing_datasets', 'DATASETS_COLUMNS', 'SESSIONS_COLUMNS'] _logger = logging.getLogger(__name__) @@ -52,7 +52,6 @@ DATASETS_COLUMNS = ( 'id', # int64 'eid', # int64 - 'session_path', # relative to the root 'rel_path', # relative to the session path, includes the filename 'file_size', # file size in bytes 'hash', # sha1/md5, computed in load function @@ -89,7 +88,6 @@ def _get_dataset_info(full_ses_path, rel_dset_path, ses_eid=None, compute_hash=F return { 'id': Path(rel_ses_path, rel_dset_path).as_posix(), 'eid': str(ses_eid), - 'session_path': str(rel_ses_path), 'rel_path': Path(rel_dset_path).as_posix(), 'file_size': file_size, 'hash': md5(full_dset_path) if compute_hash else None, @@ -297,18 +295,30 @@ def remove_missing_datasets(cache_dir, tables=None, remove_empty_sessions=True, if tables is None: tables = {} for name in ('datasets', 'sessions'): - tables[name], _ = parquet.load(cache_dir / f'{name}.pqt') - to_delete = [] + table, m = parquet.load(cache_dir / f'{name}.pqt') + tables[name] = patch_cache(table, m.get('min_api_version'), name) + + INDEX_KEY = '.?id' + for name in tables: + # Set the appropriate index if none already set + if isinstance(tables[name].index, pd.RangeIndex): + idx_columns = sorted(tables[name].filter(regex=INDEX_KEY).columns) + tables[name].set_index(idx_columns, inplace=True) + + to_delete = set() gen_path = partial(session_record2path, root_dir=cache_dir) - sessions = sorted(map(lambda x: gen_path(x[1]), tables['sessions'].iterrows())) + # map of session path to eid + sessions = {gen_path(rec): eid for eid, rec in tables['sessions'].iterrows()} for session_path in iter_sessions(cache_dir): - rel_session_path = session_path.relative_to(cache_dir).as_posix() - datasets = tables['datasets'][tables['datasets']['session_path'] == rel_session_path] + try: + datasets = tables['datasets'].loc[sessions[session_path]] + except KeyError: + datasets = tables['datasets'].iloc[0:0, :] for dataset in iter_datasets(session_path): if dataset.as_posix() not in datasets['rel_path']: - to_delete.append(session_path.joinpath(dataset)) + to_delete.add(session_path.joinpath(dataset)) if session_path not in sessions and remove_empty_sessions: - to_delete.append(session_path) + to_delete.add(session_path) if dry: print('The following session and datasets would be removed:', end='\n\t') diff --git a/one/alf/exceptions.py b/one/alf/exceptions.py index c5f72fca..d963767b 100644 --- a/one/alf/exceptions.py +++ b/one/alf/exceptions.py @@ -1,4 +1,4 @@ -"""ALyx File related errors. +"""ALyx File related errors and warnings. A set of Alyx and ALF related error classes which provide a more verbose description of the raised issues. @@ -82,3 +82,8 @@ class ALFMultipleRevisionsFound(ALFError): explanation = ('The matching object/file(s) belong to more than one revision. ' 'Multiple datasets in different revision folders were found with no default ' 'specified.') + + +class ALFWarning(Warning): + """Cautions when loading ALF datasets.""" + pass diff --git a/one/api.py b/one/api.py index 2e39dd2d..c375f735 100644 --- a/one/api.py +++ b/one/api.py @@ -20,7 +20,7 @@ import packaging.version from iblutil.io import parquet, hashfile -from iblutil.util import Bunch, flatten +from iblutil.util import Bunch, flatten, ensure_list import one.params import one.webclient as wc @@ -28,10 +28,10 @@ import one.alf.files as alfiles import one.alf.exceptions as alferr from .alf.cache import make_parquet_db, DATASETS_COLUMNS, SESSIONS_COLUMNS -from .alf.spec import is_uuid_string, QC +from .alf.spec import is_uuid_string, QC, to_alf from . import __version__ from one.converters import ConversionMixin, session_record2path -import one.util as util +from one import util _logger = logging.getLogger(__name__) __all__ = ['ONE', 'One', 'OneAlyx'] @@ -131,7 +131,7 @@ def load_cache(self, tables_dir=None, **kwargs): # Set the appropriate index if none already set if isinstance(cache.index, pd.RangeIndex): - idx_columns = cache.filter(regex=INDEX_KEY).columns.tolist() + idx_columns = sorted(cache.filter(regex=INDEX_KEY).columns) if len(idx_columns) == 0: raise KeyError('Failed to set index') cache.set_index(idx_columns, inplace=True) @@ -287,9 +287,10 @@ def _update_cache_from_records(self, strict=False, **kwargs): if not strict: # Deal with case where there are extra columns in the cache extra_columns = set(self._cache[table].columns) - set(records.columns) - for col in extra_columns: - n = list(self._cache[table].columns).index(col) - records.insert(n, col, np.nan) + column_ids = map(list(self._cache[table].columns).index, extra_columns) + for col, n in sorted(zip(extra_columns, column_ids), key=lambda x: x[1]): + val = records.get('exists', True) if col.startswith('exists_') else np.nan + records.insert(n, col, val) # Drop any extra columns in the records that aren't in cache table to_drop = set(records.columns) - set(self._cache[table].columns) records.drop(to_drop, axis=1, inplace=True) @@ -302,7 +303,8 @@ def _update_cache_from_records(self, strict=False, **kwargs): to_assign = records[~to_update] if isinstance(self._cache[table].index, pd.MultiIndex) and not to_assign.empty: # Concatenate and sort (no other way for non-unique index within MultiIndex) - self._cache[table] = pd.concat([self._cache[table], to_assign]).sort_index() + frames = filter(lambda x: not x.empty, [self._cache[table], to_assign]) + self._cache[table] = pd.concat(frames).sort_index() else: for index, record in to_assign.iterrows(): self._cache[table].loc[index, :] = record[self._cache[table].columns].values @@ -501,7 +503,7 @@ def sort_fcn(itm): return ([], None) if details else [] # String fields elif key in ('subject', 'task_protocol', 'laboratory', 'projects'): - query = '|'.join(util.ensure_list(value)) + query = '|'.join(ensure_list(value)) key = 'lab' if key == 'laboratory' else key mask = sessions[key].str.contains(query, regex=self.wildcards) sessions = sessions[mask.astype(bool, copy=False)] @@ -510,7 +512,7 @@ def sort_fcn(itm): session_date = pd.to_datetime(sessions['date']) sessions = sessions[(session_date >= start) & (session_date <= end)] elif key == 'number': - query = util.ensure_list(value) + query = ensure_list(value) sessions = sessions[sessions[key].isin(map(int, query))] # Dataset/QC check is biggest so this should be done last elif key == 'dataset' or (key == 'dataset_qc_lte' and 'dataset' not in queries): @@ -518,7 +520,7 @@ def sort_fcn(itm): qc = QC.validate(queries.get('dataset_qc_lte', 'FAIL')).name # validate value has_dset = sessions.index.isin(datasets.index.get_level_values('eid')) datasets = datasets.loc[(sessions.index.values[has_dset], ), :] - query = util.ensure_list(value if key == 'dataset' else '') + query = ensure_list(value if key == 'dataset' else '') # For each session check any dataset both contains query and exists mask = ( (datasets @@ -550,9 +552,8 @@ def _check_filesystem(self, datasets, offline=None, update_exists=True, check_ha Given a set of datasets, check whether records correctly reflect the filesystem. Called by load methods, this returns a list of file paths to load and return. - TODO This needs changing; overload for downloading? - This changes datasets frame, calls _update_cache(sessions=None, datasets=None) to - update and save tables. Download_datasets can also call this function. + This changes datasets frame, calls _update_cache(sessions=None, datasets=None) to + update and save tables. Download_datasets may also call this function. Parameters ---------- @@ -573,12 +574,18 @@ def _check_filesystem(self, datasets, offline=None, update_exists=True, check_ha """ if isinstance(datasets, pd.Series): datasets = pd.DataFrame([datasets]) + assert datasets.index.nlevels <= 2 + idx_names = ['eid', 'id'] if datasets.index.nlevels == 2 else ['id'] + datasets.index.set_names(idx_names, inplace=True) elif not isinstance(datasets, pd.DataFrame): # Cast set of dicts (i.e. from REST datasets endpoint) datasets = util.datasets2records(list(datasets)) + else: + datasets = datasets.copy() indices_to_download = [] # indices of datasets that need (re)downloading files = [] # file path list to return # If the session_path field is missing from the datasets table, fetch from sessions table + # Typically only aggregate frames contain this column if 'session_path' not in datasets.columns: if 'eid' not in datasets.index.names: # Get slice of full frame with eid in index @@ -619,20 +626,6 @@ def _check_filesystem(self, datasets, offline=None, update_exists=True, check_ha files.append(None) # Add this index to list of datasets that need downloading indices_to_download.append(i) - if rec['exists'] != file.exists(): - with warnings.catch_warnings(): - # Suppress future warning: exist column should always be present - msg = '.*indexing on a MultiIndex with a nested sequence of labels.*' - warnings.filterwarnings('ignore', message=msg) - datasets.at[i, 'exists'] = not rec['exists'] - if update_exists: - _logger.debug('Updating exists field') - if isinstance(i, tuple): - self._cache['datasets'].loc[i, 'exists'] = not rec['exists'] - else: # eid index level missing in datasets input - i = pd.IndexSlice[:, i] - self._cache['datasets'].loc[i, 'exists'] = not rec['exists'] - self._cache['_meta']['modified_time'] = datetime.now() # If online and we have datasets to download, call download_datasets with these datasets if not (offline or self.offline) and indices_to_download: @@ -643,14 +636,32 @@ def _check_filesystem(self, datasets, offline=None, update_exists=True, check_ha for i, file in zip(indices_to_download, new_files): files[datasets.index.get_loc(i)] = file + # NB: Currently if not offline and a remote file is missing, an exception will be raised + # before we reach this point. This could change in the future. + exists = list(map(bool, files)) + if not all(datasets['exists'] == exists): + with warnings.catch_warnings(): + # Suppress future warning: exist column should always be present + msg = '.*indexing on a MultiIndex with a nested sequence of labels.*' + warnings.filterwarnings('ignore', message=msg) + datasets['exists'] = exists + if update_exists: + _logger.debug('Updating exists field') + i = datasets.index + if i.nlevels == 1: + # eid index level missing in datasets input + i = pd.IndexSlice[:, i] + self._cache['datasets'].loc[i, 'exists'] = exists + self._cache['_meta']['modified_time'] = datetime.now() + if self.record_loaded: loaded = np.fromiter(map(bool, files), bool) - loaded_ids = np.array(datasets.index.to_list())[loaded] + loaded_ids = datasets.index.get_level_values('id')[loaded].to_numpy() if '_loaded_datasets' not in self._cache: self._cache['_loaded_datasets'] = np.unique(loaded_ids) else: loaded_set = np.hstack([self._cache['_loaded_datasets'], loaded_ids]) - self._cache['_loaded_datasets'] = np.unique(loaded_set, axis=0) + self._cache['_loaded_datasets'] = np.unique(loaded_set) # Return full list of file paths return files @@ -701,7 +712,7 @@ def list_subjects(self) -> List[str]: @util.refresh def list_datasets( self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL, - ignore_qc_not_set=False, details=False, query_type=None + ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False ) -> Union[np.ndarray, pd.DataFrame]: """ Given an eid, return the datasets for those sessions. @@ -734,6 +745,9 @@ def list_datasets( relative paths (collection/revision/filename) - see one.alf.spec.describe for details. query_type : str Query cache ('local') or Alyx database ('remote'). + default_revisions_only : bool + When true, only matching datasets that are considered default revisions are returned. + If no 'default_revision' column is present, and ALFError is raised. Returns ------- @@ -763,6 +777,11 @@ def list_datasets( >>> datasets = one.list_datasets(eid, {'object': ['wheel', 'trial?']}) """ datasets = self._cache['datasets'] + if default_revisions_only: + if 'default_revision' not in datasets.columns: + raise alferr.ALFError('No default revisions specified') + datasets = datasets[datasets['default_revision']] + filter_args = dict( collection=collection, filename=filename, wildcards=self.wildcards, revision=revision, revision_last_before=False, assert_unique=False, qc=qc, @@ -1003,6 +1022,9 @@ def load_object(self, # For those that don't exist, download them offline = None if query_type == 'auto' else self.mode == 'local' + if datasets.index.nlevels == 1: + # Reinstate eid index + datasets = pd.concat({str(eid): datasets}, names=['eid']) files = self._check_filesystem(datasets, offline=offline, check_hash=check_hash) files = [x for x in files if x] if not files: @@ -1107,6 +1129,9 @@ def load_dataset(self, wildcards=self.wildcards, assert_unique=assert_unique) if len(datasets) == 0: raise alferr.ALFObjectNotFound(f'Dataset "{dataset}" not found') + if datasets.index.nlevels == 1: + # Reinstate eid index + datasets = pd.concat({str(eid): datasets}, names=['eid']) # Check files exist / download remote files offline = None if query_type == 'auto' else self.mode == 'local' @@ -1130,9 +1155,11 @@ def load_datasets(self, download_only: bool = False, check_hash: bool = True) -> Any: """ - Load datasets for a given session id. Returns two lists the length of datasets. The - first is the data (or file paths if download_data is false), the second is a list of - meta data Bunches. If assert_present is false, missing data will be returned as None. + Load datasets for a given session id. + + Returns two lists the length of datasets. The first is the data (or file paths if + download_data is false), the second is a list of meta data Bunches. If assert_present is + false, missing data will be returned as None. Parameters ---------- @@ -1164,9 +1191,9 @@ def load_datasets(self, Returns ------- list - A list of data (or file paths) the length of datasets + A list of data (or file paths) the length of datasets. list - A list of meta data Bunches. If assert_present is False, missing data will be None + A list of meta data Bunches. If assert_present is False, missing data will be None. Notes ----- @@ -1178,6 +1205,8 @@ def load_datasets(self, revision as separate keyword arguments. - To ensure you are loading the correct revision, use the revisions kwarg instead of relative paths. + - To load an exact revision (i.e. not the last revision before a given date), pass in + a list of relative paths or a data frame. Raises ------ @@ -1220,8 +1249,25 @@ def _verify_specifiers(specifiers): if isinstance(datasets, str): raise TypeError('`datasets` must be a non-string iterable') - # Check input args - collections, revisions = _verify_specifiers([collections, revisions]) + + # Check if rel paths have been used (e.g. the output of list_datasets) + is_frame = isinstance(datasets, pd.DataFrame) + if is_rel_paths := (is_frame or any('/' in x for x in datasets)): + if not (collections, revisions) == (None, None): + raise ValueError( + 'collection and revision kwargs must be None when dataset is a relative path') + if is_frame: + if 'eid' in datasets.index.names: + assert set(datasets.index.get_level_values('eid')) == {eid} + datasets = datasets['rel_path'].tolist() + datasets = list(map(partial(alfiles.rel_path_parts, as_dict=True), datasets)) + if len(datasets) > 0: + # Extract collection and revision from each of the parsed datasets + # None -> '' ensures exact collections and revisions are used in filter + # NB: f user passes in dicts, any collection/revision keys will be ignored. + collections, revisions = zip( + *((x.pop('collection') or '', x.pop('revision') or '') for x in datasets) + ) # Short circuit query_type = query_type or self.mode @@ -1235,35 +1281,49 @@ def _verify_specifiers(specifiers): if len(datasets) == 0: return None, all_datasets.iloc[0:0] # Return empty - # Filter and load missing - if self.wildcards: # Append extension wildcard if 'object.attribute' string - datasets = [x + ('.*' if isinstance(x, str) and len(x.split('.')) == 2 else '') - for x in datasets] + # More input validation + input_types = [(isinstance(x, str), isinstance(x, dict)) for x in datasets] + if not all(map(any, input_types)) or not any(map(all, zip(*input_types))): + raise ValueError('`datasets` must be iterable of only str or only dicts') + if self.wildcards and input_types[0][0]: # if wildcards and input is iter of str + # Append extension wildcard if 'object.attribute' string + datasets = [ + x + ('.*' if isinstance(x, str) and len(x.split('.')) == 2 else '') + for x in datasets + ] + + # Check input args + collections, revisions = _verify_specifiers([collections, revisions]) + # If collections provided in datasets list, e.g. [collection/x.y.z], do not assert unique - validate = not any(('/' if isinstance(d, str) else 'collection') in d for d in datasets) - if not validate and not all(x is None for x in collections + revisions): - raise ValueError( - 'collection and revision kwargs must be None when dataset is a relative path') - ops = dict(wildcards=self.wildcards, assert_unique=validate) + # If not a dataframe, use revision last before (we've asserted no revision in rel_path) + ops = dict( + wildcards=self.wildcards, assert_unique=True, revision_last_before=not is_rel_paths) slices = [util.filter_datasets(all_datasets, x, y, z, **ops) for x, y, z in zip(datasets, collections, revisions)] present = [len(x) == 1 for x in slices] present_datasets = pd.concat(slices) + if present_datasets.index.nlevels == 1: + # Reinstate eid index + present_datasets = pd.concat({str(eid): present_datasets}, names=['eid']) # Check if user is blindly downloading all data and warn of non-default revisions if 'default_revision' in present_datasets and \ - not any(revisions) and not all(present_datasets['default_revision']): + is_rel_paths and not all(present_datasets['default_revision']): old = present_datasets.loc[~present_datasets['default_revision'], 'rel_path'].to_list() warnings.warn( 'The following datasets may have been revised and ' + 'are therefore not recommended for analysis:\n\t' + '\n\t'.join(old) + '\n' - 'To avoid this warning, specify the revision as a kwarg or use load_dataset.' + 'To avoid this warning, specify the revision as a kwarg or use load_dataset.', + alferr.ALFWarning ) if not all(present): - missing_list = ', '.join(x for x, y in zip(datasets, present) if not y) - # FIXME include collection and revision also + missing_list = (x if isinstance(x, str) else to_alf(**x) for x in datasets) + missing_list = ('/'.join(filter(None, [c, f'#{r}#' if r else None, d])) + for c, r, d in zip(collections, revisions, missing_list)) + missing_list = ', '.join(x for x, y in zip(missing_list, present) if not y) message = f'The following datasets are not in the cache: {missing_list}' if assert_present: raise alferr.ALFObjectNotFound(message) @@ -1284,7 +1344,7 @@ def _verify_specifiers(specifiers): # Make list of metadata Bunches out of the table records = (present_datasets - .reset_index(names='id') + .reset_index(names=['eid', 'id']) .to_dict('records', into=Bunch)) # Ensure result same length as input datasets list @@ -1417,6 +1477,9 @@ def load_collection(self, if len(datasets) == 0: raise alferr.ALFObjectNotFound(object or '') parts = [alfiles.rel_path_parts(x) for x in datasets.rel_path] + if datasets.index.nlevels == 1: + # Reinstate eid index + datasets = pd.concat({str(eid): datasets}, names=['eid']) # For those that don't exist, download them offline = None if query_type == 'auto' else self.mode == 'local' @@ -1766,11 +1829,11 @@ def describe_dataset(self, dataset_type=None): @util.refresh def list_datasets( self, eid=None, filename=None, collection=None, revision=None, qc=QC.FAIL, - ignore_qc_not_set=False, details=False, query_type=None + ignore_qc_not_set=False, details=False, query_type=None, default_revisions_only=False ) -> Union[np.ndarray, pd.DataFrame]: filters = dict( - collection=collection, filename=filename, revision=revision, - qc=qc, ignore_qc_not_set=ignore_qc_not_set) + collection=collection, filename=filename, revision=revision, qc=qc, + ignore_qc_not_set=ignore_qc_not_set, default_revisions_only=default_revisions_only) if (query_type or self.mode) != 'remote': return super().list_datasets(eid, details=details, query_type=query_type, **filters) elif not eid: @@ -1785,6 +1848,7 @@ def list_datasets( if datasets is None or datasets.empty: return self._cache['datasets'].iloc[0:0] if details else [] # Return empty assert set(datasets.index.unique('eid')) == {eid} + del filters['default_revisions_only'] datasets = util.filter_datasets( datasets.droplevel('eid'), assert_unique=False, wildcards=self.wildcards, **filters) # Return only the relative path @@ -1825,8 +1889,7 @@ def list_aggregates(self, relation: str, identifier: str = None, all_aggregates = self.alyx.rest('datasets', 'list', django=query) records = (util.datasets2records(all_aggregates) .reset_index(level=0) - .drop('eid', axis=1) - .rename_axis(index={'id': 'did'})) + .drop('eid', axis=1)) # Since rel_path for public FI file records starts with 'public/aggregates' instead of just # 'aggregates', we should discard the file path parts before 'aggregates' (if present) records['rel_path'] = records['rel_path'].str.replace( @@ -1847,11 +1910,6 @@ def path2id(p) -> str: # NB: We avoid exact matches as most users will only include subject, not lab/subject records = records[records['identifier'].str.contains(identifier)] - # Add exists_aws field for download method - for i, rec in records.iterrows(): - fr = next(x['file_records'] for x in all_aggregates if x['url'].endswith(i)) - records.loc[i, 'exists_aws'] = any( - x['data_repository'].startswith('aws') and x['exists'] for x in fr) return util.filter_datasets(records, filename=dataset, revision=revision, wildcards=True, assert_unique=assert_unique) @@ -1907,6 +1965,7 @@ def load_aggregate(self, relation: str, identifier: str, raise alferr.ALFObjectNotFound( f'{dataset or "dataset"} not found for {relation}/{identifier}') # update_exists=False because these datasets are not in the cache table + records['session_path'] = '' # explicitly add session path column file, = self._check_filesystem(records, update_exists=False) if not file: raise alferr.ALFObjectNotFound('Dataset file not found on disk') @@ -2212,7 +2271,7 @@ def search(self, details=False, query_type=None, **kwargs): def _add_date(records): """Add date field for compatibility with One.search output.""" - for s in util.ensure_list(records): + for s in ensure_list(records): s['date'] = datetime.fromisoformat(s['start_time']).date() return records @@ -2243,9 +2302,12 @@ def _download_datasets(self, dsets, **kwargs) -> List[Path]: try: if not isinstance(dsets, pd.DataFrame): raise TypeError('Input datasets must be a pandas data frame for AWS download.') - if 'exists_aws' in dsets and np.all(np.equal(dsets['exists_aws'].values, True)): - _logger.info('Downloading from AWS') - return self._download_aws(map(lambda x: x[1], dsets.iterrows()), **kwargs) + assert 'exists_aws' not in dsets or np.all(np.equal(dsets['exists_aws'].values, True)) + _logger.debug('Downloading from AWS') + files = self._download_aws(map(lambda x: x[1], dsets.iterrows()), **kwargs) + # Trigger fallback download of any files missing on AWS + assert all(files), f'{sum(map(bool, files))} datasets not found on AWS' + return files except Exception as ex: _logger.debug(ex) return self._download_dataset(dsets, **kwargs) @@ -2281,7 +2343,7 @@ def _download_aws(self, dsets, update_exists=True, keep_uuid=None, **_) -> List[ assert self.mode != 'local' # Get all dataset URLs dsets = list(dsets) # Ensure not generator - uuids = [util.ensure_list(x.name)[-1] for x in dsets] + uuids = [ensure_list(x.name)[-1] for x in dsets] # If number of UUIDs is too high, fetch in loop to avoid 414 HTTP status code remote_records = [] N = 100 # Number of UUIDs per query @@ -2295,15 +2357,18 @@ def _download_aws(self, dsets, update_exists=True, keep_uuid=None, **_) -> List[ # Fetch file record path record = next((x for x in record['file_records'] if x['data_repository'].startswith('aws') and x['exists']), None) - if not record and update_exists and 'exists_aws' in self._cache['datasets']: - _logger.debug('Updating exists field') - self._cache['datasets'].loc[(slice(None), uuid), 'exists_aws'] = False - self._cache['_meta']['modified_time'] = datetime.now() + if not record: + if update_exists and 'exists_aws' in self._cache['datasets']: + _logger.debug('Updating exists field') + self._cache['datasets'].loc[(slice(None), uuid), 'exists_aws'] = False + self._cache['_meta']['modified_time'] = datetime.now() out_files.append(None) continue + assert record['relative_path'].endswith(dset['rel_path']), \ + f'Relative path for dataset {uuid} does not match Alyx record' source_path = PurePosixPath(record['data_repository_path'], record['relative_path']) source_path = alfiles.add_uuid_string(source_path, uuid) - local_path = self.cache_dir.joinpath(dset['session_path'], dset['rel_path']) + local_path = self.cache_dir.joinpath(record['relative_path']) if keep_uuid is True or (keep_uuid is None and self.uuid_filenames is True): local_path = alfiles.add_uuid_string(local_path, uuid) local_path.parent.mkdir(exist_ok=True, parents=True) @@ -2352,7 +2417,7 @@ def _dset2url(self, dset, update_cache=True): did = dset['id'] elif 'file_records' not in dset: # Convert dataset Series to alyx dataset dict url = self.record2url(dset) # NB: URL will always be returned but may not exist - did = util.ensure_list(dset.name)[-1] + did = ensure_list(dset.name)[-1] else: # from datasets endpoint repo = getattr(getattr(self._web_client, '_par', None), 'HTTP_DATA_SERVER', None) url = next( @@ -2366,7 +2431,7 @@ def _dset2url(self, dset, update_cache=True): _logger.debug('Updating cache') # NB: This will be considerably easier when IndexSlice supports Ellipsis idx = [slice(None)] * int(self._cache['datasets'].index.nlevels / 2) - self._cache['datasets'].loc[(*idx, *util.ensure_list(did)), 'exists'] = False + self._cache['datasets'].loc[(*idx, *ensure_list(did)), 'exists'] = False self._cache['_meta']['modified_time'] = datetime.now() return url @@ -2460,7 +2525,7 @@ def _download_file(self, url, target_dir, keep_uuid=None, file_size=None, hash=N """ assert not self.offline # Ensure all target directories exist - [Path(x).mkdir(parents=True, exist_ok=True) for x in set(util.ensure_list(target_dir))] + [Path(x).mkdir(parents=True, exist_ok=True) for x in set(ensure_list(target_dir))] # download file(s) from url(s), returns file path(s) with UUID local_path, md5 = self.alyx.download_file(url, target_dir=target_dir, return_md5=True) @@ -2469,7 +2534,7 @@ def _download_file(self, url, target_dir, keep_uuid=None, file_size=None, hash=N if isinstance(url, (tuple, list)): assert (file_size is None) or len(file_size) == len(url) assert (hash is None) or len(hash) == len(url) - for args in zip(*map(util.ensure_list, (file_size, md5, hash, local_path, url))): + for args in zip(*map(ensure_list, (file_size, md5, hash, local_path, url))): self._check_hash_and_file_size_mismatch(*args) # check if we are keeping the uuid on the list of file names diff --git a/one/converters.py b/one/converters.py index dbfc8c43..20b80ea7 100644 --- a/one/converters.py +++ b/one/converters.py @@ -13,15 +13,14 @@ from uuid import UUID from inspect import unwrap from pathlib import Path, PurePosixPath -from urllib.parse import urlsplit from typing import Optional, Union, Mapping, List, Iterable as Iter import pandas as pd from iblutil.util import Bunch from one.alf.spec import is_session_path, is_uuid_string -from one.alf.files import get_session_path, add_uuid_string, session_path_parts, remove_uuid_string -from .util import Listable, ensure_list +from one.alf.files import get_session_path, add_uuid_string, session_path_parts, get_alf_path +from .util import Listable def recurse(func): @@ -232,39 +231,43 @@ def path2record(self, path) -> pd.Series: A cache file record """ is_session = is_session_path(path) - rec = self._cache['sessions' if is_session else 'datasets'] - if rec.empty: - return - # if (rec := self._cache['datasets']).empty: # py 3.8 - # return + if self._cache['sessions' if is_session else 'datasets'].empty: + return # short circuit: no records in the cache if is_session_path(path): lab, subject, date, number = session_path_parts(path) - rec = rec[ - (rec['lab'] == lab) & (rec['subject'] == subject) & - (rec['number'] == int(number)) & - (rec['date'] == datetime.date.fromisoformat(date)) + df = self._cache['sessions'] + rec = df[ + (df['lab'] == lab) & (df['subject'] == subject) & + (df['number'] == int(number)) & + (df['date'] == datetime.date.fromisoformat(date)) ] return None if rec.empty else rec.squeeze() - # Deal with file path - if isinstance(path, str) and path.startswith('http'): - # Remove the UUID from path - path = urlsplit(path).path.strip('/') - path = remove_uuid_string(PurePosixPath(path)) - session_path = get_session_path(path).as_posix() - else: - # No way of knowing root session path parts without cache tables - eid = self.path2eid(path) - session_series = self.list_datasets(eid, details=True).session_path - if not eid or session_series.empty: + # Deal with dataset path + if isinstance(path, str): + path = Path(path) + # If there's a UUID in the path, use that to fetch the record + name_parts = path.stem.split('.') + if is_uuid_string(uuid := name_parts[-1]): + try: + return self._cache['datasets'].loc[pd.IndexSlice[:, uuid], :].squeeze() + except KeyError: return - session_path, *_ = session_series - rec = rec[rec['session_path'] == session_path] - rec = rec[rec['rel_path'].apply(lambda x: path.as_posix().endswith(x))] + # Fetch via session record + eid = self.path2eid(path) + df = self.list_datasets(eid, details=True) + if not eid or df.empty: + return + + # Find row where relative path matches + rec = df[df['rel_path'] == path.relative_to(get_session_path(path)).as_posix()] assert len(rec) < 2, 'Multiple records found' - return None if rec.empty else rec.squeeze() + if rec.empty: + return None + # Convert slice to series and reinstate eid index if dropped + return rec.squeeze().rename(index=(eid, rec.index.get_level_values('id')[0])) @recurse def path2url(self, filepath): @@ -313,19 +316,18 @@ def record2url(self, record): session_spec = '{lab}/Subjects/{subject}/{date}/{number:03d}' url = record.get('session_path') or session_spec.format(**record) return webclient.rel_path2url(url) - uuid = ensure_list(record.name)[-1] # may be (eid, did) or simply did else: raise TypeError( f'record must be pandas.DataFrame or pandas.Series, got {type(record)} instead') - - session_path, rel_path = record[['session_path', 'rel_path']].to_numpy().flatten() - url = PurePosixPath(session_path, rel_path) + assert isinstance(record.name, tuple) and len(record.name) == 2 + eid, uuid = record.name # must be (eid, did) + session_path = self.eid2path(eid) + url = PurePosixPath(get_alf_path(session_path), record['rel_path']) return webclient.rel_path2url(add_uuid_string(url, uuid).as_posix()) def record2path(self, dataset) -> Optional[Path]: """ - Given a set of dataset records, checks the corresponding exists flag in the cache - correctly reflects the files system. + Given a set of dataset records, returns the corresponding paths Parameters ---------- @@ -337,13 +339,19 @@ def record2path(self, dataset) -> Optional[Path]: pathlib.Path File path for the record """ - assert isinstance(dataset, pd.Series) or len(dataset) == 1 - session_path, rel_path = dataset[['session_path', 'rel_path']].to_numpy().flatten() - file = Path(self.cache_dir, session_path, rel_path) + if isinstance(dataset, pd.DataFrame): + return [self.record2path(r) for _, r in dataset.iterrows()] + elif not isinstance(dataset, pd.Series): + raise TypeError( + f'record must be pandas.DataFrame or pandas.Series, got {type(dataset)} instead') + assert isinstance(dataset.name, tuple) and len(dataset.name) == 2 + eid, uuid = dataset.name # must be (eid, did) + if not (session_path := self.eid2path(eid)): + raise ValueError(f'Failed to determine session path for eid "{eid}"') + file = session_path / dataset['rel_path'] if self.uuid_filenames: - i = dataset.name if isinstance(dataset, pd.Series) else dataset.index[0] - file = add_uuid_string(file, i[1] if isinstance(i, tuple) else i) - return file # files[0] if len(datasets) == 1 else files + file = add_uuid_string(file, uuid) + return file @recurse def eid2ref(self, eid: Union[str, Iter], as_dict=True, parse=True) \ diff --git a/one/params.py b/one/params.py index 47409f9e..6de042cc 100644 --- a/one/params.py +++ b/one/params.py @@ -158,7 +158,8 @@ def setup(client=None, silent=False, make_default=None, username=None, cache_dir # Prompt for cache directory (default may have changed after prompt) client_key = _key_from_url(par.ALYX_URL) - cache_dir = cache_dir or Path(CACHE_DIR_DEFAULT, client_key) + def_cache_dir = cache_map.CLIENT_MAP.get(client_key) or Path(CACHE_DIR_DEFAULT, client_key) + cache_dir = cache_dir or def_cache_dir prompt = f'Enter the location of the download cache, current value is ["{cache_dir}"]:' cache_dir = input(prompt) or cache_dir @@ -185,7 +186,9 @@ def setup(client=None, silent=False, make_default=None, username=None, cache_dir # Precedence: user provided cache_dir; previously defined; the default location default_cache_dir = Path(CACHE_DIR_DEFAULT, client_key) cache_dir = cache_dir or cache_map.CLIENT_MAP.get(client_key, default_cache_dir) - par = par_current + # Use current params but drop any extras (such as the TOKEN or ALYX_PWD field) + keep_keys = par_default.as_dict().keys() + par = iopar.from_dict({k: v for k, v in par_current.as_dict().items() if k in keep_keys}) if any(v for k, v in cache_map.CLIENT_MAP.items() if k != client_key and v == cache_dir): warnings.warn('Warning: the directory provided is already a cache for another URL.') diff --git a/one/registration.py b/one/registration.py index ae85eb31..900a813e 100644 --- a/one/registration.py +++ b/one/registration.py @@ -23,14 +23,13 @@ import requests.exceptions from iblutil.io import hashfile -from iblutil.util import Bunch +from iblutil.util import Bunch, ensure_list import one.alf.io as alfio from one.alf.files import session_path_parts, get_session_path, folder_parts, filename_parts from one.alf.spec import is_valid import one.alf.exceptions as alferr from one.api import ONE -from one.util import ensure_list from one.webclient import no_cache _logger = logging.getLogger(__name__) diff --git a/one/remote/globus.py b/one/remote/globus.py index 161ef310..f9626eb5 100644 --- a/one/remote/globus.py +++ b/one/remote/globus.py @@ -97,12 +97,12 @@ from globus_sdk import TransferAPIError, GlobusAPIError, NetworkError, GlobusTimeoutError, \ GlobusConnectionError, GlobusConnectionTimeoutError, GlobusSDKUsageError, NullAuthorizer from iblutil.io import params as iopar +from iblutil.util import ensure_list from one.alf.spec import is_uuid from one.alf.files import remove_uuid_string import one.params from one.webclient import AlyxClient -from one.util import ensure_list from .base import DownloadClient, load_client_params, save_client_params __all__ = ['Globus', 'get_lab_from_endpoint_id', 'as_globus_path'] diff --git a/one/tests/alf/test_alf_io.py b/one/tests/alf/test_alf_io.py index 21e56e0d..cfe05b33 100644 --- a/one/tests/alf/test_alf_io.py +++ b/one/tests/alf/test_alf_io.py @@ -727,7 +727,7 @@ def setUp(self): self.session_path.joinpath(f'bar.baz_y.{uuid.uuid4()}.npy'), self.session_path.joinpath('#2021-01-01#', f'bar.baz.{uuid.uuid4()}.npy'), self.session_path.joinpath('task_00', 'x.y.z'), - self.session_path.joinpath('x.y.z'), + self.session_path.joinpath('x.y.z') ] for f in self.dsets: f.parent.mkdir(exist_ok=True, parents=True) diff --git a/one/tests/alf/test_cache.py b/one/tests/alf/test_cache.py index f729a06a..05c894c3 100644 --- a/one/tests/alf/test_cache.py +++ b/one/tests/alf/test_cache.py @@ -81,7 +81,6 @@ def test_datasets_df(self): print('Datasets dataframe') print(df) dset_info = df.loc[0].to_dict() - self.assertEqual(dset_info['session_path'], self.rel_ses_path[:-1]) self.assertEqual(dset_info['rel_path'], self.rel_ses_files[0].as_posix()) self.assertTrue(dset_info['file_size'] > 0) self.assertFalse(df.rel_path.str.contains('invalid').any()) @@ -100,7 +99,6 @@ def tests_db(self): df_dsets, metadata2 = parquet.load(fn_dsets) self.assertEqual(metadata2, metadata_exp) dset_info = df_dsets.loc[0].to_dict() - self.assertEqual(dset_info['session_path'], self.rel_ses_path[:-1]) self.assertEqual(dset_info['rel_path'], self.rel_ses_files[0].as_posix()) # Check behaviour when no files found @@ -115,12 +113,9 @@ def tests_db(self): apt.make_parquet_db(self.tmpdir, hash_ids=False, lab='another') # Create some more datasets in a session folder outside of a lab directory - dsets = revisions_datasets_table() with tempfile.TemporaryDirectory() as tdir: - for session_path, rel_path in dsets[['session_path', 'rel_path']].values: - filepath = Path(tdir).joinpath(session_path, rel_path) - filepath.parent.mkdir(exist_ok=True, parents=True) - filepath.touch() + session_path = Path(tdir).joinpath('subject', '1900-01-01', '001') + _ = revisions_datasets_table(touch_path=session_path) # create some files fn_ses, _ = apt.make_parquet_db(tdir, hash_ids=False, lab='another') df_ses, _ = parquet.load(fn_ses) self.assertTrue((df_ses['lab'] == 'another').all()) diff --git a/one/tests/remote/test_globus.py b/one/tests/remote/test_globus.py index 47943d9b..6fa99a76 100644 --- a/one/tests/remote/test_globus.py +++ b/one/tests/remote/test_globus.py @@ -107,7 +107,7 @@ def test_as_globus_path(self): # Only test this on windows if sys.platform == 'win32': actual = globus.as_globus_path('/foo/bar') - self.assertEqual(actual, f'/{Path.cwd().drive[0]}/foo/bar') + self.assertEqual(actual, f'/{Path.cwd().drive[0].upper()}/foo/bar') # On all systems an explicit Windows path should be converted to a POSIX one actual = globus.as_globus_path(PureWindowsPath('E:\\FlatIron\\integration')) diff --git a/one/tests/test_alyxclient.py b/one/tests/test_alyxclient.py index 728c033c..09615e6e 100644 --- a/one/tests/test_alyxclient.py +++ b/one/tests/test_alyxclient.py @@ -1,6 +1,7 @@ """Unit tests for the one.webclient module""" import unittest from unittest import mock +import urllib.parse import random import os import one.webclient as wc @@ -51,10 +52,17 @@ def test_authentication(self): ac.authenticate() mock_input.assert_not_called() self.assertTrue(ac.is_logged_in) + + # When password is None and in silent mode, there should be a warning + # followed by a failed login attempt + ac._par = ac._par.set('ALYX_PWD', None) + ac.logout() + with self.assertWarns(UserWarning), self.assertRaises(requests.HTTPError): + self.ac.authenticate(password=None) + # Test using input args ac._par = iopar.from_dict({k: v for k, v in ac._par.as_dict().items() if k not in login_keys}) - ac.logout() with mock.patch('builtins.input') as mock_input: ac.authenticate(TEST_DB_2['username'], TEST_DB_2['password'], cache_token=False) mock_input.assert_not_called() @@ -76,6 +84,16 @@ def test_authentication(self): with mock.patch('one.webclient.getpass', return_value=TEST_DB_2['password']) as mock_pwd: ac.authenticate(cache_token=True, force=True) mock_pwd.assert_called() + # If a password is passed, should always force re-authentication + rep = requests.Response() + rep.status_code = 200 + rep.json = lambda **_: {'token': 'abc'} + assert self.ac.is_logged_in + with mock.patch('one.webclient.requests.post', return_value=rep) as m: + self.ac.authenticate(password='foo', force=False) + expected = {'username': TEST_DB_2['username'], 'password': 'foo'} + m.assert_called_once_with(TEST_DB_2['base_url'] + '/auth-token', data=expected) + # Check non-silent double logout ac.logout() ac.logout() # Shouldn't complain @@ -462,6 +480,82 @@ def test_cache_dir_setter(self): finally: ac._par = ac._par.set('CACHE_DIR', prev_path) + def test_paginated_response(self): + """Test the _PaginatedResponse class.""" + alyx = mock.Mock(spec_set=ac) + N, lim = 2000, 250 # 2000 results, 250 records per page + url = ac.base_url + f'/?foo=bar&offset={lim}&limit={lim}' + res = {'count': N, 'next': url, 'previous': None, 'results': []} + res['results'] = [{'id': i} for i in range(lim)] + alyx._generic_request.return_value = res + # Check initialization + pg = wc._PaginatedResponse(alyx, res, cache_args=dict(clobber=True)) + self.assertEqual(pg.count, N) + self.assertEqual(len(pg), N) + self.assertEqual(pg.limit, lim) + self.assertEqual(len(pg._cache), N) + self.assertEqual(pg._cache[:lim], res['results']) + self.assertTrue(not any(pg._cache[lim:])) + self.assertIs(pg.alyx, alyx) + + # Check fetching cached item with +ve int + self.assertEqual({'id': 1}, pg[1]) + alyx._generic_request.assert_not_called() + # Check fetching cached item with +ve slice + self.assertEqual([{'id': 1}, {'id': 2}], pg[1:3]) + alyx._generic_request.assert_not_called() + # Check fetching cached item with -ve int + self.assertEqual({'id': 100}, pg[-1900]) + alyx._generic_request.assert_not_called() + # Check fetching cached item with -ve slice + self.assertEqual([{'id': 100}, {'id': 101}], pg[-1900:-1898]) + alyx._generic_request.assert_not_called() + # Check fetching uncached item with +ve int + n = offset = lim + res['results'] = [{'id': i} for i in range(offset, offset + lim)] + assert not any(pg._cache[offset:offset + lim]) + self.assertEqual({'id': lim}, pg[n]) + self.assertEqual(res['results'], pg._cache[offset:offset + lim]) + alyx._generic_request.assert_called_once_with(requests.get, mock.ANY, clobber=True) + self._check_get_query(alyx._generic_request.call_args, lim, offset) + # Check fetching uncached item with -ve int + offset = lim * 3 + res['results'] = [{'id': i} for i in range(offset, offset + lim)] + n = offset - N + 2 + assert not any(pg._cache[offset:offset + lim]) + self.assertEqual({'id': N + n}, pg[n]) + self.assertEqual(res['results'], pg._cache[offset:offset + lim]) + alyx._generic_request.assert_called_with(requests.get, mock.ANY, clobber=True) + self._check_get_query(alyx._generic_request.call_args, lim, offset) + # Check fetching uncached item with +ve slice + offset = lim * 5 + res['results'] = [{'id': i} for i in range(offset, offset + lim)] + n = offset + 20 + assert not any(pg._cache[offset:offset + lim]) + self.assertEqual([{'id': n}, {'id': n + 1}], pg[n:n + 2]) + self.assertEqual(res['results'], pg._cache[offset:offset + lim]) + alyx._generic_request.assert_called_with(requests.get, mock.ANY, clobber=True) + self._check_get_query(alyx._generic_request.call_args, lim, offset) + # Check fetching uncached item with -ve slice + offset = N - lim + res['results'] = [{'id': i} for i in range(offset, offset + lim)] + assert not any(pg._cache[offset:offset + lim]) + self.assertEqual([{'id': N - 2}, {'id': N - 1}], pg[-2:]) + self.assertEqual(res['results'], pg._cache[offset:offset + lim]) + alyx._generic_request.assert_called_with(requests.get, mock.ANY, clobber=True) + self._check_get_query(alyx._generic_request.call_args, lim, offset) + # At this point, there should be a certain number of None values left + self.assertEqual(expected_calls := 4, alyx._generic_request.call_count) + self.assertEqual((expected_calls + 1) * lim, sum(list(map(bool, pg._cache)))) + + def _check_get_query(self, call_args, limit, offset): + """Check URL get query contains the expected limit and offset params.""" + (_, url), _ = call_args + self.assertTrue(url.startswith(ac.base_url)) + query = urllib.parse.parse_qs(urllib.parse.urlparse(url).query) + expected = {'foo': ['bar'], 'offset': [str(offset)], 'limit': [str(limit)]} + self.assertDictEqual(query, expected) + if __name__ == '__main__': unittest.main(exit=False, verbosity=2) diff --git a/one/tests/test_converters.py b/one/tests/test_converters.py index 27360417..79aa0b78 100644 --- a/one/tests/test_converters.py +++ b/one/tests/test_converters.py @@ -2,7 +2,7 @@ import unittest from unittest import mock from pathlib import Path, PurePosixPath, PureWindowsPath -from uuid import UUID +from uuid import UUID, uuid4 import datetime import pandas as pd @@ -116,11 +116,17 @@ def test_path2record(self): self.assertTrue(file.as_posix().endswith(rec['rel_path'])) # Test URL - parts = add_uuid_string(file, '94285bfd-7500-4583-83b1-906c420cc667').parts[-7:] + uuid = '6cbb724e-c7ec-4eab-b24b-555001502d10' + parts = add_uuid_string(file, uuid).parts[-7:] url = TEST_DB_2['base_url'] + '/'.join(('', *parts)) rec = self.one.path2record(url) self.assertIsInstance(rec, pd.Series) self.assertTrue(file.as_posix().endswith(rec['rel_path'])) + # With a UUID missing from cache, should return None + uuid = '94285bfd-7500-4583-83b1-906c420cc667' + parts = add_uuid_string(file, uuid).parts[-7:] + url = TEST_DB_2['base_url'] + '/'.join(('', *parts)) + self.assertIsNone(self.one.path2record(url)) file = file.parent / '_fake_obj.attr.npy' self.assertIsNone(self.one.path2record(file)) @@ -277,14 +283,18 @@ def test_record2path(self): # As pd.DataFrame idx = rec.rel_path == 'alf/probe00/_phy_spikes_subset.channels.npy' path = self.one.record2path(rec[idx]) - self.assertEqual(expected, path) + self.assertEqual([expected], path) + # Test validation + self.assertRaises(AssertionError, self.one.record2path, rec[idx].droplevel(0)) # no eid + self.assertRaises(TypeError, self.one.record2path, rec[idx].to_dict()) + unknown = rec[idx].squeeze().rename(index=(str(uuid4()), data_id)) + self.assertRaises(ValueError, self.one.record2path, unknown) # unknown eid # With UUID in file name try: self.one.uuid_filenames = True expected = expected.with_suffix(f'.{data_id}.npy') - self.assertEqual(expected, self.one.record2path(rec[idx])) # as pd.DataFrame + self.assertEqual([expected], self.one.record2path(rec[idx])) # as pd.DataFrame self.assertEqual(expected, self.one.record2path(rec[idx].squeeze())) # as pd.Series - self.assertEqual(expected, self.one.record2path(rec[idx].droplevel(0))) # no eid finally: self.one.uuid_filenames = False diff --git a/one/tests/test_one.py b/one/tests/test_one.py index e978a2ea..c88351be 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -48,7 +48,7 @@ from one.api import ONE, One, OneAlyx from one.util import ( ses2records, validate_date_range, index_last_before, filter_datasets, _collection_spec, - filter_revision_last_before, parse_id, autocomplete, LazyId, datasets2records + filter_revision_last_before, parse_id, autocomplete, LazyId, datasets2records, ensure_list ) import one.params import one.alf.exceptions as alferr @@ -293,15 +293,22 @@ def test_filter(self): datasets['rel_path'] = revisions # Should return last revision before date for each collection/dataset + # These comprise mixed revisions which should trigger ALF warning revision = '2020-09-06' - verifiable = filter_datasets(datasets, None, None, revision, assert_unique=False) + expected_warn = 'Multiple revisions: "2020-08-31", "2020-01-01"' + with self.assertWarnsRegex(alferr.ALFWarning, expected_warn): + verifiable = filter_datasets(datasets, None, None, revision, assert_unique=False) self.assertEqual(2, len(verifiable)) self.assertTrue(all(x.split('#')[1] < revision for x in verifiable['rel_path'])) - # Should return single dataset with last revision when default specified - with self.assertRaises(alferr.ALFMultipleRevisionsFound): - filter_datasets(datasets, '*spikes.times*', 'alf/probe00', None, - assert_unique=True, wildcards=True, revision_last_before=True) + # with no default_revisions column there should be a warning about return latest revision + # when no revision is provided. + with self.assertWarnsRegex(alferr.ALFWarning, 'No default revision for dataset'): + verifiable = filter_datasets( + datasets, '*spikes.times*', 'alf/probe00', None, + assert_unique=True, wildcards=True, revision_last_before=True) + self.assertEqual(1, len(verifiable)) + self.assertTrue(verifiable['rel_path'].str.contains('#2021-xx-xx#').all()) # Should return matching revision verifiable = filter_datasets(datasets, None, None, r'2020-08-\d{2}', @@ -342,8 +349,8 @@ def test_filter(self): assert_unique=True, wildcards=True, revision_last_before=True) self.assertEqual(verifiable.rel_path.to_list(), ['alf/probe00/#2020-01-01#/spikes.times.npy']) - # When revision_last_before is false, expect multiple objects error - with self.assertRaises(alferr.ALFMultipleObjectsFound): + # When revision_last_before is false, expect multiple revisions error + with self.assertRaises(alferr.ALFMultipleRevisionsFound): filter_datasets(datasets, '*spikes.times*', 'alf/probe00', None, assert_unique=True, wildcards=True, revision_last_before=False) @@ -355,12 +362,29 @@ def test_filter_wildcards(self): assert_unique=False, wildcards=True) self.assertTrue(len(verifiable) == 2) # As dict with list (should act as logical OR) + kwargs = dict(assert_unique=False, revision_last_before=False, wildcards=True) dataset = dict(attribute=['timestamp?', 'raw']) - verifiable = filter_datasets(datasets, dataset, None, None, - assert_unique=False, revision_last_before=False, - wildcards=True) + verifiable = filter_datasets(datasets, dataset, None, None, **kwargs) self.assertEqual(4, len(verifiable)) + # Test handling greedy captures of collection parts when there are wildcards at the start + # of the filename patten. + + # Add some identical files that exist in collections and sub-collections + # (i.e. raw_ephys_data, raw_ephys_data/probe00, raw_ephys_data/probe01) + all_datasets = self.one._cache.datasets + meta_datasets = all_datasets[all_datasets.rel_path.str.contains('meta')].copy() + datasets = pd.concat([datasets, meta_datasets]) + + # Matching *meta should not capture raw_ephys_data/probe00, etc. + verifiable = filter_datasets(datasets, '*.meta', 'raw_ephys_data', None, **kwargs) + expected = ['raw_ephys_data/_spikeglx_ephysData_g0_t0.nidq.meta'] + self.assertCountEqual(expected, verifiable.rel_path) + verifiable = filter_datasets(datasets, '*.meta', 'raw_ephys_data/probe??', None, **kwargs) + self.assertEqual(2, len(verifiable)) + verifiable = filter_datasets(datasets, '*.meta', 'raw_ephys_data*', None, **kwargs) + self.assertEqual(3, len(verifiable)) + def test_list_datasets(self): """Test One.list_datasets""" # test filename @@ -402,6 +426,25 @@ def test_list_datasets(self): self.assertIsInstance(dsets, list) self.assertTrue(len(dsets) == np.unique(dsets).size) + # Test default_revisions_only=True + with self.assertRaises(alferr.ALFError): # should raise as no 'default_revision' column + self.one.list_datasets('KS005/2019-04-02/001', default_revisions_only=True) + # Add the column and add some alternates + datasets = util.revisions_datasets_table(collections=['alf'], revisions=['', '2023-01-01']) + datasets['default_revision'] = [False, True] * 2 + self.one._cache.datasets['default_revision'] = True + self.one._cache.datasets = pd.concat([self.one._cache.datasets, datasets]).sort_index() + eid, *_ = datasets.index.get_level_values(0) + dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=False) + self.assertEqual(4, len(dsets)) + dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=True) + self.assertEqual(2, len(dsets)) + self.assertTrue(all('#2023-01-01#' in x for x in dsets)) + # Should be the same with details=True + dsets = self.one.list_datasets(eid, 'spikes.*', default_revisions_only=True, details=True) + self.assertEqual(2, len(dsets)) + self.assertTrue(all('#2023-01-01#' in x for x in dsets.rel_path)) + def test_list_collections(self): """Test One.list_collections""" # Test no eid @@ -482,23 +525,46 @@ def test_get_details(self): def test_check_filesystem(self): """Test for One._check_filesystem. + Most is already covered by other tests, this checks that it can deal with dataset frame without eid index and without a session_path column. """ # Get two eids to test eids = self.one._cache['datasets'].index.get_level_values(0)[[0, -1]] - datasets = self.one._cache['datasets'].loc[eids].drop('session_path', axis=1) + datasets = self.one._cache['datasets'].loc[eids] files = self.one._check_filesystem(datasets) self.assertEqual(53, len(files)) + # Expect same number of unique session paths as eids session_paths = set(map(lambda p: p.parents[1], files)) self.assertEqual(len(eids), len(session_paths)) expected = map(lambda x: '/'.join(x.parts[-3:]), session_paths) session_parts = self.one._cache['sessions'].loc[eids, ['subject', 'date', 'number']].values self.assertCountEqual(expected, map(lambda x: f'{x[0]}/{x[1]}/{x[2]:03}', session_parts)) - # Attempt the same with the eid index missing - datasets = datasets.droplevel(0).drop('session_path', axis=1) + + # Test a very rare occurence of a missing dataset with eid index missing + # but session_path column present + idx = self.one._cache.datasets.index[(i := 5)] # pick a random dataset to make vanish + _eid2path = { + e: self.one.eid2path(e).relative_to(self.one.cache_dir).as_posix() for e in eids + } + session_paths = list(map(_eid2path.get, datasets.index.get_level_values(0))) + datasets['session_path'] = session_paths + datasets = datasets.droplevel(0) + files[(i := 5)].unlink() + # For this check the current state should be exists==True in the main cache + assert self.one._cache.datasets.loc[idx, 'exists'].all() + _files = self.one._check_filesystem(datasets, update_exists=True) + self.assertIsNone(_files[i]) + self.assertFalse( + self.one._cache.datasets.loc[idx, 'exists'].all(), 'failed to update cache exists') + files[i].touch() # restore file for next check + + # Attempt to load datasets with both eid index + # and session_path column missing (most common) + datasets = datasets.drop('session_path', axis=1) self.assertEqual(files, self.one._check_filesystem(datasets)) + # Test with uuid_filenames as True self.one.uuid_filenames = True try: @@ -611,22 +677,49 @@ def test_load_datasets(self): files, meta = self.one.load_datasets(eid, dsets, download_only=True) self.assertTrue(all(isinstance(x, Path) for x in files)) + # Check behaviour when loading with a data frame (undocumented) + eid = '01390fcc-4f86-4707-8a3b-4d9309feb0a1' + datasets = self.one._cache.datasets.loc[([eid],), :].iloc[:3, :] + files, meta = self.one.load_datasets(eid, datasets, download_only=True) + self.assertTrue(all(isinstance(x, Path) for x in files)) + # Should raise when data frame contains a different eid + self.assertRaises(AssertionError, self.one.load_datasets, uuid4(), datasets) + + # Mix of str and dict + # Check download only + dsets = [ + spec.regex(spec.FILE_SPEC).match('_ibl_wheel.position.npy').groupdict(), + '_ibl_wheel.timestamps.npy' + ] + with self.assertRaises(ValueError): + self.one.load_datasets('KS005/2019-04-02/001', dsets, assert_present=False) + # Loading of non default revisions without using the revision kwarg causes user warning. - # With relative paths provided as input, dataset uniqueness validation is supressed. + # With relative paths provided as input, dataset uniqueness validation is suppressed. + eid = self.one._cache.sessions.iloc[0].name datasets = util.revisions_datasets_table( - revisions=('', '2020-01-08'), attributes=('times',)) + revisions=('', '2020-01-08'), attributes=('times',), touch_path=self.one.eid2path(eid)) datasets['default_revision'] = [False, True] * 3 - eid = datasets.iloc[0].name[0] + datasets.index = datasets.index.set_levels([eid], level=0) self.one._cache.datasets = datasets - with self.assertWarns(UserWarning): - self.one.load_datasets(eid, datasets['rel_path'].to_list(), - download_only=True, assert_present=False) + with self.assertWarns(alferr.ALFWarning): + self.one.load_datasets(eid, datasets['rel_path'].to_list(), download_only=True) # When loading without collections in the dataset list (i.e. just the dataset names) # an exception should be raised when datasets belong to multiple collections. self.assertRaises( alferr.ALFMultipleCollectionsFound, self.one.load_datasets, eid, ['spikes.times']) + # Ensure that when rel paths are passed, a null collection/revision is not interpreted as + # an ANY. NB this means the output of 'spikes.times.npy' will be different depending on + # weather other datasets in list include a collection or revision. + self.one._cache.datasets = datasets.iloc[:2, :].copy() # only two datasets, one default + (file,), _ = self.one.load_datasets(eid, ['spikes.times.npy', ], download_only=True) + self.assertTrue(file.as_posix().endswith('001/#2020-01-08#/spikes.times.npy')) + (file, _), _ = self.one.load_datasets( + eid, ['spikes.times.npy', 'xx/obj.attr.ext'], download_only=True, assert_present=False) + self.assertTrue(file.as_posix().endswith('001/spikes.times.npy')) + def test_load_dataset_from_id(self): """Test One.load_dataset_from_id""" uid = np.array([[-9204203870374650458, -6411285612086772563]]) @@ -836,13 +929,24 @@ def test_update_cache_from_records(self): # Check behaviour when columns don't match datasets.loc[:, 'exists'] = ~datasets.loc[:, 'exists'] datasets['extra_column'] = True + self.one._cache.datasets['foo_bar'] = 12 # this column is missing in our new records self.one._cache.datasets['new_column'] = False - self.addCleanup(self.one._cache.datasets.drop, 'new_column', axis=1, inplace=True) + self.addCleanup(self.one._cache.datasets.drop, 'foo_bar', axis=1, inplace=True) + # An exception is exists_* as the Alyx cache contains exists_aws and exists_flatiron + # These should simply be filled with the values of exists as Alyx won't return datasets + # that don't exist on FlatIron and if they don't exist on AWS it falls back to this. + self.one._cache.datasets['exists_aws'] = False with self.assertRaises(AssertionError): self.one._update_cache_from_records(datasets=datasets, strict=True) self.one._update_cache_from_records(datasets=datasets) verifiable = self.one._cache.datasets.loc[datasets.index.values, 'exists'] self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists'])) + self.one._update_cache_from_records(datasets=datasets) + verifiable = self.one._cache.datasets.loc[datasets.index.values, 'exists_aws'] + self.assertTrue(np.all(verifiable == datasets.loc[:, 'exists'])) + # If the extra column does not start with 'exists' it should be set to NaN + verifiable = self.one._cache.datasets.loc[datasets.index.values, 'foo_bar'] + self.assertTrue(np.isnan(verifiable).all()) # Check fringe cases with self.assertRaises(KeyError): @@ -877,14 +981,14 @@ def test_save_loaded_ids(self): old_cache = self.one._cache['datasets'] try: datasets = [self.one._cache.datasets, dset.to_frame().T] - datasets = pd.concat(datasets).astype(old_cache.dtypes) + datasets = pd.concat(datasets).astype(old_cache.dtypes).sort_index() + datasets.index.set_names(('eid', 'id'), inplace=True) self.one._cache['datasets'] = datasets dsets = [dset['rel_path'], '_ibl_trials.feedback_times.npy'] new_files, rec = self.one.load_datasets(eid, dsets, assert_present=False) loaded = self.one._cache['_loaded_datasets'] # One dataset is already in the list (test for duplicates) and other doesn't exist self.assertEqual(len(files), len(loaded), 'No new UUIDs should have been added') - self.assertIn(rec[1]['id'], loaded) # Already in list self.assertEqual(len(loaded), len(np.unique(loaded))) self.assertNotIn(dset.name[1], loaded) # Wasn't loaded as doesn't exist on disk finally: @@ -1250,9 +1354,6 @@ def test_list_aggregates(self): # Test listing by relation datasets = self.one.list_aggregates('subjects') self.assertTrue(all(datasets['rel_path'].str.startswith('aggregates/Subjects'))) - self.assertIn('exists_aws', datasets.columns) - self.assertIn('session_path', datasets.columns) - self.assertTrue(all(datasets['session_path'] == '')) self.assertTrue(self.one.list_aggregates('foobar').empty) # Test filtering with an identifier datasets = self.one.list_aggregates('subjects', 'ZM_1085') @@ -1463,7 +1564,7 @@ def test_search_insertions(self): dataset=['wheel.times'], query_type='remote') def test_search_terms(self): - """Test OneAlyx.search_terms""" + """Test OneAlyx.search_terms.""" assert self.one.mode != 'remote' search1 = self.one.search_terms() self.assertIn('dataset', search1) @@ -1481,7 +1582,7 @@ def test_search_terms(self): self.assertIn('model', search5) def test_load_dataset(self): - """Test OneAlyx.load_dataset""" + """Test OneAlyx.load_dataset.""" file = self.one.load_dataset(self.eid, '_spikeglx_sync.channels.npy', collection='raw_ephys_data', query_type='remote', download_only=True) @@ -1497,7 +1598,7 @@ def test_load_dataset(self): collection='alf', query_type='remote') def test_load_object(self): - """Test OneAlyx.load_object""" + """Test OneAlyx.load_object.""" files = self.one.load_object(self.eid, 'wheel', collection='alf', query_type='remote', download_only=True) @@ -1507,7 +1608,7 @@ def test_load_object(self): ) def test_get_details(self): - """Test OneAlyx.get_details""" + """Test OneAlyx.get_details.""" det = self.one.get_details(self.eid, query_type='remote') self.assertIsInstance(det, dict) self.assertEqual('SWC_043', det['subject']) @@ -1524,7 +1625,7 @@ def test_get_details(self): @unittest.skipIf(OFFLINE_ONLY, 'online only test') class TestOneDownload(unittest.TestCase): - """Test downloading datasets using OpenAlyx""" + """Test downloading datasets using OpenAlyx.""" tempdir = None one = None @@ -1538,7 +1639,7 @@ def setUp(self) -> None: self.eid = 'aad23144-0e52-4eac-80c5-c4ee2decb198' def test_download_datasets(self): - """Test OneAlyx._download_dataset, _download_file and _dset2url""" + """Test OneAlyx._download_dataset, _download_file and _dset2url.""" det = self.one.get_details(self.eid, True) rec = next(x for x in det['data_dataset_session_related'] if 'channels.brainLocation' in x['dataset_type']) @@ -1620,6 +1721,8 @@ def test_download_datasets(self): rec = self.one.list_datasets(self.eid, details=True) rec = rec[rec.rel_path.str.contains('00/pykilosort/channels.brainLocation')] rec['exists_aws'] = False # Ensure we use FlatIron for this + rec = pd.concat({str(self.eid): rec}, names=['eid']) + files = self.one._download_datasets(rec) self.assertFalse(None in files) @@ -1630,8 +1733,11 @@ def test_download_datasets(self): def test_download_aws(self): """Test for OneAlyx._download_aws method.""" # Test download datasets via AWS - dsets = self.one.list_datasets(self.eid, details=True) - file = self.one.cache_dir / dsets['rel_path'].values[0] + dsets = self.one.list_datasets( + self.eid, filename='*wiring.json', collection='raw_ephys_data/probe??', details=True) + dsets = pd.concat({str(self.eid): dsets}, names=['eid']) + + file = self.one.eid2path(self.eid) / dsets['rel_path'].values[0] with mock.patch('one.remote.aws.get_s3_from_alyx', return_value=(None, None)), \ mock.patch('one.remote.aws.s3_download_file', return_value=file) as method: self.one._download_datasets(dsets) @@ -1643,23 +1749,25 @@ def test_download_aws(self): # Check keep_uuid = True self.one._download_datasets(dsets, keep_uuid=True) _, local = method.call_args.args - self.assertIn(dsets.iloc[-1].name, local.name) + self.assertIn(dsets.iloc[-1].name[1], local.name) # Test behaviour when dataset not remotely accessible dsets = dsets[:1].copy() - rec = self.one.alyx.rest('datasets', 'read', id=dsets.index[0]) + rec = self.one.alyx.rest('datasets', 'read', id=dsets.index[0][1]) # need to find the index of matching aws repo, this is not constant across releases iaws = list(map(lambda x: x['data_repository'].startswith('aws'), rec['file_records'])).index(True) rec['file_records'][iaws]['exists'] = False # Set AWS file record to non-existent + self.one._cache.datasets['exists_aws'] = True # Only changes column if exists with mock.patch('one.remote.aws.get_s3_from_alyx', return_value=(None, None)), \ mock.patch.object(self.one.alyx, 'rest', return_value=[rec]), \ self.assertLogs('one.api', logging.DEBUG) as log: - self.assertEqual([None], self.one._download_datasets(dsets)) - self.assertRegex(log.output[-1], 'Updating exists field') + # should still download file via fallback method + self.assertEqual([file], self.one._download_datasets(dsets)) + self.assertRegex(log.output[1], 'Updating exists field') datasets = self.one._cache['datasets'] self.assertFalse( - datasets.loc[pd.IndexSlice[:, dsets.index[0]], 'exists_aws'].any() + datasets.loc[pd.IndexSlice[:, dsets.index[0][1]], 'exists_aws'].any() ) # Check falls back to HTTP when error raised @@ -1674,6 +1782,7 @@ def test_download_aws(self): def test_tag_mismatched_file_record(self): """Test for OneAlyx._tag_mismatched_file_record. + This method is also tested in test_download_datasets. """ did = '4a1500c2-60f3-418f-afa2-c752bb1890f0' @@ -1696,7 +1805,7 @@ def tearDown(self) -> None: class TestOneSetup(unittest.TestCase): - """Test parameter setup upon ONE instantiation and calling setup methods""" + """Test parameter setup upon ONE instantiation and calling setup methods.""" def setUp(self) -> None: self.tempdir = tempfile.TemporaryDirectory() self.addCleanup(self.tempdir.cleanup) @@ -1707,7 +1816,7 @@ def setUp(self) -> None: self.addCleanup(patch.stop) def test_local_cache_setup_prompt(self): - """Test One.setup""" + """Test One.setup.""" path = Path(self.tempdir.name).joinpath('subject', '2020-01-01', '1', 'spikes.times.npy') path.parent.mkdir(parents=True) path.touch() @@ -1731,6 +1840,7 @@ def test_local_cache_setup_prompt(self): def test_setup_silent(self): """Test setting up parameters with silent flag. + - Mock getfile to return temp dir as param file location - Mock input function as fail safe in case function erroneously prompts user for input """ @@ -1761,6 +1871,7 @@ def test_setup_silent(self): def test_setup_username(self): """Test setting up parameters with a provided username. + - Mock getfile to return temp dir as param file location - Mock input function as fail safe in case function erroneously prompts user for input - Mock requests.post returns a fake user authentication response @@ -1789,14 +1900,14 @@ def test_setup_username(self): @unittest.skipIf(OFFLINE_ONLY, 'online only test') def test_static_setup(self): - """Test OneAlyx.setup""" + """Test OneAlyx.setup.""" with mock.patch('iblutil.io.params.getfile', new=self.get_file), \ mock.patch('one.webclient.getpass', return_value='international'): one_obj = OneAlyx.setup(silent=True) self.assertEqual(one_obj.alyx.base_url, one.params.default().ALYX_URL) def test_setup(self): - """Test one.params.setup""" + """Test one.params.setup.""" url = TEST_DB_1['base_url'] def mock_input(prompt): @@ -1840,7 +1951,7 @@ def mock_input(prompt): self.assertTrue(getattr(mock_input, 'conflict_warn', False)) def test_patch_params(self): - """Test patching legacy params to the new location""" + """Test patching legacy params to the new location.""" # Save some old-style params old_pars = one.params.default().set('HTTP_DATA_SERVER', 'openalyx.org') # Save a REST query in the old location @@ -1857,7 +1968,7 @@ def test_patch_params(self): self.assertTrue(any(one_obj.alyx.cache_dir.joinpath('.rest').glob('*'))) def test_one_factory(self): - """Tests the ONE class factory""" + """Tests the ONE class factory.""" with mock.patch('iblutil.io.params.getfile', new=self.get_file), \ mock.patch('one.params.input', new=self.assertFalse): # Cache dir not in client cache map; use One (light) @@ -1891,9 +2002,9 @@ def test_one_factory(self): class TestOneMisc(unittest.TestCase): - """Test functions in one.util""" + """Test functions in one.util.""" def test_validate_date_range(self): - """Test one.util.validate_date_range""" + """Test one.util.validate_date_range.""" # Single string date actual = validate_date_range('2020-01-01') # On this day expected = (pd.Timestamp('2020-01-01 00:00:00'), @@ -1936,7 +2047,7 @@ def test_validate_date_range(self): validate_date_range(['2020-01-01', '2019-09-06', '2021-10-04']) def test_index_last_before(self): - """Test one.util.index_last_before""" + """Test one.util.index_last_before.""" revisions = ['2021-01-01', '2020-08-01', '', '2020-09-30'] verifiable = index_last_before(revisions, '2021-01-01') self.assertEqual(0, verifiable) @@ -1953,7 +2064,7 @@ def test_index_last_before(self): self.assertEqual(0, verifiable, 'should return most recent') def test_collection_spec(self): - """Test one.util._collection_spec""" + """Test one.util._collection_spec.""" # Test every combination of input inputs = [] _collection = {None: '({collection}/)?', '': '', '-': '{collection}/'} @@ -1967,20 +2078,20 @@ def test_collection_spec(self): self.assertEqual(expected, verifiable) def test_revision_last_before(self): - """Test one.util.filter_revision_last_before""" + """Test one.util.filter_revision_last_before.""" datasets = util.revisions_datasets_table() df = datasets[datasets.rel_path.str.startswith('alf/probe00')].copy() - verifiable = filter_revision_last_before(df, - revision='2020-09-01', assert_unique=False) + verifiable = filter_revision_last_before(df, revision='2020-09-01', assert_unique=False) self.assertTrue(len(verifiable) == 2) - # Test assert unique + # Remove one of the datasets' revisions to test assert unique on mixed revisions + df_mixed = df.drop((df['revision'] == '2020-01-08').idxmax()) with self.assertRaises(alferr.ALFMultipleRevisionsFound): - filter_revision_last_before(df, revision='2020-09-01', assert_unique=True) + filter_revision_last_before(df_mixed, revision='2020-09-01', assert_consistent=True) # Test with default revisions df['default_revision'] = False - with self.assertLogs(logging.getLogger('one.util')): + with self.assertWarnsRegex(alferr.ALFWarning, 'No default revision for dataset'): verifiable = filter_revision_last_before(df.copy(), assert_unique=False) self.assertTrue(len(verifiable) == 2) @@ -1991,8 +2102,10 @@ def test_revision_last_before(self): # Add unique default revisions df.iloc[[0, 4], -1] = True - verifiable = filter_revision_last_before(df.copy(), assert_unique=True) - self.assertTrue(len(verifiable) == 2) + # Should log mixed revisions + with self.assertWarnsRegex(alferr.ALFWarning, 'Multiple revisions'): + verifiable = filter_revision_last_before(df.copy(), assert_unique=True) + self.assertEqual(2, len(verifiable)) self.assertCountEqual(verifiable['rel_path'], df['rel_path'].iloc[[0, 4]]) # Add multiple default revisions @@ -2001,7 +2114,7 @@ def test_revision_last_before(self): filter_revision_last_before(df.copy(), assert_unique=True) def test_parse_id(self): - """Test one.util.parse_id""" + """Test one.util.parse_id.""" obj = unittest.mock.MagicMock() # Mock object to decorate obj.to_eid.return_value = 'parsed_id' # Method to be called input = 'subj/date/num' # Input id to pass to `to_eid` @@ -2015,7 +2128,7 @@ def test_parse_id(self): parse_id(obj.method)(obj, input) def test_autocomplete(self): - """Test one.util.autocomplete""" + """Test one.util.autocomplete.""" search_terms = ('subject', 'date_range', 'dataset', 'dataset_type') self.assertEqual('subject', autocomplete('Subj', search_terms)) self.assertEqual('dataset', autocomplete('dataset', search_terms)) @@ -2025,7 +2138,7 @@ def test_autocomplete(self): autocomplete('dat', search_terms) def test_LazyID(self): - """Test one.util.LazyID""" + """Test one.util.LazyID.""" uuids = [ 'c1a2758d-3ce5-4fa7-8d96-6b960f029fa9', '0780da08-a12b-452a-b936-ebc576aa7670', @@ -2039,3 +2152,12 @@ def test_LazyID(self): self.assertEqual(ez[0:2], uuids[0:2]) ez = LazyId([{'id': x} for x in uuids]) self.assertCountEqual(map(str, ez), uuids) + + def test_ensure_list(self): + """Test one.util.ensure_list. + + This function has now moved therefore we simply check for deprecation warning. + """ + with self.assertWarns(DeprecationWarning): + self.assertEqual(['123'], ensure_list('123')) + self.assertIs(x := ['123'], ensure_list(x)) diff --git a/one/tests/util.py b/one/tests/util.py index fee30011..98286c8d 100644 --- a/one/tests/util.py +++ b/one/tests/util.py @@ -11,6 +11,7 @@ import one.params from one.util import QC_TYPE +from one.converters import session_record2path def set_up_env() -> tempfile.TemporaryDirectory: @@ -68,7 +69,8 @@ def create_file_tree(one): """ # Create dset files from cache - for session_path, rel_path in one._cache.datasets[['session_path', 'rel_path']].values: + for (eid, _), rel_path in one._cache.datasets['rel_path'].items(): + session_path = session_record2path(one._cache.sessions.loc[eid]) filepath = Path(one.cache_dir).joinpath(session_path, rel_path) filepath.parent.mkdir(exist_ok=True, parents=True) filepath.touch() @@ -116,7 +118,8 @@ def setup_test_params(token=False, cache_dir=None): def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'), revisions=('', '2020-01-08', '2021-07-06'), object='spikes', - attributes=('times', 'waveforems')): + attributes=('times', 'waveforems'), + touch_path=None): """Returns a datasets cache DataFrame containing datasets with revision folders. As there are no revised datasets on the test databases, this function acts as a fixture for @@ -132,11 +135,13 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'), An ALF object attributes : tuple A list of ALF attributes + touch_path : pathlib.Path, str + If provided, files are created in this directory. Returns ------- pd.DataFrame - A datasets cache table containing datasets made from the input names + A datasets cache table containing datasets made from the input names. """ rel_path = [] @@ -146,8 +151,7 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'), rel_path.append('/'.join(x for x in (collec, rev, f'{object}.{attr}.npy') if x)) d = { 'rel_path': rel_path, - 'session_path': 'subject/1900-01-01/001', - 'file_size': None, + 'file_size': 0, 'hash': None, 'exists': True, 'qc': 'NOT_SET', @@ -155,6 +159,12 @@ def revisions_datasets_table(collections=('', 'alf/probe00', 'alf/probe01'), 'id': map(str, (uuid4() for _ in rel_path)) } + if touch_path: + for p in rel_path: + path = Path(touch_path).joinpath(p) + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + return pd.DataFrame(data=d).astype({'qc': QC_TYPE}).set_index(['eid', 'id']) diff --git a/one/util.py b/one/util.py index e7602def..056da194 100644 --- a/one/util.py +++ b/one/util.py @@ -1,14 +1,17 @@ """Decorators and small standalone functions for api module.""" +import re import logging import urllib.parse -from functools import wraps +import fnmatch +import warnings +from functools import wraps, partial from typing import Sequence, Union, Iterable, Optional, List from collections.abc import Mapping -import fnmatch from datetime import datetime import pandas as pd from iblutil.io import parquet +from iblutil.util import ensure_list as _ensure_list import numpy as np from packaging import version @@ -58,8 +61,8 @@ def _to_record(d): rec['eid'] = session.name file_path = urllib.parse.urlsplit(d['data_url'], allow_fragments=False).path.strip('/') file_path = get_alf_path(remove_uuid_string(file_path)) - rec['session_path'] = get_session_path(file_path).as_posix() - rec['rel_path'] = file_path[len(rec['session_path']):].strip('/') + session_path = get_session_path(file_path).as_posix() + rec['rel_path'] = file_path[len(session_path):].strip('/') rec['default_revision'] = d['default_revision'] == 'True' rec['qc'] = d.get('qc', 'NOT_SET') return rec @@ -94,7 +97,7 @@ def datasets2records(datasets, additional=None) -> pd.DataFrame: """ records = [] - for d in ensure_list(datasets): + for d in _ensure_list(datasets): file_record = next((x for x in d['file_records'] if x['data_url'] and x['exists']), None) if not file_record: continue # Ignore files that are not accessible @@ -104,10 +107,10 @@ def datasets2records(datasets, additional=None) -> pd.DataFrame: data_url = urllib.parse.urlsplit(file_record['data_url'], allow_fragments=False) file_path = get_alf_path(data_url.path.strip('/')) file_path = remove_uuid_string(file_path).as_posix() - rec['session_path'] = get_session_path(file_path) or '' - if rec['session_path']: - rec['session_path'] = rec['session_path'].as_posix() - rec['rel_path'] = file_path[len(rec['session_path']):].strip('/') + session_path = get_session_path(file_path) or '' + if session_path: + session_path = session_path.as_posix() + rec['rel_path'] = file_path[len(session_path):].strip('/') rec['default_revision'] = d['default_dataset'] rec['qc'] = d.get('qc') for field in additional or []: @@ -307,7 +310,9 @@ def filter_datasets( revision_last_before : bool When true and no exact match exists, the (lexicographically) previous revision is used instead. When false the revision string is matched like collection and filename, - with regular expressions permitted. + with regular expressions permitted. NB: When true and `revision` is None the default + revision is returned which may not be the last revision. If no default is defined, the + last revision is returned. qc : str, int, one.alf.spec.QC Returns datasets at or below this QC level. Integer values should correspond to the QC enumeration NOT the qc category column codes in the pandas table. @@ -350,6 +355,26 @@ def filter_datasets( >>> datasets filter_datasets(all_datasets, qc='PASS', ignore_qc_not_set=True) + Raises + ------ + one.alf.exceptions.ALFMultipleCollectionsFound + The matching list of datasets have more than one unique collection and `assert_unique` is + True. + one.alf.exceptions.ALFMultipleRevisionsFound + When `revision_last_before` is false, the matching list of datasets have more than one + unique revision. When `revision_last_before` is true, a 'default_revision' column exists, + and no revision is passed, this error means that one or more matching datasets have + multiple revisions specified as the default. This is typically an error in the cache table + itself as all datasets should have one and only one default revision specified. + one.alf.exceptions.ALFMultipleObjectsFound + The matching list of datasets have more than one unique filename and both `assert_unique` + and `revision_last_before` are true. + one.alf.exceptions.ALFError + When both `assert_unique` and `revision_last_before` is true, and a 'default_revision' + column exists but `revision` is None; one or more matching datasets have no default + revision specified. This is typically an error in the cache table itself as all datasets + should have one and only one default revision specified. + Notes ----- - It is not possible to match datasets that are in a given collection OR NOT in ANY collection. @@ -365,9 +390,15 @@ def filter_datasets( spec_str += _file_spec(**filename) regex_args.update(**filename) else: - # Convert to regex is necessary and assert end of string - filename = [fnmatch.translate(x) if wildcards else x + '$' for x in ensure_list(filename)] - spec_str += '|'.join(filename) + # Convert to regex if necessary and assert end of string + flagless_token = re.escape(r'(?s:') # fnmatch.translate may wrap input in flagless group + # If there is a wildcard at the start of the filename we must exclude capture of slashes to + # avoid capture of collection part, e.g. * -> .* -> [^/]* (one or more non-slash chars) + exclude_slash = partial(re.sub, fr'^({flagless_token})?\.\*', r'\g<1>[^/]*') + spec_str += '|'.join( + exclude_slash(fnmatch.translate(x)) if wildcards else x + '$' + for x in _ensure_list(filename) + ) # If matching revision name, add to regex string if not revision_last_before: @@ -378,7 +409,7 @@ def filter_datasets( continue if wildcards: # Convert to regex, remove \\Z which asserts end of string - v = (fnmatch.translate(x).replace('\\Z', '') for x in ensure_list(v)) + v = (fnmatch.translate(x).replace('\\Z', '') for x in _ensure_list(v)) if not isinstance(v, str): regex_args[k] = '|'.join(v) # logical OR @@ -393,33 +424,37 @@ def filter_datasets( qc_match &= all_datasets['qc'].ne('NOT_SET') # Filter datasets on path and QC - match = all_datasets[path_match & qc_match] + match = all_datasets[path_match & qc_match].copy() if len(match) == 0 or not (revision_last_before or assert_unique): return match - revisions = [rel_path_parts(x)[1] or '' for x in match.rel_path.values] + # Extract revision to separate column + if 'revision' not in match.columns: + match['revision'] = match.rel_path.map(lambda x: rel_path_parts(x)[1] or '') if assert_unique: collections = set(rel_path_parts(x)[0] or '' for x in match.rel_path.values) if len(collections) > 1: _list = '"' + '", "'.join(collections) + '"' raise alferr.ALFMultipleCollectionsFound(_list) if not revision_last_before: - if filename and len(match) > 1: + if len(set(match['revision'])) > 1: + _list = '"' + '", "'.join(set(match['revision'])) + '"' + raise alferr.ALFMultipleRevisionsFound(_list) + if len(match) > 1: _list = '"' + '", "'.join(match['rel_path']) + '"' raise alferr.ALFMultipleObjectsFound(_list) - if len(set(revisions)) > 1: - _list = '"' + '", "'.join(set(revisions)) + '"' - raise alferr.ALFMultipleRevisionsFound(_list) else: return match - elif filename and len(set(revisions)) != len(revisions): - _list = '"' + '", "'.join(match['rel_path']) + '"' - raise alferr.ALFMultipleObjectsFound(_list) - return filter_revision_last_before(match, revision, assert_unique=assert_unique) + match = filter_revision_last_before(match, revision, assert_unique=assert_unique) + if assert_unique and len(match) > 1: + _list = '"' + '", "'.join(match['rel_path']) + '"' + raise alferr.ALFMultipleObjectsFound(_list) + return match -def filter_revision_last_before(datasets, revision=None, assert_unique=True): +def filter_revision_last_before( + datasets, revision=None, assert_unique=True, assert_consistent=False): """ Filter datasets by revision, returning previous revision in ordered list if revision doesn't exactly match. @@ -433,43 +468,80 @@ def filter_revision_last_before(datasets, revision=None, assert_unique=True): assert_unique : bool When true an alferr.ALFMultipleRevisionsFound exception is raised when multiple default revisions are found; an alferr.ALFError when no default revision is found. + assert_consistent : bool + Will raise alferr.ALFMultipleRevisionsFound if matching revision is different between + datasets. Returns ------- pd.DataFrame A datasets DataFrame with 0 or 1 row per unique dataset. + + Raises + ------ + one.alf.exceptions.ALFMultipleRevisionsFound + When the 'default_revision' column exists and no revision is passed, this error means that + one or more matching datasets have multiple revisions specified as the default. This is + typically an error in the cache table itself as all datasets should have one and only one + default revision specified. + When `assert_consistent` is True, this error may mean that the matching datasets have + mixed revisions. + one.alf.exceptions.ALFMultipleObjectsFound + The matching list of datasets have more than one unique filename and both `assert_unique` + and `revision_last_before` are true. + one.alf.exceptions.ALFError + When both `assert_unique` and `revision_last_before` is true, and a 'default_revision' + column exists but `revision` is None; one or more matching datasets have no default + revision specified. This is typically an error in the cache table itself as all datasets + should have one and only one default revision specified. + + Notes + ----- + - When `revision` is not None, the default revision value is not used. If an older revision is + the default one (uncommon), passing in a revision may lead to a newer revision being returned + than if revision is None. + - A view is returned if a revision column is present, otherwise a copy is returned. """ def _last_before(df): """Takes a DataFrame with only one dataset and multiple revisions, returns matching row""" - if revision is None and 'default_revision' in df.columns: - if assert_unique and sum(df.default_revision) > 1: - revisions = df['revision'][df.default_revision.values] - rev_list = '"' + '", "'.join(revisions) + '"' - raise alferr.ALFMultipleRevisionsFound(rev_list) - if sum(df.default_revision) == 1: - return df[df.default_revision] - if len(df) == 1: # This may be the case when called from load_datasets - return df # It's not the default be there's only one available revision - # default_revision column all False; default isn't copied to remote repository + if revision is None: dset_name = df['rel_path'].iloc[0] - if assert_unique: - raise alferr.ALFError(f'No default revision for dataset {dset_name}') - else: - logger.warning(f'No default revision for dataset {dset_name}; using most recent') + if 'default_revision' in df.columns: + if assert_unique and sum(df.default_revision) > 1: + revisions = df['revision'][df.default_revision.values] + rev_list = '"' + '", "'.join(revisions) + '"' + raise alferr.ALFMultipleRevisionsFound(rev_list) + if sum(df.default_revision) == 1: + return df[df.default_revision] + if len(df) == 1: # This may be the case when called from load_datasets + return df # It's not the default but there's only one available revision + # default_revision column all False; default isn't copied to remote repository + if assert_unique: + raise alferr.ALFError(f'No default revision for dataset {dset_name}') + warnings.warn( + f'No default revision for dataset {dset_name}; using most recent', + alferr.ALFWarning) # Compare revisions lexicographically - if assert_unique and len(df['revision'].unique()) > 1: - rev_list = '"' + '", "'.join(df['revision'].unique()) + '"' - raise alferr.ALFMultipleRevisionsFound(rev_list) - # Square brackets forces 1 row DataFrame returned instead of Series idx = index_last_before(df['revision'].tolist(), revision) - # return df.iloc[slice(0, 0) if idx is None else [idx], :] + # Square brackets forces 1 row DataFrame returned instead of Series return df.iloc[slice(0, 0) if idx is None else [idx], :] - with pd.option_context('mode.chained_assignment', None): # FIXME Explicitly copy? - datasets['revision'] = [rel_path_parts(x)[1] or '' for x in datasets.rel_path] + # Extract revision to separate column + if 'revision' not in datasets.columns: + with pd.option_context('mode.chained_assignment', None): # FIXME Explicitly copy? + datasets['revision'] = datasets.rel_path.map(lambda x: rel_path_parts(x)[1] or '') + # Group by relative path (sans revision) groups = datasets.rel_path.str.replace('#.*#/', '', regex=True).values grouped = datasets.groupby(groups, group_keys=False) - return grouped.apply(_last_before) + filtered = grouped.apply(_last_before) + # Raise if matching revision is different between datasets + if len(filtered['revision'].unique()) > 1: + rev_list = '"' + '", "'.join(filtered['revision'].unique()) + '"' + if assert_consistent: + raise alferr.ALFMultipleRevisionsFound(rev_list) + else: + warnings.warn(f'Multiple revisions: {rev_list}', alferr.ALFWarning) + return filtered def index_last_before(revisions: List[str], revision: Optional[str]) -> Optional[int]: @@ -521,7 +593,10 @@ def autocomplete(term, search_terms) -> str: def ensure_list(value): """Ensure input is a list.""" - return [value] if isinstance(value, (str, dict)) or not isinstance(value, Iterable) else value + warnings.warn( + 'one.util.ensure_list is deprecated, use iblutil.util.ensure_list instead', + DeprecationWarning) + return _ensure_list(value) class LazyId(Mapping): @@ -606,4 +681,6 @@ def patch_cache(table: pd.DataFrame, min_api_version=None, name=None) -> pd.Data if name == 'datasets' and min_version < version.Version('2.7.0') and 'qc' not in table.columns: qc = pd.Categorical.from_codes(np.zeros(len(table.index), dtype=int), dtype=QC_TYPE) table = table.assign(qc=qc) + if name == 'datasets' and 'session_path' in table.columns: + table = table.drop('session_path', axis=1) return table diff --git a/one/webclient.py b/one/webclient.py index 0782122c..d6b749d7 100644 --- a/one/webclient.py +++ b/one/webclient.py @@ -54,7 +54,7 @@ import one.params from iblutil.io import hashfile from iblutil.io.params import set_hidden -from one.util import ensure_list +from iblutil.util import ensure_list import concurrent.futures _logger = logging.getLogger(__name__) @@ -213,9 +213,12 @@ def __len__(self): def __getitem__(self, item): if isinstance(item, slice): while None in self._cache[item]: - self.populate(item.start + self._cache[item].index(None)) + # If slice start index is -ve, convert to +ve index + i = self.count + item.start if item.start < 0 else item.start + self.populate(i + self._cache[item].index(None)) elif self._cache[item] is None: - self.populate(item) + # If index is -ve, convert to +ve + self.populate(self.count + item if item < 0 else item) return self._cache[item] def populate(self, idx): @@ -642,6 +645,11 @@ def authenticate(self, username=None, password=None, cache_token=True, force=Fal if username is None and not self.silent: username = input('Enter Alyx username:') + # If user passes in a password, force re-authentication even if token cached + if password is not None: + if not force: + _logger.debug('Forcing token request with provided password') + force = True # Check if token cached if not force and getattr(self._par, 'TOKEN', False) and username in self._par.TOKEN: self._token = self._par.TOKEN[username] @@ -654,8 +662,17 @@ def authenticate(self, username=None, password=None, cache_token=True, force=Fal # Get password if password is None: password = getattr(self._par, 'ALYX_PWD', None) - if password is None and not self.silent: - password = getpass(f'Enter Alyx password for "{username}":') + if password is None: + if self.silent: + warnings.warn( + 'No password or cached token in silent mode. ' + 'Please run the following to re-authenticate:\n\t' + 'AlyxClient(silent=False).authenticate' + '(username=, force=True)', UserWarning) + else: + password = getpass(f'Enter Alyx password for "{username}":') + # Remove previous token + self._clear_token(username) try: credentials = {'username': username, 'password': password} rep = requests.post(self.base_url + '/auth-token', data=credentials) @@ -692,14 +709,12 @@ def authenticate(self, username=None, password=None, cache_token=True, force=Fal if not self.silent: print(f'Connected to {self.base_url} as user "{self.user}"') - def logout(self): - """Log out from Alyx. - Deletes the cached authentication token for the currently logged-in user. + def _clear_token(self, username): + """Remove auth token from client params. + + Deletes the cached authentication token for a given user. """ - if not self.is_logged_in: - return par = one.params.get(client=self.base_url, silent=True) - username = self.user # Remove token from cache if getattr(par, 'TOKEN', False) and username in par.TOKEN: del par.TOKEN[username] @@ -708,10 +723,20 @@ def logout(self): if getattr(self._par, 'TOKEN', False) and username in self._par.TOKEN: del self._par.TOKEN[username] # Remove token from object - self.user = None self._token = None if self._headers and 'Authorization' in self._headers: del self._headers['Authorization'] + + def logout(self): + """Log out from Alyx. + + Deletes the cached authentication token for the currently logged-in user + and clears the REST cache. + """ + if not self.is_logged_in: + return + self._clear_token(username := self.user) + self.user = None self.clear_rest_cache() if not self.silent: print(f'{username} logged out from {self.base_url}') diff --git a/requirements.txt b/requirements.txt index 6ea03889..1c3f9b7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ numpy>=1.18 pandas>=1.5.0 tqdm>=4.32.1 requests>=2.22.0 -iblutil>=1.1.0 +iblutil>=1.13.0 packaging boto3 pyyaml