Skip to content

Commit

Permalink
chore: refactoring app.py - phase 1
Browse files Browse the repository at this point in the history
  • Loading branch information
drduhe committed Oct 4, 2024
1 parent 620d368 commit fec578e
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 152 deletions.
132 changes: 25 additions & 107 deletions src/aws/osml/model_runner/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,8 @@
from osgeo import gdal
from osgeo.gdal import Dataset

from aws.osml.gdal import (
GDALConfigEnv,
GDALDigitalElevationModelTileFactory,
get_image_extension,
load_gdal_dataset,
set_gdal_default_configuration,
)
from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, ImageCoordinate, SensorModel, SRTMTileSet
from aws.osml.gdal import GDALConfigEnv, get_image_extension, load_gdal_dataset, set_gdal_default_configuration
from aws.osml.photogrammetry import ImageCoordinate, SensorModel

from .api import VALID_MODEL_HOSTING_OPTIONS, ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode
from .app_config import MetricLabels, ServiceConfig
Expand All @@ -38,7 +32,6 @@
RequestStatus,
ThreadingLocalContextFilter,
Timer,
build_embedded_metrics_config,
get_credentials_for_assumed_role,
mr_post_processing_options_factory,
)
Expand All @@ -50,7 +43,6 @@
InvalidImageURLException,
LoadImageException,
ProcessImageException,
ProcessRegionException,
RetryableJobException,
SelfThrottledRegionException,
UnsupportedModelException,
Expand All @@ -61,12 +53,13 @@
from .status import ImageStatusMonitor, RegionStatusMonitor
from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy, process_tiles, setup_tile_workers

# Set up metrics configuration
build_embedded_metrics_config()

# Set up logging configuration
logger = logging.getLogger(__name__)

# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume
# no exceptions so we call this explicitly until the software can be updated to match.
gdal.UseExceptions()


class ModelRunner:
"""
Expand All @@ -81,17 +74,17 @@ def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrate
:param tiling_strategy: class defining how a larger image will be broken into chunks for processing
"""
self.config = ServiceConfig()
self.tiling_strategy = tiling_strategy
self.image_request_queue = RequestQueue(ServiceConfig.image_queue, wait_seconds=0)
self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0)
self.image_requests_iter = iter(self.image_request_queue)
self.job_table = JobTable(ServiceConfig.job_table)
self.region_request_table = RegionRequestTable(ServiceConfig.region_request_table)
self.endpoint_statistics_table = EndpointStatisticsTable(ServiceConfig.endpoint_statistics_table)
self.region_request_queue = RequestQueue(ServiceConfig.region_queue, wait_seconds=10)
self.job_table = JobTable(self.config.job_table)
self.region_request_table = RegionRequestTable(self.config.region_request_table)
self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table)
self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10)
self.region_requests_iter = iter(self.region_request_queue)
self.image_status_monitor = ImageStatusMonitor(ServiceConfig.image_status_topic)
self.region_status_monitor = RegionStatusMonitor(ServiceConfig.region_status_topic)
self.elevation_model = ModelRunner.create_elevation_model()
self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic)
self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic)
self.endpoint_utils = EndpointUtils()
self.running = False

