Skip to content

Commit

Permalink
Warn if RabbitMQ cannot connect
Browse files Browse the repository at this point in the history
  • Loading branch information
fivegrant committed Aug 17, 2023
1 parent 305da42 commit df8cf77
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 25 deletions.
7 changes: 1 addition & 6 deletions service/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
logging.getLogger().setLevel(logging.DEBUG)


def enable_progress():
return True


def build_api(*args) -> FastAPI:
api = FastAPI(
title="CIEMSS Service",
Expand Down Expand Up @@ -95,7 +91,6 @@ def operate(
operation: str,
body: Operation,
redis_conn=Depends(get_redis),
progress_enabled=Depends(enable_progress),
) -> JobResponse:
def check(otype):
if isinstance(body, otype):
Expand All @@ -116,4 +111,4 @@ def check(otype):
check(EnsembleCalibrate)
case _:
raise HTTPException(status_code=404, detail="Operation not found")
return create_job(body, operation, redis_conn, progress_enabled)
return create_job(body, operation, redis_conn)
10 changes: 2 additions & 8 deletions service/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# from juliacall import newmodule
from settings import settings
from utils.rabbitmq import gen_rabbitmq_hook
from utils.tds import (
update_tds_status,
cleanup_job_dir,
Expand Down Expand Up @@ -30,19 +29,14 @@
logging.getLogger().setLevel(logging.DEBUG)


def run(request, *, job_id, progress_enabled):
def run(request, *, job_id):
logging.debug(f"STARTED {job_id} (username: {request.username})")
sim_results_url = TDS_URL + TDS_SIMULATIONS + job_id
update_tds_status(sim_results_url, status="running", start=True)

# if request.engine == "ciemss":
operation_name = request.__class__.pyciemss_lib_function
if progress_enabled:
kwargs = request.gen_pyciemss_args(
job_id, progress_hook=gen_rabbitmq_hook(job_id)
)
else:
kwargs = request.gen_pyciemss_args(job_id)
kwargs = request.gen_pyciemss_args(job_id)
if len(operation_name) == 0:
raise Exception("No operation provided in request")
else:
Expand Down
26 changes: 20 additions & 6 deletions service/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations
import socket
import logging

from enum import Enum
from typing import ClassVar, Dict, List, Optional
from pydantic import BaseModel, Field, Extra


from utils.rabbitmq import gen_rabbitmq_hook
from utils.tds import fetch_dataset, fetch_model
from settings import settings

Expand Down Expand Up @@ -77,7 +80,7 @@ class OperationRequest(BaseModel):
engine: str = Field("ciemss", example="ciemss")
username: str = Field("not_provided", example="not_provided")

def gen_pyciemss_args(self, job_id, **_):
def gen_pyciemss_args(self, job_id):
raise NotImplementedError("PyCIEMSS cannot handle this operation")

def run_sciml_operation(self, job_id, julia_context):
Expand Down Expand Up @@ -109,7 +112,7 @@ class Simulate(OperationRequest):
description="optional extra system specific arguments for advanced use cases",
)

def gen_pyciemss_args(self, job_id, **_):
def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(
self.model_config_id, TDS_URL, TDS_CONFIGURATIONS, job_id
Expand Down Expand Up @@ -187,7 +190,7 @@ class Calibrate(OperationRequest):
description="optional extra system specific arguments for advanced use cases",
)

def gen_pyciemss_args(self, job_id, *, progress_hook=lambda _: None):
def gen_pyciemss_args(self, job_id):
amr_path = fetch_model(
self.model_config_id, TDS_URL, TDS_CONFIGURATIONS, job_id
)
Expand All @@ -198,11 +201,22 @@ def gen_pyciemss_args(self, job_id, *, progress_hook=lambda _: None):

dataset_path = fetch_dataset(self.dataset.dict(), TDS_URL, job_id)

# TODO: Test RabbitMQ
try:
hook = gen_rabbitmq_hook(job_id)
except socket.gaierror:
logging.warning(
"%s: Failed to connect to RabbitMQ. Unable to log progress", job_id
)

def hook(_):
return None

return {
"petri_model_or_path": amr_path,
"timepoints": timepoints,
"data_path": dataset_path,
"progress_hook": progress_hook,
"progress_hook": hook,
"visual_options": True,
**self.extra.dict(),
}
Expand Down Expand Up @@ -231,7 +245,7 @@ class EnsembleSimulate(OperationRequest):
description="optional extra system specific arguments for advanced use cases",
)

def gen_pyciemss_args(self, job_id, **_):
def gen_pyciemss_args(self, job_id):
weights = [config.weight for config in self.model_configs]
solution_mappings = [config.solution_mappings for config in self.model_configs]
amr_paths = [
Expand Down Expand Up @@ -287,7 +301,7 @@ class EnsembleCalibrate(OperationRequest):
description="optional extra system specific arguments for advanced use cases",
)

def gen_pyciemss_args(self, job_id, **_):
def gen_pyciemss_args(self, job_id):
weights = [config.weight for config in self.model_configs]
solution_mappings = [config.solution_mappings for config in self.model_configs]
amr_paths = [
Expand Down
4 changes: 2 additions & 2 deletions service/utils/rq_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def update_status_on_job_fail(job, connection, etype, value, traceback):
logging.exception(log_message)


def create_job(request_payload, sim_type, redis_conn, progress_enabled=False):
def create_job(request_payload, sim_type, redis_conn):
job_id = f"ciemss-{uuid4()}"

post_url = TDS_URL + TDS_SIMULATIONS
Expand Down Expand Up @@ -68,7 +68,7 @@ def create_job(request_payload, sim_type, redis_conn, progress_enabled=False):
queue.enqueue_call(
func="execute.run",
args=[request_payload],
kwargs={"job_id": job_id, "progress_enabled": progress_enabled},
kwargs={"job_id": job_id},
job_id=job_id,
on_failure=update_status_on_job_fail,
)
Expand Down
4 changes: 1 addition & 3 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.testclient import TestClient
from fakeredis import FakeStrictRedis

from service.api import app, get_redis, enable_progress
from service.api import app, get_redis


@pytest.fixture
Expand All @@ -26,10 +26,8 @@ def worker(redis):
@pytest.fixture
def client(redis):
app.dependency_overrides[get_redis] = lambda: redis
app.dependency_overrides[enable_progress] = lambda: False
yield TestClient(app)
app.dependency_overrides[get_redis] = get_redis
app.dependency_overrides[enable_progress] = enable_progress


@pytest.fixture
Expand Down

0 comments on commit df8cf77

Please sign in to comment.