Skip to content

Commit

Permalink
Merge pull request #927 from hcs-t4sg/master
Browse files Browse the repository at this point in the history
Add a new tracking module to enable the monitoring of runs via Splunk
  • Loading branch information
miquelduranfrigola authored Dec 21, 2023
2 parents 497b56c + 97000e6 commit 6db1f10
Show file tree
Hide file tree
Showing 8 changed files with 463 additions and 9 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ secrets.json
development.ipynb
development.py
tmp/
.idea/

# Sphinx related
/doctrees
Expand Down Expand Up @@ -156,3 +157,6 @@ dmypy.json

# Cython debug symbols
cython_debug/

# Don't commit .csv output files
*.csv
6 changes: 6 additions & 0 deletions ersilia/cli/commands/close.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import datetime
import os
from . import ersilia_cli
from .. import echo
from ... import ErsiliaModel
from ...core.session import Session
from ...core.tracking import close_persistent_file


def close_cmd():
Expand All @@ -17,3 +20,6 @@ def close():
mdl = ErsiliaModel(model_id, service_class=service_class)
mdl.close()
echo(":no_entry: Model {0} closed".format(mdl.model_id), fg="green")

# Close our persistent tracking file
close_persistent_file()
18 changes: 15 additions & 3 deletions ersilia/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ def run_cmd():
@click.option(
"-b", "--batch_size", "batch_size", required=False, default=100, type=click.INT
)
@click.option(
"-t/", "--track_run/--no_track_run", "track_run", required=False, default=False
)
@click.option(
"--standard",
is_flag=True,
default=False,
help="Assume that the run is standard and, therefore, do not do so many checks.",
)
def run(input, output, batch_size, standard):
def run(input, output, batch_size, track_run, standard):
session = Session(config_json=None)
model_id = session.current_model_id()
service_class = session.current_service_class()
Expand All @@ -36,9 +39,18 @@ def run(input, output, batch_size, standard):
fg="red",
)
return
mdl = ErsiliaModel(model_id, service_class=service_class, config_json=None)
mdl = ErsiliaModel(
model_id,
service_class=service_class,
config_json=None,
track_runs=track_run,
)
result = mdl.run(
input=input, output=output, batch_size=batch_size, try_standard=standard
input=input,
output=output,
batch_size=batch_size,
track_run=track_run,
try_standard=standard,
)
if isinstance(result, types.GeneratorType):
for result in mdl.run(input=input, output=output, batch_size=batch_size):
Expand Down
15 changes: 14 additions & 1 deletion ersilia/cli/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .. import echo
from ... import ErsiliaModel
from ..messages import ModelNotFound
from ...core.tracking import open_persistent_file


def serve_cmd():
Expand All @@ -20,7 +21,15 @@ def serve_cmd():
type=click.INT,
help="Preferred port to use (integer)",
)
def serve(model, lake, docker, port):
# Add the new flag for tracking the serve session
@click.option(
"-t/",
"--track_serve/--no_track_serve",
"track_serve",
required=False,
default=False,
)
def serve(model, lake, docker, port, track_serve):
if docker:
service_class = "docker"
else:
Expand Down Expand Up @@ -54,3 +63,7 @@ def serve(model, lake, docker, port):
echo("")
echo(":person_tipping_hand: Information:", fg="blue")
echo(" - info", fg="blue")

# Setup persistent tracking
if track_serve:
open_persistent_file(mdl.model_id)
32 changes: 28 additions & 4 deletions ersilia/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base import ErsiliaBase
from .modelbase import ModelBase
from .session import Session, RunLogger
from .tracking import RunTracker
from ..serve.autoservice import AutoService
from ..serve.schema import ApiSchema
from ..serve.api import Api
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
fetch_if_not_available=True,
preferred_port=None,
log_runs=True,
track_runs=False,
):
ErsiliaBase.__init__(
self, config_json=config_json, credentials_json=credentials_json
Expand Down Expand Up @@ -129,6 +131,12 @@ def __init__(
)
else:
self._run_logger = None

if track_runs:
self._run_tracker = RunTracker()
else:
self._run_tracker = None

self.logger.info("Done with initialization!")

def __enter__(self):
Expand All @@ -142,7 +150,8 @@ def is_valid(self):
return self._is_valid

def _set_api(self, api_name):
if api_name == "run":
# Don't want to override apis we explicitly write
if hasattr(self, api_name):
return

def _method(input=None, output=None, batch_size=DEFAULT_BATCH_SIZE):
Expand Down Expand Up @@ -427,13 +436,21 @@ def close(self):
def get_apis(self):
return self.autoservice.get_apis()

def _run(self, input=None, output=None, batch_size=DEFAULT_BATCH_SIZE):
def _run(
self, input=None, output=None, batch_size=DEFAULT_BATCH_SIZE, track_run=False
):
# Init some tracking before the run starts
if self._run_tracker is not None and track_run:
self._run_tracker.start_tracking()

api_name = self.get_apis()[0]
result = self.api(
api_name=api_name, input=input, output=output, batch_size=batch_size
)
if self._run_logger is not None:
self._run_logger.log(result=result, meta=self._model_info)
if self._run_tracker is not None and track_run:
self._run_tracker.track(input=input, result=result, meta=self._model_info)
return result

def _standard_run(self, input=None, output=None):
Expand All @@ -450,7 +467,12 @@ def _standard_run(self, input=None, output=None):
return result, status_ok

def run(
self, input=None, output=None, batch_size=DEFAULT_BATCH_SIZE, try_standard=True
self,
input=None,
output=None,
batch_size=DEFAULT_BATCH_SIZE,
track_run=False,
try_standard=True,
):
self.logger.info("Starting runner")
standard_status_ok = False
Expand All @@ -471,7 +493,9 @@ def run(
return result
else:
self.logger.debug("Trying conventional run")
result = self._run(input=input, output=output, batch_size=batch_size)
result = self._run(
input=input, output=output, batch_size=batch_size, track_run=track_run
)
return result

@property
Expand Down
Loading

0 comments on commit 6db1f10

Please sign in to comment.