Expand All @@ -111,24 +104,6 @@ def stop(self) -> None:
"""
self.running = False

@staticmethod
def create_elevation_model() -> Optional[ElevationModel]:
"""
Create an elevation model if the relevant options are set in the service configuration.
:return: Optional[ElevationModel] = the elevation model or None if not configured
"""
if ServiceConfig.elevation_data_location:
return DigitalElevationModel(
SRTMTileSet(
version=ServiceConfig.elevation_data_version,
format_extension=ServiceConfig.elevation_data_extension,
),
GDALDigitalElevationModelTileFactory(ServiceConfig.elevation_data_location),
)

return None

def monitor_work_queues(self) -> None:
"""
Monitors SQS queues for ImageRequest and RegionRequest The region work queue is checked first and will wait for
Expand Down Expand Up @@ -201,7 +176,7 @@ def monitor_work_queues(self) -> None:
except SelfThrottledRegionException:
self.region_request_queue.reset_request(
receipt_handle,
visibility_timeout=int(ServiceConfig.throttling_retry_timeout),
visibility_timeout=int(self.config.throttling_retry_timeout),
)
except Exception as err:
logger.error(f"There was a problem processing the region request: {err}")
Expand Down Expand Up @@ -260,7 +235,7 @@ def process_image_request(self, image_request: ImageRequest) -> None:
"""
image_request_item = None
try:
if ServiceConfig.self_throttling:
if self.config.self_throttling:
max_regions = self.endpoint_utils.calculate_max_regions(
image_request.model_name, image_request.model_invocation_role
)
Expand Down Expand Up @@ -302,7 +277,7 @@ def process_image_request(self, image_request: ImageRequest) -> None:

if sensor_model is None:
logging.warning(
f"Dataset {image_request_item.image_id} did not have a geo transform. Results are not geo-referenced."
f"Dataset {image_request_item.image_id} has no geo transform. Results are not geo-referenced."
)

# If we got valid outputs
Expand Down Expand Up @@ -460,7 +435,7 @@ def process_region_request(
}
)

if ServiceConfig.self_throttling:
if self.config.self_throttling:
max_regions = self.endpoint_utils.calculate_max_regions(
region_request.model_name, region_request.model_invocation_role
)
Expand All @@ -485,7 +460,7 @@ def process_region_request(
metrics_logger=metrics,
):
# Set up our threaded tile worker pool
tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.elevation_model)
tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.config.elevation_model)

# Process all our tiles
total_tile_count, failed_tile_count = process_tiles(
Expand Down Expand Up @@ -520,11 +495,11 @@ def process_region_request(
logger.error(failed_msg)
# update the table to take in that exception
region_request_item.message = failed_msg
return self.fail_region_request(region_request_item, metrics)
return self.region_request_table.fail_region_request(region_request_item, self.job_table, metrics)

finally:
# Decrement the endpoint region counter
if ServiceConfig.self_throttling:
if self.config.self_throttling:
self.endpoint_statistics_table.decrement_region_count(region_request.model_name)

def load_image_request(
Expand Down Expand Up @@ -579,7 +554,7 @@ def load_image_request(
# Calculate a set of ML engine-sized regions that we need to process for this image
# Region size chosen to break large images into pieces that can be handled by a
# single tile worker
region_size: ImageDimensions = ast.literal_eval(ServiceConfig.region_size)
region_size: ImageDimensions = ast.literal_eval(self.config.region_size)
tile_size: ImageDimensions = ast.literal_eval(image_request_item.tile_size)
if not image_request_item.tile_overlap:
minimum_overlap = (0, 0)
Expand Down Expand Up @@ -644,9 +619,9 @@ def complete_image_request(
logger.debug(f"Processing boundary from {roi} is {processing_bounds}")

# Set up our feature table to work with the region quest
feature_table = FeatureTable(ServiceConfig.feature_table, region_request.tile_size, region_request.tile_overlap)
feature_table = FeatureTable(self.config.feature_table, region_request.tile_size, region_request.tile_overlap)
# Aggregate all the features from our job
features = self.aggregate_features(image_request_item, feature_table)
features = feature_table.aggregate_features(image_request_item)
features = self.select_features(image_request_item, features, processing_bounds)
features = self.add_properties_to_features(image_request_item, features)

Expand Down Expand Up @@ -704,32 +679,6 @@ def generate_image_processing_metrics(
if image_request_item.region_error > 0:
metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value))

def fail_region_request(
self,
region_request_item: RegionRequestItem,
metrics: MetricsLogger = None,
) -> JobItem:
"""
Fails a region if it failed to process successfully and updates the table accordingly before
raising an exception
:param region_request_item: RegionRequestItem = the region request to update
:param metrics: MetricsLogger = the metrics logger to use to report metrics.
:return: None
"""
if isinstance(metrics, MetricsLogger):
metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value))
try:
region_status = RequestStatus.FAILED
region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status)
self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing")
return self.job_table.complete_region_request(region_request_item.image_id, error=True)
except Exception as status_error:
logger.error("Unable to update region status in job table")
logger.exception(status_error)
raise ProcessRegionException("Failed to process image region!")

def validate_model_hosting(self, image_request: JobItem):
"""
Validates that the image request is valid. If not, raises an exception.
Expand All @@ -747,37 +696,6 @@ def validate_model_hosting(self, image_request: JobItem):
)
raise UnsupportedModelException(error)

@staticmethod
@metric_scope
def aggregate_features(
image_request_item: JobItem, feature_table: FeatureTable, metrics: MetricsLogger = None
) -> List[Feature]:
"""
For a given image processing job - aggregate all the features that were collected for it and
put them in the correct output sink locations.
:param image_request_item: JobItem = the image request
:param feature_table: FeatureTable = the table storing features from all completed regions
:param metrics: the current metrics scope
:return: List[geojson.Feature] = the list of features
"""
if isinstance(metrics, MetricsLogger):
metrics.set_dimensions()
metrics.put_dimensions(
{
MetricLabels.OPERATION_DIMENSION: MetricLabels.FEATURE_AGG_OPERATION,
}
)

with Timer(
task_str="Aggregating Features", metric_name=MetricLabels.DURATION, logger=logger, metrics_logger=metrics
):
features = feature_table.get_features(image_request_item.image_id)
logger.debug(f"Total features aggregated: {len(features)}")

return features

