Skip to content

Commit

Permalink
Merge branch 'master' into single_inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
DhanshreeA committed Jul 3, 2024
2 parents e161461 + d2abf0d commit 3ba9b8b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 5 deletions.
19 changes: 18 additions & 1 deletion ersilia/cli/commands/run.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import click
import json
import types
import time

from . import ersilia_cli
from .. import echo
from ... import ErsiliaModel
from ...core.session import Session
from ...core.tracking import RunTracker


def run_cmd():
Expand All @@ -27,6 +29,7 @@ def run_cmd():
help="Assume that the run is standard and, therefore, do not do so many checks.",
)
def run(input, output, batch_size, standard):
start_time = time.time()
session = Session(config_json=None)
model_id = session.current_model_id()
service_class = session.current_service_class()
Expand All @@ -38,7 +41,7 @@ def run(input, output, batch_size, standard):
fg="red",
)
return

mdl = ErsiliaModel(
model_id,
service_class=service_class,
Expand All @@ -61,3 +64,17 @@ def run(input, output, batch_size, standard):
echo("Something went wrong", fg="red")
else:
echo(result)

if track_runs:
"""
Retrieve the time taken to run the model and update the total.
"""
time_tracker = RunTracker(
model_id=model_id,
config_json=None
)

time_tracker.update_total_time(
model_id=model_id,
start_time=start_time
)
14 changes: 14 additions & 0 deletions ersilia/cli/commands/serve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import click
import time

from .. import echo
from . import ersilia_cli
from ... import ErsiliaModel
from ..messages import ModelNotFound
from ...core.tracking import write_persistent_file


def serve_cmd():
Expand Down Expand Up @@ -31,6 +33,7 @@ def serve_cmd():
default=False,
)
def serve(model, lake, docker, port, track):
start_time = time.time()
if docker:
service_class = "docker"
else:
Expand All @@ -44,6 +47,8 @@ def serve(model, lake, docker, port, track):
)
if not mdl.is_valid():
ModelNotFound(mdl).echo()


mdl.serve()
if mdl.url is None:
echo("No URL found. Service unsuccessful.", fg="red")
Expand All @@ -68,3 +73,12 @@ def serve(model, lake, docker, port, track):
echo("")
echo(":person_tipping_hand: Information:", fg="blue")
echo(" - info", fg="blue")

if track:
"""
Retrieve the time taken in seconds to serve the Model.
"""
end_time = time.time()
duration = end_time - start_time
content = "Total time taken: {0}\n".format(duration)
write_persistent_file(content, mdl.model_id)
50 changes: 46 additions & 4 deletions ersilia/core/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tracemalloc
from .session import Session
from datetime import datetime
from datetime import timedelta
from .base import ErsiliaBase
from collections import defaultdict
from ..default import EOS, ERSILIA_RUNS_FOLDER
Expand Down Expand Up @@ -391,6 +392,50 @@ def __init__(self, model_id, config_json):


# return stats


def update_total_time(self, model_id, start_time):
"""
Method to track and update the Total time taken by model.
:Param model_id: The currently running model.
:Param start_time: The start time of the running model.
"""

end_time = time.time()
duration = end_time - start_time
if check_file_exists(model_id):
file_name = get_persistent_file_path(model_id)
with open(file_name, "r") as f:
lines = f.readlines()

updated_lines = []
total_time_found = False

for line in lines:
if "Total time taken" in line and not total_time_found:
try:
total_time_str = line.split(":")[1].strip()
total_time = float(total_time_str)
total_time += duration
formatted_time = str(timedelta(seconds=total_time))
updated_lines.append(f"Total time taken: {formatted_time}\n")
total_time_found = True
except (ValueError, IndexError) as e:
print(f"Error parsing 'Total time taken' value: {e}")
else:
updated_lines.append(line)

if not total_time_found:
updated_lines.append(f"Total time taken: {formatted_duration}\n")

new_content = "".join(updated_lines)
with open(file_name, "w") as f:
f.write(f"{new_content}\n")
else:
new_content = f"Total time: {formatted_duration}\n"
with open(file_name, "w") as f:
f.write(f"{new_content}\n")


def get_file_sizes(self, input_file, output_file):
"""
Expand Down Expand Up @@ -537,7 +582,7 @@ def track(self, input, result, meta):
"""
Tracks the results of a model run.
"""
self.time_start = datetime.now()

self.docker_client = SimpleDocker()
self.data = CsvDataLoader()
json_dict = {}
Expand Down Expand Up @@ -578,9 +623,6 @@ def track(self, input, result, meta):
model_id = meta["metadata"].get("Identifier", "Unknown")
json_dict["model_id"] = model_id

time_taken = datetime.now() - self.time_start
json_dict["time_taken in seconds"] = str(time_taken)

# checking for mismatched types
nan_count = get_nan_counts(result_data)
json_dict["nan_count"] = nan_count
Expand Down

0 comments on commit 3ba9b8b

Please sign in to comment.