diff --git a/src/dmap_import/pipeline.py b/src/dmap_import/pipeline.py index 991dbee..7827026 100644 --- a/src/dmap_import/pipeline.py +++ b/src/dmap_import/pipeline.py @@ -1,15 +1,17 @@ import os from typing import List, Optional -from dmap_import.util_rds import alembic_upgrade_to_head -from dmap_import.api_job_list import produce_job_list from dmap_import.api_copy_job import run_api_copy +from dmap_import.api_job_list import produce_job_list +from dmap_import.util_aws import running_in_aws, check_for_parallel_tasks from dmap_import.util_logging import ProcessLogger +from dmap_import.util_rds import alembic_upgrade_to_head def validate_environment( required_variables: List[str], private_variables: Optional[List[str]] = None, + aws_variables: Optional[List[str]] = None, validate_db: bool = False, ) -> None: """ @@ -25,6 +27,9 @@ def validate_environment( # every pipeline needs a service name for logging required_variables.append("SERVICE_NAME") + if aws_variables and running_in_aws(): + required_variables += aws_variables + # add required database variables if validate_db: required_variables += [ @@ -110,9 +115,12 @@ def main() -> None: "CONTROLLED_KEY", "PUBLIC_KEY", ], + aws_variables=["ECS_CLUSTER", "ECS_TASK_GROUP"], validate_db=True, ) + check_for_parallel_tasks() + start() diff --git a/src/dmap_import/util_aws.py b/src/dmap_import/util_aws.py new file mode 100644 index 0000000..91a5910 --- /dev/null +++ b/src/dmap_import/util_aws.py @@ -0,0 +1,49 @@ +import os +import boto3 + +from dmap_import.util_logging import ProcessLogger + + +def running_in_aws() -> bool: + """ + return True if running on aws, else False + """ + return bool(os.getenv("AWS_DEFAULT_REGION")) + + +def check_for_parallel_tasks() -> None: + """ + Check that that this task is not already running on ECS + """ + if not running_in_aws(): + return + + process_logger = ProcessLogger("check_for_tasks") + process_logger.log_start() + + client = boto3.client("ecs") + dmap_ecs_cluster = os.environ["ECS_CLUSTER"] + dmap_ecs_task_group = os.environ["ECS_TASK_GROUP"] + + # get all of the tasks running on the cluster + task_arns = client.list_tasks(cluster=dmap_ecs_cluster)["taskArns"] + + # if tasks are running on the cluster, get their descriptions and check to + # count matches the ecs task group. + match_count = 0 + if task_arns: + running_tasks = client.describe_tasks( + cluster=dmap_ecs_cluster, tasks=task_arns + )["tasks"] + + for task in running_tasks: + if dmap_ecs_task_group == task["group"]: + match_count += 1 + + # if the group matches, raise an exception that will terminate the process + if match_count > 1: + exception = Exception("Multiple Tasks Running") + process_logger.log_failure(exception) + raise exception + + process_logger.log_complete() diff --git a/src/dmap_import/util_rds.py b/src/dmap_import/util_rds.py index 2ee1a59..859ea95 100644 --- a/src/dmap_import/util_rds.py +++ b/src/dmap_import/util_rds.py @@ -12,6 +12,7 @@ from alembic.config import Config from alembic import command +from dmap_import.util_aws import running_in_aws from dmap_import.util_logging import ProcessLogger @@ -27,13 +28,6 @@ def running_in_docker() -> bool: ) -def running_in_aws() -> bool: - """ - return True if running on aws, else False - """ - return bool(os.getenv("AWS_DEFAULT_REGION")) - - def get_db_host() -> str: """ get current db_host string