From 80d3f2a3f9c4b9ba1c84299153ce61ab22a406d2 Mon Sep 17 00:00:00 2001 From: mike zappitello Date: Mon, 20 Nov 2023 12:15:55 -0500 Subject: [PATCH] FEAT: Check for Parallel ECS Tasks (#11) FEAT: Check for Parallel ECS Tasks Although unlikely because of the configuration in our infrastructure, the dmap import task could be ran more than once in parallel. If they were both writing to the RDS at the same time, chaotic behavior could result. To prevent this, check for parallel tasks before starting the main process using a boto3 ecs client. --- src/dmap_import/pipeline.py | 12 +++++++-- src/dmap_import/util_aws.py | 49 +++++++++++++++++++++++++++++++++++++ src/dmap_import/util_rds.py | 8 +----- 3 files changed, 60 insertions(+), 9 deletions(-) create mode 100644 src/dmap_import/util_aws.py diff --git a/src/dmap_import/pipeline.py b/src/dmap_import/pipeline.py index 991dbee..7827026 100644 --- a/src/dmap_import/pipeline.py +++ b/src/dmap_import/pipeline.py @@ -1,15 +1,17 @@ import os from typing import List, Optional -from dmap_import.util_rds import alembic_upgrade_to_head -from dmap_import.api_job_list import produce_job_list from dmap_import.api_copy_job import run_api_copy +from dmap_import.api_job_list import produce_job_list +from dmap_import.util_aws import running_in_aws, check_for_parallel_tasks from dmap_import.util_logging import ProcessLogger +from dmap_import.util_rds import alembic_upgrade_to_head def validate_environment( required_variables: List[str], private_variables: Optional[List[str]] = None, + aws_variables: Optional[List[str]] = None, validate_db: bool = False, ) -> None: """ @@ -25,6 +27,9 @@ def validate_environment( # every pipeline needs a service name for logging required_variables.append("SERVICE_NAME") + if aws_variables and running_in_aws(): + required_variables += aws_variables + # add required database variables if validate_db: required_variables += [ @@ -110,9 +115,12 @@ def main() -> None: "CONTROLLED_KEY", "PUBLIC_KEY", ], + aws_variables=["ECS_CLUSTER", "ECS_TASK_GROUP"], validate_db=True, ) + check_for_parallel_tasks() + start() diff --git a/src/dmap_import/util_aws.py b/src/dmap_import/util_aws.py new file mode 100644 index 0000000..91a5910 --- /dev/null +++ b/src/dmap_import/util_aws.py @@ -0,0 +1,49 @@ +import os +import boto3 + +from dmap_import.util_logging import ProcessLogger + + +def running_in_aws() -> bool: + """ + return True if running on aws, else False + """ + return bool(os.getenv("AWS_DEFAULT_REGION")) + + +def check_for_parallel_tasks() -> None: + """ + Check that that this task is not already running on ECS + """ + if not running_in_aws(): + return + + process_logger = ProcessLogger("check_for_tasks") + process_logger.log_start() + + client = boto3.client("ecs") + dmap_ecs_cluster = os.environ["ECS_CLUSTER"] + dmap_ecs_task_group = os.environ["ECS_TASK_GROUP"] + + # get all of the tasks running on the cluster + task_arns = client.list_tasks(cluster=dmap_ecs_cluster)["taskArns"] + + # if tasks are running on the cluster, get their descriptions and check to + # count matches the ecs task group. + match_count = 0 + if task_arns: + running_tasks = client.describe_tasks( + cluster=dmap_ecs_cluster, tasks=task_arns + )["tasks"] + + for task in running_tasks: + if dmap_ecs_task_group == task["group"]: + match_count += 1 + + # if the group matches, raise an exception that will terminate the process + if match_count > 1: + exception = Exception("Multiple Tasks Running") + process_logger.log_failure(exception) + raise exception + + process_logger.log_complete() diff --git a/src/dmap_import/util_rds.py b/src/dmap_import/util_rds.py index 2ee1a59..859ea95 100644 --- a/src/dmap_import/util_rds.py +++ b/src/dmap_import/util_rds.py @@ -12,6 +12,7 @@ from alembic.config import Config from alembic import command +from dmap_import.util_aws import running_in_aws from dmap_import.util_logging import ProcessLogger @@ -27,13 +28,6 @@ def running_in_docker() -> bool: ) -def running_in_aws() -> bool: - """ - return True if running on aws, else False - """ - return bool(os.getenv("AWS_DEFAULT_REGION")) - - def get_db_host() -> str: """ get current db_host string