Skip to content

Commit

Permalink
Merge pull request #360 from nanglo123/pydantic2
Browse files Browse the repository at this point in the history
Update code to be compatible with Pydantic2
  • Loading branch information
bgyori authored Sep 25, 2024
2 parents 51d0c98 + 3a0b723 commit 44256f9
Show file tree
Hide file tree
Showing 46 changed files with 2,122 additions and 1,414 deletions.
40 changes: 20 additions & 20 deletions mira/dkg/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@
class RelationQuery(BaseModel):
"""A query for relations in the domain knowledge graph."""

source_type: Optional[str] = Field(description="The source type (i.e., prefix)", example="vo")
source_type: Optional[str] = Field(None, description="The source type (i.e., prefix)", examples=["vo"])
source_curie: Optional[str] = Field(
description="The source compact URI (CURIE)", example="doid:946"
None, description="The source compact URI (CURIE)", examples=["doid:946"]
)
target_type: Optional[str] = Field(
description="The target type (i.e., prefix)", example="ncbitaxon"
None, description="The target type (i.e., prefix)", examples=["ncbitaxon"]
)
target_curie: Optional[str] = Field(
description="The target compact URI (CURIE)", example="ncbitaxon:10090"
None, description="The target compact URI (CURIE)", examples=["ncbitaxon:10090"]
)
relations: Union[None, str, List[str]] = Field(
description="A relation string or list of relation strings", example="vo:0001243"
None, description="A relation string or list of relation strings", examples=["vo:0001243"]
)
relation_direction: Literal["right", "left", "both"] = Field(
"right", description="The direction of the relationship"
Expand All @@ -50,7 +50,7 @@ class RelationQuery(BaseModel):
ge=0,
)
limit: Optional[int] = Field(
description="A limit on the number of records returned", example=50, ge=0
None, description="A limit on the number of records returned", examples=[50], ge=0
)
full: bool = Field(
False,
Expand Down Expand Up @@ -178,12 +178,12 @@ def get_transitive_closure(
class RelationResponse(BaseModel):
"""A triple (or multi-predicate triple) with abbreviated data."""

subject: str = Field(description="The CURIE of the subject of the triple", example="doid:96")
subject: str = Field(description="The CURIE of the subject of the triple", examples=["doid:96"])
predicate: Union[str, List[str]] = Field(
description="A predicate or list of predicates as CURIEs",
example="ro:0002452",
examples=["ro:0002452"],
)
object: str = Field(description="The CURIE of the object of the triple", example="symp:0000001")
object: str = Field(description="The CURIE of the object of the triple", examples=["symp:0000001"])


class FullRelationResponse(BaseModel):
Expand All @@ -203,7 +203,7 @@ def get_relations(
request: Request,
relation_query: RelationQuery = Body(
...,
examples={
examples=[{
"source type query": {
"summary": "Query relations with a given source node type",
"value": {
Expand Down Expand Up @@ -285,7 +285,7 @@ def get_relations(
"full": True,
},
},
},
}],
),
):
"""Get relations based on the query sent.
Expand Down Expand Up @@ -392,10 +392,10 @@ def add_resources(
request: Request,
resource_prefix_list: List[str] = Body(
...,
description="A of resources to add to the DKG",
description="A list of resources to add to the DKG",
title="Resource Prefixes",
example=["probonto", "wikidata", "eiffel", "geonames", "ncit",
"nbcbitaxon"],
examples=[["probonto", "wikidata", "eiffel", "geonames", "ncit",
"nbcbitaxon"]],
)
):
"""From a list of resource prefixes, add a list of nodes and edges
Expand All @@ -418,10 +418,10 @@ class IsOntChildResult(BaseModel):
"""Result of a query to /is_ontological_child"""

child_curie: str = Field(...,
example="vo:0001113",
examples=["vo:0001113"],
description="The child CURIE")
parent_curie: str = Field(...,
example="obi:0000047",
examples=["obi:0000047"],
description="The parent CURIE")
is_child: bool = Field(
...,
Expand All @@ -446,7 +446,7 @@ def is_ontological_child(
request: Request,
query: IsOntChildQuery = Body(
...,
example={"child_curie": "vo:0001113", "parent_curie": "obi:0000047"},
examples=[{"child_curie": "vo:0001113", "parent_curie": "obi:0000047"}],
)
):
"""Check if one CURIE is an ontological child of another CURIE"""
Expand Down Expand Up @@ -490,7 +490,7 @@ def search(
labels: Optional[str] = Query(
default=None,
description="A comma-separated list of labels",
examples={
examples=[{
"no label filter": {
"summary": "Don't filter by label",
"value": None,
Expand All @@ -499,7 +499,7 @@ def search(
"summary": "Search for units, which are labeled as `unit`",
"value": "unit",
},
},
}],
),
wikidata_fallback: bool = Query(
default=False,
Expand Down Expand Up @@ -530,7 +530,7 @@ class ParentQuery(BaseModel):
def common_parent(
request: Request,
query: ParentQuery = Body(
..., example={"curie1": "ido:0000566", "curie2": "ido:0000567"}
..., examples=[{"curie1": "ido:0000566", "curie2": "ido:0000567"}]
),
):
"""Get the common parent of two CURIEs"""
Expand Down
9 changes: 5 additions & 4 deletions mira/dkg/askemo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_askemo_terms() -> Mapping[str, Term]:
"""Load the epi ontology JSON."""
rv = {}
for obj in json.loads(ONTOLOGY_PATH.read_text()):
term = Term.parse_obj(obj)
term = Term.model_validate(obj)
rv[term.id] = term
return rv

Expand All @@ -79,22 +79,23 @@ def get_askemosw_terms() -> Mapping[str, Term]:
"""Load the space weather ontology JSON."""
rv = {}
for obj in json.loads(SW_ONTOLOGY_PATH.read_text()):
term = Term.parse_obj(obj)
term = Term.model_validate(obj)
rv[term.id] = term
return rv

def get_askem_climate_ontology_terms() -> Mapping[str, Term]:
"""Load the space weather ontology JSON."""
rv = {}
for obj in json.loads(CLIMATE_ONTOLOGY_PATH.read_text()):
term = Term.parse_obj(obj)
term = Term.model_validate(obj)
rv[term.id] = term
return rv


def write(ontology: Mapping[str, Term], path: Path) -> None:
terms = [
term.dict(exclude_unset=True, exclude_defaults=True, exclude_none=True)
term.model_dump(exclude_unset=True, exclude_defaults=True,
exclude_none=True)
for _curie, term in sorted(ontology.items())
]
path.write_text(
Expand Down
60 changes: 30 additions & 30 deletions mira/dkg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pystow
import requests
from neo4j import GraphDatabase, Transaction, unit_of_work
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from tqdm import tqdm
from typing_extensions import Literal, TypeAlias

Expand Down Expand Up @@ -44,82 +44,82 @@
class Relation(BaseModel):
"""A relationship between two entities in the DKG"""
source_curie: str = Field(
description="The curie of the source node", example="probonto:k0000000"
description="The curie of the source node", examples=["probonto:k0000000"]
)
target_curie: str = Field(
description="The curie of the target node", example="probonto:k0000007"
description="The curie of the target node", examples=["probonto:k0000007"]
)
type: str = Field(
description="The type of the relation", example="has_parameter"
description="The type of the relation", examples=["has_parameter"]
)
pred: str = Field(
description="The curie of the relation type",
example="probonto:c0000062"
examples=["probonto:c0000062"]
)
source: str = Field(
description="The prefix of the relation curie", example="probonto"
description="The prefix of the relation curie", examples=["probonto"]
)
graph: str = Field(
description="The URI of the relation",
example="https://raw.githubusercontent.com/probonto"
"/ontology/master/probonto4ols.owl"
examples=["https://raw.githubusercontent.com/probonto"
"/ontology/master/probonto4ols.owl"]
)
version: str = Field(
description="The version number", example="2.5"
description="The version number", examples=["2.5"]
)


class Entity(BaseModel):
"""An entity in the domain knowledge graph."""

id: str = Field(
..., title="Compact URI", description="The CURIE of the entity", example="ido:0000511"
..., title="Compact URI", description="The CURIE of the entity", examples=["ido:0000511"]
)
name: Optional[str] = Field(description="The name of the entity", example="infected population")
type: EntityType = Field(..., description="The type of the entity", example="class")
obsolete: bool = Field(..., description="Is the entity marked obsolete?", example=False)
name: Optional[str] = Field(None, description="The name of the entity", examples=["infected population"])
type: EntityType = Field(..., description="The type of the entity", examples=["class"])
obsolete: bool = Field(..., description="Is the entity marked obsolete?", examples=[False])
description: Optional[str] = Field(
description="The description of the entity.",
example="An organism population whose members have an infection.",
None, description="The description of the entity.",
examples=["An organism population whose members have an infection."],
)
synonyms: List[Synonym] = Field(
default_factory=list, description="A list of string synonyms", example=[]
default_factory=list, description="A list of string synonyms", examples=[[]]
)
alts: List[str] = Field(
title="Alternative Identifiers",
default_factory=list,
example=[],
examples=[[]],
description="A list of alternative identifiers, given as CURIE strings.",
)
xrefs: List[Xref] = Field(
title="Database Cross-references",
default_factory=list,
example=[],
examples=[[]],
description="A list of database cross-references, given as CURIE strings.",
)
labels: List[str] = Field(
default_factory=list,
example=["ido"],
examples=[["ido"]],
description="A list of Neo4j labels assigned to the entity.",
)
properties: Dict[str, List[str]] = Field(
default_factory=dict,
description="A mapping of properties to their values",
example={},
examples=[{}],
)
# Gets auto-populated
link: Optional[str] = None

@validator("link")
def set_link(cls, value, values):
@field_validator("link")
def set_link(cls, value, validation_info):
"""
Set the value of the ``link`` field based on the value of the ``id``
field. This gets run as a post-init hook by Pydantic
See also:
https://stackoverflow.com/questions/54023782/pydantic-make-field-none-in-validator-based-on-other-fields-value
"""
curie = values["id"]
curie = validation_info.data["id"]
return f"{METAREGISTRY_BASE}/{curie}"

@property
Expand Down Expand Up @@ -209,7 +209,7 @@ def as_askem_entity(self):
raise ValueError(f"can only call as_askem_entity() on ASKEM ontology terms")
if isinstance(self, AskemEntity):
return self
data = self.dict()
data = self.model_dump()
return AskemEntity(
**data,
physical_min=self._get_single_property(
Expand All @@ -235,12 +235,12 @@ class AskemEntity(Entity):
"""An extended entity with more ASKEM stuff loaded in."""

# TODO @ben please write descriptions for these
physical_min: Optional[float] = Field(description="")
physical_max: Optional[float] = Field(description="")
suggested_data_type: Optional[str] = Field(description="")
suggested_unit: Optional[str] = Field(description="")
typical_min: Optional[float] = Field(description="")
typical_max: Optional[float] = Field(description="")
physical_min: Optional[float] = Field(None, description="")
physical_max: Optional[float] = Field(None, description="")
suggested_data_type: Optional[str] = Field(None, description="")
suggested_unit: Optional[str] = Field(None, description="")
typical_min: Optional[float] = Field(None, description="")
typical_max: Optional[float] = Field(None, description="")


class Neo4jClient:
Expand Down
7 changes: 6 additions & 1 deletion mira/dkg/construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,12 @@ def main(
):
"""Generate the node and edge files."""
if Path(use_case).is_file():
config = DKGConfig.parse_file(use_case)
with open(use_case, 'r') as file:
file_content = file.read()
if use_case.lower().endswith(".json"):
config = DKGConfig.model_validate_json(file_content)
else:
config = DKGConfig.model_validate(file_content)
use_case = config.use_case
else:
config = None
Expand Down
11 changes: 9 additions & 2 deletions mira/dkg/construct_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ def _construct_registry(
edges_path: Optional[Path] = None,
upload: bool = False,
):
config = Config.parse_file(config_path)
with config_path.open("r") as file:
file_content = file.read()

if config_path.suffix.lower() == ".json":
config = Config.model_validate_json(file_content)
else:
config = Config.model_validate(file_content)

prefixes = get_prefixes(nodes_path=nodes_path, edges_path=edges_path)
manager = Manager(
Expand All @@ -140,7 +146,8 @@ def _construct_registry(
)

output_path.write_text(
json.dumps(new_config.dict(exclude_none=True, exclude_unset=True), indent=2)
json.dumps(new_config.model_dump(exclude_none=True,
exclude_unset=True), indent=2)
)
if upload:
upload_s3(output_path, use_case="epi")
Expand Down
Loading

0 comments on commit 44256f9

Please sign in to comment.