Skip to content

Commit

Permalink
refactor: break out request helpers into new classes from app.py
Browse files Browse the repository at this point in the history
  • Loading branch information
drduhe committed Oct 7, 2024
1 parent ceb7168 commit 88777dc
Show file tree
Hide file tree
Showing 19 changed files with 531 additions and 477 deletions.
2 changes: 1 addition & 1 deletion bin/oversightml-mr-entry-point.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from codeguru_profiler_agent import Profiler
from pythonjsonlogger import jsonlogger

from aws.osml.model_runner.app import ModelRunner
from aws.osml.model_runner.common import ThreadingLocalContextFilter
from aws.osml.model_runner.model_runner import ModelRunner


def handler_stop_signals(signal_num: int, frame: Optional[FrameType], model_runner: ModelRunner) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/aws/osml/model_runner/api/image_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import shapely.wkt
from shapely.geometry.base import BaseGeometry

from aws.osml.model_runner.app_config import BotoConfig
from aws.osml.model_runner.common import (
FeatureDistillationAlgorithm,
FeatureDistillationNMS,
Expand All @@ -21,6 +20,7 @@
deserialize_post_processing_list,
get_credentials_for_assumed_role,
)
from aws.osml.model_runner.config import BotoConfig

from .exceptions import InvalidS3ObjectException
from .inference import ModelInvokeMode
Expand Down
2 changes: 1 addition & 1 deletion src/aws/osml/model_runner/common/credentials_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import boto3

from aws.osml.model_runner.app_config import BotoConfig
from aws.osml.model_runner.config import BotoConfig

from .exceptions import InvalidAssumedRoleException

Expand Down
2 changes: 1 addition & 1 deletion src/aws/osml/model_runner/common/endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import boto3
from cachetools import TTLCache, cachedmethod

from aws.osml.model_runner.app_config import BotoConfig, ServiceConfig
from aws.osml.model_runner.config import BotoConfig, ServiceConfig

from .credentials_utils import get_credentials_for_assumed_role

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/aws/osml/model_runner/database/ddb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import boto3
from boto3.dynamodb.conditions import Key

from aws.osml.model_runner.app_config import BotoConfig
from aws.osml.model_runner.config import BotoConfig

from .exceptions import DDBBatchWriteException, DDBUpdateException

Expand Down
2 changes: 1 addition & 1 deletion src/aws/osml/model_runner/database/feature_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from dacite import from_dict
from geojson import Feature

from aws.osml.model_runner.app_config import MetricLabels, ServiceConfig
from aws.osml.model_runner.common import ImageDimensions, Timer, get_feature_image_bounds
from aws.osml.model_runner.config import MetricLabels, ServiceConfig

from .ddb_helper import DDBHelper, DDBItem, DDBKey
from .exceptions import AddFeaturesException
Expand Down
593 changes: 133 additions & 460 deletions src/aws/osml/model_runner/app.py → ...sml/model_runner/image_request_handler.py
100755 → 100644

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/aws/osml/model_runner/inference/http_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from urllib3.util.retry import Retry

from aws.osml.model_runner.api import ModelInvokeMode
from aws.osml.model_runner.app_config import MetricLabels
from aws.osml.model_runner.common import Timer
from aws.osml.model_runner.config import MetricLabels

from .detector import Detector
from .endpoint_builder import FeatureEndpointBuilder
Expand Down
2 changes: 1 addition & 1 deletion src/aws/osml/model_runner/inference/sm_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from geojson import FeatureCollection

from aws.osml.model_runner.api import ModelInvokeMode
from aws.osml.model_runner.app_config import BotoConfig, MetricLabels
from aws.osml.model_runner.common import Timer
from aws.osml.model_runner.config import BotoConfig, MetricLabels

from .detector import Detector
from .endpoint_builder import FeatureEndpointBuilder
Expand Down
179 changes: 179 additions & 0 deletions src/aws/osml/model_runner/model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2023-2024 Amazon.com, Inc. or its affiliates.

import logging

from osgeo import gdal

from aws.osml.gdal import load_gdal_dataset, set_gdal_default_configuration

from .api import ImageRequest, InvalidImageRequestException, RegionRequest
from .common import EndpointUtils, ThreadingLocalContextFilter
from .config import ServiceConfig
from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable
from .exceptions import RetryableJobException, SelfThrottledRegionException
from .image_request_handler import ImageRequestHandler
from .queue import RequestQueue
from .region_request_handler import RegionRequestHandler
from .status import ImageStatusMonitor, RegionStatusMonitor
from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy

# Set up logging configuration
logger = logging.getLogger(__name__)

# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume
# no exceptions so we call this explicitly until the software can be updated to match.
gdal.UseExceptions()


class ModelRunner:
"""
Main class for operating the ModelRunner application. It monitors input queues for processing requests,
decomposes the image into a set of smaller regions and tiles, invokes an ML model endpoint with each tile, and
finally aggregates all the results into a single output which can be deposited into the desired output sinks.
"""

