Skip to content

Commit

Permalink
Deprecate one.util.ensure_list; moved to iblutil.util.ensure_list
Browse files Browse the repository at this point in the history
  • Loading branch information
k1o0 committed Oct 14, 2024
1 parent 20aaa20 commit 1efa119
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This version improves behaviour of loading revisions and loading datasets from l
- warn user to reauthenticate when password is None in silent mode
- always force authentication when password passed, even when token cached
- bugfix: negative indexing of paginated response objects now functions correctly
- deprecate one.util.ensure_list; moved to iblutil.util.ensure_list

### Added

Expand Down
20 changes: 10 additions & 10 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import packaging.version

from iblutil.io import parquet, hashfile
from iblutil.util import Bunch, flatten
from iblutil.util import Bunch, flatten, ensure_list

import one.params
import one.webclient as wc
Expand Down Expand Up @@ -503,7 +503,7 @@ def sort_fcn(itm):
return ([], None) if details else []
# String fields
elif key in ('subject', 'task_protocol', 'laboratory', 'projects'):
query = '|'.join(util.ensure_list(value))
query = '|'.join(ensure_list(value))
key = 'lab' if key == 'laboratory' else key
mask = sessions[key].str.contains(query, regex=self.wildcards)
sessions = sessions[mask.astype(bool, copy=False)]
Expand All @@ -512,15 +512,15 @@ def sort_fcn(itm):
session_date = pd.to_datetime(sessions['date'])
sessions = sessions[(session_date >= start) & (session_date <= end)]
elif key == 'number':
query = util.ensure_list(value)
query = ensure_list(value)
sessions = sessions[sessions[key].isin(map(int, query))]
# Dataset/QC check is biggest so this should be done last
elif key == 'dataset' or (key == 'dataset_qc_lte' and 'dataset' not in queries):
datasets = self._cache['datasets']
qc = QC.validate(queries.get('dataset_qc_lte', 'FAIL')).name # validate value
has_dset = sessions.index.isin(datasets.index.get_level_values('eid'))
datasets = datasets.loc[(sessions.index.values[has_dset], ), :]
query = util.ensure_list(value if key == 'dataset' else '')
query = ensure_list(value if key == 'dataset' else '')
# For each session check any dataset both contains query and exists
mask = (
(datasets
Expand Down Expand Up @@ -2271,7 +2271,7 @@ def search(self, details=False, query_type=None, **kwargs):

def _add_date(records):
"""Add date field for compatibility with One.search output."""
for s in util.ensure_list(records):
for s in ensure_list(records):
s['date'] = datetime.fromisoformat(s['start_time']).date()
return records

Expand Down Expand Up @@ -2343,7 +2343,7 @@ def _download_aws(self, dsets, update_exists=True, keep_uuid=None, **_) -> List[
assert self.mode != 'local'
# Get all dataset URLs
dsets = list(dsets) # Ensure not generator
uuids = [util.ensure_list(x.name)[-1] for x in dsets]
uuids = [ensure_list(x.name)[-1] for x in dsets]
# If number of UUIDs is too high, fetch in loop to avoid 414 HTTP status code
remote_records = []
N = 100 # Number of UUIDs per query
Expand Down Expand Up @@ -2417,7 +2417,7 @@ def _dset2url(self, dset, update_cache=True):
did = dset['id']
elif 'file_records' not in dset: # Convert dataset Series to alyx dataset dict
url = self.record2url(dset) # NB: URL will always be returned but may not exist
did = util.ensure_list(dset.name)[-1]
did = ensure_list(dset.name)[-1]
else: # from datasets endpoint
repo = getattr(getattr(self._web_client, '_par', None), 'HTTP_DATA_SERVER', None)
url = next(
Expand All @@ -2431,7 +2431,7 @@ def _dset2url(self, dset, update_cache=True):
_logger.debug('Updating cache')
# NB: This will be considerably easier when IndexSlice supports Ellipsis
idx = [slice(None)] * int(self._cache['datasets'].index.nlevels / 2)
self._cache['datasets'].loc[(*idx, *util.ensure_list(did)), 'exists'] = False
self._cache['datasets'].loc[(*idx, *ensure_list(did)), 'exists'] = False
self._cache['_meta']['modified_time'] = datetime.now()

return url
Expand Down Expand Up @@ -2525,7 +2525,7 @@ def _download_file(self, url, target_dir, keep_uuid=None, file_size=None, hash=N
"""
assert not self.offline
# Ensure all target directories exist
[Path(x).mkdir(parents=True, exist_ok=True) for x in set(util.ensure_list(target_dir))]
[Path(x).mkdir(parents=True, exist_ok=True) for x in set(ensure_list(target_dir))]

# download file(s) from url(s), returns file path(s) with UUID
local_path, md5 = self.alyx.download_file(url, target_dir=target_dir, return_md5=True)
Expand All @@ -2534,7 +2534,7 @@ def _download_file(self, url, target_dir, keep_uuid=None, file_size=None, hash=N
if isinstance(url, (tuple, list)):
assert (file_size is None) or len(file_size) == len(url)
assert (hash is None) or len(hash) == len(url)
for args in zip(*map(util.ensure_list, (file_size, md5, hash, local_path, url))):
for args in zip(*map(ensure_list, (file_size, md5, hash, local_path, url))):
self._check_hash_and_file_size_mismatch(*args)

# check if we are keeping the uuid on the list of file names
Expand Down
3 changes: 1 addition & 2 deletions one/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@
import requests.exceptions

from iblutil.io import hashfile
from iblutil.util import Bunch
from iblutil.util import Bunch, ensure_list

import one.alf.io as alfio
from one.alf.files import session_path_parts, get_session_path, folder_parts, filename_parts
from one.alf.spec import is_valid
import one.alf.exceptions as alferr
from one.api import ONE
from one.util import ensure_list
from one.webclient import no_cache

_logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion one/remote/globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@
from globus_sdk import TransferAPIError, GlobusAPIError, NetworkError, GlobusTimeoutError, \
GlobusConnectionError, GlobusConnectionTimeoutError, GlobusSDKUsageError, NullAuthorizer
from iblutil.io import params as iopar
from iblutil.util import ensure_list

from one.alf.spec import is_uuid
from one.alf.files import remove_uuid_string
import one.params
from one.webclient import AlyxClient
from one.util import ensure_list
from .base import DownloadClient, load_client_params, save_client_params

__all__ = ['Globus', 'get_lab_from_endpoint_id', 'as_globus_path']
Expand Down
54 changes: 33 additions & 21 deletions one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from one.api import ONE, One, OneAlyx
from one.util import (
ses2records, validate_date_range, index_last_before, filter_datasets, _collection_spec,
filter_revision_last_before, parse_id, autocomplete, LazyId, datasets2records
filter_revision_last_before, parse_id, autocomplete, LazyId, datasets2records, ensure_list
)
import one.params
import one.alf.exceptions as alferr
Expand Down Expand Up @@ -1564,7 +1564,7 @@ def test_search_insertions(self):
dataset=['wheel.times'], query_type='remote')

def test_search_terms(self):
"""Test OneAlyx.search_terms"""
"""Test OneAlyx.search_terms."""
assert self.one.mode != 'remote'
search1 = self.one.search_terms()
self.assertIn('dataset', search1)
Expand All @@ -1582,7 +1582,7 @@ def test_search_terms(self):
self.assertIn('model', search5)

def test_load_dataset(self):
"""Test OneAlyx.load_dataset"""
"""Test OneAlyx.load_dataset."""
file = self.one.load_dataset(self.eid, '_spikeglx_sync.channels.npy',
collection='raw_ephys_data', query_type='remote',
download_only=True)
Expand All @@ -1598,7 +1598,7 @@ def test_load_dataset(self):
collection='alf', query_type='remote')

def test_load_object(self):
"""Test OneAlyx.load_object"""
"""Test OneAlyx.load_object."""
files = self.one.load_object(self.eid, 'wheel',
collection='alf', query_type='remote',
download_only=True)
Expand All @@ -1608,7 +1608,7 @@ def test_load_object(self):
)

def test_get_details(self):
"""Test OneAlyx.get_details"""
"""Test OneAlyx.get_details."""
det = self.one.get_details(self.eid, query_type='remote')
self.assertIsInstance(det, dict)
self.assertEqual('SWC_043', det['subject'])
Expand All @@ -1625,7 +1625,7 @@ def test_get_details(self):

@unittest.skipIf(OFFLINE_ONLY, 'online only test')
class TestOneDownload(unittest.TestCase):
"""Test downloading datasets using OpenAlyx"""
"""Test downloading datasets using OpenAlyx."""
tempdir = None
one = None

Expand All @@ -1639,7 +1639,7 @@ def setUp(self) -> None:
self.eid = 'aad23144-0e52-4eac-80c5-c4ee2decb198'

def test_download_datasets(self):
"""Test OneAlyx._download_dataset, _download_file and _dset2url"""
"""Test OneAlyx._download_dataset, _download_file and _dset2url."""
det = self.one.get_details(self.eid, True)
rec = next(x for x in det['data_dataset_session_related']
if 'channels.brainLocation' in x['dataset_type'])
Expand Down Expand Up @@ -1782,6 +1782,7 @@ def test_download_aws(self):

def test_tag_mismatched_file_record(self):
"""Test for OneAlyx._tag_mismatched_file_record.
This method is also tested in test_download_datasets.
"""
did = '4a1500c2-60f3-418f-afa2-c752bb1890f0'
Expand All @@ -1804,7 +1805,7 @@ def tearDown(self) -> None:


class TestOneSetup(unittest.TestCase):
"""Test parameter setup upon ONE instantiation and calling setup methods"""
"""Test parameter setup upon ONE instantiation and calling setup methods."""
def setUp(self) -> None:
self.tempdir = tempfile.TemporaryDirectory()
self.addCleanup(self.tempdir.cleanup)
Expand All @@ -1815,7 +1816,7 @@ def setUp(self) -> None:
self.addCleanup(patch.stop)

def test_local_cache_setup_prompt(self):
"""Test One.setup"""
"""Test One.setup."""
path = Path(self.tempdir.name).joinpath('subject', '2020-01-01', '1', 'spikes.times.npy')
path.parent.mkdir(parents=True)
path.touch()
Expand All @@ -1839,6 +1840,7 @@ def test_local_cache_setup_prompt(self):

def test_setup_silent(self):
"""Test setting up parameters with silent flag.
- Mock getfile to return temp dir as param file location
- Mock input function as fail safe in case function erroneously prompts user for input
"""
Expand Down Expand Up @@ -1869,6 +1871,7 @@ def test_setup_silent(self):

def test_setup_username(self):
"""Test setting up parameters with a provided username.
- Mock getfile to return temp dir as param file location
- Mock input function as fail safe in case function erroneously prompts user for input
- Mock requests.post returns a fake user authentication response
Expand Down Expand Up @@ -1897,14 +1900,14 @@ def test_setup_username(self):

@unittest.skipIf(OFFLINE_ONLY, 'online only test')
def test_static_setup(self):
"""Test OneAlyx.setup"""
"""Test OneAlyx.setup."""
with mock.patch('iblutil.io.params.getfile', new=self.get_file), \
mock.patch('one.webclient.getpass', return_value='international'):
one_obj = OneAlyx.setup(silent=True)
self.assertEqual(one_obj.alyx.base_url, one.params.default().ALYX_URL)

def test_setup(self):
"""Test one.params.setup"""
"""Test one.params.setup."""
url = TEST_DB_1['base_url']

def mock_input(prompt):
Expand Down Expand Up @@ -1948,7 +1951,7 @@ def mock_input(prompt):
self.assertTrue(getattr(mock_input, 'conflict_warn', False))

def test_patch_params(self):
"""Test patching legacy params to the new location"""
"""Test patching legacy params to the new location."""
# Save some old-style params
old_pars = one.params.default().set('HTTP_DATA_SERVER', 'openalyx.org')
# Save a REST query in the old location
Expand All @@ -1965,7 +1968,7 @@ def test_patch_params(self):
self.assertTrue(any(one_obj.alyx.cache_dir.joinpath('.rest').glob('*')))

def test_one_factory(self):
"""Tests the ONE class factory"""
"""Tests the ONE class factory."""
with mock.patch('iblutil.io.params.getfile', new=self.get_file), \
mock.patch('one.params.input', new=self.assertFalse):
# Cache dir not in client cache map; use One (light)
Expand Down Expand Up @@ -1999,9 +2002,9 @@ def test_one_factory(self):


class TestOneMisc(unittest.TestCase):
"""Test functions in one.util"""
"""Test functions in one.util."""
def test_validate_date_range(self):
"""Test one.util.validate_date_range"""
"""Test one.util.validate_date_range."""
# Single string date
actual = validate_date_range('2020-01-01') # On this day
expected = (pd.Timestamp('2020-01-01 00:00:00'),
Expand Down Expand Up @@ -2044,7 +2047,7 @@ def test_validate_date_range(self):
validate_date_range(['2020-01-01', '2019-09-06', '2021-10-04'])

def test_index_last_before(self):
"""Test one.util.index_last_before"""
"""Test one.util.index_last_before."""
revisions = ['2021-01-01', '2020-08-01', '', '2020-09-30']
verifiable = index_last_before(revisions, '2021-01-01')
self.assertEqual(0, verifiable)
Expand All @@ -2061,7 +2064,7 @@ def test_index_last_before(self):
self.assertEqual(0, verifiable, 'should return most recent')

def test_collection_spec(self):
"""Test one.util._collection_spec"""
"""Test one.util._collection_spec."""
# Test every combination of input
inputs = []
_collection = {None: '({collection}/)?', '': '', '-': '{collection}/'}
Expand All @@ -2075,7 +2078,7 @@ def test_collection_spec(self):
self.assertEqual(expected, verifiable)

def test_revision_last_before(self):
"""Test one.util.filter_revision_last_before"""
"""Test one.util.filter_revision_last_before."""
datasets = util.revisions_datasets_table()
df = datasets[datasets.rel_path.str.startswith('alf/probe00')].copy()
verifiable = filter_revision_last_before(df, revision='2020-09-01', assert_unique=False)
Expand Down Expand Up @@ -2111,7 +2114,7 @@ def test_revision_last_before(self):
filter_revision_last_before(df.copy(), assert_unique=True)

def test_parse_id(self):
"""Test one.util.parse_id"""
"""Test one.util.parse_id."""
obj = unittest.mock.MagicMock() # Mock object to decorate
obj.to_eid.return_value = 'parsed_id' # Method to be called
input = 'subj/date/num' # Input id to pass to `to_eid`
Expand All @@ -2125,7 +2128,7 @@ def test_parse_id(self):
parse_id(obj.method)(obj, input)

def test_autocomplete(self):
"""Test one.util.autocomplete"""
"""Test one.util.autocomplete."""
search_terms = ('subject', 'date_range', 'dataset', 'dataset_type')
self.assertEqual('subject', autocomplete('Subj', search_terms))
self.assertEqual('dataset', autocomplete('dataset', search_terms))
Expand All @@ -2135,7 +2138,7 @@ def test_autocomplete(self):
autocomplete('dat', search_terms)

def test_LazyID(self):
"""Test one.util.LazyID"""
"""Test one.util.LazyID."""
uuids = [
'c1a2758d-3ce5-4fa7-8d96-6b960f029fa9',
'0780da08-a12b-452a-b936-ebc576aa7670',
Expand All @@ -2149,3 +2152,12 @@ def test_LazyID(self):
self.assertEqual(ez[0:2], uuids[0:2])
ez = LazyId([{'id': x} for x in uuids])
self.assertCountEqual(map(str, ez), uuids)

def test_ensure_list(self):
"""Test one.util.ensure_list.
This function has now moved therefore we simply check for deprecation warning.
"""
with self.assertWarns(DeprecationWarning):
self.assertEqual(['123'], ensure_list('123'))
self.assertIs(x := ['123'], ensure_list(x))
12 changes: 8 additions & 4 deletions one/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pandas as pd
from iblutil.io import parquet
from iblutil.util import ensure_list as _ensure_list
import numpy as np
from packaging import version

Expand Down Expand Up @@ -96,7 +97,7 @@ def datasets2records(datasets, additional=None) -> pd.DataFrame:
"""
records = []

for d in ensure_list(datasets):
for d in _ensure_list(datasets):
file_record = next((x for x in d['file_records'] if x['data_url'] and x['exists']), None)
if not file_record:
continue # Ignore files that are not accessible
Expand Down Expand Up @@ -396,7 +397,7 @@ def filter_datasets(
exclude_slash = partial(re.sub, fr'^({flagless_token})?\.\*', r'\g<1>[^/]*')
spec_str += '|'.join(
exclude_slash(fnmatch.translate(x)) if wildcards else x + '$'
for x in ensure_list(filename)
for x in _ensure_list(filename)
)

# If matching revision name, add to regex string
Expand All @@ -408,7 +409,7 @@ def filter_datasets(
continue
if wildcards:
# Convert to regex, remove \\Z which asserts end of string
v = (fnmatch.translate(x).replace('\\Z', '') for x in ensure_list(v))
v = (fnmatch.translate(x).replace('\\Z', '') for x in _ensure_list(v))
if not isinstance(v, str):
regex_args[k] = '|'.join(v) # logical OR

Expand Down Expand Up @@ -592,7 +593,10 @@ def autocomplete(term, search_terms) -> str:

def ensure_list(value):
"""Ensure input is a list."""
return [value] if isinstance(value, (str, dict)) or not isinstance(value, Iterable) else value
warnings.warn(
'one.util.ensure_list is deprecated, use iblutil.util.ensure_list instead',
DeprecationWarning)
return _ensure_list(value)


class LazyId(Mapping):
Expand Down
Loading

0 comments on commit 1efa119

Please sign in to comment.