Skip to content

Commit

Permalink
Run mypy through tox and address type check errors (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh authored Sep 19, 2023
2 parents e9cb6ff + 92d370e commit ba40fee
Show file tree
Hide file tree
Showing 45 changed files with 609 additions and 438 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/qc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ jobs:
run: poetry run tox -e flake8
# - name: Check package metadata with Pyroma
# run: poetry run tox -e pyroma
# - name: Check static typing with MyPy
# run: poetry run tox -e mypy
- name: Check static typing with MyPy
run: poetry run tox -e mypy

- name: Unit tests only
run: poetry run python -m unittest discover tests.unit
2 changes: 1 addition & 1 deletion src/ontogpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
rel_path = Path(__file__).resolve()
models_path = rel_path.parent / "models.yaml"

with open(models_path, 'r') as models_file:
with open(models_path, "r") as models_file:
MODELS = (safe_load(models_file))["models"]
for model in MODELS:
if "is_default" in model:
Expand Down
35 changes: 19 additions & 16 deletions src/ontogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from io import BytesIO, TextIOWrapper
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Union

import click
import jsonlines
Expand All @@ -31,8 +31,8 @@
from ontogpt.engines.embedding_similarity_engine import SimilarityEngine
from ontogpt.engines.enrichment import EnrichmentEngine
from ontogpt.engines.generic_engine import GenericEngine, QuestionCollection
from ontogpt.engines.gpt4all_engine import GPT4AllEngine
from ontogpt.engines.halo_engine import HALOEngine
from ontogpt.engines.gpt4all_engine import GPT4AllEngine # type: ignore
from ontogpt.engines.halo_engine import HALOEngine # type: ignore

# from ontogpt.engines.hfhub_engine import HFHubEngine
from ontogpt.engines.knowledge_engine import KnowledgeEngine
Expand All @@ -43,8 +43,7 @@
from ontogpt.engines.synonym_engine import SynonymEngine
from ontogpt.evaluation.enrichment.eval_enrichment import EvalEnrichment
from ontogpt.evaluation.resolver import create_evaluator
from ontogpt.io.csv_wrapper import write_obj_as_csv
from ontogpt.io.csv_wrapper import output_parser
from ontogpt.io.csv_wrapper import output_parser, write_obj_as_csv
from ontogpt.io.html_exporter import HTMLExporter
from ontogpt.io.markdown_exporter import MarkdownExporter
from ontogpt.utils.gene_set_utils import (
Expand Down Expand Up @@ -86,12 +85,13 @@ def _as_text_writer(f):
def write_extraction(
results: ExtractionResult,
output: BytesIO,
output_format: str = None,
knowledge_engine: KnowledgeEngine = None,
output_format: str,
knowledge_engine: KnowledgeEngine,
):
"""Write results of extraction to a given output stream."""
# Check if this result contains anything writable first
if results.extracted_object:
exporter: Union[MarkdownExporter, HTMLExporter, RDFExporter, OWLExporter]
if output_format == "pickle":
output.write(pickle.dumps(results))
elif output_format == "md":
Expand All @@ -100,11 +100,11 @@ def write_extraction(
exporter.export(results, output)
elif output_format == "html":
output = _as_text_writer(output)
exporter = HTMLExporter()
exporter = HTMLExporter(output=output)
exporter.export(results, output)
elif output_format == "yaml":
output = _as_text_writer(output)
output.write(dump_minimal_yaml(results))
output.write(dump_minimal_yaml(results).encode("utf-8"))
elif output_format == "turtle":
output = _as_text_writer(output)
exporter = RDFExporter()
Expand All @@ -117,11 +117,13 @@ def write_extraction(
# output = _as_text_writer(output)
# output.write(write_obj_as_csv(results))
output = _as_text_writer(output)
output.write(dump_minimal_yaml(results))
output.write(output_parser(results))
output.write(dump_minimal_yaml(results).encode("utf-8"))
with open("output.kgx.tsv") as secondoutput:
for line in output_parser(obj=results, file=output):
secondoutput.write(line)
else:
output = _as_text_writer(output)
output.write(dump_minimal_yaml(results))
output.write(dump_minimal_yaml(results).encode("utf-8"))


def get_model_by_name(modelname: str):
Expand Down Expand Up @@ -1477,16 +1479,17 @@ def fill(model, template, object: str, examples, output, output_format, show_pro
"""Fill in missing values."""
logging.info(f"Creating for {template}")

ke: KnowledgeEngine

# Choose model based on input, or use the default
if not model:
model = DEFAULT_MODEL
selectmodel = get_model_by_name(model)
model_source = selectmodel["provider"]

if model_source == "OpenAI":
ke = SPIRESEngine(template, **kwargs)

elif model_source == "GPT4All":
ke = SPIRESEngine(template=template, **kwargs)
else:
model_name = selectmodel["alternative_names"][0]
ke = GPT4AllEngine(template=template, model=model_name, **kwargs)

Expand All @@ -1495,7 +1498,7 @@ def fill(model, template, object: str, examples, output, output_format, show_pro
logging.info(f"Loading {examples}")
examples = yaml.safe_load(examples)
logging.debug(f"Input object: {object}")
results = ke.generalize(object, examples, show_prompt)
results = ke.generalize(object=object, examples=examples, show_prompt=show_prompt)

output.write(yaml.dump(results.dict()))

Expand Down
20 changes: 11 additions & 9 deletions src/ontogpt/clients/hfhub_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""HuggingFace Hub Client."""
import logging

from dataclasses import dataclass

from langchain import HuggingFaceHub, LLMChain, PromptTemplate
from oaklib.utilities.apikey_manager import get_apikey_value
from langchain import HuggingFaceHub, PromptTemplate, LLMChain

# Note: See https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads
# for all relevant models


@dataclass
class HFHubClient:
"""A client for the HuggingFace Hub API."""
Expand All @@ -23,15 +24,16 @@ def get_model(self, modelname: str) -> HuggingFaceHub:
Returns a model object of type
langchain.llms.huggingface_hub.HuggingFaceHub
"""
model = HuggingFaceHub(repo_id=modelname,
verbose=True,
model_kwargs={"temperature": 0.2, "max_length": 500},
huggingfacehub_api_token=self.api_key,
task="text-generation"
)
model = HuggingFaceHub(
repo_id=modelname,
verbose=True,
model_kwargs={"temperature": 0.2, "max_length": 500},
huggingfacehub_api_token=self.api_key,
task="text-generation",
)

return model

def query_hf_model(self, llm, prompt_text):
"""Interact with a GPT4All model."""
logging.info(f"Complete: prompt[{len(prompt_text)}]={prompt_text[0:100]}...")
Expand Down
17 changes: 9 additions & 8 deletions src/ontogpt/clients/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from pathlib import Path
from time import sleep
from typing import Iterator, Tuple
from typing import Iterator, Optional, Tuple

import numpy as np
import openai
Expand All @@ -20,9 +20,9 @@
class OpenAIClient:
# max_tokens: int = field(default_factory=lambda: 3000)
model: str = field(default_factory=lambda: "gpt-3.5-turbo")
cache_db_path: str = None
api_key: str = None
interactive: bool = None
cache_db_path: str = ""
api_key: str = ""
interactive: Optional[bool] = None

def __post_init__(self):
if not self.api_key:
Expand Down Expand Up @@ -101,7 +101,9 @@ def db_connection(self):
cur.execute("CREATE TABLE cache (prompt, engine, payload)")
return cur

def _interactive_completion(self, prompt: str, engine: str, max_tokens: int = None, **kwargs):
def _interactive_completion(
self, prompt: str, engine: str, max_tokens: Optional[int], **kwargs
):
print("Please use the ChatGPT interface to complete the following prompt:")
print(f"IMPORTANT: make sure model == {engine}")
print(f"Note: max_tokens == {max_tokens}")
Expand All @@ -119,7 +121,7 @@ def _interactive_completion(self, prompt: str, engine: str, max_tokens: int = No
return self._interactive_completion(prompt, engine, max_tokens, **kwargs)

def cached_completions(
self, search_term: str = None, engine: str = None
self, search_term: str = "", engine: str = ""
) -> Iterator[Tuple[str, str, str]]:
if search_term:
search_term = search_term.lower()
Expand All @@ -142,8 +144,7 @@ def _must_use_chat_api(self) -> bool:
return False
return True

def embeddings(self, text: str, model: str = None):

def embeddings(self, text: str, model: str = ""):
text = str(text)

if model is None:
Expand Down
17 changes: 10 additions & 7 deletions src/ontogpt/clients/pubmed_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import time
from dataclasses import dataclass
from typing import List, Tuple, Union
from typing import Generator, List, Tuple, Union
from urllib import parse

import inflection
Expand Down Expand Up @@ -337,6 +337,7 @@ def text(
)

txt = []
onetxt = ""
for doc in these_docs:
if len(doc) > self.max_text_length and not raw:
logging.warning(
Expand All @@ -348,9 +349,11 @@ def text(
txt.append(doc)
if singledoc and not pubmedcental:
onetxt = txt[0]
txt = onetxt

return txt
if len(onetxt) > 0:
return onetxt
else:
return txt

def pmc_text(self, pmc_id: str) -> str:
"""Get the text of one PubMed Central entry.
Expand Down Expand Up @@ -398,14 +401,14 @@ def pmc_text(self, pmc_id: str) -> str:

return xml_data

def search(self, term: str, keywords: List[str] = None) -> List[PMID]:
def search(self, term: str, keywords: List[str]) -> Generator[PMID, None, None]:
"""Get the quality-scored text of PubMed papers relating to a search term and keywords.
This generator yields PMIDs. Note this uses the MAX_PMIDS value
to determine how many documents to collect.
:param term: search term, a string
:param keywords: keywords, a list of strings
:return: a list of PMIDs corresponding to the search term and keywords
:return: PMIDs corresponding to the search term and keywords
"""
if keywords:
keywords = [_normalize(kw) for kw in keywords]
Expand Down Expand Up @@ -485,7 +488,7 @@ def parse_pmxml(self, xml: str, raw: bool, autoformat: bool, pubmedcentral: bool
ab = ""
if pa.find("Abstract"): # Document may not have abstract
ab = pa.find("Abstract").text
kw = ""
kw = [""]
if pa.find("KeywordList"): # Document may not have MeSH terms or keywords
kw = [tag.text for tag in pa.find_all("Keyword")]
txt = f"Title: {ti}\nKeywords: {'; '.join(kw)}\nPMID: {pmid}\nAbstract: {ab}"
Expand All @@ -502,7 +505,7 @@ def parse_pmxml(self, xml: str, raw: bool, autoformat: bool, pubmedcentral: bool
ti = pa.find("ArticleTitle").text
if pa.find("Abstract"): # Document may not have abstract
body = pa.find("Abstract").text + body
kw = ""
kw = [""]
if pa.find("KeywordList"): # Document may not have MeSH terms or keywords
kw = [tag.text for tag in pa.find_all("Keyword")]

Expand Down
5 changes: 3 additions & 2 deletions src/ontogpt/converters/ontology_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def from_obograph(self, graph: Graph) -> Ontology:
f"{' and '.join(genus_elts)} and {' and '.join(differentia)}"
)
logging.info(f"Equiv[{element.name}] = {element.equivalent_to}")
for element in element_index.values():
ontology.elements.append(element)
if ontology.elements is not None:
for element in element_index.values():
ontology.elements.append(element)
return ontology

def node_to_name(self, curie: str, label: Optional[str] = None) -> str:
Expand Down
8 changes: 4 additions & 4 deletions src/ontogpt/engines/embedding_similarity_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

@dataclass
class EmbeddingSimilarity:
subject_id: str = None
subject_label: str = None
object_id: str = None
object_label: str = None
subject_id: str = ""
subject_label: str = ""
object_id: str = ""
object_label: str = ""
embedding_cosine_similarity: float = None
object_rank_for_subject: int = None

Expand Down
Loading

0 comments on commit ba40fee

Please sign in to comment.