diff --git a/biothings_client/__init__.py b/biothings_client/__init__.py index 5c3d979..5ac8c7d 100644 --- a/biothings_client/__init__.py +++ b/biothings_client/__init__.py @@ -5,6 +5,15 @@ from biothings_client.client.asynchronous import AsyncBiothingClient, get_async_client from biothings_client.client.base import BiothingClient, get_client from biothings_client.__version__ import __version__ +from biothings_client._dependencies import _CACHING, _PANDAS -__all__ = ["AsyncBiothingClient", "BiothingClient", "get_client", "get_async_client", "__version__"] +__all__ = [ + "AsyncBiothingClient", + "BiothingClient", + "get_client", + "get_async_client", + "__version__", + "_CACHING", + "_PANDAS", +] diff --git a/biothings_client/_dependencies.py b/biothings_client/_dependencies.py new file mode 100644 index 0000000..1444af3 --- /dev/null +++ b/biothings_client/_dependencies.py @@ -0,0 +1,4 @@ +import importlib + +_PANDAS = importlib.util.find_spec("pandas") is not None +_CACHING = importlib.util.find_spec("hishel") is not None and importlib.util.find_spec("anysqlite") is not None diff --git a/tests/__init__.py b/biothings_client/cache/__init__.py similarity index 100% rename from tests/__init__.py rename to biothings_client/cache/__init__.py diff --git a/biothings_client/cache/storage.py b/biothings_client/cache/storage.py new file mode 100644 index 0000000..40092f7 --- /dev/null +++ b/biothings_client/cache/storage.py @@ -0,0 +1,265 @@ +""" +Custom cache storage handling for the biothings-client +""" + +from pathlib import Path +from typing import Optional, Union +import logging +import sqlite3 + +from biothings_client._dependencies import _CACHING + +if _CACHING: + import anysqlite + import hishel + + +logger = logging.getLogger("biothings.client") +logger.setLevel(logging.INFO) + + +class BiothingsClientSqlite3Cache(hishel.SQLiteStorage): + """ + Overriden sqlite3 client for some extra functionality + + We have two main properties that we want from this overridden + class: + + 1) The ability to get the cache location. This is accessed via + the `cache_filepath` property + 2) The ability to clear the cache. This can be performed via the + `clear_cache` methodcall + """ + + def __init__( + self, + serializer: Optional[hishel.BaseSerializer] = None, + connection: Optional[anysqlite.Connection] = None, + ttl: Optional[Union[int, float]] = None, + ) -> None: + self._cache_filepath: Path = None + self._connection: Optional[sqlite3.Connection] = connection or None + self._setup_completed: bool = False + super().__init__(serializer, connection, ttl) + + def setup_database_connection(self, cache_filepath: Union[str, Path] = None) -> None: + """ + Establishes the sqlite3 database connection if it hasn't been + created yet + + Override of the _setup method so that we can specify the database + file path. Exposed publically so that the user can specify this as well + along with during biothings_client testing + """ + if not self._setup_completed: + if cache_filepath is None: + home_directory = Path.home() + cache_directory = home_directory.joinpath(".cache") + cache_directory.mkdir(parents=True, exist_ok=True) + cache_filepath.joinpath(".hishel.sqlite") + self._cache_filepath = cache_filepath.resolve().absolute() + + with self._setup_lock: + if not self._connection: # pragma: no cover + self._connection = sqlite3.connect(self._cache_filepath, check_same_thread=False) + table_creation_commnd = "CREATE TABLE IF NOT EXISTS cache(key TEXT, data BLOB, date_created REAL)" + self._connection.execute(table_creation_commnd) + self._connection.commit() + self._setup_completed = True + + def clear_cache(self) -> None: + """ + Clears the sqlite3 cache + + 1) Performs a DELETE to remove all rows without + dropping the table + 2) Update the auto-increment counter + 3) Perform vacuum operation + """ + cache_table_name = "cache" + with self._setup_lock: + try: + drop_table_command = f"DELETE FROM {cache_table_name}" + self._connection.execute(drop_table_command) + self._connection.commit() + except sqlite3.OperationalError as operational_error: + logger.exception(operational_error) + exception_message = operational_error.args[0] + missing_cache_table_message = f"no such table: {cache_table_name}" + if exception_message == missing_cache_table_message: + logger.debug("No table [%s] to clear. Skipping ...", cache_table_name) + else: + raise operational_error + except Exception as gen_exc: + logger.exception(gen_exc) + raise gen_exc + + autoincrement_table_name = "SQLITE_SEQUENCE" + try: + reset_autoincrement_command = ( + f"UPDATE {autoincrement_table_name} SET seq = 0 WHERE name = '{cache_table_name}'" + ) + self._connection.execute(reset_autoincrement_command) + self._connection.commit() + except sqlite3.OperationalError as operational_error: + logger.exception(operational_error) + exception_message = operational_error.args[0] + missing_autoincrement_table_message = f"no such table: {autoincrement_table_name}" + if exception_message == missing_autoincrement_table_message: + logger.debug("No table [%s] to update. Skipping ...", autoincrement_table_name) + else: + raise operational_error + except Exception as gen_exc: + logger.exception(gen_exc) + raise gen_exc + + try: + vacuum_command = "VACUUM" + self._connection.execute(vacuum_command) + self._connection.commit() + except sqlite3.OperationalError as operational_error: + logger.exception(operational_error) + raise operational_error + except Exception as gen_exc: + logger.exception(gen_exc) + raise gen_exc + + @property + def cache_filepath(self) -> Path: + """ + Returns the filepath for the sqlite3 cache database + + We have either stored it because we generated it ourselves + via `BiothingsClientSqlite3Storage.database_connection` or we + have to look it up in the database via the following PRAGMA: + https://www.sqlite.org/pragma.html#pragma_database_list + """ + self.setup_database_connection() + if self._cache_filepath is None: + pragma_command = "PRAGMA database_list" + for _, name, filename in self._connection.execute(pragma_command): + if name == "main" and filename is not None: + self._cache_filepath = Path(filename).resolve().absolute() + break + return self._cache_filepath + + +class AsyncBiothingsClientSqlite3Cache(hishel.AsyncSQLiteStorage): + """ + Overriden sqlite3 client for some extra functionality + + We have two main properties that we want from this overridden + class: + + 1) The ability to get the cache location. This is accessed via + the `cache_filepath` property + 2) The ability to clear the cache. This can be performed via the + `clear_cache` methodcall + """ + + def __init__( + self, + serializer: Optional[hishel.BaseSerializer] = None, + connection: Optional[anysqlite.Connection] = None, + ttl: Optional[Union[int, float]] = None, + ) -> None: + self._cache_filepath = None + super().__init__(serializer, connection, ttl) + + async def setup_database_connection(self, cache_filepath: Union[str, Path] = None) -> None: + """ + Establishes the sqlite3 database connection if it hasn't been + created yet + + Override of the _setup method so that we can specify the database + file path. Exposed publically so that the user can specify this as well + along with during biothings_client testing + """ + if not self._setup_completed: + if cache_filepath is None: + home_directory = Path.home() + cache_directory = home_directory.joinpath(".cache") + cache_directory.mkdir(parents=True, exist_ok=True) + cache_filepath.joinpath(".hishel.sqlite") + self._cache_filepath = cache_filepath.resolve().absolute() + + async with self._setup_lock: + if not self._connection: # pragma: no cover + self._connection = await anysqlite.connect(self._cache_filepath, check_same_thread=False) + table_creation_commnd = "CREATE TABLE IF NOT EXISTS cache(key TEXT, data BLOB, date_created REAL)" + await self._connection.execute(table_creation_commnd) + await self._connection.commit() + self._setup_completed = True + + async def clear_cache(self) -> None: + """ + Clears the sqlite3 cache + + 1) Performs a DELETE to remove all rows without + dropping the table + """ + async with self._setup_lock: + cache_table_name = "cache" + try: + drop_table_command = f"DELETE FROM {cache_table_name}" + await self._connection.execute(drop_table_command) + await self._connection.commit() + except anysqlite.OperationalError as operational_error: + logger.exception(operational_error) + exception_message = operational_error.args[0] + missing_cache_table_message = f"no such table: {cache_table_name}" + if exception_message == missing_cache_table_message: + logger.debug("No table [%s] to clear. Skipping ...", cache_table_name) + else: + raise operational_error + except Exception as gen_exc: + logger.exception(gen_exc) + raise gen_exc + + autoincrement_table_name = "SQLITE_SEQUENCE" + try: + reset_autoincrement_command = ( + f"UPDATE {autoincrement_table_name} SET seq = 0 WHERE name = '{cache_table_name}'" + ) + await self._connection.execute(reset_autoincrement_command) + await self._connection.commit() + except sqlite3.OperationalError as operational_error: + logger.exception(operational_error) + exception_message = operational_error.args[0] + missing_autoincrement_table_message = f"no such table: {autoincrement_table_name}" + if exception_message == missing_autoincrement_table_message: + logger.debug("No table [%s] to update. Skipping ...", autoincrement_table_name) + else: + raise operational_error + except Exception as gen_exc: + logger.exception(gen_exc) + raise gen_exc + + try: + vacuum_command = "VACUUM" + await self._connection.execute(vacuum_command) + await self._connection.commit() + except sqlite3.OperationalError as operational_error: + logger.exception(operational_error) + raise operational_error + except Exception as gen_exc: + logger.exception(gen_exc) + raise gen_exc + + @property + async def cache_filepath(self) -> Path: + """ + Returns the filepath for the sqlite3 cache database + + We have either stored it because we generated it ourselves + via `BiothingsClientSqlite3Storage.database_connection` or we + have to look it up in the database via the following PRAGMA: + https://www.sqlite.org/pragma.html#pragma_database_list + """ + if self._cache_filepath is None: + pragma_command = "PRAGMA database_list" + async for _, name, filename in self._connection.execute(pragma_command): + if name == "main" and filename is not None: + self._cache_filepath = Path(filename).resolve().absolute() + break + return self._cache_filepath diff --git a/biothings_client/client/asynchronous.py b/biothings_client/client/asynchronous.py index c260762..13d9227 100644 --- a/biothings_client/client/asynchronous.py +++ b/biothings_client/client/asynchronous.py @@ -4,30 +4,15 @@ from collections.abc import Iterable from copy import copy -from typing import Tuple -import asyncio +from pathlib import Path +from typing import Union, Tuple import logging import platform import warnings -import httpx - - -try: - from pandas import DataFrame, json_normalize - df_avail = True -except ImportError: - df_avail = False - -from biothings_client.utils.iteration import ( - iter_n, - list_itemcnt, - safe_str, -) -from biothings_client.utils.copy import copy_func +import httpx -from biothings_client.__version__ import __version__ from biothings_client.client.settings import ( COMMON_ALIASES, COMMON_KWARGS, @@ -44,9 +29,20 @@ MYVARIANT_ALIASES, MYVARIANT_KWARGS, ) +from biothings_client.__version__ import __version__ +from biothings_client._dependencies import _CACHING, _PANDAS +from biothings_client.client.exceptions import OptionalDependencyImportError from biothings_client.mixins.gene import MyGeneClientMixin from biothings_client.mixins.variant import MyVariantClientMixin +from biothings_client.utils.copy import copy_func +from biothings_client.utils.iteration import iter_n, list_itemcnt, concatenate_list + +if _PANDAS: + import pandas +if _CACHING: + import hishel + from biothings_client.cache.storage import AsyncBiothingsClientSqlite3Cache logger = logging.getLogger("biothings.client") logger.setLevel(logging.INFO) @@ -58,31 +54,29 @@ # the business logic can be more concise and more readable. class AsyncBiothingClient: """ - This is the asynchronous client for a biothing web service. + async http client class for accessing the biothings web services """ - def __init__(self, url: str = None, enable_cache: bool = None): + def __init__(self, url: str = None): if url is None: url = self._default_url self.url = url if self.url[-1] == "/": self.url = self.url[:-1] - if enable_cache is None: - enable_cache = False - self.enable_cache = enable_cache - self.max_query = self._max_query # delay and step attributes are for batch queries. self.delay = self._delay # delay is ignored when requests made from cache. self.step = self._step + self.scroll_size = self._scroll_size # raise httpx.HTTPError for status_code > 400 # > but not for 404 on getvariant # > set to False to suppress the exceptions. self.raise_for_status = True + self.default_user_agent = ( "{package_header}/{client_version} (" "python:{python_version} " "httpx:{httpx_version}" ")" ).format( @@ -93,15 +87,83 @@ def __init__(self, url: str = None, enable_cache: bool = None): "httpx_version": httpx.__version__, } ) + self.http_client = None + self.cache_storage = None + self.http_client_setup = False + self.caching_enabled = False + + async def _build_http_client(self, cache_db: Union[str, Path] = None) -> None: + """ + Builds the async client instance for usage through the lifetime + of the biothings_client + + This modifies the state of the BiothingsClient instance + to set the values for the http_client property + + Inputs: + :param cache_db: pathlike object to the local sqlite3 cache database file - async def use_http(self): + Outputs: + :return: None + """ + if not self.http_client_setup: + http_transport = httpx.AsyncHTTPTransport() + self.http_client = httpx.AsyncClient(transport=http_transport) + self.http_client_setup = True + self.http_cache_client_setup = False + + async def _build_cache_http_client(self, cache_db: Union[str, Path] = None) -> None: + """ + Builds the client instance used for caching biothings requests. + We rebuild the client whenever we enable to caching to ensure + that we don't create a database files unless the user explicitly + wants to leverage request caching + + This modifies the state of the BiothingsClient instance + to set the values for the http_client property and the cache_storage property + + Inputs: + :param cache_db: pathlike object to the local sqlite3 cache database file + + Outputs: + :return: None + """ + if not self.http_client_setup: + if cache_db is None: + cache_db = self._default_cache_file + cache_db = Path(cache_db).resolve().absolute() + + self.cache_storage = AsyncBiothingsClientSqlite3Cache() + await self.cache_storage.setup_database_connection(cache_db) + cache_transport = hishel.AsyncCacheTransport( + transport=httpx.AsyncHTTPTransport(), storage=self.cache_storage + ) + cache_controller = hishel.Controller(cacheable_methods=["GET", "POST"]) + self.http_client = hishel.AsyncCacheClient( + controller=cache_controller, transport=cache_transport, storage=self.cache_storage + ) + self.http_client_setup = True + + async def __del__(self): + """ + Destructor for the client to ensure that we close any potential + connections to the cache database + """ + try: + if self.http_client is not None: + await self.http_client.aclose() + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Unable to close the httpx client instance %s", self.http_client) + + def use_http(self): """ Use http instead of https for API calls. """ if self.url: self.url = self.url.replace("https://", "http://") - async def use_https(self): + def use_https(self): """ Use https instead of http for API calls. This is the default. """ @@ -113,24 +175,29 @@ async def _dataframe(obj, dataframe, df_index=True): """ Converts object to DataFrame (pandas) """ - if not df_avail: - raise RuntimeError("pandas module must be installed (or upgraded) for as_dataframe option.") - # if dataframe not in ["by_source", "normal"]: - if dataframe not in [1, 2]: - raise ValueError("dataframe must be either 1 (using json_normalize) or 2 (using DataFrame.from_dict") - if "hits" in obj: - if dataframe == 1: - df = json_normalize(obj["hits"]) + if _PANDAS: + # if dataframe not in ["by_source", "normal"]: + if dataframe not in [1, 2]: + raise ValueError("dataframe must be either 1 (using json_normalize) or 2 (using DataFrame.from_dict") + + if "hits" in obj: + if dataframe == 1: + df = pandas.json_normalize(obj["hits"]) + else: + df = pandas.DataFrame.from_dict(obj) else: - df = DataFrame.from_dict(obj) + if dataframe == 1: + df = pandas.json_normalize(obj) + else: + df = pandas.DataFrame.from_dict(obj) + if df_index: + df = df.set_index("query") + return df else: - if dataframe == 1: - df = json_normalize(obj) - else: - df = DataFrame.from_dict(obj) - if df_index: - df = df.set_index("query") - return df + dataframe_library_error = OptionalDependencyImportError( + optional_function_access="enable dataframe conversion", optional_group="dataframe", libraries=["pandas"] + ) + raise dataframe_library_error async def _get( self, url: str, params: dict = None, none_on_404: bool = False, verbose: bool = True @@ -138,59 +205,73 @@ async def _get( """ Wrapper around the httpx.get method """ + await self._build_http_client() if params is None: params = {} + debug = params.pop("debug", False) return_raw = params.pop("return_raw", False) - headers = {"user-agent": self.default_user_agent} - async with httpx.AsyncClient() as client: - res = await client.get(url, params=params, headers=headers) - from_cache = getattr(res, "from_cache", False) - - if debug: - return from_cache, res - if none_on_404 and res.status_code == 404: - return from_cache, None - if self.raise_for_status: - res.raise_for_status() # raise httpx._exceptions.HTTPStatusError - if return_raw: - return from_cache, res.text - ret = res.json() - return from_cache, ret + response = await self.http_client.get( + url=url, params=params, headers=headers, extensions={"cache_disabled": not self.caching_enabled} + ) + + response_extensions = response.extensions + from_cache = response_extensions.get("from_cache", False) + + if from_cache: + logger.debug("Cached response %s from %s", response, url) + + if response.is_success: + if debug or return_raw: + get_response = (from_cache, response) + else: + get_response = (from_cache, response.json()) + else: + if none_on_404 and response.status_code == 404: + get_response = (from_cache, None) + elif self.raise_for_status: + response.raise_for_status() # raise httpx._exceptions.HTTPStatusError + return get_response - async def _post(self, url: str, params: dict, verbose: bool = True): + async def _post(self, url: str, params: dict, verbose: bool = True) -> Tuple[bool, httpx.Response]: """ Wrapper around the httpx.post method """ + await self._build_http_client() + + if params is None: + params = {} return_raw = params.pop("return_raw", False) headers = {"user-agent": self.default_user_agent} - async with httpx.AsyncClient() as client: - res = await client.post(url, data=params, headers=headers) - from_cache = getattr(res, "from_cache", False) - if self.raise_for_status: - res.raise_for_status() # raise httpx._exceptions.HTTPStatusError - if return_raw: - return from_cache, res - ret = res.json() - return from_cache, ret + response = await self.http_client.post( + url=url, data=params, headers=headers, extensions={"cache_disabled": not self.caching_enabled} + ) - @staticmethod - async def _format_list(a_list, sep=",", quoted=True): - if isinstance(a_list, (list, tuple)): - if quoted: - _out = sep.join(['"{}"'.format(safe_str(x)) for x in a_list]) + response_extensions = response.extensions + from_cache = response_extensions.get("from_cache", False) + + if from_cache: + logger.debug("Cached response %s from %s", response, url) + + if response.is_success: + if return_raw: + post_response = (from_cache, response) else: - _out = sep.join(["{}".format(safe_str(x)) for x in a_list]) + response.read() + post_response = (from_cache, response.json()) else: - _out = a_list # a_list is already a comma separated string - return _out + if self.raise_for_status: + response.raise_for_status() + else: + post_response = (from_cache, response) + return post_response async def _handle_common_kwargs(self, kwargs): # handle these common parameters accept field names as the value for kw in ["fields", "always_list", "allow_null"]: if kw in kwargs: - kwargs[kw] = await self._format_list(kwargs[kw], quoted=False) + kwargs[kw] = concatenate_list(kwargs[kw], quoted=False) return kwargs async def _repeated_query(self, query_fn, query_li, verbose=True, **fn_kwargs): @@ -205,32 +286,118 @@ async def _repeated_query(self, query_fn, query_li, verbose=True, **fn_kwargs): if verbose: logger.info("querying {0}-{1}...".format(i + 1, cnt)) i = cnt - from_cache, query_result = await query_fn(batch, **fn_kwargs) + _, query_result = await query_fn(batch, **fn_kwargs) yield query_result - if verbose: - cache_str = " {0}".format(self._from_cache_notification) if from_cache else "" - logger.info("done.{0}".format(cache_str)) - if not from_cache and self.delay: - # no need to delay if requests are from cache. - await asyncio.sleep(self.delay) - - @property - async def _from_cache_notification(self): - """ - Notification to alert user that a cached result is being returned. - """ - return "[ from cache ]" async def _metadata(self, verbose=True, **kwargs): """ Return a dictionary of Biothing metadata. """ _url = self.url + self._metadata_endpoint - from_cache, ret = await self._get(_url, params=kwargs, verbose=verbose) - if verbose and from_cache: - logger.info(self._from_cache_notification) + _, ret = await self._get(_url, params=kwargs, verbose=verbose) return ret + async def _set_caching(self, cache_db: Union[str, Path] = None, **kwargs) -> None: + """ + Enable the client caching and creates a local cache database + for all future requests + + If caching is already enabled then we no-opt + + Inputs: + :param cache_db: pathlike object to the local sqlite3 cache database file + + Outputs: + :return: None + """ + if _CACHING: + if not self.caching_enabled: + try: + self.caching_enabled = True + self.http_client_setup = False + await self._build_cache_http_client() + logger.debug("Reset the HTTP client to leverage caching %s", self.http_client) + logger.info( + ( + "Enabled client caching: %s\n" 'Future queries will be cached in "%s"', + self, + self.cache_storage.cache_filepath, + ) + ) + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Unable to enable caching") + raise gen_exc + else: + logger.warning("Caching already enabled. Skipping for now ...") + else: + caching_library_error = OptionalDependencyImportError( + optional_function_access="enable biothings-client caching", + optional_group="caching", + libraries=["anysqlite", "hishel"], + ) + raise caching_library_error + + async def _stop_caching(self) -> None: + """ + Disable client caching. The local cache database will be maintained, + but we will disable cache access when sending requests + + If caching is already disabled then we no-opt + + Inputs: + :param None + + Outputs: + :return: None + """ + if _CACHING: + if self.caching_enabled: + try: + await self.cache_storage.clear_cache() + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Error attempting to clear the local cache database") + raise gen_exc + + self.caching_enabled = False + self.http_client_setup = False + self._build_http_client() + logger.debug("Reset the HTTP client to disable caching %s", self.http_client) + logger.info("Disabled client caching: %s", self) + else: + logger.warning("Caching already disabled. Skipping for now ...") + else: + caching_library_error = OptionalDependencyImportError( + optional_function_access="disable biothings-client caching", + optional_group="caching", + libraries=["anysqlite", "hishel"], + ) + raise caching_library_error + + async def _clear_cache(self) -> None: + """ + Clear the globally installed cache. Caching will stil be enabled, + but the data stored in the cache stored will be dropped + """ + if _CACHING: + if self.caching_enabled: + try: + await self.cache_storage.clear_cache() + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Error attempting to clear the local cache database") + raise gen_exc + else: + logger.warning("Caching already disabled. No local cache database to clear. Skipping for now ...") + else: + caching_library_error = OptionalDependencyImportError( + optional_function_access="clear biothings-client cache", + optional_group="caching", + libraries=["anysqlite", "hishel"], + ) + raise caching_library_error + async def _get_fields(self, search_term=None, verbose=True): """ Wrapper for /metadata/fields @@ -246,18 +413,17 @@ async def _get_fields(self, search_term=None, verbose=True): params = {"search": search_term} else: params = {} - from_cache, ret = await self._get(_url, params=params, verbose=verbose) + _, ret = await self._get(_url, params=params, verbose=verbose) for k, v in ret.items(): del k # Get rid of the notes column information if "notes" in v: del v["notes"] - if verbose and from_cache: - logger.info(self._from_cache_notification) return ret async def _getannotation(self, _id, fields=None, **kwargs): - """Return the object given id. + """ + Return the object given id. This is a wrapper for GET query of the biothings annotation service. :param _id: an entity id. @@ -272,13 +438,11 @@ async def _getannotation(self, _id, fields=None, **kwargs): kwargs["fields"] = fields kwargs = await self._handle_common_kwargs(kwargs) _url = self.url + self._annotation_endpoint + str(_id) - from_cache, ret = await self._get(_url, kwargs, none_on_404=True, verbose=verbose) - if verbose and from_cache: - logger.info(self._from_cache_notification) + _, ret = await self._get(_url, kwargs, none_on_404=True, verbose=verbose) return ret async def _getannotations_inner(self, ids, verbose=True, **kwargs): - id_collection = await self._format_list(ids) + id_collection = concatenate_list(ids) _kwargs = {"ids": id_collection} _kwargs.update(kwargs) _url = self.url + self._annotation_endpoint @@ -293,7 +457,8 @@ async def _annotations_generator(self, query_fn, ids, verbose=True, **kwargs): yield hit async def _getannotations(self, ids, fields=None, **kwargs): - """Return the list of annotation objects for the given list of ids. + """ + Return the list of annotation objects for the given list of ids. This is a wrapper for POST query of the biothings annotation service. :param ids: a list/tuple/iterable or a string of ids. @@ -404,9 +569,7 @@ async def _query(self, q: str, **kwargs): dataframe = 1 elif dataframe != 2: dataframe = None - from_cache, out = await self._get(_url, kwargs, verbose=verbose) - if verbose and from_cache: - logger.info(self._from_cache_notification) + _, out = await self._get(_url, kwargs, verbose=verbose) if dataframe: out = await self._dataframe(out, dataframe, df_index=False) return out @@ -414,8 +577,24 @@ async def _query(self, q: str, **kwargs): async def _fetch_all(self, url: str, verbose: bool = True, **kwargs): """ Function that returns a generator to results. Assumes that 'q' is in kwargs. + Implicitly disables caching to ensure we actually hit the endpoint rather than + pulling from local cache """ - from_cache, batch = await self._get(url, params=kwargs, verbose=verbose) + logger.warning("fetch_all implicitly disables HTTP request caching") + restore_caching = False + if self.caching_enabled: + restore_caching = True + try: + await self.stop_caching() + except OptionalDependencyImportError as optional_import_error: + logger.exception(optional_import_error) + logger.debug("No cache to disable for fetch all. Continuing ...") + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Unknown error occured while attempting to disable caching") + raise gen_exc + + _, batch = await self._get(url, params=kwargs, verbose=verbose) if verbose: logger.info("Fetching {0} {1} . . .".format(batch["total"], self._optionally_plural_object_type)) for key in ["q", "fetch_all"]: @@ -429,10 +608,22 @@ async def _fetch_all(self, url: str, verbose: bool = True, **kwargs): for hit in batch["hits"]: yield hit kwargs.update({"scroll_id": batch["_scroll_id"]}) - from_cache, batch = await self._get(url, params=kwargs, verbose=verbose) + _, batch = await self._get(url, params=kwargs, verbose=verbose) + + if restore_caching: + logger.debug("re-enabling the client HTTP caching") + try: + await self.set_caching() + except OptionalDependencyImportError as optional_import_error: + logger.exception(optional_import_error) + logger.debug("No cache to disable for fetch all. Continuing ...") + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Unknown error occured while attempting to disable caching") + raise gen_exc async def _querymany_inner(self, qterms, verbose=True, **kwargs): - query_term_collection = await self._format_list(qterms) + query_term_collection = concatenate_list(qterms) _kwargs = {"q": query_term_collection} _kwargs.update(kwargs) _url = self.url + self._query_endpoint @@ -471,7 +662,7 @@ async def _querymany(self, qterms, scopes=None, **kwargs): raise ValueError('input "qterms" must be a list, tuple or iterable.') if scopes: - kwargs["scopes"] = await self._format_list(scopes, quoted=False) + kwargs["scopes"] = concatenate_list(scopes, quoted=False) kwargs = await self._handle_common_kwargs(kwargs) returnall = kwargs.pop("returnall", False) verbose = kwargs.pop("verbose", True) @@ -518,8 +709,8 @@ async def query_fn(qterms): if dataframe: out = await self._dataframe(out, dataframe, df_index=df_index) - li_dup_df = DataFrame.from_records(li_dup, columns=["query", "duplicate hits"]) - li_missing_df = DataFrame(li_missing, columns=["query"]) + li_dup_df = pandas.DataFrame.from_records(li_dup, columns=["query", "duplicate hits"]) + li_missing_df = pandas.DataFrame(li_missing, columns=["query"]) if verbose: if li_dup: @@ -572,10 +763,8 @@ def get_async_client(biothing_type: str = None, instance: bool = True, *args, ** else: biothing_type = biothing_type.lower() if biothing_type not in ASYNC_CLIENT_SETTINGS and not kwargs.get("url", False): - raise Exception( - "No client named '{0}', currently available clients are: {1}".format( - biothing_type, list(ASYNC_CLIENT_SETTINGS.keys()) - ) + raise TypeError( + f"No client named '{biothing_type}', currently available clients are: {list(ASYNC_CLIENT_SETTINGS.keys())}" ) _settings = ( ASYNC_CLIENT_SETTINGS[biothing_type] @@ -588,7 +777,6 @@ def get_async_client(biothing_type: str = None, instance: bool = True, *args, ** for src_attr, target_attr in _settings["attr_aliases"].items(): if getattr(_class, src_attr, False): setattr(_class, target_attr, copy_func(getattr(_class, src_attr), name=target_attr)) - for _name, _docstring in _settings["class_kwargs"]["_docstring_obj"].items(): _func = getattr(_class, _name, None) if _func: diff --git a/biothings_client/client/base.py b/biothings_client/client/base.py index 1dd51f5..ec394ec 100644 --- a/biothings_client/client/base.py +++ b/biothings_client/client/base.py @@ -4,19 +4,15 @@ from collections.abc import Iterable from copy import copy +from pathlib import Path +from typing import Union, Tuple import logging -import os import platform import time import warnings -import requests +import httpx -from biothings_client.utils.iteration import ( - iter_n, - list_itemcnt, - safe_str, -) from biothings_client.client.settings import ( COMMON_ALIASES, COMMON_KWARGS, @@ -33,29 +29,26 @@ MYVARIANT_ALIASES, MYVARIANT_KWARGS, ) +from biothings_client.__version__ import __version__ +from biothings_client._dependencies import _CACHING, _PANDAS +from biothings_client.client.exceptions import OptionalDependencyImportError from biothings_client.mixins.gene import MyGeneClientMixin from biothings_client.mixins.variant import MyVariantClientMixin from biothings_client.utils.copy import copy_func -from biothings_client.__version__ import __version__ - -try: - from pandas import DataFrame, json_normalize +from biothings_client.utils.iteration import iter_n, list_itemcnt, concatenate_list - df_avail = True -except ImportError: - df_avail = False +if _PANDAS: + import pandas -try: - import requests_cache - - caching_avail = True -except ImportError: - caching_avail = False +if _CACHING: + import hishel + from biothings_client.cache.storage import BiothingsClientSqlite3Cache logger = logging.getLogger("biothings.client") logger.setLevel(logging.INFO) + # Future work: # Consider use "verbose" settings to control default logging output level # by doing this instead of using branching throughout the application, @@ -64,139 +57,214 @@ class BiothingClient: """ - This is the client for a biothing web service. + sync http client class for accessing the biothings web services """ - def __init__(self, url=None): + def __init__(self, url: str = None): if url is None: url = self._default_url self.url = url if self.url[-1] == "/": self.url = self.url[:-1] + self.max_query = self._max_query + # delay and step attributes are for batch queries. self.delay = self._delay # delay is ignored when requests made from cache. self.step = self._step + self.scroll_size = self._scroll_size - # raise requests.exceptions.HTTPError for status_code > 400 + + # raise httpx.HTTPError for status_code > 400 # but not for 404 on getvariant # set to False to suppress the exceptions. self.raise_for_status = True + self.default_user_agent = ( - "{package_header}/{client_version} (" "python:{python_version} " "requests:{requests_version}" ")" + "{package_header}/{client_version} (" "python:{python_version} " "httpx:{httpx_version}" ")" ).format( **{ "package_header": self._pkg_user_agent_header, "client_version": __version__, "python_version": platform.python_version(), - "requests_version": requests.__version__, + "httpx_version": httpx.__version__, } ) - self._cached = False + + self.http_client = None + self.http_client_setup = False + self.cache_storage = None + self.caching_enabled = False + + def _build_http_client(self, cache_db: Union[str, Path] = None) -> None: + """ + Builds the client instance for usage through the lifetime + of the biothings_client + + This modifies the state of the BiothingsClient instance + to set the values for the http_client property + """ + if not self.http_client_setup: + http_transport = httpx.HTTPTransport() + self.http_client = httpx.Client(transport=http_transport) + self.http_client_setup = True + self.http_cache_client_setup = False + + def _build_cache_http_client(self, cache_db: Union[str, Path] = None) -> None: + """ + Builds the client instance used for caching biothings requests. + We rebuild the client whenever we enable to caching to ensure + that we don't create a database files unless the user explicitly + wants to leverage request caching + + This modifies the state of the BiothingsClient instance + to set the values for the http_client property and the cache_storage property + """ + if not self.http_client_setup: + if cache_db is None: + cache_db = self._default_cache_file + cache_db = Path(cache_db).resolve().absolute() + + self.cache_storage = BiothingsClientSqlite3Cache() + self.cache_storage.setup_database_connection(cache_db) + cache_transport = hishel.CacheTransport(transport=httpx.HTTPTransport(), storage=self.cache_storage) + cache_controller = hishel.Controller(cacheable_methods=["GET", "POST"]) + self.http_client = hishel.CacheClient( + controller=cache_controller, transport=cache_transport, storage=self.cache_storage + ) + self.http_client_setup = True + + def __del__(self): + """ + Destructor for the client to ensure that we close any potential + connections to the cache database + """ + try: + if self.http_client is not None: + self.http_client.close() + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Unable to close the httpx client instance %s", self.http_client) def use_http(self): - """Use http instead of https for API calls.""" + """ + Use http instead of https for API calls. + """ if self.url: self.url = self.url.replace("https://", "http://") def use_https(self): - """Use https instead of http for API calls. This is the default.""" + """ + Use https instead of http for API calls. This is the default. + """ if self.url: self.url = self.url.replace("http://", "https://") @staticmethod def _dataframe(obj, dataframe, df_index=True): - """Converts object to DataFrame (pandas)""" - if not df_avail: - raise RuntimeError("Error: pandas module must be installed " "(or upgraded) for as_dataframe option.") - # if dataframe not in ["by_source", "normal"]: - if dataframe not in [1, 2]: - raise ValueError("dataframe must be either 1 (using json_normalize) " "or 2 (using DataFrame.from_dict") - if "hits" in obj: - if dataframe == 1: - df = json_normalize(obj["hits"]) + """ + Converts object to DataFrame (pandas) + """ + if _PANDAS: + # if dataframe not in ["by_source", "normal"]: + if dataframe not in [1, 2]: + raise ValueError("dataframe must be either 1 (using json_normalize) " "or 2 (using DataFrame.from_dict") + + if "hits" in obj: + if dataframe == 1: + df = pandas.json_normalize(obj["hits"]) + else: + df = pandas.DataFrame.from_dict(obj) else: - df = DataFrame.from_dict(obj) + if dataframe == 1: + df = pandas.json_normalize(obj) + else: + df = pandas.DataFrame.from_dict(obj) + if df_index: + df = df.set_index("query") + return df else: - if dataframe == 1: - df = json_normalize(obj) - else: - df = DataFrame.from_dict(obj) - if df_index: - df = df.set_index("query") - return df + dataframe_library_error = OptionalDependencyImportError( + optional_function_access="enable dataframe conversion", optional_group="dataframe", libraries=["pandas"] + ) + raise dataframe_library_error + + def _get( + self, url: str, params: dict = None, none_on_404: bool = False, verbose: bool = True + ) -> Tuple[bool, httpx.Response]: + """ + Wrapper around the httpx.get method + """ + self._build_http_client() + if params is None: + params = {} - def _get(self, url, params=None, none_on_404=False, verbose=True): - params = params or {} debug = params.pop("debug", False) return_raw = params.pop("return_raw", False) headers = {"user-agent": self.default_user_agent} - res = requests.get(url, params=params, headers=headers) - from_cache = getattr(res, "from_cache", False) - if debug: - return from_cache, res - if none_on_404 and res.status_code == 404: - return from_cache, None - if self.raise_for_status: - # raise requests.exceptions.HTTPError if not 200 - res.raise_for_status() - if return_raw: - return from_cache, res.text - ret = res.json() - return from_cache, ret + response = self.http_client.get( + url=url, params=params, headers=headers, extensions={"cache_disabled": not self.caching_enabled} + ) - def _post(self, url, params, verbose=True): + response_extensions = response.extensions + from_cache = response_extensions.get("from_cache", False) + if from_cache: + logger.debug("Cached response %s from %s", response, url) + + if response.is_success: + if debug or return_raw: + get_response = (from_cache, response) + else: + get_response = (from_cache, response.json()) + else: + if none_on_404 and response.status_code == 404: + get_response = (from_cache, None) + elif self.raise_for_status: + response.raise_for_status() # raise httpx._exceptions.HTTPStatusError + return get_response + + def _post(self, url: str, params: dict = None, verbose: bool = True) -> Tuple[bool, httpx.Response]: + """ + Wrapper around the httpx.post method + """ + self._build_http_client() + if params is None: + params = {} return_raw = params.pop("return_raw", False) headers = {"user-agent": self.default_user_agent} - res = requests.post(url, data=params, headers=headers) - from_cache = getattr(res, "from_cache", False) - if self.raise_for_status: - # raise requests.exceptions.HTTPError if not 200 - res.raise_for_status() - if return_raw: - return from_cache, res - ret = res.json() - return from_cache, ret + response = self.http_client.post( + url=url, data=params, headers=headers, extensions={"cache_disabled": not self.caching_enabled} + ) - @staticmethod - def _format_list(a_list, sep=",", quoted=True): - if isinstance(a_list, (list, tuple)): - if quoted: - _out = sep.join(['"{}"'.format(safe_str(x)) for x in a_list]) + response_extensions = response.extensions + from_cache = response_extensions.get("from_cache", False) + + if from_cache: + logger.debug("Cached response %s from %s", response, url) + + if response.is_success: + if return_raw: + post_response = (from_cache, response) else: - _out = sep.join(["{}".format(safe_str(x)) for x in a_list]) + response.read() + post_response = (from_cache, response.json()) else: - _out = a_list # a_list is already a comma separated string - return _out + if self.raise_for_status: + response.raise_for_status() + else: + post_response = (from_cache, response) + return post_response def _handle_common_kwargs(self, kwargs): # handle these common parameters accept field names as the value for kw in ["fields", "always_list", "allow_null"]: if kw in kwargs: - kwargs[kw] = self._format_list(kwargs[kw], quoted=False) + kwargs[kw] = concatenate_list(kwargs[kw], quoted=False) return kwargs - def _repeated_query_old(self, query_fn, query_li, verbose=True, **fn_kwargs): - """This is deprecated, query_li can only be a list""" - step = min(self.step, self.max_query) - if len(query_li) <= step: - # No need to do series of batch queries, turn off verbose output - verbose = False - for i in range(0, len(query_li), step): - is_last_loop = i + step >= len(query_li) - if verbose: - logger.info("querying {0}-{1}...".format(i + 1, min(i + step, len(query_li)))) - query_result = query_fn(query_li[i : i + step], **fn_kwargs) - - yield query_result - - if verbose: - logger.info("done.") - if not is_last_loop and self.delay: - time.sleep(self.delay) - def _repeated_query(self, query_fn, query_li, verbose=True, **fn_kwargs): - """Run query_fn for input query_li in a batch (self.step). + """ + Run query_fn for input query_li in a batch (self.step). return a generator of query_result in each batch. input query_li can be a list/tuple/iterable """ @@ -204,67 +272,136 @@ def _repeated_query(self, query_fn, query_li, verbose=True, **fn_kwargs): i = 0 for batch, cnt in iter_n(query_li, step, with_cnt=True): if verbose: - logger.info("querying {0}-{1}...".format(i + 1, cnt)) + logger.info("querying %s-%s ...", i + 1, cnt) i = cnt from_cache, query_result = query_fn(batch, **fn_kwargs) yield query_result - if verbose: - cache_str = " {0}".format(self._from_cache_notification) if from_cache else "" - logger.info("done.{0}".format(cache_str)) + if not from_cache and self.delay: # no need to delay if requests are from cache. time.sleep(self.delay) - @property - def _from_cache_notification(self): - """Notification to alert user that a cached result is being returned.""" - return "[ from cache ]" - def _metadata(self, verbose=True, **kwargs): - """Return a dictionary of Biothing metadata.""" + """ + Return a dictionary of Biothing metadata. + """ _url = self.url + self._metadata_endpoint - from_cache, ret = self._get(_url, params=kwargs, verbose=verbose) - if verbose and from_cache: - logger.info(self._from_cache_notification) + _, ret = self._get(_url, params=kwargs, verbose=verbose) return ret - def _set_caching(self, cache_db=None, verbose=True, **kwargs): - """Installs a local cache for all requests. + def _set_caching(self, cache_db: Union[str, Path] = None, **kwargs) -> None: + """ + Enable the client caching and creates a local cache database + for all future requests - **cache_db** is the path to the local sqlite cache database.""" - if caching_avail: - if cache_db is None: - cache_db = self._default_cache_file - requests_cache.install_cache(cache_name=cache_db, allowable_methods=("GET", "POST"), **kwargs) - self._cached = True - if verbose: - logger.info('[ Future queries will be cached in "{0}" ]'.format(os.path.abspath(cache_db + ".sqlite"))) + If caching is already enabled then we no-opt + + Inputs: + :param cache_db: pathlike object to the local sqlite3 cache database file + + Outputs: + :return: None + """ + if _CACHING: + if not self.caching_enabled: + try: + self.caching_enabled = True + self.http_client_setup = False + self._build_cache_http_client() + logger.debug("Reset the HTTP client to leverage caching %s", self.http_client) + logger.info( + ( + "Enabled client caching: %s\n" 'Future queries will be cached in "%s"', + self, + self.cache_storage.cache_filepath, + ) + ) + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Unable to enable caching") + raise gen_exc + else: + logger.warning("Caching already enabled. Skipping for now ...") else: - raise RuntimeError( - "The requests_cache python module is required to use request caching. See " - "https://requests-cache.readthedocs.io/en/latest/user_guide.html#installation" + caching_library_error = OptionalDependencyImportError( + optional_function_access="enable biothings-client caching", + optional_group="caching", + libraries=["anysqlite", "hishel"], ) + raise caching_library_error - def _stop_caching(self): - """Stop caching.""" - if self._cached and caching_avail: - requests_cache.uninstall_cache() - self._cached = False - return + def _stop_caching(self) -> None: + """ + Disable client caching. The local cache database will be maintained, + but we will disable cache access when sending requests - def _clear_cache(self): - """Clear the globally installed cache.""" - try: - requests_cache.clear() - except AttributeError: - # requests_cache is not enabled - logger.warning("requests_cache is not enabled. Nothing to clear.") + If caching is already disabled then we no-opt + + Inputs: + :param None + + Outputs: + :return: None + """ + if _CACHING: + if self.caching_enabled: + try: + self.cache_storage.clear_cache() + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Error attempting to clear the local cache database") + raise gen_exc + + self.caching_enabled = False + self.http_client_setup = False + self._build_http_client() + logger.debug("Reset the HTTP client to disable caching %s", self.http_client) + logger.info("Disabled client caching: %s", self) + else: + logger.warning("Caching already disabled. Skipping for now ...") + else: + caching_library_error = OptionalDependencyImportError( + optional_function_access="disable biothings-client caching", + optional_group="caching", + libraries=["anysqlite", "hishel"], + ) + raise caching_library_error + + def _clear_cache(self) -> None: + """ + Clear the globally installed cache. Caching will stil be enabled, + but the data stored in the cache stored will be dropped + + Inputs: + :param None + + Outputs: + :return: None + """ + if _CACHING: + if self.caching_enabled: + try: + self.cache_storage.clear_cache() + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Error attempting to clear the local cache database") + raise gen_exc + else: + logger.warning("Caching already disabled. No local cache database to clear. Skipping for now ...") + else: + caching_library_error = OptionalDependencyImportError( + optional_function_access="clear biothings-client cache", + optional_group="caching", + libraries=["anysqlite", "hishel"], + ) + raise caching_library_error def _get_fields(self, search_term=None, verbose=True): - """Wrapper for /metadata/fields + """ + Wrapper for /metadata/fields - **search_term** is a case insensitive string to search for in available field names. - If not provided, all available fields will be returned. + **search_term** is a case insensitive string to search for in available field names. + If not provided, all available fields will be returned. .. Hint:: This is useful to find out the field names you need to pass to **fields** parameter of other methods. @@ -274,18 +411,17 @@ def _get_fields(self, search_term=None, verbose=True): params = {"search": search_term} else: params = {} - from_cache, ret = self._get(_url, params=params, verbose=verbose) + _, ret = self._get(_url, params=params, verbose=verbose) for k, v in ret.items(): del k # Get rid of the notes column information if "notes" in v: del v["notes"] - if verbose and from_cache: - logger.info(self._from_cache_notification) return ret def _getannotation(self, _id, fields=None, **kwargs): - """Return the object given id. + """ + Return the object given id. This is a wrapper for GET query of the biothings annotation service. :param _id: an entity id. @@ -300,25 +436,26 @@ def _getannotation(self, _id, fields=None, **kwargs): kwargs["fields"] = fields kwargs = self._handle_common_kwargs(kwargs) _url = self.url + self._annotation_endpoint + str(_id) - from_cache, ret = self._get(_url, kwargs, none_on_404=True, verbose=verbose) - if verbose and from_cache: - logger.info(self._from_cache_notification) + _, ret = self._get(_url, kwargs, none_on_404=True, verbose=verbose) return ret def _getannotations_inner(self, ids, verbose=True, **kwargs): - _kwargs = {"ids": self._format_list(ids)} + id_collection = concatenate_list(ids) + _kwargs = {"ids": id_collection} _kwargs.update(kwargs) _url = self.url + self._annotation_endpoint return self._post(_url, _kwargs, verbose=verbose) def _annotations_generator(self, query_fn, ids, verbose=True, **kwargs): - """Function to yield a batch of hits one at a time.""" + """ + Function to yield a batch of hits one at a time + """ for hits in self._repeated_query(query_fn, ids, verbose=verbose): - for hit in hits: - yield hit + yield from hits def _getannotations(self, ids, fields=None, **kwargs): - """Return the list of annotation objects for the given list of ids. + """ + Return the list of annotation objects for the given list of ids. This is a wrapper for POST query of the biothings annotation service. :param ids: a list/tuple/iterable or a string of ids. @@ -381,7 +518,8 @@ def query_fn(ids): return out def _query(self, q, **kwargs): - """Return the query result. + """ + Return the query result. This is a wrapper for GET query of biothings query service. :param q: a query string. @@ -428,52 +566,69 @@ def _query(self, q, **kwargs): dataframe = 1 elif dataframe != 2: dataframe = None - from_cache, out = self._get(_url, kwargs, verbose=verbose) - if verbose and from_cache: - logger.info(self._from_cache_notification) + _, out = self._get(_url, kwargs, verbose=verbose) if dataframe: out = self._dataframe(out, dataframe, df_index=False) return out def _fetch_all(self, url, verbose=True, **kwargs): - """Function that returns a generator to results. Assumes that 'q' is in kwargs.""" - - # function to get the next batch of results, automatically disables cache if we are caching - def _batch(): - if caching_avail and self._cached: - self._cached = False - with requests_cache.disabled(): - from_cache, ret = self._get(url, params=kwargs, verbose=verbose) - del from_cache - self._cached = True - else: - from_cache, ret = self._get(url, params=kwargs, verbose=verbose) - return ret + """ + Function that returns a generator to results. Assumes that 'q' is in kwargs. - batch = _batch() + Implicitly disables caching to ensure we actually hit the endpoint rather than + pulling from local cache + """ + logger.warning("fetch_all implicitly disables HTTP request caching") + restore_caching = False + if self.caching_enabled: + restore_caching = True + try: + self.stop_caching() + except OptionalDependencyImportError as optional_import_error: + logger.exception(optional_import_error) + logger.debug("No cache to disable for fetch all. Continuing ...") + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Unknown error occured while attempting to disable caching") + raise gen_exc + + _, response = self._get(url, params=kwargs, verbose=verbose) if verbose: - logger.info("Fetching {0} {1} . . .".format(batch["total"], self._optionally_plural_object_type)) + logger.info("Fetching {0} {1} . . .".format(response["total"], self._optionally_plural_object_type)) for key in ["q", "fetch_all"]: kwargs.pop(key) - while not batch.get("error", "").startswith("No results to return"): - if "error" in batch: - logger.error(batch["error"]) + while not response.get("error", "").startswith("No results to return"): + if "error" in response: + logger.error(response["error"]) break - if "_warning" in batch and verbose: - logger.warning(batch["_warning"]) - for hit in batch["hits"]: - yield hit - kwargs.update({"scroll_id": batch["_scroll_id"]}) - batch = _batch() + if "_warning" in response and verbose: + logger.warning(response["_warning"]) + yield from response["hits"] + kwargs.update({"scroll_id": response["_scroll_id"]}) + _, response = self._get(url, params=kwargs, verbose=verbose) + + if restore_caching: + logger.debug("re-enabling the client HTTP caching") + try: + self.set_caching() + except OptionalDependencyImportError as optional_import_error: + logger.exception(optional_import_error) + logger.debug("No cache to disable for fetch all. Continuing ...") + except Exception as gen_exc: + logger.exception(gen_exc) + logger.error("Unknown error occured while attempting to disable caching") + raise gen_exc def _querymany_inner(self, qterms, verbose=True, **kwargs): - _kwargs = {"q": self._format_list(qterms)} + query_term_collection = concatenate_list(qterms) + _kwargs = {"q": query_term_collection} _kwargs.update(kwargs) _url = self.url + self._query_endpoint return self._post(_url, params=_kwargs, verbose=verbose) def _querymany(self, qterms, scopes=None, **kwargs): - """Return the batch query result. + """ + Return the batch query result. This is a wrapper for POST query of "/query" service. :param qterms: a list/tuple/iterable of query terms, or a string of comma-separated query terms. @@ -504,7 +659,7 @@ def _querymany(self, qterms, scopes=None, **kwargs): raise ValueError('input "qterms" must be a list, tuple or iterable.') if scopes: - kwargs["scopes"] = self._format_list(scopes, quoted=False) + kwargs["scopes"] = concatenate_list(scopes, quoted=False) kwargs = self._handle_common_kwargs(kwargs) returnall = kwargs.pop("returnall", False) verbose = kwargs.pop("verbose", True) @@ -551,8 +706,8 @@ def query_fn(qterms): if dataframe: out = self._dataframe(out, dataframe, df_index=df_index) - li_dup_df = DataFrame.from_records(li_dup, columns=["query", "duplicate hits"]) - li_missing_df = DataFrame(li_missing, columns=["query"]) + li_dup_df = pandas.DataFrame.from_records(li_dup, columns=["query", "duplicate hits"]) + li_missing_df = pandas.DataFrame(li_missing, columns=["query"]) if verbose: if li_dup: @@ -589,7 +744,7 @@ def get_client(biothing_type=None, instance=True, *args, **kwargs): raise RuntimeError("No biothings_type or url specified.") try: url += "metadata" if url.endswith("/") else "/metadata" - res = requests.get(url) + res = httpx.get(url) dic = res.json() biothing_type = dic.get("biothing_type") if isinstance(biothing_type, list): @@ -600,15 +755,13 @@ def get_client(biothing_type=None, instance=True, *args, **kwargs): raise RuntimeError("Biothing_type in metadata url is not unique.") if not isinstance(biothing_type, str): raise RuntimeError("Biothing_type in metadata url is not a valid string.") - except requests.RequestException as request_error: + except httpx.RequestError as request_error: raise RuntimeError("Cannot access metadata url to determine biothing_type.") from request_error else: biothing_type = biothing_type.lower() if biothing_type not in CLIENT_SETTINGS and not kwargs.get("url", False): - raise Exception( - "No client named '{0}', currently available clients are: {1}".format( - biothing_type, list(CLIENT_SETTINGS.keys()) - ) + raise TypeError( + f"No client named '{biothing_type}', currently available clients are: {list(CLIENT_SETTINGS.keys())}" ) _settings = ( CLIENT_SETTINGS[biothing_type] @@ -722,9 +875,9 @@ def _pluralize(s, optional=True): _kwargs.update( { "_default_url": url, - "_annotation_endpoint": "/" + biothing_type.lower() + "/", + "_annotation_endpoint": f"/{biothing_type.lower()}/", "_optionally_plural_object_type": _pluralize(biothing_type.lower()), - "_default_cache_file": "my" + biothing_type.lower() + "_cache", + "_default_cache_file": f"my{biothing_type.lower()}_cache", } ) _aliases.update( diff --git a/biothings_client/client/exceptions.py b/biothings_client/client/exceptions.py new file mode 100644 index 0000000..c20fa66 --- /dev/null +++ b/biothings_client/client/exceptions.py @@ -0,0 +1,15 @@ +""" +Custom exceptions for the clients (async and sync) +""" + +from typing import List + + +class OptionalDependencyImportError(Exception): + def __init__(self, optional_function_access: str, optional_group: str, libraries: List[str]): + pip_command = f"`pip install biothings_client[{optional_group}]`" + message = ( + f"To {optional_function_access} requires the {libraries} library(ies). " + f"To install run the following command: {pip_command}" + ) + super().__init__(message) diff --git a/biothings_client/client/settings.py b/biothings_client/client/settings.py index 1b689a1..f144186 100644 --- a/biothings_client/client/settings.py +++ b/biothings_client/client/settings.py @@ -18,13 +18,13 @@ # *********************************************** # Function aliases common to all clients COMMON_ALIASES = { - "_metadata": "metadata", - "_set_caching": "set_caching", - "_stop_caching": "stop_caching", "_clear_cache": "clear_cache", "_get_fields": "get_fields", + "_metadata": "metadata", "_query": "query", "_querymany": "querymany", + "_set_caching": "set_caching", + "_stop_caching": "stop_caching", } # Set project specific aliases @@ -51,79 +51,79 @@ # *********************************************** # Object creation kwargs common to all clients COMMON_KWARGS = { - "_pkg_user_agent_header": "biothings_client.py", - "_query_endpoint": "/query/", + "_delay": 1, + "_docstring_obj": {}, + "_max_query": 1000, "_metadata_endpoint": "/metadata", "_metadata_fields_endpoint": "/metadata/fields", - "_top_level_jsonld_uris": [], - "_docstring_obj": {}, - "_delay": 1, - "_step": 1000, + "_pkg_user_agent_header": "biothings_client.py", + "_query_endpoint": "/query/", "_scroll_size": 1000, - "_max_query": 1000, + "_step": 1000, + "_top_level_jsonld_uris": [], } # project specific kwargs MYGENE_KWARGS = copy(COMMON_KWARGS) MYGENE_KWARGS.update( { - "_default_url": "https://mygene.info/v3", "_annotation_endpoint": "/gene/", - "_optionally_plural_object_type": "gene(s)", "_default_cache_file": "mygene_cache", - "_entity": "gene", + "_default_url": "https://mygene.info/v3", "_docstring_obj": GENE_DOCSTRING, + "_entity": "gene", + "_optionally_plural_object_type": "gene(s)", } ) MYVARIANT_KWARGS = copy(COMMON_KWARGS) MYVARIANT_KWARGS.update( { - "_default_url": "https://myvariant.info/v1", "_annotation_endpoint": "/variant/", - "_optionally_plural_object_type": "variant(s)", "_default_cache_file": "myvariant_cache", + "_default_url": "https://myvariant.info/v1", + "_docstring_obj": VARIANT_DOCSTRING, "_entity": "variant", + "_optionally_plural_object_type": "variant(s)", "_top_level_jsonld_uris": MYVARIANT_TOP_LEVEL_JSONLD_URIS, - "_docstring_obj": VARIANT_DOCSTRING, } ) MYCHEM_KWARGS = copy(COMMON_KWARGS) MYCHEM_KWARGS.update( { - "_default_url": "https://mychem.info/v1", "_annotation_endpoint": "/chem/", - "_optionally_plural_object_type": "chem(s)", - "_entity": "chem", "_default_cache_file": "mychem_cache", + "_default_url": "https://mychem.info/v1", "_docstring_obj": CHEM_DOCSTRING, + "_entity": "chem", + "_optionally_plural_object_type": "chem(s)", } ) MYDISEASE_KWARGS = copy(COMMON_KWARGS) MYDISEASE_KWARGS.update( { - "_default_url": "https://mydisease.info/v1", "_annotation_endpoint": "/disease/", - "_optionally_plural_object_type": "disease(s)", - "_entity": "disease", "_default_cache_file": "mydisease_cache", + "_default_url": "https://mydisease.info/v1", + "_entity": "disease", + "_optionally_plural_object_type": "disease(s)", } ) MYTAXON_KWARGS = copy(COMMON_KWARGS) MYTAXON_KWARGS.update( { - "_default_url": "https://t.biothings.io/v1", "_annotation_endpoint": "/taxon/", - "_optionally_plural_object_type": "taxon/taxa", - "_entity": "taxon", "_default_cache_file": "mytaxon_cache", + "_default_url": "https://t.biothings.io/v1", + "_entity": "taxon", + "_optionally_plural_object_type": "taxon/taxa", } ) MYGENESET_KWARGS = copy(COMMON_KWARGS) MYGENESET_KWARGS.update( { - "_default_url": "https://mygeneset.info/v1", "_annotation_endpoint": "/geneset/", - "_optionally_plural_object_type": "geneset(s)", - "_entity": "geneset", "_default_cache_file": "mygeneset_cache", + "_default_url": "https://mygeneset.info/v1", + "_entity": "geneset", + "_optionally_plural_object_type": "geneset(s)", } ) diff --git a/biothings_client/utils/iteration.py b/biothings_client/utils/iteration.py index 6298bcc..578e875 100644 --- a/biothings_client/utils/iteration.py +++ b/biothings_client/utils/iteration.py @@ -1,5 +1,11 @@ from collections import Counter from itertools import islice +from typing import Iterable +import logging + + +logger = logging.getLogger("biothings.client") +logger.setLevel(logging.INFO) def safe_str(s, encoding="utf-8"): @@ -12,7 +18,9 @@ def safe_str(s, encoding="utf-8"): def list_itemcnt(li): - """Return number of occurrence for each item in the list.""" + """ + Return number of occurrence for each item in the list. + """ return list(Counter(li).items()) @@ -33,3 +41,29 @@ def iter_n(iterable, n, with_cnt=False): yield (chunk, cnt) else: yield chunk + + +def concatenate_list(sequence: Iterable, sep: str = ",", quoted: bool = True) -> str: + """ + Combine all the elements of a list into a string + + Inputs: + sequence: iterable data structure + sep: delimiter for joining the entries in `sequence` + quoted: boolean indicating to quote the elements from sequence + while concatenating all the elements + Output: + string value representing the concatenated values + """ + if isinstance(sequence, (list, tuple)): + if quoted: + string_transform = sep.join(['"{}"'.format(safe_str(x)) for x in sequence]) + else: + string_transform = sep.join(["{}".format(safe_str(x)) for x in sequence]) + elif isinstance(sequence, str): + logger.warning("Input sequence provided is already in string format. No operation performed") + string_transform = sequence + else: + logger.warning("Input sequence non-iterable %s. Unable to perform concatenation operation", sequence) + string_transform = sequence + return string_transform diff --git a/pyproject.toml b/pyproject.toml index 3152cd8..aba8eac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ where = ["."] # list of folders that the contains the packages. We need it to be include = ["biothings_client*"] # package names should match these glob patterns exclude = ["tests*"] # exclude packages matching these glob patterns namespaces = false # to disable scanning PEP 420 namespaces (true by default) + [project] name="biothings_client" authors = [ @@ -67,18 +68,17 @@ classifiers = [ dependencies = [ "httpx>0.24.1; python_version>='3.8'", "httpx<=0.24.1; python_version=='3.7'", - "requests>=2.32.0; python_version>='3.8'", - "requests<=2.31.0; python_version=='3.7'", "importlib-metadata; python_version<'3.8'", ] version = "0.4.0" [project.optional-dependencies] -dataframe = ["pandas>=1.3.0"] -jsonld = ["PyLD>=0.7.2"] caching = [ - "requests_cache>=0.4.13" + "anysqlite; python_version>='3.8'", + "hishel[sqlite]; python_version >='3.8'", ] +dataframe = ["pandas>=1.3.0"] +jsonld = ["PyLD>=0.7.2"] tests = [ "pytest>=8.3.3; python_version>='3.8'", "pytest>=7.4.4; python_version=='3.7'", @@ -108,6 +108,7 @@ src_paths = ["."] # pytest configuration [tool.pytest.ini_options] minversion = "6.2.5" +asyncio_mode = "auto" pythonpath = ["."] diff --git a/tests/conftest.py b/tests/conftest.py index aaa92a4..84461e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,49 @@ """ -Client fixtures for usage across the biothings_client testing +Fixtures for the biothings_client testing """ import pytest -from biothings_client import get_async_client -from biothings_client.client.definitions import AsyncMyChemInfo, AsyncMyGeneInfo, AsyncMyGenesetInfo, AsyncMyVariantInfo +from biothings_client import get_client, get_async_client +from biothings_client.client.definitions import ( + AsyncMyChemInfo, + AsyncMyGeneInfo, + AsyncMyGenesetInfo, + AsyncMyVariantInfo, + MyChemInfo, + MyGeneInfo, + MyGenesetInfo, + MyVariantInfo, +) + + +# --- CLIENTS --- +# sync: +# >>> MyGeneInfo client +# >>> MyChemInfo client +# >>> MyVariantInfo client +# >>> MyGenesetInfo client + +# **NOTE** The asynchronous clients must have a scope of `function` as it appears +# between tests the client is closed and the asyncio loop gets closed causing issues +# async: +# >>> AsyncMyGeneInfo client +# >>> AsyncMyChemInfo client +# >>> AsyncMyVariantInfo client +# >>> AsyncMyGenesetInfo client @pytest.fixture(scope="session") +def gene_client() -> MyGeneInfo: + """ + Fixture for generating a synchronous mygene client + """ + client = "gene" + gene_client = get_client(client) + return gene_client + + +@pytest.fixture(scope="function") def async_gene_client() -> AsyncMyGeneInfo: """ Fixture for generating an asynchronous mygene client @@ -19,6 +54,16 @@ def async_gene_client() -> AsyncMyGeneInfo: @pytest.fixture(scope="session") +def chem_client() -> MyChemInfo: + """ + Fixture for generating a synchronous mychem client + """ + client = "chem" + chem_client = get_client(client) + return chem_client + + +@pytest.fixture(scope="function") def async_chem_client() -> AsyncMyChemInfo: """ Fixture for generating an asynchronous mychem client @@ -29,6 +74,16 @@ def async_chem_client() -> AsyncMyChemInfo: @pytest.fixture(scope="session") +def variant_client() -> MyVariantInfo: + """ + Fixture for generating a synchronous myvariant client + """ + client = "variant" + variant_client = get_client(client) + return variant_client + + +@pytest.fixture(scope="function") def async_variant_client() -> AsyncMyVariantInfo: """ Fixture for generating an asynchronous myvariant client @@ -39,6 +94,16 @@ def async_variant_client() -> AsyncMyVariantInfo: @pytest.fixture(scope="session") +def geneset_client() -> MyGenesetInfo: + """ + Fixture for generating a synchronous mygeneset client + """ + client = "geneset" + geneset_client = get_client(client) + return geneset_client + + +@pytest.fixture(scope="function") def async_geneset_client() -> AsyncMyGenesetInfo: """ Fixture for generating an asynchronous mygeneset client diff --git a/tests/test_async.py b/tests/test_async.py index e504c83..9736f4b 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -76,9 +76,9 @@ async def test_url_protocol(client_name: str): assert client_instance.url.startswith(https_protocol) # Transform to HTTP - await client_instance.use_http() + client_instance.use_http() assert client_instance.url.startswith(http_protocol) # Transform back to HTTPS - await client_instance.use_https() + client_instance.use_https() client_instance.url.startswith(https_protocol) diff --git a/tests/test_async_chem.py b/tests/test_async_chem.py index 38a21c0..45b8d2f 100644 --- a/tests/test_async_chem.py +++ b/tests/test_async_chem.py @@ -4,7 +4,6 @@ > implemented in pytest for asyncio marker """ -import importlib.util import json import logging import types @@ -12,13 +11,11 @@ import pytest +import biothings_client from biothings_client.client.definitions import AsyncMyChemInfo from biothings_client.utils.score import descore -pandas_available = importlib.util.find_spec("pandas") is not None -requests_cache_available = importlib.util.find_spec("requests_cache") is not None - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -410,7 +407,7 @@ async def test_querymany_notfound(async_chem_client: AsyncMyChemInfo): @pytest.mark.asyncio -@pytest.mark.skipif(not pandas_available, reason="requires the pandas library") +@pytest.mark.skipif(not biothings_client._PANDAS, reason="requires the pandas library") async def test_querymany_dataframe(async_chem_client: AsyncMyChemInfo): from pandas import DataFrame diff --git a/tests/test_async_gene.py b/tests/test_async_gene.py index 4391f92..a2b6e7e 100644 --- a/tests/test_async_gene.py +++ b/tests/test_async_gene.py @@ -4,18 +4,19 @@ > implemented in pytest for asyncio marker """ -import importlib.util +import logging import types import pytest +import biothings_client from biothings_client.client.definitions import AsyncMyGeneInfo from biothings_client.utils.score import descore -pandas_available = importlib.util.find_spec("pandas") is not None -requests_cache_available = importlib.util.find_spec("requests_cache") is not None +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) @pytest.mark.asyncio @@ -260,7 +261,7 @@ async def test_querymany_notfound(async_gene_client: AsyncMyGeneInfo): @pytest.mark.asyncio -@pytest.mark.skipif(not pandas_available, reason="requires the pandas library") +@pytest.mark.skipif(not biothings_client._PANDAS, reason="requires the pandas library") async def test_querymany_dataframe(async_gene_client: AsyncMyGeneInfo): from pandas import DataFrame diff --git a/tests/test_async_geneset.py b/tests/test_async_geneset.py index 28a9695..dd42688 100644 --- a/tests/test_async_geneset.py +++ b/tests/test_async_geneset.py @@ -1,4 +1,10 @@ -import importlib.util +""" +Mirror of the geneset tests but two main differences: +> asynchronous +> implemented in pytest for asyncio marker +""" + +import logging import types import pytest @@ -8,7 +14,8 @@ from biothings_client.utils.score import descore -requests_cache_available = importlib.util.find_spec("requests_cache") is not None +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) @pytest.mark.asyncio diff --git a/tests/test_async_variant.py b/tests/test_async_variant.py index 04728e8..b53fc27 100644 --- a/tests/test_async_variant.py +++ b/tests/test_async_variant.py @@ -1,14 +1,22 @@ -import importlib.util +""" +Mirror of the variant tests but two main differences: +> asynchronous +> implemented in pytest for asyncio marker +""" + +import logging import types import pytest +import biothings_client from biothings_client.client.definitions import AsyncMyVariantInfo from biothings_client.utils.score import descore -pandas_available = importlib.util.find_spec("pandas") is not None -requests_cache_available = importlib.util.find_spec("requests_cache") is not None + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) @pytest.mark.asyncio @@ -216,7 +224,7 @@ async def test_querymany_notfound(async_variant_client: AsyncMyVariantInfo): @pytest.mark.asyncio -@pytest.mark.skipif(not pandas_available, reason="requires the pandas library") +@pytest.mark.skipif(not biothings_client._PANDAS, reason="requires the pandas library") async def test_querymany_dataframe(async_variant_client: AsyncMyVariantInfo): from pandas import DataFrame diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 0000000..bd4dd32 --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,212 @@ +""" +Tests the client caching functionality +""" + +import datetime +import logging +from typing import Callable + +import pytest + +import biothings_client +from biothings_client.client.asynchronous import get_async_client +from biothings_client.client.base import get_client + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +@pytest.mark.skipif(not biothings_client._CACHING, reason="caching libraries not installed") +@pytest.mark.parametrize( + "method,client,function,function_arguments", + [ + ("GET", "gene", "getgene", {"_id": "1017", "return_raw": True}), + ("POST", "gene", "getgenes", {"ids": ["1017", "1018"], "return_raw": True}), + ("GET", "gene", "query", {"q": "cdk2", "return_raw": True}), + ("POST", "gene", "querymany", {"qterms": ["1017", "695"], "return_raw": True}), + ( + "GET", + "chem", + "getdrug", + {"_id": "ZRALSGWEFCBTJO-UHFFFAOYSA-N", "return_raw": True}, + ), + ( + "POST", + "chem", + "getdrugs", + {"ids": ["ZRALSGWEFCBTJO-UHFFFAOYSA-N", "RRUDCFGSUDOHDG-UHFFFAOYSA-N"], "return_raw": True}, + ), + ( + "GET", + "chem", + "query", + {"q": "chebi.name:albendazole", "size": 5, "return_raw": True}, + ), + ( + "POST", + "chem", + "querymany", + {"qterms": ["CHEBI:31690", "CHEBI:15365"], "scopes": "chebi.id", "return_raw": True}, + ), + ("GET", "variant", "getvariant", {"_id": "chr9:g.107620835G>A", "return_raw": True}), + ("POST", "variant", "getvariants", {"ids": ["chr9:g.107620835G>A", "chr1:g.876664G>A"], "return_raw": True}), + ("GET", "variant", "query", {"q": "dbsnp.rsid:rs58991260", "return_raw": True}), + ( + "POST", + "variant", + "querymany", + {"qterms": ["rs58991260", "rs2500"], "scopes": "dbsnp.rsid", "return_raw": True}, + ), + ("GET", "geneset", "getgeneset", {"_id": "WP100", "return_raw": True}), + ("POST", "geneset", "getgenesets", {"ids": ["WP100", "WP101"], "return_raw": True}), + ("GET", "geneset", "query", {"q": "wnt", "fields": "name,count,source,taxid", "return_raw": True}), + ( + "POST", + "geneset", + "querymany", + {"qterms": ["wnt", "jak-stat"], "fields": "name,count,source,taxid", "return_raw": True}, + ), + ], +) +def test_sync_caching(method: str, client: str, function: str, function_arguments: dict): + """ + Tests sync caching methods for a variety of data + """ + try: + client_instance = get_client(client) + client_instance._build_http_client() + client_instance.set_caching() + assert client_instance.caching_enabled + client_instance.clear_cache() + + client_callback = getattr(client_instance, function) + assert isinstance(client_callback, Callable) + cold_response = client_callback(**function_arguments) + hot_response = client_callback(**function_arguments) + + client_instance.stop_caching() + forced_cold_response = client_callback(**function_arguments) + client_instance.set_caching() + cold_response2 = client_callback(**function_arguments) + client_instance.clear_cache() + forced_cold_response2 = client_callback(**function_arguments) + hot_response2 = client_callback(**function_arguments) + + cold_responses = [cold_response, cold_response2, forced_cold_response, forced_cold_response2] + hot_responses = [hot_response, hot_response2] + for cold_response in cold_responses: + assert cold_response.status_code == 200 + assert not cold_response.extensions.get("from_cache", False) + assert not cold_response.extensions.get("revalidated", False) + + for hot_response in hot_responses: + assert hot_response.status_code == 200 + assert hot_response.extensions["from_cache"] + assert not hot_response.extensions["revalidated"] + assert isinstance(hot_response.extensions["cache_metadata"]["created_at"], datetime.datetime) + assert hot_response.extensions["cache_metadata"]["number_of_uses"] >= 1 + assert all(hot_response.elapsed < cold_response.elapsed for cold_response in cold_responses) + + except Exception as gen_exc: + client_instance.clear_cache() + logger.exception(gen_exc) + raise gen_exc + + +@pytest.mark.asyncio +@pytest.mark.skipif(not biothings_client._CACHING, reason="caching libraries not installed") +@pytest.mark.parametrize( + "method,client,function,function_arguments", + [ + ("GET", "gene", "getgene", {"_id": "1017", "return_raw": True}), + ("POST", "gene", "getgenes", {"ids": ["1017", "1018"], "return_raw": True}), + ("GET", "gene", "query", {"q": "cdk2", "return_raw": True}), + ("POST", "gene", "querymany", {"qterms": ["1017", "695"], "return_raw": True}), + ( + "GET", + "chem", + "getdrug", + {"_id": "ZRALSGWEFCBTJO-UHFFFAOYSA-N", "return_raw": True}, + ), + ( + "POST", + "chem", + "getdrugs", + {"ids": ["ZRALSGWEFCBTJO-UHFFFAOYSA-N", "RRUDCFGSUDOHDG-UHFFFAOYSA-N"], "return_raw": True}, + ), + ( + "GET", + "chem", + "query", + {"q": "chebi.name:albendazole", "size": 5, "return_raw": True}, + ), + ( + "POST", + "chem", + "querymany", + {"qterms": ["CHEBI:31690", "CHEBI:15365"], "scopes": "chebi.id", "return_raw": True}, + ), + ("GET", "variant", "getvariant", {"_id": "chr9:g.107620835G>A", "return_raw": True}), + ("POST", "variant", "getvariants", {"ids": ["chr9:g.107620835G>A", "chr1:g.876664G>A"], "return_raw": True}), + ("GET", "variant", "query", {"q": "dbsnp.rsid:rs58991260", "return_raw": True}), + ( + "POST", + "variant", + "querymany", + {"qterms": ["rs58991260", "rs2500"], "scopes": "dbsnp.rsid", "return_raw": True}, + ), + ("GET", "geneset", "getgeneset", {"_id": "WP100", "return_raw": True}), + ("POST", "geneset", "getgenesets", {"ids": ["WP100", "WP101"], "return_raw": True}), + ("GET", "geneset", "query", {"q": "wnt", "fields": "name,count,source,taxid", "return_raw": True}), + ( + "POST", + "geneset", + "querymany", + {"qterms": ["wnt", "jak-stat"], "fields": "name,count,source,taxid", "return_raw": True}, + ), + ], +) +async def test_async_caching(method: str, client: str, function: str, function_arguments: dict): + """ + Tests async caching methods for a variety of data + """ + try: + client_instance = get_async_client(client) + await client_instance._build_http_client() + await client_instance.set_caching() + assert client_instance.caching_enabled + await client_instance.clear_cache() + + client_callback = getattr(client_instance, function) + assert isinstance(client_callback, Callable) + cold_response = await client_callback(**function_arguments) + hot_response = await client_callback(**function_arguments) + + await client_instance.stop_caching() + forced_cold_response = await client_callback(**function_arguments) + await client_instance.set_caching() + cold_response2 = await client_callback(**function_arguments) + await client_instance.clear_cache() + forced_cold_response2 = await client_callback(**function_arguments) + hot_response2 = await client_callback(**function_arguments) + + cold_responses = [cold_response, cold_response2, forced_cold_response, forced_cold_response2] + hot_responses = [hot_response, hot_response2] + for cold_response in cold_responses: + assert cold_response.status_code == 200 + assert not cold_response.extensions.get("from_cache", False) + assert not cold_response.extensions.get("revalidated", False) + + for hot_response in hot_responses: + assert hot_response.status_code == 200 + assert hot_response.extensions["from_cache"] + assert not hot_response.extensions["revalidated"] + assert isinstance(hot_response.extensions["cache_metadata"]["created_at"], datetime.datetime) + assert hot_response.extensions["cache_metadata"]["number_of_uses"] >= 1 + assert all(hot_response.elapsed < cold_response.elapsed for cold_response in cold_responses) + + except Exception as gen_exc: + await client_instance.clear_cache() + logger.exception(gen_exc) + raise gen_exc diff --git a/tests/test_chem.py b/tests/test_chem.py index 0bf2f64..751e9ba 100644 --- a/tests/test_chem.py +++ b/tests/test_chem.py @@ -1,501 +1,428 @@ -import importlib.util -import logging +""" +Tests for exercising the sychronous biothings_client for mychem +""" + import json -import os -import sys +import logging import types -import unittest -sys.path.insert(0, os.path.split(os.path.split(os.path.abspath(__file__))[0])[0]) +import pytest -from biothings_client.utils.cache import cache_request -from biothings_client.utils.score import descore -import biothings_client - -sys.stdout.write( - '"biothings_client {0}" loaded from "{1}"\n'.format(biothings_client.__version__, biothings_client.__file__) -) -pandas_available = importlib.util.find_spec("pandas") is not None -requests_cache_available = importlib.util.find_spec("requests_cache") is not None +import biothings_client +from biothings_client.client.definitions import MyChemInfo +from biothings_client.utils.score import descore logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -class TestChemClient(unittest.TestCase): - def setUp(self): - self.mc = biothings_client.get_client("chem") - self.query_list1 = [ - "QCYGXOCMWHSXSU-UHFFFAOYSA-N", - "ADFOMBKCPIMCOO-BTVCFUMJSA-N", - "DNUTZBZXLPWRJG-UHFFFAOYSA-N", - "DROLRDZYPMOKLM-BIVLZKPYSA-N", - "KPBZROQVTHLCDU-GOSISDBHSA-N", - "UTUUIUQHGDRVPU-UHFFFAOYSA-K", - "WZWDUEKBAIXVCC-IGHBBLSQSA-N", - "IAJIIJBMBCZPSW-BDAKNGLRSA-N", - "NASIOHFAYPRIAC-JTQLQIEISA-N", - "VGWIQFDQAFSSKA-UHFFFAOYSA-N", - ] - - def test_metadata(self): - meta = self.mc.metadata() - self.assertTrue("src" in meta) - self.assertTrue("stats" in meta) - self.assertTrue("total" in meta["stats"]) - - def test_getchem(self): - c = self.mc.getchem("ZRALSGWEFCBTJO-UHFFFAOYSA-N") - self.assertEqual(c["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - self.assertEqual(c["chebi"]["name"], "guanidine") - - def test_getchem_with_fields(self): - c = self.mc.getchem("7AXV542LZ4", fields="chebi.name,chembl.inchi_key,pubchem.cid") - self.assertTrue("_id" in c) - self.assertTrue("chebi" in c) - self.assertTrue("name" in c["chebi"]) - self.assertTrue("chembl" in c) - self.assertTrue("inchi_key" in c["chembl"]) - self.assertTrue("pubchem" in c) - self.assertTrue("cid" in c["pubchem"]) - - def test_curie_id_query(self): - """ - Tests the annotation endpoint support for the biolink CURIE ID. - - If support is enabled then we should retrieve the exact same document for all the provided - queries - """ - - curie_id_testing_collection = [ - ( - "UCMIRNVEIXFBKS-UHFFFAOYSA-N", - "CHEMBL297569", - "CHEMBL.COMPOUND:CHEMBL297569", - "chembl.compound:CHEMBL297569", - "cHEmbl.ComPOUND:CHEMBL297569", - "chembl.molecule_chembl_id:CHEMBL297569", - ), - ( - "AKUPVPKIFATOBM-UHFFFAOYSA-N", - "120933777", - 120933777, - "PUBCHEM.COMPOUND:120933777", - "pubchem.compound:120933777", - "PuBcHEm.COMPound:120933777", - "pubchem.cid:120933777", - ), - ( - "UCMIRNVEIXFBKS-UHFFFAOYSA-N", - "CHEBI:CHEBI:57966", - "chebi:CHEBI:57966", - "CheBi:CHEBI:57966", - "chebi.id:CHEBI:57966", - ), - ( - "UCMIRNVEIXFBKS-UHFFFAOYSA-N", - "11P2JDE17B", - "UNII:11P2JDE17B", - "unii:11P2JDE17B", - "uNIi:11P2JDE17B", - "unii.unii:11P2JDE17B", - ), - ( - "UCMIRNVEIXFBKS-UHFFFAOYSA-N", - "dB03107", - "DRUGBANK:dB03107", - "drugbank:dB03107", - "DrugBaNK:dB03107", - "drugbank.id:dB03107", - ), - ] - - aggregation_query_groups = [] - for query_collection in curie_id_testing_collection: - query_result_storage = [] - for similar_query in query_collection: - query_result = self.mc.getchem(_id=similar_query) - query_result_storage.append(query_result) - - results_aggregation = [query == query_result_storage[0] for query in query_result_storage[1:]] - - query_result_mapping = dict(zip(query_collection[1:], results_aggregation)) - logger.debug( - "Comparison to first term %s ->\n%s", query_collection[0], json.dumps(query_result_mapping, indent=4) - ) - - if all(results_aggregation): - logger.info("Query group %s succeeded", query_collection) - else: - logger.error("Query group %s failed", query_collection) - - aggregation_query_groups.append(all(results_aggregation)) - assert all(aggregation_query_groups) - - def test_multiple_curie_id_query(self): - """ - Tests the annotations endpoint support for the biolink CURIE ID. - - Batch query testing against the POST endpoint to verify that the CURIE ID can work with - multiple - - If support is enabled then we should retrieve the exact same document for all the provided - queries - """ - curie_id_testing_collection = [ - ( - "UCMIRNVEIXFBKS-UHFFFAOYSA-N", - "CHEMBL297569", - "CHEMBL.COMPOUND:CHEMBL297569", - "chembl.compound:CHEMBL297569", - "cHEmbl.ComPOUND:CHEMBL297569", - "chembl.molecule_chembl_id:CHEMBL297569", - ), - ( - "AKUPVPKIFATOBM-UHFFFAOYSA-N", - "120933777", - 120933777, - "PUBCHEM.COMPOUND:120933777", - "pubchem.compound:120933777", - "PuBcHEm.COMPound:120933777", - "pubchem.cid:120933777", - ), - ( - "UCMIRNVEIXFBKS-UHFFFAOYSA-N", - "CHEBI:CHEBI:57966", - "chebi:CHEBI:57966", - "CheBi:CHEBI:57966", - "chebi.id:CHEBI:57966", - ), - ( - "UCMIRNVEIXFBKS-UHFFFAOYSA-N", - "11P2JDE17B", - "UNII:11P2JDE17B", - "unii:11P2JDE17B", - "uNIi:11P2JDE17B", - "unii.unii:11P2JDE17B", - ), - ( - "UCMIRNVEIXFBKS-UHFFFAOYSA-N", - "dB03107", - "DRUGBANK:dB03107", - "drugbank:dB03107", - "DrugBaNK:dB03107", - "drugbank.id:dB03107", - ), - ] - - results_aggregation = [] - for query_collection in curie_id_testing_collection: - base_result = self.mc.getchem(_id=query_collection[0]) - query_results = self.mc.getchems(ids=query_collection) - assert len(query_results) == len(query_collection) - - batch_result = [] - for query_result, query_entry in zip(query_results, query_collection): - return_query_field = query_result.pop("query") - assert return_query_field == str(query_entry) - batch_result.append(base_result == query_result) - - aggregate_result = all(results_aggregation) - - query_result_mapping = dict(zip(query_collection[1:], results_aggregation)) - logger.debug( - "Comparison to first term %s ->\n%s", query_collection[0], json.dumps(query_result_mapping, indent=4) - ) - - if aggregate_result: - logger.info("Query group %s succeeded", query_collection) - else: - logger.error("Query group %s failed", query_collection) - - results_aggregation.append(aggregate_result) - assert all(results_aggregation) - - @unittest.expectedFailure - def get_getdrug(self): - c = self.mc.getdrug("CHEMBL1308") - self.assertEqual(c["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - c = self.mc.getdrug("7AXV542LZ4") - self.assertEqual(c["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - c = self.mc.getdrug("CHEBI:6431") - self.assertEqual(c["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - # PubChem CID - # not working yet - c = self.mc.getdrug("CID:1990") - self.assertEqual(c["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - c = self.mc.getdrug("1990") - self.assertEqual(c["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_getchems(self): - c_li = self.mc.getchems( - ["KTUFNOKKBVMGRW-UHFFFAOYSA-N", "HXHWSAZORRCQMX-UHFFFAOYSA-N", "DQMZLTXERSFNPB-UHFFFAOYSA-N"] - ) - self.assertEqual(len(c_li), 3) - self.assertEqual(c_li[0]["_id"], "KTUFNOKKBVMGRW-UHFFFAOYSA-N") - self.assertEqual(c_li[1]["_id"], "HXHWSAZORRCQMX-UHFFFAOYSA-N") - self.assertEqual(c_li[2]["_id"], "DQMZLTXERSFNPB-UHFFFAOYSA-N") - - def test_query(self): - qres = self.mc.query("chebi.name:albendazole", size=5) - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 5) - - @unittest.skip("Drugbank was removed") - def test_query_drugbank(self): - qres = self.mc.query("drugbank.id:DB00536") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_chebi(self): - qres = self.mc.query(r"chebi.id:CHEBI\:42820") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_chembl(self): - qres = self.mc.query('chembl.smiles:"CC(=O)NO"') - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "RRUDCFGSUDOHDG-UHFFFAOYSA-N") - - @unittest.expectedFailure - def test_query_drugcentral(self): - qres = self.mc.query("drugcentral.drug_use.contraindication.umls_cui:C0023530", fields="drugcentral", size=50) - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 50) - - # not working yet - qres = self.mc.query("drugcentral.xrefs.kegg_drug:D00220") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_pubchem(self): - qres = self.mc.query("pubchem.molecular_formula:C2H5NO2", fields="pubchem", size=20) - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 20) - - qres = self.mc.query('pubchem.inchi:"InChI=1S/CH5N3/c2-1(3)4/h(H5,2,3,4)"') - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_ginas(self): - qres = self.mc.query("ginas.approvalID:JU58VJ6Y3B") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_pharmgkb(self): - qres = self.mc.query("pharmgkb.id:PA164781028") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_ndc(self): - qres = self.mc.query('ndc.productndc:"27437-051"') - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "KPQZUUQMTUIKBP-UHFFFAOYSA-N") - - @unittest.expectedFailure - def test_query_sider(self): - qres = self.mc.query("sider.meddra.umls_id:C0232487", fields="sider", size=5) - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 5) - # Temp disable this check till we fix the data issue - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_unii(self): - qres = self.mc.query("unii.unii:JU58VJ6Y3B") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_aeolus(self): - qres = self.mc.query("aeolus.rxcui:50675") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def test_query_fetch_all(self): - # fetch_all won't work when caching is used. - self.mc.stop_caching() - q = "drugcentral.drug_use.contraindication.umls_cui:C0023530" - qres = self.mc.query(q, size=0) - total = qres["total"] - - qres = self.mc.query(q, fields="drugcentral.drug_use", fetch_all=True) - self.assertTrue(isinstance(qres, types.GeneratorType)) - self.assertEqual(total, len(list(qres))) - - def test_querymany(self): - qres = self.mc.querymany(["ZRALSGWEFCBTJO-UHFFFAOYSA-N", "RRUDCFGSUDOHDG-UHFFFAOYSA-N"], verbose=False) - self.assertEqual(len(qres), 2) - - qres = self.mc.querymany("ZRALSGWEFCBTJO-UHFFFAOYSA-N,RRUDCFGSUDOHDG-UHFFFAOYSA-N", verbose=False) - self.assertEqual(len(qres), 2) - - def test_querymany_with_scopes(self): - qres = self.mc.querymany(["CHEBI:31690", "CHEBI:15365"], scopes="chebi.id", verbose=False) - self.assertEqual(len(qres), 2) - - qres = self.mc.querymany( - ["CHEMBL374515", "4RZ82L2GY5"], scopes="chembl.molecule_chembl_id,unii.unii", verbose=False +def test_metadata(chem_client: MyChemInfo): + meta = chem_client.metadata() + assert "src" in meta + assert "stats" in meta + assert "total" in meta["stats"] + + +def test_getchem(chem_client: MyChemInfo): + c = chem_client.getchem("ZRALSGWEFCBTJO-UHFFFAOYSA-N") + assert c["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + assert c["chebi"]["name"] == "guanidine" + + +def test_getchem_with_fields(chem_client: MyChemInfo): + c = chem_client.getchem("7AXV542LZ4", fields="chebi.name,chembl.inchi_key,pubchem.cid") + assert "_id" in c + assert "chebi" in c + assert "name" in c["chebi"] + assert "chembl" in c + assert "inchi_key" in c["chembl"] + assert "pubchem" in c + assert "cid" in c["pubchem"] + + +def test_curie_id_query(chem_client: MyChemInfo): + """ + Tests the annotation endpoint support for the biolink CURIE ID. + + If support is enabled then we should retrieve the exact same document for all the provided + queries + """ + + curie_id_testing_collection = [ + ( + "UCMIRNVEIXFBKS-UHFFFAOYSA-N", + "CHEMBL297569", + "CHEMBL.COMPOUND:CHEMBL297569", + "chembl.compound:CHEMBL297569", + "cHEmbl.ComPOUND:CHEMBL297569", + "chembl.molecule_chembl_id:CHEMBL297569", + ), + ( + "AKUPVPKIFATOBM-UHFFFAOYSA-N", + "120933777", + 120933777, + "PUBCHEM.COMPOUND:120933777", + "pubchem.compound:120933777", + "PuBcHEm.COMPound:120933777", + "pubchem.cid:120933777", + ), + ( + "UCMIRNVEIXFBKS-UHFFFAOYSA-N", + "CHEBI:CHEBI:57966", + "chebi:CHEBI:57966", + "CheBi:CHEBI:57966", + "chebi.id:CHEBI:57966", + ), + ( + "UCMIRNVEIXFBKS-UHFFFAOYSA-N", + "11P2JDE17B", + "UNII:11P2JDE17B", + "unii:11P2JDE17B", + "uNIi:11P2JDE17B", + "unii.unii:11P2JDE17B", + ), + ( + "UCMIRNVEIXFBKS-UHFFFAOYSA-N", + "dB03107", + "DRUGBANK:dB03107", + "drugbank:dB03107", + "DrugBaNK:dB03107", + "drugbank.id:dB03107", + ), + ] + + aggregation_query_groups = [] + for query_collection in curie_id_testing_collection: + query_result_storage = [] + for similar_query in query_collection: + query_result = chem_client.getchem(_id=similar_query) + query_result_storage.append(query_result) + + results_aggregation = [query == query_result_storage[0] for query in query_result_storage[1:]] + + query_result_mapping = dict(zip(query_collection[1:], results_aggregation)) + logger.debug( + "Comparison to first term %s ->\n%s", query_collection[0], json.dumps(query_result_mapping, indent=4) ) - self.assertTrue(len(qres) >= 2) - - def test_querymany_fields(self): - qres1 = self.mc.querymany( - ["CHEBI:31690", "CHEBI:15365"], - scopes="chebi.id", - fields=["chebi.name", "unii.registry_number"], - verbose=False, - ) - self.assertEqual(len(qres1), 2) - qres2 = self.mc.querymany( - ["CHEBI:31690", "CHEBI:15365"], scopes="chebi.id", fields="chebi.name,unii.registry_number", verbose=False + if all(results_aggregation): + logger.info("Query group %s succeeded", query_collection) + else: + logger.error("Query group %s failed", query_collection) + + aggregation_query_groups.append(all(results_aggregation)) + assert all(aggregation_query_groups) + + +def test_multiple_curie_id_query(chem_client: MyChemInfo): + """ + Tests the annotations endpoint support for the biolink CURIE ID. + + Batch query testing against the POST endpoint to verify that the CURIE ID can work with + multiple + + If support is enabled then we should retrieve the exact same document for all the provided + queries + """ + curie_id_testing_collection = [ + ( + "UCMIRNVEIXFBKS-UHFFFAOYSA-N", + "CHEMBL297569", + "CHEMBL.COMPOUND:CHEMBL297569", + "chembl.compound:CHEMBL297569", + "cHEmbl.ComPOUND:CHEMBL297569", + "chembl.molecule_chembl_id:CHEMBL297569", + ), + ( + "AKUPVPKIFATOBM-UHFFFAOYSA-N", + "120933777", + 120933777, + "PUBCHEM.COMPOUND:120933777", + "pubchem.compound:120933777", + "PuBcHEm.COMPound:120933777", + "pubchem.cid:120933777", + ), + ( + "UCMIRNVEIXFBKS-UHFFFAOYSA-N", + "CHEBI:CHEBI:57966", + "chebi:CHEBI:57966", + "CheBi:CHEBI:57966", + "chebi.id:CHEBI:57966", + ), + ( + "UCMIRNVEIXFBKS-UHFFFAOYSA-N", + "11P2JDE17B", + "UNII:11P2JDE17B", + "unii:11P2JDE17B", + "uNIi:11P2JDE17B", + "unii.unii:11P2JDE17B", + ), + ( + "UCMIRNVEIXFBKS-UHFFFAOYSA-N", + "dB03107", + "DRUGBANK:dB03107", + "drugbank:dB03107", + "DrugBaNK:dB03107", + "drugbank.id:dB03107", + ), + ] + + results_aggregation = [] + for query_collection in curie_id_testing_collection: + base_result = chem_client.getchem(_id=query_collection[0]) + query_results = chem_client.getchems(ids=query_collection) + assert len(query_results) == len(query_collection) + + batch_result = [] + for query_result, query_entry in zip(query_results, query_collection): + return_query_field = query_result.pop("query") + assert return_query_field == str(query_entry) + batch_result.append(base_result == query_result) + + aggregate_result = all(results_aggregation) + + query_result_mapping = dict(zip(query_collection[1:], results_aggregation)) + logger.debug( + "Comparison to first term %s ->\n%s", query_collection[0], json.dumps(query_result_mapping, indent=4) ) - self.assertEqual(len(qres2), 2) - - self.assertEqual(descore(qres1), descore(qres2)) - - def test_querymany_notfound(self): - qres = self.mc.querymany(["CHEBI:31690", "CHEBI:15365", "NA_TEST"], scopes="chebi.id") - self.assertEqual(len(qres), 3) - self.assertEqual(qres[2], {"query": "NA_TEST", "notfound": True}) - - @unittest.skipIf(not pandas_available, "pandas not available") - def test_querymany_dataframe(self): - from pandas import DataFrame - - qres = self.mc.querymany(self.query_list1, scopes="_id", fields="pubchem", as_dataframe=True) - self.assertTrue(isinstance(qres, DataFrame)) - self.assertTrue("pubchem.inchi" in qres.columns) - self.assertEqual(set(self.query_list1), set(qres.index)) - - def test_querymany_step(self): - qres1 = self.mc.querymany(self.query_list1, scopes="_id", fields="pubchem") - default_step = self.mc.step - self.mc.step = 3 - qres2 = self.mc.querymany(self.query_list1, scopes="_id", fields="pubchem") - self.mc.step = default_step - qres1 = descore(sorted(qres1, key=lambda doc: doc["_id"])) - qres2 = descore(sorted(qres2, key=lambda doc: doc["_id"])) - self.assertEqual(qres1, qres2) - - def test_get_fields(self): - fields = self.mc.get_fields() - self.assertTrue("chembl.inchi_key" in fields.keys()) - self.assertTrue("pharmgkb.trade_names" in fields.keys()) - - fields = self.mc.get_fields("unii") - self.assertTrue("unii.molecular_formula" in fields.keys()) - - @unittest.skipIf(not requests_cache_available, "requests_cache not available") - def test_caching(self): - def _getchem(): - return self.mc.getchem("ZRALSGWEFCBTJO-UHFFFAOYSA-N") - - def _getdrugs(): - return self.mc.getdrugs(["ZRALSGWEFCBTJO-UHFFFAOYSA-N", "RRUDCFGSUDOHDG-UHFFFAOYSA-N"]) - - def _query(): - return self.mc.query("chebi.name:albendazole", size=5) - - def _querymany(): - return self.mc.querymany(["CHEBI:31690", "CHEBI:15365"], scopes="chebi.id") - - try: - from_cache, pre_cache_r = cache_request(_getchem) - self.assertFalse(from_cache) - - cache_name = "mcc" - cache_file = cache_name + ".sqlite" - - if os.path.exists(cache_file): - os.remove(cache_file) - self.mc.set_caching(cache_name) - - # populate cache - from_cache, cache_fill_r = cache_request(_getchem) - self.assertTrue(os.path.exists(cache_file)) - self.assertFalse(from_cache) - # is it from the cache? - from_cache, cached_r = cache_request(_getchem) - self.assertTrue(from_cache) - - self.mc.stop_caching() - # same query should be live - not cached - from_cache, post_cache_r = cache_request(_getchem) - self.assertFalse(from_cache) - - self.mc.set_caching(cache_name) - - # same query should still be sourced from cache - from_cache, recached_r = cache_request(_getchem) - self.assertTrue(from_cache) - - self.mc.clear_cache() - # cache was cleared, same query should be live - from_cache, clear_cached_r = cache_request(_getchem) - self.assertFalse(from_cache) - - # all requests should be identical except the _score, which can vary slightly - for x in [pre_cache_r, cache_fill_r, cached_r, post_cache_r, recached_r, clear_cached_r]: - x.pop("_score", None) - - self.assertTrue( - all( - [ - x == pre_cache_r - for x in [pre_cache_r, cache_fill_r, cached_r, post_cache_r, recached_r, clear_cached_r] - ] - ) - ) - - # test getvariants POST caching - from_cache, first_getgenes_r = cache_request(_getdrugs) - del first_getgenes_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_getgenes_r = cache_request(_getdrugs) - del second_getgenes_r - self.assertTrue(from_cache) - - # test query GET caching - from_cache, first_query_r = cache_request(_query) - del first_query_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_query_r = cache_request(_query) - del second_query_r - self.assertTrue(from_cache) - - # test querymany POST caching - from_cache, first_querymany_r = cache_request(_querymany) - del first_querymany_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_querymany_r = cache_request(_querymany) - del second_querymany_r - self.assertTrue(from_cache) - - finally: - self.mc.stop_caching() - if os.path.exists(cache_file): - os.remove(cache_file) - - -def suite(): - return unittest.defaultTestLoader.loadTestsFromTestCase(TestChemClient) - - -if __name__ == "__main__": - unittest.TextTestRunner().run(suite()) + + if aggregate_result: + logger.info("Query group %s succeeded", query_collection) + else: + logger.error("Query group %s failed", query_collection) + + results_aggregation.append(aggregate_result) + assert all(results_aggregation) + + +@pytest.mark.xfail(reason="pubchem CID not working yet") +def get_getdrug(chem_client: MyChemInfo): + c = chem_client.getdrug("CHEMBL1308") + assert c["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + c = chem_client.getdrug("7AXV542LZ4") + assert c["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + c = chem_client.getdrug("CHEBI:6431") + assert c["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + # PubChem CID not working yet + c = chem_client.getdrug("CID:1990") + assert c["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + c = chem_client.getdrug("1990") + assert c["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_getchems(chem_client: MyChemInfo): + c_li = chem_client.getchems( + ["KTUFNOKKBVMGRW-UHFFFAOYSA-N", "HXHWSAZORRCQMX-UHFFFAOYSA-N", "DQMZLTXERSFNPB-UHFFFAOYSA-N"] + ) + assert len(c_li) == 3 + assert c_li[0]["_id"] == "KTUFNOKKBVMGRW-UHFFFAOYSA-N" + assert c_li[1]["_id"] == "HXHWSAZORRCQMX-UHFFFAOYSA-N" + assert c_li[2]["_id"] == "DQMZLTXERSFNPB-UHFFFAOYSA-N" + + +def test_query(chem_client: MyChemInfo): + qres = chem_client.query("chebi.name:albendazole", size=5) + assert "hits" in qres + assert len(qres["hits"]) == 5 + + +@pytest.mark.xfail(reason="Drugbank was removed") +def test_query_drugbank(chem_client: MyChemInfo): + qres = chem_client.query("drugbank.id:DB00536") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_chebi(chem_client: MyChemInfo): + qres = chem_client.query(r"chebi.id:CHEBI\:42820") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_chembl(chem_client: MyChemInfo): + qres = chem_client.query('chembl.smiles:"CC(=O)NO"') + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "RRUDCFGSUDOHDG-UHFFFAOYSA-N" + + +@pytest.mark.xfail(reason="drugcentral query not working yet") +def test_query_drugcentral(chem_client: MyChemInfo): + qres = chem_client.query("drugcentral.drug_use.contraindication.umls_cui:C0023530", fields="drugcentral", size=50) + assert "hits" in qres + assert len(qres["hits"]) == 50 + + # not working yet + qres = chem_client.query("drugcentral.xrefs.kegg_drug:D00220") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_pubchem(chem_client: MyChemInfo): + qres = chem_client.query("pubchem.molecular_formula:C2H5NO2", fields="pubchem", size=20) + assert "hits" in qres + assert len(qres["hits"]) == 20 + + qres = chem_client.query('pubchem.inchi:"InChI=1S/CH5N3/c2-1(3)4/h(H5,2,3,4)"') + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_ginas(chem_client: MyChemInfo): + qres = chem_client.query("ginas.approvalID:JU58VJ6Y3B") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_pharmgkb(chem_client: MyChemInfo): + qres = chem_client.query("pharmgkb.id:PA164781028") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_ndc(chem_client: MyChemInfo): + qres = chem_client.query('ndc.productndc:"27437-051"') + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "KPQZUUQMTUIKBP-UHFFFAOYSA-N" + + +@pytest.mark.xfail +def test_query_sider(chem_client: MyChemInfo): + qres = chem_client.query("sider.meddra.umls_id:C0232487", fields="sider", size=5) + assert "hits" in qres + assert len(qres["hits"]) == 5 + + # Temp disable this check till we fix the data issue + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_unii(chem_client: MyChemInfo): + qres = chem_client.query("unii.unii:JU58VJ6Y3B") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_aeolus(chem_client: MyChemInfo): + qres = chem_client.query("aeolus.rxcui:50675") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "ZRALSGWEFCBTJO-UHFFFAOYSA-N" + + +def test_query_fetch_all(chem_client: MyChemInfo): + # fetch_all won't work when caching is used. + q = "drugcentral.drug_use.contraindication.umls_cui:C0023530" + qres = chem_client.query(q, size=0) + total = qres["total"] + + qres = chem_client.query(q, fields="drugcentral.drug_use", fetch_all=True) + assert isinstance(qres, types.GeneratorType) + assert total == len(list(qres)) + + +def test_querymany(chem_client: MyChemInfo): + qres = chem_client.querymany(["ZRALSGWEFCBTJO-UHFFFAOYSA-N", "RRUDCFGSUDOHDG-UHFFFAOYSA-N"], verbose=False) + assert len(qres) == 2 + + qres = chem_client.querymany("ZRALSGWEFCBTJO-UHFFFAOYSA-N,RRUDCFGSUDOHDG-UHFFFAOYSA-N", verbose=False) + assert len(qres) == 2 + + +def test_querymany_with_scopes(chem_client: MyChemInfo): + qres = chem_client.querymany(["CHEBI:31690", "CHEBI:15365"], scopes="chebi.id", verbose=False) + assert len(qres) == 2 + + qres = chem_client.querymany( + ["CHEMBL374515", "4RZ82L2GY5"], scopes="chembl.molecule_chembl_id,unii.unii", verbose=False + ) + assert len(qres) >= 2 + + +def test_querymany_fields(chem_client: MyChemInfo): + qres1 = chem_client.querymany( + ["CHEBI:31690", "CHEBI:15365"], + scopes="chebi.id", + fields=["chebi.name", "unii.registry_number"], + verbose=False, + ) + assert len(qres1) == 2 + + qres2 = chem_client.querymany( + ["CHEBI:31690", "CHEBI:15365"], scopes="chebi.id", fields="chebi.name,unii.registry_number", verbose=False + ) + assert len(qres2) == 2 + + assert descore(qres1) == descore(qres2) + + +def test_querymany_notfound(chem_client: MyChemInfo): + qres = chem_client.querymany(["CHEBI:31690", "CHEBI:15365", "NA_TEST"], scopes="chebi.id") + assert len(qres) == 3 + assert qres[2] == {"query": "NA_TEST", "notfound": True} + + +@pytest.mark.skipif(not biothings_client._PANDAS, reason="pandas library not installed") +def test_querymany_dataframe(chem_client: MyChemInfo): + from pandas import DataFrame + + query_list1 = [ + "QCYGXOCMWHSXSU-UHFFFAOYSA-N", + "ADFOMBKCPIMCOO-BTVCFUMJSA-N", + "DNUTZBZXLPWRJG-UHFFFAOYSA-N", + "DROLRDZYPMOKLM-BIVLZKPYSA-N", + "KPBZROQVTHLCDU-GOSISDBHSA-N", + "UTUUIUQHGDRVPU-UHFFFAOYSA-K", + "WZWDUEKBAIXVCC-IGHBBLSQSA-N", + "IAJIIJBMBCZPSW-BDAKNGLRSA-N", + "NASIOHFAYPRIAC-JTQLQIEISA-N", + "VGWIQFDQAFSSKA-UHFFFAOYSA-N", + ] + + qres = chem_client.querymany(query_list1, scopes="_id", fields="pubchem", as_dataframe=True) + assert isinstance(qres, DataFrame) + assert "pubchem.inchi" in qres.columns + assert set(query_list1) == set(qres.index) + + +def test_querymany_step(chem_client: MyChemInfo): + + query_list1 = [ + "QCYGXOCMWHSXSU-UHFFFAOYSA-N", + "ADFOMBKCPIMCOO-BTVCFUMJSA-N", + "DNUTZBZXLPWRJG-UHFFFAOYSA-N", + "DROLRDZYPMOKLM-BIVLZKPYSA-N", + "KPBZROQVTHLCDU-GOSISDBHSA-N", + "UTUUIUQHGDRVPU-UHFFFAOYSA-K", + "WZWDUEKBAIXVCC-IGHBBLSQSA-N", + "IAJIIJBMBCZPSW-BDAKNGLRSA-N", + "NASIOHFAYPRIAC-JTQLQIEISA-N", + "VGWIQFDQAFSSKA-UHFFFAOYSA-N", + ] + qres1 = chem_client.querymany(query_list1, scopes="_id", fields="pubchem") + default_step = chem_client.step + chem_client.step = 3 + qres2 = chem_client.querymany(query_list1, scopes="_id", fields="pubchem") + chem_client.step = default_step + qres1 = descore(sorted(qres1, key=lambda doc: doc["_id"])) + qres2 = descore(sorted(qres2, key=lambda doc: doc["_id"])) + assert qres1 == qres2 + + +def test_get_fields(chem_client: MyChemInfo): + fields = chem_client.get_fields() + assert "chembl.inchi_key" in fields.keys() + assert "pharmgkb.trade_names" in fields.keys() + + fields = chem_client.get_fields("unii") + assert "unii.molecular_formula" in fields.keys() diff --git a/tests/test_gene.py b/tests/test_gene.py index f3842fd..7258bd9 100644 --- a/tests/test_gene.py +++ b/tests/test_gene.py @@ -1,377 +1,292 @@ -import importlib.util -import os -import sys +""" +Tests for exercising the sychronous biothings_client for mygene +""" + +import logging import types -import unittest -sys.path.insert(0, os.path.split(os.path.split(os.path.abspath(__file__))[0])[0]) +import pytest -from biothings_client.utils.cache import cache_request -from biothings_client.utils.score import descore import biothings_client +from biothings_client.client.definitions import MyGeneInfo +from biothings_client.utils.score import descore -sys.stdout.write( - '"biothings_client {0}" loaded from "{1}"\n'.format(biothings_client.__version__, biothings_client.__file__) -) - -pandas_available = importlib.util.find_spec("pandas") is not None -requests_cache_available = importlib.util.find_spec("requests_cache") is not None - - -class TestGeneClient(unittest.TestCase): - def setUp(self): - self.mg = biothings_client.get_client("gene") - self.query_list1 = [ - "1007_s_at", - "1053_at", - "117_at", - "121_at", - "1255_g_at", - "1294_at", - "1316_at", - "1320_at", - "1405_i_at", - "1431_at", - ] - - def test_http(self): - # this is the default - self.mg.url.startswith("https://") - # switch to http - self.mg.use_http() - self.mg.url.startswith("http://") - # reset to default - self.mg.use_https() - self.mg.url.startswith("https://") - - def test_metadata(self): - meta = self.mg.metadata() - self.assertTrue("stats" in meta) - self.assertTrue("total_genes" in meta["stats"]) - - def test_getgene(self): - g = self.mg.getgene("1017") - self.assertEqual(g["_id"], "1017") - self.assertEqual(g["symbol"], "CDK2") - - def test_getgene_with_fields(self): - g = self.mg.getgene("1017", fields="name,symbol,refseq") - self.assertTrue("_id" in g) - self.assertTrue("name" in g) - self.assertTrue("symbol" in g) - self.assertTrue("refseq" in g) - self.assertFalse("summary" in g) - - def test_curie_id_query(self): - """ - Tests the annotation endpoint support for the biolink CURIE ID. - - If support is enabled then we should retrieve the exact same document for all the provided - queries - """ - curie_id_testing_collection = [ - ("1017", "entrezgene:1017", "NCBIgene:1017"), - (1017, "entrezgene:1017", "ncbigene:1017"), - ("1017", "entrezgene:1017", "NCBIGENE:1017"), - ("1018", "ensembl.gene:ENSG00000250506", "ENSEMBL:ENSG00000250506"), - (1018, "ensembl.gene:ENSG00000250506", "ensembl:ENSG00000250506"), - ("5995", "uniprot.Swiss-Prot:P47804", "UniProtKB:P47804"), - (5995, "uniprot.Swiss-Prot:P47804", "UNIPROTKB:P47804"), - ("5995", "uniprot.Swiss-Prot:P47804", "uniprotkb:P47804"), - ] - - results_aggregation = [] - for id_query, biothings_query, biolink_query in curie_id_testing_collection: - id_query_result = self.mg.getgene(_id=id_query) - biothings_term_query_result = self.mg.getgene(_id=biothings_query) - biolink_term_query_result = self.mg.getgene(_id=biolink_query) - results_aggregation.append( - ( - id_query_result == biothings_term_query_result, - id_query_result == biolink_term_query_result, - biothings_term_query_result == biolink_term_query_result, - ) - ) - results_validation = [] - failure_messages = [] - for result, test_query in zip(results_aggregation, curie_id_testing_collection): - cumulative_result = all(result) - if not cumulative_result: - failure_messages.append(f"Query Failure: {test_query} | Results: {result}") - results_validation.append(cumulative_result) - - self.assertTrue(all(results_validation), msg="\n".join(failure_messages)) - - def test_multiple_curie_id_query(self): - """ - Tests the annotations endpoint support for the biolink CURIE ID. - - Batch query testing against the POST endpoint to verify that the CURIE ID can work with - multiple - - If support is enabled then we should retrieve the exact same document for all the provided - queries - """ - curie_id_testing_collection = [ - ("1017", "entrezgene:1017", "NCBIgene:1017"), - (1017, "entrezgene:1017", "ncbigene:1017"), - ("1017", "entrezgene:1017", "NCBIGENE:1017"), - ("1018", "ensembl.gene:ENSG00000250506", "ENSEMBL:ENSG00000250506"), - (1018, "ensembl.gene:ENSG00000250506", "ensembl:ENSG00000250506"), - ("5995", "uniprot.Swiss-Prot:P47804", "UniProtKB:P47804"), - (5995, "uniprot.Swiss-Prot:P47804", "UNIPROTKB:P47804"), - ("5995", "uniprot.Swiss-Prot:P47804", "uniprotkb:P47804"), - ] - - results_aggregation = [] - for id_query, biothings_query, biolink_query in curie_id_testing_collection: - base_result = self.mg.getgene(_id=id_query) - - batch_query = [id_query, biothings_query, biolink_query] - query_results = self.mg.getgenes(ids=batch_query) - assert len(query_results) == len(batch_query) - - batch_id_query = query_results[0] - batch_biothings_query = query_results[1] - batch_biolink_query = query_results[2] - - batch_id_query_return_value = batch_id_query.pop("query") - assert batch_id_query_return_value == str(id_query) - - batch_biothings_query_return_value = batch_biothings_query.pop("query") - assert batch_biothings_query_return_value == str(biothings_query) - - batch_biolink_query_return_value = batch_biolink_query.pop("query") - assert batch_biolink_query_return_value == str(biolink_query) - - batch_result = ( - base_result == batch_id_query, - base_result == batch_biothings_query, - base_result == batch_biolink_query, +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def test_metadata(gene_client: MyGeneInfo): + meta = gene_client.metadata() + assert "stats" in meta + assert "total_genes" in meta["stats"] + + +def test_getgene(gene_client: MyGeneInfo): + g = gene_client.getgene("1017") + assert g["_id"] == "1017" + assert g["symbol"] == "CDK2" + + +def test_getgene_with_fields(gene_client: MyGeneInfo): + g = gene_client.getgene("1017", fields="name,symbol,refseq") + assert "_id" in g + assert "name" in g + assert "symbol" in g + assert "refseq" in g + assert "summary" not in g + + +def test_curie_id_query(gene_client: MyGeneInfo): + """ + Tests the annotation endpoint support for the biolink CURIE ID. + + If support is enabled then we should retrieve the exact same document for all the provided + queries + """ + curie_id_testing_collection = [ + ("1017", "entrezgene:1017", "NCBIgene:1017"), + (1017, "entrezgene:1017", "ncbigene:1017"), + ("1017", "entrezgene:1017", "NCBIGENE:1017"), + ("1018", "ensembl.gene:ENSG00000250506", "ENSEMBL:ENSG00000250506"), + (1018, "ensembl.gene:ENSG00000250506", "ensembl:ENSG00000250506"), + ("5995", "uniprot.Swiss-Prot:P47804", "UniProtKB:P47804"), + (5995, "uniprot.Swiss-Prot:P47804", "UNIPROTKB:P47804"), + ("5995", "uniprot.Swiss-Prot:P47804", "uniprotkb:P47804"), + ] + + results_aggregation = [] + for id_query, biothings_query, biolink_query in curie_id_testing_collection: + id_query_result = gene_client.getgene(_id=id_query) + biothings_term_query_result = gene_client.getgene(_id=biothings_query) + biolink_term_query_result = gene_client.getgene(_id=biolink_query) + results_aggregation.append( + ( + id_query_result == biothings_term_query_result, + id_query_result == biolink_term_query_result, + biothings_term_query_result == biolink_term_query_result, ) - results_aggregation.append(batch_result) - - results_validation = [] - failure_messages = [] - for result, test_query in zip(results_aggregation, curie_id_testing_collection): - cumulative_result = all(result) - if not cumulative_result: - failure_messages.append(f"Query Failure: {test_query} | Results: {result}") - results_validation.append(cumulative_result) - - self.assertTrue(all(results_validation), msg="\n".join(failure_messages)) - - def test_getgene_with_fields_as_list(self): - g1 = self.mg.getgene("1017", fields="name,symbol,refseq") - g2 = self.mg.getgene("1017", fields=["name", "symbol", "refseq"]) - self.assertEqual(g1, g2) - - def test_getgenes(self): - g_li = self.mg.getgenes([1017, 1018, "ENSG00000148795"]) - self.assertEqual(len(g_li), 3) - self.assertEqual(g_li[0]["_id"], "1017") - self.assertEqual(g_li[1]["_id"], "1018") - self.assertEqual(g_li[2]["_id"], "1586") - - def test_query(self): - qres = self.mg.query("cdk2", size=5) - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 5) - - def test_query_with_fields_as_list(self): - qres1 = self.mg.query("entrezgene:1017", fields="name,symbol,refseq") - qres2 = self.mg.query("entrezgene:1017", fields=["name", "symbol", "refseq"]) - self.assertTrue("hits" in qres1) - self.assertEqual(len(qres1["hits"]), 1) - self.assertEqual(descore(qres1["hits"]), descore(qres2["hits"])) - - def test_query_reporter(self): - qres = self.mg.query("reporter:1000_at") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "5595") - - def test_query_symbol(self): - qres = self.mg.query("symbol:cdk2", species="mouse") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "12566") - - def test_query_fetch_all(self): - # fetch_all won't work when caching is used. - self.mg.stop_caching() - qres = self.mg.query("_exists_:pdb") - total = qres["total"] - - qres = self.mg.query("_exists_:pdb", fields="pdb", fetch_all=True) - self.assertTrue(isinstance(qres, types.GeneratorType)) - self.assertEqual(total, len(list(qres))) - - def test_querymany(self): - qres = self.mg.querymany([1017, "695"], verbose=False) - self.assertEqual(len(qres), 2) - - qres = self.mg.querymany("1017,695", verbose=False) - self.assertEqual(len(qres), 2) - - def test_querymany_with_scopes(self): - qres = self.mg.querymany([1017, "695"], scopes="entrezgene", verbose=False) - self.assertEqual(len(qres), 2) - - qres = self.mg.querymany([1017, "BTK"], scopes="entrezgene,symbol", verbose=False) - self.assertTrue(len(qres) >= 2) - - def test_querymany_species(self): - qres = self.mg.querymany([1017, "695"], scopes="entrezgene", species="human", verbose=False) - self.assertEqual(len(qres), 2) - - qres = self.mg.findgenes([1017, "695"], scopes="entrezgene", species=9606, verbose=False) - self.assertEqual(len(qres), 2) - - qres = self.mg.findgenes([1017, "CDK2"], scopes="entrezgene,symbol", species=9606, verbose=False) - self.assertEqual(len(qres), 2) - - def test_querymany_fields(self): - qres1 = self.mg.findgenes( - [1017, "CDK2"], scopes="entrezgene,symbol", fields=["uniprot", "unigene"], species=9606, verbose=False ) - self.assertEqual(len(qres1), 2) - qres2 = self.mg.findgenes( - "1017,CDK2", scopes="entrezgene,symbol".split(","), fields="uniprot,unigene", species=9606, verbose=False + results_validation = [] + failure_messages = [] + for result, test_query in zip(results_aggregation, curie_id_testing_collection): + cumulative_result = all(result) + if not cumulative_result: + failure_messages.append(f"Query Failure: {test_query} | Results: {result}") + results_validation.append(cumulative_result) + + assert all(results_validation), "\n".join(failure_messages) + + +def test_multiple_curie_id_query(gene_client: MyGeneInfo): + """ + Tests the annotations endpoint support for the biolink CURIE ID. + + Batch query testing against the POST endpoint to verify that the CURIE ID can work with + multiple + + If support is enabled then we should retrieve the exact same document for all the provided + queries + """ + curie_id_testing_collection = [ + ("1017", "entrezgene:1017", "NCBIgene:1017"), + (1017, "entrezgene:1017", "ncbigene:1017"), + ("1017", "entrezgene:1017", "NCBIGENE:1017"), + ("1018", "ensembl.gene:ENSG00000250506", "ENSEMBL:ENSG00000250506"), + (1018, "ensembl.gene:ENSG00000250506", "ensembl:ENSG00000250506"), + ("5995", "uniprot.Swiss-Prot:P47804", "UniProtKB:P47804"), + (5995, "uniprot.Swiss-Prot:P47804", "UNIPROTKB:P47804"), + ("5995", "uniprot.Swiss-Prot:P47804", "uniprotkb:P47804"), + ] + + results_aggregation = [] + for id_query, biothings_query, biolink_query in curie_id_testing_collection: + base_result = gene_client.getgene(_id=id_query) + + batch_query = [id_query, biothings_query, biolink_query] + query_results = gene_client.getgenes(ids=batch_query) + assert len(query_results) == len(batch_query) + + batch_id_query = query_results[0] + batch_biothings_query = query_results[1] + batch_biolink_query = query_results[2] + + batch_id_query_return_value = batch_id_query.pop("query") + assert batch_id_query_return_value == str(id_query) + + batch_biothings_query_return_value = batch_biothings_query.pop("query") + assert batch_biothings_query_return_value == str(biothings_query) + + batch_biolink_query_return_value = batch_biolink_query.pop("query") + assert batch_biolink_query_return_value == str(biolink_query) + + batch_result = ( + base_result == batch_id_query, + base_result == batch_biothings_query, + base_result == batch_biolink_query, ) - self.assertEqual(len(qres2), 2) - - self.assertEqual(descore(qres1), descore(qres2)) - - def test_querymany_notfound(self): - qres = self.mg.findgenes([1017, "695", "NA_TEST"], scopes="entrezgene", species=9606) - self.assertEqual(len(qres), 3) - self.assertEqual(qres[2], {"query": "NA_TEST", "notfound": True}) - - @unittest.skipIf(not pandas_available, "pandas not available") - def test_querymany_dataframe(self): - from pandas import DataFrame - - qres = self.mg.querymany(self.query_list1, scopes="reporter", as_dataframe=True) - self.assertTrue(isinstance(qres, DataFrame)) - self.assertTrue("name" in qres.columns) - self.assertEqual(set(self.query_list1), set(qres.index)) - - def test_querymany_step(self): - qres1 = self.mg.querymany(self.query_list1, scopes="reporter") - default_step = self.mg.step - self.mg.step = 3 - qres2 = self.mg.querymany(self.query_list1, scopes="reporter") - self.mg.step = default_step - qres1 = descore(sorted(qres1, key=lambda doc: doc["_id"])) - qres2 = descore(sorted(qres2, key=lambda doc: doc["_id"])) - self.assertEqual(qres1, qres2) - - def test_get_fields(self): - fields = self.mg.get_fields() - self.assertTrue("uniprot" in fields.keys()) - self.assertTrue("exons" in fields.keys()) - - fields = self.mg.get_fields("kegg") - self.assertTrue("pathway.kegg" in fields.keys()) - - @unittest.skipIf(not requests_cache_available, "requests_cache not available") - def test_caching(self): - def _getgene(): - return self.mg.getgene("1017") - - def _getgenes(): - return self.mg.getgenes(["1017", "1018"]) - - def _query(): - return self.mg.query("cdk2") - - def _querymany(): - return self.mg.querymany(["1017", "695"]) - - try: - from_cache, pre_cache_r = cache_request(_getgene) - self.assertFalse(from_cache) - - cache_name = "mgc" - cache_file = cache_name + ".sqlite" - if os.path.exists(cache_file): - os.remove(cache_file) - self.mg.set_caching(cache_name) - - # populate cache - from_cache, cache_fill_r = cache_request(_getgene) - self.assertTrue(os.path.exists(cache_file)) - self.assertFalse(from_cache) - # is it from the cache? - from_cache, cached_r = cache_request(_getgene) - self.assertTrue(from_cache) - - self.mg.stop_caching() - # same query should be live - not cached - from_cache, post_cache_r = cache_request(_getgene) - self.assertFalse(from_cache) - - self.mg.set_caching(cache_name) - - # same query should still be sourced from cache - from_cache, recached_r = cache_request(_getgene) - self.assertTrue(from_cache) - - self.mg.clear_cache() - # cache was cleared, same query should be live - from_cache, clear_cached_r = cache_request(_getgene) - self.assertFalse(from_cache) - - # all requests should be identical except the _score, which can vary slightly - for x in [pre_cache_r, cache_fill_r, cached_r, post_cache_r, recached_r, clear_cached_r]: - x.pop("_score", None) - - self.assertTrue( - all( - x == pre_cache_r - for x in [pre_cache_r, cache_fill_r, cached_r, post_cache_r, recached_r, clear_cached_r] - ) - ) + results_aggregation.append(batch_result) + + results_validation = [] + failure_messages = [] + for result, test_query in zip(results_aggregation, curie_id_testing_collection): + cumulative_result = all(result) + if not cumulative_result: + failure_messages.append(f"Query Failure: {test_query} | Results: {result}") + results_validation.append(cumulative_result) + + assert all(results_validation), "\n".join(failure_messages) + + +def test_getgene_with_fields_as_list(gene_client: MyGeneInfo): + g1 = gene_client.getgene("1017", fields="name,symbol,refseq") + g2 = gene_client.getgene("1017", fields=["name", "symbol", "refseq"]) + assert g1 == g2 + + +def test_getgenes(gene_client: MyGeneInfo): + g_li = gene_client.getgenes([1017, 1018, "ENSG00000148795"]) + assert len(g_li) == 3 + assert g_li[0]["_id"] == "1017" + assert g_li[1]["_id"] == "1018" + assert g_li[2]["_id"] == "1586" + + +def test_query(gene_client: MyGeneInfo): + qres = gene_client.query("cdk2", size=5) + assert "hits" in qres + assert len(qres["hits"]) == 5 + + +def test_query_with_fields_as_list(gene_client: MyGeneInfo): + qres1 = gene_client.query("entrezgene:1017", fields="name,symbol,refseq") + qres2 = gene_client.query("entrezgene:1017", fields=["name", "symbol", "refseq"]) + assert "hits" in qres1 + assert len(qres1["hits"]) == 1 + assert descore(qres1["hits"]) == descore(qres2["hits"]) + + +def test_query_reporter(gene_client: MyGeneInfo): + qres = gene_client.query("reporter:1000_at") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "5595" + + +def test_query_symbol(gene_client: MyGeneInfo): + qres = gene_client.query("symbol:cdk2", species="mouse") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "12566" + + +def test_query_fetch_all(gene_client: MyGeneInfo): + qres = gene_client.query("_exists_:pdb") + total = qres["total"] + + qres = gene_client.query("_exists_:pdb", fields="pdb", fetch_all=True) + assert isinstance(qres, types.GeneratorType) + assert total == len(list(qres)) + + +def test_querymany(gene_client: MyGeneInfo): + qres = gene_client.querymany([1017, "695"], verbose=False) + assert len(qres) == 2 + + qres = gene_client.querymany("1017,695", verbose=False) + assert len(qres) == 2 + + +def test_querymany_with_scopes(gene_client: MyGeneInfo): + qres = gene_client.querymany([1017, "695"], scopes="entrezgene", verbose=False) + assert len(qres) == 2 + + qres = gene_client.querymany([1017, "BTK"], scopes="entrezgene,symbol", verbose=False) + assert len(qres) >= 2 + + +def test_querymany_species(gene_client: MyGeneInfo): + qres = gene_client.querymany([1017, "695"], scopes="entrezgene", species="human", verbose=False) + assert len(qres) == 2 + + qres = gene_client.findgenes([1017, "695"], scopes="entrezgene", species=9606, verbose=False) + assert len(qres) == 2 + + qres = gene_client.findgenes([1017, "CDK2"], scopes="entrezgene,symbol", species=9606, verbose=False) + assert len(qres) == 2 + + +def test_querymany_fields(gene_client: MyGeneInfo): + qres1 = gene_client.findgenes( + [1017, "CDK2"], scopes="entrezgene,symbol", fields=["uniprot", "unigene"], species=9606, verbose=False + ) + assert len(qres1) == 2 + + qres2 = gene_client.findgenes( + "1017,CDK2", scopes="entrezgene,symbol".split(","), fields="uniprot,unigene", species=9606, verbose=False + ) + assert len(qres2) == 2 + + assert descore(qres1) == descore(qres2) + + +def test_querymany_notfound(gene_client: MyGeneInfo): + qres = gene_client.findgenes([1017, "695", "NA_TEST"], scopes="entrezgene", species=9606) + assert len(qres) == 3 + assert qres[2] == {"query": "NA_TEST", "notfound": True} + + +@pytest.mark.skipif(not biothings_client._PANDAS, reason="pandas library not installed") +def test_querymany_dataframe(gene_client: MyGeneInfo): + from pandas import DataFrame + + query_list1 = [ + "1007_s_at", + "1053_at", + "117_at", + "121_at", + "1255_g_at", + "1294_at", + "1316_at", + "1320_at", + "1405_i_at", + "1431_at", + ] + + qres = gene_client.querymany(query_list1, scopes="reporter", as_dataframe=True) + assert isinstance(qres, DataFrame) + assert "name" in qres.columns + assert set(query_list1) == set(qres.index) + + +def test_querymany_step(gene_client: MyGeneInfo): + query_list1 = [ + "1007_s_at", + "1053_at", + "117_at", + "121_at", + "1255_g_at", + "1294_at", + "1316_at", + "1320_at", + "1405_i_at", + "1431_at", + ] + qres1 = gene_client.querymany(query_list1, scopes="reporter") + default_step = gene_client.step + gene_client.step = 3 + qres2 = gene_client.querymany(query_list1, scopes="reporter") + gene_client.step = default_step + qres1 = descore(sorted(qres1, key=lambda doc: doc["_id"])) + qres2 = descore(sorted(qres2, key=lambda doc: doc["_id"])) + assert qres1 == qres2 + + +def test_get_fields(gene_client: MyGeneInfo): + fields = gene_client.get_fields() + assert "uniprot" in fields.keys() + assert "exons" in fields.keys() - # test getvariants POST caching - from_cache, first_getgenes_r = cache_request(_getgenes) - del first_getgenes_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_getgenes_r = cache_request(_getgenes) - del second_getgenes_r - self.assertTrue(from_cache) - - # test query GET caching - from_cache, first_query_r = cache_request(_query) - del first_query_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_query_r = cache_request(_query) - del second_query_r - self.assertTrue(from_cache) - - # test querymany POST caching - from_cache, first_querymany_r = cache_request(_querymany) - del first_querymany_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_querymany_r = cache_request(_querymany) - del second_querymany_r - self.assertTrue(from_cache) - - finally: - self.mg.stop_caching() - if os.path.exists(cache_file): - os.remove(cache_file) - - -def suite(): - return unittest.defaultTestLoader.loadTestsFromTestCase(TestGeneClient) - - -if __name__ == "__main__": - unittest.TextTestRunner().run(suite()) + fields = gene_client.get_fields("kegg") + assert "pathway.kegg" in fields.keys() diff --git a/tests/test_geneset.py b/tests/test_geneset.py index 3ff59ce..5a018b9 100644 --- a/tests/test_geneset.py +++ b/tests/test_geneset.py @@ -1,251 +1,162 @@ -import importlib.util -import os -import sys +""" +Tests for exercising the sychronous biothings_client for mygeneset +""" + +import logging import types -import unittest -sys.path.insert(0, os.path.split(os.path.split(os.path.abspath(__file__))[0])[0]) +import pytest -from biothings_client.utils.cache import cache_request -from biothings_client.utils.score import descore import biothings_client +from biothings_client.utils.score import descore +from biothings_client.client.definitions import MyGenesetInfo + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def test_metadata(geneset_client: MyGenesetInfo): + meta = geneset_client.metadata() + assert "src" in meta + assert "stats" in meta + assert "total" in meta["stats"] + + +def test_getgeneset(geneset_client: MyGenesetInfo): + gs = geneset_client.getgeneset("WP100") + assert gs["_id"] == "WP100" + assert gs["name"] == "Glutathione metabolism" + assert gs["source"] == "wikipathways" + assert gs["taxid"] == "9606" + assert len(gs["genes"]) >= 19 + assert gs["count"] == len(gs["genes"]) + + assert "wikipathways" in gs + assert gs["wikipathways"]["id"] == "WP100" + assert gs["wikipathways"]["pathway_name"] == "Glutathione metabolism" + assert gs["wikipathways"]["url"] == "https://www.wikipathways.org/instance/WP100" + assert gs["wikipathways"]["_license"] == "https://www.wikipathways.org/terms.html" + + assert any((gene.get("mygene_id") == "2937" and gene.get("symbol") == "GSS") for gene in gs["genes"]) + + +def test_query_fetch_all(geneset_client: MyGenesetInfo): + """ + pdb --> reactome + q = source:reactome + _exists_:pdb ---> source:reactome + """ + qres = geneset_client.query("source:reactome") + total = qres["total"] + + qres = geneset_client.query("source:reactome", fields="source,count,name", fetch_all=True) + assert isinstance(qres, types.GeneratorType) + assert total == len(list(qres)) + + +def test_query_with_fields_as_list(geneset_client: MyGenesetInfo): + qres1 = geneset_client.query("genes.ncbigene:1017", fields="name,source,taxid") + qres2 = geneset_client.query("genes.ncbigene:1017", fields=["name", "source", "taxid"]) + assert "hits" in qres1 + assert len(qres1["hits"]) == 10 + assert descore(qres1["hits"]) == descore(qres2["hits"]) + + +def test_getgeneset_with_fields(geneset_client: MyGenesetInfo): + gs = geneset_client.getgeneset("WP100", fields="name,source,taxid,genes.mygene_id,genes.symbol") + + assert "_id" in gs + assert "name" in gs + assert "source" in gs + assert "taxid" in gs + + assert any((gene.get("mygene_id") == "2937" and gene.get("symbol") == "GSS") for gene in gs["genes"]) + assert not any(gene.get("name") for gene in gs["genes"]) + + +def test_getgenesets(geneset_client: MyGenesetInfo): + gs_li = geneset_client.getgenesets(["WP100", "WP101", "WP103"]) + + assert len(gs_li) == 3 + assert gs_li[0]["_id"] == "WP100" + assert gs_li[1]["_id"] == "WP101" + assert gs_li[2]["_id"] == "WP103" + + +def test_query(geneset_client: MyGenesetInfo): + qres = geneset_client.query("genes.mygene_id:2937", size=5) + assert "hits" in qres + assert len(qres["hits"]) == 5 + + +def test_query_default_fields(geneset_client: MyGenesetInfo): + geneset_client.query(q="glucose") + + +def test_query_field(geneset_client: MyGenesetInfo): + geneset_client.query(q="genes.ncbigene:1017") + + +def test_species_filter_plus_query(geneset_client: MyGenesetInfo): + dog = geneset_client.query(q="glucose", species="9615") + assert dog["hits"][0]["taxid"] == "9615" + + +def test_query_by_id(geneset_client: MyGenesetInfo): + query = geneset_client.query(q="_id:WP100") + assert query["hits"][0]["_id"] == "WP100" + + +def test_query_by_name(geneset_client: MyGenesetInfo): + kinase = geneset_client.query(q="name:kinase") + assert "kinase" in kinase["hits"][0]["name"].lower() + + +def test_query_by_description(geneset_client: MyGenesetInfo): + desc = geneset_client.query(q="description:cytosine deamination") + assert "cytosine" in desc["hits"][0]["description"].lower() + assert "deamination" in desc["hits"][0]["description"].lower() + + +def test_query_by_source_go(geneset_client: MyGenesetInfo): + go = geneset_client.query(q="source:go", fields="all") + assert "go" in go["hits"][0].keys() + assert go["hits"][0]["source"] == "go" + + +def test_query_by_source_ctd(geneset_client: MyGenesetInfo): + ctd = geneset_client.query(q="source:ctd", fields="all") + assert "ctd" in ctd["hits"][0].keys() + assert ctd["hits"][0]["source"] == "ctd" + + +def test_query_by_source_msigdb(geneset_client: MyGenesetInfo): + msigdb = geneset_client.query(q="source:msigdb", fields="all") + assert "msigdb" in msigdb["hits"][0].keys() + assert msigdb["hits"][0]["source"] == "msigdb" + + +@pytest.mark.xfail(reason="We removed kegg data source for now") +def test_query_by_source_kegg(geneset_client: MyGenesetInfo): + kegg = geneset_client.query(q="source:kegg", fields="all") + assert "kegg" in kegg["hits"][0].keys() + assert kegg["hits"][0]["source"] == "kegg" + +def test_query_by_source_do(geneset_client: MyGenesetInfo): + do = geneset_client.query(q="source:do", fields="all") + assert "do" in do["hits"][0].keys() + assert do["hits"][0]["source"] == "do" -sys.stdout.write( - '"biothings_client {0}" loaded from "{1}"\n'.format(biothings_client.__version__, biothings_client.__file__) -) - -requests_cache_available = importlib.util.find_spec("requests_cache") is not None +def test_query_by_source_reactome(geneset_client: MyGenesetInfo): + reactome = geneset_client.query(q="source:reactome", fields="all") + assert "reactome" in reactome["hits"][0].keys() + assert reactome["hits"][0]["source"] == "reactome" -class TestGenesetClient(unittest.TestCase): - def setUp(self): - self.mgs = biothings_client.get_client("geneset") - - def test_metadata(self): - meta = self.mgs.metadata() - self.assertTrue("src" in meta) - self.assertTrue("stats" in meta) - self.assertTrue("total" in meta["stats"]) - def test_getgeneset(self): - gs = self.mgs.getgeneset("WP100") - self.assertEqual(gs["_id"], "WP100") - self.assertEqual(gs["name"], "Glutathione metabolism") - self.assertEqual(gs["source"], "wikipathways") - self.assertEqual(gs["taxid"], "9606") - self.assertGreaterEqual(len(gs["genes"]), 19) - self.assertEqual(gs["count"], len(gs["genes"])) - - self.assertTrue("wikipathways" in gs) - self.assertEqual(gs["wikipathways"]["id"], "WP100") - self.assertEqual(gs["wikipathways"]["pathway_name"], "Glutathione metabolism") - self.assertEqual(gs["wikipathways"]["url"], "https://www.wikipathways.org/instance/WP100") - self.assertEqual(gs["wikipathways"]["_license"], "https://www.wikipathways.org/terms.html") - - self.assertTrue(any((gene.get("mygene_id") == "2937" and gene.get("symbol") == "GSS") for gene in gs["genes"])) - - def test_query_fetch_all(self): - # pdb-->reactome - # q=source:reactome - # _exists_:pdb ---> source:reactome - - # fetch_all won't work when caching is used. - self.mgs.stop_caching() - qres = self.mgs.query("source:reactome") - total = qres["total"] - - qres = self.mgs.query("source:reactome", fields="source,count,name", fetch_all=True) - self.assertTrue(isinstance(qres, types.GeneratorType)) - self.assertEqual(total, len(list(qres))) - - def test_query_with_fields_as_list(self): - qres1 = self.mgs.query("genes.ncbigene:1017", fields="name,source,taxid") - qres2 = self.mgs.query("genes.ncbigene:1017", fields=["name", "source", "taxid"]) - self.assertTrue("hits" in qres1) - self.assertEqual(len(qres1["hits"]), 10) - self.assertEqual(descore(qres1["hits"]), descore(qres2["hits"])) - - def test_getgeneset_with_fields(self): - gs = self.mgs.getgeneset("WP100", fields="name,source,taxid,genes.mygene_id,genes.symbol") - - self.assertTrue("_id" in gs) - self.assertTrue("name" in gs) - self.assertTrue("source" in gs) - self.assertTrue("taxid" in gs) - - self.assertTrue(any((gene.get("mygene_id") == "2937" and gene.get("symbol") == "GSS") for gene in gs["genes"])) - self.assertFalse(any(gene.get("name") for gene in gs["genes"])) - - def test_getgenesets(self): - gs_li = self.mgs.getgenesets(["WP100", "WP101", "WP103"]) - - self.assertEqual(len(gs_li), 3) - self.assertEqual(gs_li[0]["_id"], "WP100") - self.assertEqual(gs_li[1]["_id"], "WP101") - self.assertEqual(gs_li[2]["_id"], "WP103") - - def test_query(self): - qres = self.mgs.query("genes.mygene_id:2937", size=5) - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 5) - - def test_query_default_fields(self): - self.mgs.query(q="glucose") - - def test_query_field(self): - self.mgs.query(q="genes.ncbigene:1017") - - def test_species_filter_plus_query(self): - dog = self.mgs.query(q="glucose", species="9615") - self.assertEqual(dog["hits"][0]["taxid"], "9615") - - def test_query_by_id(self): - query = self.mgs.query(q="_id:WP100") - self.assertEqual(query["hits"][0]["_id"], "WP100") - - def test_query_by_name(self): - kinase = self.mgs.query(q="name:kinase") - self.assertIn("kinase", kinase["hits"][0]["name"].lower()) - - def test_query_by_description(self): - desc = self.mgs.query(q="description:cytosine deamination") - self.assertIn("cytosine", desc["hits"][0]["description"].lower()) - self.assertIn("deamination", desc["hits"][0]["description"].lower()) - - def test_query_by_source_go(self): - go = self.mgs.query(q="source:go", fields="all") - self.assertIn("go", go["hits"][0].keys()) - self.assertEqual(go["hits"][0]["source"], "go") - - def test_query_by_source_ctd(self): - ctd = self.mgs.query(q="source:ctd", fields="all") - self.assertIn("ctd", ctd["hits"][0].keys()) - self.assertEqual(ctd["hits"][0]["source"], "ctd") - - def test_query_by_source_msigdb(self): - msigdb = self.mgs.query(q="source:msigdb", fields="all") - self.assertIn("msigdb", msigdb["hits"][0].keys()) - self.assertEqual(msigdb["hits"][0]["source"], "msigdb") - - @unittest.skip("We removed kegg data source for now") - def test_query_by_source_kegg(self): - kegg = self.mgs.query(q="source:kegg", fields="all") - self.assertIn("kegg", kegg["hits"][0].keys()) - self.assertEqual(kegg["hits"][0]["source"], "kegg") - - def test_query_by_source_do(self): - do = self.mgs.query(q="source:do", fields="all") - self.assertIn("do", do["hits"][0].keys()) - self.assertEqual(do["hits"][0]["source"], "do") - - def test_query_by_source_reactome(self): - reactome = self.mgs.query(q="source:reactome", fields="all") - self.assertIn("reactome", reactome["hits"][0].keys()) - self.assertEqual(reactome["hits"][0]["source"], "reactome") - - def test_query_by_source_smpdb(self): - smpdb = self.mgs.query(q="source:smpdb", fields="all") - self.assertIn("smpdb", smpdb["hits"][0].keys()) - self.assertEqual(smpdb["hits"][0]["source"], "smpdb") - - @unittest.skipIf(not requests_cache_available, "requests_cache not available") - def test_caching(self): - def _getgeneset(): - return self.mgs.getgeneset("WP100") - - def _getgenesets(): - return self.mgs.getgenesets(["WP100", "WP101"]) - - def _query(): - return self.mgs.query("wnt", fields="name,count,source,taxid") - - def _querymany(): - return self.mgs.querymany(["wnt", "jak-stat"], fields="name,count,source,taxid") - - try: - from_cache, pre_cache_r = cache_request(_getgeneset) - self.assertFalse(from_cache) - - cache_name = "mgsc" - cache_file = cache_name + ".sqlite" - - if os.path.exists(cache_file): - os.remove(cache_file) - self.mgs.set_caching(cache_name) - # populate cache - from_cache, cache_fill_r = cache_request(_getgeneset) - self.assertTrue(os.path.exists(cache_file)) - self.assertFalse(from_cache) - # is it from the cache? - from_cache, cached_r = cache_request(_getgeneset) - self.assertTrue(from_cache) - - self.mgs.stop_caching() - # same query should be live - not cached - from_cache, post_cache_r = cache_request(_getgeneset) - self.assertFalse(from_cache) - - self.mgs.set_caching(cache_name) - - # same query should still be sourced from cache - from_cache, recached_r = cache_request(_getgeneset) - self.assertTrue(from_cache) - - self.mgs.clear_cache() - # cache was cleared, same query should be live - from_cache, clear_cached_r = cache_request(_getgeneset) - self.assertFalse(from_cache) - - # all requests should be identical except the _score, which can vary slightly - for x in [pre_cache_r, cache_fill_r, cached_r, post_cache_r, recached_r, clear_cached_r]: - x.pop("_score", None) - - self.assertTrue( - all( - [ - x == pre_cache_r - for x in [pre_cache_r, cache_fill_r, cached_r, post_cache_r, recached_r, clear_cached_r] - ] - ) - ) - - # test getvariants POST caching - from_cache, first_getgenes_r = cache_request(_getgenesets) - del first_getgenes_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_getgenes_r = cache_request(_getgenesets) - del second_getgenes_r - self.assertTrue(from_cache) - - # test query GET caching - from_cache, first_query_r = cache_request(_query) - del first_query_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_query_r = cache_request(_query) - del second_query_r - self.assertTrue(from_cache) - - # test querymany POST caching - from_cache, first_querymany_r = cache_request(_querymany) - del first_querymany_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_querymany_r = cache_request(_querymany) - del second_querymany_r - self.assertTrue(from_cache) - - finally: - self.mgs.stop_caching() - if os.path.exists(cache_file): - os.remove(cache_file) - - -def suite(): - return unittest.defaultTestLoader.loadTestsFromTestCase(TestGenesetClient) - - -if __name__ == "__main__": - unittest.TextTestRunner().run(suite()) +def test_query_by_source_smpdb(geneset_client: MyGenesetInfo): + smpdb = geneset_client.query(q="source:smpdb", fields="all") + assert "smpdb" in smpdb["hits"][0].keys() + assert smpdb["hits"][0]["source"] == "smpdb" diff --git a/tests/test_sync.py b/tests/test_sync.py index 5e69147..8506742 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,51 +1,81 @@ -import os -import sys -import unittest +""" +Test suite for the sync client +""" -import biothings_client - -sys.path.insert(0, os.path.split(os.path.split(os.path.abspath(__file__))[0])[0]) - - -sys.stdout.write( - '"biothings_client {0}" loaded from "{1}"\n'.format(biothings_client.__version__, biothings_client.__file__) -) - - -class TestBiothingsClient(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - - def test_get_client(self): - gene_client = biothings_client.get_client("gene") - self.assertEqual(type(gene_client).__name__, "MyGeneInfo") +from typing import List - variant_client = biothings_client.get_client("variant") - self.assertEqual(type(variant_client).__name__, "MyVariantInfo") +import pytest - chem_client = biothings_client.get_client("chem") - self.assertEqual(type(chem_client).__name__, "MyChemInfo") - - # drug_client as an alias of chem_client - drug_client = biothings_client.get_client("drug") - self.assertEqual(type(drug_client).__name__, "MyChemInfo") - - disease_client = biothings_client.get_client("disease") - self.assertEqual(type(disease_client).__name__, "MyDiseaseInfo") - - taxon_client = biothings_client.get_client("taxon") - self.assertEqual(type(taxon_client).__name__, "MyTaxonInfo") - - geneset_client = biothings_client.get_client("geneset") - self.assertEqual(type(geneset_client).__name__, "MyGenesetInfo") +import biothings_client - geneset_client = biothings_client.get_client(url="https://mygeneset.info/v1") - self.assertEqual(type(geneset_client).__name__, "MyGenesetInfo") - def test_generate_settings_from_url(self): - client_settings = biothings_client.client.base.generate_settings("geneset", url="https://mygeneset.info/v1") - self.assertEqual(client_settings["class_kwargs"]["_default_url"], "https://mygeneset.info/v1") - self.assertEqual(client_settings["class_name"], "MyGenesetInfo") +@pytest.mark.parametrize( + "client_name,client_url,class_name", + [ + (["gene"], "https://mygene.info/v3", "MyGeneInfo"), + (["variant"], "https://myvariant.info/v1", "MyVariantInfo"), + (["chem", "drug"], "https://mychem.info/v1", "MyChemInfo"), + (["disease"], "https://mydisease.info/v1", "MyDiseaseInfo"), + (["taxon"], "https://t.biothings.io/v1", "MyTaxonInfo"), + (["geneset"], "https://mygeneset.info/v1", "MyGenesetInfo"), + ], +) +def test_get_client(client_name: List[str], client_url: str, class_name: str): + """ + Tests our ability to generate sync clients + """ + client_name_instances = [biothings_client.get_client(name) for name in client_name] + client_url_instance = biothings_client.get_client(url=client_url) + clients = [client_url_instance, *client_name_instances] + for client in clients: + assert type(client).__name__ == class_name + + +@pytest.mark.parametrize( + "client_name,client_url,class_name", + [ + ("gene", "https://mygene.info/v3", "MyGeneInfo"), + ("variant", "https://mychem.info/v1", "MyVariantInfo"), + ("chem", "https://mychem.info/v1", "MyChemInfo"), + ("disease", "https://mydisease.info/v1", "MyDiseaseInfo"), + ("taxon", "https://t.biothings.io/v1", "MyTaxonInfo"), + ("geneset", "https://mygeneset.info/v1", "MyGenesetInfo"), + ], +) +def test_generate_settings(client_name: str, client_url: str, class_name: str): + client_settings = biothings_client.client.base.generate_settings(client_name, url=client_url) + assert client_settings["class_kwargs"]["_default_url"] == client_url + assert client_settings["class_name"] == class_name + + +@pytest.mark.parametrize( + "client_name", + ( + "gene", + "variant", + "chem", + "disease", + "taxon", + "geneset", + ), +) +def test_url_protocol(client_name: str): + """ + Tests that our HTTP protocol modification methods work + as expected when transforming to either HTTP or HTTPS + """ + client_instance = biothings_client.get_client(client_name) + + http_protocol = "http://" + https_protocol = "https://" + + # DEFAULT: HTTPS + assert client_instance.url.startswith(https_protocol) + + # Transform to HTTP + client_instance.use_http() + assert client_instance.url.startswith(http_protocol) + + # Transform back to HTTPS + client_instance.use_https() + client_instance.url.startswith(https_protocol) diff --git a/tests/test_variant.py b/tests/test_variant.py index f17c654..024385d 100644 --- a/tests/test_variant.py +++ b/tests/test_variant.py @@ -1,322 +1,255 @@ -import importlib.util -import os -import sys +""" +Tests for exercising the sychronous biothings_client for mygene +""" + +import logging import types -import unittest -sys.path.insert(0, os.path.split(os.path.split(os.path.abspath(__file__))[0])[0]) +import pytest -from biothings_client.utils.cache import cache_request -from biothings_client.utils.score import descore import biothings_client +from biothings_client.utils.score import descore +from biothings_client.client.definitions import MyVariantInfo + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def test_format_hgvs(variant_client: MyVariantInfo): + assert variant_client.format_hgvs("1", 35366, "C", "T") == "chr1:g.35366C>T" + assert variant_client.format_hgvs("chr2", 17142, "G", "GA") == "chr2:g.17142_17143insA" + assert variant_client.format_hgvs("1", 10019, "TA", "T") == "chr1:g.10020del" + assert variant_client.format_hgvs("MT", 8270, "CACCCCCTCT", "C") == "chrMT:g.8271_8279del" + assert variant_client.format_hgvs("7", 15903, "G", "GC") == "chr7:g.15903_15904insC" + assert variant_client.format_hgvs("X", 107930849, "GGA", "C") == "chrX:g.107930849_107930851delinsC" + assert variant_client.format_hgvs("20", 1234567, "GTC", "GTCT") == "chr20:g.1234569_1234570insT" -sys.stderr.write( - '"biothings_client {0}" loaded from "{1}"\n'.format(biothings_client.__version__, biothings_client.__file__) -) - -pandas_available = importlib.util.find_spec("pandas") is not None -requests_cache_available = importlib.util.find_spec("requests_cache") is not None - - -class TestVariantClient(unittest.TestCase): - def setUp(self): - self.mv = biothings_client.get_client("variant") - self.query_list1 = [ - "chr1:g.866422C>T", - "chr1:g.876664G>A", - "chr1:g.69635G>C", - "chr1:g.69869T>A", - "chr1:g.881918G>A", - "chr1:g.865625G>A", - "chr1:g.69892T>C", - "chr1:g.879381C>T", - "chr1:g.878330C>G", - ] - self.query_list2 = [ - "rs374802787", - "rs1433078", - "rs1433115", - "rs377266517", - "rs587640013", - "rs137857980", - "rs199710579", - "rs186823979", - # 'rs2276240', - "rs34521797", - "rs372452565", - ] - - def test_format_hgvs(self): - self.assertEqual(self.mv.format_hgvs("1", 35366, "C", "T"), "chr1:g.35366C>T") - self.assertEqual(self.mv.format_hgvs("chr2", 17142, "G", "GA"), "chr2:g.17142_17143insA") - self.assertEqual(self.mv.format_hgvs("1", 10019, "TA", "T"), "chr1:g.10020del") - self.assertEqual(self.mv.format_hgvs("MT", 8270, "CACCCCCTCT", "C"), "chrMT:g.8271_8279del") - self.assertEqual(self.mv.format_hgvs("7", 15903, "G", "GC"), "chr7:g.15903_15904insC") - self.assertEqual(self.mv.format_hgvs("X", 107930849, "GGA", "C"), "chrX:g.107930849_107930851delinsC") - self.assertEqual(self.mv.format_hgvs("20", 1234567, "GTC", "GTCT"), "chr20:g.1234569_1234570insT") - - def test_metadata(self): - meta = self.mv.metadata() - self.assertTrue("stats" in meta) - self.assertTrue("total" in meta["stats"]) - - def test_getvariant(self): - v = self.mv.getvariant("chr9:g.107620835G>A") - self.assertEqual(v["_id"], "chr9:g.107620835G>A") - self.assertEqual(v["snpeff"]["ann"]["genename"], "ABCA1") - - v = self.mv.getvariant("'chr1:g.1A>C'") # something does not exist - self.assertEqual(v, None) - - def test_getvariant_with_fields(self): - v = self.mv.getvariant("chr9:g.107620835G>A", fields="dbnsfp,cadd,cosmic") - self.assertTrue("_id" in v) - self.assertTrue("dbnsfp" in v) - self.assertTrue("cadd" in v) - self.assertTrue("cosmic" in v) - - def test_getvariants(self): - v_li = self.mv.getvariants(self.query_list1) - self.assertEqual(len(v_li), 9) - self.assertEqual(v_li[0]["_id"], self.query_list1[0]) - self.assertEqual(v_li[1]["_id"], self.query_list1[1]) - self.assertEqual(v_li[2]["_id"], self.query_list1[2]) - - self.mv.step = 4 - # test input is a string of comma-separated ids - v_li2 = self.mv.getvariants(",".join(self.query_list1)) - self.assertEqual(v_li, v_li2) - # test input is a tuple - v_li2 = self.mv.getvariants(tuple(self.query_list1)) - self.assertEqual(v_li, v_li2) - - # test input is a generator - def _input(li): - for x in li: - yield x - - v_li2 = self.mv.getvariants(_input(self.query_list1)) - self.assertEqual(v_li, v_li2) - self.mv.step = 1000 - - def test_query(self): - qres = self.mv.query("dbnsfp.genename:cdk2", size=5) - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 5) - - def test_query_hgvs(self): - qres = self.mv.query('"NM_000048.4:c.566A>G"', size=5) - # should match clinvar.hgvs.coding field from variant "chr7:g.65551772A>G" - # sometime we need to update ".4" part if clinvar data updated. - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - - def test_query_rsid(self): - qres = self.mv.query("dbsnp.rsid:rs58991260") - self.assertTrue("hits" in qres) - self.assertEqual(len(qres["hits"]), 1) - self.assertEqual(qres["hits"][0]["_id"], "chr1:g.218631822G>A") - qres2 = self.mv.query("rs58991260") - # exclude _score field before comparison - qres["hits"][0].pop("_score") - qres2["hits"][0].pop("_score") - self.assertEqual(qres["hits"], qres2["hits"]) - - def test_query_symbol(self): - qres = self.mv.query("snpeff.ann.genename:cdk2") - self.assertTrue("hits" in qres) - self.assertTrue(qres["total"] > 5000) - self.assertEqual(qres["hits"][0]["snpeff"]["ann"][0]["genename"], "CDK2") - - def test_query_genomic_range(self): - qres = self.mv.query("chr1:69000-70000") - self.assertTrue("hits" in qres) - self.assertTrue(qres["total"] >= 3) - - def test_query_fetch_all(self): - # fetch_all won't work when caching is used. - self.mv.stop_caching() - qres = self.mv.query("chr1:69500-70000", fields="chrom") - total = qres["total"] - - qres = self.mv.query("chr1:69500-70000", fields="chrom", fetch_all=True) - self.assertTrue(isinstance(qres, types.GeneratorType)) - self.assertEqual(total, len(list(qres))) - - def test_querymany(self): - qres = self.mv.querymany(self.query_list1, verbose=False) - self.assertEqual(len(qres), 9) - - self.mv.step = 4 - # test input as a string - qres2 = self.mv.querymany(",".join(self.query_list1), verbose=False) - self.assertEqual(qres, qres2) - # test input as a tuple - qres2 = self.mv.querymany(tuple(self.query_list1), verbose=False) - self.assertEqual(qres, qres2) - # test input as a iterator - qres2 = self.mv.querymany(iter(self.query_list1), verbose=False) - self.assertEqual(qres, qres2) - self.mv.step = 1000 - - def test_querymany_with_scopes(self): - qres = self.mv.querymany(["rs58991260", "rs2500"], scopes="dbsnp.rsid", verbose=False) - self.assertEqual(len(qres), 2) - - qres = self.mv.querymany( - ["RCV000083620", "RCV000083611", "RCV000083584"], scopes="clinvar.rcv_accession", verbose=False - ) - self.assertEqual(len(qres), 3) - - qres = self.mv.querymany( - ["rs2500", "RCV000083611", "COSM1392449"], - scopes="clinvar.rcv_accession,dbsnp.rsid,cosmic.cosmic_id", - verbose=False, - ) - self.assertEqual(len(qres), 3) - - def test_querymany_fields(self): - ids = ["COSM1362966", "COSM990046", "COSM1392449"] - qres1 = self.mv.querymany( - ids, scopes="cosmic.cosmic_id", fields=["cosmic.tumor_site", "cosmic.cosmic_id"], verbose=False - ) - self.assertEqual(len(qres1), 3) - - qres2 = self.mv.querymany( - ids, scopes="cosmic.cosmic_id", fields="cosmic.tumor_site,cosmic.cosmic_id", verbose=False - ) - self.assertEqual(len(qres2), 3) - - self.assertEqual(descore(qres1), descore(qres2)) - - def test_querymany_notfound(self): - qres = self.mv.querymany(["rs58991260", "rs2500", "NA_TEST"], scopes="dbsnp.rsid", verbose=False) - self.assertEqual(len(qres), 3) - self.assertEqual(qres[2], {"query": "NA_TEST", "notfound": True}) - - @unittest.skipIf(not pandas_available, "pandas not available") - def test_querymany_dataframe(self): - from pandas import DataFrame - - qres = self.mv.querymany( - self.query_list2, scopes="dbsnp.rsid", fields="dbsnp", as_dataframe=True, verbose=False - ) - self.assertTrue(isinstance(qres, DataFrame)) - self.assertTrue("dbsnp.vartype" in qres.columns) - self.assertEqual(set(self.query_list2), set(qres.index)) - - def test_querymany_step(self): - qres1 = self.mv.querymany(self.query_list2, scopes="dbsnp.rsid", fields="dbsnp.rsid", verbose=False) - default_step = self.mv.step - self.mv.step = 3 - qres2 = self.mv.querymany(self.query_list2, scopes="dbsnp.rsid", fields="dbsnp.rsid", verbose=False) - self.mv.step = default_step - # self.assertEqual(qres1, qres2, (qres1, qres2)) - self.assertEqual(descore(qres1), descore(qres2)) - - def test_get_fields(self): - fields = self.mv.get_fields() - self.assertTrue("dbsnp.chrom" in fields.keys()) - self.assertTrue("clinvar.chrom" in fields.keys()) - - @unittest.skipIf(not requests_cache_available, "requests_cache not available") - def test_caching(self): - def _getvariant(): - return self.mv.getvariant("chr9:g.107620835G>A") - - def _getvariants(): - return self.mv.getvariants(["chr9:g.107620835G>A", "chr1:g.876664G>A"]) - - def _query(): - return self.mv.query("dbsnp.rsid:rs58991260") - - def _querymany(): - return self.mv.querymany(["rs58991260", "rs2500"], scopes="dbsnp.rsid") - - try: - from_cache, pre_cache_r = cache_request(_getvariant) - self.assertFalse(from_cache) - - cache_name = "mvc" - cache_file = cache_name + ".sqlite" - - if os.path.exists(cache_file): - os.remove(cache_file) - self.mv.set_caching(cache_name) - - # populate cache - from_cache, cache_fill_r = cache_request(_getvariant) - self.assertTrue(os.path.exists(cache_file)) - self.assertFalse(from_cache) - # is it from the cache? - from_cache, cached_r = cache_request(_getvariant) - self.assertTrue(from_cache) - - self.mv.stop_caching() - # same query should be live - not cached - from_cache, post_cache_r = cache_request(_getvariant) - self.assertFalse(from_cache) - - self.mv.set_caching(cache_name) - # same query should still be sourced from cache - from_cache, recached_r = cache_request(_getvariant) - self.assertTrue(from_cache) - - self.mv.clear_cache() - # cache was cleared, same query should be live - from_cache, clear_cached_r = cache_request(_getvariant) - self.assertFalse(from_cache) - - # all requests should be identical except their _score, which can vary slightly - for x in [pre_cache_r, cache_fill_r, cached_r, post_cache_r, recached_r, clear_cached_r]: - x.pop("_score", None) - - self.assertTrue( - all( - [ - x == pre_cache_r - for x in [pre_cache_r, cache_fill_r, cached_r, post_cache_r, recached_r, clear_cached_r] - ] - ) - ) - - # test getvariants POST caching - from_cache, first_getvariants_r = cache_request(_getvariants) - del first_getvariants_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_getvariants_r = cache_request(_getvariants) - del second_getvariants_r - self.assertTrue(from_cache) - - # test query GET caching - from_cache, first_query_r = cache_request(_query) - del first_query_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_query_r = cache_request(_query) - del second_query_r - self.assertTrue(from_cache) - - # test querymany POST caching - from_cache, first_querymany_r = cache_request(_querymany) - del first_querymany_r - self.assertFalse(from_cache) - # should be from cache this time - from_cache, second_querymany_r = cache_request(_querymany) - del second_querymany_r - self.assertTrue(from_cache) - - finally: - self.mv.stop_caching() - if os.path.exists(cache_file): - os.remove(cache_file) - - -def suite(): - return unittest.defaultTestLoader.loadTestsFromTestCase(TestVariantClient) - - -if __name__ == "__main__": - unittest.TextTestRunner().run(suite()) +def test_metadata(variant_client: MyVariantInfo): + meta = variant_client.metadata() + assert "stats" in meta + assert "total" in meta["stats"] + + +def test_getvariant(variant_client: MyVariantInfo): + v = variant_client.getvariant("chr9:g.107620835G>A") + assert v["_id"] == "chr9:g.107620835G>A" + assert v["snpeff"]["ann"]["genename"] == "ABCA1" + + v = variant_client.getvariant("'chr1:g.1A>C'") # something does not exist + assert v is None + + +def test_getvariant_with_fields(variant_client: MyVariantInfo): + v = variant_client.getvariant("chr9:g.107620835G>A", fields="dbnsfp,cadd,cosmic") + assert "_id" in v + assert "dbnsfp" in v + assert "cadd" in v + assert "cosmic" in v + + +def test_getvariants(variant_client: MyVariantInfo): + query_list1 = [ + "chr1:g.866422C>T", + "chr1:g.876664G>A", + "chr1:g.69635G>C", + "chr1:g.69869T>A", + "chr1:g.881918G>A", + "chr1:g.865625G>A", + "chr1:g.69892T>C", + "chr1:g.879381C>T", + "chr1:g.878330C>G", + ] + v_li = variant_client.getvariants(query_list1) + assert len(v_li) == 9 + assert v_li[0]["_id"] == query_list1[0] + assert v_li[1]["_id"] == query_list1[1] + assert v_li[2]["_id"] == query_list1[2] + + variant_client.step = 4 + # test input is a string of comma-separated ids + v_li2 = variant_client.getvariants(",".join(query_list1)) + assert v_li == v_li2 + # test input is a tuple + v_li2 = variant_client.getvariants(tuple(query_list1)) + assert v_li == v_li2 + + # test input is a generator + def _input(li): + for x in li: + yield x + + v_li2 = variant_client.getvariants(_input(query_list1)) + assert v_li == v_li2 + variant_client.step = 1000 + + +def test_query(variant_client: MyVariantInfo): + qres = variant_client.query("dbnsfp.genename:cdk2", size=5) + assert "hits" in qres + assert len(qres["hits"]) == 5 + + +def test_query_hgvs(variant_client: MyVariantInfo): + qres = variant_client.query('"NM_000048.4:c.566A>G"', size=5) + # should match clinvar.hgvs.coding field from variant "chr7:g.65551772A>G" + # sometime we need to update ".4" part if clinvar data updated. + assert "hits" in qres + assert len(qres["hits"]) == 1 + + +def test_query_rsid(variant_client: MyVariantInfo): + qres = variant_client.query("dbsnp.rsid:rs58991260") + assert "hits" in qres + assert len(qres["hits"]) == 1 + assert qres["hits"][0]["_id"] == "chr1:g.218631822G>A" + qres2 = variant_client.query("rs58991260") + + # exclude _score field before comparison + qres["hits"][0].pop("_score") + qres2["hits"][0].pop("_score") + assert qres["hits"] == qres2["hits"] + + +def test_query_symbol(variant_client: MyVariantInfo): + qres = variant_client.query("snpeff.ann.genename:cdk2") + assert "hits" in qres + assert qres["total"] > 5000 + assert qres["hits"][0]["snpeff"]["ann"][0]["genename"] == "CDK2" + + +def test_query_genomic_range(variant_client: MyVariantInfo): + qres = variant_client.query("chr1:69000-70000") + assert "hits" in qres + assert qres["total"] >= 3 + + +def test_query_fetch_all(variant_client: MyVariantInfo): + qres = variant_client.query("chr1:69500-70000", fields="chrom") + total = qres["total"] + + qres = variant_client.query("chr1:69500-70000", fields="chrom", fetch_all=True) + assert isinstance(qres, types.GeneratorType) + assert total == len(list(qres)) + + +def test_querymany(variant_client: MyVariantInfo): + query_list1 = [ + "chr1:g.866422C>T", + "chr1:g.876664G>A", + "chr1:g.69635G>C", + "chr1:g.69869T>A", + "chr1:g.881918G>A", + "chr1:g.865625G>A", + "chr1:g.69892T>C", + "chr1:g.879381C>T", + "chr1:g.878330C>G", + ] + qres = variant_client.querymany(query_list1, verbose=False) + assert len(qres) == 9 + + variant_client.step = 4 + + # test input as a string + qres2 = variant_client.querymany(",".join(query_list1), verbose=False) + assert qres == qres2 + + # test input as a tuple + qres2 = variant_client.querymany(tuple(query_list1), verbose=False) + assert qres == qres2 + + # test input as a iterator + qres2 = variant_client.querymany(iter(query_list1), verbose=False) + assert qres == qres2 + variant_client.step = 1000 + + +def test_querymany_with_scopes(variant_client: MyVariantInfo): + qres = variant_client.querymany(["rs58991260", "rs2500"], scopes="dbsnp.rsid", verbose=False) + assert len(qres) == 2 + + qres = variant_client.querymany( + ["RCV000083620", "RCV000083611", "RCV000083584"], scopes="clinvar.rcv_accession", verbose=False + ) + assert len(qres) == 3 + + qres = variant_client.querymany( + ["rs2500", "RCV000083611", "COSM1392449"], + scopes="clinvar.rcv_accession,dbsnp.rsid,cosmic.cosmic_id", + verbose=False, + ) + assert len(qres) == 3 + + +def test_querymany_fields(variant_client: MyVariantInfo): + ids = ["COSM1362966", "COSM990046", "COSM1392449"] + qres1 = variant_client.querymany( + ids, scopes="cosmic.cosmic_id", fields=["cosmic.tumor_site", "cosmic.cosmic_id"], verbose=False + ) + assert len(qres1) == 3 + + qres2 = variant_client.querymany( + ids, scopes="cosmic.cosmic_id", fields="cosmic.tumor_site,cosmic.cosmic_id", verbose=False + ) + assert len(qres2) == 3 + + assert descore(qres1) == descore(qres2) + + +def test_querymany_notfound(variant_client: MyVariantInfo): + qres = variant_client.querymany(["rs58991260", "rs2500", "NA_TEST"], scopes="dbsnp.rsid", verbose=False) + assert len(qres) == 3 + assert qres[2] == {"query": "NA_TEST", "notfound": True} + + +@pytest.mark.skipif(not biothings_client._PANDAS, reason="pandas library not installed") +def test_querymany_dataframe(variant_client: MyVariantInfo): + from pandas import DataFrame + + query_list2 = [ + "rs374802787", + "rs1433078", + "rs1433115", + "rs377266517", + "rs587640013", + "rs137857980", + "rs199710579", + "rs186823979", + # 'rs2276240', + "rs34521797", + "rs372452565", + ] + + qres = variant_client.querymany(query_list2, scopes="dbsnp.rsid", fields="dbsnp", as_dataframe=True, verbose=False) + assert isinstance(qres, DataFrame) + assert "dbsnp.vartype" in qres.columns + assert set(query_list2) == set(qres.index) + + +def test_querymany_step(variant_client: MyVariantInfo): + query_list2 = [ + "rs374802787", + "rs1433078", + "rs1433115", + "rs377266517", + "rs587640013", + "rs137857980", + "rs199710579", + "rs186823979", + # 'rs2276240', + "rs34521797", + "rs372452565", + ] + qres1 = variant_client.querymany(query_list2, scopes="dbsnp.rsid", fields="dbsnp.rsid", verbose=False) + default_step = variant_client.step + variant_client.step = 3 + qres2 = variant_client.querymany(query_list2, scopes="dbsnp.rsid", fields="dbsnp.rsid", verbose=False) + variant_client.step = default_step + # assert qres1, qres2, (qres1, qres2)) + assert descore(qres1) == descore(qres2) + + +def test_get_fields(variant_client: MyVariantInfo): + fields = variant_client.get_fields() + assert "dbsnp.chrom" in fields.keys() + assert "clinvar.chrom" in fields.keys() diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 2023482..0000000 --- a/tox.ini +++ /dev/null @@ -1,36 +0,0 @@ -[tox] -requires = - tox>=4 -env_list = py{27,34,35,36,36-pandas,36-caching37,38,39,310,311} - -[testenv:py36-pandas] -deps= - pytest - pandas -commands= - pytest -v {tox_root}/tests/gene.py::TestGeneClient::test_querymany_dataframe - pytest -v {tox_root}/tests/variant.py::TestVariantClient::test_querymany_dataframe - -[testenv:py36-caching] -deps= - pytest - requests_cache -commands= - pytest -v {tox_root}/tests/gene.py::TestGeneClient::test_caching - pytest -v {tox_root}/tests/variant.py::TestVariantClient::test_caching - -[testenv:py311] -description = python3.11 pytest test suite execution -deps= - pandas - pytest - requests_cache -commands= - pytest -v \ - "{tox_root}{/}tests{/}chem.py" \ - "{tox_root}{/}tests{/}gene.py" \ - "{tox_root}{/}tests{/}geneset.py" \ - "{tox_root}{/}tests{/}test.py" \ - "{tox_root}{/}tests{/}utils.py" \ - "{tox_root}{/}tests{/}variant.py" \ - {posargs}