diff --git a/reflex_core/aws_rule_interface.py b/reflex_core/aws_rule_interface.py index b2cfa22..c73adad 100644 --- a/reflex_core/aws_rule_interface.py +++ b/reflex_core/aws_rule_interface.py @@ -21,6 +21,10 @@ class AWSRuleInterface: region (str): The AWS region that the event occurred in. service (str): The name of the AWS service that triggered the event. client (boto3.client): A boto3 client for the service that triggered the event. + pre_compliance_check_functions (list): A list of callables (usually functions) + to be run before the resource compliance check occurs. + post_compliance_check_functions (list): A list of callables (usually functions) + to be run after rthe resource compliance check occurs. pre_remediation_functions (list): A list of callables (usually functions) to be run before remediation action occurs. post_remediation_functions (list): A list of callables (usually functions) @@ -45,6 +49,8 @@ def __init__(self, event): self.client = self.get_boto3_client() self.extract_event_data(event) + self.pre_compliance_check_functions = [] + self.post_compliance_check_functions = [] self.pre_remediation_functions = [] self.post_remediation_functions = [] self.notifiers = [] @@ -141,8 +147,14 @@ def run_compliance_rule(self): None """ try: + self.pre_compliance_check() + self.LOGGER.debug("Checking if resource is compliant") - if not self.resource_compliant(): + resource_compliant = self.resource_compliant() + + self.post_compliance_check() + + if not resource_compliant: self.LOGGER.debug("Resource is not compliant") if self.should_remediate(): @@ -188,6 +200,44 @@ def remediate(self): """ raise NotImplementedError("remediate not implemented") + def pre_compliance_check(self): + """Runs all pre-compliance check functions. + + This function executes all functions that have been registered in the + pre_compliance_check_functions list. This is run immediately before the + resource_compliant function is called. Functions are executed in the order + they occur in the pre_compliance_check_functions list. + + Returns: + None + """ + self.LOGGER.debug("Running pre-compliance check functions") + for pre_compliance_check_function in self.pre_compliance_check_functions: + self.LOGGER.debug( + "Running pre-compliance check function %s", + pre_compliance_check_function.__name__, + ) + pre_compliance_check_function() + + def post_compliance_check(self): + """Runs all post-compliance check functions. + + This function executes all functions that have been registered in the + post_compliance_check_functions list. This is run immediately after the + resource_compliant function is called. Functions are executed in the order + they occur in the post_compliance_check_functions list. + + Returns: + None + """ + self.LOGGER.debug("Running post-compliance check functions") + for post_compliance_check_function in self.post_compliance_check_functions: + self.LOGGER.debug( + "Running post-remediation function %s", + post_compliance_check_function.__name__, + ) + post_compliance_check_function() + def pre_remediation(self): """Runs all pre-remediation functions. @@ -276,6 +326,70 @@ def get_remediation_message_subject(self): fixed_subject = " ".join(subject_split) return f"The Reflex {fixed_subject} was triggered." + def add_pre_compliance_check_functions(self, functions): + """Sets a function or list of functions to be run before the resource + compliance check occurs. + + If anything other than a function is present in the list, it will be ignored. + If something other than a function or list is passed, it will be ignored. + + Returns: + None + """ + self._add_functions( + functions, + self.pre_compliance_check_functions, + "pre_compliance_check_functions", + ) + + def remove_pre_compliance_check_functions(self, functions): + """Stop a function or list of functions from being run before the resource + compliance check occurs. + + Takes a function or list of functions and removes them from the list + of pre-compliance check functions. Anything not in the list will be ignored. + + Returns: + None + """ + self._remove_functions( + functions, + self.pre_compliance_check_functions, + "pre_compliance_check_functions", + ) + + def add_post_compliance_check_functions(self, functions): + """Sets a function or list of functions to be run after the resource + compliance check occurs. + + If anything other than a function is present in the list, it will be ignored. + If something other than a function or list is passed, it will be ignored. + + Returns: + None + """ + self._add_functions( + functions, + self.post_compliance_check_functions, + "post_compliance_check_functions", + ) + + def remove_post_compliance_check_functions(self, functions): + """Stop a function or list of functions from being run after the + resource compliance check occurs. + + Takes a function or list of functions and removes them from the list + of post-compliance check functions. Anything not in the list will be ignored. + + Returns: + None + """ + self._remove_functions( + functions, + self.post_compliance_check_functions, + "post_compliance_check_functions", + ) + def add_pre_remediation_functions(self, functions): """Sets a function or list of functions to be run before remediation action occurs. @@ -285,28 +399,9 @@ def add_pre_remediation_functions(self, functions): Returns: None """ - if isinstance(functions, list): - for function in functions: - if callable(function): - self.LOGGER.debug( - "Adding %s to pre-remediation functions", function.__name__ - ) - self.pre_remediation_functions.append(function) - else: - self.LOGGER.warning( - "%s is not a function. Not adding to list of pre-remediation functions.", - function.__name__, - ) - elif callable(functions): - self.LOGGER.debug( - "Adding %s to pre-remediation functions", functions.__name__ - ) - self.pre_remediation_functions.append(functions) - else: - self.LOGGER.warning( - "%s is not a function or list. Not adding to list of pre-remediation functions.", - functions.__name__, - ) + self._add_functions( + functions, self.pre_remediation_functions, "pre_remediation_functions" + ) def remove_pre_remediation_functions(self, functions): """Stop a function or list of functions from being run pre-remediation. @@ -317,29 +412,9 @@ def remove_pre_remediation_functions(self, functions): Returns: None """ - if isinstance(functions, list): - for function in functions: - try: - self.LOGGER.debug( - "Removing %s from pre-remediation functions", function.__name__ - ) - self.pre_remediation_functions.remove(function) - except ValueError: - self.LOGGER.warning( - "%s is not in the list of pre-remediation functions. Skipping", - function.__name__, - ) - else: - try: - self.LOGGER.debug( - "Removing %s from pre-remediation functions", functions.__name__ - ) - self.pre_remediation_functions.remove(functions) - except ValueError: - self.LOGGER.warning( - "%s is not in the list of pre-remediation functions. Skipping", - functions.__name__, - ) + self._remove_functions( + functions, self.pre_remediation_functions, "pre_remediation_functions" + ) def add_post_remediation_functions(self, functions): """Sets a function or list of functions to be run after remediation action occurs. @@ -350,28 +425,9 @@ def add_post_remediation_functions(self, functions): Returns: None """ - if isinstance(functions, list): - for function in functions: - if callable(function): - self.LOGGER.debug( - "Adding %s to post-remediation functions", function.__name__ - ) - self.post_remediation_functions.append(function) - else: - self.LOGGER.warning( - "%s is not a function. Not adding to list of post-remediation functions.", - function.__name__, - ) - elif callable(functions): - self.LOGGER.debug( - "Adding %s to post-remediation functions", functions.__name__ - ) - self.post_remediation_functions.append(functions) - else: - self.LOGGER.warning( - "%s is not a function or list. Not adding to list of post-remediation functions.", - functions.__name__, - ) + self._add_functions( + functions, self.post_remediation_functions, "post_remediation_functions" + ) def remove_post_remediation_functions(self, functions): """Stop a function or list of functions from being run post-remediation. @@ -382,29 +438,9 @@ def remove_post_remediation_functions(self, functions): Returns: None """ - if isinstance(functions, list): - for function in functions: - try: - self.LOGGER.debug( - "Removing %s from post-remediation functions", function.__name__ - ) - self.post_remediation_functions.remove(function) - except ValueError: - self.LOGGER.warning( - "%s is not in the list of post-remediation functions. Skipping", - function.__name__, - ) - else: - try: - self.LOGGER.debug( - "Removing %s from post-remediation functions", functions.__name__ - ) - self.post_remediation_functions.remove(functions) - except ValueError: - self.LOGGER.warning( - "%s is not in the list of post-remediation functions. Skipping", - functions.__name__, - ) + self._remove_functions( + functions, self.post_remediation_functions, "post_remediation_functions" + ) def add_notifiers(self, notifiers): """Sets a Notifier or list of Notifiers to send remediation notifications with. @@ -494,3 +530,62 @@ def should_remediate(self): """ mode = os.environ.get("MODE", "detect").lower() return mode == "remediate" + + def _add_functions(self, functions, function_list, list_name): + """Adds a function or list of functions to the provided function list. + + If anything other than a function is present in the functions list, it will be ignored. + If something other than a function or list is passed, it will be ignored. + + Returns: + None + """ + if isinstance(functions, list): + for function in functions: + if callable(function): + self.LOGGER.debug("Adding %s to %s", function.__name__, list_name) + function_list.append(function) + else: + self.LOGGER.warning( + "%s is not a function. Not adding to %s.", + function.__name__, + list_name, + ) + elif callable(functions): + self.LOGGER.debug("Adding %s to %s", functions.__name__, list_name) + function_list.append(functions) + else: + self.LOGGER.warning( + "%s is not a function or list. Not adding to %s.", + functions.__name__, + list_name, + ) + + def _remove_functions(self, functions, function_list, list_name): + """Remove a function or list of functions from the provided function list. + + Takes a function or list of functions and removes them from function_list. + Anything not in function_list will be ignored. + + Returns: + None + """ + if isinstance(functions, list): + for function in functions: + try: + self.LOGGER.debug( + "Removing %s from %s", function.__name__, list_name + ) + function_list.remove(function) + except ValueError: + self.LOGGER.warning( + "%s is not in %s. Skipping", function.__name__, list_name, + ) + else: + try: + self.LOGGER.debug("Removing %s from %s", functions.__name__, list_name) + function_list.remove(functions) + except ValueError: + self.LOGGER.warning( + "%s is not in %s. Skipping", functions.__name__, list_name, + )