Skip to content

Commit

Permalink
remove mlem
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Apr 10, 2024
1 parent bb85873 commit 9c32ea5
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ A package to help create and deploy Translator Reasoner APIs (TRAPI) from any pr
The **TRAPI Predict Kit** helps data scientists to build, and **publish prediction models** in a [FAIR](https://www.go-fair.org/fair-principles/) and reproducible manner. It provides helpers for various steps of the process:

* A template to help user quickly bootstrap a new prediction project with the recommended structure ([MaastrichtU-IDS/cookiecutter-trapi-predict-kit](https://github.com/MaastrichtU-IDS/cookiecutter-trapi-predict-kit/))
* Helper function to easily save a generated model, its metadata, and the data used to generate it. It uses tools such as [`dvc`](https://dvc.org/) and [`mlem`](https://mlem.ai/) to store large model outside of the git repository.
* Helper function to easily save a generated model, its metadata, and the data used to generate it. It uses tools such as [`dvc`](https://dvc.org/) to store large model outside of the git repository.
* Deploy API endpoints for retrieving predictions, which comply with the NCATS Biomedical Data Translator standards ([Translator Reasoner API](https://github.com/NCATSTranslator/ReasonerAPI) and [BioLink model](https://github.com/biolink/biolink-model)), using a decorator `@trapi_predict` to simply annotate the function that produces predicted associations for a given input entity

Predictions are usually generated from machine learning models (e.g. predict disease treated by drug), but it can adapt to generic python function, as long as the input params and return object follow the expected structure.
Expand Down Expand Up @@ -182,7 +182,7 @@ uvicorn trapi.main:app --port 8808 --reload

### 💾 Save a generated model

Helper function to easily save a generated model, its metadata, and the data used to generate it. It uses tools such as [`dvc`](https://dvc.org/) and [`mlem`](https://mlem.ai/) to store large model outside of the git repository.
Helper function to easily save a generated model, its metadata, and the data used to generate it. It uses tools such as [`dvc`](https://dvc.org/) to store large model outside of the git repository.

```python
from trapi_predict_kit import save
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"fastapi >=0.68.1",
"rdflib >=6.1.1",
"reasoner-pydantic >=3.0.1",
"mlem",
# "mlem",
"dvc",
"bmt",
# "fairworkflows @ git+https://github.com/vemonet/fairworkflows.git",
Expand Down
22 changes: 11 additions & 11 deletions src/trapi_predict_kit/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Any, Optional

from mlem import api as mlem
# from mlem import api as mlem
from rdflib import Graph

# from mlem.api import save as mlem_save, load as mlem_load
Expand Down Expand Up @@ -42,11 +42,11 @@ def save(

# mlem_model = MlemModel.from_obj(model, sample_data=sample_data)
# mlem_model.dump(path)
if method == "mlem":
mlem.save(model, path, sample_data=sample_data)
else:
with open(path, "wb") as f:
pickle.dump(model, f)
# if method == "mlem":
# mlem.save(model, path, sample_data=sample_data)
# else:
with open(path, "wb") as f:
pickle.dump(model, f)

g = get_run_metadata(scores, sample_data, hyper_params, model_name)
g.serialize(f"{path}.ttl", format="ttl")
Expand All @@ -63,11 +63,11 @@ def save(

def load(path: str, method: str = "pickle") -> LoadedModel:
log.info(f"💽 Loading model from {path} using {method}")
if method == "mlem":
model = mlem.load(path)
else:
with open(path, "rb") as f:
model = pickle.load(f)
# if method == "mlem":
# model = mlem.load(path)
# else:
with open(path, "rb") as f:
model = pickle.load(f)

g = Graph()
g.parse(f"{path}.ttl", format="ttl")
Expand Down
18 changes: 9 additions & 9 deletions tests/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def test_save_pickle():
shutil.rmtree(tmp_path)


def test_save_mlem():
"""Test to save and load a basic model with mlem"""
save(model, model_path, sample_data=data, scores=scores, hyper_params=hyper_params, method="mlem")
assert Path(model_path).is_file()
assert Path(f"{model_path}.mlem").is_file()
assert Path(f"{model_path}.ttl").is_file()
loaded_model = load(model_path, method="mlem")
assert loaded_model.model is not None
shutil.rmtree(tmp_path)
# def test_save_mlem():
# """Test to save and load a basic model with mlem"""
# save(model, model_path, sample_data=data, scores=scores, hyper_params=hyper_params, method="mlem")
# assert Path(model_path).is_file()
# assert Path(f"{model_path}.mlem").is_file()
# assert Path(f"{model_path}.ttl").is_file()
# loaded_model = load(model_path, method="mlem")
# assert loaded_model.model is not None
# shutil.rmtree(tmp_path)

0 comments on commit 9c32ea5

Please sign in to comment.