diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b4c4a3d..587853d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ This version improves behaviour of loading revisions and loading datasets from l - warn user to reauthenticate when password is None in silent mode - always force authentication when password passed, even when token cached - bugfix: negative indexing of paginated response objects now functions correctly +- deprecate one.util.ensure_list; moved to iblutil.util.ensure_list ### Added diff --git a/one/api.py b/one/api.py index 1156e965..c375f735 100644 --- a/one/api.py +++ b/one/api.py @@ -20,7 +20,7 @@ import packaging.version from iblutil.io import parquet, hashfile -from iblutil.util import Bunch, flatten +from iblutil.util import Bunch, flatten, ensure_list import one.params import one.webclient as wc @@ -503,7 +503,7 @@ def sort_fcn(itm): return ([], None) if details else [] # String fields elif key in ('subject', 'task_protocol', 'laboratory', 'projects'): - query = '|'.join(util.ensure_list(value)) + query = '|'.join(ensure_list(value)) key = 'lab' if key == 'laboratory' else key mask = sessions[key].str.contains(query, regex=self.wildcards) sessions = sessions[mask.astype(bool, copy=False)] @@ -512,7 +512,7 @@ def sort_fcn(itm): session_date = pd.to_datetime(sessions['date']) sessions = sessions[(session_date >= start) & (session_date <= end)] elif key == 'number': - query = util.ensure_list(value) + query = ensure_list(value) sessions = sessions[sessions[key].isin(map(int, query))] # Dataset/QC check is biggest so this should be done last elif key == 'dataset' or (key == 'dataset_qc_lte' and 'dataset' not in queries): @@ -520,7 +520,7 @@ def sort_fcn(itm): qc = QC.validate(queries.get('dataset_qc_lte', 'FAIL')).name # validate value has_dset = sessions.index.isin(datasets.index.get_level_values('eid')) datasets = datasets.loc[(sessions.index.values[has_dset], ), :] - query = util.ensure_list(value if key == 'dataset' else '') + query = ensure_list(value if key == 'dataset' else '') # For each session check any dataset both contains query and exists mask = ( (datasets @@ -2271,7 +2271,7 @@ def search(self, details=False, query_type=None, **kwargs): def _add_date(records): """Add date field for compatibility with One.search output.""" - for s in util.ensure_list(records): + for s in ensure_list(records): s['date'] = datetime.fromisoformat(s['start_time']).date() return records @@ -2343,7 +2343,7 @@ def _download_aws(self, dsets, update_exists=True, keep_uuid=None, **_) -> List[ assert self.mode != 'local' # Get all dataset URLs dsets = list(dsets) # Ensure not generator - uuids = [util.ensure_list(x.name)[-1] for x in dsets] + uuids = [ensure_list(x.name)[-1] for x in dsets] # If number of UUIDs is too high, fetch in loop to avoid 414 HTTP status code remote_records = [] N = 100 # Number of UUIDs per query @@ -2417,7 +2417,7 @@ def _dset2url(self, dset, update_cache=True): did = dset['id'] elif 'file_records' not in dset: # Convert dataset Series to alyx dataset dict url = self.record2url(dset) # NB: URL will always be returned but may not exist - did = util.ensure_list(dset.name)[-1] + did = ensure_list(dset.name)[-1] else: # from datasets endpoint repo = getattr(getattr(self._web_client, '_par', None), 'HTTP_DATA_SERVER', None) url = next( @@ -2431,7 +2431,7 @@ def _dset2url(self, dset, update_cache=True): _logger.debug('Updating cache') # NB: This will be considerably easier when IndexSlice supports Ellipsis idx = [slice(None)] * int(self._cache['datasets'].index.nlevels / 2) - self._cache['datasets'].loc[(*idx, *util.ensure_list(did)), 'exists'] = False + self._cache['datasets'].loc[(*idx, *ensure_list(did)), 'exists'] = False self._cache['_meta']['modified_time'] = datetime.now() return url @@ -2525,7 +2525,7 @@ def _download_file(self, url, target_dir, keep_uuid=None, file_size=None, hash=N """ assert not self.offline # Ensure all target directories exist - [Path(x).mkdir(parents=True, exist_ok=True) for x in set(util.ensure_list(target_dir))] + [Path(x).mkdir(parents=True, exist_ok=True) for x in set(ensure_list(target_dir))] # download file(s) from url(s), returns file path(s) with UUID local_path, md5 = self.alyx.download_file(url, target_dir=target_dir, return_md5=True) @@ -2534,7 +2534,7 @@ def _download_file(self, url, target_dir, keep_uuid=None, file_size=None, hash=N if isinstance(url, (tuple, list)): assert (file_size is None) or len(file_size) == len(url) assert (hash is None) or len(hash) == len(url) - for args in zip(*map(util.ensure_list, (file_size, md5, hash, local_path, url))): + for args in zip(*map(ensure_list, (file_size, md5, hash, local_path, url))): self._check_hash_and_file_size_mismatch(*args) # check if we are keeping the uuid on the list of file names diff --git a/one/registration.py b/one/registration.py index ae85eb31..900a813e 100644 --- a/one/registration.py +++ b/one/registration.py @@ -23,14 +23,13 @@ import requests.exceptions from iblutil.io import hashfile -from iblutil.util import Bunch +from iblutil.util import Bunch, ensure_list import one.alf.io as alfio from one.alf.files import session_path_parts, get_session_path, folder_parts, filename_parts from one.alf.spec import is_valid import one.alf.exceptions as alferr from one.api import ONE -from one.util import ensure_list from one.webclient import no_cache _logger = logging.getLogger(__name__) diff --git a/one/remote/globus.py b/one/remote/globus.py index 161ef310..f9626eb5 100644 --- a/one/remote/globus.py +++ b/one/remote/globus.py @@ -97,12 +97,12 @@ from globus_sdk import TransferAPIError, GlobusAPIError, NetworkError, GlobusTimeoutError, \ GlobusConnectionError, GlobusConnectionTimeoutError, GlobusSDKUsageError, NullAuthorizer from iblutil.io import params as iopar +from iblutil.util import ensure_list from one.alf.spec import is_uuid from one.alf.files import remove_uuid_string import one.params from one.webclient import AlyxClient -from one.util import ensure_list from .base import DownloadClient, load_client_params, save_client_params __all__ = ['Globus', 'get_lab_from_endpoint_id', 'as_globus_path'] diff --git a/one/tests/test_one.py b/one/tests/test_one.py index b4040d51..c88351be 100644 --- a/one/tests/test_one.py +++ b/one/tests/test_one.py @@ -48,7 +48,7 @@ from one.api import ONE, One, OneAlyx from one.util import ( ses2records, validate_date_range, index_last_before, filter_datasets, _collection_spec, - filter_revision_last_before, parse_id, autocomplete, LazyId, datasets2records + filter_revision_last_before, parse_id, autocomplete, LazyId, datasets2records, ensure_list ) import one.params import one.alf.exceptions as alferr @@ -1564,7 +1564,7 @@ def test_search_insertions(self): dataset=['wheel.times'], query_type='remote') def test_search_terms(self): - """Test OneAlyx.search_terms""" + """Test OneAlyx.search_terms.""" assert self.one.mode != 'remote' search1 = self.one.search_terms() self.assertIn('dataset', search1) @@ -1582,7 +1582,7 @@ def test_search_terms(self): self.assertIn('model', search5) def test_load_dataset(self): - """Test OneAlyx.load_dataset""" + """Test OneAlyx.load_dataset.""" file = self.one.load_dataset(self.eid, '_spikeglx_sync.channels.npy', collection='raw_ephys_data', query_type='remote', download_only=True) @@ -1598,7 +1598,7 @@ def test_load_dataset(self): collection='alf', query_type='remote') def test_load_object(self): - """Test OneAlyx.load_object""" + """Test OneAlyx.load_object.""" files = self.one.load_object(self.eid, 'wheel', collection='alf', query_type='remote', download_only=True) @@ -1608,7 +1608,7 @@ def test_load_object(self): ) def test_get_details(self): - """Test OneAlyx.get_details""" + """Test OneAlyx.get_details.""" det = self.one.get_details(self.eid, query_type='remote') self.assertIsInstance(det, dict) self.assertEqual('SWC_043', det['subject']) @@ -1625,7 +1625,7 @@ def test_get_details(self): @unittest.skipIf(OFFLINE_ONLY, 'online only test') class TestOneDownload(unittest.TestCase): - """Test downloading datasets using OpenAlyx""" + """Test downloading datasets using OpenAlyx.""" tempdir = None one = None @@ -1639,7 +1639,7 @@ def setUp(self) -> None: self.eid = 'aad23144-0e52-4eac-80c5-c4ee2decb198' def test_download_datasets(self): - """Test OneAlyx._download_dataset, _download_file and _dset2url""" + """Test OneAlyx._download_dataset, _download_file and _dset2url.""" det = self.one.get_details(self.eid, True) rec = next(x for x in det['data_dataset_session_related'] if 'channels.brainLocation' in x['dataset_type']) @@ -1782,6 +1782,7 @@ def test_download_aws(self): def test_tag_mismatched_file_record(self): """Test for OneAlyx._tag_mismatched_file_record. + This method is also tested in test_download_datasets. """ did = '4a1500c2-60f3-418f-afa2-c752bb1890f0' @@ -1804,7 +1805,7 @@ def tearDown(self) -> None: class TestOneSetup(unittest.TestCase): - """Test parameter setup upon ONE instantiation and calling setup methods""" + """Test parameter setup upon ONE instantiation and calling setup methods.""" def setUp(self) -> None: self.tempdir = tempfile.TemporaryDirectory() self.addCleanup(self.tempdir.cleanup) @@ -1815,7 +1816,7 @@ def setUp(self) -> None: self.addCleanup(patch.stop) def test_local_cache_setup_prompt(self): - """Test One.setup""" + """Test One.setup.""" path = Path(self.tempdir.name).joinpath('subject', '2020-01-01', '1', 'spikes.times.npy') path.parent.mkdir(parents=True) path.touch() @@ -1839,6 +1840,7 @@ def test_local_cache_setup_prompt(self): def test_setup_silent(self): """Test setting up parameters with silent flag. + - Mock getfile to return temp dir as param file location - Mock input function as fail safe in case function erroneously prompts user for input """ @@ -1869,6 +1871,7 @@ def test_setup_silent(self): def test_setup_username(self): """Test setting up parameters with a provided username. + - Mock getfile to return temp dir as param file location - Mock input function as fail safe in case function erroneously prompts user for input - Mock requests.post returns a fake user authentication response @@ -1897,14 +1900,14 @@ def test_setup_username(self): @unittest.skipIf(OFFLINE_ONLY, 'online only test') def test_static_setup(self): - """Test OneAlyx.setup""" + """Test OneAlyx.setup.""" with mock.patch('iblutil.io.params.getfile', new=self.get_file), \ mock.patch('one.webclient.getpass', return_value='international'): one_obj = OneAlyx.setup(silent=True) self.assertEqual(one_obj.alyx.base_url, one.params.default().ALYX_URL) def test_setup(self): - """Test one.params.setup""" + """Test one.params.setup.""" url = TEST_DB_1['base_url'] def mock_input(prompt): @@ -1948,7 +1951,7 @@ def mock_input(prompt): self.assertTrue(getattr(mock_input, 'conflict_warn', False)) def test_patch_params(self): - """Test patching legacy params to the new location""" + """Test patching legacy params to the new location.""" # Save some old-style params old_pars = one.params.default().set('HTTP_DATA_SERVER', 'openalyx.org') # Save a REST query in the old location @@ -1965,7 +1968,7 @@ def test_patch_params(self): self.assertTrue(any(one_obj.alyx.cache_dir.joinpath('.rest').glob('*'))) def test_one_factory(self): - """Tests the ONE class factory""" + """Tests the ONE class factory.""" with mock.patch('iblutil.io.params.getfile', new=self.get_file), \ mock.patch('one.params.input', new=self.assertFalse): # Cache dir not in client cache map; use One (light) @@ -1999,9 +2002,9 @@ def test_one_factory(self): class TestOneMisc(unittest.TestCase): - """Test functions in one.util""" + """Test functions in one.util.""" def test_validate_date_range(self): - """Test one.util.validate_date_range""" + """Test one.util.validate_date_range.""" # Single string date actual = validate_date_range('2020-01-01') # On this day expected = (pd.Timestamp('2020-01-01 00:00:00'), @@ -2044,7 +2047,7 @@ def test_validate_date_range(self): validate_date_range(['2020-01-01', '2019-09-06', '2021-10-04']) def test_index_last_before(self): - """Test one.util.index_last_before""" + """Test one.util.index_last_before.""" revisions = ['2021-01-01', '2020-08-01', '', '2020-09-30'] verifiable = index_last_before(revisions, '2021-01-01') self.assertEqual(0, verifiable) @@ -2061,7 +2064,7 @@ def test_index_last_before(self): self.assertEqual(0, verifiable, 'should return most recent') def test_collection_spec(self): - """Test one.util._collection_spec""" + """Test one.util._collection_spec.""" # Test every combination of input inputs = [] _collection = {None: '({collection}/)?', '': '', '-': '{collection}/'} @@ -2075,7 +2078,7 @@ def test_collection_spec(self): self.assertEqual(expected, verifiable) def test_revision_last_before(self): - """Test one.util.filter_revision_last_before""" + """Test one.util.filter_revision_last_before.""" datasets = util.revisions_datasets_table() df = datasets[datasets.rel_path.str.startswith('alf/probe00')].copy() verifiable = filter_revision_last_before(df, revision='2020-09-01', assert_unique=False) @@ -2111,7 +2114,7 @@ def test_revision_last_before(self): filter_revision_last_before(df.copy(), assert_unique=True) def test_parse_id(self): - """Test one.util.parse_id""" + """Test one.util.parse_id.""" obj = unittest.mock.MagicMock() # Mock object to decorate obj.to_eid.return_value = 'parsed_id' # Method to be called input = 'subj/date/num' # Input id to pass to `to_eid` @@ -2125,7 +2128,7 @@ def test_parse_id(self): parse_id(obj.method)(obj, input) def test_autocomplete(self): - """Test one.util.autocomplete""" + """Test one.util.autocomplete.""" search_terms = ('subject', 'date_range', 'dataset', 'dataset_type') self.assertEqual('subject', autocomplete('Subj', search_terms)) self.assertEqual('dataset', autocomplete('dataset', search_terms)) @@ -2135,7 +2138,7 @@ def test_autocomplete(self): autocomplete('dat', search_terms) def test_LazyID(self): - """Test one.util.LazyID""" + """Test one.util.LazyID.""" uuids = [ 'c1a2758d-3ce5-4fa7-8d96-6b960f029fa9', '0780da08-a12b-452a-b936-ebc576aa7670', @@ -2149,3 +2152,12 @@ def test_LazyID(self): self.assertEqual(ez[0:2], uuids[0:2]) ez = LazyId([{'id': x} for x in uuids]) self.assertCountEqual(map(str, ez), uuids) + + def test_ensure_list(self): + """Test one.util.ensure_list. + + This function has now moved therefore we simply check for deprecation warning. + """ + with self.assertWarns(DeprecationWarning): + self.assertEqual(['123'], ensure_list('123')) + self.assertIs(x := ['123'], ensure_list(x)) diff --git a/one/util.py b/one/util.py index 8c5a3719..056da194 100644 --- a/one/util.py +++ b/one/util.py @@ -11,6 +11,7 @@ import pandas as pd from iblutil.io import parquet +from iblutil.util import ensure_list as _ensure_list import numpy as np from packaging import version @@ -96,7 +97,7 @@ def datasets2records(datasets, additional=None) -> pd.DataFrame: """ records = [] - for d in ensure_list(datasets): + for d in _ensure_list(datasets): file_record = next((x for x in d['file_records'] if x['data_url'] and x['exists']), None) if not file_record: continue # Ignore files that are not accessible @@ -396,7 +397,7 @@ def filter_datasets( exclude_slash = partial(re.sub, fr'^({flagless_token})?\.\*', r'\g<1>[^/]*') spec_str += '|'.join( exclude_slash(fnmatch.translate(x)) if wildcards else x + '$' - for x in ensure_list(filename) + for x in _ensure_list(filename) ) # If matching revision name, add to regex string @@ -408,7 +409,7 @@ def filter_datasets( continue if wildcards: # Convert to regex, remove \\Z which asserts end of string - v = (fnmatch.translate(x).replace('\\Z', '') for x in ensure_list(v)) + v = (fnmatch.translate(x).replace('\\Z', '') for x in _ensure_list(v)) if not isinstance(v, str): regex_args[k] = '|'.join(v) # logical OR @@ -592,7 +593,10 @@ def autocomplete(term, search_terms) -> str: def ensure_list(value): """Ensure input is a list.""" - return [value] if isinstance(value, (str, dict)) or not isinstance(value, Iterable) else value + warnings.warn( + 'one.util.ensure_list is deprecated, use iblutil.util.ensure_list instead', + DeprecationWarning) + return _ensure_list(value) class LazyId(Mapping): diff --git a/one/webclient.py b/one/webclient.py index 155beeca..d6b749d7 100644 --- a/one/webclient.py +++ b/one/webclient.py @@ -54,7 +54,7 @@ import one.params from iblutil.io import hashfile from iblutil.io.params import set_hidden -from one.util import ensure_list +from iblutil.util import ensure_list import concurrent.futures _logger = logging.getLogger(__name__) diff --git a/requirements.txt b/requirements.txt index 6ea03889..1c3f9b7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ numpy>=1.18 pandas>=1.5.0 tqdm>=4.32.1 requests>=2.22.0 -iblutil>=1.1.0 +iblutil>=1.13.0 packaging boto3 pyyaml