Skip to content

Commit

Permalink
Interim work for async testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jschaff committed Oct 4, 2024
1 parent 99ed0f9 commit 72e8319
Show file tree
Hide file tree
Showing 7 changed files with 483 additions and 68 deletions.
57 changes: 26 additions & 31 deletions biothings_client/client/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ async def _repeated_query(self, query_fn, query_li, verbose=True, **fn_kwargs):
if verbose:
logger.info("querying {0}-{1}...".format(i + 1, cnt))
i = cnt
from_cache, query_result = query_fn(batch, **fn_kwargs)
from_cache, query_result = await query_fn(batch, **fn_kwargs)
yield query_result
if verbose:
cache_str = " {0}".format(self._from_cache_notification) if from_cache else ""
Expand All @@ -226,7 +226,7 @@ async def _metadata(self, verbose=True, **kwargs):
Return a dictionary of Biothing metadata.
"""
_url = self.url + self._metadata_endpoint
from_cache, ret = self._get(_url, params=kwargs, verbose=verbose)
from_cache, ret = await self._get(_url, params=kwargs, verbose=verbose)
if verbose and from_cache:
logger.info(self._from_cache_notification)
return ret
Expand All @@ -246,7 +246,7 @@ async def _get_fields(self, search_term=None, verbose=True):
params = {"search": search_term}
else:
params = {}
from_cache, ret = self._get(_url, params=params, verbose=verbose)
from_cache, ret = await self._get(_url, params=params, verbose=verbose)
for k, v in ret.items():
del k
# Get rid of the notes column information
Expand All @@ -270,9 +270,9 @@ async def _getannotation(self, _id, fields=None, **kwargs):
verbose = kwargs.pop("verbose", True)
if fields:
kwargs["fields"] = fields
kwargs = self._handle_common_kwargs(kwargs)
kwargs = await self._handle_common_kwargs(kwargs)
_url = self.url + self._annotation_endpoint + str(_id)
from_cache, ret = self._get(_url, kwargs, none_on_404=True, verbose=verbose)
from_cache, ret = await self._get(_url, kwargs, none_on_404=True, verbose=verbose)
if verbose and from_cache:
logger.info(self._from_cache_notification)
return ret
Expand All @@ -288,8 +288,8 @@ async def _annotations_generator(self, query_fn, ids, verbose=True, **kwargs):
"""
Function to yield a batch of hits one at a time.
"""
async for hits in await self._repeated_query(query_fn, ids, verbose=verbose):
async for hit in hits:
for hits in self._repeated_query(query_fn, ids, verbose=verbose):
for hit in hits:
yield hit

async def _getannotations(self, ids, fields=None, **kwargs):
Expand Down Expand Up @@ -325,7 +325,7 @@ async def _getannotations(self, ids, fields=None, **kwargs):
raise ValueError('input "ids" must be a list, tuple or iterable.')
if fields:
kwargs["fields"] = fields
kwargs = self._handle_common_kwargs(kwargs)
kwargs = await self._handle_common_kwargs(kwargs)
verbose = kwargs.pop("verbose", True)
dataframe = kwargs.pop("as_dataframe", None)
df_index = kwargs.pop("df_index", True)
Expand All @@ -338,21 +338,21 @@ async def _getannotations(self, ids, fields=None, **kwargs):
if return_raw:
dataframe = None

def query_fn(ids):
return self._getannotations_inner(ids, verbose=verbose, **kwargs)
async def query_fn(ids):
return await self._getannotations_inner(ids, verbose=verbose, **kwargs)

if generator:
return self._annotations_generator(query_fn, ids, verbose=verbose, **kwargs)
out = []
for hits in self._repeated_query(query_fn, ids, verbose=verbose):
async for hits in self._repeated_query(query_fn, ids, verbose=verbose):
if return_raw:
out.append(hits) # hits is the raw response text
else:
out.extend(hits)
if return_raw and len(out) == 1:
out = out[0]
if dataframe:
out = self._dataframe(out, dataframe, df_index=df_index)
out = await self._dataframe(out, dataframe, df_index=df_index)
return out

async def _query(self, q: str, **kwargs):
Expand Down Expand Up @@ -388,7 +388,7 @@ async def _query(self, q: str, **kwargs):
"""
_url = self.url + self._query_endpoint
verbose = kwargs.pop("verbose", True)
kwargs = self._handle_common_kwargs(kwargs)
kwargs = await self._handle_common_kwargs(kwargs)
kwargs.update({"q": q})
fetch_all = kwargs.get("fetch_all")
if fetch_all in [True, 1]:
Expand All @@ -397,30 +397,24 @@ async def _query(self, q: str, **kwargs):
"Ignored 'as_dataframe' because 'fetch_all' is specified. "
"Too many documents to return as a Dataframe."
)
return self._fetch_all(url=_url, verbose=verbose, **kwargs)
return await self._fetch_all(url=_url, verbose=verbose, **kwargs)
dataframe = kwargs.pop("as_dataframe", None)
if dataframe in [True, 1]:
dataframe = 1
elif dataframe != 2:
dataframe = None
from_cache, out = self._get(_url, kwargs, verbose=verbose)
from_cache, out = await self._get(_url, kwargs, verbose=verbose)
if verbose and from_cache:
logger.info(self._from_cache_notification)
if dataframe:
out = self._dataframe(out, dataframe, df_index=False)
out = await self._dataframe(out, dataframe, df_index=False)
return out

