From dbfe95f3caeeedf07c9105137d9b6b339183272e Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Mon, 14 Oct 2024 13:20:20 +0300 Subject: [PATCH] Added util.ensure_list --- CHANGELOG.md | 14 ++++++++++++-- iblutil/__init__.py | 2 +- iblutil/io/net/app.py | 2 +- iblutil/util.py | 28 +++++++++++++++++++++++++++- tests/test_util.py | 20 ++++++++++++++++++++ 5 files changed, 61 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 662f9b5..47a8d2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,17 @@ * Minor releases (X.1.X) are new features such as added functions or small changes that don't cause major compatibility issues. * Major releases (1.X.X) are major new features or changes that break backward compatibility in a big way. -## [Latest](https://github.com/int-brain-lab/iblutil/commits/main) [1.12.1] +## [Latest](https://github.com/int-brain-lab/iblutil/commits/main) [1.13.0] + +### Added + +- util.ensure_list: function returning input wrapped in list if not an iterator or a member of specific iterable classes + +### Modified + +- io.net.app.EchoProtocol.confirmed_send: more informative timeout error message + +## [1.12.1] ### Modified @@ -13,7 +23,7 @@ ### Added -- io.net.base.get_mac: function returning the machine's unique MAC address formatted according to IEEE 802 specifications. +- io.net.base.get_mac: function returning the machine's unique MAC address formatted according to IEEE 802 specifications ## [1.11.0] diff --git a/iblutil/__init__.py b/iblutil/__init__.py index fe70fa2..84c54b7 100644 --- a/iblutil/__init__.py +++ b/iblutil/__init__.py @@ -1 +1 @@ -__version__ = '1.12.1' +__version__ = '1.13.0' diff --git a/iblutil/io/net/app.py b/iblutil/io/net/app.py index 5588556..3b0c2dd 100644 --- a/iblutil/io/net/app.py +++ b/iblutil/io/net/app.py @@ -377,7 +377,7 @@ async def confirmed_send(self, data, addr=None, timeout=None): await asyncio.wait_for(echo_future, timeout=timeout) except asyncio.TimeoutError: self.close() - raise TimeoutError('Failed to receive client response in time') + raise TimeoutError(f'Failed to receive client response in time ({self.name}: {self.server_uri})') except RuntimeError: self.close() raise RuntimeError('Unexpected response from server') diff --git a/iblutil/util.py b/iblutil/util.py index c18b20b..ac9d4bf 100644 --- a/iblutil/util.py +++ b/iblutil/util.py @@ -7,7 +7,7 @@ import copy import logging import sys -from typing import Union +from typing import Union, Iterable import numpy as np @@ -298,3 +298,29 @@ def get_mac() -> str: (e.g., 'BA-DB-AD-C0-FF-EE'). """ return uuid.getnode().to_bytes(6, 'big').hex('-').upper() + + +def ensure_list(value, exclude_type=(str, dict)): + """Ensure input is a list. + + Wraps `value` in a list if not already an iterator or if it is a member of specific + iterable classes. + + To allow users the option of passing a single value or multiple values, this function + will wrap the former in a list and by default will consider str and dict instances as + a single value. This function is useful because it treats tuples, lists, sets, and + generators all as 'lists', but not dictionaries and strings. + + Parameters + ---------- + value : any + Input to ensure list. + exclude_type : tuple, optional + A list of iterable classes to wrap in a list. + + Returns + ------- + Iterable + Either `value` if iterable and not in `exclude_type` list, or `value` wrapped in a list. + """ + return [value] if isinstance(value, exclude_type) or not isinstance(value, Iterable) else value diff --git a/tests/test_util.py b/tests/test_util.py index 50de94f..c46c31e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -210,5 +210,25 @@ def test_get_mac(self): self.assertEqual(util.get_mac(), 'BA-DB-AD-C0-FF-EE') +class TestEnsureList(unittest.TestCase): + """Test ensure_list function.""" + + def test_ensure_list(self): + """Test ensure_list function.""" + x = [1, 2, 3] + self.assertIs(x, util.ensure_list(x)) + x = tuple(x) + self.assertIs(x, util.ensure_list(x)) + # Shouldn't iterate over strings or dicts + x = '123' + self.assertEqual([x], util.ensure_list(x)) + x = {'1': 1, '2': 2, '3': 3} + self.assertEqual([x], util.ensure_list(x)) + # Check exclude_type param behaviour + x = np.array([1, 2, 3]) + self.assertIs(x, util.ensure_list(x)) + self.assertEqual([x], util.ensure_list(x, exclude_type=(np.ndarray))) + + if __name__ == '__main__': unittest.main(exit=False)