From fec578e664ceb6c86ddb73294da319f12123fe1d Mon Sep 17 00:00:00 2001 From: drduhe Date: Fri, 4 Oct 2024 09:25:33 -0600 Subject: [PATCH] chore: refactoring app.py - phase 1 --- src/aws/osml/model_runner/app.py | 132 ++++-------------- src/aws/osml/model_runner/app_config.py | 70 +++++++--- src/aws/osml/model_runner/common/__init__.py | 1 - .../osml/model_runner/common/metrics_utils.py | 14 -- .../osml/model_runner/database/exceptions.py | 4 + .../model_runner/database/feature_table.py | 28 ++++ .../database/region_request_table.py | 41 +++++- test/test_app.py | 15 +- 8 files changed, 153 insertions(+), 152 deletions(-) delete mode 100755 src/aws/osml/model_runner/common/metrics_utils.py diff --git a/src/aws/osml/model_runner/app.py b/src/aws/osml/model_runner/app.py index 73ebf182..678b5f2e 100755 --- a/src/aws/osml/model_runner/app.py +++ b/src/aws/osml/model_runner/app.py @@ -18,14 +18,8 @@ from osgeo import gdal from osgeo.gdal import Dataset -from aws.osml.gdal import ( - GDALConfigEnv, - GDALDigitalElevationModelTileFactory, - get_image_extension, - load_gdal_dataset, - set_gdal_default_configuration, -) -from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, ImageCoordinate, SensorModel, SRTMTileSet +from aws.osml.gdal import GDALConfigEnv, get_image_extension, load_gdal_dataset, set_gdal_default_configuration +from aws.osml.photogrammetry import ImageCoordinate, SensorModel from .api import VALID_MODEL_HOSTING_OPTIONS, ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode from .app_config import MetricLabels, ServiceConfig @@ -38,7 +32,6 @@ RequestStatus, ThreadingLocalContextFilter, Timer, - build_embedded_metrics_config, get_credentials_for_assumed_role, mr_post_processing_options_factory, ) @@ -50,7 +43,6 @@ InvalidImageURLException, LoadImageException, ProcessImageException, - ProcessRegionException, RetryableJobException, SelfThrottledRegionException, UnsupportedModelException, @@ -61,12 +53,13 @@ from .status import ImageStatusMonitor, RegionStatusMonitor from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy, process_tiles, setup_tile_workers -# Set up metrics configuration -build_embedded_metrics_config() - # Set up logging configuration logger = logging.getLogger(__name__) +# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume +# no exceptions so we call this explicitly until the software can be updated to match. +gdal.UseExceptions() + class ModelRunner: """ @@ -81,17 +74,17 @@ def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrate :param tiling_strategy: class defining how a larger image will be broken into chunks for processing """ + self.config = ServiceConfig() self.tiling_strategy = tiling_strategy - self.image_request_queue = RequestQueue(ServiceConfig.image_queue, wait_seconds=0) + self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0) self.image_requests_iter = iter(self.image_request_queue) - self.job_table = JobTable(ServiceConfig.job_table) - self.region_request_table = RegionRequestTable(ServiceConfig.region_request_table) - self.endpoint_statistics_table = EndpointStatisticsTable(ServiceConfig.endpoint_statistics_table) - self.region_request_queue = RequestQueue(ServiceConfig.region_queue, wait_seconds=10) + self.job_table = JobTable(self.config.job_table) + self.region_request_table = RegionRequestTable(self.config.region_request_table) + self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table) + self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10) self.region_requests_iter = iter(self.region_request_queue) - self.image_status_monitor = ImageStatusMonitor(ServiceConfig.image_status_topic) - self.region_status_monitor = RegionStatusMonitor(ServiceConfig.region_status_topic) - self.elevation_model = ModelRunner.create_elevation_model() + self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic) + self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic) self.endpoint_utils = EndpointUtils() self.running = False @@ -111,24 +104,6 @@ def stop(self) -> None: """ self.running = False - @staticmethod - def create_elevation_model() -> Optional[ElevationModel]: - """ - Create an elevation model if the relevant options are set in the service configuration. - - :return: Optional[ElevationModel] = the elevation model or None if not configured - """ - if ServiceConfig.elevation_data_location: - return DigitalElevationModel( - SRTMTileSet( - version=ServiceConfig.elevation_data_version, - format_extension=ServiceConfig.elevation_data_extension, - ), - GDALDigitalElevationModelTileFactory(ServiceConfig.elevation_data_location), - ) - - return None - def monitor_work_queues(self) -> None: """ Monitors SQS queues for ImageRequest and RegionRequest The region work queue is checked first and will wait for @@ -201,7 +176,7 @@ def monitor_work_queues(self) -> None: except SelfThrottledRegionException: self.region_request_queue.reset_request( receipt_handle, - visibility_timeout=int(ServiceConfig.throttling_retry_timeout), + visibility_timeout=int(self.config.throttling_retry_timeout), ) except Exception as err: logger.error(f"There was a problem processing the region request: {err}") @@ -260,7 +235,7 @@ def process_image_request(self, image_request: ImageRequest) -> None: """ image_request_item = None try: - if ServiceConfig.self_throttling: + if self.config.self_throttling: max_regions = self.endpoint_utils.calculate_max_regions( image_request.model_name, image_request.model_invocation_role ) @@ -302,7 +277,7 @@ def process_image_request(self, image_request: ImageRequest) -> None: if sensor_model is None: logging.warning( - f"Dataset {image_request_item.image_id} did not have a geo transform. Results are not geo-referenced." + f"Dataset {image_request_item.image_id} has no geo transform. Results are not geo-referenced." ) # If we got valid outputs @@ -460,7 +435,7 @@ def process_region_request( } ) - if ServiceConfig.self_throttling: + if self.config.self_throttling: max_regions = self.endpoint_utils.calculate_max_regions( region_request.model_name, region_request.model_invocation_role ) @@ -485,7 +460,7 @@ def process_region_request( metrics_logger=metrics, ): # Set up our threaded tile worker pool - tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.elevation_model) + tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.config.elevation_model) # Process all our tiles total_tile_count, failed_tile_count = process_tiles( @@ -520,11 +495,11 @@ def process_region_request( logger.error(failed_msg) # update the table to take in that exception region_request_item.message = failed_msg - return self.fail_region_request(region_request_item, metrics) + return self.region_request_table.fail_region_request(region_request_item, self.job_table, metrics) finally: # Decrement the endpoint region counter - if ServiceConfig.self_throttling: + if self.config.self_throttling: self.endpoint_statistics_table.decrement_region_count(region_request.model_name) def load_image_request( @@ -579,7 +554,7 @@ def load_image_request( # Calculate a set of ML engine-sized regions that we need to process for this image # Region size chosen to break large images into pieces that can be handled by a # single tile worker - region_size: ImageDimensions = ast.literal_eval(ServiceConfig.region_size) + region_size: ImageDimensions = ast.literal_eval(self.config.region_size) tile_size: ImageDimensions = ast.literal_eval(image_request_item.tile_size) if not image_request_item.tile_overlap: minimum_overlap = (0, 0) @@ -644,9 +619,9 @@ def complete_image_request( logger.debug(f"Processing boundary from {roi} is {processing_bounds}") # Set up our feature table to work with the region quest - feature_table = FeatureTable(ServiceConfig.feature_table, region_request.tile_size, region_request.tile_overlap) + feature_table = FeatureTable(self.config.feature_table, region_request.tile_size, region_request.tile_overlap) # Aggregate all the features from our job - features = self.aggregate_features(image_request_item, feature_table) + features = feature_table.aggregate_features(image_request_item) features = self.select_features(image_request_item, features, processing_bounds) features = self.add_properties_to_features(image_request_item, features) @@ -704,32 +679,6 @@ def generate_image_processing_metrics( if image_request_item.region_error > 0: metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value)) - def fail_region_request( - self, - region_request_item: RegionRequestItem, - metrics: MetricsLogger = None, - ) -> JobItem: - """ - Fails a region if it failed to process successfully and updates the table accordingly before - raising an exception - - :param region_request_item: RegionRequestItem = the region request to update - :param metrics: MetricsLogger = the metrics logger to use to report metrics. - - :return: None - """ - if isinstance(metrics, MetricsLogger): - metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value)) - try: - region_status = RequestStatus.FAILED - region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status) - self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing") - return self.job_table.complete_region_request(region_request_item.image_id, error=True) - except Exception as status_error: - logger.error("Unable to update region status in job table") - logger.exception(status_error) - raise ProcessRegionException("Failed to process image region!") - def validate_model_hosting(self, image_request: JobItem): """ Validates that the image request is valid. If not, raises an exception. @@ -747,37 +696,6 @@ def validate_model_hosting(self, image_request: JobItem): ) raise UnsupportedModelException(error) - @staticmethod - @metric_scope - def aggregate_features( - image_request_item: JobItem, feature_table: FeatureTable, metrics: MetricsLogger = None - ) -> List[Feature]: - """ - For a given image processing job - aggregate all the features that were collected for it and - put them in the correct output sink locations. - - :param image_request_item: JobItem = the image request - :param feature_table: FeatureTable = the table storing features from all completed regions - :param metrics: the current metrics scope - - :return: List[geojson.Feature] = the list of features - """ - if isinstance(metrics, MetricsLogger): - metrics.set_dimensions() - metrics.put_dimensions( - { - MetricLabels.OPERATION_DIMENSION: MetricLabels.FEATURE_AGG_OPERATION, - } - ) - - with Timer( - task_str="Aggregating Features", metric_name=MetricLabels.DURATION, logger=logger, metrics_logger=metrics - ): - features = feature_table.get_features(image_request_item.image_id) - logger.debug(f"Total features aggregated: {len(features)}") - - return features - @metric_scope def select_features( self, @@ -826,7 +744,7 @@ def select_features( feature_distillation_option = FeatureDistillationDeserializer().deserialize(feature_distillation_option_dict) feature_selector = FeatureSelector(feature_distillation_option) - region_size = ast.literal_eval(ServiceConfig.region_size) + region_size = ast.literal_eval(self.config.region_size) tile_size = ast.literal_eval(image_request_item.tile_size) overlap = ast.literal_eval(image_request_item.tile_overlap) deduped_features = self.tiling_strategy.cleanup_duplicate_features( diff --git a/src/aws/osml/model_runner/app_config.py b/src/aws/osml/model_runner/app_config.py index 4e93a343..8c4f0e27 100755 --- a/src/aws/osml/model_runner/app_config.py +++ b/src/aws/osml/model_runner/app_config.py @@ -1,12 +1,16 @@ # Copyright 2023-2024 Amazon.com, Inc. or its affiliates. import os -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Optional +from aws_embedded_metrics.config import Configuration, get_config from botocore.config import Config +from aws.osml.gdal import GDALDigitalElevationModelTileFactory +from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, SRTMTileSet + @dataclass class ServiceConfig: @@ -15,21 +19,9 @@ class ServiceConfig: operate that are provided through ENV variables. Note that required env parameters are enforced by the implied schema validation as os.environ[] is used to fetch the values. Optional parameters are fetched using, os.getenv(), which returns None. - - The data schema is defined as follows: - region: (str) The AWS region where the Model Runner is deployed. - job_table: (str) The name of the job processing DDB table - region_request_table: (str) The name of the region request processing DDB table - feature_table: (str) The name of the feature aggregation DDB table - image_queue: (str) The name of the image processing SQS queue - region_queue: (str) The name of the region processing SQS queue - workers_per_cpu: (int) The number of workers to launch per CPU - image_timeout: (int) The number of seconds to wait for an image to be processed - region_timeout: (int) The number of seconds to wait for a region to be processed - cp_api_endpoint: (str) The URL of the control plane API endpoint """ - # required env configuration + # Required env configuration aws_region: str = os.environ["AWS_DEFAULT_REGION"] job_table: str = os.environ["JOB_TABLE"] region_request_table: str = os.environ["REGION_REQUEST_TABLE"] @@ -40,12 +32,13 @@ class ServiceConfig: workers_per_cpu: str = os.environ["WORKERS_PER_CPU"] workers: str = os.environ["WORKERS"] - # optional elevation data + # Optional elevation data elevation_data_location: Optional[str] = os.getenv("ELEVATION_DATA_LOCATION") elevation_data_extension: str = os.getenv("ELEVATION_DATA_EXTENSION", ".tif") elevation_data_version: str = os.getenv("ELEVATION_DATA_VERSION", "1arc_v3") + elevation_model: Optional[ElevationModel] = field(init=False, default=None) - # optional env configuration + # Optional env configuration image_status_topic: Optional[str] = os.getenv("IMAGE_STATUS_TOPIC") region_status_topic: Optional[str] = os.getenv("REGION_STATUS_TOPIC") cp_api_endpoint: Optional[str] = os.getenv("API_ENDPOINT") @@ -53,19 +46,56 @@ class ServiceConfig: os.getenv("SM_SELF_THROTTLING", "False") == "True" or os.getenv("SM_SELF_THROTTLING", "False") == "true" ) - # optional + defaulted configuration + # Optional + defaulted configuration region_size: str = os.getenv("REGION_SIZE", "(10240, 10240)") throttling_vcpu_scale_factor: str = os.getenv("THROTTLING_SCALE_FACTOR", "10") - # Time in seconds to set region request visibility timeout when a request - # is self throttled throttling_retry_timeout: str = os.getenv("THROTTLING_RETRY_TIMEOUT", "10") - # constant configuration + # Constant configuration kinesis_max_record_per_batch: str = "500" kinesis_max_record_size_batch: str = "5242880" # 5 MB in bytes kinesis_max_record_size: str = "1048576" # 1 MB in bytes ddb_max_item_size: str = "200000" + # Metrics configuration + metrics_config: Configuration = field(init=False, default=None) + + def __post_init__(self): + """ + Post-initialization method to set up the elevation model. + """ + self.elevation_model = self.create_elevation_model() + self.metrics_config = self.configure_metrics() + + def create_elevation_model(self) -> Optional[ElevationModel]: + """ + Create an elevation model if the relevant options are set in the service configuration. + + :return: Optional[ElevationModel] = the elevation model or None if not configured + """ + if self.elevation_data_location: + return DigitalElevationModel( + SRTMTileSet( + version=self.elevation_data_version, + format_extension=self.elevation_data_extension, + ), + GDALDigitalElevationModelTileFactory(self.elevation_data_location), + ) + return None + + @staticmethod + def configure_metrics(): + """ + Embedded metrics configuration + """ + metrics_config = get_config() + metrics_config.service_name = "OSML" + metrics_config.log_group_name = "/aws/OSML/MRService" + metrics_config.namespace = "OSML/ModelRunner" + metrics_config.environment = "local" + + return metrics_config + @dataclass class BotoConfig: diff --git a/src/aws/osml/model_runner/common/__init__.py b/src/aws/osml/model_runner/common/__init__.py index 91985d46..cd6d30f7 100755 --- a/src/aws/osml/model_runner/common/__init__.py +++ b/src/aws/osml/model_runner/common/__init__.py @@ -10,7 +10,6 @@ from .exceptions import InvalidAssumedRoleException from .feature_utils import get_feature_image_bounds from .log_context import ThreadingLocalContextFilter -from .metrics_utils import build_embedded_metrics_config from .mr_post_processing import ( FeatureDistillationAlgorithm, FeatureDistillationAlgorithmType, diff --git a/src/aws/osml/model_runner/common/metrics_utils.py b/src/aws/osml/model_runner/common/metrics_utils.py deleted file mode 100755 index 58545052..00000000 --- a/src/aws/osml/model_runner/common/metrics_utils.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. - -from aws_embedded_metrics.config import get_config - - -def build_embedded_metrics_config(): - """ - Embedded metrics configuration - """ - metrics_config = get_config() - metrics_config.service_name = "OSML" - metrics_config.log_group_name = "/aws/OSML/MRService" - metrics_config.namespace = "OSML/ModelRunner" - metrics_config.environment = "local" diff --git a/src/aws/osml/model_runner/database/exceptions.py b/src/aws/osml/model_runner/database/exceptions.py index 7bac3799..cee22e93 100755 --- a/src/aws/osml/model_runner/database/exceptions.py +++ b/src/aws/osml/model_runner/database/exceptions.py @@ -48,3 +48,7 @@ class UpdateRegionException(Exception): class CompleteRegionException(Exception): pass + + +class ProcessRegionException(Exception): + pass diff --git a/src/aws/osml/model_runner/database/feature_table.py b/src/aws/osml/model_runner/database/feature_table.py index 34fefb41..0c88dea3 100755 --- a/src/aws/osml/model_runner/database/feature_table.py +++ b/src/aws/osml/model_runner/database/feature_table.py @@ -21,6 +21,7 @@ from .ddb_helper import DDBHelper, DDBItem, DDBKey from .exceptions import AddFeaturesException +from .job_table import JobItem logger = logging.getLogger(__name__) @@ -245,3 +246,30 @@ def generate_tile_key(self, feature: Feature) -> str: min_y_index -= 1 return f"{feature['properties']['image_id']}-region-{min_x_index}:{max_x_index}:{min_y_index}:{max_y_index}" + + @metric_scope + def aggregate_features(self, image_request_item: JobItem, metrics: MetricsLogger = None) -> List[Feature]: + """ + For a given image processing job - aggregate all the features that were collected for it and + put them in the correct output sink locations. + + :param image_request_item: JobItem = the image request + :param metrics: the current metrics scope + + :return: List[geojson.Feature] = the list of features + """ + if isinstance(metrics, MetricsLogger): + metrics.set_dimensions() + metrics.put_dimensions( + { + MetricLabels.OPERATION_DIMENSION: MetricLabels.FEATURE_AGG_OPERATION, + } + ) + + with Timer( + task_str="Aggregating Features", metric_name=MetricLabels.DURATION, logger=logger, metrics_logger=metrics + ): + features = self.get_features(image_request_item.image_id) + logger.debug(f"Total features aggregated: {len(features)}") + + return features diff --git a/src/aws/osml/model_runner/database/region_request_table.py b/src/aws/osml/model_runner/database/region_request_table.py index 9d29b37e..63cffc14 100755 --- a/src/aws/osml/model_runner/database/region_request_table.py +++ b/src/aws/osml/model_runner/database/region_request_table.py @@ -5,13 +5,24 @@ from dataclasses import dataclass from typing import Any, List, Optional +from aws_embedded_metrics import MetricsLogger +from aws_embedded_metrics.unit import Unit from dacite import from_dict from aws.osml.model_runner.api import RegionRequest from aws.osml.model_runner.common import ImageRegion, RequestStatus, TileState +from ..app_config import MetricLabels +from . import JobTable from .ddb_helper import DDBHelper, DDBItem, DDBKey -from .exceptions import CompleteRegionException, GetRegionRequestItemException, StartRegionException, UpdateRegionException +from .exceptions import ( + CompleteRegionException, + GetRegionRequestItemException, + ProcessRegionException, + StartRegionException, + UpdateRegionException, +) +from .job_table import JobItem logger = logging.getLogger(__name__) @@ -233,3 +244,31 @@ def add_tile(self, image_id: str, region_id: str, tile: ImageRegion, state: Tile except Exception as err: logger.error(f"Failed to append {state.value} {tile} to item region_id={region_id}: {str(err)}") raise UpdateRegionException(f"Failed to append {state.value} {tile} to item region_id={region_id}.") from err + + def fail_region_request( + self, + region_request_item: RegionRequestItem, + job_table: JobTable, + metrics: MetricsLogger = None, + ) -> JobItem: + """ + Fails a region if it failed to process successfully and updates the table accordingly before + raising an exception + + :param region_request_item: RegionRequestItem = the region request to update. + :param job_table: JobTable = the job table tracking image requests. + :param metrics: MetricsLogger = the metrics logger to use to report metrics. + + :return: None + """ + if isinstance(metrics, MetricsLogger): + metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value)) + try: + region_status = RequestStatus.FAILED + region_request_item = self.table.complete_region_request(region_request_item, region_status) + self.table.process_event(region_request_item, region_status, "Completed region processing") + return job_table.complete_region_request(region_request_item.image_id, error=True) + except Exception as status_error: + logger.error("Unable to update region status in job table") + logger.exception(status_error) + raise ProcessRegionException("Failed to process image region!") diff --git a/test/test_app.py b/test/test_app.py index e00a9a0f..052cb52f 100755 --- a/test/test_app.py +++ b/test/test_app.py @@ -347,14 +347,14 @@ def test_process_region_request_exception(self, mock_process_tiles: Mock, mock_s self.model_runner.region_request_table = Mock(RegionRequestTable, autospec=True) self.model_runner.endpoint_statistics_table = Mock(EndpointStatisticsTable, autospec=True) self.model_runner.endpoint_statistics_table.current_in_progress_regions.return_value = 0 - self.model_runner.fail_region_request = Mock() + self.model_runner.region_request_table.fail_region_request = Mock() mock_setup_tile_workers.return_value = (Queue(), []) mock_process_tiles.side_effect = Exception("Mock processing exception") self.model_runner.process_region_request(self.region_request, region_request_item, raster_dataset, sensor_model) mock_process_tiles.assert_called_once() - self.model_runner.fail_region_request.assert_called_once() + self.model_runner.region_request_table.fail_region_request.assert_called_once() model = TEST_CONFIG["MODEL_ENDPOINT"] self.model_runner.endpoint_statistics_table.increment_region_count.assert_called_once_with(model) @@ -455,16 +455,14 @@ def test_create_elevation_model(self) -> None: import aws.osml.model_runner.app_config reload(aws.osml.model_runner.app_config) - reload(aws.osml.model_runner.app) from aws.osml.gdal.gdal_dem_tile_factory import GDALDigitalElevationModelTileFactory - from aws.osml.model_runner.app import ModelRunner from aws.osml.model_runner.app_config import ServiceConfig from aws.osml.photogrammetry.digital_elevation_model import DigitalElevationModel from aws.osml.photogrammetry.srtm_dem_tile_set import SRTMTileSet assert ServiceConfig.elevation_data_location == TEST_CONFIG["ELEVATION_DATA_LOCATION"] - - elevation_model = ModelRunner.create_elevation_model() + config = ServiceConfig() + elevation_model = config.create_elevation_model() assert elevation_model assert isinstance(elevation_model, DigitalElevationModel) assert isinstance(elevation_model.tile_set, SRTMTileSet) @@ -485,13 +483,12 @@ def test_create_elevation_model_disabled(self) -> None: import aws.osml.model_runner.app_config reload(aws.osml.model_runner.app_config) - reload(aws.osml.model_runner.app) - from aws.osml.model_runner.app import ModelRunner from aws.osml.model_runner.app_config import ServiceConfig assert ServiceConfig.elevation_data_location is None + config = ServiceConfig() + elevation_model = config.create_elevation_model() - elevation_model = ModelRunner.create_elevation_model() assert not elevation_model @staticmethod