Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fastapi integration #1199

Merged
merged 21 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
531321b
work in progress fastapi integration
miquelduranfrigola Jun 7, 2024
94a1492
WIP
miquelduranfrigola Jul 12, 2024
7092852
Merge branch 'master' into fastapi-integration
miquelduranfrigola Jul 12, 2024
6a08ff7
Merge branch 'master' into fastapi-integration
miquelduranfrigola Jul 12, 2024
45df8d9
fastapi integration in serve and fetch
miquelduranfrigola Jul 12, 2024
7880743
blackened
miquelduranfrigola Jul 12, 2024
d709575
added conda environment in the ersilia pack functions
miquelduranfrigola Jul 15, 2024
502f14a
tabular formatting using fastapi
miquelduranfrigola Jul 15, 2024
1045bf1
add pack method constants
DhanshreeA Jul 16, 2024
834d5b7
Add logging, and replace string literals
DhanshreeA Jul 16, 2024
22740e8
replace string literals, replace info.json with information.json for …
DhanshreeA Jul 16, 2024
04e11ab
correctly read input from card based on how it's structured
DhanshreeA Jul 16, 2024
f63025b
replace string literal
DhanshreeA Jul 16, 2024
f430df2
remove print
DhanshreeA Jul 16, 2024
60f9a2a
remove string literals
DhanshreeA Jul 16, 2024
7aa6f8e
use information.json instead of metadata.json for reading model card …
DhanshreeA Jul 16, 2024
9f250d0
bypass bentoml code with template resolver
DhanshreeA Jul 16, 2024
276496a
Merge branch 'master' into fastapi-integration
DhanshreeA Jul 16, 2024
edc63b1
Fix LocalCard class; fix constant imports
DhanshreeA Jul 17, 2024
945db9b
fix file loading in LocalCard
DhanshreeA Jul 17, 2024
b6bf68d
Format with black; fix broken tests
DhanshreeA Jul 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions .github/scripts/convert_airtable_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,50 @@
import requests

AIRTABLE_MODEL_HUB_BASE_ID = "appgxpCzCDNyGjWc8"
AIRTABLE_TABLE_ID = 'tblZGe2a2XeBxrEHP'
AIRTABLE_TABLE_ID = "tblZGe2a2XeBxrEHP"
AWS_ACCOUNT_REGION = "eu-central-1"
ERSILIA_MODEL_HUB_S3_BUCKET= 'ersilia-model-hub'
ERSILIA_MODEL_HUB_S3_BUCKET = "ersilia-model-hub"

def convert_airtable_to_json(airtable_api_key, aws_access_key_id, aws_secret_access_key):

headers = {'Authorization': f'Bearer {airtable_api_key}'}
response= requests.get(f'https://api.airtable.com/v0/{AIRTABLE_MODEL_HUB_BASE_ID}/{AIRTABLE_TABLE_ID}', headers=headers)

data=response.json()
records_models= [record['fields'] for record in data['records']]
models_json=json.dumps(records_models, indent=4)
def convert_airtable_to_json(
airtable_api_key, aws_access_key_id, aws_secret_access_key
):
headers = {"Authorization": f"Bearer {airtable_api_key}"}
response = requests.get(
f"https://api.airtable.com/v0/{AIRTABLE_MODEL_HUB_BASE_ID}/{AIRTABLE_TABLE_ID}",
headers=headers,
)

#Load JSON in AWS S3 bucket
s3 = boto3.client('s3',aws_access_key_id=aws_access_key_id,aws_secret_access_key=aws_secret_access_key,region_name=AWS_ACCOUNT_REGION)
data = response.json()
records_models = [record["fields"] for record in data["records"]]
models_json = json.dumps(records_models, indent=4)