async def _fetch_all(self, url, verbose=True, **kwargs):
"""
Function that returns a generator to results. Assumes that 'q' is in kwargs.
"""

# function to get the next batch of results, automatically disables cache if we are caching
async def _batch():
from_cache, ret = self._get(url, params=kwargs, verbose=verbose)
return ret

batch = _batch()
from_cache, batch = await self._get(url, params=kwargs, verbose=verbose)
if verbose:
logger.info("Fetching {0} {1} . . .".format(batch["total"], self._optionally_plural_object_type))
for key in ["q", "fetch_all"]:
Expand All @@ -434,10 +428,11 @@ async def _batch():
for hit in batch["hits"]:
yield hit
kwargs.update({"scroll_id": batch["_scroll_id"]})
batch = _batch()
from_cache, batch = await self._get(url, params=kwargs, verbose=verbose)

async def _querymany_inner(self, qterms, verbose=True, **kwargs):
_kwargs = {"q": self._format_list(qterms)}
query_term_collection = await self._format_list(qterms)
_kwargs = {"q": query_term_collection}
_kwargs.update(kwargs)
_url = self.url + self._query_endpoint
return await self._post(_url, params=_kwargs, verbose=verbose)
Expand Down Expand Up @@ -474,8 +469,8 @@ async def _querymany(self, qterms, scopes=None, **kwargs):
raise ValueError('input "qterms" must be a list, tuple or iterable.')

if scopes:
kwargs["scopes"] = self._format_list(scopes, quoted=False)
kwargs = self._handle_common_kwargs(kwargs)
kwargs["scopes"] = await self._format_list(scopes, quoted=False)
kwargs = await self._handle_common_kwargs(kwargs)
returnall = kwargs.pop("returnall", False)
verbose = kwargs.pop("verbose", True)
dataframe = kwargs.pop("as_dataframe", None)
Expand All @@ -494,9 +489,9 @@ async def _querymany(self, qterms, scopes=None, **kwargs):
li_query = []

async def query_fn(qterms):
return self._querymany_inner(qterms, verbose=verbose, **kwargs)
return await self._querymany_inner(qterms, verbose=verbose, **kwargs)

for hits in await self._repeated_query(query_fn, qterms, verbose=verbose):
async for hits in self._repeated_query(query_fn, qterms, verbose=verbose):
if return_raw:
out.append(hits) # hits is the raw response text
else:
Expand All @@ -520,7 +515,7 @@ async def query_fn(qterms):
del li_query

if dataframe:
out = self._dataframe(out, dataframe, df_index=df_index)
out = await self._dataframe(out, dataframe, df_index=df_index)
li_dup_df = DataFrame.from_records(li_dup, columns=["query", "duplicate hits"])
li_missing_df = DataFrame(li_missing, columns=["query"])