def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrategy()) -> None:
"""
Initialize a model runner with the injectable behaviors.
:param tiling_strategy: class defining how a larger image will be broken into chunks for processing
"""
self.config = ServiceConfig()
self.tiling_strategy = tiling_strategy
self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0)
self.image_requests_iter = iter(self.image_request_queue)
self.job_table = JobTable(self.config.job_table)
self.region_request_table = RegionRequestTable(self.config.region_request_table)
self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table)
self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10)
self.region_requests_iter = iter(self.region_request_queue)
self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic)
self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic)
self.endpoint_utils = EndpointUtils()
self.running = False

# Pass dependencies into RegionRequestHandler
self.region_request_handler = RegionRequestHandler(
region_request_table=self.region_request_table,
job_table=self.job_table,
region_status_monitor=self.region_status_monitor,
endpoint_statistics_table=self.endpoint_statistics_table,
tiling_strategy=self.tiling_strategy,
region_request_queue=self.region_request_queue,
endpoint_utils=self.endpoint_utils,
config=self.config,
)

# Pass dependencies into ImageRequestHandler
self.image_request_handler = ImageRequestHandler(
job_table=self.job_table,
image_status_monitor=self.image_status_monitor,
endpoint_statistics_table=self.endpoint_statistics_table,
tiling_strategy=self.tiling_strategy,
region_request_queue=self.region_request_queue,
region_request_table=self.region_request_table,
endpoint_utils=self.endpoint_utils,
config=self.config,
region_request_handler=self.region_request_handler,
)

def run(self) -> None:
"""
Starts ModelRunner in a loop that continuously monitors the image work queue and region work queue.
"""
self.monitor_work_queues()

def stop(self) -> None:
"""
Stops ModelRunner by setting the global run variable to False.
"""
self.running = False

def monitor_work_queues(self) -> None:
"""
Monitors SQS queues for ImageRequest and RegionRequest.
"""
self.running = True
set_gdal_default_configuration()

try:
while self.running:
logger.debug("Checking work queue for regions to process ...")
(receipt_handle, region_request_attributes) = next(self.region_requests_iter)
ThreadingLocalContextFilter.set_context(region_request_attributes)

if region_request_attributes is not None:
try:
region_request = RegionRequest(region_request_attributes)

if "s3:/" in region_request.image_url:
ImageRequest.validate_image_path(region_request.image_url, region_request.image_read_role)
image_path = region_request.image_url.replace("s3:/", "/vsis3", 1)
else:
image_path = region_request.image_url

raster_dataset, sensor_model = load_gdal_dataset(image_path)
image_format = str(raster_dataset.GetDriver().ShortName).upper()

region_request_item = self.region_request_table.get_region_request(
region_request.region_id, region_request.image_id
)
if region_request_item is None:
region_request_item = RegionRequestItem.from_region_request(region_request)
self.region_request_table.start_region_request(region_request_item)
logging.debug(
f"Adding region request: image id: {region_request_item.image_id} - "
f"region id: {region_request_item.region_id}"
)

image_request_item = self.region_request_handler.process_region_request(
region_request, region_request_item, raster_dataset, sensor_model
)

if self.job_table.is_image_request_complete(image_request_item):
self.image_request_handler.complete_image_request(
region_request, image_format, raster_dataset, sensor_model
)

self.region_request_queue.finish_request(receipt_handle)
except RetryableJobException:
self.region_request_queue.reset_request(receipt_handle, visibility_timeout=0)
except SelfThrottledRegionException:
self.region_request_queue.reset_request(
receipt_handle, visibility_timeout=int(self.config.throttling_retry_timeout)
)
except Exception as err:
logger.error(f"There was a problem processing the region request: {err}")
self.region_request_queue.finish_request(receipt_handle)
else:
logger.debug("Checking work queue for images to process ...")
(receipt_handle, image_request_message) = next(self.image_requests_iter)

if image_request_message is not None:
image_request = None
try:
image_request = ImageRequest.from_external_message(image_request_message)
ThreadingLocalContextFilter.set_context(image_request.__dict__)

if not image_request.is_valid():
error = f"Invalid image request: {image_request_message}"
logger.exception(error)
raise InvalidImageRequestException(error)

self.image_request_handler.process_image_request(image_request)
self.image_request_queue.finish_request(receipt_handle)
except RetryableJobException:
self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0)
except Exception as err:
logger.error(f"There was a problem processing the image request: {err}")
min_image_id = image_request.image_id if image_request and image_request.image_id else ""
min_job_id = image_request.job_id if image_request and image_request.job_id else ""
minimal_job_item = JobItem(
image_id=min_image_id,
job_id=min_job_id,
processing_duration=0,
)
self.image_request_handler.fail_image_request(minimal_job_item, err)
self.image_request_queue.finish_request(receipt_handle)
finally:
self.running = False
2 changes: 1 addition & 1 deletion src/aws/osml/model_runner/queue/request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import boto3
from botocore.exceptions import ClientError

from aws.osml.model_runner.app_config import BotoConfig
from aws.osml.model_runner.config import BotoConfig


class RequestQueue:
Expand Down
Loading

0 comments on commit 88777dc

Please sign in to comment.