# Load JSON in AWS S3 bucket
s3 = boto3.client(
"s3",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=AWS_ACCOUNT_REGION,
)
try:
s3.put_object(Body=models_json, Bucket=ERSILIA_MODEL_HUB_S3_BUCKET, Key='models.json', ACL='public-read')
s3.put_object(
Body=models_json,
Bucket=ERSILIA_MODEL_HUB_S3_BUCKET,
Key="models.json",
ACL="public-read",
)
print("file models.json uploaded")
except NoCredentialsError:
logging.error("Unable to upload tracking data to AWS: Credentials not found")
except ClientError as e:
logging.error(e)

if __name__ == "__main__":


if __name__ == "__main__":
print("Getting environmental variables")
airtable_api_key = os.environ.get('AIRTABLE_API_KEY')
aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
airtable_api_key = os.environ.get("AIRTABLE_API_KEY")
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")

print("Converting AirTable base to JSON file")
convert_airtable_to_json(airtable_api_key, aws_access_key_id, aws_secret_access_key)
2 changes: 1 addition & 1 deletion .github/scripts/place_a_dockerfile_in_current_eos_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ def download_file(url, filename):
text = text.replace("eos_identifier", model_id)

with open("Dockerfile", "w") as f:
f.write(text)
f.write(text)
6 changes: 4 additions & 2 deletions .github/scripts/update_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ def populate_metadata(self):
# Check if model_description is a list
if isinstance(self.json_input["model_description"], list):
# Join the list elements into a single string separated by commas
self.metadata["Description"] = ", ".join(self.json_input["model_description"])
self.metadata["Description"] = ", ".join(
self.json_input["model_description"]
)
else:
# If it's already a string, just assign it directly
self.metadata["Description"] = self.json_input["model_description"]
self.metadata["Description"] = self.json_input["model_description"]
if self.metadata["Publication"] == "":
self.metadata["Publication"] = self.json_input["publication"]
if self.metadata["Source Code"] == "":
Expand Down
2 changes: 2 additions & 0 deletions ersilia/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def get_latest_semver_tag():
return tag
return None


def increment_patch_version(version):
version = version.split(".")
version[2] = str(int(version[2]) + 1)
return ".".join(version)


