Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor setting step description, etc, to common code #389

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 13 additions & 51 deletions broker/tasks/alb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from sqlalchemy import and_, select, func, null
from sqlalchemy.orm import aliased
from sqlalchemy.orm.attributes import flag_modified


from broker.aws import alb
from broker.extensions import config, db
Expand All @@ -14,7 +12,7 @@
Certificate,
Operation,
)
from broker.tasks import huey
from broker.tasks.huey import pipeline_operation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,16 +109,10 @@ def get_lowest_dedicated_alb(service_instance, db):
db.session.commit()


@huey.retriable_task
def select_dedicated_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Selecting load balancer")
def select_dedicated_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Selecting load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if (
service_instance.alb_arn
and operation.action == Operation.Actions.PROVISION.value
Expand All @@ -131,16 +123,10 @@ def select_dedicated_alb(operation_id, **kwargs):
return get_lowest_dedicated_alb(service_instance, db)


@huey.retriable_task
def select_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Selecting load balancer")
def select_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Selecting load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if (
service_instance.alb_arn
and operation.action == Operation.Actions.PROVISION.value
Expand All @@ -156,17 +142,11 @@ def select_alb(operation_id, **kwargs):
db.session.commit()


@huey.retriable_task
def add_certificate_to_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Adding SSL certificate to load balancer")
def add_certificate_to_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance
certificate = service_instance.new_certificate

operation.step_description = "Adding SSL certificate to load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

alb.add_listener_certificates(
ListenerArn=service_instance.alb_listener_arn,
Certificates=[{"CertificateArn": certificate.iam_server_certificate_arn}],
Expand All @@ -184,16 +164,10 @@ def add_certificate_to_alb(operation_id, **kwargs):
db.session.commit()


@huey.retriable_task
def remove_certificate_from_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Removing SSL certificate from load balancer")
def remove_certificate_from_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Removing SSL certificate from load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.alb_listener_arn is not None:
alb.remove_listener_certificates(
ListenerArn=service_instance.alb_listener_arn,
Expand All @@ -208,9 +182,8 @@ def remove_certificate_from_alb(operation_id, **kwargs):
time.sleep(config.IAM_CERTIFICATE_PROPAGATION_TIME)


@huey.retriable_task
def remove_certificate_from_previous_alb(operation_id, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Removing SSL certificate from load balancer")
def remove_certificate_from_previous_alb(operation_id, *, operation, db, **kwargs):
service_instance = operation.service_instance
remove_certificate = Certificate.query.filter(
and_(
Expand All @@ -219,11 +192,6 @@ def remove_certificate_from_previous_alb(operation_id, **kwargs):
)
).first()

operation.step_description = "Removing SSL certificate from load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.previous_alb_listener_arn is not None:
time.sleep(config.ALB_OVERLAP_SLEEP_TIME)
alb.remove_listener_certificates(
Expand Down Expand Up @@ -263,19 +231,13 @@ def remove_certificate_from_previous_alb(operation_id, **kwargs):
db.session.commit()


@huey.retriable_task
@pipeline_operation("Removing certificate from previous load balancer")
def remove_certificate_from_previous_alb_during_update_to_dedicated(
operation_id, **kwargs
operation_id, *, operation, db, **kwargs
):
operation = db.session.get(Operation, operation_id)
service_instance = operation.service_instance
remove_certificate = service_instance.current_certificate

operation.step_description = "Removing SSL certificate from load balancer"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.previous_alb_listener_arn is not None:
time.sleep(config.ALB_OVERLAP_SLEEP_TIME)
alb.remove_listener_certificates(
Expand Down
92 changes: 24 additions & 68 deletions broker/tasks/cloudfront.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import logging
import time

from sqlalchemy.orm.attributes import flag_modified

from broker.aws import cloudfront
from broker.extensions import config, db
from broker.extensions import config
from broker.lib.tags import add_tag
from broker.models import Operation, CDNServiceInstance, CDNDedicatedWAFServiceInstance
from broker.tasks import huey
from broker.models import CDNServiceInstance, CDNDedicatedWAFServiceInstance
from broker.tasks.huey import pipeline_operation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,17 +60,11 @@ def get_custom_error_responses(service_instance):
return {"Quantity": 0}


@huey.retriable_task
def create_distribution(operation_id: int, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Creating CloudFront distribution")
def create_distribution(operation_id: int, *, operation, db, **kwargs):
service_instance = operation.service_instance
certificate = service_instance.new_certificate

operation.step_description = "Creating CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.cloudfront_distribution_id:
try:
cloudfront.get_distribution(Id=service_instance.cloudfront_distribution_id)
Expand Down Expand Up @@ -183,16 +175,10 @@ def create_distribution(operation_id: int, **kwargs):
db.session.commit()


@huey.retriable_task
def disable_distribution(operation_id: int, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Disabling CloudFront distribution")
def disable_distribution(operation_id: int, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Disabling CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.cloudfront_distribution_id is None:
return

Expand All @@ -210,16 +196,10 @@ def disable_distribution(operation_id: int, **kwargs):
return


@huey.retriable_task
def wait_for_distribution_disabled(operation_id: int, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Waiting for CloudFront distribution to disable")
def wait_for_distribution_disabled(operation_id: int, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Waiting for CloudFront distribution to disable"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.cloudfront_distribution_id is None:
return

Expand All @@ -244,21 +224,15 @@ def wait_for_distribution_disabled(operation_id: int, **kwargs):
except cloudfront.exceptions.NoSuchDistribution:
return
distribution_disabled = (
status["Distribution"]["DistributionConfig"]["Enabled"] == False
not status["Distribution"]["DistributionConfig"]["Enabled"]
and status["Distribution"]["Status"] == "Deployed"
)


@huey.retriable_task
def delete_distribution(operation_id: int, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Deleting CloudFront distribution")
def delete_distribution(operation_id: int, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Deleting CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

if service_instance.cloudfront_distribution_id is None:
return

Expand All @@ -273,16 +247,10 @@ def delete_distribution(operation_id: int, **kwargs):
return


@huey.retriable_task
def wait_for_distribution(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Waiting for CloudFront distribution")
def wait_for_distribution(operation_id: str, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Waiting for CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

waiter = cloudfront.get_waiter("distribution_deployed")
waiter.wait(
Id=service_instance.cloudfront_distribution_id,
Expand All @@ -293,16 +261,10 @@ def wait_for_distribution(operation_id: str, **kwargs):
)


@huey.retriable_task
def update_certificate(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Updating CloudFront distribution certificate")
def update_certificate(operation_id: str, *, operation, db, **kwargs):
service_instance = operation.service_instance

operation.step_description = "Updating CloudFront distribution certificate"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

config = cloudfront.get_distribution_config(
Id=service_instance.cloudfront_distribution_id
)
Expand All @@ -320,17 +282,11 @@ def update_certificate(operation_id: str, **kwargs):
db.session.commit()


@huey.retriable_task
def update_distribution(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Updating CloudFront distribution")
def update_distribution(operation_id: str, *, operation, db, **kwargs):
service_instance = operation.service_instance
certificate = service_instance.new_certificate

operation.step_description = "Updating CloudFront distribution"
flag_modified(operation, "step_description")
db.session.add(operation)
db.session.commit()

config_response = cloudfront.get_distribution_config(
Id=service_instance.cloudfront_distribution_id
)
Expand Down Expand Up @@ -369,9 +325,10 @@ def update_distribution(operation_id: str, **kwargs):
db.session.commit()


@huey.retriable_task
def remove_s3_bucket_from_cdn_broker_instance(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Removing s3 bucket binding")
def remove_s3_bucket_from_cdn_broker_instance(
operation_id: str, *, operation, db, **kwargs
):
service_instance = operation.service_instance
config_response = cloudfront.get_distribution_config(
Id=service_instance.cloudfront_distribution_id
Expand Down Expand Up @@ -414,9 +371,8 @@ def remove_s3_bucket_from_cdn_broker_instance(operation_id: str, **kwargs):
)


@huey.retriable_task
def add_logging_to_bucket(operation_id: str, **kwargs):
operation = db.session.get(Operation, operation_id)
@pipeline_operation("Adding logging to Cloudfront distribution")
def add_logging_to_bucket(operation_id: str, *, operation, db, **kwargs):
service_instance = operation.service_instance
config_response = cloudfront.get_distribution_config(
Id=service_instance.cloudfront_distribution_id
Expand Down
Loading