diff --git a/poetry.lock b/poetry.lock index 1ee4982..892b26a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -272,6 +272,25 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "fakeredis" +version = "2.17.0" +description = "Python implementation of redis API, can be used for testing purposes." +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "fakeredis-2.17.0-py3-none-any.whl", hash = "sha256:a99ef6e5642c31e91d36be78809fec3743e2bf7aaa682685b0d65a849fecd148"}, + {file = "fakeredis-2.17.0.tar.gz", hash = "sha256:e304bc7addb2f862c3550cb7db58548418a0fadd4cd78a4de66464c84fbc2195"}, +] + +[package.dependencies] +redis = ">=4" +sortedcontainers = ">=2,<3" + +[package.extras] +json = ["jsonpath-ng (>=1.5,<2.0)"] +lua = ["lupa (>=1.14,<2.0)"] + [[package]] name = "fastapi" version = "0.96.1" @@ -319,6 +338,50 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "httpcore" +version = "0.17.3" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.7" +files = [ + {file = "httpcore-0.17.3-py3-none-any.whl", hash = "sha256:c2789b767ddddfa2a5782e3199b2b7f6894540b17b16ec26b2c4d8e103510b87"}, + {file = "httpcore-0.17.3.tar.gz", hash = "sha256:a6f30213335e34c1ade7be6ec7c47f19f50c56db36abef1a9dfa3815b1cb3888"}, +] + +[package.dependencies] +anyio = ">=3.0,<5.0" +certifi = "*" +h11 = ">=0.13,<0.15" +sniffio = "==1.*" + +[package.extras] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + +[[package]] +name = "httpx" +version = "0.24.1" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.7" +files = [ + {file = "httpx-0.24.1-py3-none-any.whl", hash = "sha256:06781eb9ac53cde990577af654bd990a4949de37a28bdb4a230d434f3a30b9bd"}, + {file = "httpx-0.24.1.tar.gz", hash = "sha256:5853a43053df830c20f8110c5e69fe44d035d850b2dfe795e196f00fdb774bdd"}, +] + +[package.dependencies] +certifi = "*" +httpcore = ">=0.15.0,<0.18.0" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "identify" version = "2.5.26" @@ -925,6 +988,17 @@ files = [ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" +optional = false +python-versions = "*" +files = [ + {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"}, + {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, +] + [[package]] name = "starlette" version = "0.27.0" @@ -1033,4 +1107,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "1a2168cc30d52212ad25ebbcc81c7512526ebd46342a2d2e2136cc3f16d91372" +content-hash = "f89a262c979f57e0122d2020e6ace36e84051631da50b47c955eadc4f95142d0" diff --git a/pyproject.toml b/pyproject.toml index 6b090a4..f3e5755 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ uvicorn = "^0.22.0" pika = "1.3.2" pandas = "^2.0.0" filelock = "^3.12.2" +poethepoet = "^0.21.1" # juliacall = { version="^0.9.14", optional = true } [tool.poetry.scripts] @@ -30,14 +31,14 @@ pytest = "^7.4.0" pre-commit = "^3.3.3" requests-mock = "^1.11.0" mock = "^5.1.0" -poethepoet = "^0.21.1" +fakeredis = "^2.17.0" +httpx = "^0.24.1" [tool.poe.tasks] install-pyciemss = "pip install --no-cache git+https://github.com/ciemss/pyciemss.git@v0.0.1" [tool.pytest.ini_options] -python_files = ["tests/tests.py"] -markers = ["operation"] +markers = ["example_dir"] pythonpath = "service" diff --git a/service/api.py b/service/api.py index d540a7d..80ef036 100644 --- a/service/api.py +++ b/service/api.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, Depends, HTTPException from fastapi.middleware.cors import CORSMiddleware from models import ( @@ -14,7 +14,7 @@ StatusSimulationIdGetResponse, ) -from utils.rq_helpers import create_job, fetch_job_status, kill_job +from utils.rq_helpers import get_redis, create_job, fetch_job_status, kill_job Operation = Simulate | Calibrate | EnsembleSimulate | EnsembleCalibrate @@ -55,11 +55,13 @@ def get_ping(): @app.get("/status/{simulation_id}", response_model=StatusSimulationIdGetResponse) -def get_status(simulation_id: str) -> StatusSimulationIdGetResponse: +def get_status( + simulation_id: str, redis_conn=Depends(get_redis) +) -> StatusSimulationIdGetResponse: """ Retrieve the status of a simulation """ - status = fetch_job_status(simulation_id) + status = fetch_job_status(simulation_id, redis_conn) logging.info(status) if not isinstance(status, str): return status @@ -70,11 +72,13 @@ def get_status(simulation_id: str) -> StatusSimulationIdGetResponse: @app.get( "/cancel/{simulation_id}", response_model=StatusSimulationIdGetResponse ) # NOT IN SPEC -def cancel_job(simulation_id: str) -> StatusSimulationIdGetResponse: +def cancel_job( + simulation_id: str, redis_conn=Depends(get_redis) +) -> StatusSimulationIdGetResponse: """ Cancel a simulation """ - status = kill_job(simulation_id) + status = kill_job(simulation_id, redis_conn) logging.info(status) if not isinstance(status, str): return status @@ -83,7 +87,11 @@ def cancel_job(simulation_id: str) -> StatusSimulationIdGetResponse: @app.post("/{operation}", response_model=JobResponse) -def operate(operation: str, body: Operation) -> JobResponse: +def operate( + operation: str, + body: Operation, + redis_conn=Depends(get_redis), +) -> JobResponse: def check(otype): if isinstance(body, otype): return None @@ -103,4 +111,4 @@ def check(otype): check(EnsembleCalibrate) case _: raise HTTPException(status_code=404, detail="Operation not found") - return create_job(body, operation) + return create_job(body, operation, redis_conn) diff --git a/service/models.py b/service/models.py index 407e077..b384d80 100644 --- a/service/models.py +++ b/service/models.py @@ -1,12 +1,14 @@ from __future__ import annotations +import socket +import logging from enum import Enum from typing import ClassVar, Dict, List, Optional from pydantic import BaseModel, Field, Extra -from utils.tds import fetch_dataset, fetch_model from utils.rabbitmq import gen_rabbitmq_hook +from utils.tds import fetch_dataset, fetch_model from settings import settings TDS_CONFIGURATIONS = "/model_configurations/" @@ -199,11 +201,22 @@ def gen_pyciemss_args(self, job_id): dataset_path = fetch_dataset(self.dataset.dict(), TDS_URL, job_id) + # TODO: Test RabbitMQ + try: + hook = gen_rabbitmq_hook(job_id) + except socket.gaierror: + logging.warning( + "%s: Failed to connect to RabbitMQ. Unable to log progress", job_id + ) + + def hook(_): + return None + return { "petri_model_or_path": amr_path, "timepoints": timepoints, "data_path": dataset_path, - "progress_hook": gen_rabbitmq_hook(job_id), + "progress_hook": hook, "visual_options": True, **self.extra.dict(), } diff --git a/service/utils/rq_helpers.py b/service/utils/rq_helpers.py index ff04e4a..16f6e3a 100644 --- a/service/utils/rq_helpers.py +++ b/service/utils/rq_helpers.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -import uuid +from uuid import uuid4 import json import requests @@ -22,9 +22,8 @@ logging.getLogger().setLevel(logging.DEBUG) -# REDIS CONNECTION AND QUEUE OBJECTS -redis = Redis(settings.REDIS_HOST, settings.REDIS_PORT) -queue = Queue(connection=redis, default_timeout=-1) +def get_redis(): + return Redis(settings.REDIS_HOST, settings.REDIS_PORT) def update_status_on_job_fail(job, connection, etype, value, traceback): @@ -41,9 +40,8 @@ def update_status_on_job_fail(job, connection, etype, value, traceback): logging.exception(log_message) -def create_job(request_payload, sim_type): - random_id = str(uuid.uuid4()) - job_id = f"ciemss-{random_id}" +def create_job(request_payload, sim_type, redis_conn): + job_id = f"ciemss-{uuid4()}" post_url = TDS_URL + TDS_SIMULATIONS payload = { @@ -66,6 +64,7 @@ def create_job(request_payload, sim_type): ) logging.info(response.content) + queue = Queue(connection=redis_conn, default_timeout=-1) queue.enqueue_call( func="execute.run", args=[request_payload], @@ -77,7 +76,7 @@ def create_job(request_payload, sim_type): return {"simulation_id": job_id} -def fetch_job_status(job_id): +def fetch_job_status(job_id, redis_conn): """Fetch a job's results from RQ. Args: @@ -89,7 +88,7 @@ def fetch_job_status(job_id): content: contains the job's results. """ try: - job = Job.fetch(job_id, connection=redis) + job = Job.fetch(job_id, connection=redis_conn) # r = job.latest_result() # string_res = r.return_value result = job.get_status() @@ -101,9 +100,9 @@ def fetch_job_status(job_id): ) -def kill_job(job_id): +def kill_job(job_id, redis_conn): try: - job = Job.fetch(job_id, connection=redis) + job = Job.fetch(job_id, connection=redis_conn) except NoSuchJobError: return Response( status_code=status.HTTP_404_NOT_FOUND, diff --git a/service/utils/tds.py b/service/utils/tds.py index e83a263..db6a0b8 100644 --- a/service/utils/tds.py +++ b/service/utils/tds.py @@ -126,6 +126,13 @@ def attach_files(output: dict, tds_api, simulation_endpoint, job_id, status="com presigned_upload_url = upload_response.json()["url"] with open(location, "rb") as f: upload_response = requests.put(presigned_upload_url, f) + if upload_response.status_code >= 300: + raise Exception( + ( + "Failed to upload file to TDS " + f"(status: {upload_response.status_code}): {handle}" + ) + ) else: logging.error(f"{job_id} ran into error") diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f4abbed --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,28 @@ +import json +import os + +import pytest + + +@pytest.fixture +def example_context(request): + ctx = {} + chosen = request.node.get_closest_marker("example_dir").args[0] + path_prefix = f"./tests/examples/{chosen}" + with open(f"{path_prefix}/input/request.json", "r") as file: + ctx["request"] = json.load(file) + with open(f"{path_prefix}/output/tds_simulation.json", "r") as file: + ctx["tds_simulation"] = json.load(file) + + def fetch(handle, return_path=False): + io_dir = ( + "input" if os.path.exists(f"{path_prefix}/input/{handle}") else "output" + ) + path = f"{path_prefix}/{io_dir}/{handle}" + if return_path: + return os.path.abspath(path) + with open(path, "r") as file: + return file.read() + + ctx["fetch"] = fetch + return ctx diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..399e9c4 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,86 @@ +from urllib.parse import urlparse, parse_qs +import re +import json +import csv +import io +import pytest + +from rq import SimpleWorker, Queue +from fastapi.testclient import TestClient +from fakeredis import FakeStrictRedis + +from service.api import app, get_redis + + +@pytest.fixture +def redis(): + return FakeStrictRedis() + + +@pytest.fixture +def worker(redis): + queue = Queue(connection=redis, default_timeout=-1) + return SimpleWorker([queue], connection=redis) + + +@pytest.fixture +def client(redis): + app.dependency_overrides[get_redis] = lambda: redis + yield TestClient(app) + app.dependency_overrides[get_redis] = get_redis + + +@pytest.fixture +def file_storage(requests_mock): + storage = {} + + def get_filename(url): + return parse_qs(urlparse(url).query)["filename"][0] + + def get_loc(request, _): + filename = get_filename(request.url) + return {"url": f"https://filesave?filename={filename}"} + + def save(request, context): + filename = get_filename(request.url) + storage[filename] = request.body.read().decode("utf-8") + return {"status": "success"} + + def retrieve(filename): + return storage.get(filename, storage) + + get_upload_url = re.compile("upload-url") + requests_mock.get(get_upload_url, json=get_loc) + upload_url = re.compile("filesave") + requests_mock.put(upload_url, json=save) + + yield retrieve + + +# NOTE: This probably doesn't need to be a fixture +@pytest.fixture +def file_check(): + def checker(file_type, content): + match file_type: + case "json": + try: + json.loads(content) + except json.JSONDecodeError: + return False + else: + return True + case "csv": + result_csv = csv.reader(io.StringIO(content)) + try: + i = 0 + for _ in result_csv: + i += 1 + if i > 10: + return True + return True + except csv.Error: + return False + case _: + raise NotImplementedError("File type cannot be checked") + + return checker diff --git a/tests/integration/test_calibrate.py b/tests/integration/test_calibrate.py new file mode 100644 index 0000000..543cf02 --- /dev/null +++ b/tests/integration/test_calibrate.py @@ -0,0 +1,67 @@ +import json + +import pytest + +from service.settings import settings + +TDS_URL = settings.TDS_URL + + +@pytest.mark.example_dir("calibrate") +def test_calibrate_example( + example_context, client, worker, file_storage, file_check, requests_mock +): + request = example_context["request"] + config_id = request["model_config_id"] + model = json.loads(example_context["fetch"](config_id + ".json")) + + dataset_id = example_context["request"]["dataset"]["id"] + filename = example_context["request"]["dataset"]["filename"] + dataset = example_context["fetch"](filename, True) + dataset_loc = {"method": "GET", "url": dataset} + requests_mock.get( + f"{TDS_URL}/datasets/{dataset_id}/download-url?filename={filename}", + json=dataset_loc, + ) + + requests_mock.post(f"{TDS_URL}/simulations/", json={"id": None}) + + response = client.post( + "/calibrate", + json=request, + headers={"Content-Type": "application/json"}, + ) + simulation_id = response.json()["simulation_id"] + response = client.get( + f"/status/{simulation_id}", + ) + status = response.json()["status"] + assert status == "queued" + + tds_sim = example_context["tds_simulation"] + tds_sim["id"] = simulation_id + + requests_mock.get(f"{TDS_URL}/simulations/{simulation_id}", json=tds_sim) + requests_mock.put( + f"{TDS_URL}/simulations/{simulation_id}", json={"status": "success"} + ) + requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) + + worker.work(burst=True) + + response = client.get( + f"/status/{simulation_id}", + ) + status = response.json()["status"] + result = file_storage("result.csv") + viz = file_storage("visualization.json") + # eval = file_storage("eval.csv") # NOTE: Do we want to check this + + # Checks + assert status == "complete" + + assert result is not None + assert file_check("csv", result) + + assert viz is not None + assert file_check("json", viz) diff --git a/tests/integration/test_ensemble_calibrate.py b/tests/integration/test_ensemble_calibrate.py new file mode 100644 index 0000000..b10b13e --- /dev/null +++ b/tests/integration/test_ensemble_calibrate.py @@ -0,0 +1,71 @@ +import json + +import pytest + +from service.settings import settings + +TDS_URL = settings.TDS_URL + + +@pytest.mark.example_dir("ensemble-calibrate") +def test_ensemble_calibrate_example( + example_context, client, worker, file_storage, file_check, requests_mock +): + request = example_context["request"] + config_ids = [ + config["id"] for config in example_context["request"]["model_configs"] + ] + for config_id in config_ids: + model = json.loads(example_context["fetch"](config_id + ".json")) + requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) + + dataset_id = example_context["request"]["dataset"]["id"] + filename = example_context["request"]["dataset"]["filename"] + dataset = example_context["fetch"](filename, True) + dataset_loc = {"method": "GET", "url": dataset} + requests_mock.get( + f"{TDS_URL}/datasets/{dataset_id}/download-url?filename={filename}", + json=dataset_loc, + ) + requests_mock.get("http://dataset", text=dataset) + + requests_mock.post(f"{TDS_URL}/simulations/", json={"id": None}) + + response = client.post( + "/ensemble-calibrate", + json=request, + headers={"Content-Type": "application/json"}, + ) + simulation_id = response.json()["simulation_id"] + response = client.get( + f"/status/{simulation_id}", + ) + status = response.json()["status"] + assert status == "queued" + + tds_sim = example_context["tds_simulation"] + tds_sim["id"] = simulation_id + + requests_mock.get(f"{TDS_URL}/simulations/{simulation_id}", json=tds_sim) + requests_mock.put( + f"{TDS_URL}/simulations/{simulation_id}", json={"status": "success"} + ) + + worker.work(burst=True) + + response = client.get( + f"/status/{simulation_id}", + ) + status = response.json()["status"] + result = file_storage("result.csv") + viz = file_storage("visualization.json") + # eval = file_storage("eval.csv") # NOTE: Do we want to check this + + # Checks + assert status == "complete" + + assert result is not None + assert file_check("csv", result) + + assert viz is not None + assert file_check("json", viz) diff --git a/tests/integration/test_ensemble_simulate.py b/tests/integration/test_ensemble_simulate.py new file mode 100644 index 0000000..d145755 --- /dev/null +++ b/tests/integration/test_ensemble_simulate.py @@ -0,0 +1,61 @@ +import json + +import pytest + +from service.settings import settings + +TDS_URL = settings.TDS_URL + + +@pytest.mark.example_dir("ensemble-simulate") +def test_ensemble_simulate_example( + example_context, client, worker, file_storage, file_check, requests_mock +): + request = example_context["request"] + config_ids = [ + config["id"] for config in example_context["request"]["model_configs"] + ] + for config_id in config_ids: + model = json.loads(example_context["fetch"](config_id + ".json")) + requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) + + requests_mock.post(f"{TDS_URL}/simulations/", json={"id": None}) + + response = client.post( + "/ensemble-simulate", + json=request, + headers={"Content-Type": "application/json"}, + ) + simulation_id = response.json()["simulation_id"] + response = client.get( + f"/status/{simulation_id}", + ) + status = response.json()["status"] + assert status == "queued" + + tds_sim = example_context["tds_simulation"] + tds_sim["id"] = simulation_id + + requests_mock.get(f"{TDS_URL}/simulations/{simulation_id}", json=tds_sim) + requests_mock.put( + f"{TDS_URL}/simulations/{simulation_id}", json={"status": "success"} + ) + + worker.work(burst=True) + + response = client.get( + f"/status/{simulation_id}", + ) + status = response.json()["status"] + result = file_storage("result.csv") + viz = file_storage("visualization.json") + # eval = file_storage("eval.csv") # NOTE: Do we want to check this + + # Checks + assert status == "complete" + + assert result is not None + assert file_check("csv", result) + + assert viz is not None + assert file_check("json", viz) diff --git a/tests/integration/test_simulate.py b/tests/integration/test_simulate.py new file mode 100644 index 0000000..3748231 --- /dev/null +++ b/tests/integration/test_simulate.py @@ -0,0 +1,58 @@ +import json + +import pytest + +from service.settings import settings + +TDS_URL = settings.TDS_URL + + +@pytest.mark.example_dir("simulate") +def test_simulate_example( + example_context, client, worker, file_storage, file_check, requests_mock +): + request = example_context["request"] + config_id = request["model_config_id"] + model = json.loads(example_context["fetch"](config_id + ".json")) + + requests_mock.post(f"{TDS_URL}/simulations/", json={"id": None}) + + response = client.post( + "/simulate", + json=request, + headers={"Content-Type": "application/json"}, + ) + simulation_id = response.json()["simulation_id"] + response = client.get( + f"/status/{simulation_id}", + ) + status = response.json()["status"] + assert status == "queued" + + tds_sim = example_context["tds_simulation"] + tds_sim["id"] = simulation_id + + requests_mock.get(f"{TDS_URL}/simulations/{simulation_id}", json=tds_sim) + requests_mock.put( + f"{TDS_URL}/simulations/{simulation_id}", json={"status": "success"} + ) + requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) + + worker.work(burst=True) + + response = client.get( + f"/status/{simulation_id}", + ) + status = response.json()["status"] + result = file_storage("result.csv") + viz = file_storage("visualization.json") + # eval = file_storage("eval.csv") # NOTE: Do we want to check this + + # Checks + assert status == "complete" + + assert result is not None + assert file_check("csv", result) + + assert viz is not None + assert file_check("json", viz) diff --git a/tests/test_conversions.py b/tests/test_conversions.py new file mode 100644 index 0000000..54cd44a --- /dev/null +++ b/tests/test_conversions.py @@ -0,0 +1,130 @@ +import json +from inspect import signature + +import pytest + +from pyciemss.PetriNetODE.interfaces import ( # noqa: F401 + load_and_calibrate_and_sample_petri_model, + load_and_sample_petri_model, +) + +from pyciemss.Ensemble.interfaces import ( # noqa: F401 + load_and_sample_petri_ensemble, + load_and_calibrate_and_sample_ensemble_model, +) + +from service.models import Simulate, Calibrate, EnsembleSimulate, EnsembleCalibrate +from service.settings import settings + +TDS_URL = settings.TDS_URL + + +def is_satisfactory(kwargs, f): + parameters = signature(f).parameters + for key, value in kwargs.items(): + if key in parameters: + # TODO: Check types as well + # param = parameters[key] + # if param.annotation != Signature.empty and not isinstance( + # value, param.annotation + # ): + # return False + continue + return False + return True + + +class TestSimulate: + @pytest.mark.example_dir("simulate") + def test_example_conversion(self, example_context, requests_mock): + job_id = example_context["tds_simulation"]["id"] + + config_id = example_context["request"]["model_config_id"] + model = json.loads(example_context["fetch"](config_id + ".json")) + requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) + + ### Act and Assert + + operation_request = Simulate(**example_context["request"]) + kwargs = operation_request.gen_pyciemss_args(job_id) + + assert kwargs.get("visual_options", False) + assert is_satisfactory(kwargs, load_and_sample_petri_model) + + +class TestCalibrate: + @pytest.mark.example_dir("calibrate") + def test_example_conversion(self, example_context, requests_mock): + job_id = example_context["tds_simulation"]["id"] + + config_id = example_context["request"]["model_config_id"] + model = json.loads(example_context["fetch"](config_id + ".json")) + requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) + + dataset_id = example_context["request"]["dataset"]["id"] + filename = example_context["request"]["dataset"]["filename"] + dataset = example_context["fetch"](filename, True) + dataset_loc = {"method": "GET", "url": dataset} + requests_mock.get( + f"{TDS_URL}/datasets/{dataset_id}/download-url?filename={filename}", + json=dataset_loc, + ) + + ### Act and Assert + operation_request = Calibrate(**example_context["request"]) + kwargs = operation_request.gen_pyciemss_args(job_id) + + assert kwargs.get("visual_options", False) + assert is_satisfactory(kwargs, load_and_calibrate_and_sample_petri_model) + + +class TestEnsembleSimulate: + @pytest.mark.example_dir("ensemble-simulate") + def test_example_conversion(self, example_context, requests_mock): + job_id = example_context["tds_simulation"]["id"] + + config_ids = [ + config["id"] for config in example_context["request"]["model_configs"] + ] + for config_id in config_ids: + model = json.loads(example_context["fetch"](config_id + ".json")) + requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) + + ### Act and Assert + + operation_request = EnsembleSimulate(**example_context["request"]) + kwargs = operation_request.gen_pyciemss_args(job_id) + + assert kwargs.get("visual_options", False) + assert is_satisfactory(kwargs, load_and_sample_petri_ensemble) + + +class TestEnsembleCalibrate: + @pytest.mark.example_dir("ensemble-calibrate") + def test_example_conversion(self, example_context, requests_mock): + job_id = example_context["tds_simulation"]["id"] + + config_ids = [ + config["id"] for config in example_context["request"]["model_configs"] + ] + for config_id in config_ids: + model = json.loads(example_context["fetch"](config_id + ".json")) + requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) + + dataset_id = example_context["request"]["dataset"]["id"] + filename = example_context["request"]["dataset"]["filename"] + dataset = example_context["fetch"](filename, True) + dataset_loc = {"method": "GET", "url": dataset} + requests_mock.get( + f"{TDS_URL}/datasets/{dataset_id}/download-url?filename={filename}", + json=dataset_loc, + ) + requests_mock.get("http://dataset", text=dataset) + + ### Act and Assert + + operation_request = EnsembleCalibrate(**example_context["request"]) + kwargs = operation_request.gen_pyciemss_args(job_id) + + assert kwargs.get("visual_options", False) + assert is_satisfactory(kwargs, load_and_calibrate_and_sample_ensemble_model) diff --git a/tests/tests.py b/tests/tests.py deleted file mode 100644 index 8044fac..0000000 --- a/tests/tests.py +++ /dev/null @@ -1,158 +0,0 @@ -import json -import os -from inspect import signature - -from mock import patch -import pytest - -from pyciemss.PetriNetODE.interfaces import ( # noqa: F401 - load_and_calibrate_and_sample_petri_model, - load_and_sample_petri_model, -) - -from pyciemss.Ensemble.interfaces import ( # noqa: F401 - load_and_sample_petri_ensemble, - load_and_calibrate_and_sample_ensemble_model, -) - -from service.models import Simulate, Calibrate, EnsembleSimulate, EnsembleCalibrate -from service.settings import settings - -TDS_URL = settings.TDS_URL - - -def is_satisfactory(kwargs, f): - parameters = signature(f).parameters - for key, value in kwargs.items(): - if key in parameters: - # TODO: Check types as well - # param = parameters[key] - # if param.annotation != Signature.empty and not isinstance( - # value, param.annotation - # ): - # return False - continue - return False - return True - - -@pytest.fixture -def operation_context(request): - ctx = {} - chosen = request.node.get_closest_marker("operation").args[0] - path_prefix = f"./tests/examples/{chosen}" - with open(f"{path_prefix}/input/request.json", "r") as file: - ctx["request"] = json.load(file) - with open(f"{path_prefix}/output/tds_simulation.json", "r") as file: - ctx["tds_simulation"] = json.load(file) - - def fetch(handle, return_path=False): - io_dir = ( - "input" if os.path.exists(f"{path_prefix}/input/{handle}") else "output" - ) - path = f"{path_prefix}/{io_dir}/{handle}" - if return_path: - return os.path.abspath(path) - with open(path, "r") as file: - return file.read() - - ctx["fetch"] = fetch - return ctx - - -class TestSimulate: - @pytest.mark.operation("simulate") - def test_example_conversion(self, operation_context, requests_mock): - job_id = operation_context["tds_simulation"]["id"] - - config_id = operation_context["request"]["model_config_id"] - model = json.loads(operation_context["fetch"](config_id + ".json")) - requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) - - ### Act and Assert - - operation_request = Simulate(**operation_context["request"]) - kwargs = operation_request.gen_pyciemss_args(job_id) - - assert kwargs.get("visual_options", False) - assert is_satisfactory(kwargs, load_and_sample_petri_model) - - -class TestCalibrate: - @pytest.mark.operation("calibrate") - def test_example_conversion(self, operation_context, requests_mock): - job_id = operation_context["tds_simulation"]["id"] - - config_id = operation_context["request"]["model_config_id"] - model = json.loads(operation_context["fetch"](config_id + ".json")) - requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) - - dataset_id = operation_context["request"]["dataset"]["id"] - filename = operation_context["request"]["dataset"]["filename"] - dataset = operation_context["fetch"](filename, True) - dataset_loc = {"method": "GET", "url": dataset} - requests_mock.get( - f"{TDS_URL}/datasets/{dataset_id}/download-url?filename={filename}", - json=dataset_loc, - ) - - ### Act and Assert - - with patch("service.models.gen_rabbitmq_hook", return_value=lambda _: None): - operation_request = Calibrate(**operation_context["request"]) - kwargs = operation_request.gen_pyciemss_args(job_id) - - assert kwargs.get("visual_options", False) - assert is_satisfactory(kwargs, load_and_calibrate_and_sample_petri_model) - - -class TestEnsembleSimulate: - @pytest.mark.operation("ensemble-simulate") - def test_example_conversion(self, operation_context, requests_mock): - job_id = operation_context["tds_simulation"]["id"] - - config_ids = [ - config["id"] for config in operation_context["request"]["model_configs"] - ] - for config_id in config_ids: - model = json.loads(operation_context["fetch"](config_id + ".json")) - requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) - - ### Act and Assert - - operation_request = EnsembleSimulate(**operation_context["request"]) - kwargs = operation_request.gen_pyciemss_args(job_id) - - assert kwargs.get("visual_options", False) - assert is_satisfactory(kwargs, load_and_sample_petri_ensemble) - - -class TestEnsembleCalibrate: - @pytest.mark.operation("ensemble-calibrate") - def test_example_conversion(self, operation_context, requests_mock): - job_id = operation_context["tds_simulation"]["id"] - - config_ids = [ - config["id"] for config in operation_context["request"]["model_configs"] - ] - for config_id in config_ids: - model = json.loads(operation_context["fetch"](config_id + ".json")) - requests_mock.get(f"{TDS_URL}/model_configurations/{config_id}", json=model) - - dataset_id = operation_context["request"]["dataset"]["id"] - filename = operation_context["request"]["dataset"]["filename"] - dataset = operation_context["fetch"](filename, True) - dataset_loc = {"method": "GET", "url": dataset} - requests_mock.get( - f"{TDS_URL}/datasets/{dataset_id}/download-url?filename={filename}", - json=dataset_loc, - ) - requests_mock.get("http://dataset", text=dataset) - - ### Act and Assert - - operation_request = EnsembleCalibrate(**operation_context["request"]) - kwargs = operation_request.gen_pyciemss_args(job_id) - - assert kwargs.get("visual_options", False) - assert is_satisfactory(kwargs, load_and_calibrate_and_sample_ensemble_model)