diff --git a/README.md b/README.md index 507c51c..74d6238 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ EOS provides: * Device and sample container allocation system to prevent conflicts * Result aggregation such as automatic output file storage +Documentation is available at [https://unc-robotics.github.io/eos/](https://unc-robotics.github.io/eos/). + ## Installation ### 1. Install PDM diff --git a/docker/.env.example b/docker/.env.example index 3d6b778..2b57dda 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1,16 +1,12 @@ # EOS ##################################### COMPOSE_PROJECT_NAME=eos -# MongoDB root username +# MongoDB admin credentials EOS_MONGO_INITDB_ROOT_USERNAME= - -# MongoDB root user password EOS_MONGO_INITDB_ROOT_PASSWORD= -# MinIO root username +# MinIO admin credentials EOS_MINIO_ROOT_USER= - -# MinIO root user password EOS_MINIO_ROOT_PASSWORD= # Budibase ################################ diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 67e3067..d535c39 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,7 +1,14 @@ services: eos-mongodb: - image: mongo:noble + build: + context: . + dockerfile: mongodb/Dockerfile + args: + - MONGO_INITDB_ROOT_USERNAME=${EOS_MONGO_INITDB_ROOT_USERNAME} + - MONGO_INITDB_ROOT_PASSWORD=${EOS_MONGO_INITDB_ROOT_PASSWORD} + image: eos-mongodb/latest container_name: eos-mongodb + hostname: eos-mongodb restart: unless-stopped environment: MONGO_INITDB_ROOT_USERNAME: ${EOS_MONGO_INITDB_ROOT_USERNAME} @@ -12,10 +19,12 @@ services: - eos_network volumes: - mongodb_data:/data/db + command: ["-f", "/etc/mongod.conf"] eos-minio: image: minio/minio:RELEASE.2024-10-02T17-50-41Z container_name: eos-minio + hostname: eos-minio restart: unless-stopped environment: MINIO_ROOT_USER: ${EOS_MINIO_ROOT_USER} @@ -32,9 +41,8 @@ services: eos-budibase: image: budibase/budibase:2.32.12-sqs container_name: eos-budibase + hostname: eos-budibase restart: unless-stopped - ports: - - "8080:80" environment: JWT_SECRET: ${BB_JWT_SECRET} MINIO_ACCESS_KEY: ${BB_MINIO_ACCESS_KEY} @@ -45,6 +53,8 @@ services: INTERNAL_API_KEY: ${BB_INTERNAL_API_KEY} BB_ADMIN_USER_EMAIL: ${BB_ADMIN_USER_EMAIL} BB_ADMIN_USER_PASSWORD: ${BB_ADMIN_USER_PASSWORD} + ports: + - "8080:80" networks: - eos_network extra_hosts: diff --git a/docker/mongodb/Dockerfile b/docker/mongodb/Dockerfile new file mode 100644 index 0000000..e412e19 --- /dev/null +++ b/docker/mongodb/Dockerfile @@ -0,0 +1,9 @@ +FROM mongo:noble + +COPY mongodb/generate_keyfile.sh /root/generate_keyfile.sh +RUN /bin/bash /root/generate_keyfile.sh + +COPY mongodb/mongod.conf /etc/mongod.conf +COPY mongodb/init_mongodb.js /docker-entrypoint-initdb.d/init_mongodb.js + +CMD ["mongod", "-f", "/etc/mongod.conf"] diff --git a/docker/mongodb/generate_keyfile.sh b/docker/mongodb/generate_keyfile.sh new file mode 100755 index 0000000..a1d32f7 --- /dev/null +++ b/docker/mongodb/generate_keyfile.sh @@ -0,0 +1,9 @@ +#!/bin/bash +MONGO_KEYFILE="/data/mongo-keyfile" + +if [ ! -f "$MONGO_KEYFILE" ]; then + echo "Generating keyfile..." + openssl rand -base64 756 > "$MONGO_KEYFILE" + chmod 400 "$MONGO_KEYFILE" + chown mongodb:mongodb "$MONGO_KEYFILE" +fi diff --git a/docker/mongodb/init_mongodb.js b/docker/mongodb/init_mongodb.js new file mode 100644 index 0000000..8a412e9 --- /dev/null +++ b/docker/mongodb/init_mongodb.js @@ -0,0 +1,21 @@ +// Replica set configuration +var config = { + "_id": "rs0", + "members": [ + { "_id": 0, "host": "localhost:27017" } + ] +} + +rs.initiate(config) + +while (!rs.isMaster().ismaster) { + sleep(1000) +} + +// Create the admin user +var adminDb = db.getSiblingDB('admin'); +adminDb.createUser({ + user: process.env["MONGO_INITDB_ROOT_USERNAME"], + pwd: process.env["MONGO_INITDB_ROOT_PASSWORD"], + roles: [{ role: 'root', db: 'admin' }] +}); diff --git a/docker/mongodb/mongod.conf b/docker/mongodb/mongod.conf new file mode 100644 index 0000000..4739bce --- /dev/null +++ b/docker/mongodb/mongod.conf @@ -0,0 +1,13 @@ +storage: + dbPath: /data/db + +net: + port: 27017 + bindIp: localhost,eos-mongodb + +security: + authorization: enabled + keyFile: /data/mongo-keyfile + +replication: + replSetName: rs0 diff --git a/docs/conf.py b/docs/conf.py index bb7a298..920e1b8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,7 +9,7 @@ project = "eos" copyright = "2024, UNC Robotics" author = "Angelos Angelopoulos" -release = "0.3.0" +release = "0.4.0" extensions = [ "sphinx.ext.autosectionlabel", diff --git a/eos/campaigns/campaign_executor.py b/eos/campaigns/campaign_executor.py index 27703e9..cbe3445 100644 --- a/eos/campaigns/campaign_executor.py +++ b/eos/campaigns/campaign_executor.py @@ -33,6 +33,7 @@ def __init__( self._campaign_id = campaign_id self._experiment_type = experiment_type self._execution_parameters = execution_parameters + self._campaign_manager = campaign_manager self._campaign_optimizer_manager = campaign_optimizer_manager self._task_manager = task_manager @@ -46,7 +47,7 @@ def __init__( self._campaign_status: CampaignStatus | None = None - def _setup_optimizer(self) -> None: + async def _setup_optimizer(self) -> None: if self._optimizer: return @@ -56,7 +57,7 @@ def _setup_optimizer(self) -> None: self._execution_parameters.optimizer_computer_ip, ) self._optimizer_input_names, self._optimizer_output_names = ( - self._campaign_optimizer_manager.get_input_and_output_names(self._campaign_id) + await self._campaign_optimizer_manager.get_input_and_output_names(self._campaign_id) ) def cleanup(self) -> None: @@ -70,13 +71,13 @@ async def start_campaign(self) -> None: """ Start the campaign or handle an existing campaign. """ - campaign = self._campaign_manager.get_campaign(self._campaign_id) + campaign = await self._campaign_manager.get_campaign(self._campaign_id) if campaign: await self._handle_existing_campaign(campaign) else: - self._create_new_campaign() + await self._create_new_campaign() - self._campaign_manager.start_campaign(self._campaign_id) + await self._campaign_manager.start_campaign(self._campaign_id) self._campaign_status = CampaignStatus.RUNNING log.info(f"Started campaign '{self._campaign_id}'.") @@ -87,7 +88,6 @@ async def _handle_existing_campaign(self, campaign: Campaign) -> None: self._campaign_status = campaign.status if not self._execution_parameters.resume: - def _raise_error(status: str) -> None: raise EosCampaignExecutionError( f"Cannot start campaign '{self._campaign_id}' as it already exists and is '{status}'. " @@ -104,27 +104,27 @@ def _raise_error(status: str) -> None: await self._resume_campaign() - def _create_new_campaign(self) -> None: + async def _create_new_campaign(self) -> None: """ Create a new campaign. """ - self._campaign_manager.create_campaign( + await self._campaign_manager.create_campaign( campaign_id=self._campaign_id, experiment_type=self._experiment_type, execution_parameters=self._execution_parameters, ) if self._execution_parameters.do_optimization: - self._setup_optimizer() + await self._setup_optimizer() async def _resume_campaign(self) -> None: """ Resume an existing campaign. """ - self._campaign_manager.delete_current_campaign_experiments(self._campaign_id) + await self._campaign_manager.delete_current_campaign_experiments(self._campaign_id) if self._execution_parameters.do_optimization: - self._setup_optimizer() + await self._setup_optimizer() await self._restore_optimizer_state() log.info(f"Campaign '{self._campaign_id}' resumed.") @@ -133,7 +133,7 @@ async def _restore_optimizer_state(self) -> None: """ Restore the optimizer state for a resumed campaign. """ - completed_experiment_ids = self._campaign_manager.get_campaign_experiment_ids( + completed_experiment_ids = await self._campaign_manager.get_campaign_experiment_ids( self._campaign_id, status=ExperimentStatus.COMPLETED ) @@ -150,7 +150,7 @@ async def cancel_campaign(self) -> None: """ Cancel the campaign and all running experiments. """ - campaign = self._campaign_manager.get_campaign(self._campaign_id) + campaign = await self._campaign_manager.get_campaign(self._campaign_id) if not campaign or campaign.status != CampaignStatus.RUNNING: raise EosCampaignExecutionError( f"Cannot cancel campaign '{self._campaign_id}' with status " @@ -158,7 +158,7 @@ async def cancel_campaign(self) -> None: ) log.warning(f"Cancelling campaign '{self._campaign_id}'...") - self._campaign_manager.cancel_campaign(self._campaign_id) + await self._campaign_manager.cancel_campaign(self._campaign_id) self._campaign_status = CampaignStatus.CANCELLED await self._cancel_running_experiments() @@ -194,18 +194,18 @@ async def progress_campaign(self) -> bool: await self._progress_experiments() - campaign = self._campaign_manager.get_campaign(self._campaign_id) + campaign = await self._campaign_manager.get_campaign(self._campaign_id) if self._is_campaign_completed(campaign): if self._execution_parameters.do_optimization: await self._compute_pareto_solutions() - self._campaign_manager.complete_campaign(self._campaign_id) + await self._campaign_manager.complete_campaign(self._campaign_id) return True await self._create_experiments(campaign) return False except EosExperimentExecutionError as e: - self._campaign_manager.fail_campaign(self._campaign_id) + await self._campaign_manager.fail_campaign(self._campaign_id) self._campaign_status = CampaignStatus.FAILED raise EosCampaignExecutionError(f"Error executing campaign '{self._campaign_id}'") from e @@ -225,8 +225,8 @@ async def _progress_experiments(self) -> None: for experiment_id in completed_experiments: del self._experiment_executors[experiment_id] - self._campaign_manager.delete_campaign_experiment(self._campaign_id, experiment_id) - self._campaign_manager.increment_iteration(self._campaign_id) + await self._campaign_manager.delete_campaign_experiment(self._campaign_id, experiment_id) + await self._campaign_manager.increment_iteration(self._campaign_id) async def _process_completed_experiments(self, completed_experiments: list[str]) -> None: """ @@ -234,7 +234,7 @@ async def _process_completed_experiments(self, completed_experiments: list[str]) """ inputs_df, outputs_df = await self._collect_experiment_results(completed_experiments) await self._optimizer.report.remote(inputs_df, outputs_df) - self._campaign_optimizer_manager.record_campaign_samples( + await self._campaign_optimizer_manager.record_campaign_samples( self._campaign_id, completed_experiments, inputs_df, outputs_df ) @@ -248,11 +248,12 @@ async def _collect_experiment_results(self, experiment_ids: list[str]) -> tuple[ for experiment_id in experiment_ids: for input_name in self._optimizer_input_names: reference_task_id, parameter_name = input_name.split(".") - task = self._task_manager.get_task(experiment_id, reference_task_id) + task = await self._task_manager.get_task(experiment_id, reference_task_id) inputs[input_name].append(float(task.input.parameters[parameter_name])) for output_name in self._optimizer_output_names: reference_task_id, parameter_name = output_name.split(".") - output_parameters = self._task_manager.get_task_output(experiment_id, reference_task_id).parameters + task_output = await self._task_manager.get_task_output(experiment_id, reference_task_id) + output_parameters = task_output.parameters outputs[output_name].append(float(output_parameters[parameter_name])) return pd.DataFrame(inputs), pd.DataFrame(outputs) @@ -271,9 +272,9 @@ async def _create_experiments(self, campaign: Campaign) -> None: experiment_executor = self._experiment_executor_factory.create( new_experiment_id, self._experiment_type, experiment_execution_parameters ) - self._campaign_manager.add_campaign_experiment(self._campaign_id, new_experiment_id) + await self._campaign_manager.add_campaign_experiment(self._campaign_id, new_experiment_id) self._experiment_executors[new_experiment_id] = experiment_executor - experiment_executor.start_experiment(experiment_dynamic_parameters) + await experiment_executor.start_experiment(experiment_dynamic_parameters) async def _get_experiment_parameters(self, iteration: int) -> dict[str, Any]: """ @@ -324,7 +325,7 @@ async def _compute_pareto_solutions(self) -> None: try: pareto_solutions_df = await self._optimizer.get_optimal_solutions.remote() pareto_solutions = pareto_solutions_df.to_dict(orient="records") - self._campaign_manager.set_pareto_solutions(self._campaign_id, pareto_solutions) + await self._campaign_manager.set_pareto_solutions(self._campaign_id, pareto_solutions) except Exception as e: raise EosCampaignExecutionError(f"CMP '{self._campaign_id}' - Error computing Pareto solutions.") from e diff --git a/eos/campaigns/campaign_manager.py b/eos/campaigns/campaign_manager.py index 3cc076a..5349072 100644 --- a/eos/campaigns/campaign_manager.py +++ b/eos/campaigns/campaign_manager.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timezone from typing import Any @@ -8,7 +9,7 @@ from eos.experiments.entities.experiment import ExperimentStatus from eos.experiments.repositories.experiment_repository import ExperimentRepository from eos.logging.logger import log -from eos.persistence.db_manager import DbManager +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface from eos.tasks.repositories.task_repository import TaskRepository @@ -17,16 +18,26 @@ class CampaignManager: Responsible for managing the state of all experiment campaigns in EOS and tracking their execution. """ - def __init__(self, configuration_manager: ConfigurationManager, db_manager: DbManager): + def __init__(self, configuration_manager: ConfigurationManager, db_interface: AsyncMongoDbInterface): self._configuration_manager = configuration_manager - self._campaigns = CampaignRepository("campaigns", db_manager) - self._campaigns.create_indices([("id", 1)], unique=True) - self._experiments = ExperimentRepository("experiments", db_manager) - self._tasks = TaskRepository("tasks", db_manager) + self._session_factory = db_interface.session_factory + self._campaigns = None + self._experiments = None + self._tasks = None + + async def initialize(self, db_interface: AsyncMongoDbInterface) -> None: + self._campaigns = CampaignRepository(db_interface) + await self._campaigns.initialize() + + self._experiments = ExperimentRepository(db_interface) + await self._experiments.initialize() + + self._tasks = TaskRepository(db_interface) + await self._tasks.initialize() log.debug("Campaign manager initialized.") - def create_campaign( + async def create_campaign( self, campaign_id: str, experiment_type: str, @@ -41,7 +52,7 @@ def create_campaign( :param execution_parameters: Parameters for the execution of the campaign. :param metadata: Additional metadata to be stored with the campaign. """ - if self._campaigns.get_one(id=campaign_id): + if await self._campaigns.get_one(id=campaign_id): raise EosCampaignStateError(f"Campaign '{campaign_id}' already exists.") experiment_config = self._configuration_manager.experiments.get(experiment_type) @@ -54,68 +65,68 @@ def create_campaign( execution_parameters=execution_parameters, metadata=metadata or {}, ) - self._campaigns.create(campaign.model_dump()) + await self._campaigns.create(campaign.model_dump()) log.info(f"Created campaign '{campaign_id}'.") - def delete_campaign(self, campaign_id: str) -> None: + async def delete_campaign(self, campaign_id: str) -> None: """ Delete a campaign. """ if not self._campaigns.exists(id=campaign_id): raise EosCampaignStateError(f"Campaign '{campaign_id}' does not exist.") - self._campaigns.delete(id=campaign_id) + await self._campaigns.delete_one(id=campaign_id) log.info(f"Deleted campaign '{campaign_id}'.") - def start_campaign(self, campaign_id: str) -> None: + async def start_campaign(self, campaign_id: str) -> None: """ Start a campaign. """ - self._set_campaign_status(campaign_id, CampaignStatus.RUNNING) + await self._set_campaign_status(campaign_id, CampaignStatus.RUNNING) - def complete_campaign(self, campaign_id: str) -> None: + async def complete_campaign(self, campaign_id: str) -> None: """ Complete a campaign. """ - self._set_campaign_status(campaign_id, CampaignStatus.COMPLETED) + await self._set_campaign_status(campaign_id, CampaignStatus.COMPLETED) - def cancel_campaign(self, campaign_id: str) -> None: + async def cancel_campaign(self, campaign_id: str) -> None: """ Cancel a campaign. """ - self._set_campaign_status(campaign_id, CampaignStatus.CANCELLED) + await self._set_campaign_status(campaign_id, CampaignStatus.CANCELLED) - def suspend_campaign(self, campaign_id: str) -> None: + async def suspend_campaign(self, campaign_id: str) -> None: """ Suspend a campaign. """ - self._set_campaign_status(campaign_id, CampaignStatus.SUSPENDED) + await self._set_campaign_status(campaign_id, CampaignStatus.SUSPENDED) - def fail_campaign(self, campaign_id: str) -> None: + async def fail_campaign(self, campaign_id: str) -> None: """ Fail a campaign. """ - self._set_campaign_status(campaign_id, CampaignStatus.FAILED) + await self._set_campaign_status(campaign_id, CampaignStatus.FAILED) - def get_campaign(self, campaign_id: str) -> Campaign | None: + async def get_campaign(self, campaign_id: str) -> Campaign | None: """ Get a campaign. """ - campaign = self._campaigns.get_one(id=campaign_id) + campaign = await self._campaigns.get_one(id=campaign_id) return Campaign(**campaign) if campaign else None - def get_campaigns(self, **query: dict[str, Any]) -> list[Campaign]: + async def get_campaigns(self, **query: dict[str, Any]) -> list[Campaign]: """ Query campaigns with arbitrary parameters. :param query: Dictionary of query parameters. """ - campaigns = self._campaigns.get_all(**query) + campaigns = await self._campaigns.get_all(**query) return [Campaign(**campaign) for campaign in campaigns] - def _set_campaign_status(self, campaign_id: str, new_status: CampaignStatus) -> None: + async def _set_campaign_status(self, campaign_id: str, new_status: CampaignStatus) -> None: """ Set the status of a campaign. """ @@ -129,39 +140,41 @@ def _set_campaign_status(self, campaign_id: str, new_status: CampaignStatus) -> ]: update_fields["end_time"] = datetime.now(tz=timezone.utc) - self._campaigns.update(update_fields, id=campaign_id) + await self._campaigns.update_one(update_fields, id=campaign_id) - def increment_iteration(self, campaign_id: str) -> None: + async def increment_iteration(self, campaign_id: str) -> None: """ Increment the iteration count of a campaign. """ - self._campaigns.increment_campaign_iteration(campaign_id) + await self._campaigns.increment_campaign_iteration(campaign_id) - def add_campaign_experiment(self, campaign_id: str, experiment_id: str) -> None: + async def add_campaign_experiment(self, campaign_id: str, experiment_id: str) -> None: """ Add an experiment to a campaign. """ - self._campaigns.add_current_experiment(campaign_id, experiment_id) + await self._campaigns.add_current_experiment(campaign_id, experiment_id) - def delete_campaign_experiment(self, campaign_id: str, experiment_id: str) -> None: + async def delete_campaign_experiment(self, campaign_id: str, experiment_id: str) -> None: """ Remove an experiment from a campaign. """ - self._campaigns.remove_current_experiment(campaign_id, experiment_id) + await self._campaigns.remove_current_experiment(campaign_id, experiment_id) - def delete_current_campaign_experiments(self, campaign_id: str) -> None: + async def delete_current_campaign_experiments(self, campaign_id: str) -> None: """ Delete all current experiments from a campaign. """ - campaign = self.get_campaign(campaign_id) + campaign = await self.get_campaign(campaign_id) for experiment_id in campaign.current_experiment_ids: - self._experiments.delete(id=experiment_id) - self._tasks.delete(experiment_id=experiment_id) + await asyncio.gather( + self._experiments.delete_one(id=experiment_id), + self._tasks.delete_many(experiment_id=experiment_id), + ) - self._campaigns.clear_current_experiments(campaign_id) + await self._campaigns.clear_current_experiments(campaign_id) - def get_campaign_experiment_ids(self, campaign_id: str, status: ExperimentStatus | None = None) -> list[str]: + async def get_campaign_experiment_ids(self, campaign_id: str, status: ExperimentStatus | None = None) -> list[str]: """ Get all experiment IDs of a campaign with an optional status filter. @@ -169,10 +182,10 @@ def get_campaign_experiment_ids(self, campaign_id: str, status: ExperimentStatus :param status: Optional status to filter experiments. :return: A list of experiment IDs. """ - return self._experiments.get_experiment_ids_by_campaign(campaign_id, status) + return await self._experiments.get_experiment_ids_by_campaign(campaign_id, status) - def set_pareto_solutions(self, campaign_id: str, pareto_solutions: dict[str, Any]) -> None: + async def set_pareto_solutions(self, campaign_id: str, pareto_solutions: dict[str, Any]) -> None: """ Set the Pareto solutions for a campaign. """ - self._campaigns.update({"pareto_solutions": pareto_solutions}, id=campaign_id) + await self._campaigns.update_one({"pareto_solutions": pareto_solutions}, id=campaign_id) diff --git a/eos/campaigns/campaign_optimizer_manager.py b/eos/campaigns/campaign_optimizer_manager.py index 6ce37d9..61907b8 100644 --- a/eos/campaigns/campaign_optimizer_manager.py +++ b/eos/campaigns/campaign_optimizer_manager.py @@ -1,13 +1,15 @@ +import asyncio + import pandas as pd import ray from ray.actor import ActorHandle from eos.campaigns.entities.campaign import CampaignSample +from eos.campaigns.repositories.campaign_samples_repository import CampaignSamplesRepository from eos.configuration.configuration_manager import ConfigurationManager from eos.logging.logger import log from eos.optimization.sequential_optimizer_actor import SequentialOptimizerActor -from eos.persistence.db_manager import DbManager -from eos.persistence.mongo_repository import MongoRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface class CampaignOptimizerManager: @@ -15,13 +17,15 @@ class CampaignOptimizerManager: Responsible for managing the optimizers associated with experiment campaigns. """ - def __init__(self, configuration_manager: ConfigurationManager, db_manager: DbManager): - self._campaign_samples = MongoRepository("campaign_samples", db_manager) - self._campaign_samples.create_indices([("campaign_id", 1), ("experiment_id", 1)], unique=True) - + def __init__(self, configuration_manager: ConfigurationManager, db_interface: AsyncMongoDbInterface): + self._session_factory = db_interface.session_factory self._campaign_optimizer_plugin_registry = configuration_manager.campaign_optimizers - self._optimizer_actors: dict[str, ActorHandle] = {} + self._campaign_samples = None + + async def initialize(self, db_interface: AsyncMongoDbInterface) -> None: + self._campaign_samples = CampaignSamplesRepository(db_interface) + await self._campaign_samples.initialize() log.debug("Campaign optimizer manager initialized.") @@ -67,7 +71,7 @@ def get_campaign_optimizer_actor(self, campaign_id: str) -> ActorHandle: """ return self._optimizer_actors[campaign_id] - def get_input_and_output_names(self, campaign_id: str) -> tuple[list[str], list[str]]: + async def get_input_and_output_names(self, campaign_id: str) -> tuple[list[str], list[str]]: """ Get the input and output names from an optimizer associated with a campaign. @@ -76,13 +80,13 @@ def get_input_and_output_names(self, campaign_id: str) -> tuple[list[str], list[ """ optimizer_actor = self._optimizer_actors[campaign_id] - input_names, output_names = ray.get( - [optimizer_actor.get_input_names.remote(), optimizer_actor.get_output_names.remote()] + input_names, output_names = await asyncio.gather( + optimizer_actor.get_input_names.remote(), optimizer_actor.get_output_names.remote() ) return input_names, output_names - def record_campaign_samples( + async def record_campaign_samples( self, campaign_id: str, experiment_ids: list[str], @@ -112,12 +116,12 @@ def record_campaign_samples( ] for campaign_sample in campaign_samples: - self._campaign_samples.create(campaign_sample.model_dump()) + await self._campaign_samples.create(campaign_sample.model_dump()) - def delete_campaign_samples(self, campaign_id: str) -> None: + async def delete_campaign_samples(self, campaign_id: str) -> None: """ Delete all campaign samples for a campaign. :param campaign_id: The ID of the campaign. """ - self._campaign_samples.delete(campaign_id=campaign_id) + await self._campaign_samples.delete_many(campaign_id=campaign_id) diff --git a/eos/campaigns/repositories/campaign_repository.py b/eos/campaigns/repositories/campaign_repository.py index f7acb9a..83aa2a8 100644 --- a/eos/campaigns/repositories/campaign_repository.py +++ b/eos/campaigns/repositories/campaign_repository.py @@ -1,30 +1,44 @@ +from motor.core import AgnosticClientSession + from eos.campaigns.exceptions import EosCampaignStateError -from eos.persistence.mongo_repository import MongoRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository + +class CampaignRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("campaigns", db_interface) -class CampaignRepository(MongoRepository): - def increment_campaign_iteration(self, campaign_id: str) -> None: - result = self._collection.update_one({"id": campaign_id}, {"$inc": {"experiments_completed": 1}}) + async def initialize(self) -> None: + await self.create_indices([("id", 1)], unique=True) + + async def increment_campaign_iteration( + self, campaign_id: str, session: AgnosticClientSession | None = None + ) -> None: + result = await self._collection.update_one( + {"id": campaign_id}, {"$inc": {"experiments_completed": 1}}, session=session + ) if result.matched_count == 0: raise EosCampaignStateError( f"Cannot increment the iteration of campaign '{campaign_id}' as it does not exist." ) - def add_current_experiment(self, campaign_id: str, experiment_id: str) -> None: - self._collection.update_one( - {"id": campaign_id}, - {"$addToSet": {"current_experiment_ids": experiment_id}}, + async def add_current_experiment( + self, campaign_id: str, experiment_id: str, session: AgnosticClientSession | None = None + ) -> None: + await self._collection.update_one( + {"id": campaign_id}, {"$addToSet": {"current_experiment_ids": experiment_id}}, session=session ) - def remove_current_experiment(self, campaign_id: str, experiment_id: str) -> None: - self._collection.update_one( - {"id": campaign_id}, - {"$pull": {"current_experiment_ids": experiment_id}}, + async def remove_current_experiment( + self, campaign_id: str, experiment_id: str, session: AgnosticClientSession | None = None + ) -> None: + await self._collection.update_one( + {"id": campaign_id}, {"$pull": {"current_experiment_ids": experiment_id}}, session=session ) - def clear_current_experiments(self, campaign_id: str) -> None: - self._collection.update_one( - {"id": campaign_id}, - {"$set": {"current_experiment_ids": []}}, + async def clear_current_experiments(self, campaign_id: str, session: AgnosticClientSession | None = None) -> None: + await self._collection.update_one( + {"id": campaign_id}, {"$set": {"current_experiment_ids": []}}, session=session ) diff --git a/eos/campaigns/repositories/campaign_samples_repository.py b/eos/campaigns/repositories/campaign_samples_repository.py new file mode 100644 index 0000000..752c246 --- /dev/null +++ b/eos/campaigns/repositories/campaign_samples_repository.py @@ -0,0 +1,10 @@ +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository + + +class CampaignSamplesRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("campaign_samples", db_interface) + + async def initialize(self) -> None: + await self.create_indices([("campaign_id", 1), ("experiment_id", 1)], unique=True) diff --git a/eos/cli/orchestrator_cli.py b/eos/cli/orchestrator_cli.py index 8d59d6f..779b5c7 100644 --- a/eos/cli/orchestrator_cli.py +++ b/eos/cli/orchestrator_cli.py @@ -24,43 +24,31 @@ from eos.web_api.orchestrator.controllers.task_controller import TaskController from eos.web_api.orchestrator.exception_handling import global_exception_handler -default_config = { - "user_dir": "./user", - "labs": [], - "experiments": [], - "log_level": "INFO", - "web_api": { - "host": "localhost", - "port": 8070, - }, - "db": { - "host": "localhost", - "port": 27017, - "username": None, - "password": None, - }, - "file_db": { - "host": "localhost", - "port": 9004, - "username": None, - "password": None, - }, -} - -eos_banner = r"""The Experiment Orchestration System - ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄ -▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌ -▐░█▀▀▀▀▀▀▀▀▀ ▐░█▀▀▀▀▀▀▀█░▌▐░█▀▀▀▀▀▀▀▀▀ -▐░█▄▄▄▄▄▄▄▄▄ ▐░▌ ▐░▌▐░█▄▄▄▄▄▄▄▄▄ -▐░░░░░░░░░░░▌▐░▌ ▐░▌▐░░░░░░░░░░░▌ -▐░█▀▀▀▀▀▀▀▀▀ ▐░▌ ▐░▌ ▀▀▀▀▀▀▀▀▀█░▌ -▐░█▄▄▄▄▄▄▄▄▄ ▐░█▄▄▄▄▄▄▄█░▌ ▄▄▄▄▄▄▄▄▄█░▌ -▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌ - ▀▀▀▀▀▀▀▀▀▀▀ ▀▀▀▀▀▀▀▀▀▀▀ ▀▀▀▀▀▀▀▀▀▀▀ -""" - def load_config(config_file: str) -> DictConfig: + default_config = { + "user_dir": "./user", + "labs": [], + "experiments": [], + "log_level": "INFO", + "web_api": { + "host": "localhost", + "port": 8070, + }, + "db": { + "host": "localhost", + "port": 27017, + "username": None, + "password": None, + }, + "file_db": { + "host": "localhost", + "port": 9004, + "username": None, + "password": None, + }, + } + if not Path(config_file).exists(): raise FileNotFoundError(f"Config file '{config_file}' does not exist") return OmegaConf.merge(OmegaConf.create(default_config), OmegaConf.load(config_file)) @@ -100,17 +88,56 @@ def signal_handler(*_) -> None: await web_api_server.shutdown() log.info("Shutting down the orchestrator...") - orchestrator.terminate() + await orchestrator.terminate() log.info("EOS shut down.") -async def run_all(orchestrator: Orchestrator, web_api_server: uvicorn.Server) -> None: - async with handle_shutdown(orchestrator, web_api_server): - orchestrator_task = asyncio.create_task(orchestrator.spin()) - web_server_task = asyncio.create_task(web_api_server.serve()) +async def setup_orchestrator(config: DictConfig) -> Orchestrator: + db_credentials = ServiceCredentials(**config.db) + file_db_credentials = ServiceCredentials(**config.file_db) + + orchestrator = Orchestrator(config.user_dir, db_credentials, file_db_credentials) + await orchestrator.initialize() + await orchestrator.load_labs(config.labs) + orchestrator.load_experiments(config.experiments) + + return orchestrator + + +def setup_web_api(orchestrator: Orchestrator, config: DictConfig) -> uvicorn.Server: + litestar_logging_config = LoggingConfig( + configure_root_logger=False, + loggers={"litestar": {"level": "CRITICAL"}}, + ) + os.environ["LITESTAR_WARN_IMPLICIT_SYNC_TO_THREAD"] = "0" + + api_router = Router( + path="/api", + route_handlers=[TaskController, ExperimentController, CampaignController, LabController, FileController], + dependencies={"orchestrator": Provide(lambda: orchestrator)}, + exception_handlers={Exception: global_exception_handler}, + ) + + web_api_app = Litestar( + route_handlers=[api_router], + logging_config=litestar_logging_config, + exception_handlers={Exception: global_exception_handler}, + ) - await asyncio.gather(orchestrator_task, web_server_task) + uv_config = uvicorn.Config(web_api_app, host=config.web_api.host, port=config.web_api.port, log_level="critical") + + return uvicorn.Server(uv_config) + + +async def run_eos(config: DictConfig) -> None: + orchestrator = await setup_orchestrator(config) + web_api_server = setup_web_api(orchestrator, config) + + log.info("EOS initialized.") + + async with handle_shutdown(orchestrator, web_api_server): + await asyncio.gather(orchestrator.spin(), web_api_server.serve()) def start_orchestrator( @@ -132,55 +159,31 @@ def start_orchestrator( ) = None, log_level: Annotated[LogLevel, typer.Option("--log-level", "-v", help="Logging level")] = None, ) -> None: - - typer.echo(eos_banner) + typer.echo(EOS_BANNER) file_config = load_config(config_file) - cli_config = {} - if user_dir is not None: - cli_config["user_dir"] = user_dir - if labs is not None: - cli_config["labs"] = parse_list_arg(labs) - if experiments is not None: - cli_config["experiments"] = parse_list_arg(experiments) - if log_level is not None: - cli_config["log_level"] = log_level.value + cli_config = { + "user_dir": user_dir, + "labs": parse_list_arg(labs) if labs else None, + "experiments": parse_list_arg(experiments) if experiments else None, + "log_level": log_level.value if log_level else None, + } + cli_config = {k: v for k, v in cli_config.items() if v is not None} config = OmegaConf.merge(file_config, OmegaConf.create(cli_config)) log.set_level(config.log_level) - # Set up the orchestrator - db_credentials = ServiceCredentials(**config.db) - file_db_credentials = ServiceCredentials(**config.file_db) - orchestrator = Orchestrator(config.user_dir, db_credentials, file_db_credentials) - orchestrator.load_labs(config.labs) - orchestrator.load_experiments(config.experiments) - log.info("EOS initialized.") - - # Set up the web API server - logging_config = LoggingConfig( - configure_root_logger=False, - loggers={ - "litestar": {"level": "CRITICAL"}, - }, - ) - os.environ["LITESTAR_WARN_IMPLICIT_SYNC_TO_THREAD"] = "0" - - def orchestrator_provider() -> Orchestrator: - return orchestrator + asyncio.run(run_eos(config)) - api_router = Router( - path="/api", - route_handlers=[TaskController, ExperimentController, CampaignController, LabController, FileController], - dependencies={"orchestrator": Provide(orchestrator_provider)}, - exception_handlers={Exception: global_exception_handler}, - ) - web_api_app = Litestar( - route_handlers=[api_router], - logging_config=logging_config, - exception_handlers={Exception: global_exception_handler}, - ) - config = uvicorn.Config(web_api_app, host=config.web_api.host, port=config.web_api.port, log_level="critical") - web_api_server = uvicorn.Server(config) - asyncio.run(run_all(orchestrator, web_api_server)) +EOS_BANNER = r"""The Experiment Orchestration System + ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄▄▄▄ +▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌ +▐░█▀▀▀▀▀▀▀▀▀ ▐░█▀▀▀▀▀▀▀█░▌▐░█▀▀▀▀▀▀▀▀▀ +▐░█▄▄▄▄▄▄▄▄▄ ▐░▌ ▐░▌▐░█▄▄▄▄▄▄▄▄▄ +▐░░░░░░░░░░░▌▐░▌ ▐░▌▐░░░░░░░░░░░▌ +▐░█▀▀▀▀▀▀▀▀▀ ▐░▌ ▐░▌ ▀▀▀▀▀▀▀▀▀█░▌ +▐░█▄▄▄▄▄▄▄▄▄ ▐░█▄▄▄▄▄▄▄█░▌ ▄▄▄▄▄▄▄▄▄█░▌ +▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌▐░░░░░░░░░░░▌ + ▀▀▀▀▀▀▀▀▀▀▀ ▀▀▀▀▀▀▀▀▀▀▀ ▀▀▀▀▀▀▀▀▀▀▀ +""" diff --git a/eos/cli/web_api_cli.py b/eos/cli/web_api_cli.py index 99f108e..a0f94d3 100644 --- a/eos/cli/web_api_cli.py +++ b/eos/cli/web_api_cli.py @@ -10,7 +10,7 @@ def start_web_api( - host: Annotated[str, typer.Option("--host", help="Host for the EOS web API server")] = "0.0.0.0", + host: Annotated[str, typer.Option("--host", help="Host for the EOS web API server")] = "localhost", port: Annotated[int, typer.Option("--port", help="Port for the EOS web API server")] = 8000, orchestrator_host: Annotated[ str, typer.Option("--orchestrator-host", help="Host for the EOS orchestrator server") diff --git a/eos/configuration/configuration_manager.py b/eos/configuration/configuration_manager.py index 5351c6a..291e00a 100644 --- a/eos/configuration/configuration_manager.py +++ b/eos/configuration/configuration_manager.py @@ -222,4 +222,4 @@ def _unload_experiments_associated_with_labs(self, lab_names: set[str]) -> None: for experiment_name in experiments_to_remove: self.unload_experiment(experiment_name) - log.info(f"Unloaded experiment '{experiment_name}' as it was associated with lab(s) {lab_names}") + log.debug(f"Unloaded experiment '{experiment_name}' as it was associated with lab(s) {lab_names}") diff --git a/eos/configuration/entities/parameters.py b/eos/configuration/entities/parameters.py index 9c31b7c..7b80a20 100644 --- a/eos/configuration/entities/parameters.py +++ b/eos/configuration/entities/parameters.py @@ -10,200 +10,200 @@ def is_dynamic_parameter(parameter: AllowedParameterTypes) -> bool: - return isinstance(parameter, str) and parameter.lower() == "eos_dynamic" + return isinstance(parameter, str) and parameter.lower() == "eos_dynamic" class ParameterType(Enum): - integer = "integer" - decimal = "decimal" - string = "string" - boolean = "boolean" - choice = "choice" - list = "list" - dictionary = "dictionary" - - def python_type(self) -> type: - mapping = { - "integer": int, - "decimal": float, - "string": str, - "boolean": bool, - "choice": str, - "list": list, - "dictionary": dict, - } - return mapping[self.value] + integer = "integer" + decimal = "decimal" + string = "string" + boolean = "boolean" + choice = "choice" + list = "list" + dictionary = "dictionary" + + def python_type(self) -> type: + mapping = { + "integer": int, + "decimal": float, + "string": str, + "boolean": bool, + "choice": str, + "list": list, + "dictionary": dict, + } + return mapping[self.value] @dataclass(kw_only=True) class Parameter: - type: ParameterType - description: str - value: Any | None = None + type: ParameterType + description: str + value: Any | None = None - def __post_init__(self): - self._validate_type() + def __post_init__(self): + self._validate_type() - def _validate_type(self) -> None: - try: - self.type = ParameterType(self.type) - except ValueError as e: - raise EosConfigurationError(f"Invalid task parameter type '{self.type}'") from e + def _validate_type(self) -> None: + try: + self.type = ParameterType(self.type) + except ValueError as e: + raise EosConfigurationError(f"Invalid task parameter type '{self.type}'") from e @dataclass(kw_only=True) class NumericParameter(Parameter): - unit: str - min: int | float | None = None - max: int | float | None = None + unit: str + min: int | float | None = None + max: int | float | None = None - def __post_init__(self): - super().__post_init__() - self._validate_unit() - self._validate_min_max() - self._validate_value_range() + def __post_init__(self): + super().__post_init__() + self._validate_unit() + self._validate_min_max() + self._validate_value_range() - def _validate_unit(self) -> None: - if not self.unit: - raise EosConfigurationError("Task parameter type is numeric but no unit is specified.") + def _validate_unit(self) -> None: + if not self.unit: + raise EosConfigurationError("Task parameter type is numeric but no unit is specified.") - def _validate_min_max(self) -> None: - if self.min is not None and self.max is not None and self.min >= self.max: - raise EosConfigurationError("Task parameter 'min' is greater than or equal to 'max'.") + def _validate_min_max(self) -> None: + if self.min is not None and self.max is not None and self.min >= self.max: + raise EosConfigurationError("Task parameter 'min' is greater than or equal to 'max'.") - def _validate_value_range(self) -> None: - if self.value is None or is_dynamic_parameter(self.value): - return + def _validate_value_range(self) -> None: + if self.value is None or is_dynamic_parameter(self.value): + return - if not isinstance(self.value, int | float): - raise EosConfigurationError("Task parameter value is not numerical.") - if self.min is not None and self.value < self.min: - raise EosConfigurationError("Task parameter value is less than 'min'.") - if self.max is not None and self.value > self.max: - raise EosConfigurationError("Task parameter value is greater than 'max'.") + if not isinstance(self.value, int | float): + raise EosConfigurationError("Task parameter value is not numerical.") + if self.min is not None and self.value < self.min: + raise EosConfigurationError("Task parameter value is less than 'min'.") + if self.max is not None and self.value > self.max: + raise EosConfigurationError("Task parameter value is greater than 'max'.") @dataclass(kw_only=True) class BooleanParameter(Parameter): - def __post_init__(self): - super().__post_init__() - self._validate_value() + def __post_init__(self): + super().__post_init__() + self._validate_value() - def _validate_value(self) -> None: - if not isinstance(self.value, bool) and not is_dynamic_parameter(self.value): - raise EosConfigurationError( - f"Task parameter value '{self.value}' is not true/false but the declared type is 'boolean'." - ) + def _validate_value(self) -> None: + if not isinstance(self.value, bool) and not is_dynamic_parameter(self.value): + raise EosConfigurationError( + f"Task parameter value '{self.value}' is not true/false but the declared type is 'boolean'." + ) @dataclass(kw_only=True) class ChoiceParameter(Parameter): - choices: list[str] + choices: list[str] - def __post_init__(self): - super().__post_init__() - self._validate_choices() + def __post_init__(self): + super().__post_init__() + self._validate_choices() - def _validate_choices(self) -> None: - if not self.choices: - raise EosConfigurationError("Task parameter choices are not specified when the type is 'choice'.") + def _validate_choices(self) -> None: + if not self.choices: + raise EosConfigurationError("Task parameter choices are not specified when the type is 'choice'.") - if ( - not self.value - or len(self.value) == 0 - or self.value not in self.choices - and not is_dynamic_parameter(self.value) - ): - raise EosConfigurationError( - f"Task parameter value '{self.value}' is not one of the choices {self.choices}." - ) + if ( + not self.value + or len(self.value) == 0 + or self.value not in self.choices + and not is_dynamic_parameter(self.value) + ): + raise EosConfigurationError( + f"Task parameter value '{self.value}' is not one of the choices {self.choices}." + ) @dataclass(kw_only=True) class ListParameter(Parameter): - element_type: ParameterType - length: int | None = None - min: list[int | float] | None = None - max: list[int | float] | None = None - - def __post_init__(self): - super().__post_init__() - self._validate_element_type() - self._validate_list_attributes() - self._validate_elements_within_bounds() - - def _validate_element_type(self) -> None: - if isinstance(self.element_type, str): - try: - self.element_type = ParameterType[self.element_type] - except KeyError as e: - raise EosConfigurationError(f"Invalid list parameter element type '{self.element_type}'") from e - if self.element_type == ParameterType.list: - raise EosConfigurationError("List parameter element type cannot be 'list'. Nested lists are not supported.") - - def _validate_list_attributes(self) -> None: - for attr_name in ["value", "min", "max"]: - attr_value = getattr(self, attr_name) - if attr_value is None: - continue - - if not isinstance(attr_value, list) and not isinstance(attr_value, ListConfig): - raise EosConfigurationError( - f"List parameter '{attr_name}' must be a list for 'list' type parameters.", - EosConfigurationError, - ) - if not all(isinstance(item, self.element_type.python_type()) for item in attr_value): - raise EosConfigurationError( - f"All elements of list parameter '{attr_name}' must be of the same type as specified " - f"by 'element_type'." - ) - if self.length is not None and len(attr_value) != self.length: - raise EosConfigurationError(f"List parameter '{attr_name}' length must be {self.length}.") - - def _validate_elements_within_bounds(self) -> None: - if self.value is None or is_dynamic_parameter(self.value) or self.min is None and self.max is None: - return - - if self.length is None and (self.min is not None or self.max is not None): - raise EosConfigurationError( - "List parameter 'min' and 'max' can only be specified when 'length' is specified." - ) - - _min = self.min or [float("-inf")] * self.length - _max = self.max or [float("inf")] * self.length - for i, val in enumerate(self.value): - if not _min[i] <= val <= _max[i]: - raise EosConfigurationError( - f"Element {i} of the list with value {val} is not within the the bounds [{_min[i]}, {_max[i]}]." - ) + element_type: ParameterType + length: int | None = None + min: list[int | float] | None = None + max: list[int | float] | None = None + + def __post_init__(self): + super().__post_init__() + self._validate_element_type() + self._validate_list_attributes() + self._validate_elements_within_bounds() + + def _validate_element_type(self) -> None: + if isinstance(self.element_type, str): + try: + self.element_type = ParameterType[self.element_type] + except KeyError as e: + raise EosConfigurationError(f"Invalid list parameter element type '{self.element_type}'") from e + if self.element_type == ParameterType.list: + raise EosConfigurationError("List parameter element type cannot be 'list'. Nested lists are not supported.") + + def _validate_list_attributes(self) -> None: + for attr_name in ["value", "min", "max"]: + attr_value = getattr(self, attr_name) + if attr_value is None: + continue + + if not isinstance(attr_value, list) and not isinstance(attr_value, ListConfig): + raise EosConfigurationError( + f"List parameter '{attr_name}' must be a list for 'list' type parameters.", + EosConfigurationError, + ) + if not all(isinstance(item, self.element_type.python_type()) for item in attr_value): + raise EosConfigurationError( + f"All elements of list parameter '{attr_name}' must be of the same type as specified " + f"by 'element_type'." + ) + if self.length is not None and len(attr_value) != self.length: + raise EosConfigurationError(f"List parameter '{attr_name}' length must be {self.length}.") + + def _validate_elements_within_bounds(self) -> None: + if self.value is None or is_dynamic_parameter(self.value) or self.min is None and self.max is None: + return + + if self.length is None and (self.min is not None or self.max is not None): + raise EosConfigurationError( + "List parameter 'min' and 'max' can only be specified when 'length' is specified." + ) + + _min = self.min or [float("-inf")] * self.length + _max = self.max or [float("inf")] * self.length + for i, val in enumerate(self.value): + if not _min[i] <= val <= _max[i]: + raise EosConfigurationError( + f"Element {i} of the list with value {val} is not within the the bounds [{_min[i]}, {_max[i]}]." + ) @dataclass(kw_only=True) class DictionaryParameter(Parameter): - pass + pass class ParameterFactory: - _TYPE_MAPPING: ClassVar = { - ParameterType.integer: NumericParameter, - ParameterType.decimal: NumericParameter, - ParameterType.string: Parameter, - ParameterType.boolean: BooleanParameter, - ParameterType.choice: ChoiceParameter, - ParameterType.list: ListParameter, - ParameterType.dictionary: DictionaryParameter, - } - - @staticmethod - def create_parameter(parameter_type: ParameterType | str, **kwargs) -> Parameter: - if isinstance(parameter_type, str): - parameter_type = ParameterType(parameter_type) - - parameter_class = ParameterFactory._TYPE_MAPPING.get(parameter_type) - if not parameter_class: - raise EosConfigurationError(f"Unsupported parameter type: {parameter_type}") - - if "type" not in kwargs: - kwargs["type"] = parameter_type - - return parameter_class(**kwargs) + _TYPE_MAPPING: ClassVar = { + ParameterType.integer: NumericParameter, + ParameterType.decimal: NumericParameter, + ParameterType.string: Parameter, + ParameterType.boolean: BooleanParameter, + ParameterType.choice: ChoiceParameter, + ParameterType.list: ListParameter, + ParameterType.dictionary: DictionaryParameter, + } + + @staticmethod + def create_parameter(parameter_type: ParameterType | str, **kwargs) -> Parameter: + if isinstance(parameter_type, str): + parameter_type = ParameterType(parameter_type) + + parameter_class = ParameterFactory._TYPE_MAPPING.get(parameter_type) + if not parameter_class: + raise EosConfigurationError(f"Unsupported parameter type: {parameter_type}") + + if "type" not in kwargs: + kwargs["type"] = parameter_type + + return parameter_class(**kwargs) diff --git a/eos/configuration/validation/lab_validator.py b/eos/configuration/validation/lab_validator.py index e5ae2e1..3f0db53 100644 --- a/eos/configuration/validation/lab_validator.py +++ b/eos/configuration/validation/lab_validator.py @@ -5,160 +5,160 @@ from eos.configuration.exceptions import EosLabConfigurationError from eos.configuration.spec_registries.device_specification_registry import DeviceSpecificationRegistry from eos.configuration.spec_registries.task_specification_registry import ( - TaskSpecificationRegistry, + TaskSpecificationRegistry, ) from eos.logging.batch_error_logger import batch_error, raise_batched_errors class LabValidator: - """ - Validates the configuration of a lab. It validates the locations, devices, and containers defined in the - lab configuration. - """ - - def __init__(self, config_dir: str, lab_config: LabConfig): - self._lab_config = lab_config - self._lab_config_dir = Path(config_dir) / LABS_DIR / lab_config.type.lower() - self._tasks = TaskSpecificationRegistry() - self._devices = DeviceSpecificationRegistry() - - def validate(self) -> None: - self._validate_lab_folder_name_matches_lab_type() - self._validate_locations() - self._validate_computers() - self._validate_devices() - self._validate_containers() - - def _validate_locations(self) -> None: - self._validate_device_locations() - self._validate_container_locations() - - def _validate_lab_folder_name_matches_lab_type(self) -> None: - if Path(self._lab_config_dir).name != self._lab_config.type: - raise EosLabConfigurationError( - f"Lab folder name '{Path(self._lab_config_dir).name}' does not match lab type " - f"'{self._lab_config.type}'." - ) - - def _validate_device_locations(self) -> None: - locations = self._lab_config.locations - for device_name, device in self._lab_config.devices.items(): - if device.location and device.location not in locations: - batch_error( - f"Device '{device_name}' has invalid location '{device.location}'.", - EosLabConfigurationError, - ) - raise_batched_errors(EosLabConfigurationError) - - def _validate_container_locations(self) -> None: - locations = self._lab_config.locations - for container in self._lab_config.containers: - if container.location not in locations: - raise EosLabConfigurationError( - f"Container of type '{container.type}' has invalid location '{container.location}'." - ) - - def _validate_computers(self) -> None: - self._validate_computer_unique_ips() - self._validate_eos_computer_not_specified() - - def _validate_computer_unique_ips(self) -> None: - ip_addresses = set() - - for computer_name, computer in self._lab_config.computers.items(): - if computer.ip in ip_addresses: - batch_error( - f"Computer '{computer_name}' has a duplicate IP address '{computer.ip}'.", - EosLabConfigurationError, - ) - ip_addresses.add(computer.ip) - - raise_batched_errors(EosLabConfigurationError) - - def _validate_eos_computer_not_specified(self) -> None: - for computer_name, computer in self._lab_config.computers.items(): - if computer_name.lower() == EOS_COMPUTER_NAME: - batch_error( - "Computer name 'eos_computer' is reserved and cannot be used.", - EosLabConfigurationError, - ) - if computer.ip in ["127.0.0.1", "localhost"]: - batch_error( - f"Computer '{computer_name}' cannot use the reserved IP '127.0.0.1' or 'localhost'.", - EosLabConfigurationError, - ) - raise_batched_errors(EosLabConfigurationError) - - def _validate_devices(self) -> None: - self._validate_devices_have_computers() - self._validate_device_initialization_parameters() - - def _validate_devices_have_computers(self) -> None: - for device_name, device in self._lab_config.devices.items(): - if device.computer.lower() == EOS_COMPUTER_NAME: - continue - if device.computer not in self._lab_config.computers: - batch_error( - f"Device '{device_name}' has invalid computer '{device.computer}'.", - EosLabConfigurationError, - ) - raise_batched_errors(EosLabConfigurationError) - - def _validate_device_initialization_parameters(self) -> None: - for device_name, device in self._lab_config.devices.items(): - device_spec = self._devices.get_spec_by_config(device) - if not device_spec: - batch_error( - f"No specification found for device type '{device.type}' of device '{device_name}'.", - EosLabConfigurationError, - ) - continue - - if device.initialization_parameters: - spec_params = device_spec.initialization_parameters or {} - for param_name in device.initialization_parameters: - if param_name not in spec_params: - batch_error( - f"Invalid initialization parameter '{param_name}' for device '{device_name}' " - f"of type '{device.type}' in lab type '{self._lab_config.type}'. " - f"Valid parameters are: {', '.join(spec_params.keys())}", - EosLabConfigurationError, + """ + Validates the configuration of a lab. It validates the locations, devices, and containers defined in the + lab configuration. + """ + + def __init__(self, config_dir: str, lab_config: LabConfig): + self._lab_config = lab_config + self._lab_config_dir = Path(config_dir) / LABS_DIR / lab_config.type.lower() + self._tasks = TaskSpecificationRegistry() + self._devices = DeviceSpecificationRegistry() + + def validate(self) -> None: + self._validate_lab_folder_name_matches_lab_type() + self._validate_locations() + self._validate_computers() + self._validate_devices() + self._validate_containers() + + def _validate_locations(self) -> None: + self._validate_device_locations() + self._validate_container_locations() + + def _validate_lab_folder_name_matches_lab_type(self) -> None: + if Path(self._lab_config_dir).name != self._lab_config.type: + raise EosLabConfigurationError( + f"Lab folder name '{Path(self._lab_config_dir).name}' does not match lab type " + f"'{self._lab_config.type}'." ) - raise_batched_errors(EosLabConfigurationError) - - def _validate_containers(self) -> None: - self._validate_container_unique_types() - self._validate_container_unique_ids() - - def _validate_container_unique_types(self) -> None: - container_types = [] - for container in self._lab_config.containers: - container_types.append(container.type) - - unique_container_types = set(container_types) - - for container_type in unique_container_types: - if container_types.count(container_type) > 1: - batch_error( - f"Container type '{container_type}' already defined." - f" Please add more ids to the existing container definition.", - EosLabConfigurationError, - ) - raise_batched_errors(EosLabConfigurationError) - - def _validate_container_unique_ids(self) -> None: - container_ids = set() - duplicate_ids = set() - for container in self._lab_config.containers: - for container_id in container.ids: - if container_id in container_ids: - duplicate_ids.add(container_id) - else: - container_ids.add(container_id) - - if duplicate_ids: - duplicate_ids_str = "\n ".join(duplicate_ids) - raise EosLabConfigurationError( - f"Containers must have unique IDs. The following are not unique:\n {duplicate_ids_str}" - ) + def _validate_device_locations(self) -> None: + locations = self._lab_config.locations + for device_name, device in self._lab_config.devices.items(): + if device.location and device.location not in locations: + batch_error( + f"Device '{device_name}' has invalid location '{device.location}'.", + EosLabConfigurationError, + ) + raise_batched_errors(EosLabConfigurationError) + + def _validate_container_locations(self) -> None: + locations = self._lab_config.locations + for container in self._lab_config.containers: + if container.location not in locations: + raise EosLabConfigurationError( + f"Container of type '{container.type}' has invalid location '{container.location}'." + ) + + def _validate_computers(self) -> None: + self._validate_computer_unique_ips() + self._validate_eos_computer_not_specified() + + def _validate_computer_unique_ips(self) -> None: + ip_addresses = set() + + for computer_name, computer in self._lab_config.computers.items(): + if computer.ip in ip_addresses: + batch_error( + f"Computer '{computer_name}' has a duplicate IP address '{computer.ip}'.", + EosLabConfigurationError, + ) + ip_addresses.add(computer.ip) + + raise_batched_errors(EosLabConfigurationError) + + def _validate_eos_computer_not_specified(self) -> None: + for computer_name, computer in self._lab_config.computers.items(): + if computer_name.lower() == EOS_COMPUTER_NAME: + batch_error( + "Computer name 'eos_computer' is reserved and cannot be used.", + EosLabConfigurationError, + ) + if computer.ip in ["127.0.0.1", "localhost"]: + batch_error( + f"Computer '{computer_name}' cannot use the reserved IP '127.0.0.1' or 'localhost'.", + EosLabConfigurationError, + ) + raise_batched_errors(EosLabConfigurationError) + + def _validate_devices(self) -> None: + self._validate_devices_have_computers() + self._validate_device_initialization_parameters() + + def _validate_devices_have_computers(self) -> None: + for device_name, device in self._lab_config.devices.items(): + if device.computer.lower() == EOS_COMPUTER_NAME: + continue + if device.computer not in self._lab_config.computers: + batch_error( + f"Device '{device_name}' has invalid computer '{device.computer}'.", + EosLabConfigurationError, + ) + raise_batched_errors(EosLabConfigurationError) + + def _validate_device_initialization_parameters(self) -> None: + for device_name, device in self._lab_config.devices.items(): + device_spec = self._devices.get_spec_by_config(device) + if not device_spec: + batch_error( + f"No specification found for device type '{device.type}' of device '{device_name}'.", + EosLabConfigurationError, + ) + continue + + if device.initialization_parameters: + spec_params = device_spec.initialization_parameters or {} + for param_name in device.initialization_parameters: + if param_name not in spec_params: + batch_error( + f"Invalid initialization parameter '{param_name}' for device '{device_name}' " + f"of type '{device.type}' in lab type '{self._lab_config.type}'. " + f"Valid parameters are: {', '.join(spec_params.keys())}", + EosLabConfigurationError, + ) + + raise_batched_errors(EosLabConfigurationError) + + def _validate_containers(self) -> None: + self._validate_container_unique_types() + self._validate_container_unique_ids() + + def _validate_container_unique_types(self) -> None: + container_types = [] + for container in self._lab_config.containers: + container_types.append(container.type) + + unique_container_types = set(container_types) + + for container_type in unique_container_types: + if container_types.count(container_type) > 1: + batch_error( + f"Container type '{container_type}' already defined." + f" Please add more ids to the existing container definition.", + EosLabConfigurationError, + ) + raise_batched_errors(EosLabConfigurationError) + + def _validate_container_unique_ids(self) -> None: + container_ids = set() + duplicate_ids = set() + for container in self._lab_config.containers: + for container_id in container.ids: + if container_id in container_ids: + duplicate_ids.add(container_id) + else: + container_ids.add(container_id) + + if duplicate_ids: + duplicate_ids_str = "\n ".join(duplicate_ids) + raise EosLabConfigurationError( + f"Containers must have unique IDs. The following are not unique:\n {duplicate_ids_str}" + ) diff --git a/eos/containers/container_manager.py b/eos/containers/container_manager.py index b4e990f..dcd8f1c 100644 --- a/eos/containers/container_manager.py +++ b/eos/containers/container_manager.py @@ -1,4 +1,4 @@ -import threading +import asyncio from collections import defaultdict from typing import Any @@ -7,7 +7,8 @@ from eos.containers.exceptions import EosContainerStateError from eos.containers.repositories.container_repository import ContainerRepository from eos.logging.logger import log -from eos.persistence.db_manager import DbManager +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.utils.async_rlock import AsyncRLock class ContainerManager: @@ -15,116 +16,117 @@ class ContainerManager: The container manager provides methods for interacting with containers in a lab. """ - def __init__(self, configuration_manager: ConfigurationManager, db_manager: DbManager): + def __init__(self, configuration_manager: ConfigurationManager, db_interface: AsyncMongoDbInterface): self._configuration_manager = configuration_manager - - self._containers = ContainerRepository("containers", db_manager) - self._containers.create_indices([("id", 1)], unique=True) - self._locks = defaultdict(threading.RLock) - - self._create_containers() + self._session_factory = db_interface.session_factory + self._locks = defaultdict(AsyncRLock) + self._containers = None + + async def initialize(self, db_interface: AsyncMongoDbInterface) -> None: + self._containers = ContainerRepository(db_interface) + await self._containers.initialize() + await self._create_containers() log.debug("Container manager initialized.") - def get_container(self, container_id: str) -> Container: + async def get_container(self, container_id: str) -> Container: """ Get a copy of the container with the specified ID. """ - container = self._containers.get_one(id=container_id) + container = await self._containers.get_one(id=container_id) if container: return Container(**container) raise EosContainerStateError(f"Container '{container_id}' does not exist.") - def get_containers(self, **query: dict[str, Any]) -> list[Container]: + async def get_containers(self, **query: dict[str, Any]) -> list[Container]: """ Query containers with arbitrary parameters. :param query: Dictionary of query parameters. """ - containers = self._containers.get_all(**query) + containers = await self._containers.get_all(**query) return [Container(**container) for container in containers] - def set_location(self, container_id: str, location: str) -> None: + async def set_location(self, container_id: str, location: str) -> None: """ Set the location of a container. """ - with self._get_lock(container_id): - self._containers.update({"location": location}, id=container_id) + async with self._get_lock(container_id): + await self._containers.update_one({"location": location}, id=container_id) - def set_lab(self, container_id: str, lab: str) -> None: + async def set_lab(self, container_id: str, lab: str) -> None: """ Set the lab of a container. """ - with self._get_lock(container_id): - self._containers.update({"lab": lab}, id=container_id) + async with self._get_lock(container_id): + await self._containers.update_one({"lab": lab}, id=container_id) - def set_metadata(self, container_id: str, metadata: dict[str, Any]) -> None: + async def set_metadata(self, container_id: str, metadata: dict[str, Any]) -> None: """ Set metadata for a container. """ - with self._get_lock(container_id): - self._containers.update({"metadata": metadata}, id=container_id) + async with self._get_lock(container_id): + await self._containers.update_one({"metadata": metadata}, id=container_id) - def add_metadata(self, container_id: str, metadata: dict[str, Any]) -> None: + async def add_metadata(self, container_id: str, metadata: dict[str, Any]) -> None: """ Add metadata to a container. """ - container = self.get_container(container_id) + container = await self.get_container(container_id) container.metadata.update(metadata) - with self._get_lock(container_id): - self._containers.update({"metadata": container.metadata}, id=container_id) + async with self._get_lock(container_id): + await self._containers.update_one({"metadata": container.metadata}, id=container_id) - def remove_metadata(self, container_id: str, metadata_keys: list[str]) -> None: + async def remove_metadata(self, container_id: str, metadata_keys: list[str]) -> None: """ Remove metadata from a container. """ - container = self.get_container(container_id) + container = await self.get_container(container_id) for key in metadata_keys: container.metadata.pop(key, None) - with self._get_lock(container_id): - self._containers.update({"metadata": container.metadata}, id=container_id) + async with self._get_lock(container_id): + await self._containers.update_one({"metadata": container.metadata}, id=container_id) - def update_container(self, container: Container) -> None: + async def update_container(self, container: Container) -> None: """ Update a container in the database. """ - self._containers.update(container.model_dump(), id=container.id) + await self._containers.update_one(container.model_dump(), id=container.id) - def update_containers(self, loaded_labs: set[str] | None = None, unloaded_labs: set[str] | None = None) -> None: + async def update_containers( + self, loaded_labs: set[str] | None = None, unloaded_labs: set[str] | None = None + ) -> None: """ Update containers based on loaded and unloaded labs. """ if unloaded_labs: - for lab_id in unloaded_labs: - self._remove_containers_for_lab(lab_id) + await asyncio.gather(*[self._remove_containers_for_lab(lab_id) for lab_id in unloaded_labs]) if loaded_labs: - for lab_id in loaded_labs: - self._create_containers_for_lab(lab_id) + await asyncio.gather(*[self._create_containers_for_lab(lab_id) for lab_id in loaded_labs]) log.debug("Containers have been updated.") - def _remove_containers_for_lab(self, lab_id: str) -> None: + async def _remove_containers_for_lab(self, lab_id: str) -> None: """ Remove containers associated with an unloaded lab. """ - containers_to_remove = self.get_containers(lab=lab_id) - for container in containers_to_remove: - self._containers.delete(id=container.id) + containers_to_remove = await self.get_containers(lab=lab_id) + await asyncio.gather(*[self._containers.delete_one(id=container.id) for container in containers_to_remove]) log.debug(f"Removed containers for lab '{lab_id}'") - def _create_containers_for_lab(self, lab_id: str) -> None: + async def _create_containers_for_lab(self, lab_id: str) -> None: """ Create containers for a loaded lab. """ lab_config = self._configuration_manager.labs[lab_id] for container_config in lab_config.containers: for container_id in container_config.ids: - existing_container = self._containers.get_one(id=container_id) - if not existing_container: + container_exists = await self._containers.exists(id=container_id) + if not container_exists: container = Container( id=container_id, type=container_config.type, @@ -132,10 +134,10 @@ def _create_containers_for_lab(self, lab_id: str) -> None: location=container_config.location, metadata=container_config.metadata, ) - self._containers.update(container.model_dump(), id=container_id) + await self._containers.update_one(container.model_dump(), id=container_id) log.debug(f"Created containers for lab '{lab_id}'") - def _create_containers(self) -> None: + async def _create_containers(self) -> None: """ Create containers from the lab configuration and add them to the database. """ @@ -149,10 +151,10 @@ def _create_containers(self) -> None: location=container_config.location, metadata=container_config.metadata, ) - self._containers.update(container.model_dump(), id=container_id) + await self._containers.update_one(container.model_dump(), id=container_id) log.debug("Created containers") - def _get_lock(self, container_id: str) -> threading.RLock: + def _get_lock(self, container_id: str) -> AsyncRLock: """ Get the lock for a specific container. """ diff --git a/eos/containers/repositories/container_repository.py b/eos/containers/repositories/container_repository.py index cf3bb3a..a6cad51 100644 --- a/eos/containers/repositories/container_repository.py +++ b/eos/containers/repositories/container_repository.py @@ -1,5 +1,10 @@ -from eos.persistence.mongo_repository import MongoRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository -class ContainerRepository(MongoRepository): - pass +class ContainerRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("containers", db_interface) + + async def initialize(self) -> None: + await self.create_indices([("id", 1)], unique=True) diff --git a/eos/devices/base_device.py b/eos/devices/base_device.py index ce39e3d..ef56613 100644 --- a/eos/devices/base_device.py +++ b/eos/devices/base_device.py @@ -1,3 +1,4 @@ +import atexit import threading from abc import ABC, abstractmethod, ABCMeta from enum import Enum @@ -71,15 +72,9 @@ def __init__( self._lock = threading.Lock() + atexit.register(self.cleanup) self.initialize(initialization_parameters) - def __del__(self): - if "_status" not in self.__dict__: - return - if self._status and self._status != DeviceStatus.DISABLED: - self._status = DeviceStatus.DISABLED - self.cleanup() - def initialize(self, initialization_parameters: dict[str, Any]) -> None: """ Initialize the device. After calling this method, the device is ready to be used for tasks @@ -104,6 +99,9 @@ def cleanup(self) -> None: DISABLED. """ with self._lock: + if self._status == DeviceStatus.DISABLED: + return + if self._status == DeviceStatus.BUSY: raise EosDeviceCleanupError( f"Device {self._device_id} is busy. Cannot perform cleanup.", diff --git a/eos/devices/device_manager.py b/eos/devices/device_manager.py index 082746b..6388b8e 100644 --- a/eos/devices/device_manager.py +++ b/eos/devices/device_manager.py @@ -1,3 +1,5 @@ +import asyncio +import itertools from typing import Any import ray @@ -8,10 +10,10 @@ from eos.configuration.constants import EOS_COMPUTER_NAME from eos.devices.entities.device import Device, DeviceStatus from eos.devices.exceptions import EosDeviceStateError, EosDeviceInitializationError +from eos.devices.repositories.device_repository import DeviceRepository from eos.logging.batch_error_logger import batch_error, raise_batched_errors from eos.logging.logger import log -from eos.persistence.db_manager import DbManager -from eos.persistence.mongo_repository import MongoRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface class DeviceManager: @@ -19,44 +21,50 @@ class DeviceManager: Provides methods for interacting with the devices in a lab. """ - def __init__(self, configuration_manager: ConfigurationManager, db_manager: DbManager): + def __init__(self, configuration_manager: ConfigurationManager, db_interface: AsyncMongoDbInterface): self._configuration_manager = configuration_manager - - self._devices = MongoRepository("devices", db_manager) - self._devices.create_indices([("lab_id", 1), ("id", 1)], unique=True) + self._session_factory = db_interface.session_factory + self._devices = None self._device_plugin_registry = configuration_manager.devices self._device_actor_handles: dict[str, ActorHandle] = {} self._device_actor_computer_ips: dict[str, str] = {} + async def initialize(self, db_interface: AsyncMongoDbInterface) -> None: + self._devices = DeviceRepository(db_interface) + await self._devices.initialize() + log.debug("Device manager initialized.") - def get_device(self, lab_id: str, device_id: str) -> Device | None: + async def get_device(self, lab_id: str, device_id: str) -> Device | None: """ - Get a device by its ID. + Get a device by its lab and device ID. + + :param lab_id: The ID of the lab the device is in. + :param device_id: The ID of the device in the lab. """ - device = self._devices.get_one(lab_id=lab_id, id=device_id) + device = await self._devices.get_one(lab_id=lab_id, id=device_id) if not device: return None return Device(**device) - def get_devices(self, **query: dict[str, Any]) -> list[Device]: + async def get_devices(self, **query: dict[str, Any]) -> list[Device]: """ - Query devices with arbitrary parameters. + Query devices with arbitrary parameters and return a list of matching devices. :param query: Dictionary of query parameters. """ - devices = self._devices.get_all(**query) + devices = await self._devices.get_all(**query) return [Device(**device) for device in devices] - def set_device_status(self, lab_id: str, device_id: str, status: DeviceStatus) -> None: + async def set_device_status(self, lab_id: str, device_id: str, status: DeviceStatus) -> None: """ Set the status of a device. """ - if not self._devices.exists(lab_id=lab_id, id=device_id): + if not await self._devices.exists(lab_id=lab_id, id=device_id): raise EosDeviceStateError(f"Device '{device_id}' in lab '{lab_id}' does not exist.") - self._devices.update({"status": status.value}, lab_id=lab_id, id=device_id) + await self._devices.update_one({"status": status.value}, lab_id=lab_id, id=device_id) def get_device_actor(self, lab_id: str, device_id: str) -> ActorHandle: """ @@ -68,41 +76,56 @@ def get_device_actor(self, lab_id: str, device_id: str) -> ActorHandle: return self._device_actor_handles.get(actor_id) - def update_devices(self, loaded_labs: set[str] | None = None, unloaded_labs: set[str] | None = None) -> None: + async def update_devices(self, loaded_labs: set[str] | None = None, unloaded_labs: set[str] | None = None) -> None: if unloaded_labs: - for lab_id in unloaded_labs: - self._remove_devices_for_lab(lab_id) + await self.cleanup_device_actors(lab_ids=list(unloaded_labs)) if loaded_labs: - for lab_id in loaded_labs: - self._create_devices_for_lab(lab_id) + creation_tasks = [self._create_devices_for_lab(lab_id) for lab_id in loaded_labs] + await asyncio.gather(*creation_tasks) self._check_device_actors_healthy() log.debug("Devices have been updated.") - def cleanup_device_actors(self) -> None: - for actor in self._device_actor_handles.values(): - ray.kill(actor) - self._device_actor_handles.clear() - self._device_actor_computer_ips.clear() - self._devices.delete() - log.info("All device actors have been cleaned up.") - - def _remove_devices_for_lab(self, lab_id: str) -> None: - devices_to_remove = self.get_devices(lab_id=lab_id) - for device in devices_to_remove: - actor_id = device.get_actor_id() + async def cleanup_device_actors(self, lab_ids: list[str] | None = None) -> None: + """ + Terminate device actors, optionally for specific labs. + + :param lab_ids: If provided, cleanup devices for these labs. + If None, cleanup all devices. + """ + if lab_ids: + devices_by_lab = await self._devices.get_devices_by_lab_ids(lab_ids) + devices_to_remove = list(itertools.chain(*devices_by_lab.values())) + actor_ids = [Device(**device).get_actor_id() for device in devices_to_remove] + else: + actor_ids = list(self._device_actor_handles.keys()) + + async def cleanup_device(actor_id: str) -> None: if actor_id in self._device_actor_handles: + await self._device_actor_handles[actor_id].cleanup.remote() ray.kill(self._device_actor_handles[actor_id]) del self._device_actor_handles[actor_id] del self._device_actor_computer_ips[actor_id] - self._devices.delete(lab_id=lab_id) - log.debug(f"Removed devices for lab '{lab_id}'") - def _create_devices_for_lab(self, lab_id: str) -> None: + await asyncio.gather(*[cleanup_device(actor_id) for actor_id in actor_ids]) + + if lab_ids: + await self._devices.delete_devices_by_lab_ids(lab_ids) + log.debug(f"Cleaned up devices for lab(s): {', '.join(lab_ids)}") + else: + await self._devices.delete_all() + log.info("All devices have been cleaned up.") + + async def _create_devices_for_lab(self, lab_id: str) -> None: lab_config = self._configuration_manager.labs[lab_id] + + existing_devices = {device["id"]: Device(**device) for device in await self._devices.get_all(lab_id=lab_id)} + + devices_to_upsert: list[Device] = [] + for device_id, device_config in lab_config.devices.items(): - device = self.get_device(lab_id, device_id) + device = existing_devices.get(device_id) if device and device.get_actor_id() in self._device_actor_handles: continue @@ -110,17 +133,20 @@ def _create_devices_for_lab(self, lab_id: str) -> None: if device and device.actor_handle: self._restore_device_actor(device) else: - device = Device( + new_device = Device( lab_id=lab_id, id=device_id, type=device_config.type, location=device_config.location, computer=device_config.computer, ) - self._devices.update(device.model_dump(), lab_id=lab_id, id=device_id) - self._create_device_actor(device) + devices_to_upsert.append(new_device) + self._create_device_actor(new_device) + + if devices_to_upsert: + await self._devices.bulk_upsert([device.model_dump() for device in devices_to_upsert]) - log.debug(f"Created devices for lab '{lab_id}'") + log.debug(f"Updated devices for lab '{lab_id}'") def _restore_device_actor(self, device: Device) -> None: """ diff --git a/eos/devices/repositories/__init__.py b/eos/devices/repositories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/eos/devices/repositories/device_repository.py b/eos/devices/repositories/device_repository.py new file mode 100644 index 0000000..38213e4 --- /dev/null +++ b/eos/devices/repositories/device_repository.py @@ -0,0 +1,61 @@ +from typing import Any + +from motor.core import AgnosticClientSession +from pymongo import UpdateOne +from pymongo.results import DeleteResult, BulkWriteResult + +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository + + +class DeviceRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("devices", db_interface) + + async def initialize(self) -> None: + await self.create_indices([("lab_id", 1), ("id", 1)], unique=True) + + async def delete_devices_by_lab_ids( + self, lab_ids: list[str], session: AgnosticClientSession | None = None + ) -> DeleteResult: + """ + Delete all devices associated with the given lab IDs in a single operation. + + :param lab_ids: List of lab_ids for which to delete devices. + :return: The result of the delete operation. + """ + return await self._collection.delete_many({"lab_id": {"$in": lab_ids}}, session=session) + + async def get_devices_by_lab_ids( + self, lab_ids: list[str], session: AgnosticClientSession | None = None + ) -> dict[str, list[dict[str, Any]]]: + """ + Get all devices associated with the given lab IDs in a single operation. + + :param lab_ids: List of lab_ids for which to fetch devices. + :return: A dictionary with lab_ids as keys and lists of devices as values. + """ + cursor = self._collection.find({"lab_id": {"$in": lab_ids}}, session=session) + devices = await cursor.to_list(length=None) + + # Group devices by lab_id + devices_by_lab = {lab_id: [] for lab_id in lab_ids} + for device in devices: + devices_by_lab[device["lab_id"]].append(device) + + return devices_by_lab + + async def bulk_upsert( + self, devices: list[dict[str, Any]], session: AgnosticClientSession | None = None + ) -> BulkWriteResult: + """ + Perform a bulk upsert operation for multiple devices. + + :param devices: List of device dictionaries to upsert. + :return: The result of the bulk write operation. + """ + operations = [ + UpdateOne({"lab_id": device["lab_id"], "id": device["id"]}, {"$set": device}, upsert=True) + for device in devices + ] + return await self._collection.bulk_write(operations, session=session) diff --git a/eos/experiments/experiment_executor.py b/eos/experiments/experiment_executor.py index 3c32549..942815c 100644 --- a/eos/experiments/experiment_executor.py +++ b/eos/experiments/experiment_executor.py @@ -16,7 +16,7 @@ from eos.scheduling.entities.scheduled_task import ScheduledTask from eos.tasks.entities.task import TaskOutput from eos.tasks.entities.task_execution_parameters import TaskExecutionParameters -from eos.tasks.exceptions import EosTaskExecutionError +from eos.tasks.exceptions import EosTaskExecutionError, EosTaskCancellationError from eos.tasks.task_executor import TaskExecutor from eos.tasks.task_input_resolver import TaskInputResolver from eos.tasks.task_manager import TaskManager @@ -41,6 +41,7 @@ def __init__( self._experiment_type = experiment_type self._execution_parameters = execution_parameters self._experiment_graph = experiment_graph + self._experiment_manager = experiment_manager self._task_manager = task_manager self._container_manager = container_manager @@ -52,7 +53,7 @@ def __init__( self._task_output_futures: dict[str, asyncio.Task] = {} self._experiment_status = None - def start_experiment( + async def start_experiment( self, dynamic_parameters: dict[str, dict[str, Any]] | None = None, metadata: dict[str, Any] | None = None, @@ -60,11 +61,11 @@ def start_experiment( """ Start the experiment and register the executor with the scheduler. """ - experiment = self._experiment_manager.get_experiment(self._experiment_id) + experiment = await self._experiment_manager.get_experiment(self._experiment_id) if experiment: - self._handle_existing_experiment(experiment) + await self._handle_existing_experiment(experiment) else: - self._create_new_experiment(dynamic_parameters, metadata) + await self._create_new_experiment(dynamic_parameters, metadata) self._scheduler.register_experiment( experiment_id=self._experiment_id, @@ -72,12 +73,12 @@ def start_experiment( experiment_graph=self._experiment_graph, ) - self._experiment_manager.start_experiment(self._experiment_id) + await self._experiment_manager.start_experiment(self._experiment_id) self._experiment_status = ExperimentStatus.RUNNING log.info(f"{'Resumed' if self._execution_parameters.resume else 'Started'} experiment '{self._experiment_id}'.") - def _handle_existing_experiment(self, experiment: Experiment) -> None: + async def _handle_existing_experiment(self, experiment: Experiment) -> None: """ Handle cases when the experiment already exists. """ @@ -99,13 +100,13 @@ def _raise_error(status: str) -> None: } status_handlers.get(self._experiment_status, lambda: None)() else: - self._resume_experiment() + await self._resume_experiment() async def cancel_experiment(self) -> None: """ Cancel the experiment. """ - experiment = self._experiment_manager.get_experiment(self._experiment_id) + experiment = await self._experiment_manager.get_experiment(self._experiment_id) if not experiment or experiment.status != ExperimentStatus.RUNNING: raise EosExperimentCancellationError( f"Cannot cancel experiment '{self._experiment_id}' with status '{experiment.status}'. " @@ -114,10 +115,12 @@ async def cancel_experiment(self) -> None: log.warning(f"Cancelling experiment '{self._experiment_id}'...") self._experiment_status = ExperimentStatus.CANCELLED - self._experiment_manager.cancel_experiment(self._experiment_id) - self._scheduler.unregister_experiment(self._experiment_id) - await self._cancel_running_tasks() + await asyncio.gather( + self._experiment_manager.cancel_experiment(self._experiment_id), + self._scheduler.unregister_experiment(self._experiment_id), + self._cancel_running_tasks(), + ) log.warning(f"Cancelled experiment '{self._experiment_id}'.") async def progress_experiment(self) -> bool: @@ -130,32 +133,34 @@ async def progress_experiment(self) -> bool: if self._experiment_status != ExperimentStatus.RUNNING: return self._experiment_status == ExperimentStatus.CANCELLED - if self._scheduler.is_experiment_completed(self._experiment_id): - self._complete_experiment() + if await self._scheduler.is_experiment_completed(self._experiment_id): + await self._complete_experiment() return True - self._process_completed_tasks() + await self._process_completed_tasks() await self._execute_tasks() return False except Exception as e: - self._fail_experiment() + await self._fail_experiment() raise EosExperimentExecutionError(f"Error executing experiment '{self._experiment_id}'") from e - def _resume_experiment(self) -> None: + async def _resume_experiment(self) -> None: """ Resume an existing experiment. """ - self._experiment_manager.delete_non_completed_tasks(self._experiment_id) + await self._experiment_manager.delete_non_completed_tasks(self._experiment_id) log.info(f"Experiment '{self._experiment_id}' resumed.") - def _create_new_experiment(self, dynamic_parameters: dict[str, dict[str, Any]], metadata: dict[str, Any]) -> None: + async def _create_new_experiment( + self, dynamic_parameters: dict[str, dict[str, Any]], metadata: dict[str, Any] + ) -> None: """ Create a new experiment with the given parameters. """ dynamic_parameters = dynamic_parameters or {} self._validate_dynamic_parameters(dynamic_parameters) - self._experiment_manager.create_experiment( + await self._experiment_manager.create_experiment( experiment_id=self._experiment_id, experiment_type=self._experiment_type, execution_parameters=self._execution_parameters, @@ -173,36 +178,44 @@ async def _cancel_running_tasks(self) -> None: ] try: await asyncio.wait_for(asyncio.gather(*cancellation_futures), timeout=30) + except EosTaskCancellationError as e: + raise EosExperimentExecutionError( + f"Error cancelling tasks of experiment {self._experiment_id}. Some tasks may not have been cancelled." + ) from e except asyncio.TimeoutError as e: raise EosExperimentExecutionError( f"Timeout while cancelling experiment {self._experiment_id}. Some tasks may not have been cancelled." ) from e - def _complete_experiment(self) -> None: + async def _complete_experiment(self) -> None: """ Complete the experiment and clean up. """ - self._scheduler.unregister_experiment(self._experiment_id) - self._experiment_manager.complete_experiment(self._experiment_id) + await asyncio.gather( + self._scheduler.unregister_experiment(self._experiment_id), + self._experiment_manager.complete_experiment(self._experiment_id), + ) self._experiment_status = ExperimentStatus.COMPLETED - def _fail_experiment(self) -> None: + async def _fail_experiment(self) -> None: """ Fail the experiment. """ - self._scheduler.unregister_experiment(self._experiment_id) - self._experiment_manager.fail_experiment(self._experiment_id) + await asyncio.gather( + self._scheduler.unregister_experiment(self._experiment_id), + self._experiment_manager.fail_experiment(self._experiment_id), + ) self._experiment_status = ExperimentStatus.FAILED - def _process_completed_tasks(self) -> None: + async def _process_completed_tasks(self) -> None: """ Process the output of completed tasks. """ completed_tasks = [task_id for task_id, future in self._task_output_futures.items() if future.done()] for task_id in completed_tasks: - self._process_task_output(task_id) + await self._process_task_output(task_id) - def _process_task_output(self, task_id: str) -> None: + async def _process_task_output(self, task_id: str) -> None: """ Process the output of a single completed task. """ @@ -210,9 +223,9 @@ def _process_task_output(self, task_id: str) -> None: result = self._task_output_futures[task_id].result() if result: output_parameters, output_containers, output_files = result - self._update_containers(output_containers) - self._add_task_output(task_id, output_parameters, output_containers, output_files) - self._task_manager.complete_task(self._experiment_id, task_id) + await self._update_containers(output_containers) + await self._add_task_output(task_id, output_parameters, output_containers, output_files) + await self._task_manager.complete_task(self._experiment_id, task_id) log.info(f"EXP '{self._experiment_id}' - Completed task '{task_id}'.") except EosTaskExecutionError as e: raise EosExperimentTaskExecutionError( @@ -222,14 +235,14 @@ def _process_task_output(self, task_id: str) -> None: del self._task_output_futures[task_id] del self._current_task_execution_parameters[task_id] - def _update_containers(self, output_containers: dict[str, Any]) -> None: + async def _update_containers(self, output_containers: dict[str, Any]) -> None: """ Update containers with task output. """ for container in output_containers.values(): - self._container_manager.update_container(container) + await self._container_manager.update_container(container) - def _add_task_output( + async def _add_task_output( self, task_id: str, output_parameters: dict[str, Any], @@ -248,7 +261,7 @@ def _add_task_output( ) for file_name, file_data in output_files.items(): self._task_manager.add_task_output_file(self._experiment_id, task_id, file_name, file_data) - self._task_manager.add_task_output(self._experiment_id, task_id, task_output) + await self._task_manager.add_task_output(self._experiment_id, task_id, task_output) async def _execute_tasks(self) -> None: """ @@ -264,7 +277,7 @@ async def _execute_task(self, scheduled_task: ScheduledTask) -> None: Execute a single task. """ task_config = self._experiment_graph.get_task_config(scheduled_task.id) - task_config = self._task_input_resolver.resolve_task_inputs(self._experiment_id, task_config) + task_config = await self._task_input_resolver.resolve_task_inputs(self._experiment_id, task_config) task_execution_parameters = TaskExecutionParameters( task_id=scheduled_task.id, experiment_id=self._experiment_id, diff --git a/eos/experiments/experiment_manager.py b/eos/experiments/experiment_manager.py index 8e2b5d1..467de27 100644 --- a/eos/experiments/experiment_manager.py +++ b/eos/experiments/experiment_manager.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timezone from typing import Any @@ -6,8 +7,7 @@ from eos.experiments.exceptions import EosExperimentStateError from eos.experiments.repositories.experiment_repository import ExperimentRepository from eos.logging.logger import log -from eos.persistence.db_manager import DbManager -from eos.tasks.entities.task import TaskStatus +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface from eos.tasks.repositories.task_repository import TaskRepository @@ -16,15 +16,21 @@ class ExperimentManager: Responsible for managing the state of all experiments in EOS and tracking their execution. """ - def __init__(self, configuration_manager: ConfigurationManager, db_manager: DbManager): + def __init__(self, configuration_manager: ConfigurationManager, db_interface: AsyncMongoDbInterface): self._configuration_manager = configuration_manager - self._experiments = ExperimentRepository("experiments", db_manager) - self._experiments.create_indices([("id", 1)], unique=True) - self._tasks = TaskRepository("tasks", db_manager) + self._session_factory = db_interface.session_factory + self._experiments = None + self._tasks = None + async def initialize(self, db_interface: AsyncMongoDbInterface) -> None: + self._experiments = ExperimentRepository(db_interface) + await self._experiments.initialize() + + self._tasks = TaskRepository(db_interface) + await self._tasks.initialize() log.debug("Experiment manager initialized.") - def create_experiment( + async def create_experiment( self, experiment_id: str, experiment_type: str, @@ -41,7 +47,7 @@ def create_experiment( :param execution_parameters: Parameters for the execution of the experiment. :param metadata: Additional metadata to be stored with the experiment. """ - if self._experiments.get_one(id=experiment_id): + if await self._experiments.exists(id=experiment_id): raise EosExperimentStateError(f"Experiment '{experiment_id}' already exists.") experiment_config = self._configuration_manager.experiments.get(experiment_type) @@ -58,107 +64,108 @@ def create_experiment( dynamic_parameters=dynamic_parameters or {}, metadata=metadata or {}, ) - self._experiments.create(experiment.model_dump()) + await self._experiments.create(experiment.model_dump()) log.info(f"Created experiment '{experiment_id}'.") - def delete_experiment(self, experiment_id: str) -> None: + async def delete_experiment(self, experiment_id: str) -> None: """ Delete an experiment. """ - if not self._experiments.exists(id=experiment_id): + if not await self._experiments.exists(id=experiment_id): raise EosExperimentStateError(f"Experiment '{experiment_id}' does not exist.") - self._experiments.delete(id=experiment_id) - self._tasks.delete(experiment_id=experiment_id) + await self._experiments.delete_one(id=experiment_id) + await self._tasks.delete_many(experiment_id=experiment_id) log.info(f"Deleted experiment '{experiment_id}'.") - def start_experiment(self, experiment_id: str) -> None: + async def start_experiment(self, experiment_id: str) -> None: """ Start an experiment. """ - self._set_experiment_status(experiment_id, ExperimentStatus.RUNNING) + await self._set_experiment_status(experiment_id, ExperimentStatus.RUNNING) - def complete_experiment(self, experiment_id: str) -> None: + async def complete_experiment(self, experiment_id: str) -> None: """ Complete an experiment. """ - self._set_experiment_status(experiment_id, ExperimentStatus.COMPLETED) + await self._set_experiment_status(experiment_id, ExperimentStatus.COMPLETED) - def cancel_experiment(self, experiment_id: str) -> None: + async def cancel_experiment(self, experiment_id: str) -> None: """ Cancel an experiment. """ - self._set_experiment_status(experiment_id, ExperimentStatus.CANCELLED) + await self._set_experiment_status(experiment_id, ExperimentStatus.CANCELLED) - def suspend_experiment(self, experiment_id: str) -> None: + async def suspend_experiment(self, experiment_id: str) -> None: """ Suspend an experiment. """ - self._set_experiment_status(experiment_id, ExperimentStatus.SUSPENDED) + await self._set_experiment_status(experiment_id, ExperimentStatus.SUSPENDED) - def fail_experiment(self, experiment_id: str) -> None: + async def fail_experiment(self, experiment_id: str) -> None: """ Fail an experiment. """ - self._set_experiment_status(experiment_id, ExperimentStatus.FAILED) + await self._set_experiment_status(experiment_id, ExperimentStatus.FAILED) - def get_experiment(self, experiment_id: str) -> Experiment | None: + async def get_experiment(self, experiment_id: str) -> Experiment | None: """ Get an experiment. """ - experiment = self._experiments.get_one(id=experiment_id) + experiment = await self._experiments.get_one(id=experiment_id) return Experiment(**experiment) if experiment else None - def get_experiments(self, **query: dict[str, Any]) -> list[Experiment]: + async def get_experiments(self, **query: dict[str, Any]) -> list[Experiment]: """ Get experiments with a custom query. :param query: Dictionary of query parameters. """ - experiments = self._experiments.get_all(**query) + experiments = await self._experiments.get_all(**query) return [Experiment(**experiment) for experiment in experiments] - def get_lab_experiments(self, lab: str) -> list[Experiment]: + async def get_lab_experiments(self, lab: str) -> list[Experiment]: """ Get all experiments associated with a lab. """ - experiments = self._experiments.get_experiments_by_lab(lab) + experiments = await self._experiments.get_experiments_by_lab(lab) return [Experiment(**experiment) for experiment in experiments] - def get_running_tasks(self, experiment_id: str | None) -> set[str]: + async def get_running_tasks(self, experiment_id: str | None) -> set[str]: """ Get the list of currently running tasks constrained by experiment ID. """ - experiment = self._experiments.get_one(id=experiment_id) + experiment = await self._experiments.get_one(id=experiment_id) return set(experiment.get("running_tasks", {})) if experiment else {} - def get_completed_tasks(self, experiment_id: str) -> set[str]: + async def get_completed_tasks(self, experiment_id: str) -> set[str]: """ Get the list of completed tasks constrained by experiment ID. """ - experiment = self._experiments.get_one(id=experiment_id) + experiment = await self._experiments.get_one(id=experiment_id) return set(experiment.get("completed_tasks", {})) if experiment else {} - def delete_non_completed_tasks(self, experiment_id: str) -> None: + async def delete_non_completed_tasks(self, experiment_id: str) -> None: """ Delete all tasks that are not completed in the given experiment. """ - experiment = self.get_experiment(experiment_id) - - for task_id in experiment.running_tasks: - self._tasks.delete(experiment_id=experiment_id, id=task_id) - self._experiments.clear_running_tasks(experiment_id) + experiment = await self.get_experiment(experiment_id) - self._tasks.delete(experiment_id=experiment_id, status=TaskStatus.FAILED.value) - self._tasks.delete(experiment_id=experiment_id, status=TaskStatus.CANCELLED.value) + async with self._session_factory() as session: + await asyncio.gather( + self._tasks.delete_running_tasks(experiment_id, experiment.running_tasks, session=session), + self._experiments.clear_running_tasks(experiment_id, session=session), + self._tasks.delete_failed_and_cancelled_tasks(experiment_id, session=session), + ) + await session.commit_transaction() - def _set_experiment_status(self, experiment_id: str, new_status: ExperimentStatus) -> None: + async def _set_experiment_status(self, experiment_id: str, new_status: ExperimentStatus) -> None: """ Set the status of an experiment. """ - if not self._experiments.exists(id=experiment_id): + if not await self._experiments.exists(id=experiment_id): raise EosExperimentStateError(f"Experiment '{experiment_id}' does not exist.") update_fields = {"status": new_status.value} @@ -171,4 +178,4 @@ def _set_experiment_status(self, experiment_id: str, new_status: ExperimentStatu ]: update_fields["end_time"] = datetime.now(tz=timezone.utc) - self._experiments.update(update_fields, id=experiment_id) + await self._experiments.update_one(update_fields, id=experiment_id) diff --git a/eos/experiments/repositories/experiment_repository.py b/eos/experiments/repositories/experiment_repository.py index d925d52..ad26fcb 100644 --- a/eos/experiments/repositories/experiment_repository.py +++ b/eos/experiments/repositories/experiment_repository.py @@ -1,45 +1,61 @@ +from motor.core import AgnosticClientSession + from eos.experiments.entities.experiment import ExperimentStatus -from eos.persistence.mongo_repository import MongoRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository + +class ExperimentRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("experiments", db_interface) -class ExperimentRepository(MongoRepository): - def get_experiments_by_lab(self, lab_type: str) -> list[dict]: - return self._collection.find({"labs": {"$in": [lab_type]}}) + async def initialize(self) -> None: + await self.create_indices([("id", 1)], unique=True) - def add_running_task(self, experiment_id: str, task_id: str) -> None: - self._collection.update_one( + async def get_experiments_by_lab(self, lab_type: str, session: AgnosticClientSession | None = None) -> list[dict]: + return await self._collection.find({"labs": {"$in": [lab_type]}}, session=session).to_list(length=None) + + async def add_running_task( + self, experiment_id: str, task_id: str, session: AgnosticClientSession | None = None + ) -> None: + await self._collection.update_one( {"id": experiment_id}, {"$addToSet": {"running_tasks": task_id}}, + session=session, ) - def delete_running_task(self, experiment_id: str, task_id: str) -> None: - self._collection.update_one( + async def delete_running_task( + self, experiment_id: str, task_id: str, session: AgnosticClientSession | None = None + ) -> None: + await self._collection.update_one( {"id": experiment_id}, {"$pull": {"running_tasks": task_id}}, + session=session, ) - def clear_running_tasks(self, experiment_id: str) -> None: - self._collection.update_one( + async def clear_running_tasks(self, experiment_id: str, session: AgnosticClientSession | None = None) -> None: + await self._collection.update_one( {"id": experiment_id}, {"$set": {"running_tasks": []}}, + session=session, ) - def move_task_queue(self, experiment_id: str, task_id: str, source: str, target: str) -> None: - self._collection.update_one( + async def move_task_queue( + self, experiment_id: str, task_id: str, source: str, target: str, session: AgnosticClientSession | None = None + ) -> None: + await self._collection.update_one( {"id": experiment_id}, {"$pull": {source: task_id}, "$addToSet": {target: task_id}}, + session=session, ) - def get_experiment_ids_by_campaign(self, campaign_id: str, status: ExperimentStatus | None = None) -> list[str]: - """ - Get all experiment IDs of a campaign with an optional status filter. - - :param campaign_id: The ID of the campaign. - :param status: Optional status to filter experiments. - :return: A list of experiment IDs. - """ + async def get_experiment_ids_by_campaign( + self, campaign_id: str, status: ExperimentStatus | None = None, session: AgnosticClientSession | None = None + ) -> list[str]: query = {"id": {"$regex": f"^{campaign_id}"}} if status: query["status"] = status.value - return [doc["id"] for doc in self._collection.find(query, {"id": 1})] + return [ + doc["id"] for doc in await self._collection.find(query, {"id": 1}, session=session).to_list(length=None) + ] diff --git a/eos/monitoring/graceful_termination_monitor.py b/eos/monitoring/graceful_termination_monitor.py index 549ece0..e8fa7f1 100644 --- a/eos/monitoring/graceful_termination_monitor.py +++ b/eos/monitoring/graceful_termination_monitor.py @@ -1,6 +1,6 @@ from eos.logging.logger import log -from eos.persistence.db_manager import DbManager -from eos.persistence.mongo_repository import MongoRepository +from eos.monitoring.repositories.global_repository import GlobalRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface from eos.utils.singleton import Singleton @@ -9,26 +9,29 @@ class GracefulTerminationMonitor(metaclass=Singleton): The graceful termination monitor is responsible for tracking whether EOS has been terminated gracefully. """ - def __init__(self, db_manager: DbManager): - self._globals = MongoRepository("globals", db_manager) - self._globals.create_indices([("key", 1)], unique=True) + def __init__(self, db_interface: AsyncMongoDbInterface): + self._globals = GlobalRepository(db_interface) + self._terminated_gracefully = False - graceful_termination = self._globals.get_one(key="graceful_termination") + async def initialize(self) -> None: + await self._globals.initialize() + + graceful_termination = await self._globals.get_one(key="graceful_termination") if not graceful_termination: - self._globals.create({"key": "graceful_termination", "terminated_gracefully": False}) + await self._globals.create({"key": "graceful_termination", "terminated_gracefully": False}) self._terminated_gracefully = False else: self._terminated_gracefully = graceful_termination["terminated_gracefully"] if not self._terminated_gracefully: log.warning("EOS did not terminate gracefully!") - def previously_terminated_gracefully(self) -> bool: + async def previously_terminated_gracefully(self) -> bool: return self._terminated_gracefully - def terminated_gracefully(self) -> None: - self._set_terminated_gracefully(True) + async def set_terminated_gracefully(self) -> None: + await self._set_terminated_gracefully(True) log.debug("EOS terminated gracefully.") - def _set_terminated_gracefully(self, value: bool) -> None: + async def _set_terminated_gracefully(self, value: bool) -> None: self._terminated_gracefully = value - self._globals.update({"terminated_gracefully": value}, key="graceful_termination") + await self._globals.update_one({"terminated_gracefully": value}, key="graceful_termination") diff --git a/eos/monitoring/repositories/__init__.py b/eos/monitoring/repositories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/eos/monitoring/repositories/global_repository.py b/eos/monitoring/repositories/global_repository.py new file mode 100644 index 0000000..69a1b7c --- /dev/null +++ b/eos/monitoring/repositories/global_repository.py @@ -0,0 +1,10 @@ +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository + + +class GlobalRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("globals", db_interface) + + async def initialize(self) -> None: + await self.create_indices([("key", 1)], unique=True) diff --git a/eos/orchestration/orchestrator.py b/eos/orchestration/orchestrator.py index 855a264..f1af601 100644 --- a/eos/orchestration/orchestrator.py +++ b/eos/orchestration/orchestrator.py @@ -1,5 +1,4 @@ import asyncio -import atexit import traceback from asyncio import Lock as AsyncLock from collections.abc import AsyncIterable @@ -29,16 +28,16 @@ from eos.orchestration.exceptions import ( EosExperimentTypeInUseError, EosExperimentDoesNotExistError, - EosError, ) -from eos.persistence.db_manager import DbManager -from eos.persistence.file_db_manager import FileDbManager +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.file_db_interface import FileDbInterface from eos.persistence.service_credentials import ServiceCredentials from eos.resource_allocation.resource_allocation_manager import ( ResourceAllocationManager, ) from eos.scheduling.greedy_scheduler import GreedyScheduler from eos.tasks.entities.task import Task, TaskStatus +from eos.tasks.exceptions import EosTaskCancellationError from eos.tasks.on_demand_task_executor import OnDemandTaskExecutor from eos.tasks.task_executor import TaskExecutor from eos.tasks.task_manager import TaskManager @@ -63,12 +62,37 @@ def __init__( self._user_dir = user_dir self._db_credentials = db_credentials self._file_db_credentials = file_db_credentials + self._initialized = False - self.initialize() - atexit.register(self.terminate) + self._configuration_manager: ConfigurationManager | None = None + self._db_interface: AsyncMongoDbInterface | None = None + self._file_db_interface: FileDbInterface | None = None + self._graceful_termination_monitor: GracefulTerminationMonitor | None = None + self._device_manager: DeviceManager | None = None + self._container_manager: ContainerManager | None = None + self._resource_allocation_manager: ResourceAllocationManager | None = None + self._task_manager: TaskManager | None = None + self._experiment_manager: ExperimentManager | None = None + self._campaign_manager: CampaignManager | None = None + self._campaign_optimizer_manager: CampaignOptimizerManager | None = None + self._task_executor: TaskExecutor | None = None + self._on_demand_task_executor: OnDemandTaskExecutor | None = None + self._scheduler: GreedyScheduler | None = None + self._experiment_executor_factory: ExperimentExecutorFactory | None = None + self._campaign_executor_factory: CampaignExecutorFactory | None = None + + self._campaign_submission_lock = AsyncLock() + self._submitted_campaigns: dict[str, CampaignExecutor] = {} + self._experiment_submission_lock = AsyncLock() + self._submitted_experiments: dict[str, ExperimentExecutor] = {} + + self._campaign_cancellation_queue = asyncio.Queue(maxsize=100) + self._experiment_cancellation_queue = asyncio.Queue(maxsize=100) + + self._loading_lock = AsyncLock() - def initialize(self) -> None: + async def initialize(self) -> None: """ Prepare the orchestrator. This is required before any other operations can be performed. """ @@ -84,20 +108,33 @@ def initialize(self) -> None: self._configuration_manager = ConfigurationManager(self._user_dir) # Persistence ############################################# - self._db_manager = DbManager(self._db_credentials) - self._file_db_manager = FileDbManager(self._file_db_credentials) - - # Monitoring ############################################## - self._graceful_termination_monitor = GracefulTerminationMonitor(self._db_manager) + self._db_interface = AsyncMongoDbInterface(self._db_credentials) + self._file_db_interface = FileDbInterface(self._file_db_credentials) # State management ######################################## - self._device_manager = DeviceManager(self._configuration_manager, self._db_manager) - self._container_manager = ContainerManager(self._configuration_manager, self._db_manager) - self._resource_allocation_manager = ResourceAllocationManager(self._configuration_manager, self._db_manager) - self._task_manager = TaskManager(self._configuration_manager, self._db_manager, self._file_db_manager) - self._experiment_manager = ExperimentManager(self._configuration_manager, self._db_manager) - self._campaign_manager = CampaignManager(self._configuration_manager, self._db_manager) - self._campaign_optimizer_manager = CampaignOptimizerManager(self._configuration_manager, self._db_manager) + self._graceful_termination_monitor = GracefulTerminationMonitor(self._db_interface) + await self._graceful_termination_monitor.initialize() + + self._device_manager = DeviceManager(self._configuration_manager, self._db_interface) + await self._device_manager.initialize(self._db_interface) + + self._container_manager = ContainerManager(self._configuration_manager, self._db_interface) + await self._container_manager.initialize(self._db_interface) + + self._resource_allocation_manager = ResourceAllocationManager(self._db_interface) + await self._resource_allocation_manager.initialize(self._configuration_manager, self._db_interface) + + self._task_manager = TaskManager(self._configuration_manager, self._db_interface, self._file_db_interface) + await self._task_manager.initialize(self._db_interface) + + self._experiment_manager = ExperimentManager(self._configuration_manager, self._db_interface) + await self._experiment_manager.initialize(self._db_interface) + + self._campaign_manager = CampaignManager(self._configuration_manager, self._db_interface) + await self._campaign_manager.initialize(self._db_interface) + + self._campaign_optimizer_manager = CampaignOptimizerManager(self._configuration_manager, self._db_interface) + await self._campaign_optimizer_manager.initialize(self._db_interface) # Execution ############################################### self._task_executor = TaskExecutor( @@ -133,39 +170,29 @@ def initialize(self) -> None: self._experiment_executor_factory, ) - self._campaign_submission_lock = AsyncLock() - self._submitted_campaigns: dict[str, CampaignExecutor] = {} - self._experiment_submission_lock = AsyncLock() - self._submitted_experiments: dict[str, ExperimentExecutor] = {} - - self._campaign_cancellation_queue = asyncio.Queue(maxsize=100) - self._experiment_cancellation_queue = asyncio.Queue(maxsize=100) - - self._loading_lock = AsyncLock() - - self._fail_all_running_work() + await self._fail_all_running_work() self._initialized = True - def _fail_all_running_work(self) -> None: + async def _fail_all_running_work(self) -> None: """ When the orchestrator starts, fail all running tasks, experiments, and campaigns. This is for safety, as if the orchestrator was terminated while there was running work then the state of the system may be unknown. We want to force manual review of the state of the system and explicitly require re-submission of any work that was running. """ - running_tasks = self._task_manager.get_tasks(status=TaskStatus.RUNNING.value) + running_tasks = await self._task_manager.get_tasks(status=TaskStatus.RUNNING.value) for task in running_tasks: - self._task_manager.fail_task(task.experiment_id, task.id) + await self._task_manager.fail_task(task.experiment_id, task.id) log.warning(f"EXP '{task.experiment_id}' - Failed task '{task.id}'.") - running_experiments = self._experiment_manager.get_experiments(status=ExperimentStatus.RUNNING.value) + running_experiments = await self._experiment_manager.get_experiments(status=ExperimentStatus.RUNNING.value) for experiment in running_experiments: - self._experiment_manager.fail_experiment(experiment.id) + await self._experiment_manager.fail_experiment(experiment.id) - running_campaigns = self._campaign_manager.get_campaigns(status=CampaignStatus.RUNNING.value) + running_campaigns = await self._campaign_manager.get_campaigns(status=CampaignStatus.RUNNING.value) for campaign in running_campaigns: - self._campaign_manager.fail_campaign(campaign.id) + await self._campaign_manager.fail_campaign(campaign.id) if running_tasks: log.warning("All running tasks have been marked as failed. Please review the state of the system.") @@ -182,7 +209,7 @@ def _fail_all_running_work(self) -> None: "with resume=True." ) - def terminate(self) -> None: + async def terminate(self) -> None: """ Terminate the orchestrator. After this, no other operations can be performed. This should be called before the program exits. @@ -190,27 +217,27 @@ def terminate(self) -> None: if not self._initialized: return log.info("Cleaning up device actors...") - self._device_manager.cleanup_device_actors() + await self._device_manager.cleanup_device_actors() log.info("Shutting down Ray cluster...") ray.shutdown() - self._graceful_termination_monitor.terminated_gracefully() + await self._graceful_termination_monitor.set_terminated_gracefully() self._initialized = False - def load_labs(self, labs: set[str]) -> None: + async def load_labs(self, labs: set[str]) -> None: """ Load one or more labs into the orchestrator. """ self._configuration_manager.load_labs(labs) - self._device_manager.update_devices(loaded_labs=labs) - self._container_manager.update_containers(loaded_labs=labs) + await self._device_manager.update_devices(loaded_labs=labs) + await self._container_manager.update_containers(loaded_labs=labs) - def unload_labs(self, labs: set[str]) -> None: + async def unload_labs(self, labs: set[str]) -> None: """ Unload one or more labs from the orchestrator. """ self._configuration_manager.unload_labs(labs) - self._device_manager.update_devices(unloaded_labs=labs) - self._container_manager.update_containers(unloaded_labs=labs) + await self._device_manager.update_devices(unloaded_labs=labs) + await self._container_manager.update_containers(unloaded_labs=labs) async def reload_labs(self, lab_types: set[str]) -> None: """ @@ -219,22 +246,23 @@ async def reload_labs(self, lab_types: set[str]) -> None: async with self._loading_lock: experiments_to_reload = set() for lab_type in lab_types: - existing_experiments = self._experiment_manager.get_experiments(status=ExperimentStatus.RUNNING.value) + existing_experiments = await self._experiment_manager.get_experiments( + status=ExperimentStatus.RUNNING.value + ) for experiment in existing_experiments: experiment_config = self._configuration_manager.experiments[experiment.type] if lab_type in experiment_config.labs: - raise EosExperimentTypeInUseError( - f"Cannot reload lab type '{lab_type}' as there are running experiments that use it." - ) + log.error(f"Cannot reload lab type '{lab_type}' as there are running experiments that use it.") + raise EosExperimentTypeInUseError # Determine experiments to reload for this lab type for experiment_type, experiment_config in self._configuration_manager.experiments.items(): if lab_type in experiment_config.labs: experiments_to_reload.add(experiment_type) try: - self.unload_labs(lab_types) - self.load_labs(lab_types) + await self.unload_labs(lab_types) + await self.load_labs(lab_types) self.load_experiments(experiments_to_reload) except EosConfigurationError: log.error(f"Error reloading labs: {traceback.format_exc()}") @@ -254,18 +282,19 @@ async def update_loaded_labs(self, lab_types: set[str]) -> None: to_load = lab_types - currently_loaded for lab_type in to_unload: - existing_experiments = self._experiment_manager.get_experiments(status=ExperimentStatus.RUNNING.value) + existing_experiments = await self._experiment_manager.get_experiments( + status=ExperimentStatus.RUNNING.value + ) for experiment in existing_experiments: experiment_config = self._configuration_manager.experiments[experiment.type] if lab_type in experiment_config.labs: - raise EosExperimentTypeInUseError( - f"Cannot unload lab type '{lab_type}' as there are running experiments that use it." - ) + log.error(f"Cannot unload lab type '{lab_type}' as there are running experiments that use it.") + raise EosExperimentTypeInUseError try: - self.unload_labs(to_unload) - self.load_labs(to_load) + await self.unload_labs(to_unload) + await self.load_labs(to_load) except EosConfigurationError: log.error(f"Error updating loaded labs: {traceback.format_exc()}") raise @@ -294,14 +323,15 @@ async def reload_experiments(self, experiment_types: set[str]) -> None: """ async with self._loading_lock: for experiment_type in experiment_types: - existing_experiments = self._experiment_manager.get_experiments( + existing_experiments = await self._experiment_manager.get_experiments( status=ExperimentStatus.RUNNING.value, type=experiment_type ) if existing_experiments: - raise EosExperimentTypeInUseError( + log.error( f"Cannot reload experiment type '{experiment_type}' as there are running experiments of this " f"type." ) + raise EosExperimentTypeInUseError try: self.unload_experiments(experiment_types) self.load_experiments(experiment_types) @@ -323,14 +353,15 @@ async def update_loaded_experiments(self, experiment_types: set[str]) -> None: to_load = experiment_types - currently_loaded for experiment_type in to_unload: - existing_experiments = self._experiment_manager.get_experiments( + existing_experiments = await self._experiment_manager.get_experiments( status=ExperimentStatus.RUNNING.value, type=experiment_type ) if existing_experiments: - raise EosExperimentTypeInUseError( + log.error( f"Cannot unload experiment type '{experiment_type}' as there are running experiments of this " f"type." ) + raise EosExperimentTypeInUseError try: self.unload_experiments(to_unload) @@ -388,9 +419,9 @@ async def get_task(self, experiment_id: str, task_id: str) -> Task: :param task_id: The unique identifier of the task. :return: The task entity. """ - return self._task_manager.get_task(experiment_id, task_id) + return await self._task_manager.get_task(experiment_id, task_id) - async def submit_task( + def submit_task( self, task_config: TaskConfig, resource_allocation_priority: int = 1, @@ -406,7 +437,7 @@ async def submit_task( error. :return: The output of the task. """ - await self._on_demand_task_executor.submit_task( + self._on_demand_task_executor.submit_task( task_config, resource_allocation_priority, resource_allocation_timeout ) @@ -417,10 +448,13 @@ async def cancel_task(self, task_id: str, experiment_id: str = "on_demand") -> N :param task_id: The unique identifier of the task. :param experiment_id: The unique identifier of the experiment. """ - if experiment_id == "on_demand": - await self._on_demand_task_executor.cancel_task(task_id) - else: - await self._task_executor.request_task_cancellation(experiment_id, task_id) + try: + if experiment_id == "on_demand": + await self._on_demand_task_executor.request_task_cancellation(task_id) + else: + await self._task_executor.request_task_cancellation(experiment_id, task_id) + except EosTaskCancellationError: + log.error(f"Failed to cancel task '{task_id}'.") async def get_task_types(self) -> list[str]: """ @@ -434,7 +468,7 @@ async def get_task_spec(self, task_type: str) -> TaskSpecification | None: """ task_spec = self._configuration_manager.task_specs.get_spec_by_type(task_type) if not task_spec: - raise EosError(f"Task type '{task_type}' does not exist.") + log.error(f"Task type '{task_type}' does not exist.") return task_spec @@ -459,7 +493,7 @@ async def get_experiment(self, experiment_id: str) -> Experiment | None: :param experiment_id: The unique identifier of the experiment. :return: The experiment entity. """ - return self._experiment_manager.get_experiment(experiment_id) + return await self._experiment_manager.get_experiment(experiment_id) async def submit_experiment( self, @@ -491,7 +525,7 @@ async def submit_experiment( ) try: - experiment_executor.start_experiment(dynamic_parameters, metadata) + await experiment_executor.start_experiment(dynamic_parameters, metadata) self._submitted_experiments[experiment_id] = experiment_executor except EosExperimentExecutionError: log.error(f"Failed to submit experiment '{experiment_id}': {traceback.format_exc()}") @@ -542,7 +576,7 @@ async def get_campaign(self, campaign_id: str) -> Campaign | None: :param campaign_id: The unique identifier of the campaign. :return: The campaign entity. """ - return self._campaign_manager.get_campaign(campaign_id) + return await self._campaign_manager.get_campaign(campaign_id) async def submit_campaign( self, @@ -580,46 +614,56 @@ async def cancel_campaign(self, campaign_id: str) -> None: if campaign_id in self._submitted_campaigns: await self._campaign_cancellation_queue.put(campaign_id) - async def spin(self, rate_hz: int = 10) -> None: + async def spin(self, rate_hz: int = 5) -> None: """ - Spin the orchestrator at a given rate in Hz. + Spin the orchestrator at a given rate in Hz. Process submitted work. - :param rate_hz: The processing rate in Hz. This is the rate in which the orchestrator will check for progress in - submitted experiments and campaigns. + :param rate_hz: The processing rate in Hz. This is the rate in which the orchestrator updates. """ while True: - await self._process_experiment_and_campaign_cancellations() + await self._process_experiment_cancellations() + await self._process_campaign_cancellations() await asyncio.gather( self._process_on_demand_tasks(), self._process_experiments(), self._process_campaigns(), ) - self._resource_allocation_manager.process_active_requests() + + await self._resource_allocation_manager.process_active_requests() await asyncio.sleep(1 / rate_hz) - async def _process_experiment_and_campaign_cancellations(self) -> None: + async def _process_experiment_cancellations(self) -> None: + experiment_ids = [] while not self._experiment_cancellation_queue.empty(): - experiment_id = await self._experiment_cancellation_queue.get() + experiment_ids.append(await self._experiment_cancellation_queue.get()) - log.warning(f"Attempting to cancel experiment '{experiment_id}'.") - try: - await self._submitted_experiments[experiment_id].cancel_experiment() - finally: - del self._submitted_experiments[experiment_id] - log.warning(f"Cancelled experiment '{experiment_id}'.") + if experiment_ids: + log.warning(f"Attempting to cancel experiments: {experiment_ids}") + experiment_cancel_tasks = [ + self._submitted_experiments[exp_id].cancel_experiment() for exp_id in experiment_ids + ] + await asyncio.gather(*experiment_cancel_tasks) + + for exp_id in experiment_ids: + del self._submitted_experiments[exp_id] + log.warning(f"Cancelled experiments: {experiment_ids}") + async def _process_campaign_cancellations(self) -> None: + campaign_ids = [] while not self._campaign_cancellation_queue.empty(): - campaign_id = await self._campaign_cancellation_queue.get() + campaign_ids.append(await self._campaign_cancellation_queue.get()) - log.warning(f"Attempting to cancel campaign '{campaign_id}'.") - try: - await self._submitted_campaigns[campaign_id].cancel_campaign() - finally: + if campaign_ids: + log.warning(f"Attempting to cancel campaigns: {campaign_ids}") + campaign_cancel_tasks = [self._submitted_campaigns[camp_id].cancel_campaign() for camp_id in campaign_ids] + await asyncio.gather(*campaign_cancel_tasks) + + for campaign_id in campaign_ids: self._submitted_campaigns[campaign_id].cleanup() del self._submitted_campaigns[campaign_id] - log.warning(f"Cancelled campaign '{campaign_id}'.") + log.warning(f"Cancelled campaigns: {campaign_ids}") async def _process_experiments(self) -> None: to_remove_completed = [] @@ -680,42 +724,5 @@ async def _process_on_demand_tasks(self) -> None: def _validate_experiment_type_exists(self, experiment_type: str) -> None: if experiment_type not in self._configuration_manager.experiments: - raise EosExperimentDoesNotExistError( - f"Cannot submit experiment of type '{experiment_type}' as it does not exist." - ) - - @property - def configuration_manager(self) -> ConfigurationManager: - return self._configuration_manager - - @property - def db_manager(self) -> DbManager: - return self._db_manager - - @property - def device_manager(self) -> DeviceManager: - return self._device_manager - - @property - def container_manager(self) -> ContainerManager: - return self._container_manager - - @property - def resource_allocation_manager(self) -> ResourceAllocationManager: - return self._resource_allocation_manager - - @property - def task_manager(self) -> TaskManager: - return self._task_manager - - @property - def experiment_manager(self) -> ExperimentManager: - return self._experiment_manager - - @property - def campaign_manager(self) -> CampaignManager: - return self._campaign_manager - - @property - def task_executor(self) -> TaskExecutor: - return self._task_executor + log.error(f"Cannot submit experiment of type '{experiment_type}' as it does not exist.") + raise EosExperimentDoesNotExistError diff --git a/eos/persistence/abstract_async_repository.py b/eos/persistence/abstract_async_repository.py new file mode 100644 index 0000000..518c8b8 --- /dev/null +++ b/eos/persistence/abstract_async_repository.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class AbstractAsyncRepository(ABC): + """ + Abstract class for a repository that provides CRUD operations for a collection of entities. + """ + + @abstractmethod + async def create(self, entity: dict) -> None: + pass + + @abstractmethod + async def count(self, **query: dict) -> int: + pass + + @abstractmethod + async def exists(self, count: int = 1, **query: dict) -> bool: + pass + + @abstractmethod + async def get_one(self, **query: dict) -> dict: + pass + + @abstractmethod + async def get_all(self, **query: dict) -> list[dict]: + pass + + @abstractmethod + async def update_one(self, updated_entity: dict[str, Any], **kwargs) -> None: + pass + + @abstractmethod + async def delete_one(self, **query: dict) -> None: + pass + + @abstractmethod + async def delete_many(self, **query: dict) -> None: + pass + + @abstractmethod + async def delete_all(self) -> None: + pass diff --git a/eos/persistence/abstract_repository.py b/eos/persistence/abstract_repository.py deleted file mode 100644 index c9ffa7f..0000000 --- a/eos/persistence/abstract_repository.py +++ /dev/null @@ -1,35 +0,0 @@ -from abc import ABC, abstractmethod - - -class AbstractRepository(ABC): - """ - Abstract class for a repository that provides CRUD operations for a collection of entities. - """ - - @abstractmethod - def create(self, entity: dict) -> None: - pass - - @abstractmethod - def count(self, **query: dict) -> int: - pass - - @abstractmethod - def exists(self, count: int = 1, **query: dict) -> bool: - pass - - @abstractmethod - def get_one(self, **query: dict) -> dict: - pass - - @abstractmethod - def get_all(self, **query: dict) -> list[dict]: - pass - - @abstractmethod - def update(self, entity_id: str, entity: dict) -> None: - pass - - @abstractmethod - def delete(self, entity_id: str) -> None: - pass diff --git a/eos/persistence/async_mongodb_interface.py b/eos/persistence/async_mongodb_interface.py new file mode 100644 index 0000000..662bb06 --- /dev/null +++ b/eos/persistence/async_mongodb_interface.py @@ -0,0 +1,39 @@ +from motor.core import AgnosticDatabase +from motor.motor_asyncio import AsyncIOMotorClient + +from eos.logging.logger import log +from eos.persistence.async_mongodb_session_factory import AsyncMongoDbSessionFactory +from eos.persistence.service_credentials import ServiceCredentials + + +class AsyncMongoDbInterface: + """ + Gives asynchronous access to a MongoDB database. + """ + + def __init__( + self, + db_credentials: ServiceCredentials, + db_name: str = "eos", + ): + self._db_credentials = db_credentials + + self._db_client = AsyncIOMotorClient( + f"mongodb://{self._db_credentials.username}:{self._db_credentials.password}" + f"@{self._db_credentials.host}:{self._db_credentials.port}" + ) + + self._db: AgnosticDatabase = self._db_client[db_name] + self.session_factory = AsyncMongoDbSessionFactory(self._db_client) + + log.debug(f"Async Db manager initialized with database '{db_name}'.") + + def get_db(self) -> AgnosticDatabase: + """Get the database.""" + return self._db + + async def clean_db(self) -> None: + """Clean the database.""" + collections = await self._db.list_collection_names() + for collection in collections: + await self._db[collection].drop() diff --git a/eos/persistence/async_mongodb_session_factory.py b/eos/persistence/async_mongodb_session_factory.py new file mode 100644 index 0000000..b20e2ec --- /dev/null +++ b/eos/persistence/async_mongodb_session_factory.py @@ -0,0 +1,29 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +from motor.core import AgnosticClientSession +from motor.motor_asyncio import AsyncIOMotorClient + + +class AsyncMongoDbSessionFactory: + def __init__(self, db_client: AsyncIOMotorClient): + self._db_client = db_client + + @asynccontextmanager + async def __call__(self) -> AsyncGenerator[AgnosticClientSession, None]: + """ + Async context manager for MongoDB sessions with transactions. + Usage: + async with db_manager.transaction_session_factory() as session: + # Perform operations within the session and transaction + """ + session = await self._db_client.start_session() + try: + async with session.start_transaction(): + try: + yield session + except Exception: + await session.abort_transaction() + raise + finally: + await session.end_session() diff --git a/eos/persistence/db_manager.py b/eos/persistence/db_manager.py deleted file mode 100644 index f2dbbd4..0000000 --- a/eos/persistence/db_manager.py +++ /dev/null @@ -1,55 +0,0 @@ -from pymongo import MongoClient -from pymongo.client_session import ClientSession -from pymongo.database import Database - -from eos.logging.logger import log -from eos.persistence.service_credentials import ServiceCredentials - - -class DbManager: - """ - Responsible for giving access to a MongoDB database. - """ - - def __init__( - self, - db_credentials: ServiceCredentials, - db_name: str = "eos", - ): - self._db_credentials = db_credentials - - self._db_client = MongoClient( - host=self._db_credentials.host, - port=self._db_credentials.port, - username=self._db_credentials.username, - password=self._db_credentials.password, - serverSelectionTimeoutMS=10000, - ) - - self._db: Database = self._db_client[db_name] - - log.debug(f"Db manager initialized with database '{db_name}'.") - - def get_db(self) -> Database: - """Get the database.""" - return self._db - - def create_collection_index(self, collection: str, index: list[tuple[str, int]], unique: bool = False) -> None: - """ - Create an index for a collection in the database if it doesn't already exist. - :param collection: The collection name. - :param index: The index to create. A list of tuples of the field names and index orders. - :param unique: Whether the index should be unique. - """ - index_name = "_".join(f"{field}_{order}" for field, order in index) - if index_name not in self._db[collection].index_information(): - self._db[collection].create_index(index, unique=unique, name=index_name) - - def start_session(self) -> ClientSession: - """Start a new client session.""" - return self._db_client.start_session() - - def clean_db(self) -> None: - """Clean the database.""" - for collection in self._db.list_collection_names(): - self._db[collection].drop() diff --git a/eos/persistence/file_db_manager.py b/eos/persistence/file_db_interface.py similarity index 97% rename from eos/persistence/file_db_manager.py rename to eos/persistence/file_db_interface.py index f848aab..9203f95 100644 --- a/eos/persistence/file_db_manager.py +++ b/eos/persistence/file_db_interface.py @@ -8,9 +8,9 @@ from eos.persistence.service_credentials import ServiceCredentials -class FileDbManager: +class FileDbInterface: """ - Responsible for storing and retrieving files from a MinIO server. + Provides access to a MinIO server for storing and retrieving files. """ def __init__(self, file_db_credentials: ServiceCredentials, bucket_name: str = "eos"): diff --git a/eos/persistence/mongo_repository.py b/eos/persistence/mongo_repository.py deleted file mode 100644 index e2b0bcc..0000000 --- a/eos/persistence/mongo_repository.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import Any - -from pymongo.results import DeleteResult, UpdateResult, InsertOneResult - -from eos.persistence.abstract_repository import AbstractRepository -from eos.persistence.db_manager import DbManager - - -class MongoRepository(AbstractRepository): - """ - Provides CRUD operations for a MongoDB collection. - """ - - def __init__(self, collection_name: str, db_manager: DbManager): - self._collection = db_manager.get_db().get_collection(collection_name) - - def create_indices(self, indices: list[tuple[str, int]], unique: bool = False) -> None: - """ - Create indices on the collection. - - :param indices: List of tuples of field names and order (1 for ascending, -1 for descending). - :param unique: Whether the index should be unique. - """ - index_name = "_".join(f"{field}_{order}" for field, order in indices) - if index_name not in self._collection.index_information(): - self._collection.create_index(indices, unique=unique, name=index_name) - - def create(self, entity: dict[str, Any]) -> InsertOneResult: - """ - Create a new entity in the collection. - - :param entity: The entity to create. - :return: The result of the insert operation. - """ - return self._collection.insert_one(entity) - - def count(self, **kwargs) -> int: - """ - Count the number of entities that match the query in the collection. - - :param kwargs: Query parameters. - :return: The number of entities. - """ - return self._collection.count_documents(kwargs) - - def exists(self, count: int = 1, **kwargs) -> bool: - """ - Check if the number of entities that match the query exist in the collection. - - :param count: The number of entities to check for. - :param kwargs: Query parameters. - :return: Whether the entity exists. - """ - return self.count(**kwargs) >= count - - def get_one(self, **kwargs) -> dict[str, Any]: - """ - Get a single entity from the collection. - - :param kwargs: Query parameters. - :return: The entity as a dictionary. - """ - return self._collection.find_one(kwargs) - - def get_all(self, **kwargs) -> list[dict[str, Any]]: - """ - Get all entities from the collection. - - :param kwargs: Query parameters. - :return: List of entities as dictionaries. - """ - return list(self._collection.find(kwargs)) - - def update(self, entity: dict[str, Any], **kwargs) -> UpdateResult: - """ - Update an entity in the collection. - - :param entity: The updated entity (or some of its fields). - :param kwargs: Query parameters. - :return: The result of the update operation. - """ - return self._collection.update_one(kwargs, {"$set": entity}, upsert=True) - - def delete(self, **kwargs) -> DeleteResult: - """ - Delete entities from the collection. - - :param kwargs: Query parameters. - :return: The result of the delete operation. - """ - return self._collection.delete_many(kwargs) diff --git a/eos/persistence/mongodb_async_repository.py b/eos/persistence/mongodb_async_repository.py new file mode 100644 index 0000000..207d80a --- /dev/null +++ b/eos/persistence/mongodb_async_repository.py @@ -0,0 +1,120 @@ +from typing import Any + +from motor.core import AgnosticClientSession +from pymongo.results import DeleteResult, UpdateResult, InsertOneResult + +from eos.persistence.abstract_async_repository import AbstractAsyncRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface + + +class MongoDbAsyncRepository(AbstractAsyncRepository): + """ + Provides CRUD operations for a MongoDB collection. + """ + + def __init__(self, collection_name: str, db_interface: AsyncMongoDbInterface): + self._collection = db_interface.get_db().get_collection(collection_name) + + async def create_indices(self, indices: list[tuple[str, int]], unique: bool = False) -> None: + """ + Create indices on the collection synchronously. + + :param indices: List of tuples of field names and order (1 for ascending, -1 for descending). + :param unique: Whether the index should be unique. + """ + index_name = "_".join(f"{field}_{order}" for field, order in indices) + + if index_name not in await self._collection.index_information(): + await self._collection.create_index(indices, unique=unique, name=index_name) + + async def create(self, entity: dict[str, Any], session: AgnosticClientSession | None = None) -> InsertOneResult: + """ + Create a new entity in the collection. + + :param entity: The entity to create. + :param session: The optional session to use for the operation. + :return: The result of the insert operation. + """ + return await self._collection.insert_one(entity, session=session) + + async def count(self, session: AgnosticClientSession | None = None, **kwargs) -> int: + """ + Count the number of entities that match the query in the collection. + + :param session: The optional session to use for the operation. + :param kwargs: Query parameters. + :return: The number of entities. + """ + return await self._collection.count_documents(session=session, filter=kwargs) + + async def exists(self, count: int = 1, session: AgnosticClientSession | None = None, **kwargs) -> bool: + """ + Check if the number of entities that match the query exist in the collection. + + :param count: The number of entities to check for. + :param session: The optional session to use for the operation. + :param kwargs: Query parameters. + :return: Whether the entity exists. + """ + return await self.count(session=session, **kwargs) >= count + + async def get_one(self, session: AgnosticClientSession | None = None, **kwargs) -> dict[str, Any]: + """ + Get a single entity from the collection. + + :param session: The optional session to use for the operation. + :param kwargs: Query parameters. + :return: The entity as a dictionary. + """ + return await self._collection.find_one(kwargs, session=session) + + async def get_all(self, session: AgnosticClientSession | None = None, **kwargs) -> list[dict[str, Any]]: + """ + Get all entities from the collection, optionally filtered by query parameters. + + :param session: The optional session to use for the operation. + :param kwargs: Query parameters. + :return: List of entities as dictionaries. + """ + return await self._collection.find(kwargs, session=session).to_list(None) + + async def update_one(self, updated_entity: dict[str, Any], session: AgnosticClientSession | None = None, + **kwargs) -> UpdateResult: + """ + Update an entity in the collection. + + :param updated_entity: The updated entity (or some of its fields). + :param session: The optional session to use for the operation. + :param kwargs: Query parameters. + :return: The result of the update operation. + """ + return await self._collection.update_one(kwargs, {"$set": updated_entity}, upsert=True, session=session) + + async def delete_one(self, session: AgnosticClientSession | None = None, **kwargs) -> DeleteResult: + """ + Delete an entity from the collection, optionally filtered by query parameters. + + :param session: The optional session to use for the operation. + :param kwargs: Query parameters. + :return: The result of the delete operation. + """ + return await self._collection.delete_one(kwargs, session=session) + + async def delete_many(self, session: AgnosticClientSession | None = None, **kwargs) -> DeleteResult: + """ + Delete multiple entities from the collection, optionally filtered by query parameters. + + :param session: The optional session to use for the operation. + :param kwargs: Query parameters. + :return: The result of the delete operation. + """ + return await self._collection.delete_many(kwargs, session=session) + + async def delete_all(self, session: AgnosticClientSession | None = None) -> DeleteResult: + """ + Delete all entities from the collection. + + :param session: The optional session to use for the operation. + :return: The result of the delete operation. + """ + return await self._collection.delete_many({}, session=session) diff --git a/eos/resource_allocation/container_allocation_manager.py b/eos/resource_allocation/container_allocator.py similarity index 67% rename from eos/resource_allocation/container_allocation_manager.py rename to eos/resource_allocation/container_allocator.py index fd0645a..15925c4 100644 --- a/eos/resource_allocation/container_allocation_manager.py +++ b/eos/resource_allocation/container_allocator.py @@ -2,8 +2,7 @@ from eos.configuration.configuration_manager import ConfigurationManager from eos.logging.logger import log -from eos.persistence.db_manager import DbManager -from eos.persistence.mongo_repository import MongoRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface from eos.resource_allocation.entities.container_allocation import ( ContainerAllocation, ) @@ -11,9 +10,10 @@ EosContainerAllocatedError, EosContainerNotFoundError, ) +from eos.resource_allocation.repositories.container_allocation_repository import ContainerAllocationRepository -class ContainerAllocationManager: +class ContainerAllocator: """ Responsible for allocating containers to "owners". An owner may be an experiment task, a human, etc. A container can only be held by one owner at a time. @@ -22,19 +22,22 @@ class ContainerAllocationManager: def __init__( self, configuration_manager: ConfigurationManager, - db_manager: DbManager, + db_interface: AsyncMongoDbInterface, ): self._configuration_manager = configuration_manager - self._allocations = MongoRepository("container_allocations", db_manager) - self._allocations.create_indices([("id", 1)], unique=True) + self._session_factory = db_interface.session_factory + self._allocations = None + async def initialize(self, db_interface: AsyncMongoDbInterface) -> None: + self._allocations = ContainerAllocationRepository(db_interface) + await self._allocations.initialize() log.debug("Container allocator initialized.") - def allocate(self, container_id: str, owner: str, experiment_id: str | None = None) -> None: + async def allocate(self, container_id: str, owner: str, experiment_id: str | None = None) -> None: """ Allocate a container to an owner. """ - if self.is_allocated(container_id): + if await self.is_allocated(container_id): raise EosContainerAllocatedError(f"Container '{container_id}' is already allocated.") container_config = self._get_container_config(container_id) @@ -45,45 +48,45 @@ def allocate(self, container_id: str, owner: str, experiment_id: str | None = No lab=container_config["lab"], experiment_id=experiment_id, ) - self._allocations.create(allocation.model_dump()) + await self._allocations.create(allocation.model_dump()) - def deallocate(self, container_id: str) -> None: + async def deallocate(self, container_id: str) -> None: """ Deallocate a container. """ - result = self._allocations.delete(id=container_id) + result = await self._allocations.delete_one(id=container_id) if result.deleted_count == 0: log.warning(f"Container '{container_id}' is not allocated. No action taken.") else: log.debug(f"Deallocated container '{container_id}'.") - def is_allocated(self, container_id: str) -> bool: + async def is_allocated(self, container_id: str) -> bool: """ Check if a container is allocated. """ self._get_container_config(container_id) - return self._allocations.get_one(id=container_id) is not None + return await self._allocations.exists(id=container_id) - def get_allocation(self, container_id: str) -> ContainerAllocation | None: + async def get_allocation(self, container_id: str) -> ContainerAllocation | None: """ Get the allocation details of a container. """ self._get_container_config(container_id) - allocation = self._allocations.get_one(id=container_id) + allocation = await self._allocations.get_one(id=container_id) return ContainerAllocation(**allocation) if allocation else None - def get_allocations(self, **query: dict[str, Any]) -> list[ContainerAllocation]: + async def get_allocations(self, **query: dict[str, Any]) -> list[ContainerAllocation]: """ Query allocations with arbitrary parameters. """ - allocations = self._allocations.get_all(**query) + allocations = await self._allocations.get_all(**query) return [ContainerAllocation(**allocation) for allocation in allocations] - def get_all_unallocated(self) -> list[str]: + async def get_all_unallocated(self) -> list[str]: """ Get all unallocated containers. """ - allocated_containers = [allocation.id for allocation in self.get_allocations()] + allocated_containers = [allocation.id for allocation in await self.get_allocations()] all_containers = [ container_id for lab_config in self._configuration_manager.labs.values() @@ -92,18 +95,18 @@ def get_all_unallocated(self) -> list[str]: ] return list(set(all_containers) - set(allocated_containers)) - def deallocate_all(self) -> None: + async def deallocate_all(self) -> None: """ Deallocate all containers. """ - result = self._allocations.delete() + result = await self._allocations.delete_all() log.debug(f"Deallocated all {result.deleted_count} containers.") - def deallocate_all_by_owner(self, owner: str) -> None: + async def deallocate_all_by_owner(self, owner: str) -> None: """ Deallocate all containers allocated to an owner. """ - result = self._allocations.delete(owner=owner) + result = await self._allocations.delete_many(owner=owner) if result.deleted_count == 0: log.warning(f"Owner '{owner}' has no containers allocated. No action taken.") else: diff --git a/eos/resource_allocation/device_allocation_manager.py b/eos/resource_allocation/device_allocator.py similarity index 64% rename from eos/resource_allocation/device_allocation_manager.py rename to eos/resource_allocation/device_allocator.py index 3e6c75c..249874f 100644 --- a/eos/resource_allocation/device_allocation_manager.py +++ b/eos/resource_allocation/device_allocator.py @@ -2,8 +2,7 @@ from eos.configuration.configuration_manager import ConfigurationManager from eos.logging.logger import log -from eos.persistence.db_manager import DbManager -from eos.persistence.mongo_repository import MongoRepository +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface from eos.resource_allocation.entities.device_allocation import ( DeviceAllocation, ) @@ -11,9 +10,10 @@ EosDeviceAllocatedError, EosDeviceNotFoundError, ) +from eos.resource_allocation.repositories.device_allocation_repository import DeviceAllocationRepository -class DeviceAllocationManager: +class DeviceAllocator: """ Responsible for allocating devices to "owners". An owner may be an experiment task, a human, etc. A device can only be held by one owner at a time. @@ -22,19 +22,22 @@ class DeviceAllocationManager: def __init__( self, configuration_manager: ConfigurationManager, - db_manager: DbManager, + db_interface: AsyncMongoDbInterface, ): self._configuration_manager = configuration_manager - self._allocations = MongoRepository("device_allocations", db_manager) - self._allocations.create_indices([("lab_id", 1), ("id", 1)], unique=True) + self._session_factory = db_interface.session_factory + self._allocations = None + async def initialize(self, db_interface: AsyncMongoDbInterface) -> None: + self._allocations = DeviceAllocationRepository(db_interface) + await self._allocations.initialize() log.debug("Device allocator initialized.") - def allocate(self, lab_id: str, device_id: str, owner: str, experiment_id: str | None = None) -> None: + async def allocate(self, lab_id: str, device_id: str, owner: str, experiment_id: str | None = None) -> None: """ Allocate a device to an owner. """ - if self.is_allocated(lab_id, device_id): + if await self.is_allocated(lab_id, device_id): raise EosDeviceAllocatedError(f"Device '{device_id}' in lab '{lab_id}' is already allocated.") device_config = self._get_device_config(lab_id, device_id) @@ -45,65 +48,65 @@ def allocate(self, lab_id: str, device_id: str, owner: str, experiment_id: str | device_type=device_config["type"], experiment_id=experiment_id, ) - self._allocations.create(allocation.model_dump()) + await self._allocations.create(allocation.model_dump()) - def deallocate(self, lab_id: str, device_id: str) -> None: + async def deallocate(self, lab_id: str, device_id: str) -> None: """ Deallocate a device. """ - result = self._allocations.delete(lab_id=lab_id, id=device_id) + result = await self._allocations.delete_one(lab_id=lab_id, id=device_id) if result.deleted_count == 0: log.warning(f"Device '{device_id}' in lab '{lab_id}' is not allocated. No action taken.") else: log.debug(f"Deallocated device '{device_id}' in lab '{lab_id}'.") - def is_allocated(self, lab_id: str, device_id: str) -> bool: + async def is_allocated(self, lab_id: str, device_id: str) -> bool: """ Check if a device is allocated. """ self._get_device_config(lab_id, device_id) - return self._allocations.get_one(lab_id=lab_id, id=device_id) is not None + return await self._allocations.exists(lab_id=lab_id, id=device_id) - def get_allocation(self, lab_id: str, device_id: str) -> DeviceAllocation | None: + async def get_allocation(self, lab_id: str, device_id: str) -> DeviceAllocation | None: """ Get the allocation details of a device. """ self._get_device_config(lab_id, device_id) - allocation = self._allocations.get_one(lab_id=lab_id, id=device_id) + allocation = await self._allocations.get_one(lab_id=lab_id, id=device_id) return DeviceAllocation(**allocation) if allocation else None - def get_allocations(self, **query: dict[str, Any]) -> list[DeviceAllocation]: + async def get_allocations(self, **query: dict[str, Any]) -> list[DeviceAllocation]: """ Query device allocations with arbitrary parameters. """ - allocations = self._allocations.get_all(**query) + allocations = await self._allocations.get_all(**query) return [DeviceAllocation(**allocation) for allocation in allocations] - def get_all_unallocated(self) -> list[str]: + async def get_all_unallocated(self) -> list[str]: """ Get all unallocated devices. """ - allocated_devices = [allocation.id for allocation in self.get_allocations()] + allocated_devices = [allocation.id for allocation in await self.get_allocations()] all_devices = [ device_id for lab_config in self._configuration_manager.labs.values() for device_id in lab_config.devices ] return list(set(all_devices) - set(allocated_devices)) - def deallocate_all_by_owner(self, owner: str) -> None: + async def deallocate_all_by_owner(self, owner: str) -> None: """ Deallocate all devices allocated to an owner. """ - result = self._allocations.delete(owner=owner) + result = await self._allocations.delete_many(owner=owner) if result.deleted_count == 0: log.warning(f"Owner '{owner}' has no devices allocated. No action taken.") else: log.debug(f"Deallocated {result.deleted_count} devices for owner '{owner}'.") - def deallocate_all(self) -> None: + async def deallocate_all(self) -> None: """ Deallocate all devices. """ - result = self._allocations.delete() + result = await self._allocations.delete_all() log.debug(f"Deallocated all {result.deleted_count} devices.") def _get_device_config(self, lab_id: str, device_id: str) -> dict[str, Any]: diff --git a/eos/resource_allocation/repositories/container_allocation_repository.py b/eos/resource_allocation/repositories/container_allocation_repository.py new file mode 100644 index 0000000..bf97d2a --- /dev/null +++ b/eos/resource_allocation/repositories/container_allocation_repository.py @@ -0,0 +1,10 @@ +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository + + +class ContainerAllocationRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("container_allocations", db_interface) + + async def initialize(self) -> None: + await self.create_indices([("id", 1)], unique=True) diff --git a/eos/resource_allocation/repositories/device_allocation_repository.py b/eos/resource_allocation/repositories/device_allocation_repository.py new file mode 100644 index 0000000..b1f6639 --- /dev/null +++ b/eos/resource_allocation/repositories/device_allocation_repository.py @@ -0,0 +1,10 @@ +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository + + +class DeviceAllocationRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("device_allocations", db_interface) + + async def initialize(self) -> None: + await self.create_indices([("lab_id", 1), ("id", 1)], unique=True) diff --git a/eos/resource_allocation/repositories/resource_request_repository.py b/eos/resource_allocation/repositories/resource_request_repository.py index 9998349..077969a 100644 --- a/eos/resource_allocation/repositories/resource_request_repository.py +++ b/eos/resource_allocation/repositories/resource_request_repository.py @@ -1,15 +1,24 @@ -from eos.persistence.mongo_repository import MongoRepository +from motor.core import AgnosticClientSession + +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository from eos.resource_allocation.entities.resource_request import ( ResourceAllocationRequest, ResourceRequestAllocationStatus, ) -class ResourceRequestRepository(MongoRepository): - def get_requests_prioritized(self, status: ResourceRequestAllocationStatus) -> list[dict]: - return self._collection.find({"status": status.value}).sort("request.priority", 1) +class ResourceRequestRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("resource_requests", db_interface) + + async def get_requests_prioritized(self, status: ResourceRequestAllocationStatus, + session: AgnosticClientSession | None = None) -> list[dict]: + return await self._collection.find({"status": status.value}, session=session).sort("request.priority", + 1).to_list() - def get_existing_request(self, request: ResourceAllocationRequest) -> dict: + async def get_existing_request(self, request: ResourceAllocationRequest, + session: AgnosticClientSession | None = None) -> dict: query = { "request.resources": [r.model_dump() for r in request.resources], "request.requester": request.requester, @@ -21,10 +30,10 @@ def get_existing_request(self, request: ResourceAllocationRequest) -> dict: }, } - return self._collection.find_one(query) + return await self._collection.find_one(query, session=session) - def clean_requests(self) -> None: - self._collection.delete_many( + async def clean_requests(self, session: AgnosticClientSession | None = None) -> None: + await self._collection.delete_many( { "status": { "$in": [ @@ -32,5 +41,6 @@ def clean_requests(self) -> None: ResourceRequestAllocationStatus.ABORTED.value, ] } - } + }, + session=session, ) diff --git a/eos/resource_allocation/resource_allocation_manager.py b/eos/resource_allocation/resource_allocation_manager.py index 882c7ad..7387786 100644 --- a/eos/resource_allocation/resource_allocation_manager.py +++ b/eos/resource_allocation/resource_allocation_manager.py @@ -1,14 +1,14 @@ +import asyncio from collections.abc import Callable from datetime import datetime, timezone -from threading import Lock from bson import ObjectId from eos.configuration.configuration_manager import ConfigurationManager from eos.logging.logger import log -from eos.persistence.db_manager import DbManager -from eos.resource_allocation.container_allocation_manager import ContainerAllocationManager -from eos.resource_allocation.device_allocation_manager import DeviceAllocationManager +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.resource_allocation.container_allocator import ContainerAllocator +from eos.resource_allocation.device_allocator import DeviceAllocator from eos.resource_allocation.entities.resource_request import ( ResourceAllocationRequest, ActiveResourceAllocationRequest, @@ -28,24 +28,32 @@ class ResourceAllocationManager: def __init__( self, - configuration_manager: ConfigurationManager, - db_manager: DbManager, + db_interface: AsyncMongoDbInterface, ): - self._device_allocation_manager = DeviceAllocationManager(configuration_manager, db_manager) - self._container_allocation_manager = ContainerAllocationManager(configuration_manager, db_manager) - self._active_requests = ResourceRequestRepository("resource_requests", db_manager) + self._active_requests = ResourceRequestRepository(db_interface) + self._device_allocator = None + self._container_allocator = None # Callbacks for when resource allocation requests are processed self._request_callbacks: dict[ObjectId, Callable[[ActiveResourceAllocationRequest], None]] = {} - self._lock = Lock() + self._lock = asyncio.Lock() - self._delete_all_requests() - self._delete_all_allocations() + async def initialize( + self, configuration_manager: ConfigurationManager, db_interface: AsyncMongoDbInterface + ) -> None: + self._device_allocator = DeviceAllocator(configuration_manager, db_interface) + await self._device_allocator.initialize(db_interface) + + self._container_allocator = ContainerAllocator(configuration_manager, db_interface) + await self._container_allocator.initialize(db_interface) + + await self._delete_all_requests() + await self._delete_all_allocations() log.debug("Resource allocation manager initialized.") - def request_resources( + async def request_resources( self, request: ResourceAllocationRequest, callback: Callable[[ActiveResourceAllocationRequest], None], @@ -59,8 +67,8 @@ def request_resources( :param callback: Callback function to be called when the resource allocation request is processed. :return: List of active resource allocation requests. """ - with self._lock: - existing_request = self._find_existing_request(request) + async with self._lock: + existing_request = await self._find_existing_request(request) if existing_request: if existing_request.status in [ ResourceRequestAllocationStatus.PENDING, @@ -70,71 +78,73 @@ def request_resources( return existing_request active_request = ActiveResourceAllocationRequest(request=request) - result = self._active_requests.create(active_request.model_dump(by_alias=True)) + result = await self._active_requests.create(active_request.model_dump(by_alias=True)) active_request.id = result.inserted_id self._request_callbacks[active_request.id] = callback return active_request - def release_resources(self, active_request: ActiveResourceAllocationRequest) -> None: + async def release_resources(self, active_request: ActiveResourceAllocationRequest) -> None: """ Release the resources allocated for an active resource allocation request. :param active_request: The active resource allocation request. """ - with self._lock: + async with self._lock: for resource in active_request.request.resources: if resource.resource_type == ResourceType.DEVICE: - self._device_allocation_manager.deallocate(resource.lab_id, resource.id) + await self._device_allocator.deallocate(resource.lab_id, resource.id) elif resource.resource_type == ResourceType.CONTAINER: - self._container_allocation_manager.deallocate(resource.id) + await self._container_allocator.deallocate(resource.id) else: raise EosResourceRequestError(f"Unknown resource type: {resource.resource_type}") - self._update_request_status(active_request.id, ResourceRequestAllocationStatus.COMPLETED) + await self._update_request_status(active_request.id, ResourceRequestAllocationStatus.COMPLETED) - def process_active_requests(self) -> None: - with self._lock: - self._clean_completed_and_aborted_requests() + async def process_active_requests(self) -> None: + async with self._lock: + await self._clean_completed_and_aborted_requests() - active_requests = self._get_all_active_requests_prioritized() + active_requests = await self._get_all_active_requests_prioritized() for active_request in active_requests: if active_request.status != ResourceRequestAllocationStatus.PENDING: continue - allocation_success = self._try_allocate(active_request) + allocation_success = await self._try_allocate(active_request) if allocation_success: self._invoke_request_callback(active_request) - def abort_active_request(self, request_id: ObjectId) -> None: + async def abort_active_request(self, request_id: ObjectId) -> None: """ Abort an active resource allocation request. """ - with self._lock: - request = self.get_active_request(request_id) + async with self._lock: + request = await self.get_active_request(request_id) for resource in request.request.resources: if resource.resource_type == ResourceType.DEVICE: - self._device_allocation_manager.deallocate(resource.lab_id, resource.id) + await self._device_allocator.deallocate(resource.lab_id, resource.id) elif resource.resource_type == ResourceType.CONTAINER: - self._container_allocation_manager.deallocate(resource.id) - self._update_request_status(request_id, ResourceRequestAllocationStatus.ABORTED) - active_request = self.get_active_request(request_id) + await self._container_allocator.deallocate(resource.id) + await self._update_request_status(request_id, ResourceRequestAllocationStatus.ABORTED) + active_request = await self.get_active_request(request_id) self._invoke_request_callback(active_request) - def _get_all_active_requests_prioritized(self) -> list[ActiveResourceAllocationRequest]: + async def _get_all_active_requests_prioritized(self) -> list[ActiveResourceAllocationRequest]: """ Get all active resource allocation requests prioritized by the request priority in ascending order. """ active_requests = [] - active_requests_count = self._active_requests.count(status=ResourceRequestAllocationStatus.PENDING.value) + active_requests_count = await self._active_requests.count(status=ResourceRequestAllocationStatus.PENDING.value) if active_requests_count > 0: - active_requests = self._active_requests.get_requests_prioritized(ResourceRequestAllocationStatus.PENDING) + active_requests = await self._active_requests.get_requests_prioritized( + ResourceRequestAllocationStatus.PENDING + ) return [ActiveResourceAllocationRequest(**request) for request in active_requests] - def get_all_active_requests( + async def get_all_active_requests( self, requester: str | None = None, lab_id: str | None = None, @@ -156,25 +166,17 @@ def get_all_active_requests( query["request.experiment_id"] = experiment_id if status: query["status"] = status.value - active_requests = self._active_requests.get_all(**query) + active_requests = await self._active_requests.get_all(**query) return [ActiveResourceAllocationRequest(**request) for request in active_requests] - def get_active_request(self, request_id: ObjectId) -> ActiveResourceAllocationRequest | None: + async def get_active_request(self, request_id: ObjectId) -> ActiveResourceAllocationRequest | None: """ Get an active resource allocation request by ID. If the request does not exist, returns None. """ - request = self._active_requests.get_one(_id=request_id) + request = await self._active_requests.get_one(_id=request_id) return ActiveResourceAllocationRequest(**request) if request else None - @property - def device_allocation_manager(self) -> DeviceAllocationManager: - return self._device_allocation_manager - - @property - def container_allocation_manager(self) -> ContainerAllocationManager: - return self._container_allocation_manager - - def _update_request_status(self, request_id: ObjectId, status: ResourceRequestAllocationStatus) -> None: + async def _update_request_status(self, request_id: ObjectId, status: ResourceRequestAllocationStatus) -> None: """ Update the status of an active resource allocation request. """ @@ -182,13 +184,15 @@ def _update_request_status(self, request_id: ObjectId, status: ResourceRequestAl if status == ResourceRequestAllocationStatus.ALLOCATED: update_data["allocated_at"] = datetime.now(tz=timezone.utc) - self._active_requests.update(update_data, _id=request_id) + await self._active_requests.update_one(update_data, _id=request_id) - def _find_existing_request(self, request: ResourceAllocationRequest) -> ActiveResourceAllocationRequest | None: + async def _find_existing_request( + self, request: ResourceAllocationRequest + ) -> ActiveResourceAllocationRequest | None: """ Find an existing active resource allocation request that matches the given request. """ - existing_request = self._active_requests.get_existing_request(request) + existing_request = await self._active_requests.get_existing_request(request) return ActiveResourceAllocationRequest(**existing_request) if existing_request else None def _invoke_request_callback(self, active_request: ActiveResourceAllocationRequest) -> None: @@ -199,19 +203,19 @@ def _invoke_request_callback(self, active_request: ActiveResourceAllocationReque if callback: callback(active_request) - def _try_allocate(self, active_request: ActiveResourceAllocationRequest) -> bool: + async def _try_allocate(self, active_request: ActiveResourceAllocationRequest) -> bool: temp_allocations = [] all_available = True for resource in active_request.request.resources: if resource.resource_type == ResourceType.DEVICE: - if not self._device_allocation_manager.is_allocated(resource.lab_id, resource.id): + if not await self._device_allocator.is_allocated(resource.lab_id, resource.id): temp_allocations.append(("device", resource.lab_id, resource.id)) else: all_available = False break elif resource.resource_type == ResourceType.CONTAINER: - if not self._container_allocation_manager.is_allocated(resource.id): + if not await self._container_allocator.is_allocated(resource.id): temp_allocations.append(("container", resource.id)) else: all_available = False @@ -222,40 +226,49 @@ def _try_allocate(self, active_request: ActiveResourceAllocationRequest) -> bool if all_available: for allocation in temp_allocations: if allocation[0] == "device": - self._device_allocation_manager.allocate( + await self._device_allocator.allocate( allocation[1], allocation[2], active_request.request.requester, experiment_id=active_request.request.experiment_id, ) else: # container - self._container_allocation_manager.allocate( + await self._container_allocator.allocate( allocation[1], active_request.request.requester, experiment_id=active_request.request.experiment_id, ) - self._update_request_status(active_request.id, ResourceRequestAllocationStatus.ALLOCATED) + await self._update_request_status(active_request.id, ResourceRequestAllocationStatus.ALLOCATED) active_request.status = ResourceRequestAllocationStatus.ALLOCATED return True return False - def _clean_completed_and_aborted_requests(self) -> None: + async def _clean_completed_and_aborted_requests(self) -> None: """ Remove completed or aborted active resource allocation requests. """ - self._active_requests.clean_requests() + await self._active_requests.clean_requests() - def _delete_all_requests(self) -> None: + async def _delete_all_requests(self) -> None: """ Delete all active resource allocation requests. """ - self._active_requests.delete() + await self._active_requests.delete_all() - def _delete_all_allocations(self) -> None: + async def _delete_all_allocations(self) -> None: """ Delete all device and container allocations. """ - self._device_allocation_manager.deallocate_all() - self._container_allocation_manager.deallocate_all() + await asyncio.gather( + self._device_allocator.deallocate_all(), self._container_allocator.deallocate_all() + ) + + @property + def device_allocator(self) -> DeviceAllocator: + return self._device_allocator + + @property + def container_allocator(self) -> ContainerAllocator: + return self._container_allocator diff --git a/eos/scheduling/abstract_scheduler.py b/eos/scheduling/abstract_scheduler.py index b253140..01c869d 100644 --- a/eos/scheduling/abstract_scheduler.py +++ b/eos/scheduling/abstract_scheduler.py @@ -16,7 +16,7 @@ def register_experiment(self, experiment_id: str, experiment_type: str, experime """ @abstractmethod - def unregister_experiment(self, experiment_id: str) -> None: + async def unregister_experiment(self, experiment_id: str) -> None: """ Unregister an experiment from the scheduler. @@ -33,7 +33,7 @@ async def request_tasks(self, experiment_id: str) -> list[ScheduledTask]: """ @abstractmethod - def is_experiment_completed(self, experiment_id: str) -> bool: + async def is_experiment_completed(self, experiment_id: str) -> bool: """ Check if an experiment has been completed. diff --git a/eos/scheduling/greedy_scheduler.py b/eos/scheduling/greedy_scheduler.py index f6c8054..4a079af 100644 --- a/eos/scheduling/greedy_scheduler.py +++ b/eos/scheduling/greedy_scheduler.py @@ -44,8 +44,8 @@ def __init__( self._device_manager = device_manager self._resource_allocation_manager = resource_allocation_manager - self._device_allocation_manager = self._resource_allocation_manager.device_allocation_manager - self._container_allocation_manager = self._resource_allocation_manager.container_allocation_manager + self._device_allocator = self._resource_allocation_manager.device_allocator + self._container_allocator = self._resource_allocation_manager.container_allocator self._registered_experiments = {} self._allocated_resources: dict[str, dict[str, ActiveResourceAllocationRequest]] = {} @@ -66,7 +66,7 @@ def register_experiment(self, experiment_id: str, experiment_type: str, experime self._registered_experiments[experiment_id] = (experiment_type, experiment_graph) log.debug("Experiment '%s' registered for scheduling.", experiment_id) - def unregister_experiment(self, experiment_id: str) -> None: + async def unregister_experiment(self, experiment_id: str) -> None: """ Unregister an experiment from the scheduler. The scheduler will no longer consider this experiment when tasks are requested. @@ -74,7 +74,7 @@ def unregister_experiment(self, experiment_id: str) -> None: with self._lock: if experiment_id in self._registered_experiments: del self._registered_experiments[experiment_id] - self._release_experiment_resources(experiment_id) + await self._release_experiment_resources(experiment_id) else: raise EosSchedulerRegistrationError( f"Cannot unregister experiment {experiment_id} from the scheduler as it is not registered." @@ -97,13 +97,15 @@ async def request_tasks(self, experiment_id: str) -> list[ScheduledTask]: experiment_type, experiment_graph = self._registered_experiments[experiment_id] all_tasks = experiment_graph.get_topologically_sorted_tasks() - completed_tasks = self._experiment_manager.get_completed_tasks(experiment_id) + completed_tasks = await self._experiment_manager.get_completed_tasks(experiment_id) pending_tasks = [task_id for task_id in all_tasks if task_id not in completed_tasks] # Release resources for completed tasks - for task_id in completed_tasks: - if task_id in self._allocated_resources.get(experiment_id, {}): - self._release_task_resources(experiment_id, task_id) + await asyncio.gather(*[ + self._release_task_resources(experiment_id, task_id) + for task_id in completed_tasks + if task_id in self._allocated_resources.get(experiment_id, {}) + ]) scheduled_tasks = [] for task_id in pending_tasks: @@ -111,13 +113,15 @@ async def request_tasks(self, experiment_id: str) -> list[ScheduledTask]: continue task_config = experiment_graph.get_task_config(task_id) - task_config = self._task_input_resolver.resolve_input_container_references(experiment_id, task_config) + task_config = await self._task_input_resolver.resolve_input_container_references(experiment_id, task_config) - if not all(self._check_device_available(device) for device in task_config.devices): + device_checks = [self._check_device_available(device) for device in task_config.devices] + if not all(await asyncio.gather(*device_checks)): continue - if not all( + container_checks = [ self._check_container_available(container_id) for container_id in task_config.containers.values() - ): + ] + if not all(await asyncio.gather(*container_checks)): continue try: @@ -184,19 +188,19 @@ def resource_request_callback(request: ActiveResourceAllocationRequest) -> None: active_request = request allocation_event.set() - active_resource_request = self._resource_allocation_manager.request_resources( + active_resource_request = await self._resource_allocation_manager.request_resources( resource_request, resource_request_callback ) if active_resource_request.status == ResourceRequestAllocationStatus.ALLOCATED: return active_resource_request - self._resource_allocation_manager.process_active_requests() + await self._resource_allocation_manager.process_active_requests() try: await asyncio.wait_for(allocation_event.wait(), timeout) except asyncio.TimeoutError as e: - self._resource_allocation_manager.abort_active_request(active_resource_request.id) + await self._resource_allocation_manager.abort_active_request(active_resource_request.id) raise EosSchedulerResourceAllocationError( f"Resource allocation timed out after {timeout} seconds for task '{resource_request.requester}' " f"while trying to schedule it. " @@ -210,19 +214,19 @@ def resource_request_callback(request: ActiveResourceAllocationRequest) -> None: return active_request - def _release_task_resources(self, experiment_id: str, task_id: str) -> None: + async def _release_task_resources(self, experiment_id: str, task_id: str) -> None: active_request = self._allocated_resources[experiment_id].pop(task_id, None) if active_request: try: - self._resource_allocation_manager.release_resources(active_request) - self._resource_allocation_manager.process_active_requests() + await self._resource_allocation_manager.release_resources(active_request) + await self._resource_allocation_manager.process_active_requests() except EosResourceRequestError as e: log.error(f"Error releasing resources for task '{task_id}' in experiment '{experiment_id}': {e}") - def _release_experiment_resources(self, experiment_id: str) -> None: + async def _release_experiment_resources(self, experiment_id: str) -> None: task_ids = list(self._allocated_resources.get(experiment_id, {}).keys()) for task_id in task_ids: - self._release_task_resources(experiment_id, task_id) + await self._release_task_resources(experiment_id, task_id) if experiment_id in self._allocated_resources: del self._allocated_resources[experiment_id] @@ -237,28 +241,29 @@ def _check_task_dependencies_met( dependencies = experiment_graph.get_task_dependencies(task_id) return all(dep in completed_tasks for dep in dependencies) - def _check_device_available(self, task_device: TaskDeviceConfig) -> bool: + async def _check_device_available(self, task_device: TaskDeviceConfig) -> bool: """ Check if a device is available for a task. A device is available if it is active, not allocated by the device allocation manager. """ - if self._device_manager.get_device(task_device.lab_id, task_device.id).status == DeviceStatus.INACTIVE: + device = await self._device_manager.get_device(task_device.lab_id, task_device.id) + if device.status == DeviceStatus.INACTIVE: log.warning( f"Device {task_device.id} in lab {task_device.lab_id} is inactive but is requested by task " f"{task_device.id}." ) return False - return not self._device_allocation_manager.is_allocated(task_device.lab_id, task_device.id) + return not await self._device_allocator.is_allocated(task_device.lab_id, task_device.id) - def _check_container_available(self, container_id: str) -> bool: + async def _check_container_available(self, container_id: str) -> bool: """ Check if a container is available for a task. A device is available if not allocated by the container allocation manager. """ - return not self._container_allocation_manager.is_allocated(container_id) + return not await self._container_allocator.is_allocated(container_id) - def is_experiment_completed(self, experiment_id: str) -> bool: + async def is_experiment_completed(self, experiment_id: str) -> bool: """ Check if an experiment has been completed. The scheduler should consider the completed tasks from the task manager to determine if the experiment has been completed. @@ -270,6 +275,6 @@ def is_experiment_completed(self, experiment_id: str) -> bool: experiment_type, experiment_graph = self._registered_experiments[experiment_id] all_tasks = experiment_graph.get_task_graph().nodes - completed_tasks = self._experiment_manager.get_completed_tasks(experiment_id) + completed_tasks = await self._experiment_manager.get_completed_tasks(experiment_id) return all(task in completed_tasks for task in all_tasks) diff --git a/eos/tasks/entities/task.py b/eos/tasks/entities/task.py index 095f489..c414bb0 100644 --- a/eos/tasks/entities/task.py +++ b/eos/tasks/entities/task.py @@ -10,77 +10,77 @@ class TaskStatus(Enum): - CREATED = "CREATED" - RUNNING = "RUNNING" - COMPLETED = "COMPLETED" - FAILED = "FAILED" - CANCELLED = "CANCELLED" + CREATED = "CREATED" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + CANCELLED = "CANCELLED" class TaskContainer(BaseModel): - id: str + id: str class TaskInput(BaseModel): - parameters: dict[str, Any] | None = None - containers: dict[str, Container] | None = None + parameters: dict[str, Any] | None = None + containers: dict[str, Container] | None = None - class Config: - arbitrary_types_allowed = True + class Config: + arbitrary_types_allowed = True - @field_serializer("parameters") - def serialize_parameters(self, parameters: dict[str, Any] | None, _info) -> Any: - if parameters is None: - return None - return omegaconf_serializer(parameters) + @field_serializer("parameters") + def serialize_parameters(self, parameters: dict[str, Any] | None, _info) -> Any: + if parameters is None: + return None + return omegaconf_serializer(parameters) class TaskOutput(BaseModel): - parameters: dict[str, Any] | None = None - containers: dict[str, Container] | None = None - file_names: list[str] | None = None + parameters: dict[str, Any] | None = None + containers: dict[str, Container] | None = None + file_names: list[str] | None = None - @field_serializer("parameters") - def serialize_parameters(self, parameters: dict[str, Any] | None, _info) -> Any: - if parameters is None: - return None - return omegaconf_serializer(parameters) + @field_serializer("parameters") + def serialize_parameters(self, parameters: dict[str, Any] | None, _info) -> Any: + if parameters is None: + return None + return omegaconf_serializer(parameters) def omegaconf_serializer(obj: Any) -> Any: - if isinstance(obj, ListConfig | DictConfig): - return OmegaConf.to_object(obj) - if isinstance(obj, dict): - return {k: omegaconf_serializer(v) for k, v in obj.items()} - if isinstance(obj, list): - return [omegaconf_serializer(v) for v in obj] - return obj + if isinstance(obj, ListConfig | DictConfig): + return OmegaConf.to_object(obj) + if isinstance(obj, dict): + return {k: omegaconf_serializer(v) for k, v in obj.items()} + if isinstance(obj, list): + return [omegaconf_serializer(v) for v in obj] + return obj class Task(BaseModel): - id: str - type: str - experiment_id: str + id: str + type: str + experiment_id: str - devices: list[TaskDeviceConfig] = [] - input: TaskInput = TaskInput() - output: TaskOutput = TaskInput() + devices: list[TaskDeviceConfig] = [] + input: TaskInput = TaskInput() + output: TaskOutput = TaskInput() - status: TaskStatus = TaskStatus.CREATED + status: TaskStatus = TaskStatus.CREATED - metadata: dict[str, Any] = {} - start_time: datetime | None = None - end_time: datetime | None = None + metadata: dict[str, Any] = {} + start_time: datetime | None = None + end_time: datetime | None = None - created_at: datetime = datetime.now(tz=timezone.utc) + created_at: datetime = datetime.now(tz=timezone.utc) - class Config: - arbitrary_types_allowed = True - json_encoders: ClassVar = { - ListConfig: lambda v: omegaconf_serializer(v), - DictConfig: lambda v: omegaconf_serializer(v), - } + class Config: + arbitrary_types_allowed = True + json_encoders: ClassVar = { + ListConfig: lambda v: omegaconf_serializer(v), + DictConfig: lambda v: omegaconf_serializer(v), + } - @field_serializer("status") - def status_enum_to_string(self, v: TaskStatus) -> str: - return v.value + @field_serializer("status") + def status_enum_to_string(self, v: TaskStatus) -> str: + return v.value diff --git a/eos/tasks/exceptions.py b/eos/tasks/exceptions.py index bdee178..ce1f515 100644 --- a/eos/tasks/exceptions.py +++ b/eos/tasks/exceptions.py @@ -22,5 +22,9 @@ class EosTaskExecutionError(EosTaskError): pass +class EosTaskCancellationError(EosTaskError): + pass + + class EosTaskResourceAllocationError(EosTaskError): pass diff --git a/eos/tasks/on_demand_task_executor.py b/eos/tasks/on_demand_task_executor.py index 23cb45b..72e1f26 100644 --- a/eos/tasks/on_demand_task_executor.py +++ b/eos/tasks/on_demand_task_executor.py @@ -29,12 +29,13 @@ def __init__(self, task_executor: TaskExecutor, task_manager: TaskManager, conta log.debug("On-demand task executor initialized.") - async def submit_task( + def submit_task( self, task_config: TaskConfig, resource_allocation_priority: int = 90, resource_allocation_timeout: int = 3600, ) -> None: + """Submit an on-demand task for execution.""" task_id = task_config.id task_execution_parameters = TaskExecutionParameters( experiment_id=self.EXPERIMENT_ID, @@ -48,24 +49,28 @@ async def submit_task( ) log.info(f"Submitted on-demand task '{task_id}'.") - async def cancel_task(self, task_id: str) -> None: + async def request_task_cancellation(self, task_id: str) -> None: + """Request cancellation of an on-demand task.""" if task_id not in self._task_futures: raise EosTaskExecutionError(f"Cannot cancel non-existent on-demand task '{task_id}'.") - future = self._task_futures[task_id] - future.cancel() await self._task_executor.request_task_cancellation(self.EXPERIMENT_ID, task_id) + self._task_futures[task_id].cancel() del self._task_futures[task_id] log.info(f"Cancelled on-demand task '{task_id}'.") async def process_tasks(self) -> None: + """ + Process the on-demand tasks that have been submitted. + This should be called periodically to check for task completion. + """ completed_tasks = [] for task_id, future in self._task_futures.items(): if future.done(): try: output = await future - self._process_task_output(task_id, *output) + await self._process_task_output(task_id, *output) except asyncio.CancelledError: log.info(f"On-demand task '{task_id}' was cancelled.") except (EosTaskExecutionError, EosTaskValidationError, EosTaskStateError): @@ -76,15 +81,16 @@ async def process_tasks(self) -> None: for task_id in completed_tasks: del self._task_futures[task_id] - def _process_task_output( + async def _process_task_output( self, task_id: str, output_parameters: dict[str, Any], output_containers: dict[str, Container], output_files: dict[str, bytes], ) -> None: - for container in output_containers.values(): - self._container_manager.update_container(container) + await asyncio.gather( + *[self._container_manager.update_container(container) for container in output_containers.values()] + ) task_output = TaskOutput( experiment_id=self.EXPERIMENT_ID, @@ -97,6 +103,6 @@ def _process_task_output( for file_name, file_data in output_files.items(): self._task_manager.add_task_output_file(self.EXPERIMENT_ID, task_id, file_name, file_data) - self._task_manager.add_task_output(self.EXPERIMENT_ID, task_id, task_output) - self._task_manager.complete_task(self.EXPERIMENT_ID, task_id) + await self._task_manager.add_task_output(self.EXPERIMENT_ID, task_id, task_output) + await self._task_manager.complete_task(self.EXPERIMENT_ID, task_id) log.info(f"EXP '{self.EXPERIMENT_ID}' - Completed task '{task_id}'.") diff --git a/eos/tasks/repositories/task_repository.py b/eos/tasks/repositories/task_repository.py index dce767b..a8ca2ac 100644 --- a/eos/tasks/repositories/task_repository.py +++ b/eos/tasks/repositories/task_repository.py @@ -1,5 +1,41 @@ -from eos.persistence.mongo_repository import MongoRepository +from motor.core import AgnosticClientSession +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository +from eos.tasks.entities.task import TaskStatus -class TaskRepository(MongoRepository): - pass + +class TaskRepository(MongoDbAsyncRepository): + def __init__(self, db_interface: AsyncMongoDbInterface): + super().__init__("tasks", db_interface) + + async def initialize(self) -> None: + await self.create_indices([("experiment_id", 1), ("id", 1)], unique=True) + + async def delete_running_tasks( + self, experiment_id: str, task_ids: list[str], session: AgnosticClientSession | None = None + ) -> None: + """ + Delete all running tasks for a given experiment in a single operation. + """ + await self._collection.delete_many({"experiment_id": experiment_id, "id": {"$in": task_ids}}, session=session) + + async def delete_failed_and_cancelled_tasks( + self, experiment_id: str, session: AgnosticClientSession | None = None + ) -> None: + """ + Delete all non-completed tasks for a given experiment in a single operation. + This includes tasks with FAILED, CANCELLED, and any other non-completed status. + """ + await self._collection.delete_many( + { + "experiment_id": experiment_id, + "status": { + "$in": [ + TaskStatus.FAILED.value, + TaskStatus.CANCELLED.value, + ] + }, + }, + session=session, + ) diff --git a/eos/tasks/task_executor.py b/eos/tasks/task_executor.py index ba366fe..a74bd90 100644 --- a/eos/tasks/task_executor.py +++ b/eos/tasks/task_executor.py @@ -28,7 +28,7 @@ EosTaskResourceAllocationError, EosTaskExecutionError, EosTaskValidationError, - EosTaskExistsError, + EosTaskExistsError, EosTaskCancellationError, ) from eos.tasks.task_input_parameter_caster import TaskInputParameterCaster from eos.tasks.task_manager import TaskManager @@ -61,18 +61,31 @@ def __init__( self._task_validator = TaskValidator() self._task_input_parameter_caster = TaskInputParameterCaster() - self._active_tasks: dict[str, TaskExecutionContext] = {} + self._active_tasks: dict[tuple[str, str], TaskExecutionContext] = {} log.debug("Task executor initialized.") async def request_task_execution( self, task_parameters: TaskExecutionParameters, scheduled_task: ScheduledTask | None = None ) -> BaseTask.OutputType | None: + """ + Request the execution of a task. Resources will first be requested to be allocated (if not pre-allocated) + and then the task will be executed. + + :param task_parameters: Parameters for task execution + :param scheduled_task: Scheduled task information, if applicable. This is populated by the EOS scheduler. + :return: Output of the executed task + + :raises EosTaskExecutionError: If there's an error during task execution + :raises EosTaskValidationError: If the task fails validation + :raises EosTaskResourceAllocationError: If resource allocation fails + """ context = TaskExecutionContext(task_parameters.experiment_id, task_parameters.task_config.id) - self._active_tasks[context.task_id] = context + task_key = (context.experiment_id, context.task_id) + self._active_tasks[task_key] = context try: - containers = self._prepare_containers(task_parameters) + containers = await self._prepare_containers(task_parameters) await self._initialize_task(task_parameters, containers) self._task_validator.validate(task_parameters.task_config) @@ -83,58 +96,69 @@ async def request_task_execution( else await self._allocate_resources(task_parameters) ) - context.task_ref = self._execute_task(task_parameters, containers) + context.task_ref = await self._execute_task(task_parameters, containers) return await context.task_ref except EosTaskExistsError as e: raise EosTaskExecutionError( f"Error executing task '{context.task_id}' in experiment '{context.experiment_id}'" ) from e except EosTaskValidationError as e: - self._task_manager.fail_task(context.experiment_id, context.task_id) + await self._task_manager.fail_task(context.experiment_id, context.task_id) log.warning(f"EXP '{context.experiment_id}' - Failed task '{context.task_id}'.") raise EosTaskValidationError( f"Validation error for task '{context.task_id}' in experiment '{context.experiment_id}'" ) from e except EosTaskResourceAllocationError as e: - self._task_manager.fail_task(context.experiment_id, context.task_id) + await self._task_manager.fail_task(context.experiment_id, context.task_id) log.warning(f"EXP '{context.experiment_id}' - Failed task '{context.task_id}'.") raise EosTaskResourceAllocationError( f"Failed to allocate resources for task '{context.task_id}' in experiment '{context.experiment_id}'" ) from e except Exception as e: - self._task_manager.fail_task(context.experiment_id, context.task_id) + await self._task_manager.fail_task(context.experiment_id, context.task_id) log.warning(f"EXP '{context.experiment_id}' - Failed task '{context.task_id}'.") raise EosTaskExecutionError( f"Error executing task '{context.task_id}' in experiment '{context.experiment_id}'" ) from e finally: if context.active_resource_request and not scheduled_task: - self._release_resources(context.active_resource_request) + # We only release resources if they were allocated by the task executor and not the scheduler + await self._release_resources(context.active_resource_request) - if context.task_id in self._active_tasks: - del self._active_tasks[context.task_id] + if task_key in self._active_tasks: + del self._active_tasks[task_key] async def request_task_cancellation(self, experiment_id: str, task_id: str) -> None: - context = self._active_tasks.get(task_id) + """ + Request the cancellation of a running task. + + :param experiment_id: ID of the experiment + :param task_id: ID of the task to cancel + """ + task_key = (experiment_id, task_id) + context = self._active_tasks.get(task_key) if not context: - return + raise EosTaskCancellationError( + f"Cannot cancel task '{task_id}' in experiment '{experiment_id}' as it does not exist.") if context.task_ref: ray.cancel(context.task_ref, recursive=True) if context.active_resource_request: - self._resource_allocation_manager.abort_active_request(context.active_resource_request.id) - self._resource_allocation_manager.process_active_requests() + await self._resource_allocation_manager.abort_active_request(context.active_resource_request.id) + await self._resource_allocation_manager.process_active_requests() - self._task_manager.cancel_task(experiment_id, task_id) - del self._active_tasks[task_id] + await self._task_manager.cancel_task(experiment_id, task_id) + del self._active_tasks[task_key] log.warning(f"EXP '{experiment_id}' - Cancelled task '{task_id}'.") - def _prepare_containers(self, execution_parameters: TaskExecutionParameters) -> dict[str, Container]: - return { - container_name: self._container_manager.get_container(container_id) - for container_name, container_id in execution_parameters.task_config.containers.items() - } + async def _prepare_containers(self, execution_parameters: TaskExecutionParameters) -> dict[str, Container]: + containers = execution_parameters.task_config.containers + fetched_containers = await asyncio.gather( + *[self._container_manager.get_container(container_id) for container_id in containers.values()] + ) + + return dict(zip(containers.keys(), fetched_containers, strict=True)) async def _initialize_task( self, execution_parameters: TaskExecutionParameters, containers: dict[str, Container] @@ -142,13 +166,13 @@ async def _initialize_task( experiment_id, task_id = execution_parameters.experiment_id, execution_parameters.task_config.id log.debug(f"Execution of task '{task_id}' for experiment '{experiment_id}' has been requested") - task = self._task_manager.get_task(experiment_id, task_id) + task = await self._task_manager.get_task(experiment_id, task_id) if task and task.status == TaskStatus.RUNNING: log.warning(f"Found running task '{task_id}' for experiment '{experiment_id}'. Restarting it.") await self.request_task_cancellation(experiment_id, task_id) - self._task_manager.delete_task(experiment_id, task_id) + await self._task_manager.delete_task(experiment_id, task_id) - self._task_manager.create_task( + await self._task_manager.create_task( experiment_id=experiment_id, task_id=task_id, task_type=execution_parameters.task_config.type, @@ -174,7 +198,7 @@ def _get_device_actor_references(self, task_parameters: TaskExecutionParameters) for device in task_parameters.task_config.devices ] - def _execute_task( + async def _execute_task( self, task_execution_parameters: TaskExecutionParameters, containers: dict[str, Container], @@ -202,7 +226,7 @@ def _ray_execute_task( devices = DeviceActorWrapperRegistry(_devices_actor_references) return task.execute(devices, _parameters, _containers) - self._task_manager.start_task(experiment_id, task_id) + await self._task_manager.start_task(experiment_id, task_id) log.info(f"EXP '{experiment_id}' - Started task '{task_id}'.") return _ray_execute_task.options(name=f"{experiment_id}.{task_id}").remote( @@ -246,19 +270,19 @@ def resource_request_callback(request: ActiveResourceAllocationRequest) -> None: active_request = request allocation_event.set() - active_resource_request = self._resource_allocation_manager.request_resources( + active_resource_request = await self._resource_allocation_manager.request_resources( resource_request, resource_request_callback ) if active_resource_request.status == ResourceRequestAllocationStatus.ALLOCATED: return active_resource_request - self._resource_allocation_manager.process_active_requests() + await self._resource_allocation_manager.process_active_requests() try: await asyncio.wait_for(allocation_event.wait(), timeout) except asyncio.TimeoutError as e: - self._resource_allocation_manager.abort_active_request(active_resource_request.id) + await self._resource_allocation_manager.abort_active_request(active_resource_request.id) raise EosTaskResourceAllocationError( f"Resource allocation timed out after {timeout} seconds for task '{resource_request.requester}'. " f"Aborting all resource allocations for this task." @@ -269,9 +293,9 @@ def resource_request_callback(request: ActiveResourceAllocationRequest) -> None: return active_request - def _release_resources(self, active_request: ActiveResourceAllocationRequest) -> None: + async def _release_resources(self, active_request: ActiveResourceAllocationRequest) -> None: try: - self._resource_allocation_manager.release_resources(active_request) - self._resource_allocation_manager.process_active_requests() + await self._resource_allocation_manager.release_resources(active_request) + await self._resource_allocation_manager.process_active_requests() except EosResourceRequestError as e: raise EosTaskExecutionError(f"Error releasing task '{active_request.request.requester}' resources") from e diff --git a/eos/tasks/task_input_resolver.py b/eos/tasks/task_input_resolver.py index c24508a..00bfaa6 100644 --- a/eos/tasks/task_input_resolver.py +++ b/eos/tasks/task_input_resolver.py @@ -1,5 +1,4 @@ import copy -import functools from typing import Protocol from eos.configuration.entities.task import TaskConfig @@ -9,8 +8,8 @@ from eos.tasks.task_manager import TaskManager -class Resolver(Protocol): - def __call__(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: ... +class AsyncResolver(Protocol): + async def __call__(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: ... class TaskInputResolver: @@ -23,11 +22,11 @@ def __init__(self, task_manager: TaskManager, experiment_manager: ExperimentMana self._task_manager = task_manager self._experiment_manager = experiment_manager - def resolve_task_inputs(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: + async def resolve_task_inputs(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: """ Resolve all input references for a task. """ - return self._apply_resolvers( + return await self._apply_resolvers( experiment_id, task_config, [ @@ -37,34 +36,37 @@ def resolve_task_inputs(self, experiment_id: str, task_config: TaskConfig) -> Ta ], ) - def resolve_dynamic_parameters(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: + async def _apply_resolvers( + self, experiment_id: str, task_config: TaskConfig, resolvers: list[AsyncResolver] + ) -> TaskConfig: """ - Resolve dynamic parameters for a task. + Apply a list of async resolver functions to the task config. """ - return self._apply_resolvers(experiment_id, task_config, [self._resolve_dynamic_parameters]) + config = copy.deepcopy(task_config) + for resolver in resolvers: + config = await resolver(experiment_id, config) + return config - def resolve_input_parameter_references(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: + async def resolve_dynamic_parameters(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: """ - Resolve input parameter references for a task. + Resolve dynamic parameters for a task. """ - return self._apply_resolvers(experiment_id, task_config, [self._resolve_input_parameter_references]) + return await self._apply_resolvers(experiment_id, task_config, [self._resolve_dynamic_parameters]) - def resolve_input_container_references(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: + async def resolve_input_parameter_references(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: """ - Resolve input container references for a task. + Resolve input parameter references for a task. """ - return self._apply_resolvers(experiment_id, task_config, [self._resolve_input_container_references]) + return await self._apply_resolvers(experiment_id, task_config, [self._resolve_input_parameter_references]) - def _apply_resolvers(self, experiment_id: str, task_config: TaskConfig, resolvers: list[Resolver]) -> TaskConfig: + async def resolve_input_container_references(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: """ - Apply a list of resolver functions to the task config. + Resolve input container references for a task. """ - return functools.reduce( - lambda config, resolver: resolver(experiment_id, config), resolvers, copy.deepcopy(task_config) - ) + return await self._apply_resolvers(experiment_id, task_config, [self._resolve_input_container_references]) - def _resolve_dynamic_parameters(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: - experiment = self._experiment_manager.get_experiment(experiment_id) + async def _resolve_dynamic_parameters(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: + experiment = await self._experiment_manager.get_experiment(experiment_id) task_dynamic_parameters = experiment.dynamic_parameters.get(task_config.id, {}) task_config.parameters.update(task_dynamic_parameters) @@ -80,13 +82,13 @@ def _resolve_dynamic_parameters(self, experiment_id: str, task_config: TaskConfi return task_config - def _resolve_input_parameter_references(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: + async def _resolve_input_parameter_references(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: for param_name, param_value in task_config.parameters.items(): if not validation_utils.is_parameter_reference(param_value): continue ref_task_id, ref_param_name = param_value.split(".") - resolved_value = self._resolve_reference(experiment_id, ref_task_id, ref_param_name, "parameter") + resolved_value = await self._resolve_reference(experiment_id, ref_task_id, ref_param_name, "parameter") if resolved_value is not None: task_config.parameters[param_name] = resolved_value @@ -97,13 +99,13 @@ def _resolve_input_parameter_references(self, experiment_id: str, task_config: T return task_config - def _resolve_input_container_references(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: + async def _resolve_input_container_references(self, experiment_id: str, task_config: TaskConfig) -> TaskConfig: for container_name, container_id in task_config.containers.items(): if not validation_utils.is_container_reference(container_id): continue ref_task_id, ref_container_name = container_id.split(".") - resolved_value = self._resolve_reference(experiment_id, ref_task_id, ref_container_name, "container") + resolved_value = await self._resolve_reference(experiment_id, ref_task_id, ref_container_name, "container") if resolved_value is not None: task_config.containers[container_name] = resolved_value @@ -114,13 +116,15 @@ def _resolve_input_container_references(self, experiment_id: str, task_config: T return task_config - def _resolve_reference(self, experiment_id: str, ref_task_id: str, ref_name: str, ref_type: str) -> str | None: - ref_task_output = self._task_manager.get_task_output(experiment_id, ref_task_id) + async def _resolve_reference( + self, experiment_id: str, ref_task_id: str, ref_name: str, ref_type: str + ) -> str | None: + ref_task_output = await self._task_manager.get_task_output(experiment_id, ref_task_id) if ref_type == "parameter": if ref_name in (ref_task_output.parameters or {}): return ref_task_output.parameters[ref_name] - ref_task = self._task_manager.get_task(experiment_id, ref_task_id) + ref_task = await self._task_manager.get_task(experiment_id, ref_task_id) if ref_name in (ref_task.input.parameters or {}): return ref_task.input.parameters[ref_name] elif ref_type == "container": diff --git a/eos/tasks/task_manager.py b/eos/tasks/task_manager.py index ad10117..e2350a4 100644 --- a/eos/tasks/task_manager.py +++ b/eos/tasks/task_manager.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import AsyncIterable from datetime import datetime, timezone from typing import Any @@ -7,8 +8,8 @@ from eos.containers.entities.container import Container from eos.experiments.repositories.experiment_repository import ExperimentRepository from eos.logging.logger import log -from eos.persistence.db_manager import DbManager -from eos.persistence.file_db_manager import FileDbManager +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.file_db_interface import FileDbInterface from eos.tasks.entities.task import Task, TaskStatus, TaskInput, TaskOutput from eos.tasks.exceptions import EosTaskStateError, EosTaskExistsError from eos.tasks.repositories.task_repository import TaskRepository @@ -22,19 +23,25 @@ class TaskManager: def __init__( self, configuration_manager: ConfigurationManager, - db_manager: DbManager, - file_db_manager: FileDbManager, + db_interface: AsyncMongoDbInterface, + file_db_interface: FileDbInterface, ): self._configuration_manager = configuration_manager - self._db_manager = db_manager - self._file_db_manager = file_db_manager - self._tasks = TaskRepository("tasks", db_manager) - self._tasks.create_indices([("experiment_id", 1), ("id", 1)], unique=True) - self._experiments = ExperimentRepository("experiments", db_manager) + self._file_db_interface = file_db_interface + self._session_factory = db_interface.session_factory + self._tasks = None + self._experiments = None + + async def initialize(self, db_interface: AsyncMongoDbInterface) -> None: + self._tasks = TaskRepository(db_interface) + await self._tasks.initialize() + + self._experiments = ExperimentRepository(db_interface) + await self._experiments.initialize() log.debug("Task manager initialized.") - def create_task( + async def create_task( self, experiment_id: str, task_id: str, @@ -55,7 +62,7 @@ def create_task( :param containers: The input containers for the task. :param metadata: Additional metadata to be stored with the task. """ - if self._tasks.get_one(experiment_id=experiment_id, id=task_id): + if await self._tasks.get_one(experiment_id=experiment_id, id=task_id): raise EosTaskExistsError(f"Cannot create task '{task_id}' as a task with that ID already exists.") task_spec = self._configuration_manager.task_specs.get_spec_by_type(task_type) @@ -72,81 +79,95 @@ def create_task( input=task_input, metadata=metadata or {}, ) - self._tasks.create(task.model_dump()) + await self._tasks.create(task.model_dump()) - def delete_task(self, experiment_id: str, task_id: str) -> None: + async def delete_task(self, experiment_id: str, task_id: str) -> None: """ Delete an experiment task instance. """ - self._validate_task_exists(experiment_id, task_id) + await self._validate_task_exists(experiment_id, task_id) - self._tasks.delete(experiment_id=experiment_id, id=task_id) + await asyncio.gather( + self._experiments.delete_running_task(experiment_id, task_id), + self._tasks.delete_one(experiment_id=experiment_id, id=task_id), + ) - self._experiments.delete_running_task(experiment_id, task_id) log.info(f"Deleted task '{task_id}' from experiment '{experiment_id}'.") - def start_task(self, experiment_id: str, task_id: str) -> None: + async def start_task(self, experiment_id: str, task_id: str) -> None: """ Add a task to the running tasks list and update its status to running. """ - self._validate_task_exists(experiment_id, task_id) - self._experiments.add_running_task(experiment_id, task_id) - self._set_task_status(experiment_id, task_id, TaskStatus.RUNNING) + await self._validate_task_exists(experiment_id, task_id) - def complete_task(self, experiment_id: str, task_id: str) -> None: + await asyncio.gather( + self._experiments.add_running_task(experiment_id, task_id), + self._set_task_status(experiment_id, task_id, TaskStatus.RUNNING), + ) + + async def complete_task(self, experiment_id: str, task_id: str) -> None: """ Remove a task from the running tasks list and add it to the completed tasks list. """ - self._validate_task_exists(experiment_id, task_id) - self._experiments.move_task_queue(experiment_id, task_id, "running_tasks", "completed_tasks") - self._set_task_status(experiment_id, task_id, TaskStatus.COMPLETED) + await self._validate_task_exists(experiment_id, task_id) - def fail_task(self, experiment_id: str, task_id: str) -> None: + await asyncio.gather( + self._experiments.move_task_queue(experiment_id, task_id, "running_tasks", "completed_tasks"), + self._set_task_status(experiment_id, task_id, TaskStatus.COMPLETED), + ) + + async def fail_task(self, experiment_id: str, task_id: str) -> None: """ Remove a task from the running tasks list and do not add it to the executed tasks list. Update the task status to failed. """ - self._validate_task_exists(experiment_id, task_id) - self._experiments.delete_running_task(experiment_id, task_id) - self._set_task_status(experiment_id, task_id, TaskStatus.FAILED) + await self._validate_task_exists(experiment_id, task_id) + + await asyncio.gather( + self._experiments.delete_running_task(experiment_id, task_id), + self._set_task_status(experiment_id, task_id, TaskStatus.FAILED), + ) - def cancel_task(self, experiment_id: str, task_id: str) -> None: + async def cancel_task(self, experiment_id: str, task_id: str) -> None: """ Remove a task from the running tasks list and do not add it to the executed tasks list. Update the task status to cancelled. """ - self._validate_task_exists(experiment_id, task_id) - self._experiments.delete_running_task(experiment_id, task_id) - self._set_task_status(experiment_id, task_id, TaskStatus.CANCELLED) + await self._validate_task_exists(experiment_id, task_id) + + await asyncio.gather( + self._experiments.delete_running_task(experiment_id, task_id), + self._set_task_status(experiment_id, task_id, TaskStatus.CANCELLED), + ) log.warning(f"EXP '{experiment_id}' - Cancelled task '{task_id}'.") - def get_task(self, experiment_id: str, task_id: str) -> Task | None: + async def get_task(self, experiment_id: str, task_id: str) -> Task | None: """ Get a task by its ID and experiment ID. """ - task = self._tasks.get_one(experiment_id=experiment_id, id=task_id) + task = await self._tasks.get_one(experiment_id=experiment_id, id=task_id) return Task(**task) if task else None - def get_tasks(self, **query: dict[str, Any]) -> list[Task]: + async def get_tasks(self, **query: dict[str, Any]) -> list[Task]: """ Query tasks with arbitrary parameters. :param query: Dictionary of query parameters. """ - tasks = self._tasks.get_all(**query) + tasks = await self._tasks.get_all(**query) return [Task(**task) for task in tasks] - def add_task_output(self, experiment_id: str, task_id: str, task_output: TaskOutput) -> None: + async def add_task_output(self, experiment_id: str, task_id: str, task_output: TaskOutput) -> None: """ Add the output of a task to the database. """ - self._tasks.update({"output": task_output.model_dump()}, experiment_id=experiment_id, id=task_id) + await self._tasks.update_one({"output": task_output.model_dump()}, experiment_id=experiment_id, id=task_id) - def get_task_output(self, experiment_id: str, task_id: str) -> TaskOutput | None: + async def get_task_output(self, experiment_id: str, task_id: str) -> TaskOutput | None: """ Get the output of a task by its ID and experiment ID. """ - result = self._tasks.get_one(experiment_id=experiment_id, id=task_id) + result = await self._tasks.get_one(experiment_id=experiment_id, id=task_id) if not result: return None @@ -161,14 +182,14 @@ def add_task_output_file(self, experiment_id: str, task_id: str, file_name: str, Add a file output from a task to the file database. """ path = f"{experiment_id}/{task_id}/{file_name}" - self._file_db_manager.store_file(path, file_data) + self._file_db_interface.store_file(path, file_data) def get_task_output_file(self, experiment_id: str, task_id: str, file_name: str) -> bytes: """ Get a file output from a task from the file database. """ path = f"{experiment_id}/{task_id}/{file_name}" - return self._file_db_manager.get_file(path) + return self._file_db_interface.get_file(path) def stream_task_output_file( self, experiment_id: str, task_id: str, file_name: str, chunk_size: int = 3 * 1024 * 1024 @@ -177,27 +198,27 @@ def stream_task_output_file( Stream a file output from a task from the file database. """ path = f"{experiment_id}/{task_id}/{file_name}" - return self._file_db_manager.stream_file(path, chunk_size) + return self._file_db_interface.stream_file(path, chunk_size) def list_task_output_files(self, experiment_id: str, task_id: str) -> list[str]: """ List all file outputs from a task in the file database. """ prefix = f"{experiment_id}/{task_id}/" - return self._file_db_manager.list_files(prefix) + return self._file_db_interface.list_files(prefix) def delete_task_output_file(self, experiment_id: str, task_id: str, file_name: str) -> None: """ Delete a file output from a task in the file database. """ path = f"{experiment_id}/{task_id}/{file_name}" - self._file_db_manager.delete_file(path) + self._file_db_interface.delete_file(path) - def _set_task_status(self, experiment_id: str, task_id: str, new_status: TaskStatus) -> None: + async def _set_task_status(self, experiment_id: str, task_id: str, new_status: TaskStatus) -> None: """ Update the status of a task. """ - self._validate_task_exists(experiment_id, task_id) + await self._validate_task_exists(experiment_id, task_id) update_fields = {"status": new_status.value} if new_status == TaskStatus.RUNNING: @@ -205,11 +226,11 @@ def _set_task_status(self, experiment_id: str, task_id: str, new_status: TaskSta elif new_status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: update_fields["end_time"] = datetime.now(tz=timezone.utc) - self._tasks.update(update_fields, experiment_id=experiment_id, id=task_id) + await self._tasks.update_one(update_fields, experiment_id=experiment_id, id=task_id) - def _validate_task_exists(self, experiment_id: str, task_id: str) -> None: + async def _validate_task_exists(self, experiment_id: str, task_id: str) -> None: """ Check if a task exists in an experiment. """ - if not self._tasks.exists(experiment_id=experiment_id, id=task_id): + if not await self._tasks.exists(experiment_id=experiment_id, id=task_id): raise EosTaskStateError(f"Task '{task_id}' does not exist in experiment '{experiment_id}'.") diff --git a/eos/utils/async_rlock.py b/eos/utils/async_rlock.py new file mode 100644 index 0000000..2320e75 --- /dev/null +++ b/eos/utils/async_rlock.py @@ -0,0 +1,36 @@ +import asyncio + + +class AsyncRLock: + def __init__(self): + self._lock = asyncio.Lock() + self._owner: asyncio.Task | None = None + self._count = 0 + + async def acquire(self) -> bool: + current_task = asyncio.current_task() + if self._owner is current_task: + self._count += 1 + return True + + await self._lock.acquire() + self._owner = current_task + self._count = 1 + return True + + def release(self) -> None: + current_task = asyncio.current_task() + if self._owner is not current_task: + raise RuntimeError("Cannot release an un-acquired lock") + + self._count -= 1 + if self._count == 0: + self._owner = None + self._lock.release() + + async def __aenter__(self): + await self.acquire() + return self + + async def __aexit__(self, exc_type, exc, tb): + self.release() diff --git a/eos/web_api/orchestrator/controllers/task_controller.py b/eos/web_api/orchestrator/controllers/task_controller.py index 023f84d..0e5f50d 100644 --- a/eos/web_api/orchestrator/controllers/task_controller.py +++ b/eos/web_api/orchestrator/controllers/task_controller.py @@ -20,7 +20,7 @@ async def get_task(self, experiment_id: str, task_id: str, orchestrator: Orchest @post("/submit") @handle_exceptions("Failed to submit task") async def submit_task(self, data: SubmitTaskRequest, orchestrator: Orchestrator) -> Response: - await orchestrator.submit_task( + orchestrator.submit_task( data.task_config, data.resource_allocation_priority, data.resource_allocation_timeout ) return Response(content=None, status_code=HTTP_201_CREATED) diff --git a/pdm.lock b/pdm.lock index 0d6439e..7825cea 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:8c7e0b866c8de9954f521425fe4a268163c441bfce9aea93f717e5393ae11116" +content_hash = "sha256:f11de7d602ce06bb8d7d0c62e185925068eacb0e790be5334c7ddfda9ceefa16" [[metadata.targets]] requires_python = ">=3.10" @@ -1341,7 +1341,7 @@ files = [ [[package]] name = "litestar" -version = "2.11.0" +version = "2.12.1" requires_python = "<4.0,>=3.8" summary = "Litestar - A production-ready, highly performant, extensible ASGI API Framework" groups = ["default"] @@ -1361,13 +1361,13 @@ dependencies = [ "typing-extensions", ] files = [ - {file = "litestar-2.11.0-py3-none-any.whl", hash = "sha256:6d677ccdc00a0b4ce54cff5172531890358a27d6da1a054c8cab6a7e2119823e"}, - {file = "litestar-2.11.0.tar.gz", hash = "sha256:6c8cf2b60c352e6b8e08e6a995d2a66ddc26ec53bc2f1df7214d26abcc1d00c2"}, + {file = "litestar-2.12.1-py3-none-any.whl", hash = "sha256:74915e3731c200caa099c416a1c3b3079ffacdd6e6393974e0284f8919606f9c"}, + {file = "litestar-2.12.1.tar.gz", hash = "sha256:d2cc43157060a06dac8a77e9dc6ba2936238beada61e272e8842c21fca23fcee"}, ] [[package]] name = "litestar" -version = "2.11.0" +version = "2.12.1" extras = ["standard"] requires_python = "<4.0,>=3.8" summary = "Litestar - A production-ready, highly performant, extensible ASGI API Framework" @@ -1376,13 +1376,13 @@ dependencies = [ "fast-query-parsers>=1.0.2", "jinja2", "jsbeautifier", - "litestar==2.11.0", + "litestar==2.12.1", "uvicorn[standard]", "uvloop>=0.18.0; sys_platform != \"win32\"", ] files = [ - {file = "litestar-2.11.0-py3-none-any.whl", hash = "sha256:6d677ccdc00a0b4ce54cff5172531890358a27d6da1a054c8cab6a7e2119823e"}, - {file = "litestar-2.11.0.tar.gz", hash = "sha256:6c8cf2b60c352e6b8e08e6a995d2a66ddc26ec53bc2f1df7214d26abcc1d00c2"}, + {file = "litestar-2.12.1-py3-none-any.whl", hash = "sha256:74915e3731c200caa099c416a1c3b3079ffacdd6e6393974e0284f8919606f9c"}, + {file = "litestar-2.12.1.tar.gz", hash = "sha256:d2cc43157060a06dac8a77e9dc6ba2936238beada61e272e8842c21fca23fcee"}, ] [[package]] @@ -1538,6 +1538,20 @@ files = [ {file = "minio-7.2.8.tar.gz", hash = "sha256:f8af2dafc22ebe1aef3ac181b8e217037011c430aa6da276ed627e55aaf7c815"}, ] +[[package]] +name = "motor" +version = "3.6.0" +requires_python = ">=3.8" +summary = "Non-blocking MongoDB driver for Tornado or asyncio" +groups = ["default"] +dependencies = [ + "pymongo<4.10,>=4.9", +] +files = [ + {file = "motor-3.6.0-py3-none-any.whl", hash = "sha256:9f07ed96f1754963d4386944e1b52d403a5350c687edc60da487d66f98dbf894"}, + {file = "motor-3.6.0.tar.gz", hash = "sha256:0ef7f520213e852bf0eac306adf631aabe849227d8aec900a2612512fb9c5b8d"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -1719,13 +1733,13 @@ files = [ [[package]] name = "networkx" -version = "3.3" +version = "3.4.1" requires_python = ">=3.10" summary = "Python package for creating and manipulating graphs and networks" groups = ["default"] files = [ - {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"}, - {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"}, + {file = "networkx-3.4.1-py3-none-any.whl", hash = "sha256:e30a87b48c9a6a7cc220e732bffefaee585bdb166d13377734446ce1a0620eed"}, + {file = "networkx-3.4.1.tar.gz", hash = "sha256:f9df45e85b78f5bd010993e897b4f1fdb242c11e015b101bd951e5c0e29982d8"}, ] [[package]] @@ -2019,7 +2033,7 @@ files = [ [[package]] name = "pandas" -version = "2.2.2" +version = "2.2.3" requires_python = ">=3.9" summary = "Powerful data structures for data analysis, time series, and statistics" groups = ["default"] @@ -2032,28 +2046,41 @@ dependencies = [ "tzdata>=2022.7", ] files = [ - {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, - {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, - {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, - {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, - {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f"}, + {file = "pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32"}, + {file = "pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, + {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, + {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, + {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, ] [[package]] @@ -2240,24 +2267,24 @@ files = [ [[package]] name = "pydantic" -version = "2.9.1" +version = "2.9.2" requires_python = ">=3.8" summary = "Data validation using Python type hints" groups = ["default"] dependencies = [ "annotated-types>=0.6.0", - "pydantic-core==2.23.3", + "pydantic-core==2.23.4", "typing-extensions>=4.12.2; python_version >= \"3.13\"", "typing-extensions>=4.6.1; python_version < \"3.13\"", ] files = [ - {file = "pydantic-2.9.1-py3-none-any.whl", hash = "sha256:7aff4db5fdf3cf573d4b3c30926a510a10e19a0774d38fc4967f78beb6deb612"}, - {file = "pydantic-2.9.1.tar.gz", hash = "sha256:1363c7d975c7036df0db2b4a61f2e062fbc0aa5ab5f2772e0ffc7191a4f4bce2"}, + {file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"}, + {file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"}, ] [[package]] name = "pydantic-core" -version = "2.23.3" +version = "2.23.4" requires_python = ">=3.8" summary = "Core functionality for Pydantic validation and serialization" groups = ["default"] @@ -2265,63 +2292,63 @@ dependencies = [ "typing-extensions!=4.7.0,>=4.6.0", ] files = [ - {file = "pydantic_core-2.23.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:7f10a5d1b9281392f1bf507d16ac720e78285dfd635b05737c3911637601bae6"}, - {file = "pydantic_core-2.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3c09a7885dd33ee8c65266e5aa7fb7e2f23d49d8043f089989726391dd7350c5"}, - {file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6470b5a1ec4d1c2e9afe928c6cb37eb33381cab99292a708b8cb9aa89e62429b"}, - {file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9172d2088e27d9a185ea0a6c8cebe227a9139fd90295221d7d495944d2367700"}, - {file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86fc6c762ca7ac8fbbdff80d61b2c59fb6b7d144aa46e2d54d9e1b7b0e780e01"}, - {file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0cb80fd5c2df4898693aa841425ea1727b1b6d2167448253077d2a49003e0ed"}, - {file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03667cec5daf43ac4995cefa8aaf58f99de036204a37b889c24a80927b629cec"}, - {file = "pydantic_core-2.23.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:047531242f8e9c2db733599f1c612925de095e93c9cc0e599e96cf536aaf56ba"}, - {file = "pydantic_core-2.23.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5499798317fff7f25dbef9347f4451b91ac2a4330c6669821c8202fd354c7bee"}, - {file = "pydantic_core-2.23.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bbb5e45eab7624440516ee3722a3044b83fff4c0372efe183fd6ba678ff681fe"}, - {file = "pydantic_core-2.23.3-cp310-none-win32.whl", hash = "sha256:8b5b3ed73abb147704a6e9f556d8c5cb078f8c095be4588e669d315e0d11893b"}, - {file = "pydantic_core-2.23.3-cp310-none-win_amd64.whl", hash = "sha256:2b603cde285322758a0279995b5796d64b63060bfbe214b50a3ca23b5cee3e83"}, - {file = "pydantic_core-2.23.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:c889fd87e1f1bbeb877c2ee56b63bb297de4636661cc9bbfcf4b34e5e925bc27"}, - {file = "pydantic_core-2.23.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea85bda3189fb27503af4c45273735bcde3dd31c1ab17d11f37b04877859ef45"}, - {file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7f7f72f721223f33d3dc98a791666ebc6a91fa023ce63733709f4894a7dc611"}, - {file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b2b55b0448e9da68f56b696f313949cda1039e8ec7b5d294285335b53104b61"}, - {file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c24574c7e92e2c56379706b9a3f07c1e0c7f2f87a41b6ee86653100c4ce343e5"}, - {file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2b05e6ccbee333a8f4b8f4d7c244fdb7a979e90977ad9c51ea31261e2085ce0"}, - {file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2c409ce1c219c091e47cb03feb3c4ed8c2b8e004efc940da0166aaee8f9d6c8"}, - {file = "pydantic_core-2.23.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d965e8b325f443ed3196db890d85dfebbb09f7384486a77461347f4adb1fa7f8"}, - {file = "pydantic_core-2.23.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f56af3a420fb1ffaf43ece3ea09c2d27c444e7c40dcb7c6e7cf57aae764f2b48"}, - {file = "pydantic_core-2.23.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5b01a078dd4f9a52494370af21aa52964e0a96d4862ac64ff7cea06e0f12d2c5"}, - {file = "pydantic_core-2.23.3-cp311-none-win32.whl", hash = "sha256:560e32f0df04ac69b3dd818f71339983f6d1f70eb99d4d1f8e9705fb6c34a5c1"}, - {file = "pydantic_core-2.23.3-cp311-none-win_amd64.whl", hash = "sha256:c744fa100fdea0d000d8bcddee95213d2de2e95b9c12be083370b2072333a0fa"}, - {file = "pydantic_core-2.23.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e0ec50663feedf64d21bad0809f5857bac1ce91deded203efc4a84b31b2e4305"}, - {file = "pydantic_core-2.23.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db6e6afcb95edbe6b357786684b71008499836e91f2a4a1e55b840955b341dbb"}, - {file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98ccd69edcf49f0875d86942f4418a4e83eb3047f20eb897bffa62a5d419c8fa"}, - {file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a678c1ac5c5ec5685af0133262103defb427114e62eafeda12f1357a12140162"}, - {file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:01491d8b4d8db9f3391d93b0df60701e644ff0894352947f31fff3e52bd5c801"}, - {file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fcf31facf2796a2d3b7fe338fe8640aa0166e4e55b4cb108dbfd1058049bf4cb"}, - {file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7200fd561fb3be06827340da066df4311d0b6b8eb0c2116a110be5245dceb326"}, - {file = "pydantic_core-2.23.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dc1636770a809dee2bd44dd74b89cc80eb41172bcad8af75dd0bc182c2666d4c"}, - {file = "pydantic_core-2.23.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:67a5def279309f2e23014b608c4150b0c2d323bd7bccd27ff07b001c12c2415c"}, - {file = "pydantic_core-2.23.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:748bdf985014c6dd3e1e4cc3db90f1c3ecc7246ff5a3cd4ddab20c768b2f1dab"}, - {file = "pydantic_core-2.23.3-cp312-none-win32.whl", hash = "sha256:255ec6dcb899c115f1e2a64bc9ebc24cc0e3ab097775755244f77360d1f3c06c"}, - {file = "pydantic_core-2.23.3-cp312-none-win_amd64.whl", hash = "sha256:40b8441be16c1e940abebed83cd006ddb9e3737a279e339dbd6d31578b802f7b"}, - {file = "pydantic_core-2.23.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:6daaf5b1ba1369a22c8b050b643250e3e5efc6a78366d323294aee54953a4d5f"}, - {file = "pydantic_core-2.23.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d015e63b985a78a3d4ccffd3bdf22b7c20b3bbd4b8227809b3e8e75bc37f9cb2"}, - {file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3fc572d9b5b5cfe13f8e8a6e26271d5d13f80173724b738557a8c7f3a8a3791"}, - {file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f6bd91345b5163ee7448bee201ed7dd601ca24f43f439109b0212e296eb5b423"}, - {file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc379c73fd66606628b866f661e8785088afe2adaba78e6bbe80796baf708a63"}, - {file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbdce4b47592f9e296e19ac31667daed8753c8367ebb34b9a9bd89dacaa299c9"}, - {file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc3cf31edf405a161a0adad83246568647c54404739b614b1ff43dad2b02e6d5"}, - {file = "pydantic_core-2.23.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8e22b477bf90db71c156f89a55bfe4d25177b81fce4aa09294d9e805eec13855"}, - {file = "pydantic_core-2.23.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:0a0137ddf462575d9bce863c4c95bac3493ba8e22f8c28ca94634b4a1d3e2bb4"}, - {file = "pydantic_core-2.23.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:203171e48946c3164fe7691fc349c79241ff8f28306abd4cad5f4f75ed80bc8d"}, - {file = "pydantic_core-2.23.3-cp313-none-win32.whl", hash = "sha256:76bdab0de4acb3f119c2a4bff740e0c7dc2e6de7692774620f7452ce11ca76c8"}, - {file = "pydantic_core-2.23.3-cp313-none-win_amd64.whl", hash = "sha256:37ba321ac2a46100c578a92e9a6aa33afe9ec99ffa084424291d84e456f490c1"}, - {file = "pydantic_core-2.23.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f399e8657c67313476a121a6944311fab377085ca7f490648c9af97fc732732d"}, - {file = "pydantic_core-2.23.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:6b5547d098c76e1694ba85f05b595720d7c60d342f24d5aad32c3049131fa5c4"}, - {file = "pydantic_core-2.23.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0dda0290a6f608504882d9f7650975b4651ff91c85673341789a476b1159f211"}, - {file = "pydantic_core-2.23.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b6e5da855e9c55a0c67f4db8a492bf13d8d3316a59999cfbaf98cc6e401961"}, - {file = "pydantic_core-2.23.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:09e926397f392059ce0afdcac920df29d9c833256354d0c55f1584b0b70cf07e"}, - {file = "pydantic_core-2.23.3-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:87cfa0ed6b8c5bd6ae8b66de941cece179281239d482f363814d2b986b79cedc"}, - {file = "pydantic_core-2.23.3-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e61328920154b6a44d98cabcb709f10e8b74276bc709c9a513a8c37a18786cc4"}, - {file = "pydantic_core-2.23.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ce3317d155628301d649fe5e16a99528d5680af4ec7aa70b90b8dacd2d725c9b"}, - {file = "pydantic_core-2.23.3.tar.gz", hash = "sha256:3cb0f65d8b4121c1b015c60104a685feb929a29d7cf204387c7f2688c7974690"}, + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"}, + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071"}, + {file = "pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119"}, + {file = "pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64"}, + {file = "pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f"}, + {file = "pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24"}, + {file = "pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84"}, + {file = "pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f"}, + {file = "pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769"}, + {file = "pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753"}, + {file = "pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863"}, ] [[package]] @@ -2358,7 +2385,7 @@ files = [ [[package]] name = "pymongo" -version = "4.8.0" +version = "4.9.2" requires_python = ">=3.8" summary = "Python driver for MongoDB " groups = ["default"] @@ -2366,34 +2393,43 @@ dependencies = [ "dnspython<3.0.0,>=1.16.0", ] files = [ - {file = "pymongo-4.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f2b7bec27e047e84947fbd41c782f07c54c30c76d14f3b8bf0c89f7413fac67a"}, - {file = "pymongo-4.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3c68fe128a171493018ca5c8020fc08675be130d012b7ab3efe9e22698c612a1"}, - {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:920d4f8f157a71b3cb3f39bc09ce070693d6e9648fb0e30d00e2657d1dca4e49"}, - {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52b4108ac9469febba18cea50db972605cc43978bedaa9fea413378877560ef8"}, - {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:180d5eb1dc28b62853e2f88017775c4500b07548ed28c0bd9c005c3d7bc52526"}, - {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aec2b9088cdbceb87e6ca9c639d0ff9b9d083594dda5ca5d3c4f6774f4c81b33"}, - {file = "pymongo-4.8.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0cf61450feadca81deb1a1489cb1a3ae1e4266efd51adafecec0e503a8dcd84"}, - {file = "pymongo-4.8.0-cp310-cp310-win32.whl", hash = "sha256:8b18c8324809539c79bd6544d00e0607e98ff833ca21953df001510ca25915d1"}, - {file = "pymongo-4.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e5df28f74002e37bcbdfdc5109799f670e4dfef0fb527c391ff84f078050e7b5"}, - {file = "pymongo-4.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6b50040d9767197b77ed420ada29b3bf18a638f9552d80f2da817b7c4a4c9c68"}, - {file = "pymongo-4.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:417369ce39af2b7c2a9c7152c1ed2393edfd1cbaf2a356ba31eb8bcbd5c98dd7"}, - {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf821bd3befb993a6db17229a2c60c1550e957de02a6ff4dd0af9476637b2e4d"}, - {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9365166aa801c63dff1a3cb96e650be270da06e3464ab106727223123405510f"}, - {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc8b8582f4209c2459b04b049ac03c72c618e011d3caa5391ff86d1bda0cc486"}, - {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16e5019f75f6827bb5354b6fef8dfc9d6c7446894a27346e03134d290eb9e758"}, - {file = "pymongo-4.8.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b5802151fc2b51cd45492c80ed22b441d20090fb76d1fd53cd7760b340ff554"}, - {file = "pymongo-4.8.0-cp311-cp311-win32.whl", hash = "sha256:4bf58e6825b93da63e499d1a58de7de563c31e575908d4e24876234ccb910eba"}, - {file = "pymongo-4.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:b747c0e257b9d3e6495a018309b9e0c93b7f0d65271d1d62e572747f4ffafc88"}, - {file = "pymongo-4.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e6a720a3d22b54183352dc65f08cd1547204d263e0651b213a0a2e577e838526"}, - {file = "pymongo-4.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31e4d21201bdf15064cf47ce7b74722d3e1aea2597c6785882244a3bb58c7eab"}, - {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b804bb4f2d9dc389cc9e827d579fa327272cdb0629a99bfe5b83cb3e269ebf"}, - {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f2fbdb87fe5075c8beb17a5c16348a1ea3c8b282a5cb72d173330be2fecf22f5"}, - {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd39455b7ee70aabee46f7399b32ab38b86b236c069ae559e22be6b46b2bbfc4"}, - {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:940d456774b17814bac5ea7fc28188c7a1338d4a233efbb6ba01de957bded2e8"}, - {file = "pymongo-4.8.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:236bbd7d0aef62e64caf4b24ca200f8c8670d1a6f5ea828c39eccdae423bc2b2"}, - {file = "pymongo-4.8.0-cp312-cp312-win32.whl", hash = "sha256:47ec8c3f0a7b2212dbc9be08d3bf17bc89abd211901093e3ef3f2adea7de7a69"}, - {file = "pymongo-4.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:e84bc7707492f06fbc37a9f215374d2977d21b72e10a67f1b31893ec5a140ad8"}, - {file = "pymongo-4.8.0.tar.gz", hash = "sha256:454f2295875744dc70f1881e4b2eb99cdad008a33574bc8aaf120530f66c0cde"}, + {file = "pymongo-4.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ab8d54529feb6e29035ba8f0570c99ad36424bc26486c238ad7ce28597bc43c8"}, + {file = "pymongo-4.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f928bdc152a995cbd0b563fab201b2df873846d11f7a41d1f8cc8a01b35591ab"}, + {file = "pymongo-4.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6e7251d59fa3dcbb1399a71a3aec63768cebc6b22180b671601c2195fe1f90a"}, + {file = "pymongo-4.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0e759ed0459e7264a11b6896016f616341a8e4c6ab7f71ae651bd21ffc7e9524"}, + {file = "pymongo-4.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f3fc60f242191840ccf02b898bc615b5141fbb70064f38f7e60fcaa35d3b5efd"}, + {file = "pymongo-4.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c798351666ac97a0ddaa823689061c3af949c2d6acf7fb2d9ab0a7f465ced79"}, + {file = "pymongo-4.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aac78b5fdd49ed8cae49adf76befacb02293a23b412676775c4715148e166d85"}, + {file = "pymongo-4.9.2-cp310-cp310-win32.whl", hash = "sha256:bf77bf175c315e299a91332c2bbebc097c4d4fcc8713e513a9861684aa39023a"}, + {file = "pymongo-4.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:c42b5aad8971256365bfd0a545fb1c7a199c93db80decd298ea2f987419e2a6d"}, + {file = "pymongo-4.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:99e40f44877b32bf4b3c46ceed2228f08c222cf7dec8a4366dd192a1429143fa"}, + {file = "pymongo-4.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f6834d575ed87edc7dfcab4501d961b6a423b3839edd29ecb1382eee7736777"}, + {file = "pymongo-4.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3010018f5672e5b7e8d096dea9f1ea6545b05345ff0eb1754f6ee63785550773"}, + {file = "pymongo-4.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:69394ee9f0ce38ff71266bad01b7e045cd75e58500ebad5d72187cbabf2e652a"}, + {file = "pymongo-4.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87b18094100f21615d9db99c255dcd9e93e476f10fb03c1d3632cf4b82d201d2"}, + {file = "pymongo-4.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3039e093d28376d6a54bdaa963ca12230c8a53d7b19c8e6368e19bcfbd004176"}, + {file = "pymongo-4.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ab42d9ee93fe6b90020c42cba5bfb43a2b4660951225d137835efc21940da48"}, + {file = "pymongo-4.9.2-cp311-cp311-win32.whl", hash = "sha256:a663ca60e187a248d370c58961e40f5463077d2b43831eb92120ea28a79ecf96"}, + {file = "pymongo-4.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:24e7b6887bbfefd05afed26a99a2c69459e2daa351a43a410de0d6c0ee3cce4e"}, + {file = "pymongo-4.9.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8083bbe8cb10bb33dca4d93f8223dd8d848215250bb73867374650bac5fe69e1"}, + {file = "pymongo-4.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a1b8c636bf557c7166e3799bbf1120806ca39e3f06615b141c88d9c9ceae4d8c"}, + {file = "pymongo-4.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8aac5dce28454f47576063fbad31ea9789bba67cab86c95788f97aafd810e65b"}, + {file = "pymongo-4.9.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1d5e7123af1fddf15b2b53e58f20bf5242884e671bcc3860f5e954fe13aeddd"}, + {file = "pymongo-4.9.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe97c847b56d61e533a7af0334193d6b28375b9189effce93129c7e4733794a9"}, + {file = "pymongo-4.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ad54433a996e2d1985a9cd8fc82538ca8747c95caae2daf453600cc8c317f9"}, + {file = "pymongo-4.9.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98b9cade40f5b13e04492a42ae215c3721099be1014ddfe0fbd23f27e4f62c0c"}, + {file = "pymongo-4.9.2-cp312-cp312-win32.whl", hash = "sha256:dde6068ae7c62ea8ee2c5701f78c6a75618cada7e11f03893687df87709558de"}, + {file = "pymongo-4.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:e1ab6cd7cd2d38ffc7ccdc79fdc166c7a91a63f844a96e3e6b2079c054391c68"}, + {file = "pymongo-4.9.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1ad79d6a74f439a068caf9a1e2daeabc20bf895263435484bbd49e90fbea7809"}, + {file = "pymongo-4.9.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:877699e21703717507cbbea23e75b419f81a513b50b65531e1698df08b2d7094"}, + {file = "pymongo-4.9.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc9322ce7cf116458a637ac10517b0c5926a8211202be6dbdc51dab4d4a9afc8"}, + {file = "pymongo-4.9.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cca029f46acf475504eedb33c7839f030c4bc4f946dcba12d9a954cc48850b79"}, + {file = "pymongo-4.9.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c8c861e77527eec5a4b7363c16030dd0374670b620b08a5300f97594bbf5a40"}, + {file = "pymongo-4.9.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1fc70326ae71b3c7b8d6af82f46bb71dafdba3c8f335b29382ae9cf263ef3a5c"}, + {file = "pymongo-4.9.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba9d2f6df977fee24437f82f7412460b0628cd6b961c4235c9cff71577a5b61f"}, + {file = "pymongo-4.9.2-cp313-cp313-win32.whl", hash = "sha256:b3254769e708bc4aa634745c262081d13c841a80038eff3afd15631540a1d227"}, + {file = "pymongo-4.9.2-cp313-cp313-win_amd64.whl", hash = "sha256:169b85728cc17800344ba17d736375f400ef47c9fbb4c42910c4b3e7c0247382"}, + {file = "pymongo-4.9.2.tar.gz", hash = "sha256:3e63535946f5df7848307b9031aa921f82bb0cbe45f9b0c3296f2173f9283eb0"}, ] [[package]] @@ -2583,8 +2619,8 @@ files = [ [[package]] name = "ray" -version = "2.35.0" -requires_python = ">=3.8" +version = "2.37.0" +requires_python = ">=3.9" summary = "Ray provides a simple, universal API for building distributed applications." groups = ["default"] dependencies = [ @@ -2600,28 +2636,28 @@ dependencies = [ "requests", ] files = [ - {file = "ray-2.35.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:1e7e2d2e987be728a81821b6fd2bccb23e4d8a6cca8417db08b24f06a08d8476"}, - {file = "ray-2.35.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8bd48be4c362004d31e5df072fd58b929efc67adfefc0adece41483b15f84539"}, - {file = "ray-2.35.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ef41e9254f3e18a90a8cf13fac9e35ac086eb778079ab6c76a37d3a6059186c5"}, - {file = "ray-2.35.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:1994aaf9996ffc45019856545e817d527ad572762f1af76ad669ae4e786fcfd6"}, - {file = "ray-2.35.0-cp310-cp310-win_amd64.whl", hash = "sha256:d3b7a7d73f818e249064460ffa95402ebd852bf97d9ec6167b8b0d95be03da9f"}, - {file = "ray-2.35.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:e29754fac4b69a9cb0d089841af59ec6fb10b5d4a248b7c579d319ca2ed1c96f"}, - {file = "ray-2.35.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d7a606c8ca53c64fc496703e9fd15d1a1ffb50e6b457a33d3622be2f13fc30a5"}, - {file = "ray-2.35.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ac561e20a62ce941b74d02a0b92b7765c6ba87cc22e24f34f64ded2c454ba64e"}, - {file = "ray-2.35.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:587af570cbe5f6cedca854f15107740e63c67207bee900713cb2ee38f6ebf20f"}, - {file = "ray-2.35.0-cp311-cp311-win_amd64.whl", hash = "sha256:8e406cce41679790146d4d2b1b0cb0b413ca35276e43b68ee796366169c1dbde"}, - {file = "ray-2.35.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:eb86355a3a0e794e2f1dbd5a84805dddfca64921ad0999b7fa5276e40d243692"}, - {file = "ray-2.35.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b746913268d5ea5e19bff0eb6bdc7e0538036892a8b57c08411787481195df2"}, - {file = "ray-2.35.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:e2ccfd144180f03d38b02a81afdac2b437f27e46736bf2653a1f0e8d67ea56cd"}, - {file = "ray-2.35.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:2ca1a0de41d4462fd764598a5981cf55fc955599f38f9a1ae10868e94c6dd80d"}, - {file = "ray-2.35.0-cp312-cp312-win_amd64.whl", hash = "sha256:c5600f745bb0e4df840a5cd51e82b1acf517f73505df9869fe3e369966956129"}, + {file = "ray-2.37.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:96366285038fe0c47e975ffd64eb891f70fb863a80be91c0be64f2ab0cf16d9c"}, + {file = "ray-2.37.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:31c55de41b7e1899a62f2dd6a693ffca0a4cb52633aa66617e3816d48b70aac3"}, + {file = "ray-2.37.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:aee7ff189fd52530d020b13c5e7e6da55e65456193a349d39635a72981e521db"}, + {file = "ray-2.37.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:29932441e68ab7dad35b276c763670bf42ebf721cddc4f4de8200bd92ac05c58"}, + {file = "ray-2.37.0-cp310-cp310-win_amd64.whl", hash = "sha256:8a96139143584558507b7bca05581962d92ff86fdd0c58210ed53adc7340ec98"}, + {file = "ray-2.37.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa642e9b34e88c6a7edb17b291201351d44f063e04ba9f1e83e42aaf492fc14a"}, + {file = "ray-2.37.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c53ee350a009bab6b811254f8407387812de9a290269e32dbf7c3f0dce6c93c9"}, + {file = "ray-2.37.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:60298e199d9938d3be7418e0645aae312f1283e31123991053d36d0ff1e4ec43"}, + {file = "ray-2.37.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:b420279ca14f02cc27fc592ff1f28da9aa08b962316bf65ddf370db877082e91"}, + {file = "ray-2.37.0-cp311-cp311-win_amd64.whl", hash = "sha256:7faff20ea7a06612d3cd860a61d2736aa9f82d0d2bcef0917717ced67c8b51c5"}, + {file = "ray-2.37.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:860f3d45438c3daad30f034f107e3fed05a710c7251e10714f942be598715bd2"}, + {file = "ray-2.37.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0b8c23ced4186040dee37e982227e3b1296e2fcbd4c520e4399e5d99ed3c641d"}, + {file = "ray-2.37.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:75cd9a1f6f332ac00d77154b24bd38f4b46a4e600cd02a2440e69b918273b475"}, + {file = "ray-2.37.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:0268c7bc2e8bb6ef9bb8969299deb5857bf672bfcb59da95db7495a8a502f8ba"}, + {file = "ray-2.37.0-cp312-cp312-win_amd64.whl", hash = "sha256:4132f79902160c650eaffe1ed1265e5b88d461ff5f3a777a16a750beeed7de1e"}, ] [[package]] name = "ray" -version = "2.35.0" +version = "2.37.0" extras = ["default"] -requires_python = ">=3.8" +requires_python = ">=3.9" summary = "Ray provides a simple, universal API for building distributed applications." groups = ["default"] dependencies = [ @@ -2635,27 +2671,27 @@ dependencies = [ "prometheus-client>=0.7.1", "py-spy>=0.2.0", "pydantic!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3", - "ray==2.35.0", + "ray==2.37.0", "requests", "smart-open", "virtualenv!=20.21.1,>=20.0.24", ] files = [ - {file = "ray-2.35.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:1e7e2d2e987be728a81821b6fd2bccb23e4d8a6cca8417db08b24f06a08d8476"}, - {file = "ray-2.35.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8bd48be4c362004d31e5df072fd58b929efc67adfefc0adece41483b15f84539"}, - {file = "ray-2.35.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ef41e9254f3e18a90a8cf13fac9e35ac086eb778079ab6c76a37d3a6059186c5"}, - {file = "ray-2.35.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:1994aaf9996ffc45019856545e817d527ad572762f1af76ad669ae4e786fcfd6"}, - {file = "ray-2.35.0-cp310-cp310-win_amd64.whl", hash = "sha256:d3b7a7d73f818e249064460ffa95402ebd852bf97d9ec6167b8b0d95be03da9f"}, - {file = "ray-2.35.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:e29754fac4b69a9cb0d089841af59ec6fb10b5d4a248b7c579d319ca2ed1c96f"}, - {file = "ray-2.35.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d7a606c8ca53c64fc496703e9fd15d1a1ffb50e6b457a33d3622be2f13fc30a5"}, - {file = "ray-2.35.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ac561e20a62ce941b74d02a0b92b7765c6ba87cc22e24f34f64ded2c454ba64e"}, - {file = "ray-2.35.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:587af570cbe5f6cedca854f15107740e63c67207bee900713cb2ee38f6ebf20f"}, - {file = "ray-2.35.0-cp311-cp311-win_amd64.whl", hash = "sha256:8e406cce41679790146d4d2b1b0cb0b413ca35276e43b68ee796366169c1dbde"}, - {file = "ray-2.35.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:eb86355a3a0e794e2f1dbd5a84805dddfca64921ad0999b7fa5276e40d243692"}, - {file = "ray-2.35.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b746913268d5ea5e19bff0eb6bdc7e0538036892a8b57c08411787481195df2"}, - {file = "ray-2.35.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:e2ccfd144180f03d38b02a81afdac2b437f27e46736bf2653a1f0e8d67ea56cd"}, - {file = "ray-2.35.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:2ca1a0de41d4462fd764598a5981cf55fc955599f38f9a1ae10868e94c6dd80d"}, - {file = "ray-2.35.0-cp312-cp312-win_amd64.whl", hash = "sha256:c5600f745bb0e4df840a5cd51e82b1acf517f73505df9869fe3e369966956129"}, + {file = "ray-2.37.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:96366285038fe0c47e975ffd64eb891f70fb863a80be91c0be64f2ab0cf16d9c"}, + {file = "ray-2.37.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:31c55de41b7e1899a62f2dd6a693ffca0a4cb52633aa66617e3816d48b70aac3"}, + {file = "ray-2.37.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:aee7ff189fd52530d020b13c5e7e6da55e65456193a349d39635a72981e521db"}, + {file = "ray-2.37.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:29932441e68ab7dad35b276c763670bf42ebf721cddc4f4de8200bd92ac05c58"}, + {file = "ray-2.37.0-cp310-cp310-win_amd64.whl", hash = "sha256:8a96139143584558507b7bca05581962d92ff86fdd0c58210ed53adc7340ec98"}, + {file = "ray-2.37.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa642e9b34e88c6a7edb17b291201351d44f063e04ba9f1e83e42aaf492fc14a"}, + {file = "ray-2.37.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c53ee350a009bab6b811254f8407387812de9a290269e32dbf7c3f0dce6c93c9"}, + {file = "ray-2.37.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:60298e199d9938d3be7418e0645aae312f1283e31123991053d36d0ff1e4ec43"}, + {file = "ray-2.37.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:b420279ca14f02cc27fc592ff1f28da9aa08b962316bf65ddf370db877082e91"}, + {file = "ray-2.37.0-cp311-cp311-win_amd64.whl", hash = "sha256:7faff20ea7a06612d3cd860a61d2736aa9f82d0d2bcef0917717ced67c8b51c5"}, + {file = "ray-2.37.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:860f3d45438c3daad30f034f107e3fed05a710c7251e10714f942be598715bd2"}, + {file = "ray-2.37.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0b8c23ced4186040dee37e982227e3b1296e2fcbd4c520e4399e5d99ed3c641d"}, + {file = "ray-2.37.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:75cd9a1f6f332ac00d77154b24bd38f4b46a4e600cd02a2440e69b918273b475"}, + {file = "ray-2.37.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:0268c7bc2e8bb6ef9bb8969299deb5857bf672bfcb59da95db7495a8a502f8ba"}, + {file = "ray-2.37.0-cp312-cp312-win_amd64.whl", hash = "sha256:4132f79902160c650eaffe1ed1265e5b88d461ff5f3a777a16a750beeed7de1e"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index b3e13d7..11fedfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "eos" -version = "0.3.0" +version = "0.4.0" description = "The Experiment Orchestration System (EOS) is a comprehensive software framework and runtime for laboratory automation." keywords = ["automation", "science", "lab", "experiment", "orchestration", "distributed", "infrastructure"] authors = [ @@ -22,19 +22,20 @@ classifiers = [ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "ray[default]~=2.35.0", + "ray[default]~=2.37.0", "typer~=0.12.5", "rich~=13.8.1", "omegaconf~=2.3.0", "jinja2~=3.1.4", "PyYAML~=6.0.2", - "networkx~=3.3.0", - "pymongo~=4.8.0", - "pydantic~=2.9.1", + "networkx~=3.4.1", + "pymongo~=4.9.2", + "motor~=3.6.0", + "pydantic~=2.9.2", "bofire[optimization]~=0.0.13", - "pandas~=2.2.2", + "pandas~=2.2.3", "numpy~=1.26.2", - "litestar[standard]~=2.11.0", + "litestar[standard]~=2.12.1", "minio~=7.2.8", ] @@ -85,6 +86,7 @@ testpaths = [ markers = [ "slow: mark tests as slow (deselect with '-m \"not slow\"')", ] +asyncio_mode = "auto" [tool.ruff] include = [ @@ -151,4 +153,6 @@ lint.ignore = ["I001", "ANN001", "ANN002", "ANN003", "ANN101", "ANN204", "ANN401 "PT023", "PLR0913", "PLR2004", + "F401", + "F811", ] diff --git a/tests/conftest.py b/tests/conftest.py index d8351ee..e11b487 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,18 @@ +import asyncio + +import pytest + + +@pytest.fixture(scope="session") +def event_loop(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + yield loop + loop.close() + + def pytest_collection_modifyitems(items): """Sort tests by the slow marker. Tests with the slow marker will be executed last.""" diff --git a/tests/fixtures.py b/tests/fixtures.py index 76d1c51..1506b0b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -18,39 +18,35 @@ from eos.experiments.experiment_executor_factory import ExperimentExecutorFactory from eos.experiments.experiment_manager import ExperimentManager from eos.logging.logger import log -from eos.persistence.db_manager import DbManager -from eos.persistence.file_db_manager import FileDbManager +from eos.persistence.async_mongodb_interface import AsyncMongoDbInterface +from eos.persistence.file_db_interface import FileDbInterface from eos.persistence.service_credentials import ServiceCredentials -from eos.resource_allocation.container_allocation_manager import ContainerAllocationManager -from eos.resource_allocation.device_allocation_manager import DeviceAllocationManager +from eos.resource_allocation.container_allocator import ContainerAllocator +from eos.resource_allocation.device_allocator import DeviceAllocator from eos.resource_allocation.resource_allocation_manager import ( ResourceAllocationManager, ) from eos.scheduling.greedy_scheduler import GreedyScheduler +from eos.tasks.on_demand_task_executor import OnDemandTaskExecutor from eos.tasks.task_executor import TaskExecutor from eos.tasks.task_manager import TaskManager log.set_level("INFO") -def load_test_config(config_name): - config_path = Path(__file__).resolve().parent / "test_config.yaml" +def load_test_config(): + config_path = Path(__file__).resolve().parent / "test_config.yml" if not config_path.exists(): raise FileNotFoundError(f"Test config file not found at {config_path}") with Path(config_path).open("r") as file: - config = yaml.safe_load(file) - - if config_name not in config: - raise KeyError(f"Config key {config_name} not found in test config file") - - return config.get(config_name) + return yaml.safe_load(file) @pytest.fixture(scope="session") def configuration_manager(): - config = load_test_config("configuration_manager") + config = load_test_config() root_dir = Path(__file__).resolve().parent.parent user_dir = root_dir / config["user_dir"] os.chdir(root_dir) @@ -64,43 +60,31 @@ def task_specification_registry(configuration_manager): @pytest.fixture def user_dir(): - config = load_test_config("configuration_manager") + config = load_test_config() root_dir = Path(__file__).resolve().parent.parent return root_dir / config["user_dir"] @pytest.fixture(scope="session") -def db_manager(): - config = load_test_config("db_manager") - - db_credentials_config = config["db_credentials"] - db_credentials = ServiceCredentials( - host=db_credentials_config["host"], - port=db_credentials_config["port"], - username=db_credentials_config["username"], - password=db_credentials_config["password"], - ) +def db_interface(): + config = load_test_config() + + db_credentials = ServiceCredentials(**config["db"]) - return DbManager(db_credentials, "test-eos") + return AsyncMongoDbInterface(db_credentials, "test-eos") @pytest.fixture(scope="session") -def file_db_manager(db_manager): - config = load_test_config("file_db_manager") - - file_db_credentials_config = config["file_db_credentials"] - file_db_credentials = ServiceCredentials( - host=file_db_credentials_config["host"], - port=file_db_credentials_config["port"], - username=file_db_credentials_config["username"], - password=file_db_credentials_config["password"], - ) +def file_db_interface(db_interface): + config = load_test_config() + + file_db_credentials = ServiceCredentials(**config["file_db"]) - return FileDbManager(file_db_credentials, bucket_name="test-eos") + return FileDbInterface(file_db_credentials, bucket_name="test-eos") @pytest.fixture -def setup_lab_experiment(request, configuration_manager, db_manager): +def setup_lab_experiment(request, configuration_manager, db_interface): lab_name, experiment_name = request.param if lab_name not in configuration_manager.labs: @@ -124,46 +108,60 @@ def experiment_graph(setup_lab_experiment): @pytest.fixture -def clean_db(db_manager): - db_manager.clean_db() +async def clean_db(db_interface): + await db_interface.clean_db() @pytest.fixture -def container_manager(setup_lab_experiment, configuration_manager, db_manager, clean_db): - return ContainerManager(configuration_manager, db_manager) +async def container_manager(setup_lab_experiment, configuration_manager, db_interface, clean_db): + container_manager = ContainerManager(configuration_manager=configuration_manager, db_interface=db_interface) + await container_manager.initialize(db_interface) + return container_manager @pytest.fixture -def device_manager(setup_lab_experiment, configuration_manager, db_manager, clean_db): - device_manager = DeviceManager(configuration_manager, db_manager) - device_manager.update_devices(loaded_labs=set(configuration_manager.labs.keys())) +async def device_manager(setup_lab_experiment, configuration_manager, db_interface, clean_db): + device_manager = DeviceManager(configuration_manager, db_interface) + await device_manager.initialize(db_interface) + + await device_manager.update_devices(loaded_labs=set(configuration_manager.labs.keys())) yield device_manager - device_manager.cleanup_device_actors() + await device_manager.cleanup_device_actors() @pytest.fixture -def experiment_manager(setup_lab_experiment, configuration_manager, db_manager, clean_db): - return ExperimentManager(configuration_manager, db_manager) +async def experiment_manager(setup_lab_experiment, configuration_manager, db_interface, clean_db): + experiment_manager = ExperimentManager(configuration_manager, db_interface) + await experiment_manager.initialize(db_interface) + return experiment_manager @pytest.fixture -def container_allocator(setup_lab_experiment, configuration_manager, db_manager, clean_db): - return ContainerAllocationManager(configuration_manager, db_manager) +async def container_allocator(setup_lab_experiment, configuration_manager, db_interface, clean_db): + container_allocator = ContainerAllocator(configuration_manager, db_interface) + await container_allocator.initialize(db_interface) + return container_allocator @pytest.fixture -def device_allocator(setup_lab_experiment, configuration_manager, db_manager, clean_db): - return DeviceAllocationManager(configuration_manager, db_manager) +async def device_allocator(setup_lab_experiment, configuration_manager, db_interface, clean_db): + device_allocator = DeviceAllocator(configuration_manager, db_interface) + await device_allocator.initialize(db_interface) + return device_allocator @pytest.fixture -def resource_allocation_manager(setup_lab_experiment, configuration_manager, db_manager, clean_db): - return ResourceAllocationManager(configuration_manager, db_manager) +async def resource_allocation_manager(setup_lab_experiment, configuration_manager, db_interface, clean_db): + resource_allocation_manager = ResourceAllocationManager(db_interface) + await resource_allocation_manager.initialize(configuration_manager, db_interface) + return resource_allocation_manager @pytest.fixture -def task_manager(setup_lab_experiment, configuration_manager, db_manager, file_db_manager, clean_db): - return TaskManager(configuration_manager, db_manager, file_db_manager) +async def task_manager(setup_lab_experiment, configuration_manager, db_interface, file_db_interface, clean_db): + task_manager = TaskManager(configuration_manager, db_interface, file_db_interface) + await task_manager.initialize(db_interface) + return task_manager @pytest.fixture(scope="session", autouse=True) @@ -187,6 +185,18 @@ def task_executor( ) +@pytest.fixture +def on_demand_task_executor( + setup_lab_experiment, + task_executor, + task_manager, + container_manager, +): + return OnDemandTaskExecutor( + task_executor, task_manager, container_manager + ) + + @pytest.fixture def greedy_scheduler( setup_lab_experiment, @@ -246,19 +256,23 @@ def experiment_executor_factory( @pytest.fixture -def campaign_manager( +async def campaign_manager( configuration_manager, - db_manager, + db_interface, ): - return CampaignManager(configuration_manager, db_manager) + campaign_manager = CampaignManager(configuration_manager, db_interface) + await campaign_manager.initialize(db_interface) + return campaign_manager @pytest.fixture -def campaign_optimizer_manager( +async def campaign_optimizer_manager( configuration_manager, - db_manager, + db_interface, ): - return CampaignOptimizerManager(configuration_manager, db_manager) + campaign_optimizer_manager = CampaignOptimizerManager(configuration_manager, db_interface) + await campaign_optimizer_manager.initialize(db_interface) + return campaign_optimizer_manager @pytest.fixture diff --git a/tests/test_base_device.py b/tests/test_base_device.py index e6934d9..9404ead 100644 --- a/tests/test_base_device.py +++ b/tests/test_base_device.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +import ray from eos.devices.base_device import BaseDevice, DeviceStatus from eos.devices.exceptions import EosDeviceError, EosDeviceCleanupError, EosDeviceInitializationError @@ -87,8 +88,3 @@ def test_double_initialization(self, mock_device): with pytest.raises(EosDeviceInitializationError): mock_device.initialize({}) assert mock_device.status == DeviceStatus.IDLE - - def test_del_method(self, mock_device): - mock_device.__del__() - assert mock_device.status == DeviceStatus.DISABLED - assert mock_device.mock_resource is None diff --git a/tests/test_bayesian_sequential_optimizer.py b/tests/test_bayesian_sequential_optimizer.py index 55521f6..3e93354 100644 --- a/tests/test_bayesian_sequential_optimizer.py +++ b/tests/test_bayesian_sequential_optimizer.py @@ -30,7 +30,7 @@ def test_single_objective_optimization(self): optimal_solutions = optimizer.get_optimal_solutions() assert len(optimal_solutions) == 1 - assert abs(optimal_solutions["y"].to_numpy()[0] - 4) < 0.01 + assert abs(optimal_solutions["y"].to_numpy()[0] - 4) < 0.02 @pytest.mark.slow def test_competing_multi_objective_optimization(self): @@ -48,7 +48,7 @@ def test_competing_multi_objective_optimization(self): initial_sampling_method=SamplingMethodEnum.SOBOL, ) - for _ in range(30): + for _ in range(20): parameters = optimizer.sample() results = pd.DataFrame() results["y1"] = -((parameters["x"] - 2) ** 2) + 4 # Objective 1: Maximize y1 @@ -74,8 +74,8 @@ def test_competing_multi_objective_optimization(self): for true_solution in true_pareto_front: assert any( - abs(solution["x"] - true_solution["x"]) < 0.7 - and abs(solution["y1"] - true_solution["y1"]) < 0.7 - and abs(solution["y2"] - true_solution["y2"]) < 0.7 + abs(solution["x"] - true_solution["x"]) < 0.8 + and abs(solution["y1"] - true_solution["y1"]) < 0.8 + and abs(solution["y2"] - true_solution["y2"]) < 0.8 for _, solution in optimal_solutions.iterrows() ) diff --git a/tests/test_campaign_executor.py b/tests/test_campaign_executor.py index 4e6b2a3..5e2be14 100644 --- a/tests/test_campaign_executor.py +++ b/tests/test_campaign_executor.py @@ -8,7 +8,7 @@ LAB_ID = "multiplication_lab" CAMPAIGN_ID = "optimize_multiplication_campaign" EXPERIMENT_TYPE = "optimize_multiplication" -MAX_EXPERIMENTS = 40 +MAX_EXPERIMENTS = 30 DO_OPTIMIZATION = True @@ -28,7 +28,7 @@ class TestCampaignExecutor: async def test_start_campaign(self, campaign_executor, campaign_manager): await campaign_executor.start_campaign() - campaign = campaign_manager.get_campaign(CAMPAIGN_ID) + campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) assert campaign is not None assert campaign.id == CAMPAIGN_ID assert campaign.status == CampaignStatus.RUNNING @@ -46,7 +46,7 @@ async def test_progress_campaign(self, campaign_executor, campaign_manager, camp solutions = await campaign_executor.optimizer.get_optimal_solutions.remote() assert not solutions.empty assert len(solutions) == 1 - assert solutions["compute_multiplication_objective.objective"].iloc[0] / 100 <= 80 + assert solutions["compute_multiplication_objective.objective"].iloc[0] / 100 <= 120 @pytest.mark.slow @pytest.mark.asyncio @@ -69,7 +69,7 @@ async def mock_progress_experiment(*args, **kwargs): assert campaign_executor._campaign_status == CampaignStatus.FAILED # Verify that the campaign manager has marked the campaign as failed - campaign = campaign_manager.get_campaign(CAMPAIGN_ID) + campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) assert campaign.status == CampaignStatus.FAILED @pytest.mark.slow @@ -81,7 +81,7 @@ async def test_campaign_cancellation(self, campaign_executor, campaign_manager): completed_experiments = 0 while completed_experiments < 2: await campaign_executor.progress_campaign() - campaign = campaign_manager.get_campaign(CAMPAIGN_ID) + campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) completed_experiments = campaign.experiments_completed await asyncio.sleep(0.1) @@ -90,7 +90,7 @@ async def test_campaign_cancellation(self, campaign_executor, campaign_manager): await campaign_executor.cancel_campaign() - campaign = campaign_manager.get_campaign(CAMPAIGN_ID) + campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) assert campaign.status == CampaignStatus.CANCELLED # Try to progress the campaign after cancellation @@ -112,15 +112,16 @@ async def test_campaign_resuming( completed_experiments = 0 while completed_experiments < 3: await campaign_executor.progress_campaign() - campaign = campaign_manager.get_campaign(CAMPAIGN_ID) + campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) completed_experiments = campaign.experiments_completed await asyncio.sleep(0.1) - initial_campaign = campaign_manager.get_campaign(CAMPAIGN_ID) + initial_campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) num_initial_reported_samples = ray.get(campaign_executor.optimizer.get_num_samples_reported.remote()) await campaign_executor.cancel_campaign() - assert campaign_manager.get_campaign(CAMPAIGN_ID).status == CampaignStatus.CANCELLED + campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) + assert campaign.status == CampaignStatus.CANCELLED campaign_executor.cleanup() # Create a new campaign executor to resume the campaign @@ -137,7 +138,7 @@ async def test_campaign_resuming( experiment_executor_factory, ) await new_campaign_executor.start_campaign() - resumed_campaign = campaign_manager.get_campaign(CAMPAIGN_ID) + resumed_campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) assert resumed_campaign.status == CampaignStatus.RUNNING # Verify that the number of completed experiments is preserved @@ -162,7 +163,7 @@ async def test_campaign_cancellation_timeout(self, campaign_executor, campaign_m # Run until one experiment is complete while ( - campaign_manager.get_campaign(CAMPAIGN_ID).experiments_completed < 1 + (await campaign_manager.get_campaign(CAMPAIGN_ID)).experiments_completed < 1 or len(campaign_executor._experiment_executors) < 1 ): await campaign_executor.progress_campaign() @@ -183,5 +184,5 @@ async def cancel_experiment(self): await campaign_executor.cancel_campaign() assert "Timed out while cancelling experiments" in str(exc_info.value) - campaign = campaign_manager.get_campaign(CAMPAIGN_ID) + campaign = await campaign_manager.get_campaign(CAMPAIGN_ID) assert campaign.status == CampaignStatus.CANCELLED diff --git a/tests/test_config.yaml b/tests/test_config.yaml deleted file mode 100644 index 3fe82e7..0000000 --- a/tests/test_config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -configuration_manager: - user_dir: tests/user - -db_manager: - db_credentials: - host: localhost - port: 27017 - username: eos-user - password: eos-password - -file_db_manager: - file_db_credentials: - host: localhost - port: 9004 - username: eos-user - password: eos-password diff --git a/tests/test_config.yml b/tests/test_config.yml new file mode 100644 index 0000000..eb9ef3a --- /dev/null +++ b/tests/test_config.yml @@ -0,0 +1,13 @@ +user_dir: tests/user + +db: + host: localhost + port: 27017 + username: eos-user + password: eos-password + +file_db: + host: localhost + port: 9004 + username: eos-user + password: eos-password diff --git a/tests/test_container_allocator.py b/tests/test_container_allocator.py index acd1bde..bfc86f3 100644 --- a/tests/test_container_allocator.py +++ b/tests/test_container_allocator.py @@ -9,76 +9,86 @@ "setup_lab_experiment", [("small_lab", "water_purification")], indirect=True ) class TestContainerAllocator: - def test_allocate_container(self, container_allocator): + @pytest.mark.asyncio + async def test_allocate_container(self, container_allocator): container_id = "ec1ca48cd5d14c0c8cde376476e0d98d" - container_allocator.allocate(container_id, "owner", "water_purification_1") - container = container_allocator.get_allocation(container_id) + await container_allocator.allocate(container_id, "owner", "water_purification_1") + container = await container_allocator.get_allocation(container_id) assert container.id == container_id assert container.owner == "owner" assert container.experiment_id == "water_purification_1" - def test_allocate_container_already_allocated(self, container_allocator): + @pytest.mark.asyncio + async def test_allocate_container_already_allocated(self, container_allocator): container_id = "ec1ca48cd5d14c0c8cde376476e0d98d" - container_allocator.allocate(container_id, "owner", "water_purification_1") + await container_allocator.allocate(container_id, "owner", "water_purification_1") with pytest.raises(EosContainerAllocatedError): - container_allocator.allocate(container_id, "owner", "water_purification_1") + await container_allocator.allocate(container_id, "owner", "water_purification_1") - def test_allocate_nonexistent_container(self, container_allocator): + @pytest.mark.asyncio + async def test_allocate_nonexistent_container(self, container_allocator): container_id = "nonexistent_container_id" with pytest.raises(EosContainerNotFoundError): - container_allocator.allocate(container_id, "owner", "water_purification_1") + await container_allocator.allocate(container_id, "owner", "water_purification_1") - def test_deallocate_container(self, container_allocator): + @pytest.mark.asyncio + async def test_deallocate_container(self, container_allocator): container_id = "ec1ca48cd5d14c0c8cde376476e0d98d" - container_allocator.allocate(container_id, "owner", "water_purification_1") + await container_allocator.allocate(container_id, "owner", "water_purification_1") - container_allocator.deallocate(container_id) - container = container_allocator.get_allocation(container_id) + await container_allocator.deallocate(container_id) + container = await container_allocator.get_allocation(container_id) assert container is None - def test_deallocate_container_not_allocated(self, container_allocator): + @pytest.mark.asyncio + async def test_deallocate_container_not_allocated(self, container_allocator): container_id = "ec1ca48cd5d14c0c8cde376476e0d98d" - container_allocator.deallocate(container_id) - assert container_allocator.get_allocation(container_id) is None + await container_allocator.deallocate(container_id) - def test_is_allocated(self, container_allocator): + allocation = await container_allocator.get_allocation(container_id) + assert allocation is None + + @pytest.mark.asyncio + async def test_is_allocated(self, container_allocator): container_id = "ec1ca48cd5d14c0c8cde376476e0d98d" - assert not container_allocator.is_allocated(container_id) + assert not await container_allocator.is_allocated(container_id) - container_allocator.allocate(container_id, "owner", "water_purification_1") - assert container_allocator.is_allocated(container_id) + await container_allocator.allocate(container_id, "owner", "water_purification_1") + assert await container_allocator.is_allocated(container_id) - def test_get_allocations_by_owner(self, container_allocator): + @pytest.mark.asyncio + async def test_get_allocations_by_owner(self, container_allocator): container_id_1 = "ec1ca48cd5d14c0c8cde376476e0d98d" container_id_2 = "84eb17d61e884ffd9d1fdebcbad1532b" container_id_3 = "a3b958aea8bd435386cdcbab20a2d3ec" - container_allocator.allocate(container_id_1, "owner", "water_purification_1") - container_allocator.allocate(container_id_2, "owner", "water_purification_1") - container_allocator.allocate(container_id_3, "another_owner", "water_purification_1") + await container_allocator.allocate(container_id_1, "owner", "water_purification_1") + await container_allocator.allocate(container_id_2, "owner", "water_purification_1") + await container_allocator.allocate(container_id_3, "another_owner", "water_purification_1") - allocations = container_allocator.get_allocations(owner="owner") + allocations = await container_allocator.get_allocations(owner="owner") assert allocations[0].id == container_id_1 assert allocations[1].id == container_id_2 assert len(allocations) == 2 - allocations = container_allocator.get_allocations(owner="another_owner") + allocations = await container_allocator.get_allocations(owner="another_owner") assert allocations[0].id == container_id_3 assert len(allocations) == 1 - def test_get_all_allocations(self, container_allocator): + @pytest.mark.asyncio + async def test_get_all_allocations(self, container_allocator): container_id_1 = "ec1ca48cd5d14c0c8cde376476e0d98d" container_id_2 = "84eb17d61e884ffd9d1fdebcbad1532b" container_id_3 = "a3b958aea8bd435386cdcbab20a2d3ec" - container_allocator.allocate(container_id_1, "owner", "water_purification_1") - container_allocator.allocate(container_id_2, "owner", "water_purification_1") - container_allocator.allocate(container_id_3, "another_owner", "water_purification_1") + await container_allocator.allocate(container_id_1, "owner", "water_purification_1") + await container_allocator.allocate(container_id_2, "owner", "water_purification_1") + await container_allocator.allocate(container_id_3, "another_owner", "water_purification_1") - allocations = container_allocator.get_allocations() + allocations = await container_allocator.get_allocations() assert len(allocations) == 3 assert {allocation.id for allocation in allocations} == { container_id_1, @@ -86,50 +96,53 @@ def test_get_all_allocations(self, container_allocator): container_id_3, } - def test_get_all_unallocated_containers(self, container_allocator): + @pytest.mark.asyncio + async def test_get_all_unallocated_containers(self, container_allocator): container_id_1 = "ec1ca48cd5d14c0c8cde376476e0d98d" container_id_2 = "84eb17d61e884ffd9d1fdebcbad1532b" container_id_3 = "a3b958aea8bd435386cdcbab20a2d3ec" - initial_unallocated_containers = container_allocator.get_all_unallocated() + initial_unallocated_containers = await container_allocator.get_all_unallocated() - container_allocator.allocate(container_id_1, "owner1", "water_purification_1") - container_allocator.allocate(container_id_2, "owner2", "water_purification_1") + await container_allocator.allocate(container_id_1, "owner1", "water_purification_1") + await container_allocator.allocate(container_id_2, "owner2", "water_purification_1") - new_unallocated_containers = container_allocator.get_all_unallocated() + new_unallocated_containers = await container_allocator.get_all_unallocated() assert len(new_unallocated_containers) == len(initial_unallocated_containers) - 2 assert container_id_1 not in new_unallocated_containers assert container_id_2 not in new_unallocated_containers assert container_id_3 in new_unallocated_containers - def test_deallocate_all_containers(self, container_allocator): + @pytest.mark.asyncio + async def test_deallocate_all_containers(self, container_allocator): container_id_1 = "ec1ca48cd5d14c0c8cde376476e0d98d" container_id_2 = "84eb17d61e884ffd9d1fdebcbad1532b" container_id_3 = "a3b958aea8bd435386cdcbab20a2d3ec" - container_allocator.allocate(container_id_1, "owner1", "water_purification_1") - container_allocator.allocate(container_id_2, "owner2", "water_purification_1") - container_allocator.allocate(container_id_3, "owner3", "water_purification_1") + await container_allocator.allocate(container_id_1, "owner1", "water_purification_1") + await container_allocator.allocate(container_id_2, "owner2", "water_purification_1") + await container_allocator.allocate(container_id_3, "owner3", "water_purification_1") - assert container_allocator.get_allocations() != [] + assert await container_allocator.get_allocations() != [] - container_allocator.deallocate_all() + await container_allocator.deallocate_all() - assert container_allocator.get_allocations() == [] + assert await container_allocator.get_allocations() == [] - def test_deallocate_all_containers_by_owner(self, container_allocator): + @pytest.mark.asyncio + async def test_deallocate_all_containers_by_owner(self, container_allocator): container_id_1 = "ec1ca48cd5d14c0c8cde376476e0d98d" container_id_2 = "84eb17d61e884ffd9d1fdebcbad1532b" container_id_3 = "a3b958aea8bd435386cdcbab20a2d3ec" - container_allocator.allocate(container_id_1, "owner1", "water_purification_1") - container_allocator.allocate(container_id_2, "owner2", "water_purification_1") - container_allocator.allocate(container_id_3, "owner2", "water_purification_1") + await container_allocator.allocate(container_id_1, "owner1", "water_purification_1") + await container_allocator.allocate(container_id_2, "owner2", "water_purification_1") + await container_allocator.allocate(container_id_3, "owner2", "water_purification_1") - container_allocator.deallocate_all_by_owner("owner2") + await container_allocator.deallocate_all_by_owner("owner2") - owner2_allocations = container_allocator.get_allocations(owner="owner2") + owner2_allocations = await container_allocator.get_allocations(owner="owner2") assert owner2_allocations == [] - assert container_allocator.get_allocations() == [ - container_allocator.get_allocation(container_id_1) + assert await container_allocator.get_allocations() == [ + await container_allocator.get_allocation(container_id_1) ] diff --git a/tests/test_container_manager.py b/tests/test_container_manager.py index 3a7522d..bdb0ebc 100644 --- a/tests/test_container_manager.py +++ b/tests/test_container_manager.py @@ -1,46 +1,52 @@ from tests.fixtures import * -@pytest.fixture -def container_manager(configuration_manager, setup_lab_experiment, db_manager, clean_db): - return ContainerManager(configuration_manager, db_manager) - - @pytest.mark.parametrize("setup_lab_experiment", [("small_lab", "water_purification")], indirect=True) class TestContainerManager: - def test_set_container_location(self, container_manager): - container_id = "acf829f859e04fee80d54a1ee918555d" - container_manager.set_location(container_id, "new_location") - - assert container_manager.get_container(container_id).location == "new_location" - - def test_set_container_lab(self, container_manager): - container_id = "acf829f859e04fee80d54a1ee918555d" - container_manager.set_lab(container_id, "new_lab") - - assert container_manager.get_container(container_id).lab == "new_lab" - - def test_set_container_metadata(self, container_manager): - container_id = "acf829f859e04fee80d54a1ee918555d" - container_manager.set_metadata(container_id, {"substance": "water"}) - container_manager.set_metadata(container_id, {"temperature": "cold"}) - - assert container_manager.get_container(container_id).metadata == {"temperature": "cold"} - - def test_add_container_metadata(self, container_manager): - container_id = "acf829f859e04fee80d54a1ee918555d" - container_manager.add_metadata(container_id, {"substance": "water"}) - container_manager.add_metadata(container_id, {"temperature": "cold"}) - - assert container_manager.get_container(container_id).metadata == { - "capacity": 500, - "substance": "water", - "temperature": "cold", - } - - def test_remove_container_metadata(self, container_manager): - container_id = "acf829f859e04fee80d54a1ee918555d" - container_manager.add_metadata(container_id, {"substance": "water", "temperature": "cold", "color": "blue"}) - container_manager.remove_metadata(container_id, ["color", "temperature"]) - - assert container_manager.get_container(container_id).metadata == {"capacity": 500, "substance": "water"} + @pytest.mark.asyncio + async def test_set_container_location(self, container_manager): + container_id = "acf829f859e04fee80d54a1ee918555d" + await container_manager.set_location(container_id, "new_location") + + container = await container_manager.get_container(container_id) + assert container.location == "new_location" + + @pytest.mark.asyncio + async def test_set_container_lab(self, container_manager): + container_id = "acf829f859e04fee80d54a1ee918555d" + await container_manager.set_lab(container_id, "new_lab") + + container = await container_manager.get_container(container_id) + assert container.lab == "new_lab" + + @pytest.mark.asyncio + async def test_set_container_metadata(self, container_manager): + container_id = "acf829f859e04fee80d54a1ee918555d" + await container_manager.set_metadata(container_id, {"substance": "water"}) + await container_manager.set_metadata(container_id, {"temperature": "cold"}) + + container = await container_manager.get_container(container_id) + assert container.metadata == {"temperature": "cold"} + + @pytest.mark.asyncio + async def test_add_container_metadata(self, container_manager): + container_id = "acf829f859e04fee80d54a1ee918555d" + await container_manager.add_metadata(container_id, {"substance": "water"}) + await container_manager.add_metadata(container_id, {"temperature": "cold"}) + + container = await container_manager.get_container(container_id) + assert container.metadata == { + "capacity": 500, + "substance": "water", + "temperature": "cold", + } + + @pytest.mark.asyncio + async def test_remove_container_metadata(self, container_manager): + container_id = "acf829f859e04fee80d54a1ee918555d" + await container_manager.add_metadata(container_id, + {"substance": "water", "temperature": "cold", "color": "blue"}) + await container_manager.remove_metadata(container_id, ["color", "temperature"]) + + container = await container_manager.get_container(container_id) + assert container.metadata == {"capacity": 500, "substance": "water"} diff --git a/tests/test_device_allocator.py b/tests/test_device_allocator.py index 9d69da4..097b291 100644 --- a/tests/test_device_allocator.py +++ b/tests/test_device_allocator.py @@ -9,11 +9,12 @@ @pytest.mark.parametrize("setup_lab_experiment", [(LAB_ID, "water_purification")], indirect=True) class TestDeviceAllocator: - def test_allocate_device(self, device_allocator): + @pytest.mark.asyncio + async def test_allocate_device(self, device_allocator): device_id = "magnetic_mixer" - device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") - allocation = device_allocator.get_allocation(LAB_ID, device_id) + allocation = await device_allocator.get_allocation(LAB_ID, device_id) assert allocation.id == device_id assert allocation.lab_id == LAB_ID @@ -21,115 +22,125 @@ def test_allocate_device(self, device_allocator): assert allocation.owner == "owner" assert allocation.experiment_id == "water_purification_1" - def test_allocate_device_already_allocated(self, device_allocator): + @pytest.mark.asyncio + async def test_allocate_device_already_allocated(self, device_allocator): device_id = "magnetic_mixer" - device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") with pytest.raises(EosDeviceAllocatedError): - device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") - def test_allocate_nonexistent_device(self, device_allocator): + @pytest.mark.asyncio + async def test_allocate_nonexistent_device(self, device_allocator): device_id = "nonexistent_device_id" with pytest.raises(EosDeviceNotFoundError): - device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") - def test_deallocate_device(self, device_allocator): + @pytest.mark.asyncio + async def test_deallocate_device(self, device_allocator): device_id = "magnetic_mixer" - device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") - device_allocator.deallocate(LAB_ID, device_id) - allocation = device_allocator.get_allocation(LAB_ID, device_id) + await device_allocator.deallocate(LAB_ID, device_id) + allocation = await device_allocator.get_allocation(LAB_ID, device_id) assert allocation is None - def test_deallocate_device_not_allocated(self, device_allocator): + @pytest.mark.asyncio + async def test_deallocate_device_not_allocated(self, device_allocator): device_id = "magnetic_mixer" - device_allocator.deallocate(LAB_ID, device_id) - assert device_allocator.get_allocation(LAB_ID, device_id) is None + await device_allocator.deallocate(LAB_ID, device_id) + assert await device_allocator.get_allocation(LAB_ID, device_id) is None - def test_is_allocated(self, device_allocator): + @pytest.mark.asyncio + async def test_is_allocated(self, device_allocator): device_id = "magnetic_mixer" - assert not device_allocator.is_allocated(LAB_ID, device_id) + assert not await device_allocator.is_allocated(LAB_ID, device_id) - device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") - assert device_allocator.is_allocated(LAB_ID, device_id) + await device_allocator.allocate(LAB_ID, device_id, "owner", "water_purification_1") + assert await device_allocator.is_allocated(LAB_ID, device_id) - def test_get_allocations_by_owner(self, device_allocator): + @pytest.mark.asyncio + async def test_get_allocations_by_owner(self, device_allocator): device_id_1 = "magnetic_mixer" device_id_2 = "evaporator" device_id_3 = "substance_fridge" - device_allocator.allocate(LAB_ID, device_id_1, "owner1", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_2, "owner1", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_3, "owner2", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_1, "owner1", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_2, "owner1", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_3, "owner2", "water_purification_1") - allocations = device_allocator.get_allocations(owner="owner1") + allocations = await device_allocator.get_allocations(owner="owner1") assert len(allocations) == 2 assert device_id_1 in [allocation.id for allocation in allocations] assert device_id_2 in [allocation.id for allocation in allocations] - def test_get_all_allocations(self, device_allocator): + @pytest.mark.asyncio + async def test_get_all_allocations(self, device_allocator): device_id_1 = "magnetic_mixer" device_id_2 = "evaporator" device_id_3 = "substance_fridge" - device_allocator.allocate(LAB_ID, device_id_1, "owner", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_2, "owner", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_3, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_1, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_2, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_3, "owner", "water_purification_1") - allocations = device_allocator.get_allocations() + allocations = await device_allocator.get_allocations() assert len(allocations) == 3 assert device_id_1 in [allocation.id for allocation in allocations] assert device_id_2 in [allocation.id for allocation in allocations] assert device_id_3 in [allocation.id for allocation in allocations] - def test_get_all_unallocated(self, device_allocator): + @pytest.mark.asyncio + async def test_get_all_unallocated(self, device_allocator): device_id_1 = "magnetic_mixer" device_id_2 = "evaporator" device_id_3 = "substance_fridge" - initial_unallocated_devices = device_allocator.get_all_unallocated() + initial_unallocated_devices = await device_allocator.get_all_unallocated() - device_allocator.allocate(LAB_ID, device_id_1, "owner", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_2, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_1, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_2, "owner", "water_purification_1") - new_unallocated_devices = device_allocator.get_all_unallocated() + new_unallocated_devices = await device_allocator.get_all_unallocated() assert len(new_unallocated_devices) == len(initial_unallocated_devices) - 2 assert device_id_1 not in new_unallocated_devices assert device_id_2 not in new_unallocated_devices assert device_id_3 in new_unallocated_devices - def test_deallocate_all(self, device_allocator): + @pytest.mark.asyncio + async def test_deallocate_all(self, device_allocator): device_id_1 = "magnetic_mixer" device_id_2 = "evaporator" device_id_3 = "substance_fridge" - device_allocator.allocate(LAB_ID, device_id_1, "owner", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_2, "owner", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_3, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_1, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_2, "owner", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_3, "owner", "water_purification_1") - assert device_allocator.get_allocations() != [] + assert await device_allocator.get_allocations() != [] - device_allocator.deallocate_all() + await device_allocator.deallocate_all() - assert device_allocator.get_allocations() == [] + assert await device_allocator.get_allocations() == [] - def test_deallocate_all_by_owner(self, device_allocator): + @pytest.mark.asyncio + async def test_deallocate_all_by_owner(self, device_allocator): device_id_1 = "magnetic_mixer" device_id_2 = "evaporator" device_id_3 = "substance_fridge" - device_allocator.allocate(LAB_ID, device_id_1, "owner1", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_2, "owner2", "water_purification_1") - device_allocator.allocate(LAB_ID, device_id_3, "owner2", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_1, "owner1", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_2, "owner2", "water_purification_1") + await device_allocator.allocate(LAB_ID, device_id_3, "owner2", "water_purification_1") - device_allocator.deallocate_all_by_owner("owner2") + await device_allocator.deallocate_all_by_owner("owner2") - owner2_allocations = device_allocator.get_allocations(owner="owner2") + owner2_allocations = await device_allocator.get_allocations(owner="owner2") assert owner2_allocations == [] - assert device_allocator.get_allocations() == [ - device_allocator.get_allocation(LAB_ID, device_id_1) + assert await device_allocator.get_allocations() == [ + await device_allocator.get_allocation(LAB_ID, device_id_1) ] diff --git a/tests/test_device_manager.py b/tests/test_device_manager.py index c281257..e21ee40 100644 --- a/tests/test_device_manager.py +++ b/tests/test_device_manager.py @@ -7,31 +7,37 @@ @pytest.mark.parametrize("setup_lab_experiment", [(LAB_ID, "water_purification")], indirect=True) class TestDeviceManager: - def test_get_device(self, device_manager): - device = device_manager.get_device(LAB_ID, "substance_fridge") + @pytest.mark.asyncio + async def test_get_device(self, device_manager): + device = await device_manager.get_device(LAB_ID, "substance_fridge") assert device.id == "substance_fridge" assert device.lab_id == LAB_ID assert device.type == "fridge" assert device.location == "substance_fridge" - def test_get_device_nonexistent(self, device_manager): - device = device_manager.get_device(LAB_ID, "nonexistent_device") + @pytest.mark.asyncio + async def test_get_device_nonexistent(self, device_manager): + device = await device_manager.get_device(LAB_ID, "nonexistent_device") assert device is None - def test_get_all_devices(self, device_manager): - devices = device_manager.get_devices(lab_id=LAB_ID) + @pytest.mark.asyncio + async def test_get_all_devices(self, device_manager): + devices = await device_manager.get_devices(lab_id=LAB_ID) assert len(devices) == 5 - def test_get_devices_by_type(self, device_manager): - devices = device_manager.get_devices(lab_id=LAB_ID, type="magnetic_mixer") + @pytest.mark.asyncio + async def test_get_devices_by_type(self, device_manager): + devices = await device_manager.get_devices(lab_id=LAB_ID, type="magnetic_mixer") assert len(devices) == 2 assert all(device.type == "magnetic_mixer" for device in devices) - def test_set_device_status(self, device_manager): - device_manager.set_device_status(LAB_ID, "evaporator", DeviceStatus.ACTIVE) - device = device_manager.get_device(LAB_ID, "evaporator") + @pytest.mark.asyncio + async def test_set_device_status(self, device_manager): + await device_manager.set_device_status(LAB_ID, "evaporator", DeviceStatus.ACTIVE) + device = await device_manager.get_device(LAB_ID, "evaporator") assert device.status == DeviceStatus.ACTIVE - def test_set_device_status_nonexistent(self, device_manager): + @pytest.mark.asyncio + async def test_set_device_status_nonexistent(self, device_manager): with pytest.raises(EosDeviceStateError): - device_manager.set_device_status(LAB_ID, "nonexistent_device", DeviceStatus.INACTIVE) + await device_manager.set_device_status(LAB_ID, "nonexistent_device", DeviceStatus.INACTIVE) diff --git a/tests/test_experiment_executor.py b/tests/test_experiment_executor.py index dd75983..e8a2de3 100644 --- a/tests/test_experiment_executor.py +++ b/tests/test_experiment_executor.py @@ -33,10 +33,11 @@ indirect=True, ) class TestExperimentExecutor: - def test_start_experiment(self, experiment_executor, experiment_manager): - experiment_executor.start_experiment(DYNAMIC_PARAMETERS) + @pytest.mark.asyncio + async def test_start_experiment(self, experiment_executor, experiment_manager): + await experiment_executor.start_experiment(DYNAMIC_PARAMETERS) - experiment = experiment_manager.get_experiment(EXPERIMENT_ID) + experiment = await experiment_manager.get_experiment(EXPERIMENT_ID) assert experiment is not None assert experiment.id == EXPERIMENT_ID assert experiment.status == ExperimentStatus.RUNNING @@ -44,7 +45,7 @@ def test_start_experiment(self, experiment_executor, experiment_manager): @pytest.mark.slow @pytest.mark.asyncio async def test_progress_experiment(self, experiment_executor, experiment_manager, task_manager): - experiment_executor.start_experiment(DYNAMIC_PARAMETERS) + await experiment_executor.start_experiment(DYNAMIC_PARAMETERS) experiment_completed = await experiment_executor.progress_experiment() assert not experiment_completed @@ -52,48 +53,48 @@ async def test_progress_experiment(self, experiment_executor, experiment_manager experiment_completed = await experiment_executor.progress_experiment() assert not experiment_completed - task = task_manager.get_task(EXPERIMENT_ID, "mixing") + task = await task_manager.get_task(EXPERIMENT_ID, "mixing") assert task is not None assert task.status == TaskStatus.COMPLETED await experiment_executor._task_output_futures["evaporation"] experiment_completed = await experiment_executor.progress_experiment() - task = task_manager.get_task(EXPERIMENT_ID, "evaporation") + task = await task_manager.get_task(EXPERIMENT_ID, "evaporation") assert task.status == TaskStatus.COMPLETED assert not experiment_completed # Final progress experiment_completed = await experiment_executor.progress_experiment() assert experiment_completed - experiment = experiment_manager.get_experiment(EXPERIMENT_ID) + experiment = await experiment_manager.get_experiment(EXPERIMENT_ID) assert experiment.status == ExperimentStatus.COMPLETED @pytest.mark.asyncio async def test_task_output_registration(self, experiment_executor, task_manager): - experiment_executor.start_experiment(DYNAMIC_PARAMETERS) + await experiment_executor.start_experiment(DYNAMIC_PARAMETERS) experiment_completed = False while not experiment_completed: experiment_completed = await experiment_executor.progress_experiment() await asyncio.sleep(0.1) - mixing_output = task_manager.get_task_output(EXPERIMENT_ID, "mixing") + mixing_output = await task_manager.get_task_output(EXPERIMENT_ID, "mixing") assert mixing_output is not None assert mixing_output.parameters["mixing_time"] == DYNAMIC_PARAMETERS["mixing"]["time"] @pytest.mark.asyncio async def test_resolve_input_parameter_references_and_dynamic_parameters(self, experiment_executor, task_manager): - experiment_executor.start_experiment(DYNAMIC_PARAMETERS) + await experiment_executor.start_experiment(DYNAMIC_PARAMETERS) experiment_completed = False while not experiment_completed: experiment_completed = await experiment_executor.progress_experiment() await asyncio.sleep(0.1) - mixing_task = task_manager.get_task(EXPERIMENT_ID, "mixing") - mixing_result = task_manager.get_task_output(EXPERIMENT_ID, "mixing") + mixing_task = await task_manager.get_task(EXPERIMENT_ID, "mixing") + mixing_result = await task_manager.get_task_output(EXPERIMENT_ID, "mixing") - evaporation_task = task_manager.get_task(EXPERIMENT_ID, "evaporation") + evaporation_task = await task_manager.get_task(EXPERIMENT_ID, "evaporation") # Check the dynamic parameter for input mixing time assert mixing_task.input.parameters["time"] == DYNAMIC_PARAMETERS["mixing"]["time"] @@ -110,11 +111,12 @@ async def test_resolve_input_parameter_references_and_dynamic_parameters(self, e ExperimentStatus.RUNNING, ], ) - def test_handle_existing_experiment(self, experiment_executor, experiment_manager, experiment_status): - experiment_manager.create_experiment( + @pytest.mark.asyncio + async def test_handle_existing_experiment(self, experiment_executor, experiment_manager, experiment_status): + await experiment_manager.create_experiment( EXPERIMENT_ID, EXPERIMENT_TYPE, experiment_executor._execution_parameters, {}, {} ) - experiment_manager._set_experiment_status(EXPERIMENT_ID, experiment_status) + await experiment_manager._set_experiment_status(EXPERIMENT_ID, experiment_status) experiment_executor._execution_parameters.resume = False with patch.object(experiment_executor, "_resume_experiment") as mock_resume: @@ -125,16 +127,19 @@ def test_handle_existing_experiment(self, experiment_executor, experiment_manage ExperimentStatus.FAILED, ]: with pytest.raises(EosExperimentExecutionError) as exc_info: - experiment_executor._handle_existing_experiment(experiment_manager.get_experiment(EXPERIMENT_ID)) + experiment = await experiment_manager.get_experiment(EXPERIMENT_ID) + await experiment_executor._handle_existing_experiment(experiment) assert experiment_status.name.lower() in str(exc_info.value) mock_resume.assert_not_called() else: - experiment_executor._handle_existing_experiment(experiment_manager.get_experiment(EXPERIMENT_ID)) + experiment = await experiment_manager.get_experiment(EXPERIMENT_ID) + await experiment_executor._handle_existing_experiment(experiment) mock_resume.assert_not_called() experiment_executor._execution_parameters.resume = True with patch.object(experiment_executor, "_resume_experiment") as mock_resume: - experiment_executor._handle_existing_experiment(experiment_manager.get_experiment(EXPERIMENT_ID)) + experiment = await experiment_manager.get_experiment(EXPERIMENT_ID) + await experiment_executor._handle_existing_experiment(experiment) mock_resume.assert_called_once() assert experiment_executor._experiment_status == experiment_status diff --git a/tests/test_experiment_manager.py b/tests/test_experiment_manager.py index 2f297c6..7301d94 100644 --- a/tests/test_experiment_manager.py +++ b/tests/test_experiment_manager.py @@ -7,87 +7,97 @@ @pytest.mark.parametrize("setup_lab_experiment", [("small_lab", EXPERIMENT_ID)], indirect=True) class TestExperimentManager: - def test_create_experiment(self, experiment_manager): - experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) - experiment_manager.create_experiment("test_experiment_2", EXPERIMENT_ID) - - assert experiment_manager.get_experiment("test_experiment").id == "test_experiment" - assert experiment_manager.get_experiment("test_experiment_2").id == "test_experiment_2" - - def test_create_experiment_nonexistent_type(self, experiment_manager): + @pytest.mark.asyncio + async def test_create_experiment(self, experiment_manager): + await experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) + await experiment_manager.create_experiment("test_experiment_2", EXPERIMENT_ID) + + experiment1 = await experiment_manager.get_experiment("test_experiment") + assert experiment1.id == "test_experiment" + experiment2 = await experiment_manager.get_experiment("test_experiment_2") + assert experiment2.id == "test_experiment_2" + + @pytest.mark.asyncio + async def test_create_experiment_nonexistent_type(self, experiment_manager): with pytest.raises(EosExperimentStateError): - experiment_manager.create_experiment("test_experiment", "nonexistent_type") + await experiment_manager.create_experiment("test_experiment", "nonexistent_type") - def test_create_existing_experiment(self, experiment_manager): - experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) + @pytest.mark.asyncio + async def test_create_existing_experiment(self, experiment_manager): + await experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) with pytest.raises(EosExperimentStateError): - experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) + await experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) - def test_delete_experiment(self, experiment_manager): - experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) + @pytest.mark.asyncio + async def test_delete_experiment(self, experiment_manager): + await experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) - assert experiment_manager.get_experiment("test_experiment").id == "test_experiment" + experiment = await experiment_manager.get_experiment("test_experiment") + assert experiment.id == "test_experiment" - experiment_manager.delete_experiment("test_experiment") + await experiment_manager.delete_experiment("test_experiment") - assert experiment_manager.get_experiment("test_experiment") is None + experiment = await experiment_manager.get_experiment("test_experiment") + assert experiment is None - def test_delete_nonexisting_experiment(self, experiment_manager): + @pytest.mark.asyncio + async def test_delete_nonexisting_experiment(self, experiment_manager): with pytest.raises(EosExperimentStateError): - experiment_manager.delete_experiment("non_existing_experiment") + await experiment_manager.delete_experiment("non_existing_experiment") - def test_get_experiments_by_status(self, experiment_manager): - experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) - experiment_manager.create_experiment("test_experiment_2", EXPERIMENT_ID) - experiment_manager.create_experiment("test_experiment_3", EXPERIMENT_ID) + @pytest.mark.asyncio + async def test_get_experiments_by_status(self, experiment_manager): + await experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) + await experiment_manager.create_experiment("test_experiment_2", EXPERIMENT_ID) + await experiment_manager.create_experiment("test_experiment_3", EXPERIMENT_ID) - experiment_manager.start_experiment("test_experiment") - experiment_manager.start_experiment("test_experiment_2") - experiment_manager.complete_experiment("test_experiment_3") + await experiment_manager.start_experiment("test_experiment") + await experiment_manager.start_experiment("test_experiment_2") + await experiment_manager.complete_experiment("test_experiment_3") - running_experiments = experiment_manager.get_experiments( + running_experiments = await experiment_manager.get_experiments( status=ExperimentStatus.RUNNING.value ) - completed_experiments = experiment_manager.get_experiments( + completed_experiments = await experiment_manager.get_experiments( status=ExperimentStatus.COMPLETED.value ) assert running_experiments == [ - experiment_manager.get_experiment("test_experiment"), - experiment_manager.get_experiment("test_experiment_2"), + await experiment_manager.get_experiment("test_experiment"), + await experiment_manager.get_experiment("test_experiment_2"), ] - assert completed_experiments == [experiment_manager.get_experiment("test_experiment_3")] + assert completed_experiments == [await experiment_manager.get_experiment("test_experiment_3")] - def test_set_experiment_status(self, experiment_manager): - experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) - assert ( - experiment_manager.get_experiment("test_experiment").status == ExperimentStatus.CREATED - ) + @pytest.mark.asyncio + async def test_set_experiment_status(self, experiment_manager): + await experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) + experiment = await experiment_manager.get_experiment("test_experiment") + assert experiment.status == ExperimentStatus.CREATED - experiment_manager.start_experiment("test_experiment") - assert ( - experiment_manager.get_experiment("test_experiment").status == ExperimentStatus.RUNNING - ) + await experiment_manager.start_experiment("test_experiment") + experiment = await experiment_manager.get_experiment("test_experiment") + assert experiment.status == ExperimentStatus.RUNNING - experiment_manager.complete_experiment("test_experiment") - assert ( - experiment_manager.get_experiment("test_experiment").status - == ExperimentStatus.COMPLETED - ) + await experiment_manager.complete_experiment("test_experiment") + experiment = await experiment_manager.get_experiment("test_experiment") + assert experiment.status == ExperimentStatus.COMPLETED - def test_set_experiment_status_nonexistent_experiment(self, experiment_manager): + @pytest.mark.asyncio + async def test_set_experiment_status_nonexistent_experiment(self, experiment_manager): with pytest.raises(EosExperimentStateError): - experiment_manager.start_experiment("nonexistent_experiment") - - def test_get_all_experiments(self, experiment_manager): - experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) - experiment_manager.create_experiment("test_experiment_2", EXPERIMENT_ID) - experiment_manager.create_experiment("test_experiment_3", EXPERIMENT_ID) - - assert experiment_manager.get_experiments() == [ - experiment_manager.get_experiment("test_experiment"), - experiment_manager.get_experiment("test_experiment_2"), - experiment_manager.get_experiment("test_experiment_3"), + await experiment_manager.start_experiment("nonexistent_experiment") + + @pytest.mark.asyncio + async def test_get_all_experiments(self, experiment_manager): + await experiment_manager.create_experiment("test_experiment", EXPERIMENT_ID) + await experiment_manager.create_experiment("test_experiment_2", EXPERIMENT_ID) + await experiment_manager.create_experiment("test_experiment_3", EXPERIMENT_ID) + + experiments = await experiment_manager.get_experiments() + assert experiments == [ + await experiment_manager.get_experiment("test_experiment"), + await experiment_manager.get_experiment("test_experiment_2"), + await experiment_manager.get_experiment("test_experiment_3"), ] diff --git a/tests/test_greedy_scheduler.py b/tests/test_greedy_scheduler.py index 050be64..15ff627 100644 --- a/tests/test_greedy_scheduler.py +++ b/tests/test_greedy_scheduler.py @@ -17,17 +17,18 @@ def test_register_experiment(self, greedy_scheduler, experiment_graph, configura experiment_graph, ) - def test_unregister_experiment(self, greedy_scheduler, experiment_graph): + @pytest.mark.asyncio + async def test_unregister_experiment(self, greedy_scheduler, experiment_graph): greedy_scheduler.register_experiment("experiment_1", "abstract_experiment", experiment_graph) - greedy_scheduler.unregister_experiment("experiment_1") + await greedy_scheduler.unregister_experiment("experiment_1") assert "experiment_1" not in greedy_scheduler._registered_experiments @pytest.mark.asyncio async def test_correct_schedule(self, greedy_scheduler, experiment_graph, experiment_manager, task_manager): - def complete_task(task_id, task_type): - task_manager.create_task("experiment_1", task_id, task_type, []) - task_manager.start_task("experiment_1", task_id) - task_manager.complete_task("experiment_1", task_id) + async def complete_task(task_id, task_type): + await task_manager.create_task("experiment_1", task_id, task_type, []) + await task_manager.start_task("experiment_1", task_id) + await task_manager.complete_task("experiment_1", task_id) def get_task_if_exists(tasks, task_id): return next((task for task in tasks if task.id == task_id), None) @@ -37,36 +38,36 @@ def assert_task(task, task_id, device_lab_id, device_id): assert task.devices[0].lab_id == device_lab_id assert task.devices[0].id == device_id - def process_and_assert(tasks, expected_tasks): + async def process_and_assert(tasks, expected_tasks): assert len(tasks) == len(expected_tasks) for task_id, device_lab_id, device_id in expected_tasks: task = get_task_if_exists(tasks, task_id) assert_task(task, task_id, device_lab_id, device_id) - complete_task(task_id, "Noop") + await complete_task(task_id, "Noop") - experiment_manager.create_experiment("experiment_1", "abstract_experiment") - experiment_manager.start_experiment("experiment_1") + await experiment_manager.create_experiment("experiment_1", "abstract_experiment") + await experiment_manager.start_experiment("experiment_1") greedy_scheduler.register_experiment("experiment_1", "abstract_experiment", experiment_graph) tasks = await greedy_scheduler.request_tasks("experiment_1") - process_and_assert(tasks, [("A", "abstract_lab", "D2")]) + await process_and_assert(tasks, [("A", "abstract_lab", "D2")]) tasks = await greedy_scheduler.request_tasks("experiment_1") - process_and_assert(tasks, [("B", "abstract_lab", "D1"), ("C", "abstract_lab", "D3")]) + await process_and_assert(tasks, [("B", "abstract_lab", "D1"), ("C", "abstract_lab", "D3")]) tasks = await greedy_scheduler.request_tasks("experiment_1") - process_and_assert( + await process_and_assert( tasks, [("D", "abstract_lab", "D1"), ("E", "abstract_lab", "D3"), ("F", "abstract_lab", "D2")], ) tasks = await greedy_scheduler.request_tasks("experiment_1") - process_and_assert(tasks, [("G", "abstract_lab", "D5")]) + await process_and_assert(tasks, [("G", "abstract_lab", "D5")]) tasks = await greedy_scheduler.request_tasks("experiment_1") - process_and_assert(tasks, [("H", "abstract_lab", "D6")]) + await process_and_assert(tasks, [("H", "abstract_lab", "D6")]) - assert greedy_scheduler.is_experiment_completed("experiment_1") + assert await greedy_scheduler.is_experiment_completed("experiment_1") tasks = await greedy_scheduler.request_tasks("experiment_1") assert len(tasks) == 0 diff --git a/tests/test_mongodb_async_repository.py b/tests/test_mongodb_async_repository.py new file mode 100644 index 0000000..1bd0593 --- /dev/null +++ b/tests/test_mongodb_async_repository.py @@ -0,0 +1,137 @@ +import asyncio + +import pytest +from pymongo.errors import DuplicateKeyError + +from eos.persistence.mongodb_async_repository import MongoDbAsyncRepository +from tests.fixtures import db_interface + + +class TestMongoDbAsyncRepository: + @pytest.fixture(scope="class") + def repository(self, db_interface): + return MongoDbAsyncRepository("test_collection", db_interface) + + @pytest.mark.asyncio + async def test_create_and_get_one(self, repository): + entity = {"name": "Test Entity", "value": 42} + result = await repository.create(entity) + assert result.acknowledged + + retrieved = await repository.get_one(name="Test Entity") + assert retrieved["name"] == "Test Entity" + assert retrieved["value"] == 42 + + @pytest.mark.asyncio + async def test_update_one(self, repository): + entity = {"name": "Update Test", "value": 10} + await repository.create(entity) + + updated_entity = {"value": 20} + result = await repository.update_one(updated_entity, name="Update Test") + assert result.modified_count == 1 + + retrieved = await repository.get_one(name="Update Test") + assert retrieved["value"] == 20 + + @pytest.mark.asyncio + async def test_delete_one(self, repository): + entity = {"name": "Delete Test", "value": 30} + await repository.create(entity) + + result = await repository.delete_one(name="Delete Test") + assert result.deleted_count == 1 + + retrieved = await repository.get_one(name="Delete Test") + assert retrieved is None + + @pytest.mark.asyncio + async def test_get_all(self, repository): + entities = [ + {"name": "Entity 1", "value": 1}, + {"name": "Entity 2", "value": 2}, + {"name": "Entity 3", "value": 3}, + ] + await asyncio.gather(*[repository.create(entity) for entity in entities]) + + retrieved = await repository.get_all() + assert len(retrieved) >= 3 + assert all(entity["name"] in [e["name"] for e in retrieved] for entity in entities) + + @pytest.mark.asyncio + async def test_count_and_exists(self, repository): + await repository.delete_all() + + entities = [ + {"name": "Count 1", "value": 1}, + {"name": "Count 2", "value": 2}, + ] + await asyncio.gather(*[repository.create(entity) for entity in entities]) + + count, exists, exists_more = await asyncio.gather( + repository.count(), repository.exists(count=2), repository.exists(count=3) + ) + + assert count == 2 + assert exists + assert not exists_more + + @pytest.mark.asyncio + async def test_delete_many(self, repository): + await repository.delete_all() + + entities = [ + {"name": "Delete Many 1", "value": 1}, + {"name": "Delete Many 2", "value": 2}, + {"name": "Keep", "value": 3}, + ] + await asyncio.gather(*[repository.create(entity) for entity in entities]) + + result = await repository.delete_many(name={"$regex": "Delete Many"}) + assert result.deleted_count == 2 + + remaining = await repository.get_all() + assert len(remaining) == 1 + assert remaining[0]["name"] == "Keep" + + @pytest.mark.asyncio + async def test_create_indices(self, repository): + indices = [("name", 1), ("value", -1)] + await repository.create_indices(indices, unique=True) + await repository.delete_all() + + # Verify that the index was created + info = await repository._collection.index_information() + assert "name_1_value_-1" in info + + # Test uniqueness constraint + await repository.create({"name": "Unique Test", "value": 300}) + with pytest.raises(DuplicateKeyError): + await repository.create({"name": "Unique Test", "value": 300}) + + @pytest.mark.asyncio + async def test_transaction_commit(self, repository, db_interface): + await repository.delete_all() + + async with db_interface.session_factory() as session: + entity = {"name": "Transaction Test", "value": 100} + await repository.create(entity, session=session) + await session.commit_transaction() + + retrieved = await repository.get_one(name="Transaction Test") + assert retrieved is not None + retrieved.pop("_id") + assert retrieved == {"name": "Transaction Test", "value": 100} + + @pytest.mark.asyncio + async def test_transaction_abort(self, repository, db_interface): + try: + async with db_interface.session_factory() as session: + entity = {"name": "Abort Test", "value": 200} + await repository.create(entity, session=session) + raise Exception("Simulated error") + except Exception: + pass # The transaction will be automatically aborted + + retrieved = await repository.get_one(name="Abort Test") + assert retrieved is None diff --git a/tests/test_on_demand_task_executor.py b/tests/test_on_demand_task_executor.py new file mode 100644 index 0000000..d91f014 --- /dev/null +++ b/tests/test_on_demand_task_executor.py @@ -0,0 +1,88 @@ +import asyncio + +from eos.configuration.entities.task import TaskConfig +from eos.tasks.entities.task import TaskStatus +from tests.fixtures import * + + +@pytest.mark.parametrize( + "setup_lab_experiment", + [("small_lab", "water_purification")], + indirect=True, +) +class TestOnDemandTaskExecutor: + @pytest.mark.asyncio + async def test_execute_on_demand_task(self, on_demand_task_executor, task_manager): + task_config = TaskConfig( + id="mixing", + type="Magnetic Mixing", + description="Mixing task", + parameters={"time": 5}, + ) + + on_demand_task_executor.submit_task(task_config) + await on_demand_task_executor.process_tasks() + + while True: + await on_demand_task_executor.process_tasks() + task = await task_manager.get_task("on_demand", "mixing") + if task and task.status == TaskStatus.COMPLETED: + break + await asyncio.sleep(0.5) + + assert task.status == TaskStatus.COMPLETED + assert task.output.parameters["mixing_time"] == 5 + + @pytest.mark.asyncio + async def test_on_demand_task_output(self, on_demand_task_executor, task_manager): + task_config = TaskConfig( + "file_gen", + type="File Generation", + description="File generation task", + parameters={"content_length": 32}, + ) + + on_demand_task_executor.submit_task(task_config) + await on_demand_task_executor.process_tasks() + + while True: + await on_demand_task_executor.process_tasks() + task = await task_manager.get_task("on_demand", "file_gen") + if task and task.status == TaskStatus.COMPLETED: + break + await asyncio.sleep(0.5) + + assert task.status == TaskStatus.COMPLETED + file = task_manager.get_task_output_file("on_demand", "file_gen", "file.txt") + + assert len(file) == 32 + + @pytest.mark.asyncio + async def test_request_task_cancellation(self, on_demand_task_executor, task_manager): + task_config = TaskConfig( + id="sleep", + type="Sleep", + description="Sleeping task", + parameters={"time": 20}, + ) + + on_demand_task_executor.submit_task(task_config) + await on_demand_task_executor.process_tasks() + + iterations = 0 + while True: + await on_demand_task_executor.process_tasks() + task = await task_manager.get_task("on_demand", "sleep") + if task and task.status != TaskStatus.RUNNING: + break + await asyncio.sleep(0.5) + iterations += 1 + + if iterations > 5: + await on_demand_task_executor.request_task_cancellation("sleep") + + if iterations > 20: + raise Exception("Task did not cancel in time") + + task = await task_manager.get_task("on_demand", "sleep") + assert task.status == TaskStatus.CANCELLED diff --git a/tests/test_resource_allocation_manager.py b/tests/test_resource_allocation_manager.py index cc4a62a..d5fbc3d 100644 --- a/tests/test_resource_allocation_manager.py +++ b/tests/test_resource_allocation_manager.py @@ -14,7 +14,8 @@ @pytest.mark.parametrize("setup_lab_experiment", [(LAB_ID, "water_purification")], indirect=True) class TestResourceAllocationManager: - def test_request_resources(self, resource_allocation_manager): + @pytest.mark.asyncio + async def test_request_resources(self, resource_allocation_manager): request = ResourceAllocationRequest( requester="test_requester", reason="Needed for experiment", @@ -29,14 +30,15 @@ def callback(active_request: ActiveResourceAllocationRequest): assert any(r.id == "magnetic_mixer" for r in active_request.request.resources) assert any(r.id == "026749f8f40342b38157f9824ae2f512" for r in active_request.request.resources) - active_request = resource_allocation_manager.request_resources(request, callback) + active_request = await resource_allocation_manager.request_resources(request, callback) assert active_request.request == request assert active_request.status == ResourceRequestAllocationStatus.PENDING - resource_allocation_manager.process_active_requests() + await resource_allocation_manager.process_active_requests() - def test_request_resources_priority(self, resource_allocation_manager): + @pytest.mark.asyncio + async def test_request_resources_priority(self, resource_allocation_manager): requests = [ ResourceAllocationRequest( requester=f"test_requester{i}", @@ -49,41 +51,42 @@ def test_request_resources_priority(self, resource_allocation_manager): for request in requests: request.add_resource("magnetic_mixer", LAB_ID, ResourceType.DEVICE) - active_requests = [resource_allocation_manager.request_resources(req, lambda x: None) for req in requests] - resource_allocation_manager.process_active_requests() + active_requests = [await resource_allocation_manager.request_resources(req, lambda x: None) for req in requests] + await resource_allocation_manager.process_active_requests() # Ensure that requests[0] is allocated and the rest are pending - active_request_3 = resource_allocation_manager.get_active_request(active_requests[2].id) + active_request_3 = await resource_allocation_manager.get_active_request(active_requests[2].id) assert active_request_3.status == ResourceRequestAllocationStatus.PENDING assert active_request_3.request.requester == "test_requester3" assert active_request_3.request.priority == 103 - active_request_2 = resource_allocation_manager.get_active_request(active_requests[1].id) + active_request_2 = await resource_allocation_manager.get_active_request(active_requests[1].id) assert active_request_2.status == ResourceRequestAllocationStatus.PENDING assert active_request_2.request.requester == "test_requester2" assert active_request_2.request.priority == 102 - active_request_1 = resource_allocation_manager.get_active_request(active_requests[0].id) + active_request_1 = await resource_allocation_manager.get_active_request(active_requests[0].id) assert active_request_1.status == ResourceRequestAllocationStatus.ALLOCATED assert active_request_1.request.requester == "test_requester1" assert active_request_1.request.priority == 101 - resource_allocation_manager.release_resources(active_request_1) + await resource_allocation_manager.release_resources(active_request_1) - resource_allocation_manager.process_active_requests() + await resource_allocation_manager.process_active_requests() # Ensure that requests[1] is now allocated and requests[2] is still pending - active_request_3 = resource_allocation_manager.get_active_request(active_requests[2].id) + active_request_3 = await resource_allocation_manager.get_active_request(active_requests[2].id) assert active_request_3.status == ResourceRequestAllocationStatus.PENDING assert active_request_3.request.requester == "test_requester3" assert active_request_3.request.priority == 103 - active_request_2 = resource_allocation_manager.get_active_request(active_requests[1].id) + active_request_2 = await resource_allocation_manager.get_active_request(active_requests[1].id) assert active_request_2.status == ResourceRequestAllocationStatus.ALLOCATED assert active_request_2.request.requester == "test_requester2" assert active_request_2.request.priority == 102 - def test_release_resources(self, resource_allocation_manager): + @pytest.mark.asyncio + async def test_release_resources(self, resource_allocation_manager): request = ResourceAllocationRequest( requester="test_requester", reason="Needed for experiment", @@ -93,18 +96,17 @@ def test_release_resources(self, resource_allocation_manager): request.add_resource("magnetic_mixer", LAB_ID, ResourceType.DEVICE) request.add_resource("026749f8f40342b38157f9824ae2f512", "", ResourceType.CONTAINER) - active_request = resource_allocation_manager.request_resources(request, lambda x: None) + active_request = await resource_allocation_manager.request_resources(request, lambda x: None) - resource_allocation_manager.process_active_requests() + await resource_allocation_manager.process_active_requests() - resource_allocation_manager.release_resources(active_request) + await resource_allocation_manager.release_resources(active_request) - assert ( - resource_allocation_manager.get_active_request(active_request.id).status - == ResourceRequestAllocationStatus.COMPLETED - ) + active_request = await resource_allocation_manager.get_active_request(active_request.id) + assert active_request.status == ResourceRequestAllocationStatus.COMPLETED - def test_process_active_requests(self, resource_allocation_manager): + @pytest.mark.asyncio + async def test_process_active_requests(self, resource_allocation_manager): requests = [ ResourceAllocationRequest( requester=f"test_requester{i}", @@ -116,20 +118,18 @@ def test_process_active_requests(self, resource_allocation_manager): for request in requests: request.add_resource("magnetic_mixer", LAB_ID, ResourceType.DEVICE) - active_requests = [resource_allocation_manager.request_resources(req, lambda x: None) for req in requests] + active_requests = [await resource_allocation_manager.request_resources(req, lambda x: None) for req in requests] - resource_allocation_manager.process_active_requests() + await resource_allocation_manager.process_active_requests() - assert ( - resource_allocation_manager.get_active_request(active_requests[0].id).status - == ResourceRequestAllocationStatus.ALLOCATED - ) - assert ( - resource_allocation_manager.get_active_request(active_requests[1].id).status - == ResourceRequestAllocationStatus.PENDING - ) + active_request = await resource_allocation_manager.get_active_request(active_requests[0].id) + assert active_request.status == ResourceRequestAllocationStatus.ALLOCATED - def test_abort_active_request(self, resource_allocation_manager): + active_request = await resource_allocation_manager.get_active_request(active_requests[1].id) + assert active_request.status == ResourceRequestAllocationStatus.PENDING + + @pytest.mark.asyncio + async def test_abort_active_request(self, resource_allocation_manager): request = ResourceAllocationRequest( requester="test_requester", reason="Needed for experiment", @@ -138,18 +138,18 @@ def test_abort_active_request(self, resource_allocation_manager): request.add_resource("magnetic_mixer", LAB_ID, ResourceType.DEVICE) request.add_resource("magnetic_mixer_2", LAB_ID, ResourceType.DEVICE) - active_request = resource_allocation_manager.request_resources(request, lambda x: None) + active_request = await resource_allocation_manager.request_resources(request, lambda x: None) - resource_allocation_manager.abort_active_request(active_request.id) + await resource_allocation_manager.abort_active_request(active_request.id) - assert resource_allocation_manager.get_active_request(active_request.id).status == ( - ResourceRequestAllocationStatus.ABORTED - ) + active_request = await resource_allocation_manager.get_active_request(active_request.id) + assert active_request.status == ResourceRequestAllocationStatus.ABORTED - assert not resource_allocation_manager._device_allocation_manager.is_allocated(LAB_ID, "magnetic_mixer") - assert not resource_allocation_manager._device_allocation_manager.is_allocated(LAB_ID, "magnetic_mixer_2") + assert not await resource_allocation_manager._device_allocator.is_allocated(LAB_ID, "magnetic_mixer") + assert not await resource_allocation_manager._device_allocator.is_allocated(LAB_ID, "magnetic_mixer_2") - def test_get_all_active_requests(self, resource_allocation_manager): + @pytest.mark.asyncio + async def test_get_all_active_requests(self, resource_allocation_manager): requests = [ ResourceAllocationRequest( requester=f"test_requester{i}", @@ -162,18 +162,20 @@ def test_get_all_active_requests(self, resource_allocation_manager): requests[1].add_resource("026749f8f40342b38157f9824ae2f512", "", ResourceType.CONTAINER) for request in requests: - resource_allocation_manager.request_resources(request, lambda x: None) + await resource_allocation_manager.request_resources(request, lambda x: None) - all_active_requests = resource_allocation_manager.get_all_active_requests() + all_active_requests = await resource_allocation_manager.get_all_active_requests() assert len(all_active_requests) == 2 assert all_active_requests[0].request == requests[0] assert all_active_requests[1].request == requests[1] - def test_get_active_request_nonexistent(self, resource_allocation_manager): + @pytest.mark.asyncio + async def test_get_active_request_nonexistent(self, resource_allocation_manager): nonexistent_id = ObjectId() - assert resource_allocation_manager.get_active_request(nonexistent_id) is None + assert await resource_allocation_manager.get_active_request(nonexistent_id) is None - def test_clean_requests(self, resource_allocation_manager): + @pytest.mark.asyncio + async def test_clean_requests(self, resource_allocation_manager): request = ResourceAllocationRequest( requester="test_requester", reason="Needed for experiment", @@ -181,20 +183,19 @@ def test_clean_requests(self, resource_allocation_manager): ) request.add_resource("magnetic_mixer", LAB_ID, ResourceType.DEVICE) - active_request = resource_allocation_manager.request_resources(request, lambda x: None) - resource_allocation_manager.process_active_requests() - resource_allocation_manager.release_resources(active_request) + active_request = await resource_allocation_manager.request_resources(request, lambda x: None) + await resource_allocation_manager.process_active_requests() + await resource_allocation_manager.release_resources(active_request) - assert ( - resource_allocation_manager.get_active_request(active_request.id).status - == ResourceRequestAllocationStatus.COMPLETED - ) + active_request = await resource_allocation_manager.get_active_request(active_request.id) + assert active_request.status == ResourceRequestAllocationStatus.COMPLETED - resource_allocation_manager._clean_completed_and_aborted_requests() + await resource_allocation_manager._clean_completed_and_aborted_requests() - assert len(resource_allocation_manager.get_all_active_requests()) == 0 + assert len(await resource_allocation_manager.get_all_active_requests()) == 0 - def test_all_or_nothing_allocation(self, resource_allocation_manager): + @pytest.mark.asyncio + async def test_all_or_nothing_allocation(self, resource_allocation_manager): request = ResourceAllocationRequest( requester="test_requester", reason="Needed for experiment", @@ -204,13 +205,13 @@ def test_all_or_nothing_allocation(self, resource_allocation_manager): request.add_resource("nonexistent_device", LAB_ID, ResourceType.DEVICE) with pytest.raises(EosDeviceNotFoundError): - active_request = resource_allocation_manager.request_resources(request, lambda x: None) - resource_allocation_manager.process_active_requests() + active_request = await resource_allocation_manager.request_resources(request, lambda x: None) + await resource_allocation_manager.process_active_requests() assert active_request.status == ResourceRequestAllocationStatus.PENDING # Verify that neither resource was allocated - assert not resource_allocation_manager._device_allocation_manager.is_allocated(LAB_ID, "magnetic_mixer") + assert not await resource_allocation_manager._device_allocator.is_allocated(LAB_ID, "magnetic_mixer") with pytest.raises(EosDeviceNotFoundError): - assert not resource_allocation_manager._device_allocation_manager.is_allocated(LAB_ID, "nonexistent_device") + assert not await resource_allocation_manager._device_allocator.is_allocated(LAB_ID, "nonexistent_device") diff --git a/tests/test_task_executor.py b/tests/test_task_executor.py index d034fe2..c7ef03e 100644 --- a/tests/test_task_executor.py +++ b/tests/test_task_executor.py @@ -23,7 +23,7 @@ async def test_request_task_execution( experiment_manager, experiment_graph, ): - experiment_manager.create_experiment("water_purification", "water_purification") + await experiment_manager.create_experiment("water_purification", "water_purification") task_config = experiment_graph.get_task_config("mixing") task_config.parameters["time"] = 5 @@ -56,10 +56,10 @@ async def test_request_task_execution_resource_request_timeout( requester="tester", ) request.add_resource("magnetic_mixer", "small_lab", ResourceType.DEVICE) - active_request = resource_allocation_manager.request_resources(request, lambda requests: None) - resource_allocation_manager.process_active_requests() + active_request = await resource_allocation_manager.request_resources(request, lambda requests: None) + await resource_allocation_manager.process_active_requests() - experiment_manager.create_experiment("water_purification", "water_purification") + await experiment_manager.create_experiment("water_purification", "water_purification") task_config = experiment_graph.get_task_config("mixing") task_config.parameters["time"] = 5 @@ -72,29 +72,38 @@ async def test_request_task_execution_resource_request_timeout( with pytest.raises(EosTaskResourceAllocationError): await task_executor.request_task_execution(task_parameters) - resource_allocation_manager.release_resources(active_request) + await resource_allocation_manager.release_resources(active_request) @pytest.mark.asyncio async def test_request_task_cancellation(self, task_executor, experiment_manager): - experiment_manager.create_experiment("water_purification", "water_purification") + await experiment_manager.create_experiment("water_purification", "water_purification") sleep_config = TaskConfig( id="sleep_task", type="Sleep", devices=[TaskDeviceConfig(lab_id="small_lab", id="general_computer")], - parameters={"sleep_time": 2}, + parameters={"sleep_time": 5}, # 5 seconds to ensure it's still running when we cancel ) task_parameters = TaskExecutionParameters( experiment_id="water_purification", task_config=sleep_config, ) - tasks = set() - task = asyncio.create_task(task_executor.request_task_execution(task_parameters)) - tasks.add(task) - await asyncio.sleep(1) + async def run_task(): + return await task_executor.request_task_execution(task_parameters) - await task_executor.request_task_cancellation(task_parameters.experiment_id, task_parameters.task_config.id) + async def cancel_task(): + await asyncio.sleep(2) # Wait for 2 seconds before cancelling + assert task_executor._active_tasks == {"water_purification": {"sleep_task": task_parameters}} + await task_executor.request_task_cancellation(task_parameters.experiment_id, task_parameters.task_config.id) - assert True + # Use asyncio.gather to run both coroutines concurrently + task_result, _ = await asyncio.gather( + run_task(), + cancel_task(), + return_exceptions=True # This allows us to catch any exceptions + ) + + # Check if the task was cancelled + assert task_executor._active_tasks == {} diff --git a/tests/test_task_manager.py b/tests/test_task_manager.py index 2b74b44..6eb24c0 100644 --- a/tests/test_task_manager.py +++ b/tests/test_task_manager.py @@ -6,94 +6,111 @@ @pytest.fixture -def experiment_manager(configuration_manager, db_manager): - experiment_manager = ExperimentManager(configuration_manager, db_manager) - experiment_manager.create_experiment(EXPERIMENT_ID, "water_purification") +async def experiment_manager(configuration_manager, db_interface): + experiment_manager = ExperimentManager(configuration_manager, db_interface) + await experiment_manager.initialize(db_interface) + await experiment_manager.create_experiment(EXPERIMENT_ID, "water_purification") return experiment_manager @pytest.mark.parametrize("setup_lab_experiment", [("small_lab", "water_purification")], indirect=True) class TestTaskManager: - def test_create_task(self, task_manager, experiment_manager): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + @pytest.mark.asyncio + async def test_create_task(self, task_manager, experiment_manager): + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) - task = task_manager.get_task(EXPERIMENT_ID, "mixing") + task = await task_manager.get_task(EXPERIMENT_ID, "mixing") assert task.id == "mixing" assert task.type == "Magnetic Mixing" - def test_create_task_nonexistent(self, task_manager, experiment_manager): + @pytest.mark.asyncio + async def test_create_task_nonexistent(self, task_manager, experiment_manager): with pytest.raises(EosTaskStateError): - task_manager.create_task(EXPERIMENT_ID, "nonexistent", "nonexistent", []) + await task_manager.create_task(EXPERIMENT_ID, "nonexistent", "nonexistent", []) - def test_create_task_nonexistent_task_type(self, task_manager, experiment_manager): + @pytest.mark.asyncio + async def test_create_task_nonexistent_task_type(self, task_manager, experiment_manager): with pytest.raises(EosTaskStateError): - task_manager.create_task(EXPERIMENT_ID, "nonexistent_task", "Nonexistent", []) + await task_manager.create_task(EXPERIMENT_ID, "nonexistent_task", "Nonexistent", []) - def test_create_existing_task(self, task_manager, experiment_manager): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + @pytest.mark.asyncio + async def test_create_existing_task(self, task_manager, experiment_manager): + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) with pytest.raises(EosTaskExistsError): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) - def test_delete_task(self, task_manager): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + @pytest.mark.asyncio + async def test_delete_task(self, task_manager): + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) - task_manager.delete_task(EXPERIMENT_ID, "mixing") + await task_manager.delete_task(EXPERIMENT_ID, "mixing") - assert task_manager.get_task(EXPERIMENT_ID, "mixing") is None + assert await task_manager.get_task(EXPERIMENT_ID, "mixing") is None - def test_delete_nonexistent_task(self, task_manager, experiment_manager): + @pytest.mark.asyncio + async def test_delete_nonexistent_task(self, task_manager, experiment_manager): with pytest.raises(EosTaskStateError): - task_manager.delete_task(EXPERIMENT_ID, "nonexistent_task") + await task_manager.delete_task(EXPERIMENT_ID, "nonexistent_task") - def test_get_all_tasks_by_status(self, task_manager, experiment_manager): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) - task_manager.create_task(EXPERIMENT_ID, "purification", "Purification", []) + @pytest.mark.asyncio + async def test_get_all_tasks_by_status(self, task_manager, experiment_manager): + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + await task_manager.create_task(EXPERIMENT_ID, "purification", "Purification", []) - task_manager.start_task(EXPERIMENT_ID, "mixing") - task_manager.complete_task(EXPERIMENT_ID, "purification") + await task_manager.start_task(EXPERIMENT_ID, "mixing") + await task_manager.complete_task(EXPERIMENT_ID, "purification") - assert len(task_manager.get_tasks(experiment_id=EXPERIMENT_ID, status=TaskStatus.RUNNING.value)) == 1 - assert len(task_manager.get_tasks(experiment_id=EXPERIMENT_ID, status=TaskStatus.COMPLETED.value)) == 1 + assert len(await task_manager.get_tasks(experiment_id=EXPERIMENT_ID, status=TaskStatus.RUNNING.value)) == 1 + assert len(await task_manager.get_tasks(experiment_id=EXPERIMENT_ID, status=TaskStatus.COMPLETED.value)) == 1 - def test_set_task_status(self, task_manager, experiment_manager): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + @pytest.mark.asyncio + async def test_set_task_status(self, task_manager, experiment_manager): + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + task = await task_manager.get_task(EXPERIMENT_ID, "mixing") + assert task.status == TaskStatus.CREATED - assert task_manager.get_task(EXPERIMENT_ID, "mixing").status == TaskStatus.CREATED + await task_manager.start_task(EXPERIMENT_ID, "mixing") + task = await task_manager.get_task(EXPERIMENT_ID, "mixing") + assert task.status == TaskStatus.RUNNING - task_manager.start_task(EXPERIMENT_ID, "mixing") - assert task_manager.get_task(EXPERIMENT_ID, "mixing").status == TaskStatus.RUNNING + await task_manager.complete_task(EXPERIMENT_ID, "mixing") + task = await task_manager.get_task(EXPERIMENT_ID, "mixing") + assert task.status == TaskStatus.COMPLETED - task_manager.complete_task(EXPERIMENT_ID, "mixing") - assert task_manager.get_task(EXPERIMENT_ID, "mixing").status == TaskStatus.COMPLETED - - def test_set_task_status_nonexistent_task(self, task_manager, experiment_manager): + @pytest.mark.asyncio + async def test_set_task_status_nonexistent_task(self, task_manager, experiment_manager): with pytest.raises(EosTaskStateError): - task_manager.start_task(EXPERIMENT_ID, "nonexistent_task") + await task_manager.start_task(EXPERIMENT_ID, "nonexistent_task") - def test_start_task(self, task_manager, experiment_manager): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + @pytest.mark.asyncio + async def test_start_task(self, task_manager, experiment_manager): + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) - task_manager.start_task(EXPERIMENT_ID, "mixing") - assert "mixing" in experiment_manager.get_running_tasks(EXPERIMENT_ID) + await task_manager.start_task(EXPERIMENT_ID, "mixing") + assert "mixing" in await experiment_manager.get_running_tasks(EXPERIMENT_ID) - def test_start_task_nonexistent_experiment(self, task_manager, experiment_manager): + @pytest.mark.asyncio + async def test_start_task_nonexistent_experiment(self, task_manager, experiment_manager): with pytest.raises(EosTaskStateError): - task_manager.start_task(EXPERIMENT_ID, "nonexistent_task") - - def test_complete_task(self, task_manager, experiment_manager): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) - task_manager.start_task(EXPERIMENT_ID, "mixing") - task_manager.complete_task(EXPERIMENT_ID, "mixing") - assert "mixing" not in experiment_manager.get_running_tasks(EXPERIMENT_ID) - assert "mixing" in experiment_manager.get_completed_tasks(EXPERIMENT_ID) - - def test_complete_task_nonexistent_experiment(self, task_manager, experiment_manager): + await task_manager.start_task(EXPERIMENT_ID, "nonexistent_task") + + @pytest.mark.asyncio + async def test_complete_task(self, task_manager, experiment_manager): + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + await task_manager.start_task(EXPERIMENT_ID, "mixing") + await task_manager.complete_task(EXPERIMENT_ID, "mixing") + assert "mixing" not in await experiment_manager.get_running_tasks(EXPERIMENT_ID) + assert "mixing" in await experiment_manager.get_completed_tasks(EXPERIMENT_ID) + + @pytest.mark.asyncio + async def test_complete_task_nonexistent_experiment(self, task_manager, experiment_manager): with pytest.raises(EosTaskStateError): - task_manager.complete_task(EXPERIMENT_ID, "nonexistent_task") + await task_manager.complete_task(EXPERIMENT_ID, "nonexistent_task") - def test_add_task_output(self, task_manager): - task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) + @pytest.mark.asyncio + async def test_add_task_output(self, task_manager): + await task_manager.create_task(EXPERIMENT_ID, "mixing", "Magnetic Mixing", []) task_output = TaskOutput( experiment_id=EXPERIMENT_ID, @@ -101,10 +118,10 @@ def test_add_task_output(self, task_manager): parameters={"x": 5}, file_names=["file"], ) - task_manager.add_task_output(EXPERIMENT_ID, "mixing", task_output) + await task_manager.add_task_output(EXPERIMENT_ID, "mixing", task_output) task_manager.add_task_output_file(EXPERIMENT_ID, "mixing", "file", b"file_data") - output = task_manager.get_task_output(experiment_id=EXPERIMENT_ID, task_id="mixing") + output = await task_manager.get_task_output(experiment_id=EXPERIMENT_ID, task_id="mixing") assert output.parameters == {"x": 5} assert output.file_names == ["file"] diff --git a/tests/user/testing/tasks/file_generation_task/task.py b/tests/user/testing/tasks/file_generation_task/task.py new file mode 100644 index 0000000..8e7a945 --- /dev/null +++ b/tests/user/testing/tasks/file_generation_task/task.py @@ -0,0 +1,18 @@ +import random + +from eos.tasks.base_task import BaseTask + + +class FileGenerationTask(BaseTask): + def _execute( + self, + devices: BaseTask.DevicesType, + parameters: BaseTask.ParametersType, + containers: BaseTask.ContainersType, + ) -> BaseTask.OutputType: + content_length = parameters["content_length"] + + file_content = "".join( + random.choices("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", k=content_length)) + + return None, None, {"file.txt": bytes(file_content, "utf-8")} diff --git a/tests/user/testing/tasks/file_generation_task/task.yml b/tests/user/testing/tasks/file_generation_task/task.yml new file mode 100644 index 0000000..6731387 --- /dev/null +++ b/tests/user/testing/tasks/file_generation_task/task.yml @@ -0,0 +1,13 @@ +type: File Generation +description: Generates a file with random data. + +device_types: + - general_computer + +input_parameters: + content_length: + type: integer + unit: n/a + value: 10 + min: 0 + description: How many characters to generate in the file. diff --git a/tests/user/testing/tasks/sleep/task.py b/tests/user/testing/tasks/sleep/task.py index da50d47..190d35b 100644 --- a/tests/user/testing/tasks/sleep/task.py +++ b/tests/user/testing/tasks/sleep/task.py @@ -12,7 +12,7 @@ def _execute( ) -> BaseTask.OutputType: self.cancel_requested = False - sleep_time = parameters["sleep_time"] + sleep_time = parameters["time"] start_time = time.time() elapsed = 0