Expand Down Expand Up @@ -705,7 +700,7 @@ def _pluralize(s, optional=True):
}
)
return {
"class_name": "My" + biothing_type.title() + "Info",
"class_name": "Async" + "My" + biothing_type.title() + "Info",
"class_kwargs": _kwargs,
"mixins": [],
"attr_aliases": _aliases,
Expand Down
46 changes: 46 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ jsonld = ["PyLD>=0.7.2"]
async = [
"hishel>=0.0.31",
]
tests = [
"pytest>=8.3.2",
"pytest-asyncio>=0.23.8"
]



Expand All @@ -82,3 +86,45 @@ profile = "black"
combine_as_imports = true
line_length = 120
src_paths = ["."]

# pytest configuration
[tool.pytest.ini_options]
minversion = "6.2.5"
pythonpath = ["."]


# Options
addopts = [
"-rA",
"-vv",
"--doctest-modules",
"--setup-show",
"--capture=no",
"--tb=line",
"--durations=0",
"--showlocals",
"--strict-markers",
"--color=yes",
"--code-highlight=yes"
]

# Path
norecursedirs = [
".svn",
".git",
"_build",
"tmp*",
"lib",
"lib64",
]
testpaths = ["test"]

markers = [
"unit",
]

# Logging
log_cli = true
log_cli_level = "DEBUG"
log_cli_format = "%(asctime)s [%(levelname)] %(message)s"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
Client fixtures for usage across the biothings_client testing
"""

import pytest

from biothings_client import get_async_client
from biothings_client.client.definitions import AsyncMyGeneInfo


@pytest.fixture(scope="session")
def async_gene_client() -> AsyncMyGeneInfo:
"""
Fixture for generating an asynchronous mygene client
"""
client = "gene"
gene_client = get_async_client(client)
return gene_client
2 changes: 1 addition & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_get_client(self):
self.assertEqual(type(geneset_client).__name__, "MyGenesetInfo")

def test_generate_settings_from_url(self):
client_settings = biothings_client._generate_settings("geneset", url="https://mygeneset.info/v1")
client_settings = biothings_client.client.base.generate_settings("geneset", url="https://mygeneset.info/v1")
self.assertEqual(client_settings["class_kwargs"]["_default_url"], "https://mygeneset.info/v1")
self.assertEqual(client_settings["class_name"], "MyGenesetInfo")

Expand Down
84 changes: 84 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Test suite for the async client
"""

from typing import List

import pytest

import biothings_client


@pytest.mark.asyncio
@pytest.mark.parametrize(
"client_name,client_url,class_name",
[
(["gene"], "https://mygene.info/v3", "AsyncMyGeneInfo"),
(["variant"], "https://myvariant.info/v1", "AsyncMyVariantInfo"),
(["chem", "drug"], "https://mychem.info/v1", "AsyncMyChemInfo"),
(["disease"], "https://mydisease.info/v1", "AsyncMyDiseaseInfo"),
(["taxon"], "https://t.biothings.info/v1", "AsyncMyTaxonInfo"),
(["geneset"], "https://mygeneset.info/v1", "AsyncMyGenesetInfo"),
],
)
async def test_get_async_client(client_name: List[str], client_url: str, class_name: str):
"""
Tests our ability to generate async clients
"""
client_name_instances = [biothings_client.get_async_client(name) for name in client_name]
client_url_instance = biothings_client.get_async_client(url=client_url)
clients = [client_url_instance, *client_name_instances]
for client in clients:
assert type(client).__name__ == class_name


@pytest.mark.asyncio
@pytest.mark.parametrize(
"client_name,client_url,class_name",
[
("gene", "https://mygene.info/v3", "AsyncMyGeneInfo"),
("variant", "https://mychem.info/v1", "AsyncMyVariantInfo"),
("chem", "https://mychem.info/v1", "AsyncMyChemInfo"),
("disease", "https://mydisease.info/v1", "AsyncMyDiseaseInfo"),
("taxon", "https://t.biothings.info/v1", "AsyncMyTaxonInfo"),
("geneset", "https://mygeneset.info/v1", "AsyncMyGenesetInfo"),
],
)
async def test_generate_async_settings(client_name: str, client_url: str, class_name: str):
client_settings = biothings_client.client.asynchronous.generate_async_settings(client_name, url=client_url)
assert client_settings["class_kwargs"]["_default_url"] == client_url
assert client_settings["class_name"] == class_name


@pytest.mark.asyncio
@pytest.mark.parametrize(
"client_name",
(
"gene",
"variant",
"chem",
"disease",
"taxon",
"geneset",
),
)
async def test_url_protocol(client_name: str):
"""
Tests that our HTTP protocol modification methods work
as expected when transforming to either HTTP or HTTPS
"""
client_instance = biothings_client.get_async_client(client_name)

http_protocol = "http://"
https_protocol = "https://"

# DEFAULT: HTTPS
assert client_instance.url.startswith(https_protocol)

# Transform to HTTP
await client_instance.use_http()
assert client_instance.url.startswith(http_protocol)

# Transform back to HTTPS
await client_instance.use_https()
client_instance.url.startswith(https_protocol)
36 changes: 0 additions & 36 deletions tests/test_async_client.py

This file was deleted.

Loading

0 comments on commit 72e8319

Please sign in to comment.