diff --git a/neo4j-app/neo4j_app/app/config.py b/neo4j-app/neo4j_app/app/config.py index bb2a2ff5..2822f28d 100644 --- a/neo4j-app/neo4j_app/app/config.py +++ b/neo4j-app/neo4j_app/app/config.py @@ -29,10 +29,13 @@ "log_level", ] +_DEFAULT_ASYNC_DEPS = "neo4j_app.tasks.ASYNC_APP_LIFESPAN_DEPS" +_DEFAULT_DEPS = "neo4j_app.app.dependencies.HTTP_SERVICE_LIFESPAN_DEPS" + class ServiceConfig(AppConfig): - neo4j_app_async_dependencies: Optional[str] = "neo4j_app.tasks.WORKER_LIFESPAN_DEPS" - neo4j_app_async_app: str = "neo4j_app.tasks.app" + neo4j_app_async_app: Optional[str] = "neo4j_app.tasks.app" + neo4j_app_dependencies: Optional[str] = _DEFAULT_DEPS neo4j_app_gunicorn_workers: int = 1 neo4j_app_host: str = "127.0.0.1" neo4j_app_n_async_workers: int = 1 @@ -50,11 +53,6 @@ def to_worker_config(self, **kwargs) -> WorkerConfig: kwargs = copy(kwargs) for suffix in _SHARED_WITH_NEO4J_WORKER_CONFIG_PREFIXED: kwargs[suffix] = getattr(self, f"neo4j_app_{suffix}") - - if self.test: - from neo4j_app.tests.icij_worker.conftest import MockWorkerConfig - - return MockWorkerConfig(**kwargs) from neo4j_app.icij_worker.worker.neo4j import Neo4jWorkerConfig for k in _SHARED_WITH_NEO4J_WORKER_CONFIG: diff --git a/neo4j-app/neo4j_app/app/dependencies.py b/neo4j-app/neo4j_app/app/dependencies.py index 03db2b7b..3e5b7f49 100644 --- a/neo4j-app/neo4j_app/app/dependencies.py +++ b/neo4j-app/neo4j_app/app/dependencies.py @@ -3,9 +3,8 @@ import os import tempfile from contextlib import asynccontextmanager -from multiprocessing.managers import SyncManager from pathlib import Path -from typing import Optional, cast +from typing import Dict, Optional, cast import neo4j from fastapi import FastAPI @@ -15,12 +14,14 @@ from neo4j_app.icij_worker import ( EventPublisher, Neo4jEventPublisher, + WorkerConfig, ) from neo4j_app.icij_worker.backend.backend import WorkerBackend from neo4j_app.icij_worker.task_manager import TaskManager from neo4j_app.icij_worker.task_manager.neo4j import Neo4JTaskManager from neo4j_app.icij_worker.utils import run_deps from neo4j_app.icij_worker.utils.dependencies import DependencyInjectionError +from neo4j_app.icij_worker.utils.imports import import_variable from neo4j_app.tasks.dependencies import ( config_enter, create_project_registry_db_enter, @@ -40,10 +41,7 @@ _EVENT_PUBLISHER: Optional[EventPublisher] = None _MP_CONTEXT = None _NEO4J_DRIVER: Optional[neo4j.AsyncDriver] = None -_PROCESS_MANAGER: Optional[SyncManager] = None _TASK_MANAGER: Optional[TaskManager] = None -_TEST_DB_FILE: Optional[Path] = None -_TEST_LOCK: Optional[multiprocessing.Lock] = None _WORKER_POOL_IS_RUNNING = False @@ -89,86 +87,16 @@ def lifespan_mp_context(): return _MP_CONTEXT -def test_db_path_enter(**_): - config = cast( - ServiceConfig, - lifespan_config(), - ) - if config.test: - # pylint: disable=consider-using-with - from neo4j_app.tests.icij_worker.conftest import DBMixin - - global _TEST_DB_FILE - _TEST_DB_FILE = tempfile.NamedTemporaryFile(prefix="db", suffix=".json") - - DBMixin.fresh_db(Path(_TEST_DB_FILE.name)) - _TEST_DB_FILE.__enter__() # pylint: disable=unnecessary-dunder-call - - -def test_db_path_exit(exc_type, exc_value, trace): - if _TEST_DB_FILE is not None: - _TEST_DB_FILE.__exit__(exc_type, exc_value, trace) - - -def _lifespan_test_db_path() -> Path: - if _TEST_DB_FILE is None: - raise DependencyInjectionError("test db path") - return Path(_TEST_DB_FILE.name) - - -def test_process_manager_enter(**_): - global _PROCESS_MANAGER - _PROCESS_MANAGER = lifespan_mp_context().Manager() - - -def test_process_manager_exit(exc_type, exc_value, trace): - _PROCESS_MANAGER.__exit__(exc_type, exc_value, trace) - - -def lifespan_test_process_manager() -> SyncManager: - if _PROCESS_MANAGER is None: - raise DependencyInjectionError("process manager") - return _PROCESS_MANAGER - - -def _test_lock_enter(**_): - config = cast( - ServiceConfig, - lifespan_config(), - ) - if config.test: - global _TEST_LOCK - _TEST_LOCK = lifespan_test_process_manager().Lock() - - -def _lifespan_test_lock() -> multiprocessing.Lock: - if _TEST_LOCK is None: - raise DependencyInjectionError("test lock") - return cast(multiprocessing.Lock, _TEST_LOCK) - - def lifespan_worker_pool_is_running() -> bool: return _WORKER_POOL_IS_RUNNING def task_manager_enter(**_): global _TASK_MANAGER - config = cast( - ServiceConfig, - lifespan_config(), + config = cast(ServiceConfig, lifespan_config()) + _TASK_MANAGER = Neo4JTaskManager( + lifespan_neo4j_driver(), max_queue_size=config.neo4j_app_task_queue_size ) - if config.test: - from neo4j_app.tests.icij_worker.conftest import MockManager - - _TASK_MANAGER = MockManager( - _lifespan_test_db_path(), - _lifespan_test_lock(), - max_queue_size=config.neo4j_app_task_queue_size, - ) - else: - _TASK_MANAGER = Neo4JTaskManager( - lifespan_neo4j_driver(), max_queue_size=config.neo4j_app_task_queue_size - ) def lifespan_task_manager() -> TaskManager: @@ -179,18 +107,7 @@ def lifespan_task_manager() -> TaskManager: def event_publisher_enter(**_): global _EVENT_PUBLISHER - config = cast( - ServiceConfig, - lifespan_config(), - ) - if config.test: - from neo4j_app.tests.icij_worker.conftest import MockEventPublisher - - _EVENT_PUBLISHER = MockEventPublisher( - _lifespan_test_db_path(), _lifespan_test_lock() - ) - else: - _EVENT_PUBLISHER = Neo4jEventPublisher(lifespan_neo4j_driver()) + _EVENT_PUBLISHER = Neo4jEventPublisher(lifespan_neo4j_driver()) def lifespan_event_publisher() -> EventPublisher: @@ -200,27 +117,31 @@ def lifespan_event_publisher() -> EventPublisher: @asynccontextmanager -async def run_app_deps(app: FastAPI): +async def run_http_service_deps( + app: FastAPI, + async_app: str, + worker_config: WorkerConfig, + worker_extras: Optional[Dict] = None, +): config = app.state.config n_workers = config.neo4j_app_n_async_workers - async with run_deps( - dependencies=FASTAPI_LIFESPAN_DEPS, ctx="FastAPI HTTP server", config=config - ): + deps = import_variable(config.neo4j_app_dependencies) + async with run_deps(dependencies=deps, ctx="FastAPI HTTP server", config=config): + # Compute the support only once we know the neo4j driver deps has successfully + # completed app.state.config = await config.with_neo4j_support() - worker_extras = {"teardown_dependencies": config.test} - config_extra = dict() - # Forward the past of the app config to load to the async app - async_app_extras = {"config_path": _lifespan_async_app_config_path()} - if config.test: - config_extra["db_path"] = _lifespan_test_db_path() - worker_extras["lock"] = _lifespan_test_lock() - worker_config = config.to_worker_config(**config_extra) + # config_extra = dict() + # # Forward the part of the app config to load to the async app + # async_app_extras = {"config_path": _lifespan_async_app_config_path()} + # if is_test: + # config_extra["db_path"] = _lifespan_test_db_path() + # TODO 1: set the async app config path inside the deps itself + # TODO 3: set the DB path in deps with WorkerBackend.MULTIPROCESSING.run_cm( - config.neo4j_app_async_app, + async_app, n_workers=n_workers, config=worker_config, worker_extras=worker_extras, - app_deps_extras=async_app_extras, ): global _WORKER_POOL_IS_RUNNING _WORKER_POOL_IS_RUNNING = True @@ -228,7 +149,7 @@ async def run_app_deps(app: FastAPI): _WORKER_POOL_IS_RUNNING = False -FASTAPI_LIFESPAN_DEPS = [ +HTTP_SERVICE_LIFESPAN_DEPS = [ ("configuration reading", config_enter, None), ("loggers setup", loggers_enter, None), ( @@ -240,9 +161,6 @@ async def run_app_deps(app: FastAPI): ("neo4j project registry creation", create_project_registry_db_enter, None), ("ES client creation", es_client_enter, es_client_exit), (None, mp_context_enter, None), - (None, test_process_manager_enter, test_process_manager_exit), - (None, test_db_path_enter, test_db_path_exit), - (None, _test_lock_enter, None), ("task manager creation", task_manager_enter, None), ("event publisher creation", event_publisher_enter, None), ("neo4j DB migration", migrate_app_db_enter, None), diff --git a/neo4j-app/neo4j_app/app/utils.py b/neo4j-app/neo4j_app/app/utils.py index aca9271c..704419bb 100644 --- a/neo4j-app/neo4j_app/app/utils.py +++ b/neo4j-app/neo4j_app/app/utils.py @@ -1,3 +1,4 @@ +import functools import logging import traceback from typing import Dict, Iterable, List, Optional @@ -13,7 +14,7 @@ from neo4j_app.app import ServiceConfig from neo4j_app.app.admin import admin_router -from neo4j_app.app.dependencies import run_app_deps +from neo4j_app.app.dependencies import run_http_service_deps from neo4j_app.app.doc import DOCUMENT_TAG, NE_TAG, OTHER_TAG from neo4j_app.app.documents import documents_router from neo4j_app.app.graphs import graphs_router @@ -21,7 +22,7 @@ from neo4j_app.app.named_entities import named_entities_router from neo4j_app.app.projects import projects_router from neo4j_app.app.tasks import tasks_router -from neo4j_app.icij_worker import AsyncApp +from neo4j_app.icij_worker import WorkerConfig INTERNAL_SERVER_ERROR = "Internal Server Error" _REQUEST_VALIDATION_ERROR = "Request Validation Error" @@ -83,15 +84,29 @@ def _debug(): logger.info("im here") -def create_app(config: ServiceConfig, async_app: Optional[AsyncApp] = None) -> FastAPI: +def create_app( + config: ServiceConfig, + async_app: Optional[str] = None, + worker_config: WorkerConfig = None, + worker_extras: Optional[Dict] = None, +) -> FastAPI: + if bool(async_app) == bool(config.neo4j_app_async_app): + raise ValueError("Please provide exactly one config") + async_app = async_app or config.neo4j_app_async_app + if worker_config is None: + worker_config = config.to_worker_config() + lifespan = functools.partial( + run_http_service_deps, + async_app=async_app, + worker_config=worker_config, + worker_extras=worker_extras, + ) app = FastAPI( title=config.doc_app_name, openapi_tags=_make_open_api_tags([DOCUMENT_TAG, NE_TAG, OTHER_TAG]), - lifespan=run_app_deps, + lifespan=lifespan, ) app.state.config = config - if async_app is not None: - app.state.async_app = async_app app.add_exception_handler(RequestValidationError, request_validation_error_handler) app.add_exception_handler(StarletteHTTPException, http_exception_handler) app.add_exception_handler(Exception, internal_exception_handler) diff --git a/neo4j-app/neo4j_app/config.py b/neo4j-app/neo4j_app/config.py index c554f17e..16796c1d 100644 --- a/neo4j-app/neo4j_app/config.py +++ b/neo4j-app/neo4j_app/config.py @@ -22,6 +22,7 @@ from neo4j_app.core.utils.pydantic import ( IgnoreExtraModel, LowerCamelCaseModel, + NoEnumModel, safe_copy, ) @@ -32,7 +33,7 @@ def _es_version() -> str: return ".".join(str(num) for num in elasticsearch.__version__) -class AppConfig(LowerCamelCaseModel, IgnoreExtraModel): +class AppConfig(LowerCamelCaseModel, IgnoreExtraModel, NoEnumModel): elasticsearch_address: str = "http://127.0.0.1:9200" elasticsearch_version: str = Field(default_factory=_es_version, const=True) es_doc_type_field: str = Field(alias="docTypeField", default="type") @@ -44,7 +45,7 @@ class AppConfig(LowerCamelCaseModel, IgnoreExtraModel): es_keep_alive: str = "1m" force_migrations: bool = False neo4j_app_log_level: str = "INFO" - neo4j_app_cancelled_task_refresh_interval_s: int = 2 + neo4j_app_cancelled_tasks_refresh_interval_s: int = 2 neo4j_app_log_in_json: bool = False neo4j_app_max_dumped_documents: Optional[int] = None neo4j_app_max_records_in_memory: int = int(1e6) diff --git a/neo4j-app/neo4j_app/icij_worker/backend/backend.py b/neo4j-app/neo4j_app/icij_worker/backend/backend.py index 6edddf45..8305368d 100644 --- a/neo4j-app/neo4j_app/icij_worker/backend/backend.py +++ b/neo4j-app/neo4j_app/icij_worker/backend/backend.py @@ -20,7 +20,6 @@ def run( n_workers: int, config: WorkerConfig, worker_extras: Optional[Dict] = None, - app_deps_extras: Optional[Dict] = None, ): # This function is meant to be run as the main function of a Python command, # in this case we want th main process to handle signals @@ -30,7 +29,6 @@ def run( config, handle_signals=True, worker_extras=worker_extras, - app_deps_extras=app_deps_extras, ): pass @@ -44,7 +42,6 @@ def run_cm( n_workers: int, config: WorkerConfig, worker_extras: Optional[Dict] = None, - app_deps_extras: Optional[Dict] = None, ): # This usage is meant for when a backend is run from another process which # handles signals by itself @@ -54,7 +51,6 @@ def run_cm( config, handle_signals=False, worker_extras=worker_extras, - app_deps_extras=app_deps_extras, ): yield @@ -67,7 +63,6 @@ def _run_cm( *, handle_signals: bool = False, worker_extras: Optional[Dict] = None, - app_deps_extras: Optional[Dict] = None, ): if self is WorkerBackend.MULTIPROCESSING: with run_workers_with_multiprocessing( @@ -76,7 +71,6 @@ def _run_cm( config, handle_signals=handle_signals, worker_extras=worker_extras, - app_deps_extras=app_deps_extras, ): yield else: diff --git a/neo4j-app/neo4j_app/icij_worker/backend/mp.py b/neo4j-app/neo4j_app/icij_worker/backend/mp.py index 4a0298c9..1c017923 100644 --- a/neo4j-app/neo4j_app/icij_worker/backend/mp.py +++ b/neo4j-app/neo4j_app/icij_worker/backend/mp.py @@ -22,15 +22,12 @@ def _mp_work_forever( worker_id: str, *, worker_extras: Optional[Dict] = None, - app_deps_extras: Optional[Dict] = None, ): - if app_deps_extras is None: - app_deps_extras = dict() if worker_extras is None: worker_extras = dict() # For multiprocessing, lifespan dependencies need to be run once per process app = AsyncApp.load(app) - deps_cm = app.lifetime_dependencies(worker_id=worker_id, **app_deps_extras) + deps_cm = app.lifetime_dependencies(worker_id=worker_id) worker = Worker.from_config(config, app=app, worker_id=worker_id, **worker_extras) # This is ugly, but we have to work around the fact that we can't use asyncio code # here @@ -68,7 +65,6 @@ def run_workers_with_multiprocessing( *, handle_signals: bool = True, worker_extras: Optional[Dict] = None, - app_deps_extras: Optional[Dict] = None, ): logger.info("Creating multiprocessing worker pool with %s workers", n_workers) # Here we set maxtasksperchild to 1. Each worker has a single never ending task @@ -81,17 +77,15 @@ def run_workers_with_multiprocessing( worker_ids = [f"worker-{main_process_id}-{i}" for i in range(n_workers)] kwds = {"app": app, "config": config} kwds["worker_extras"] = worker_extras - kwds["app_deps_extras"] = app_deps_extras pool = mp_ctx.Pool(n_workers, maxtasksperchild=1) logger.debug("Setting up signal handlers...") - tasks = [] if handle_signals: setup_main_process_signal_handlers(pool) try: for w_id in worker_ids: kwds.update({"worker_id": w_id}) logger.info("starting worker %s", w_id) - tasks.append(pool.apply_async(_mp_work_forever, kwds=kwds)) + pool.apply_async(_mp_work_forever, kwds=kwds) yield except KeyboardInterrupt as e: if not handle_signals: diff --git a/neo4j-app/neo4j_app/icij_worker/utils/imports.py b/neo4j-app/neo4j_app/icij_worker/utils/imports.py new file mode 100644 index 00000000..ce6420a3 --- /dev/null +++ b/neo4j-app/neo4j_app/icij_worker/utils/imports.py @@ -0,0 +1,21 @@ +import importlib +from typing import Any + + +class VariableNotFound(ImportError): + pass + + +def import_variable(name: str) -> Any: + parts = name.split(".") + submodule = ".".join(parts[:-1]) + variable_name = parts[-1] + try: + module = importlib.import_module(submodule) + except ModuleNotFoundError as e: + raise VariableNotFound(e.msg) from e + try: + subclass = getattr(module, variable_name) + except AttributeError as e: + raise VariableNotFound(e) from e + return subclass diff --git a/neo4j-app/neo4j_app/icij_worker/utils/registrable.py b/neo4j-app/neo4j_app/icij_worker/utils/registrable.py index 74e2ce5f..ccd5aefe 100644 --- a/neo4j-app/neo4j_app/icij_worker/utils/registrable.py +++ b/neo4j-app/neo4j_app/icij_worker/utils/registrable.py @@ -2,7 +2,6 @@ Simplified implementation of AllenNLP Registrable: https://github.com/allenai/allennlp """ -import importlib import logging from abc import ABC from collections import defaultdict @@ -21,6 +20,7 @@ from pydantic import BaseSettings, Field from neo4j_app.icij_worker.utils.from_config import FromConfig, T +from neo4j_app.icij_worker.utils.imports import VariableNotFound, import_variable logger = logging.getLogger(__name__) @@ -77,25 +77,18 @@ def resolve_class_name(cls: Type[_RegistrableT], name: str) -> Type[_Registrable subclass = Registrable._registry[cls][name] return subclass if "." in name: - # Fully qualified class name - parts = name.split(".") - submodule = ".".join(parts[:-1]) - class_name = parts[-1] - try: - module = importlib.import_module(submodule) + subclass = import_variable(name) except ModuleNotFoundError as e: raise ValueError( f"tried to interpret {name} as a path to a class " - f"but unable to import module {submodule}" + f"but unable to import module {'.'.join(name.split('.')[:-1])}" ) from e - - try: - subclass = getattr(module, class_name) - except AttributeError as e: + except VariableNotFound as e: + split = name.split(".") raise ValueError( f"tried to interpret {name} as a path to a class " - f"but unable to find class {class_name} in {submodule}" + f"but unable to find class {split[-1]} in {split[:-1]}" ) from e return subclass available = "\n-".join(cls.list_available()) diff --git a/neo4j-app/neo4j_app/run/run.py b/neo4j-app/neo4j_app/run/run.py index d9845958..a5540cb7 100644 --- a/neo4j-app/neo4j_app/run/run.py +++ b/neo4j-app/neo4j_app/run/run.py @@ -16,12 +16,6 @@ from neo4j_app.core.utils.logging import DATE_FMT, STREAM_HANDLER_FMT -def debug_app(): - config = ServiceConfig() - app = create_app(config) - return app - - class Formatter(argparse.ArgumentDefaultsHelpFormatter): def __init__(self, prog): super().__init__(prog, max_help_position=35, width=150) diff --git a/neo4j-app/neo4j_app/tasks/__init__.py b/neo4j-app/neo4j_app/tasks/__init__.py index 8f7111df..97c90da8 100644 --- a/neo4j-app/neo4j_app/tasks/__init__.py +++ b/neo4j-app/neo4j_app/tasks/__init__.py @@ -1,2 +1,2 @@ -from .app import WORKER_LIFESPAN_DEPS, app +from .app import ASYNC_APP_LIFESPAN_DEPS, app from .imports import * diff --git a/neo4j-app/neo4j_app/tasks/app.py b/neo4j-app/neo4j_app/tasks/app.py index cc678ec6..0c074c69 100644 --- a/neo4j-app/neo4j_app/tasks/app.py +++ b/neo4j-app/neo4j_app/tasks/app.py @@ -1,9 +1,9 @@ import logging from neo4j_app.icij_worker import AsyncApp -from neo4j_app.tasks.dependencies import WORKER_LIFESPAN_DEPS +from neo4j_app.tasks.dependencies import ASYNC_APP_LIFESPAN_DEPS logger = logging.getLogger(__name__) -app = AsyncApp(name="neo4j-app", dependencies=WORKER_LIFESPAN_DEPS) +app = AsyncApp(name="neo4j-app", dependencies=ASYNC_APP_LIFESPAN_DEPS) diff --git a/neo4j-app/neo4j_app/tasks/dependencies.py b/neo4j-app/neo4j_app/tasks/dependencies.py index af3d867a..85c2658d 100644 --- a/neo4j-app/neo4j_app/tasks/dependencies.py +++ b/neo4j-app/neo4j_app/tasks/dependencies.py @@ -4,6 +4,7 @@ import neo4j +from neo4j_app.app import ServiceConfig from neo4j_app.core.elasticsearch import ESClientABC from neo4j_app.core.neo4j import MIGRATIONS, migrate_db_schemas from neo4j_app.core.neo4j.migrations import delete_all_migrations @@ -13,7 +14,7 @@ logger = logging.getLogger(__name__) -_CONFIG: Optional[AppConfig] = None +_CONFIG: Optional[ServiceConfig] = None _ASYNC_APP_CONFIG: Optional[AppConfig] = None _ES_CLIENT: Optional[ESClientABC] = None _ASYNC_APP_CONFIG_PATH: Optional[Path] = None @@ -41,7 +42,7 @@ async def config_neo4j_support_enter(**_): _CONFIG = await config.with_neo4j_support() -def lifespan_config() -> AppConfig: +def lifespan_config() -> ServiceConfig: if _CONFIG is None: raise DependencyInjectionError("config") return _CONFIG @@ -117,7 +118,7 @@ async def migrate_app_db_enter(**_): ) -WORKER_LIFESPAN_DEPS = [ +ASYNC_APP_LIFESPAN_DEPS = [ ("configuration loading", config_from_path_enter, None), ("loggers setup", loggers_enter, None), ("neo4j driver creation", neo4j_driver_enter, neo4j_driver_exit), diff --git a/neo4j-app/neo4j_app/tests/app/test_tasks.py b/neo4j-app/neo4j_app/tests/app/test_tasks.py index c4bb22f8..00c13392 100644 --- a/neo4j-app/neo4j_app/tests/app/test_tasks.py +++ b/neo4j-app/neo4j_app/tests/app/test_tasks.py @@ -1,5 +1,6 @@ # pylint: disable=redefined-outer-name from functools import partial +from pathlib import Path from typing import Optional import neo4j @@ -7,34 +8,37 @@ from _pytest.fixtures import FixtureRequest from starlette.testclient import TestClient -from neo4j_app.app.utils import create_app from neo4j_app.app.config import ServiceConfig, WorkerType +from neo4j_app.app.utils import create_app from neo4j_app.core.objects import TaskJob from neo4j_app.core.utils.logging import DifferedLoggingMessage from neo4j_app.core.utils.pydantic import safe_copy from neo4j_app.icij_worker import AsyncApp, Task, TaskStatus from neo4j_app.tests.conftest import TEST_PROJECT, test_error_router, true_after +from neo4j_app.tests.icij_worker.conftest import MockServiceConfig, MockWorkerConfig @pytest.fixture(scope="function") def test_client_prod( - test_config: ServiceConfig, - test_async_app: AsyncApp, + test_config: MockServiceConfig, # Wipe neo4j and init project neo4j_app_driver: neo4j.AsyncSession, ) -> TestClient: # pylint: disable=unused-argument - config = safe_copy( - test_config, update={"neo4j_app_worker_type": WorkerType.neo4j, "test": False} - ) - new_async_app = AsyncApp( - name=test_async_app.name, - dependencies=test_async_app._dependencies, # pylint: disable=protected-access - ) - new_async_app._registry = ( # pylint: disable=protected-access - test_async_app.registry + prod_deps = "neo4j_app.app.dependencies.HTTP_SERVICE_LIFESPAN_DEPS" + config_as_dict = test_config.dict(exclude_unset=True) + update = { + "neo4j_app_async_app": None, + "neo4j_app_dependencies": prod_deps, + "neo4j_app_worker_type": WorkerType.neo4j, + } + config_as_dict.update(update) + config = ServiceConfig(**config_as_dict) + app = create_app( + config, + async_app="neo4j_app.tests.conftest.APP", + worker_extras={"teardown_dependencies": False}, ) - app = create_app(config, async_app=new_async_app) # Add a router which generates error in order to test error handling app.include_router(test_error_router()) with TestClient(app) as client: @@ -96,7 +100,10 @@ def test_task_should_return_200_for_existing_task( @pytest.mark.parametrize( "test_client_type", - ["test_client_with_async", "test_client_prod"], + [ + # "test_client_with_async", + "test_client_prod", + ], ) def test_task_integration(test_client_type: str, request: FixtureRequest): # Given @@ -147,11 +154,17 @@ def test_cancel_task(test_client: TestClient): assert cancelled.status is TaskStatus.CANCELLED +_ASYNC_APP_LIMITED_QUEUE = None + + @pytest.fixture(scope="function") def test_client_limited_queue( - test_config: ServiceConfig, test_async_app: AsyncApp + test_config: MockServiceConfig, test_async_app: AsyncApp, mock_db: Path ) -> TestClient: - config = safe_copy(test_config, update={"neo4j_app_task_queue_size": 0}) + config = safe_copy( + test_config, + update={"neo4j_app_task_queue_size": 0, "neo4j_app_async_app": None}, + ) new_async_app = AsyncApp( name=test_async_app.name, dependencies=test_async_app._dependencies, # pylint: disable=protected-access @@ -159,8 +172,17 @@ def test_client_limited_queue( new_async_app._registry = ( # pylint: disable=protected-access test_async_app.registry ) - app = create_app(config, async_app=new_async_app) - # Add a router which generates error in order to test error handling + global _ASYNC_APP_LIMITED_QUEUE + _ASYNC_APP_LIMITED_QUEUE = new_async_app + worker_extras = {"teardown_dependencies": False} + worker_config = MockWorkerConfig(db_path=mock_db) + app = create_app( + config, + worker_config=worker_config, + async_app=f"{__name__}._ASYNC_APP_LIMITED_QUEUE", + worker_extras=worker_extras, + ) + # Add a rout0er which generates error in order to test error handling app.include_router(test_error_router()) with TestClient(app) as client: yield client diff --git a/neo4j-app/neo4j_app/tests/conftest.py b/neo4j-app/neo4j_app/tests/conftest.py index 60245ac1..469175f4 100644 --- a/neo4j-app/neo4j_app/tests/conftest.py +++ b/neo4j-app/neo4j_app/tests/conftest.py @@ -2,8 +2,10 @@ import abc import asyncio import contextlib +import functools import os import random +import tempfile import traceback from copy import copy from datetime import datetime @@ -16,6 +18,7 @@ Callable, Dict, Generator, + List, Optional, Tuple, Union, @@ -34,6 +37,9 @@ from neo4j_app.app.dependencies import ( config_enter, loggers_enter, + mp_context_enter, + write_async_app_config_enter, + write_async_app_config_exit, ) from neo4j_app.app.utils import create_app from neo4j_app.core.elasticsearch import ESClient, ESClientABC @@ -43,6 +49,24 @@ from neo4j_app.core.neo4j.projects import NEO4J_COMMUNITY_DB from neo4j_app.core.utils.pydantic import BaseICIJModel from neo4j_app.icij_worker import AsyncApp, WorkerType +from neo4j_app.icij_worker.typing_ import Dependency +from neo4j_app.tasks.dependencies import ( + config_from_path_enter, + create_project_registry_db_enter, + es_client_enter, + es_client_exit, + lifespan_config, + migrate_app_db_enter, + neo4j_driver_enter, + neo4j_driver_exit, +) +from neo4j_app.tests.icij_worker.conftest import ( + DBMixin, + MockEventPublisher, + MockManager, + MockServiceConfig, + MockWorkerConfig, +) from neo4j_app.typing_ import PercentProgress # TODO: at a high level it's a waste to have to repeat code for each fixture level, @@ -130,6 +154,21 @@ async def _mocked_search(self, **kwargs): pass +@pytest.fixture(scope="session") +def mock_db_session() -> Path: + with tempfile.NamedTemporaryFile(prefix="mock-db", suffix=".json") as f: + db_path = Path(f.name) + DBMixin.fresh_db(db_path) + yield db_path + + +@pytest.fixture +def mock_db(mock_db_session: Path) -> Path: + # Wipe the DB + DBMixin.fresh_db(mock_db_session) + return mock_db_session + + # Define a session level even_loop fixture to overcome limitation explained here: # https://github.com/tortoise/tortoise-orm/issues/638#issuecomment-830124562 @pytest.fixture(scope="session") @@ -140,32 +179,89 @@ def event_loop(): loop.close() +_MOCKED_HTTP_DEPS = None + + @pytest.fixture(scope="session") -def test_config() -> ServiceConfig: - config = ServiceConfig( +def test_config(mock_db_session: Path) -> ServiceConfig: + global _MOCKED_HTTP_DEPS + _MOCKED_HTTP_DEPS = _mock_http_deps(mock_db_session) + config = MockServiceConfig( elasticsearch_address=f"http://127.0.0.1:{ELASTICSEARCH_TEST_PORT}", es_default_page_size=5, + neo4j_app_async_app=f"{__name__}.APP", + neo4j_app_dependencies=f"{__name__}._MOCKED_HTTP_DEPS", neo4j_app_host="127.0.0.1", + neo4j_app_worker_type=WorkerType.mock, + neo4j_password=NEO4J_TEST_PASSWORD, neo4j_port=NEO4J_TEST_PORT, neo4j_user=NEO4J_TEST_USER, - neo4j_password=NEO4J_TEST_PASSWORD, - neo4j_app_worker_type=WorkerType.mock, - test=True, - neo4j_app_async_app=f"{__name__}.APP", - neo4j_app_async_dependencies=f"{__name__}.TEST_WORKER_DEPS", ) return config -TEST_WORKER_DEPS = [ - ("config reading", config_enter, None), - ("loggers setup", loggers_enter, None), -] +def mock_task_manager_enter(db_path: Path, **_): + import neo4j_app.app.dependencies + + config = lifespan_config() + task_manager = MockManager(db_path, config.neo4j_app_task_queue_size) + setattr(neo4j_app.app.dependencies, "_TASK_MANAGER", task_manager) + + +def mock_event_publisher_enter(db_path: Path, **_): + import neo4j_app.app.dependencies + + event_publisher = MockEventPublisher(db_path) + setattr(neo4j_app.app.dependencies, "_EVENT_PUBLISHER", event_publisher) + + +def _mock_http_deps(db_path: Path) -> List[Dependency]: + deps = [ + ("configuration reading", config_enter, None), + ("loggers setup", loggers_enter, None), + ( + "write async config for workers", + write_async_app_config_enter, + write_async_app_config_exit, + ), + ("neo4j driver creation", neo4j_driver_enter, neo4j_driver_exit), + ("neo4j project registry creation", create_project_registry_db_enter, None), + ("neo4j DB migration", migrate_app_db_enter, None), + ("ES client creation", es_client_enter, es_client_exit), + (None, mp_context_enter, None), + ( + "task manager creation", + functools.partial(mock_task_manager_enter, db_path=db_path), + None, + ), + ( + "event publisher creation", + functools.partial(mock_event_publisher_enter, db_path=db_path), + None, + ), + ] + return deps + + +def _mock_async_deps(config_path: Path) -> List[Dependency]: + deps = [ + ( + "configuration loading", + functools.partial(config_from_path_enter, config_path=config_path), + None, + ), + ("loggers setup", loggers_enter, None), + ] + return deps @pytest.fixture(scope="session") -def test_app_session(test_config: ServiceConfig) -> FastAPI: - return create_app(test_config) +def test_app_session(test_config: MockServiceConfig, mock_db_session: Path) -> FastAPI: + worker_extras = {"teardown_dependencies": False} + worker_config = MockWorkerConfig(db_path=mock_db_session) + return create_app( + test_config, worker_config=worker_config, worker_extras=worker_extras + ) @pytest.fixture(scope="session") @@ -192,6 +288,8 @@ def test_client_module( @pytest.fixture() def test_client( test_client_session: TestClient, + # Wipe the mock db + mock_db: Path, # Wipe ES by requiring the "function" level es client es_test_client: ESClient, # Same for neo4j @@ -207,13 +305,16 @@ def test_client_with_async( es_test_client: ESClient, # Same for neo4j neo4j_test_session: neo4j.AsyncSession, - test_async_app: AsyncApp, - test_config: ServiceConfig, + test_config: MockServiceConfig, + mock_db: Path, ) -> Generator[TestClient, None, None]: - # pylint: disable=unused-argument # pylint: disable=unused-argument # Let's recreate the app to wipe the worker pool and queues - app = create_app(test_config, async_app=test_async_app) + worker_extras = {"teardown_dependencies": False} + worker_config = MockWorkerConfig(db_path=mock_db) + app = create_app( + test_config, worker_config=worker_config, worker_extras=worker_extras + ) app.include_router(test_error_router()) with TestClient(app) as client: yield client @@ -651,7 +752,7 @@ async def sleep_for( @pytest.fixture(scope="session") -def test_async_app(test_config: ServiceConfig) -> AsyncApp: +def test_async_app(test_config: MockServiceConfig) -> AsyncApp: return AsyncApp.load(test_config.neo4j_app_async_app) diff --git a/neo4j-app/neo4j_app/tests/icij_worker/conftest.py b/neo4j-app/neo4j_app/tests/icij_worker/conftest.py index 945b9702..e3e8c0ff 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/conftest.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/conftest.py @@ -3,8 +3,6 @@ import asyncio import json import logging -import multiprocessing -import threading from abc import ABC from datetime import datetime from pathlib import Path @@ -17,8 +15,8 @@ from pydantic import Field from neo4j_app import AppConfig -from neo4j_app.app.dependencies import FASTAPI_LIFESPAN_DEPS -from neo4j_app.core.utils.pydantic import safe_copy +from neo4j_app.app import ServiceConfig +from neo4j_app.core.utils.pydantic import IgnoreExtraModel, safe_copy from neo4j_app.icij_worker import ( AsyncApp, EventPublisher, @@ -70,18 +68,13 @@ class DBMixin(ABC): _error_collection = "errors" _result_collection = "results" - def __init__(self, db_path: Path, lock: threading.Lock | multiprocessing.Lock): + def __init__(self, db_path: Path): self._db_path = db_path - self.__lock = lock @property def db_path(self) -> Path: return self._db_path - @property - def db_lock(self) -> threading.Lock | multiprocessing.Lock: - return self.__lock - def _write(self, data: Dict): self._db_path.write_text(json.dumps(jsonable_encoder(data))) @@ -103,48 +96,40 @@ def fresh_db(cls, db_path: Path): class MockManager(TaskManager, DBMixin): - def __init__( - self, - db_path: Path, - lock: threading.Lock | multiprocessing.Lock, - max_queue_size: int, - ): - super().__init__(db_path, lock) + def __init__(self, db_path: Path, max_queue_size: int): + super().__init__(db_path) self._max_queue_size = max_queue_size async def _enqueue(self, task: Task, project: str) -> Task: key = self._task_key(task_id=task.id, project=project) - with self.db_lock: - db = self._read() - tasks = db[self._task_collection] - n_queued = sum( - 1 for t in tasks.values() if t["status"] == TaskStatus.QUEUED.value - ) - if n_queued > self._max_queue_size: - raise TaskQueueIsFull(self._max_queue_size) - if key in tasks: - raise TaskAlreadyExists(task.id) - update = {"status": TaskStatus.QUEUED} - task = safe_copy(task, update=update) - tasks[key] = task.dict() - self._write(db) - return task + db = self._read() + tasks = db[self._task_collection] + n_queued = sum( + 1 for t in tasks.values() if t["status"] == TaskStatus.QUEUED.value + ) + if n_queued > self._max_queue_size: + raise TaskQueueIsFull(self._max_queue_size) + if key in tasks: + raise TaskAlreadyExists(task.id) + update = {"status": TaskStatus.QUEUED} + task = safe_copy(task, update=update) + tasks[key] = task.dict() + self._write(db) + return task async def _cancel(self, *, task_id: str, project: str) -> Task: key = self._task_key(task_id=task_id, project=project) task_id = await self.get_task(task_id=task_id, project=project) - with self.db_lock: - update = {"status": TaskStatus.CANCELLED} - task_id = safe_copy(task_id, update=update) - db = self._read() - db[self._task_collection][key] = task_id.dict() - self._write(db) - return task_id + update = {"status": TaskStatus.CANCELLED} + task_id = safe_copy(task_id, update=update) + db = self._read() + db[self._task_collection][key] = task_id.dict() + self._write(db) + return task_id async def get_task(self, *, task_id: str, project: str) -> Task: key = self._task_key(task_id=task_id, project=project) - with self.db_lock: - db = self._read() + db = self._read() try: tasks = db[self._task_collection] return Task(**tasks[key]) @@ -153,8 +138,7 @@ async def get_task(self, *, task_id: str, project: str) -> Task: async def get_task_errors(self, task_id: str, project: str) -> List[TaskError]: key = self._task_key(task_id=task_id, project=project) - with self.db_lock: - db = self._read() + db = self._read() errors = db[self._error_collection] errors = errors.get(key, []) errors = [TaskError(**err) for err in errors] @@ -162,8 +146,7 @@ async def get_task_errors(self, task_id: str, project: str) -> List[TaskError]: async def get_task_result(self, task_id: str, project: str) -> TaskResult: key = self._task_key(task_id=task_id, project=project) - with self.db_lock: - db = self._read() + db = self._read() results = db[self._result_collection] try: return TaskResult(**results[key]) @@ -176,8 +159,7 @@ async def get_tasks( task_type: Optional[str] = None, status: Optional[Union[List[TaskStatus], TaskStatus]] = None, ) -> List[Task]: - with self.db_lock: - db = self._read() + db = self._read() tasks = db.values() if status: if isinstance(status, TaskStatus): @@ -190,8 +172,8 @@ async def get_tasks( class MockEventPublisher(DBMixin, EventPublisher): _excluded_from_event_update = {"error"} - def __init__(self, db_path: Path, lock: threading.Lock | multiprocessing.Lock): - super().__init__(db_path, lock) + def __init__(self, db_path: Path): + super().__init__(db_path) self.published_events = [] async def publish_event(self, event: TaskEvent, project: str): @@ -203,33 +185,32 @@ async def publish_event(self, event: TaskEvent, project: str): # Here we choose to reflect the change in the DB since its closer to what will # happen IRL and test integration further key = self._task_key(task_id=event.task_id, project=project) - with self.db_lock: - db = self._read() - try: - task = self._get_db_task(db, task_id=event.task_id, project=project) - task = Task(**task) - except UnknownTask: - task = Task(**Task.mandatory_fields(event, keep_id=True)) - update = task.resolve_event(event) - if update is not None: - task = task.dict(exclude_unset=True, by_alias=True) - update = { - k: v - for k, v in event.dict(by_alias=True, exclude_unset=True).items() - if v is not None - } - if "taskId" in update: - update["id"] = update.pop("taskId") - if "taskType" in update: - update["type"] = update.pop("taskType") - if "error" in update: - update.pop("error") - # The nack is responsible for bumping the retries - if "retries" in update: - update.pop("retries") - task.update(update) - db[self._task_collection][key] = task - self._write(db) + db = self._read() + try: + task = self._get_db_task(db, task_id=event.task_id, project=project) + task = Task(**task) + except UnknownTask: + task = Task(**Task.mandatory_fields(event, keep_id=True)) + update = task.resolve_event(event) + if update is not None: + task = task.dict(exclude_unset=True, by_alias=True) + update = { + k: v + for k, v in event.dict(by_alias=True, exclude_unset=True).items() + if v is not None + } + if "taskId" in update: + update["id"] = update.pop("taskId") + if "taskType" in update: + update["type"] = update.pop("taskType") + if "error" in update: + update.pop("error") + # The nack is responsible for bumping the retries + if "retries" in update: + update.pop("retries") + task.update(update) + db[self._task_collection][key] = task + self._write(db) def _get_db_task(self, db: Dict, task_id: str, project: str) -> Dict: tasks = db[self._task_collection] @@ -239,11 +220,17 @@ def _get_db_task(self, db: Dict, task_id: str, project: str) -> Dict: raise UnknownTask(task_id) from e -class MockWorkerConfig(WorkerConfig): +class MockWorkerConfig(WorkerConfig, IgnoreExtraModel): type: str = Field(const=True, default=WorkerType.mock) + db_path: Path +class MockServiceConfig(ServiceConfig): + def to_worker_config(self, **kwargs) -> WorkerConfig: + return MockWorkerConfig(db_path=kwargs["db_path"]) + + @Worker.register(WorkerType.mock) class MockWorker(Worker, MockEventPublisher): def __init__( @@ -251,11 +238,10 @@ def __init__( app: AsyncApp, worker_id: str, db_path: Path, - lock: Union[threading.Lock, multiprocessing.Lock], **kwargs, ): super().__init__(app, worker_id, **kwargs) - MockEventPublisher.__init__(self, db_path, lock) + MockEventPublisher.__init__(self, db_path) self._worker_id = worker_id self._logger_ = logging.getLogger(__name__) @@ -269,100 +255,92 @@ def _to_config(self) -> MockWorkerConfig: async def _save_result(self, result: TaskResult, project: str): task_key = self._task_key(task_id=result.task_id, project=project) - with self.db_lock: - db = self._read() - db[self._result_collection][task_key] = result - self._write(db) + db = self._read() + db[self._result_collection][task_key] = result + self._write(db) async def _save_error(self, error: TaskError, task: Task, project: str): task_key = self._task_key(task_id=task.id, project=project) - with self.db_lock: - db = self._read() - errors = db[self._error_collection].get(task_key) - if errors is None: - errors = [] - errors.append(error) - db[self._error_collection][task_key] = errors - self._write(db) + db = self._read() + errors = db[self._error_collection].get(task_key) + if errors is None: + errors = [] + errors.append(error) + db[self._error_collection][task_key] = errors + self._write(db) def _get_db_errors(self, task_id: str, project: str) -> List[TaskError]: key = self._task_key(task_id=task_id, project=project) - with self.db_lock: - db = self._read() - errors = db[self._error_collection] - try: - return errors[key] - except KeyError as e: - raise UnknownTask(task_id) from e + db = self._read() + errors = db[self._error_collection] + try: + return errors[key] + except KeyError as e: + raise UnknownTask(task_id) from e def _get_db_result(self, task_id: str, project: str) -> TaskResult: key = self._task_key(task_id=task_id, project=project) - with self.db_lock: - db = self._read() - try: - errors = db[self._result_collection] - return errors[key] - except KeyError as e: - raise UnknownTask(task_id) from e + db = self._read() + try: + errors = db[self._result_collection] + return errors[key] + except KeyError as e: + raise UnknownTask(task_id) from e async def _acknowledge(self, task: Task, project: str, completed_at: datetime): key = self._task_key(task.id, project) - with self.db_lock: - db = self._read() - tasks = db[self._task_collection] - try: - saved_task = tasks[key] - except KeyError as e: - raise UnknownTask(task.id) from e - saved_task = Task(**saved_task) - update = { - "completed_at": completed_at, - "status": TaskStatus.DONE, - "progress": 100.0, - } - tasks[key] = safe_copy(saved_task, update=update) - self._write(db) + db = self._read() + tasks = db[self._task_collection] + try: + saved_task = tasks[key] + except KeyError as e: + raise UnknownTask(task.id) from e + saved_task = Task(**saved_task) + update = { + "completed_at": completed_at, + "status": TaskStatus.DONE, + "progress": 100.0, + } + tasks[key] = safe_copy(saved_task, update=update) + self._write(db) async def _negatively_acknowledge( self, task: Task, project: str, *, requeue: bool ) -> Task: key = self._task_key(task.id, project) - with self.db_lock: - db = self._read() - tasks = db[self._task_collection] - try: - task = tasks[key] - except KeyError as e: - raise UnknownTask(task_id=task.id) from e - task = Task(**task) - if requeue: - update = { - "status": TaskStatus.QUEUED, - "progress": 0.0, - "retries": task.retries or 0 + 1, - } - else: - update = {"status": TaskStatus.ERROR} - task = safe_copy(task, update=update) - tasks[key] = task - self._write(db) - return task + db = self._read() + tasks = db[self._task_collection] + try: + task = tasks[key] + except KeyError as e: + raise UnknownTask(task_id=task.id) from e + task = Task(**task) + if requeue: + update = { + "status": TaskStatus.QUEUED, + "progress": 0.0, + "retries": task.retries or 0 + 1, + } + else: + update = {"status": TaskStatus.ERROR} + task = safe_copy(task, update=update) + tasks[key] = task + self._write(db) + return task async def _refresh_cancelled(self, project: str): - with self.db_lock: - db = self._read() - tasks = db[self._task_collection] - tasks = [Task(**t) for t in tasks.values()] - cancelled = [t.id for t in tasks if t.status is TaskStatus.CANCELLED] - self._cancelled_[project] = set(cancelled) + db = self._read() + tasks = db[self._task_collection] + tasks = [Task(**t) for t in tasks.values()] + cancelled = [t.id for t in tasks if t.status is TaskStatus.CANCELLED] + self._cancelled_[project] = set(cancelled) async def _consume(self) -> Tuple[Task, str]: while "waiting for some task to be available for some project": - with self.db_lock: - db = self._read() - tasks = db[self._task_collection] - tasks = [(k, Task(**t)) for k, t in tasks.items()] - queued = [(k, t) for k, t in tasks if t.status is TaskStatus.QUEUED] + db = self._read() + tasks = db[self._task_collection] + tasks = [(k, Task(**t)) for k, t in tasks.items()] + queued = [(k, t) for k, t in tasks if t.status is TaskStatus.QUEUED] if queued: k, t = min(queued, key=lambda x: x[1].created_at) project = eval(k)[1] # pylint: disable=eval-used @@ -378,7 +356,9 @@ class Recoverable(ValueError): def test_failing_async_app( test_config: AppConfig, # pylint: disable=unused-argument ) -> AsyncApp: - app = AsyncApp(name="test-app", dependencies=FASTAPI_LIFESPAN_DEPS) + from neo4j_app.app.dependencies import HTTP_SERVICE_LIFESPAN_DEPS + + app = AsyncApp(name="test-app", dependencies=HTTP_SERVICE_LIFESPAN_DEPS) already_failed = False @app.task("recovering_task", recover_from=(Recoverable,)) diff --git a/neo4j-app/neo4j_app/tests/icij_worker/worker/conftest.py b/neo4j-app/neo4j_app/tests/icij_worker/worker/conftest.py index 37b7a326..7e9565b1 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/worker/conftest.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/worker/conftest.py @@ -1,18 +1,16 @@ from __future__ import annotations -import threading from pathlib import Path import pytest -from neo4j_app.app.dependencies import FASTAPI_LIFESPAN_DEPS from neo4j_app.icij_worker import AsyncApp from neo4j_app.tests.icij_worker.conftest import MockWorker @pytest.fixture(scope="module") def test_app() -> AsyncApp: - app = AsyncApp(name="test-app", dependencies=FASTAPI_LIFESPAN_DEPS) + app = AsyncApp(name="test-app", dependencies=[]) @app.task async def hello_word(greeted: str): @@ -22,11 +20,8 @@ async def hello_word(greeted: str): @pytest.fixture(scope="function") -def mock_worker(test_async_app: AsyncApp, tmpdir: Path) -> MockWorker: - db_path = Path(tmpdir) / "db.json" - MockWorker.fresh_db(db_path) - lock = threading.Lock() +def mock_worker(test_async_app: AsyncApp, mock_db: Path) -> MockWorker: worker = MockWorker( - test_async_app, "test-worker", db_path, lock, teardown_dependencies=False + test_async_app, "test-worker", mock_db, teardown_dependencies=False ) return worker diff --git a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py index ba64bc9b..a85f9c11 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py @@ -3,7 +3,6 @@ import asyncio import logging -import threading from datetime import datetime from pathlib import Path from typing import Any, Dict, Optional @@ -26,11 +25,8 @@ @pytest.fixture(scope="function") -def mock_failing_worker(test_failing_async_app: AsyncApp, tmpdir: Path) -> MockWorker: - db_path = Path(tmpdir) / "db.json" - MockWorker.fresh_db(db_path) - lock = threading.Lock() - worker = MockWorker(test_failing_async_app, "test-worker", db_path, lock) +def mock_failing_worker(test_failing_async_app: AsyncApp, mock_db: Path) -> MockWorker: + worker = MockWorker(test_failing_async_app, "test-worker", mock_db) return worker @@ -40,7 +36,7 @@ def mock_failing_worker(test_failing_async_app: AsyncApp, tmpdir: Path) -> MockW async def test_task_wrapper_run_asyncio_task(mock_worker: MockWorker): # Given worker = mock_worker - task_manager = MockManager(worker.db_path, worker.db_lock, max_queue_size=10) + task_manager = MockManager(worker.db_path, max_queue_size=10) project = TEST_PROJECT created_at = datetime.now() task = Task( @@ -96,7 +92,7 @@ async def test_task_wrapper_run_asyncio_task(mock_worker: MockWorker): async def test_task_wrapper_run_sync_task(mock_worker: MockWorker): # Given worker = mock_worker - task_manager = MockManager(worker.db_path, worker.db_lock, max_queue_size=10) + task_manager = MockManager(worker.db_path, max_queue_size=10) project = TEST_PROJECT created_at = datetime.now() task = Task( @@ -152,7 +148,7 @@ async def test_task_wrapper_should_recover_from_recoverable_error( ): # Given worker = mock_failing_worker - task_manager = MockManager(worker.db_path, worker.db_lock, max_queue_size=10) + task_manager = MockManager(worker.db_path, max_queue_size=10) project = TEST_PROJECT created_at = datetime.now() task = Task( @@ -238,7 +234,7 @@ async def test_task_wrapper_should_handle_non_recoverable_error( ): # Given worker = mock_failing_worker - task_manager = MockManager(worker.db_path, worker.db_lock, max_queue_size=10) + task_manager = MockManager(worker.db_path, max_queue_size=10) project = TEST_PROJECT created_at = datetime.now() task = Task( @@ -303,7 +299,7 @@ async def test_task_wrapper_should_handle_non_recoverable_error( async def test_task_wrapper_should_handle_unregistered_task(mock_worker: MockWorker): # Given worker = mock_worker - task_manager = MockManager(worker.db_path, worker.db_lock, max_queue_size=10) + task_manager = MockManager(worker.db_path, max_queue_size=10) project = TEST_PROJECT created_at = datetime.now() task = Task( @@ -370,7 +366,7 @@ async def test_task_wrapper_should_handle_unregistered_task(mock_worker: MockWor async def test_work_once_should_not_run_cancelled_task(mock_worker: MockWorker, caplog): # Given worker = mock_worker - task_manager = MockManager(worker.db_path, worker.db_lock, max_queue_size=10) + task_manager = MockManager(worker.db_path, max_queue_size=10) caplog.set_level(logging.INFO) project = TEST_PROJECT created_at = datetime.now() @@ -397,7 +393,7 @@ async def test_work_once_should_not_run_cancelled_task(mock_worker: MockWorker, async def test_cancel_running_task(mock_worker: MockWorker): # Given worker = mock_worker - task_manager = MockManager(worker.db_path, worker.db_lock, max_queue_size=10) + task_manager = MockManager(worker.db_path, max_queue_size=10) project = TEST_PROJECT created_at = datetime.now() duration = 10 @@ -471,7 +467,7 @@ async def test_worker_acknowledgment_cm_should_not_raise_for_fatal_error( ): # Given worker = mock_worker - task_manager = MockManager(worker.db_path, worker.db_lock, max_queue_size=10) + task_manager = MockManager(worker.db_path, max_queue_size=10) project = TEST_PROJECT created_at = datetime.now() task = Task(