Skip to content

Commit

Permalink
FEAT: Check for Parallel ECS Tasks (#11)
Browse files Browse the repository at this point in the history
FEAT: Check for Parallel ECS Tasks

Although unlikely because of the configuration in our infrastructure,
the dmap import task could be ran more than once in parallel. If they
were both writing to the RDS at the same time, chaotic behavior could
result.

To prevent this, check for parallel tasks before starting the main
process using a boto3 ecs client.
  • Loading branch information
mzappitello authored Nov 20, 2023
1 parent c6bf458 commit 80d3f2a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 9 deletions.
12 changes: 10 additions & 2 deletions src/dmap_import/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
from typing import List, Optional

from dmap_import.util_rds import alembic_upgrade_to_head
from dmap_import.api_job_list import produce_job_list
from dmap_import.api_copy_job import run_api_copy
from dmap_import.api_job_list import produce_job_list
from dmap_import.util_aws import running_in_aws, check_for_parallel_tasks
from dmap_import.util_logging import ProcessLogger
from dmap_import.util_rds import alembic_upgrade_to_head


def validate_environment(
required_variables: List[str],
private_variables: Optional[List[str]] = None,
aws_variables: Optional[List[str]] = None,
validate_db: bool = False,
) -> None:
"""
Expand All @@ -25,6 +27,9 @@ def validate_environment(
# every pipeline needs a service name for logging
required_variables.append("SERVICE_NAME")

if aws_variables and running_in_aws():
required_variables += aws_variables

# add required database variables
if validate_db:
required_variables += [
Expand Down Expand Up @@ -110,9 +115,12 @@ def main() -> None:
"CONTROLLED_KEY",
"PUBLIC_KEY",
],
aws_variables=["ECS_CLUSTER", "ECS_TASK_GROUP"],
validate_db=True,
)

check_for_parallel_tasks()

start()


Expand Down
49 changes: 49 additions & 0 deletions src/dmap_import/util_aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import boto3

from dmap_import.util_logging import ProcessLogger


def running_in_aws() -> bool:
"""
return True if running on aws, else False
"""
return bool(os.getenv("AWS_DEFAULT_REGION"))


def check_for_parallel_tasks() -> None:
"""
Check that that this task is not already running on ECS
"""
if not running_in_aws():
return

process_logger = ProcessLogger("check_for_tasks")
process_logger.log_start()

client = boto3.client("ecs")
dmap_ecs_cluster = os.environ["ECS_CLUSTER"]
dmap_ecs_task_group = os.environ["ECS_TASK_GROUP"]

# get all of the tasks running on the cluster
task_arns = client.list_tasks(cluster=dmap_ecs_cluster)["taskArns"]

# if tasks are running on the cluster, get their descriptions and check to
# count matches the ecs task group.
match_count = 0
if task_arns:
running_tasks = client.describe_tasks(
cluster=dmap_ecs_cluster, tasks=task_arns
)["tasks"]

for task in running_tasks:
if dmap_ecs_task_group == task["group"]:
match_count += 1

# if the group matches, raise an exception that will terminate the process
if match_count > 1:
exception = Exception("Multiple Tasks Running")
process_logger.log_failure(exception)
raise exception

process_logger.log_complete()
8 changes: 1 addition & 7 deletions src/dmap_import/util_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from alembic.config import Config
from alembic import command

from dmap_import.util_aws import running_in_aws
from dmap_import.util_logging import ProcessLogger


Expand All @@ -27,13 +28,6 @@ def running_in_docker() -> bool:
)


def running_in_aws() -> bool:
"""
return True if running on aws, else False
"""
return bool(os.getenv("AWS_DEFAULT_REGION"))


def get_db_host() -> str:
"""
get current db_host string
Expand Down

0 comments on commit 80d3f2a

Please sign in to comment.