Skip to content

Commit

Permalink
Associate artifacts with extracted models (#298)
Browse files Browse the repository at this point in the history
Co-authored-by: Todd Roper <toddroper@me.com>
  • Loading branch information
brandomr and toddroper authored Aug 7, 2023
1 parent 398f911 commit e70471e
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 3 deletions.
3 changes: 2 additions & 1 deletion graph_relations.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
],
"EXTRACTED_FROM":[
["Model","Publication"],
["Dataset","Publication"]
["Dataset","Publication"],
["Model", "Artifact"]
],
"CONTAINS":[
["Project","Publication"],
Expand Down
1 change: 1 addition & 0 deletions tds/db/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ProvenanceSearchTypes(str, Enum):
parent_model_revisions = "parent_model_revisions"
parent_models = "parent_models"
parent_nodes = "parent_nodes"
models_from_code = "models_from_code"


class RelationType(str, Enum):
Expand Down
23 changes: 22 additions & 1 deletion tds/db/graph/search_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def set_limit_level(limit):
"""
)
response = session.run(query)

return nodes_edges(
response=response,
nodes=payload.get("nodes", True),
Expand Down Expand Up @@ -315,3 +314,25 @@ def concept_counts(self, payload):
for key in response:
counts[key] += 1
return counts

def models_from_code(self, payload):
"""
Identifies the code source artifact from which a model was extracted
"""
if payload.get("root_type") not in ("Model"):
raise HTTPException(
status_code=400,
detail="Code artifacts used for model extraction can "
"only be found by providing a Model",
)
with self.graph_db.session() as session:
model_id = payload["root_id"]

query = """
MATCH (a:Artifact)<-[r:EXTRACTED_FROM]-(m:Model {id: $model_id})
RETURN a
"""

response = session.run(query, {"model_id": model_id})
response_data = [res.data()["a"]["id"] for res in response]
return response_data
5 changes: 5 additions & 0 deletions tds/modules/provenance/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def search_provenance(
* Requirements: “root_type”, “root_id”
* Allowed root _types are Model *will be expanded.
**models_from_code** - Returns the artifact `id` for the artifact
from which a model was extracted
* Requirements: "root_type", “root_id”
* Allowed `root_types` are Model
## Payload format
The payload for searching needs to match the schema below.
Expand Down
2 changes: 1 addition & 1 deletion tds/modules/provenance/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ProvenanceSearch(BaseModel):
Provenance Data Model.
"""

root_id: Optional[int]
root_id: Optional[int | str]
root_type: Optional[ProvenanceType]
user_id: Optional[int]
curie: Optional[str]
Expand Down
1 change: 1 addition & 0 deletions tds/schema/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ class ProvenancePayload(BaseModel):
ProvenanceType.Simulation: "Si",
ProvenanceType.Project: "Pr",
ProvenanceType.Concept: "Cn",
ProvenanceType.Artifact: "Ar",
}

0 comments on commit e70471e

Please sign in to comment.