@metric_scope
def select_features(
self,
Expand Down Expand Up @@ -826,7 +744,7 @@ def select_features(
feature_distillation_option = FeatureDistillationDeserializer().deserialize(feature_distillation_option_dict)
feature_selector = FeatureSelector(feature_distillation_option)

region_size = ast.literal_eval(ServiceConfig.region_size)
region_size = ast.literal_eval(self.config.region_size)
tile_size = ast.literal_eval(image_request_item.tile_size)
overlap = ast.literal_eval(image_request_item.tile_overlap)
deduped_features = self.tiling_strategy.cleanup_duplicate_features(
Expand Down
70 changes: 50 additions & 20 deletions src/aws/osml/model_runner/app_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright 2023-2024 Amazon.com, Inc. or its affiliates.

import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

from aws_embedded_metrics.config import Configuration, get_config
from botocore.config import Config

from aws.osml.gdal import GDALDigitalElevationModelTileFactory
from aws.osml.photogrammetry import DigitalElevationModel, ElevationModel, SRTMTileSet


@dataclass
class ServiceConfig:
Expand All @@ -15,21 +19,9 @@ class ServiceConfig:
operate that are provided through ENV variables. Note that required env parameters are enforced by the implied
schema validation as os.environ[] is used to fetch the values. Optional parameters are fetched using, os.getenv(),
which returns None.
The data schema is defined as follows:
region: (str) The AWS region where the Model Runner is deployed.
job_table: (str) The name of the job processing DDB table
region_request_table: (str) The name of the region request processing DDB table
feature_table: (str) The name of the feature aggregation DDB table
image_queue: (str) The name of the image processing SQS queue
region_queue: (str) The name of the region processing SQS queue
workers_per_cpu: (int) The number of workers to launch per CPU
image_timeout: (int) The number of seconds to wait for an image to be processed
region_timeout: (int) The number of seconds to wait for a region to be processed
cp_api_endpoint: (str) The URL of the control plane API endpoint
"""

# required env configuration
# Required env configuration
aws_region: str = os.environ["AWS_DEFAULT_REGION"]
job_table: str = os.environ["JOB_TABLE"]
region_request_table: str = os.environ["REGION_REQUEST_TABLE"]
Expand All @@ -40,32 +32,70 @@ class ServiceConfig:
workers_per_cpu: str = os.environ["WORKERS_PER_CPU"]
workers: str = os.environ["WORKERS"]

# optional elevation data
# Optional elevation data
elevation_data_location: Optional[str] = os.getenv("ELEVATION_DATA_LOCATION")
elevation_data_extension: str = os.getenv("ELEVATION_DATA_EXTENSION", ".tif")
elevation_data_version: str = os.getenv("ELEVATION_DATA_VERSION", "1arc_v3")
elevation_model: Optional[ElevationModel] = field(init=False, default=None)

# optional env configuration
# Optional env configuration
image_status_topic: Optional[str] = os.getenv("IMAGE_STATUS_TOPIC")
region_status_topic: Optional[str] = os.getenv("REGION_STATUS_TOPIC")
cp_api_endpoint: Optional[str] = os.getenv("API_ENDPOINT")
self_throttling: bool = (
os.getenv("SM_SELF_THROTTLING", "False") == "True" or os.getenv("SM_SELF_THROTTLING", "False") == "true"
)

# optional + defaulted configuration
# Optional + defaulted configuration
region_size: str = os.getenv("REGION_SIZE", "(10240, 10240)")
throttling_vcpu_scale_factor: str = os.getenv("THROTTLING_SCALE_FACTOR", "10")
# Time in seconds to set region request visibility timeout when a request
# is self throttled
throttling_retry_timeout: str = os.getenv("THROTTLING_RETRY_TIMEOUT", "10")

# constant configuration
# Constant configuration
kinesis_max_record_per_batch: str = "500"
kinesis_max_record_size_batch: str = "5242880" # 5 MB in bytes
kinesis_max_record_size: str = "1048576" # 1 MB in bytes
ddb_max_item_size: str = "200000"

# Metrics configuration
metrics_config: Configuration = field(init=False, default=None)

def __post_init__(self):
"""
Post-initialization method to set up the elevation model.
"""
self.elevation_model = self.create_elevation_model()
self.metrics_config = self.configure_metrics()

def create_elevation_model(self) -> Optional[ElevationModel]:
"""
Create an elevation model if the relevant options are set in the service configuration.
:return: Optional[ElevationModel] = the elevation model or None if not configured
"""
if self.elevation_data_location:
return DigitalElevationModel(
SRTMTileSet(
version=self.elevation_data_version,
format_extension=self.elevation_data_extension,
),
GDALDigitalElevationModelTileFactory(self.elevation_data_location),
)
return None

@staticmethod
def configure_metrics():
"""
Embedded metrics configuration
"""
metrics_config = get_config()
metrics_config.service_name = "OSML"
metrics_config.log_group_name = "/aws/OSML/MRService"
metrics_config.namespace = "OSML/ModelRunner"
metrics_config.environment = "local"

return metrics_config


@dataclass
class BotoConfig:
Expand Down
1 change: 0 additions & 1 deletion src/aws/osml/model_runner/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .exceptions import InvalidAssumedRoleException
from .feature_utils import get_feature_image_bounds
from .log_context import ThreadingLocalContextFilter
from .metrics_utils import build_embedded_metrics_config
from .mr_post_processing import (
FeatureDistillationAlgorithm,
FeatureDistillationAlgorithmType,
Expand Down
Loading

0 comments on commit fec578e

Please sign in to comment.