def get_version_for_setup():
# version = get_latest_semver_tag()
version = increment_patch_version(get_version_from_static())
Expand Down
1 change: 1 addition & 0 deletions ersilia/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Deal with privileges in Ersilia.
Base on GitHub login.
"""

from pathlib import Path
import os
import yaml
Expand Down
4 changes: 2 additions & 2 deletions ersilia/cli/commands/close.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .. import echo
from ... import ErsiliaModel
from ...core.session import Session
from ...core.tracking import check_file_exists, close_persistent_file
from ...core.tracking import check_file_exists, close_persistent_file


def close_cmd():
Expand All @@ -23,4 +23,4 @@ def close():

# Close our persistent tracking file
if check_file_exists(model_id):
close_persistent_file(mdl.model_id)
close_persistent_file(mdl.model_id)
2 changes: 1 addition & 1 deletion ersilia/cli/commands/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def example_cmd():
@click.option("--n_samples", "-n", default=5, type=click.INT)
@click.option("--file_name", "-f", default=None, type=click.STRING)
@click.option("--simple/--complete", "-s/-c", default=True)
@click.option("--predefined/--random", "-p/-r", default=False)
@click.option("--predefined/--random", "-p/-r", default=True)
def example(model, n_samples, file_name, simple, predefined):
if model is not None:
model_id = ModelBase(model).model_id
Expand Down
20 changes: 19 additions & 1 deletion ersilia/cli/commands/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _fetch(mf, model_id):
"--from_dir",
default=None,
type=click.STRING,
help="Local path where the model is stored"
help="Local path where the model is stored",
)
@click.option(
"--from_github",
Expand Down Expand Up @@ -60,6 +60,18 @@ def _fetch(mf, model_id):
default=None,
help="Fetch a model based on a URL. This only creates a basic folder structure for the model, the model is not actually downloaded.",
)
@click.option(
"--with_bentoml",
is_flag=True,
default=False,
help="Force fetch using BentoML",
)
@click.option(
"--with_fastapi",
is_flag=True,
default=False,
help="Force fetch using FastAPI",
)
def fetch(
model,
repo_path,
Expand All @@ -72,7 +84,11 @@ def fetch(
from_s3,
from_hosted,
from_url,
with_bentoml,
with_fastapi,
):
if with_bentoml and with_fastapi:
raise Exception("Cannot use both BentoML and FastAPI")
if repo_path is not None:
mdl = ModelBase(repo_path=repo_path)
elif from_dir is not None:
Expand All @@ -94,6 +110,8 @@ def fetch(
force_from_s3=from_s3,
force_from_dockerhub=from_dockerhub,
force_from_hosted=from_hosted,
force_with_bentoml=with_bentoml,
force_with_fastapi=with_fastapi,
hosted_url=from_url,
)
_fetch(mf, model_id)
Expand Down
18 changes: 6 additions & 12 deletions ersilia/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run(input, output, batch_size, standard):
fg="red",
)
return

mdl = ErsiliaModel(
model_id,
service_class=service_class,
Expand All @@ -64,17 +64,11 @@ def run(input, output, batch_size, standard):
echo("Something went wrong", fg="red")
else:
echo(result)

if track_runs:
"""
Retrieve the time taken to run the model and update the total.
Retrieve the time taken to run the model and update the total.
"""
time_tracker = RunTracker(
model_id=model_id,
config_json=None
)

time_tracker.update_total_time(
model_id=model_id,
start_time=start_time
)
time_tracker = RunTracker(model_id=model_id, config_json=None)

time_tracker.update_total_time(model_id=model_id, start_time=start_time)
5 changes: 2 additions & 3 deletions ersilia/cli/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def serve(model, lake, docker, port, track):
)
if not mdl.is_valid():
ModelNotFound(mdl).echo()



mdl.serve()
if mdl.url is None:
echo("No URL found. Service unsuccessful.", fg="red")
Expand All @@ -73,7 +72,7 @@ def serve(model, lake, docker, port, track):
echo("")
echo(":person_tipping_hand: Information:", fg="blue")
echo(" - info", fg="blue")

if track:
"""
Retrieve the time taken in seconds to serve the Model.
Expand Down
7 changes: 5 additions & 2 deletions ersilia/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import subprocess
from pathlib import Path
from ..utils.config import Config, Credentials
from ..utils.paths import resolve_pack_method
from ..default import EOS
from .. import logger

Expand Down Expand Up @@ -89,8 +90,10 @@ def _get_bundle_location(self, model_id):
else:
return path

@staticmethod
def _get_bento_location(model_id):
def _get_bento_location(self, model_id):
bundle_path = self._get_bundle_location(model_id)
if resolve_pack_method(bundle_path) != "bentoml":
return None
cmd = ["bentoml", "get", "%s:latest" % model_id, "--print-location", "--quiet"]
result = subprocess.run(cmd, stdout=subprocess.PIPE)
result = result.stdout.decode("utf-8").rstrip()
Expand Down
5 changes: 1 addition & 4 deletions ersilia/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ def get_apis(self):
def _run(
self, input=None, output=None, batch_size=DEFAULT_BATCH_SIZE, track_run=False
):

api_name = self.get_apis()[0]
result = self.api(
api_name=api_name, input=input, output=output, batch_size=batch_size
Expand Down Expand Up @@ -498,9 +497,7 @@ def run(
)
# Start tracking model run if track flag is used in serve
if self._run_tracker is not None and track_run:
self._run_tracker.track(
input=input, result=result, meta=self._model_info
)
self._run_tracker.track(input=input, result=result, meta=self._model_info)
self._run_tracker.log(result=result, meta=self._model_info)
return result

Expand Down
Loading
Loading