Skip to content

Commit

Permalink
chore: refactor setting step description, etc, to common code
Browse files Browse the repository at this point in the history
  • Loading branch information
bengerman13 committed Sep 21, 2024
1 parent 205f2aa commit 8edc129
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 398 deletions.
64 changes: 13 additions & 51 deletions broker/tasks/alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from sqlalchemy import and_, select, func, null
from sqlalchemy.orm import aliased
from sqlalchemy.orm.attributes import flag_modified


from broker.aws import alb
from broker.extensions import config, db
Expand All @@ -14,7 +12,7 @@
Certificate,
Operation,
)
from broker.tasks import huey
from broker.tasks.huey import pipeline_operation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,16 +109,10 @@ def get_lowest_dedicated_alb(service_instance, db):
db.session.commit()


@huey.retriable_task
def select_dedicated_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Selecting load balancer")
def select_dedicated_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Selecting load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if (
service_instance.alb_arn
and operation.action == Operation.Actions.PROVISION.value
Expand All @@ -131,16 +123,10 @@ def select_dedicated_alb(operation_id, **kwargs):
return get_lowest_dedicated_alb(service_instance, db)


@huey.retriable_task
def select_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Selecting load balancer")
def select_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Selecting load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if (
service_instance.alb_arn
and operation.action == Operation.Actions.PROVISION.value
Expand All @@ -156,17 +142,11 @@ def select_alb(operation_id, **kwargs):
db.session.commit()


@huey.retriable_task
def add_certificate_to_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Adding SSL certificate to load balancer")
def add_certificate_to_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance
certificate = service_instance.new_certificate

operation.step_description = "Adding SSL certificate to load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

alb.add_listener_certificates(
ListenerArn=service_instance.alb_listener_arn,
Certificates=[{"CertificateArn": certificate.iam_server_certificate_arn}],
Expand All @@ -184,16 +164,10 @@ def add_certificate_to_alb(operation_id, **kwargs):
db.session.commit()


@huey.retriable_task
def remove_certificate_from_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Removing SSL certificate from load balancer")
def remove_certificate_from_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Removing SSL certificate from load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.alb_listener_arn is not None:
alb.remove_listener_certificates(
ListenerArn=service_instance.alb_listener_arn,
Expand All @@ -208,9 +182,8 @@ def remove_certificate_from_alb(operation_id, **kwargs):
time.sleep(config.IAM_CERTIFICATE_PROPAGATION_TIME)


@huey.retriable_task
def remove_certificate_from_previous_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Removing SSL certificate from load balancer")
def remove_certificate_from_previous_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance
remove_certificate = Certificate.query.filter(
and_(
Expand All @@ -219,11 +192,6 @@ def remove_certificate_from_previous_alb(operation_id, **kwargs):
)
).first()

operation.step_description = "Removing SSL certificate from load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.previous_alb_listener_arn is not None:
time.sleep(config.ALB_OVERLAP_SLEEP_TIME)
alb.remove_listener_certificates(
Expand Down Expand Up @@ -263,19 +231,13 @@ def remove_certificate_from_previous_alb(operation_id, **kwargs):
db.session.commit()


@huey.retriable_task
@pipeline_operation("Removing certificate from previous load balancer")
def remove_certificate_from_previous_alb_during_update_to_dedicated(
operation_id, **kwargs
operation_id, *, operation, db, **kwargs
):
operation = db.session.get(Operation, operation_id)
service_instance = operation.service_instance
remove_certificate = service_instance.current_certificate

operation.step_description = "Removing SSL certificate from load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.previous_alb_listener_arn is not None:
time.sleep(config.ALB_OVERLAP_SLEEP_TIME)
alb.remove_listener_certificates(
Expand Down
92 changes: 24 additions & 68 deletions broker/tasks/cloudfront.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import logging
import time

from sqlalchemy.orm.attributes import flag_modified

from broker.aws import cloudfront
from broker.extensions import config, db
from broker.extensions import config
from broker.lib.tags import add_tag
from broker.models import Operation, CDNServiceInstance, CDNDedicatedWAFServiceInstance
from broker.tasks import huey
from broker.models import CDNServiceInstance, CDNDedicatedWAFServiceInstance
from broker.tasks.huey import pipeline_operation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,17 +60,11 @@ def get_custom_error_responses(service_instance):
return {"Quantity": 0}


@huey.retriable_task
def create_distribution(operation_id: int, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Creating CloudFront distribution")
def create_distribution(operation_id: int, *, operation, db, **kwargs):
service_instance = operation.service_instance
certificate = service_instance.new_certificate

operation.step_description = "Creating CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.cloudfront_distribution_id:
try:
cloudfront.get_distribution(Id=service_instance.cloudfront_distribution_id)
Expand Down Expand Up @@ -183,16 +175,10 @@ def create_distribution(operation_id: int, **kwargs):
db.session.commit()


@huey.retriable_task
def disable_distribution(operation_id: int, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Disabling CloudFront distribution")
def disable_distribution(operation_id: int, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Disabling CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.cloudfront_distribution_id is None:
return

Expand All @@ -210,16 +196,10 @@ def disable_distribution(operation_id: int, **kwargs):
return


@huey.retriable_task
def wait_for_distribution_disabled(operation_id: int, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Waiting for CloudFront distribution to disable")
def wait_for_distribution_disabled(operation_id: int, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Waiting for CloudFront distribution to disable"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.cloudfront_distribution_id is None:
return

Expand All @@ -244,21 +224,15 @@ def wait_for_distribution_disabled(operation_id: int, **kwargs):
except cloudfront.exceptions.NoSuchDistribution:
return
distribution_disabled = (
status["Distribution"]["DistributionConfig"]["Enabled"] == False
not status["Distribution"]["DistributionConfig"]["Enabled"]
and status["Distribution"]["Status"] == "Deployed"
)


@huey.retriable_task
def delete_distribution(operation_id: int, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Deleting CloudFront distribution")
def delete_distribution(operation_id: int, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Deleting CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.cloudfront_distribution_id is None:
return

Expand All @@ -273,16 +247,10 @@ def delete_distribution(operation_id: int, **kwargs):
return


@huey.retriable_task
def wait_for_distribution(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Waiting for CloudFront distribution")
def wait_for_distribution(operation_id: str, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Waiting for CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

waiter = cloudfront.get_waiter("distribution_deployed")
waiter.wait(
Id=service_instance.cloudfront_distribution_id,
Expand All @@ -293,16 +261,10 @@ def wait_for_distribution(operation_id: str, **kwargs):
)


@huey.retriable_task
def update_certificate(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Updating CloudFront distribution certificate")
def update_certificate(operation_id: str, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Updating CloudFront distribution certificate"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

config = cloudfront.get_distribution_config(
Id=service_instance.cloudfront_distribution_id
)
Expand All @@ -320,17 +282,11 @@ def update_certificate(operation_id: str, **kwargs):
db.session.commit()


@huey.retriable_task
def update_distribution(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Updating CloudFront distribution")
def update_distribution(operation_id: str, *, operation, db, **kwargs):
service_instance = operation.service_instance
certificate = service_instance.new_certificate

operation.step_description = "Updating CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

config_response = cloudfront.get_distribution_config(
Id=service_instance.cloudfront_distribution_id
)
Expand Down Expand Up @@ -369,9 +325,10 @@ def update_distribution(operation_id: str, **kwargs):
db.session.commit()


@huey.retriable_task
def remove_s3_bucket_from_cdn_broker_instance(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Removing s3 bucket binding")
def remove_s3_bucket_from_cdn_broker_instance(
operation_id: str, *, operation, db, **kwargs
):
service_instance = operation.service_instance
config_response = cloudfront.get_distribution_config(
Id=service_instance.cloudfront_distribution_id
Expand Down Expand Up @@ -414,9 +371,8 @@ def remove_s3_bucket_from_cdn_broker_instance(operation_id: str, **kwargs):
)


@huey.retriable_task
def add_logging_to_bucket(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Adding logging to Cloudfront distribution")
def add_logging_to_bucket(operation_id: str, *, operation, db, **kwargs):
service_instance = operation.service_instance
config_response = cloudfront.get_distribution_config(
Id=service_instance.cloudfront_distribution_id
Expand Down
Loading

0 comments on commit 8edc129

Please sign in to comment.