diff --git a/src/aws/osml/model_runner/app_config.py b/src/aws/osml/model_runner/app_config.py index 25112afa..4e93a343 100755 --- a/src/aws/osml/model_runner/app_config.py +++ b/src/aws/osml/model_runner/app_config.py @@ -65,8 +65,6 @@ class ServiceConfig: kinesis_max_record_size_batch: str = "5242880" # 5 MB in bytes kinesis_max_record_size: str = "1048576" # 1 MB in bytes ddb_max_item_size: str = "200000" - noop_bounds_model_name: str = "NOOP_BOUNDS_MODEL_NAME" - noop_geom_model_name: str = "NOOP_GEOM_MODEL_NAME" @dataclass diff --git a/src/aws/osml/model_runner/inference/http_detector.py b/src/aws/osml/model_runner/inference/http_detector.py index 7fc4f4ab..2bf46250 100644 --- a/src/aws/osml/model_runner/inference/http_detector.py +++ b/src/aws/osml/model_runner/inference/http_detector.py @@ -16,12 +16,11 @@ from urllib3.util.retry import Retry from aws.osml.model_runner.api import ModelInvokeMode -from aws.osml.model_runner.app_config import MetricLabels, ServiceConfig +from aws.osml.model_runner.app_config import MetricLabels from aws.osml.model_runner.common import Timer from .detector import Detector from .endpoint_builder import FeatureEndpointBuilder -from .feature_utils import create_mock_feature_collection logger = logging.getLogger(__name__) @@ -153,21 +152,16 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat logger=logger, metrics_logger=metrics, ): - if self.endpoint == ServiceConfig.noop_geom_model_name: - return create_mock_feature_collection(payload, geom=True) - elif self.endpoint == ServiceConfig.noop_bounds_model_name: - return create_mock_feature_collection(payload) - else: - response = self.http_pool.request( - method="POST", - url=self.endpoint, - body=payload, - ) - retry_count = self.retry.retry_counts - if isinstance(metrics, MetricsLogger): - metrics.put_metric(MetricLabels.RETRIES, retry_count, str(Unit.COUNT.value)) - - return geojson.loads(response.data.decode("utf-8")) + response = self.http_pool.request( + method="POST", + url=self.endpoint, + body=payload, + ) + retry_count = self.retry.retry_counts + if isinstance(metrics, MetricsLogger): + metrics.put_metric(MetricLabels.RETRIES, retry_count, str(Unit.COUNT.value)) + + return geojson.loads(response.data.decode("utf-8")) except RetryError as err: if isinstance(metrics, MetricsLogger): diff --git a/src/aws/osml/model_runner/inference/sm_detector.py b/src/aws/osml/model_runner/inference/sm_detector.py index 7df7069b..093518b0 100644 --- a/src/aws/osml/model_runner/inference/sm_detector.py +++ b/src/aws/osml/model_runner/inference/sm_detector.py @@ -14,12 +14,11 @@ from geojson import FeatureCollection from aws.osml.model_runner.api import ModelInvokeMode -from aws.osml.model_runner.app_config import BotoConfig, MetricLabels, ServiceConfig +from aws.osml.model_runner.app_config import BotoConfig, MetricLabels from aws.osml.model_runner.common import Timer from .detector import Detector from .endpoint_builder import FeatureEndpointBuilder -from .feature_utils import create_mock_feature_collection logger = logging.getLogger(__name__) @@ -100,20 +99,14 @@ def find_features(self, payload: BufferedReader, metrics: MetricsLogger) -> Feat logger=logger, metrics_logger=metrics, ): - # Handle mock models for testing purposes - if self.endpoint == ServiceConfig.noop_bounds_model_name: - return create_mock_feature_collection(payload) - elif self.endpoint == ServiceConfig.noop_geom_model_name: - return create_mock_feature_collection(payload, geom=True) - else: - # Invoke the real SageMaker model endpoint - model_response = self.sm_client.invoke_endpoint(EndpointName=self.endpoint, Body=payload) - retry_count = model_response.get("ResponseMetadata", {}).get("RetryAttempts", 0) - if isinstance(metrics, MetricsLogger): - metrics.put_metric(MetricLabels.RETRIES, retry_count, str(Unit.COUNT.value)) - - # Parse the model's response as a geojson FeatureCollection - return geojson.loads(model_response.get("Body").read()) + # Invoke the real SageMaker model endpoint + model_response = self.sm_client.invoke_endpoint(EndpointName=self.endpoint, Body=payload) + retry_count = model_response.get("ResponseMetadata", {}).get("RetryAttempts", 0) + if isinstance(metrics, MetricsLogger): + metrics.put_metric(MetricLabels.RETRIES, retry_count, str(Unit.COUNT.value)) + + # Parse the model's response as a geojson FeatureCollection + return geojson.loads(model_response.get("Body").read()) except ClientError as ce: error_code = ce.response.get("Error", {}).get("Code") diff --git a/src/aws/osml/model_runner/tile_worker/tile_worker.py b/src/aws/osml/model_runner/tile_worker/tile_worker.py index b412ed9c..7b6badfa 100755 --- a/src/aws/osml/model_runner/tile_worker/tile_worker.py +++ b/src/aws/osml/model_runner/tile_worker/tile_worker.py @@ -108,7 +108,7 @@ def process_tile(self, image_info: Dict, metrics: MetricsLogger = None) -> None: ) except Exception as e: self.failed_tile_count += 1 - logging.error(f"Failed to process region tile with error: {e.with_traceback()}") + logging.error(f"Failed to process region tile with error: {e}", exc_info=True) self.region_request_table.add_tile( image_info.get("image_id"), image_info.get("region_id"), image_info.get("region"), TileState.FAILED ) diff --git a/test/aws/osml/model_runner/inference/test_feature_utils.py b/test/aws/osml/model_runner/inference/test_feature_utils.py index 0c4ee670..5209a00c 100755 --- a/test/aws/osml/model_runner/inference/test_feature_utils.py +++ b/test/aws/osml/model_runner/inference/test_feature_utils.py @@ -8,6 +8,11 @@ import numpy as np import pytest import shapely +from osgeo import gdal + +# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume +# no exceptions so we call this explicitly until the software can be updated to match. +gdal.DontUseExceptions() class TestFeatureUtils(unittest.TestCase): @@ -100,12 +105,13 @@ def test_calculate_processing_bounds_full_image(self): chip_lr = sensor_model.image_to_world(ImageCoordinate([101, 101])) min_vals = np.minimum(chip_ul.coordinate, chip_lr.coordinate) max_vals = np.maximum(chip_ul.coordinate, chip_lr.coordinate) - polygon_coords = [] - polygon_coords.append([degrees(min_vals[0]), degrees(min_vals[1])]) - polygon_coords.append([degrees(min_vals[0]), degrees(max_vals[1])]) - polygon_coords.append([degrees(max_vals[0]), degrees(max_vals[1])]) - polygon_coords.append([degrees(max_vals[0]), degrees(min_vals[1])]) - polygon_coords.append([degrees(min_vals[0]), degrees(min_vals[1])]) + polygon_coords = [ + [degrees(min_vals[0]), degrees(min_vals[1])], + [degrees(min_vals[0]), degrees(max_vals[1])], + [degrees(max_vals[0]), degrees(max_vals[1])], + [degrees(max_vals[0]), degrees(min_vals[1])], + [degrees(min_vals[0]), degrees(min_vals[1])], + ] roi = shapely.geometry.Polygon(polygon_coords) processing_bounds = calculate_processing_bounds(ds, roi, sensor_model) @@ -122,12 +128,13 @@ def test_calculate_processing_bounds_intersect(self): chip_lr = sensor_model.image_to_world(ImageCoordinate([50, 50])) min_vals = np.minimum(chip_ul.coordinate, chip_lr.coordinate) max_vals = np.maximum(chip_ul.coordinate, chip_lr.coordinate) - polygon_coords = [] - polygon_coords.append([degrees(min_vals[0]), degrees(min_vals[1])]) - polygon_coords.append([degrees(min_vals[0]), degrees(max_vals[1])]) - polygon_coords.append([degrees(max_vals[0]), degrees(max_vals[1])]) - polygon_coords.append([degrees(max_vals[0]), degrees(min_vals[1])]) - polygon_coords.append([degrees(min_vals[0]), degrees(min_vals[1])]) + polygon_coords = [ + [degrees(min_vals[0]), degrees(min_vals[1])], + [degrees(min_vals[0]), degrees(max_vals[1])], + [degrees(max_vals[0]), degrees(max_vals[1])], + [degrees(max_vals[0]), degrees(min_vals[1])], + [degrees(min_vals[0]), degrees(min_vals[1])], + ] roi = shapely.geometry.Polygon(polygon_coords) processing_bounds = calculate_processing_bounds(ds, roi, sensor_model) diff --git a/test/aws/osml/model_runner/inference/test_sm_detector.py b/test/aws/osml/model_runner/inference/test_sm_detector.py index f6887061..c5b4213c 100755 --- a/test/aws/osml/model_runner/inference/test_sm_detector.py +++ b/test/aws/osml/model_runner/inference/test_sm_detector.py @@ -12,7 +12,7 @@ from botocore.exceptions import ClientError from botocore.stub import ANY, Stubber -MOCK_RESPONSE = { +MOCK_MODEL_RESPONSE = { "Body": io.StringIO( json.dumps( { @@ -21,12 +21,13 @@ { "type": "Feature", "id": "1cc5e6d6-e12f-430d-adf0-8d2276ce8c5a", - "geometry": {"type": "Point", "coordinates": [0.0, 0.0]}, + "geometry": {"type": "Point", "coordinates": [-43.679691, -22.941953]}, "properties": { "bounds_imcoords": [429, 553, 440, 561], - "feature_types": {"ground_motor_passenger_vehicle": 0.2961518168449402}, + "geom_imcoords": [[429, 553], [429, 561], [440, 561], [440, 553], [429, 553]], + "featureClasses": [{"iri": "ground_motor_passenger_vehicle", "score": 0.2961518168449402}], "detection_score": 0.2961518168449402, - "image_id": "test-image-id", + "image_id": "2pp5e6d6-e12f-430d-adf0-8d2276ceadf0", }, } ], @@ -68,7 +69,7 @@ def test_find_features(self): sm_runtime_stub.add_response( "invoke_endpoint", expected_params={"EndpointName": "test-endpoint", "Body": ANY}, - service_response=MOCK_RESPONSE, + service_response=MOCK_MODEL_RESPONSE, ) sm_runtime_stub.activate() @@ -87,7 +88,7 @@ def test_find_features_throw_json_exception(self): sm_runtime_stub.add_response( "invoke_endpoint", expected_params={"EndpointName": "test-endpoint", "Body": ANY}, - service_response=MOCK_RESPONSE, + service_response=MOCK_MODEL_RESPONSE, ) sm_runtime_stub.add_client_error(str(JSONDecodeError)) sm_runtime_stub.activate() @@ -106,7 +107,7 @@ def test_find_features_throw_client_exception(self): sm_client_stub.add_response( "invoke_endpoint", expected_params={"EndpointName": "test-endpoint", "Body": ANY}, - service_response=MOCK_RESPONSE, + service_response=MOCK_MODEL_RESPONSE, ) sm_client_stub.add_client_error(str(ClientError({"Error": {"Code": 500, "Message": "ClientError"}}, "update_item"))) feature_detector.sm_client.invoke_endpoint = Mock( diff --git a/test/aws/osml/model_runner/status/test_image_status_monitor.py b/test/aws/osml/model_runner/status/test_image_status_monitor.py index c0ace464..072630a8 100644 --- a/test/aws/osml/model_runner/status/test_image_status_monitor.py +++ b/test/aws/osml/model_runner/status/test_image_status_monitor.py @@ -2,7 +2,6 @@ import os import unittest -from decimal import Decimal import boto3 from moto import mock_aws @@ -29,10 +28,10 @@ def setUp(self): self.test_job_item = JobItem( job_id="test-job", image_id="test-image", - processing_duration=Decimal(1000), - region_success=Decimal(5), - region_error=Decimal(0), - region_count=Decimal(5), + processing_duration=1000, + region_success=5, + region_error=0, + region_count=5, ) def test_process_event_success(self): @@ -56,9 +55,9 @@ def test_process_event_failure(self): job_id=None, image_id="test-image", processing_duration=None, - region_success=Decimal(0), - region_error=Decimal(5), - region_count=Decimal(5), + region_success=0, + region_error=5, + region_count=5, ) status = RequestStatus.FAILED message = "Processing failed." diff --git a/test/aws/osml/model_runner/status/test_region_status_monitor.py b/test/aws/osml/model_runner/status/test_region_status_monitor.py index aafd80f4..603dc48b 100644 --- a/test/aws/osml/model_runner/status/test_region_status_monitor.py +++ b/test/aws/osml/model_runner/status/test_region_status_monitor.py @@ -2,7 +2,6 @@ import os import unittest -from decimal import Decimal import boto3 from moto import mock_aws @@ -30,12 +29,12 @@ def setUp(self): job_id="test-job", image_id="test-image", region_id="test-region", - processing_duration=Decimal(1000), - failed_tile_count=Decimal(0), + processing_duration=1000, + failed_tile_count=0, failed_tiles=[], - succeeded_tile_count=Decimal(0), + succeeded_tile_count=0, succeeded_tiles=[], - total_tiles=Decimal(10), + total_tiles=10, ) def test_process_event_success(self): @@ -61,7 +60,7 @@ def test_process_event_failure(self): region_id="test-region", processing_duration=None, # Required field failed_tiles=[], - total_tiles=Decimal(10), + total_tiles=10, ) status = RequestStatus.FAILED message = "Processing failed." diff --git a/test/configuration.py b/test/configuration.py index 86c87f2a..70eb5c72 100755 --- a/test/configuration.py +++ b/test/configuration.py @@ -1,68 +1,80 @@ # Copyright 2023-2024 Amazon.com, Inc. or its affiliates. +import io +import json +from unittest.mock import Mock -# Environment Variable Configuration -TEST_ENV_CONFIG = { - # ModelRunner test config - "AWS_DEFAULT_REGION": "us-west-2", - "WORKERS": "4", - "WORKERS_PER_CPU": "1", - "JOB_TABLE": "TEST-JOB-TABLE", - "ENDPOINT_TABLE": "TEST-ENDPOINT-STATS-TABLE", - "FEATURE_TABLE": "TEST-FEATURE-TABLE", - "REGION_REQUEST_TABLE": "TEST-REGION-REQUEST-TABLE", - "IMAGE_QUEUE": "TEST-IMAGE-QUEUE", - "REGION_QUEUE": "TEST-REGION-QUEUE", - "IMAGE_STATUS_TOPIC": "TEST-IMAGE-STATUS-TOPIC", - "REGION_STATUS_TOPIC": "TEST-REGION-STATUS-TOPIC", - # Fake cred info for MOTO - "AWS_ACCESS_KEY_ID": "testing", - "AWS_SECRET_ACCESS_KEY": "testing", - "AWS_SECURITY_TOKEN": "testing", - "AWS_SESSION_TOKEN": "testing", - "SM_SELF_THROTTLING": "true", -} -# Fake account ID -TEST_ACCOUNT = "123456789123" - -# DDB Configurations -TEST_JOB_TABLE_KEY_SCHEMA = [{"AttributeName": "image_id", "KeyType": "HASH"}] -TEST_JOB_TABLE_ATTRIBUTE_DEFINITIONS = [{"AttributeName": "image_id", "AttributeType": "S"}] - -TEST_ENDPOINT_TABLE_KEY_SCHEMA = [{"AttributeName": "endpoint", "KeyType": "HASH"}] -TEST_ENDPOINT_TABLE_ATTRIBUTE_DEFINITIONS = [ - {"AttributeName": "endpoint", "AttributeType": "S"}, -] - -TEST_FEATURE_TABLE_KEY_SCHEMA = [ - {"AttributeName": "hash_key", "KeyType": "HASH"}, - {"AttributeName": "range_key", "KeyType": "RANGE"}, -] -TEST_FEATURE_TABLE_ATTRIBUTE_DEFINITIONS = [ - {"AttributeName": "hash_key", "AttributeType": "S"}, - {"AttributeName": "range_key", "AttributeType": "S"}, -] +from botocore.exceptions import ClientError -TEST_REGION_REQUEST_TABLE_KEY_SCHEMA = [ - {"AttributeName": "region_id", "KeyType": "HASH"}, - {"AttributeName": "image_id", "KeyType": "RANGE"}, -] -TEST_REGION_REQUEST_TABLE_ATTRIBUTE_DEFINITIONS = [ - {"AttributeName": "region_id", "AttributeType": "S"}, - {"AttributeName": "image_id", "AttributeType": "S"}, -] - -# S3 Configurations -TEST_RESULTS_BUCKET = "test-results-bucket" -TEST_IMAGE_FILE = "./test/data/small.ntf" -TEST_IMAGE_BUCKET = "test-image-bucket" -TEST_IMAGE_KEY = "small.ntf" -TEST_S3_FULL_BUCKET_PATH = "s3://test-results-bucket/test/data/small.ntf" - -TEST_RESULTS_STREAM = "test-results-stream" +MOCK_MODEL_RESPONSE = { + "Body": io.StringIO( + json.dumps( + { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "id": "1cc5e6d6-e12f-430d-adf0-8d2276ce8c5a", + "geometry": {"type": "Point", "coordinates": [-43.679691, -22.941953]}, + "properties": { + "bounds_imcoords": [429, 553, 440, 561], + "geom_imcoords": [[429, 553], [429, 561], [440, 561], [440, 553], [429, 553]], + "featureClasses": [{"iri": "ground_motor_passenger_vehicle", "score": 0.2961518168449402}], + "detection_score": 0.2961518168449402, + "image_id": "2pp5e6d6-e12f-430d-adf0-8d2276ceadf0", + }, + } + ], + } + ) + ) +} -TEST_IMAGE_ID = "test-image-id" -TEST_IMAGE_EXTENSION = "NITF" -TEST_JOB_ID = "test-job-id" -TEST_REGION_ID = "test-region-id" -TEST_ELEVATION_DATA_LOCATION = "s3://TEST-BUCKET/ELEVATION-DATA-LOCATION" +TEST_CONFIG = { + "ACCOUNT_ID": "123456789123", + "ELEVATION_DATA_LOCATION": "s3://test-bucket/elevation-data", + "ENDPOINT_PRODUCTION_VARIANTS": [ + {"VariantName": "Primary", "ModelName": "TestModel", "InitialInstanceCount": 1, "InstanceType": "ml.m5.12xlarge"} + ], + "ENDPOINT_TABLE_ATTRIBUTE_DEFINITIONS": [{"AttributeName": "endpoint", "AttributeType": "S"}], + "ENDPOINT_TABLE_KEY_SCHEMA": [{"AttributeName": "endpoint", "KeyType": "HASH"}], + "FEATURE_TABLE_ATTRIBUTE_DEFINITIONS": [ + {"AttributeName": "hash_key", "AttributeType": "S"}, + {"AttributeName": "range_key", "AttributeType": "S"}, + ], + "FEATURE_TABLE_KEY_SCHEMA": [ + {"AttributeName": "hash_key", "KeyType": "HASH"}, + {"AttributeName": "range_key", "KeyType": "RANGE"}, + ], + "IMAGE_BUCKET": "test-image-bucket", + "IMAGE_EXTENSION": "NITF", + "IMAGE_FILE": "./test/data/small.ntf", + "IMAGE_ID": "test-image-id", + "IMAGE_KEY": "small.ntf", + "JOB_ID": "test-job-id", + "JOB_NAME": "test-job-name", + "JOB_TABLE_ATTRIBUTE_DEFINITIONS": [{"AttributeName": "image_id", "AttributeType": "S"}], + "JOB_TABLE_KEY_SCHEMA": [{"AttributeName": "image_id", "KeyType": "HASH"}], + "MOCK_PUT_EXCEPTION": Mock(side_effect=ClientError({"Error": {"Code": 500, "Message": "ClientError"}}, "put_item")), + "MOCK_UPDATE_EXCEPTION": Mock( + side_effect=ClientError({"Error": {"Code": 500, "Message": "ClientError"}}, "update_item") + ), + "MODEL_ENDPOINT": "TestEndpoint", + "MODEL_NAME": "TestModel", + "REGION_ID": "test-region-id", + "REGION_REQUEST_TABLE_ATTRIBUTE_DEFINITIONS": [ + {"AttributeName": "region_id", "AttributeType": "S"}, + {"AttributeName": "image_id", "AttributeType": "S"}, + ], + "REGION_REQUEST_TABLE_KEY_SCHEMA": [ + {"AttributeName": "region_id", "KeyType": "HASH"}, + {"AttributeName": "image_id", "KeyType": "RANGE"}, + ], + "RESULTS_BUCKET": "test-results-bucket", + "RESULTS_STREAM": "test-results-stream", + "S3_FULL_BUCKET_PATH": "s3://test-results-bucket/test/data/small.ntf", + "SM_MODEL_CONTAINER": { + "Image": "123456789123.dkr.ecr.us-east-1.amazonaws.com/test:1", + "ModelDataUrl": "s3://test-bucket/model.tar.gz", + }, +} diff --git a/test/test_app.py b/test/test_app.py index ddcd4216..e00a9a0f 100755 --- a/test/test_app.py +++ b/test/test_app.py @@ -3,89 +3,32 @@ import os from importlib import reload from queue import Queue +from test.configuration import MOCK_MODEL_RESPONSE, TEST_CONFIG from unittest import TestCase, main from unittest.mock import Mock, patch import boto3 import geojson -from botocore.exceptions import ClientError +from botocore.stub import ANY, Stubber from moto import mock_aws from osgeo import gdal -TEST_MOCK_PUT_EXCEPTION = Mock(side_effect=ClientError({"Error": {"Code": 500, "Message": "ClientError"}}, "put_item")) -TEST_MOCK_UPDATE_EXCEPTION = Mock(side_effect=ClientError({"Error": {"Code": 500, "Message": "ClientError"}}, "update_item")) - -TEST_ACCOUNT_ID = "123456789123" -TEST_IMAGE_ID = "test-image-id" -TEST_IMAGE_EXTENSION = "NITF" -TEST_JOB_ID = "test-job-id" -TEST_RANDOM_KEY = "test-random-key" -TEST_ELEVATION_DATA_LOCATION = "s3://TEST-BUCKET/ELEVATION-DATA-LOCATION" -TEST_MODEL_ENDPOINT = "NOOP_BOUNDS_MODEL_NAME" -TEST_MODEL_NAME = "FakeCVModel" -TEST_RESULTS_BUCKET = "test-results-bucket" -TEST_IMAGE_FILE = "./test/data/small.ntf" -TEST_IMAGE_BUCKET = "test-image-bucket" -TEST_IMAGE_KEY = "small.ntf" -TEST_S3_FULL_BUCKET_PATH = "s3://test-results-bucket/test/data/small.ntf" -TEST_RESULTS_STREAM = "test-results-stream" -TEST_SM_MODEL_CONTAINER = { - "Image": "123456789123.dkr.ecr.us-east-1.amazonaws.com/test:1", - "ModelDataUrl": "s3://MyBucket/model.tar.gz", -} -TEST_ENDPOINT_PRODUCTION_VARIANTS = [ - { - "VariantName": "Primary", - "ModelName": TEST_MODEL_NAME, - "InitialInstanceCount": 1, - "InstanceType": "ml.m5.12xlarge", - }, -] - -# DDB Configurations -TEST_JOB_TABLE_KEY_SCHEMA = [{"AttributeName": "image_id", "KeyType": "HASH"}] -TEST_JOB_TABLE_ATTRIBUTE_DEFINITIONS = [{"AttributeName": "image_id", "AttributeType": "S"}] - -TEST_ENDPOINT_TABLE_KEY_SCHEMA = [{"AttributeName": "endpoint", "KeyType": "HASH"}] -TEST_ENDPOINT_TABLE_ATTRIBUTE_DEFINITIONS = [ - {"AttributeName": "endpoint", "AttributeType": "S"}, -] - -TEST_FEATURE_TABLE_KEY_SCHEMA = [ - {"AttributeName": "hash_key", "KeyType": "HASH"}, - {"AttributeName": "range_key", "KeyType": "RANGE"}, -] -TEST_FEATURE_TABLE_ATTRIBUTE_DEFINITIONS = [ - {"AttributeName": "hash_key", "AttributeType": "S"}, - {"AttributeName": "range_key", "AttributeType": "S"}, -] - -TEST_REGION_REQUEST_TABLE_KEY_SCHEMA = [ - {"AttributeName": "region_id", "KeyType": "HASH"}, - {"AttributeName": "image_id", "KeyType": "RANGE"}, -] -TEST_REGION_REQUEST_TABLE_ATTRIBUTE_DEFINITIONS = [ - {"AttributeName": "region_id", "AttributeType": "S"}, - {"AttributeName": "image_id", "AttributeType": "S"}, -] - - -class RegionRequestMatcher: - def __init__(self, region_request): - self.region_request = region_request - - def __eq__(self, other): - if other is None: - return self.region_request is None - else: - return other["region"] == self.region_request["region"] and other["image_id"] == self.region_request["image_id"] - @mock_aws class TestModelRunner(TestCase): - def setUp(self): + """ + Unit tests for the ModelRunner application. + + This test suite covers different functionalities of the ModelRunner, + such as processing image and region requests, AWS resource interactions, + and handling exceptions during tile processing. + """ + + def setUp(self) -> None: """ - Set up virtual AWS resources for use by our unit tests + Set up virtual AWS resources for use in unit tests. + Creates DynamoDB tables, S3 buckets, SNS topics, SQS queues, and + mock SageMaker endpoints required for the tests. """ from aws.osml.model_runner.api import RegionRequest from aws.osml.model_runner.api.image_request import ImageRequest @@ -97,11 +40,10 @@ def setUp(self): from aws.osml.model_runner.database.region_request_table import RegionRequestTable from aws.osml.model_runner.status import ImageStatusMonitor, RegionStatusMonitor - # GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume - # no exceptions so we call this explicitly until the software can be updated to match. + # Required to avoid warnings from GDAL gdal.DontUseExceptions() - # Create custom properties to be passed into the image request + # Set default custom feature properties that should exist on the output features= self.test_custom_feature_properties = { "modelMetadata": { "modelName": "test-model-name", @@ -111,10 +53,10 @@ def setUp(self): } } - # This is the expected results for the source property derived from the small test image + # Set default feature properties that should exist on the output features self.test_feature_source_property = [ { - "location": TEST_IMAGE_FILE, + "location": TEST_CONFIG["IMAGE_FILE"], "format": "NITF", "category": "VIS", "sourceId": "Checks an uncompressed 1024x1024 8 bit mono image with GEOcentric data. Airfield", @@ -122,75 +64,72 @@ def setUp(self): } ] + # Build a mock region request for testing self.region_request = RegionRequest( { "tile_size": (10, 10), "tile_overlap": (1, 1), "tile_format": "NITF", - "image_id": TEST_IMAGE_ID, - "image_url": TEST_IMAGE_FILE, + "image_id": TEST_CONFIG["IMAGE_ID"], + "image_url": TEST_CONFIG["IMAGE_FILE"], "region_bounds": ((0, 0), (50, 50)), - "model_name": TEST_MODEL_ENDPOINT, + "model_name": TEST_CONFIG["MODEL_ENDPOINT"], "model_invoke_mode": "SM_ENDPOINT", - "image_extension": TEST_IMAGE_EXTENSION, + "image_extension": TEST_CONFIG["IMAGE_EXTENSION"], } ) - # Build fake image request to work with + # Build a mock image request for testing self.image_request = ImageRequest.from_external_message( { - "jobName": TEST_IMAGE_ID, - "jobId": TEST_IMAGE_ID, - "imageUrls": [TEST_IMAGE_FILE], + "jobName": TEST_CONFIG["IMAGE_ID"], + "jobId": TEST_CONFIG["JOB_ID"], + "imageUrls": [TEST_CONFIG["IMAGE_FILE"]], "outputs": [ - {"type": "S3", "bucket": TEST_RESULTS_BUCKET, "prefix": f"{TEST_IMAGE_ID}/"}, - {"type": "Kinesis", "stream": TEST_RESULTS_STREAM, "batchSize": 1000}, + {"type": "S3", "bucket": TEST_CONFIG["RESULTS_BUCKET"], "prefix": f"{TEST_CONFIG['IMAGE_ID']}"}, + {"type": "Kinesis", "stream": TEST_CONFIG["RESULTS_STREAM"], "batchSize": 1000}, ], "featureProperties": [self.test_custom_feature_properties], - "imageProcessor": {"name": TEST_MODEL_ENDPOINT, "type": "SM_ENDPOINT"}, + "imageProcessor": {"name": TEST_CONFIG["MODEL_ENDPOINT"], "type": "SM_ENDPOINT"}, "imageProcessorTileSize": 2048, "imageProcessorTileOverlap": 50, "imageProcessorTileFormat": "NITF", "imageProcessorTileCompression": "JPEG", + "randomKey": "random-value", } ) - # Prepare something ahead of all tests - # Create virtual DDB tables to write test data into + # Build the required virtual DDB tables self.ddb = boto3.resource("dynamodb", config=BotoConfig.default) - # Job tracking table self.image_request_ddb = self.ddb.create_table( TableName=os.environ["JOB_TABLE"], - KeySchema=TEST_JOB_TABLE_KEY_SCHEMA, - AttributeDefinitions=TEST_JOB_TABLE_ATTRIBUTE_DEFINITIONS, + KeySchema=TEST_CONFIG["JOB_TABLE_KEY_SCHEMA"], + AttributeDefinitions=TEST_CONFIG["JOB_TABLE_ATTRIBUTE_DEFINITIONS"], BillingMode="PAY_PER_REQUEST", ) self.job_table = JobTable(os.environ["JOB_TABLE"]) - # Region Request tracking table self.image_request_ddb = self.ddb.create_table( TableName=os.environ["REGION_REQUEST_TABLE"], - KeySchema=TEST_REGION_REQUEST_TABLE_KEY_SCHEMA, - AttributeDefinitions=TEST_REGION_REQUEST_TABLE_ATTRIBUTE_DEFINITIONS, + KeySchema=TEST_CONFIG["REGION_REQUEST_TABLE_KEY_SCHEMA"], + AttributeDefinitions=TEST_CONFIG["REGION_REQUEST_TABLE_ATTRIBUTE_DEFINITIONS"], BillingMode="PAY_PER_REQUEST", ) self.region_request_table = RegionRequestTable(os.environ["REGION_REQUEST_TABLE"]) - # Endpoint statistics table self.endpoint_statistics_ddb = self.ddb.create_table( TableName=os.environ["ENDPOINT_TABLE"], - KeySchema=TEST_ENDPOINT_TABLE_KEY_SCHEMA, - AttributeDefinitions=TEST_ENDPOINT_TABLE_ATTRIBUTE_DEFINITIONS, + KeySchema=TEST_CONFIG["ENDPOINT_TABLE_KEY_SCHEMA"], + AttributeDefinitions=TEST_CONFIG["ENDPOINT_TABLE_ATTRIBUTE_DEFINITIONS"], BillingMode="PAY_PER_REQUEST", ) self.endpoint_statistics_table = EndpointStatisticsTable(os.environ["ENDPOINT_TABLE"]) - # Feature tracking table self.feature_ddb = self.ddb.create_table( TableName=os.environ["FEATURE_TABLE"], - KeySchema=TEST_FEATURE_TABLE_KEY_SCHEMA, - AttributeDefinitions=TEST_FEATURE_TABLE_ATTRIBUTE_DEFINITIONS, + KeySchema=TEST_CONFIG["FEATURE_TABLE_KEY_SCHEMA"], + AttributeDefinitions=TEST_CONFIG["FEATURE_TABLE_ATTRIBUTE_DEFINITIONS"], BillingMode="PAY_PER_REQUEST", ) self.feature_table = FeatureTable( @@ -199,35 +138,31 @@ def setUp(self): self.image_request.tile_overlap, ) - # Create fake buckets for images and results + # Build a virtual S3 and Kinesis output sink self.s3 = boto3.client("s3", config=BotoConfig.default) - # Create a fake bucket to store images self.image_bucket = self.s3.create_bucket( - Bucket=TEST_IMAGE_BUCKET, + Bucket=TEST_CONFIG["IMAGE_BUCKET"], CreateBucketConfiguration={"LocationConstraint": os.environ["AWS_DEFAULT_REGION"]}, ) - # Load our test image into our bucket - with open(TEST_IMAGE_FILE, "rb") as data: - self.s3.upload_fileobj(data, TEST_IMAGE_BUCKET, TEST_IMAGE_KEY) - # Create a fake bucket to store results in + with open(TEST_CONFIG["IMAGE_FILE"], "rb") as data: + self.s3.upload_fileobj(data, TEST_CONFIG["IMAGE_BUCKET"], TEST_CONFIG["IMAGE_KEY"]) + self.results_bucket = self.s3.create_bucket( - Bucket=TEST_RESULTS_BUCKET, + Bucket=TEST_CONFIG["RESULTS_BUCKET"], CreateBucketConfiguration={"LocationConstraint": os.environ["AWS_DEFAULT_REGION"]}, ) - # Create a fake stream to store results in - self.kinesis = boto3.client("kinesis", region_name=os.environ["AWS_DEFAULT_REGION"]) + self.kinesis = boto3.client("kinesis", config=BotoConfig.default) self.results_stream = self.kinesis.create_stream( - StreamName=TEST_RESULTS_STREAM, StreamModeDetails={"StreamMode": "ON_DEMAND"} + StreamName=TEST_CONFIG["RESULTS_STREAM"], StreamModeDetails={"StreamMode": "ON_DEMAND"} ) - # Create a fake image status sns topic for reporting job status + # Build a virtual image status topic and queue self.sns = boto3.client("sns", config=BotoConfig.default) image_status_topic_arn = self.sns.create_topic(Name=os.environ["IMAGE_STATUS_TOPIC"]).get("TopicArn") - # Create a fake sqs queue to consume the image status sns topic events self.sqs = boto3.client("sqs", config=BotoConfig.default) image_status_queue_url = self.sqs.create_queue(QueueName="mock_queue").get("QueueUrl") image_status_queue_attributes = self.sqs.get_queue_attributes( @@ -235,43 +170,35 @@ def setUp(self): ) image_status_queue_arn = image_status_queue_attributes.get("Attributes").get("QueueArn") - # Subscribe our sns topic to the queue self.sns.subscribe(TopicArn=image_status_topic_arn, Protocol="sqs", Endpoint=image_status_queue_arn) - - # Set up our status monitor for the image status queue self.image_status_monitor = ImageStatusMonitor(image_status_topic_arn) - # Create a fake region status sns topic for reporting job status + # Build a virtual region status topic and queue region_status_topic_arn = self.sns.create_topic(Name=os.environ["REGION_STATUS_TOPIC"]).get("TopicArn") self.region_status_monitor = RegionStatusMonitor(region_status_topic_arn) - # Create a fake sqs queue to consume the region status sns topic events region_status_queue_url = self.sqs.create_queue(QueueName="mock_region_queue").get("QueueUrl") region_status_queue_attributes = self.sqs.get_queue_attributes( QueueUrl=region_status_queue_url, AttributeNames=["QueueArn"] ) region_status_queue_arn = region_status_queue_attributes.get("Attributes").get("QueueArn") - # Subscribe our sns topic to the queue self.sns.subscribe(TopicArn=region_status_topic_arn, Protocol="sqs", Endpoint=region_status_queue_arn) - # Create a fake bounds model + # Build a virtual SageMaker endpoint self.sm = boto3.client("sagemaker", config=BotoConfig.default) self.sm.create_model( - ModelName=TEST_MODEL_NAME, - PrimaryContainer=TEST_SM_MODEL_CONTAINER, - ExecutionRoleArn=f"arn:aws:iam::{TEST_ACCOUNT_ID}:role/FakeRole", + ModelName=TEST_CONFIG["MODEL_NAME"], + PrimaryContainer=TEST_CONFIG["SM_MODEL_CONTAINER"], + ExecutionRoleArn=f"arn:aws:iam::{TEST_CONFIG['ACCOUNT_ID']}:role/FakeRole", ) - # Create a fake endpoint config - config_name = "UnitTestConfig" - production_variants = TEST_ENDPOINT_PRODUCTION_VARIANTS + + config_name = "TestConfig" + production_variants = TEST_CONFIG["ENDPOINT_PRODUCTION_VARIANTS"] self.sm.create_endpoint_config(EndpointConfigName=config_name, ProductionVariants=production_variants) - # Create a fake endpoint - self.sm.create_endpoint(EndpointName=TEST_MODEL_ENDPOINT, EndpointConfigName=config_name) - # Create a fake geom model endpoint - self.sm.create_endpoint(EndpointName="NOOP_GEOM_MODEL_NAME", EndpointConfigName=config_name) + self.sm.create_endpoint(EndpointName=TEST_CONFIG["MODEL_ENDPOINT"], EndpointConfigName=config_name) - # Build our model runner and plug in fake resources + # Plug in the required virtual resources to our ModelRunner instance self.model_runner = ModelRunner() self.model_runner.job_table = self.job_table self.model_runner.region_request_table = self.region_request_table @@ -279,9 +206,9 @@ def setUp(self): self.model_runner.image_status_monitor = self.image_status_monitor self.model_runner.region_status_monitor = self.region_status_monitor - def tearDown(self): + def tearDown(self) -> None: """ - Delete virtual resources after each test + Delete virtual AWS resources after each test. """ self.image_request_ddb.delete() self.endpoint_statistics_ddb.delete() @@ -298,160 +225,85 @@ def tearDown(self): self.image_status_monitor = None self.region_status_monitor = None - def test_aws_osml_model_runner_importable(self): + def test_aws_osml_model_runner_importable(self) -> None: + """ + Ensure that aws.osml.model_runner can be imported without errors. + """ import aws.osml.model_runner # noqa: F401 - def test_run(self): + def test_run(self) -> None: + """ + Test that the run method in ModelRunner initiates the work queue monitoring process. + """ self.model_runner.monitor_work_queues = Mock() - self.model_runner.run() - self.model_runner.monitor_work_queues.assert_called_once() - def test_stop(self): + def test_stop(self) -> None: + """ + Test that the stop method stops the ModelRunner. + """ self.model_runner.running = True self.model_runner.stop() assert self.model_runner.running is False - def test_process_bounds_image_request(self): - self.model_runner.process_image_request(self.image_request) - - # Check to make sure the job was marked as complete - image_request_item = self.job_table.get_image_request(self.image_request.image_id) - assert image_request_item.region_success == 1 - - # Check that we created the right amount of features - features = self.feature_table.get_features(self.image_request.image_id) - assert len(features) == 1 - - # Check to make sure the feature was assigned a real geo coordinate - assert features[0]["geometry"]["type"] == "Polygon" - - # Grab the feature results from virtual S3 bucket - results_key = self.s3.list_objects(Bucket=TEST_RESULTS_BUCKET)["Contents"][0]["Key"] - - results_contents = self.s3.get_object( - Bucket=TEST_RESULTS_BUCKET, - Key=results_key, - )["Body"].read() - - # Load them into memory as geojson - results_features = geojson.loads(results_contents.decode("utf-8"))["features"] - assert len(results_features) > 0 - - # Check that the provided custom feature property was added - assert results_features[0]["properties"]["modelMetadata"] == self.test_custom_feature_properties.get("modelMetadata") - - # Check we got the correct source data for the small.ntf file - assert results_features[0]["properties"]["sourceMetadata"] == self.test_feature_source_property - - # Check that we calculated the max in progress regions - # Test instance type is set to m5.12xl with 48 vcpus. Default - # scale factor is set to 10 and workers per cpu is 1 so: - # floor((10 * 1 * 48) / 1) = 480 - assert 480 == self.model_runner.endpoint_utils.calculate_max_regions(endpoint_name=TEST_MODEL_ENDPOINT) - - def test_process_geom_image_request(self): - from aws.osml.model_runner.api.image_request import ImageRequest - - self.image_request = ImageRequest.from_external_message( - { - "jobName": TEST_IMAGE_ID, - "jobId": TEST_IMAGE_ID, - "imageUrls": [TEST_IMAGE_FILE], - "outputs": [ - {"type": "S3", "bucket": TEST_RESULTS_BUCKET, "prefix": f"{TEST_IMAGE_ID}/"}, - {"type": "Kinesis", "stream": TEST_RESULTS_STREAM, "batchSize": 1000}, - ], - "featureProperties": [self.test_custom_feature_properties], - "imageProcessor": {"name": "NOOP_GEOM_MODEL_NAME", "type": "SM_ENDPOINT"}, - "imageProcessorTileSize": 2048, - "imageProcessorTileOverlap": 50, - "imageProcessorTileFormat": "NITF", - "imageProcessorTileCompression": "JPEG", - } - ) - self.model_runner.process_image_request(self.image_request) - - # Check to make sure the job was marked as complete - image_request_item = self.job_table.get_image_request(self.image_request.image_id) - assert image_request_item.region_success == 1 - - # Check that we created the right amount of features - features = self.feature_table.get_features(self.image_request.image_id) - assert len(features) == 1 - - # Check to make sure the feature was assigned a real geo coordinate - assert features[0]["geometry"]["type"] == "Polygon" - - # Grab the feature results from virtual S3 bucket - results_key = self.s3.list_objects(Bucket=TEST_RESULTS_BUCKET)["Contents"][0]["Key"] - - results_contents = self.s3.get_object( - Bucket=TEST_RESULTS_BUCKET, - Key=results_key, - )["Body"].read() - - # Load them into memory as geojson - results_features = geojson.loads(results_contents.decode("utf-8"))["features"] - assert len(results_features) > 0 - - # Check that the provided custom feature property was added - assert results_features[0]["properties"]["modelMetadata"] == self.test_custom_feature_properties.get("modelMetadata") - - # Check we got the correct source data for the small.ntf file - assert results_features[0]["properties"]["sourceMetadata"] == self.test_feature_source_property - - # Check that we calculated the max in progress regions - # Test instance type is set to m5.12xl with 48 vcpus. Default - # scale factor is set to 10 and workers per cpu is 1 so: - # floor((10 * 1 * 48) / 1) = 480 - assert 480 == self.model_runner.endpoint_utils.calculate_max_regions(endpoint_name=TEST_MODEL_ENDPOINT) - - def test_process_additional_attributes_image_request(self): - from aws.osml.model_runner.api.image_request import ImageRequest - - self.image_request = ImageRequest.from_external_message( - { - "jobName": TEST_IMAGE_ID, - "jobId": TEST_IMAGE_ID, - "imageUrls": [TEST_IMAGE_FILE], - "outputs": [ - {"type": "S3", "bucket": TEST_RESULTS_BUCKET, "prefix": f"{TEST_IMAGE_ID}/"}, - {"type": "Kinesis", "stream": TEST_RESULTS_STREAM, "batchSize": 1000}, - ], - "featureProperties": [self.test_custom_feature_properties], - "imageProcessor": {"name": "NOOP_GEOM_MODEL_NAME", "type": "SM_ENDPOINT"}, - "imageProcessorTileSize": 2048, - "imageProcessorTileOverlap": 50, - "imageProcessorTileFormat": "NITF", - "imageProcessorTileCompression": "JPEG", - "testRandomKey": TEST_RANDOM_KEY, - } - ) - self.model_runner.process_image_request(self.image_request) - image_request_item = self.job_table.get_image_request(self.image_request.image_id) - assert image_request_item.region_success == 1 + def test_process_image_request(self) -> None: + """ + Test the process of handling an image request, ensuring that jobs are marked as complete, + features are created, and the correct metadata is stored in S3. Checks that we calculated + the max in progress regions with the test instance type is set to m5.12xl with 48 vcpus. + """ + with patch("aws.osml.model_runner.inference.sm_detector.boto3") as mock_boto3: + # Build stubbed model client for ModelRunner to interact with + mock_boto3.client.return_value = self.get_stubbed_sm_client() + self.model_runner.process_image_request(self.image_request) + + # Ensure that the single region was processed successfully + image_request_item = self.job_table.get_image_request(self.image_request.image_id) + assert image_request_item.region_success == 1 + + # Ensure that the detection outputs arrived in our DDB table + features = self.feature_table.get_features(self.image_request.image_id) + assert len(features) == 1 + assert features[0]["geometry"]["type"] == "Polygon" + + # Ensure that the detection outputs arrived in our output bucket + results_key = self.s3.list_objects(Bucket=TEST_CONFIG["RESULTS_BUCKET"])["Contents"][0]["Key"] + results_contents = self.s3.get_object(Bucket=TEST_CONFIG["RESULTS_BUCKET"], Key=results_key)["Body"].read() + results_features = geojson.loads(results_contents.decode("utf-8"))["features"] + assert len(results_features) > 0 + + # Test that we get the correct model metadata appended to our feature outputs + actual_model_metadata = results_features[0]["properties"]["modelMetadata"] + expected_model_metadata = self.test_custom_feature_properties.get("modelMetadata") + assert actual_model_metadata == expected_model_metadata + + # Test that we get the correct source metadata appended to our feature outputs + actual_source_metadata = results_features[0]["properties"]["sourceMetadata"] + expected_source_metadata = self.test_feature_source_property + assert actual_source_metadata == expected_source_metadata + + # Default scale factor set to 10 and workers per cpu is 1 so: floor((10 * 1 * 48) / 1) = 480 + regions = self.model_runner.endpoint_utils.calculate_max_regions(endpoint_name=TEST_CONFIG["MODEL_ENDPOINT"]) + assert 480 == regions - # Remember that with multiple patch decorators the order of the mocks in the parameter list is - # reversed (i.e. the first mock parameter is the last decorator defined). Also note that the - # pytest fixtures must come at the end. @patch("aws.osml.model_runner.app.setup_tile_workers") @patch("aws.osml.model_runner.app.process_tiles") - def test_process_region_request( - self, - mock_process_tiles, - mock_setup_tile_workers, - ): + def test_process_region_request(self, mock_process_tiles: Mock, mock_setup_tile_workers: Mock) -> None: + """ + Test the process of handling a region request, ensuring that tiles are processed correctly, + region counts are updated, and errors are properly handled. We're testing a single region + here so expecting a single call to both increment and decrement for the model associated with + the region + """ from aws.osml.gdal.gdal_utils import load_gdal_dataset from aws.osml.model_runner.database.endpoint_statistics_table import EndpointStatisticsTable from aws.osml.model_runner.database.job_table import JobTable from aws.osml.model_runner.database.region_request_table import RegionRequestItem, RegionRequestTable from aws.osml.model_runner.status import ImageStatusMonitor, RegionStatusMonitor - region_request_item = RegionRequestItem(image_id=TEST_IMAGE_ID, region_id="test-region-id") + region_request_item = RegionRequestItem(image_id=TEST_CONFIG["IMAGE_ID"], region_id="test-region-id") - # Load up our test image raster_dataset, sensor_model = load_gdal_dataset(self.region_request.image_url) self.model_runner.job_table = Mock(JobTable, autospec=True) @@ -470,26 +322,25 @@ def test_process_region_request( mock_process_tiles.assert_called_once() self.model_runner.fail_region_request.assert_not_called() - # We're testing a single region here so expecting a single call to both increment and - # decrement for the model associated with the region - self.model_runner.endpoint_statistics_table.increment_region_count.assert_called_once_with(TEST_MODEL_ENDPOINT) - self.model_runner.endpoint_statistics_table.decrement_region_count.assert_called_once_with(TEST_MODEL_ENDPOINT) + model = TEST_CONFIG["MODEL_ENDPOINT"] + self.model_runner.endpoint_statistics_table.increment_region_count.assert_called_once_with(model) + self.model_runner.endpoint_statistics_table.decrement_region_count.assert_called_once_with(model) @patch("aws.osml.model_runner.app.setup_tile_workers") @patch("aws.osml.model_runner.app.process_tiles") - def test_process_region_request_exception( - self, - mock_process_tiles, - mock_setup_tile_workers, - ): + def test_process_region_request_exception(self, mock_process_tiles: Mock, mock_setup_tile_workers: Mock) -> None: + """ + Test that an exception during tile processing in a region request is properly handled. + We're testing a single region here so expecting a single call to both increment and decrement + for the model associated with the region + """ from aws.osml.gdal.gdal_utils import load_gdal_dataset from aws.osml.model_runner.database.endpoint_statistics_table import EndpointStatisticsTable from aws.osml.model_runner.database.job_table import JobTable from aws.osml.model_runner.database.region_request_table import RegionRequestItem, RegionRequestTable - region_request_item = RegionRequestItem(image_id=TEST_IMAGE_ID, region_id="test-region-id") + region_request_item = RegionRequestItem(image_id=TEST_CONFIG["IMAGE_ID"], region_id="test-region-id") - # Load up our test image raster_dataset, sensor_model = load_gdal_dataset(self.region_request.image_url) self.model_runner.job_table = Mock(JobTable, autospec=True) @@ -505,18 +356,18 @@ def test_process_region_request_exception( mock_process_tiles.assert_called_once() self.model_runner.fail_region_request.assert_called_once() - # We're testing a single region here so expecting a single call to both increment and - # decrement for the model associated with the region - self.model_runner.endpoint_statistics_table.increment_region_count.assert_called_once_with(TEST_MODEL_ENDPOINT) - self.model_runner.endpoint_statistics_table.decrement_region_count.assert_called_once_with(TEST_MODEL_ENDPOINT) + model = TEST_CONFIG["MODEL_ENDPOINT"] + self.model_runner.endpoint_statistics_table.increment_region_count.assert_called_once_with(model) + self.model_runner.endpoint_statistics_table.decrement_region_count.assert_called_once_with(model) @patch("aws.osml.model_runner.app.setup_tile_workers") @patch("aws.osml.model_runner.app.process_tiles") - def test_process_region_request_invalid( - self, - mock_process_tiles, - mock_setup_tile_workers, - ): + def test_process_region_request_invalid(self, mock_process_tiles: Mock, mock_setup_tile_workers: Mock) -> None: + """ + Test handling of an invalid region request, ensuring the appropriate exception is raised. + We're testing a single region here so expecting a single call to both increment and decrement + for the model associated with the region + """ from aws.osml.gdal.gdal_utils import load_gdal_dataset from aws.osml.model_runner.api import RegionRequest from aws.osml.model_runner.database.endpoint_statistics_table import EndpointStatisticsTable @@ -528,17 +379,16 @@ def test_process_region_request_invalid( "tile_size": (10, 10), "tile_overlap": (1, 1), "tile_format": "NITF", - "image_id": TEST_IMAGE_ID, - "image_url": TEST_IMAGE_FILE, + "image_id": TEST_CONFIG["IMAGE_ID"], + "image_url": TEST_CONFIG["IMAGE_FILE"], "region_bounds": ((0, 0), (50, 50)), "model_invoke_mode": "SM_ENDPOINT", - "image_extension": TEST_IMAGE_EXTENSION, + "image_extension": TEST_CONFIG["IMAGE_EXTENSION"], } ) - region_request_item = RegionRequestItem(image_id=TEST_IMAGE_ID, region_id="test-region-id") + region_request_item = RegionRequestItem(image_id=TEST_CONFIG["IMAGE_ID"], region_id="test-region-id") - # Load up our test image raster_dataset, sensor_model = load_gdal_dataset(self.region_request.image_url) self.model_runner.job_table = Mock(JobTable, autospec=True) @@ -553,8 +403,6 @@ def test_process_region_request_invalid( invalid_region_request, region_request_item, raster_dataset, sensor_model ) - # We're testing a single region here so expecting a single call to both increment and - # decrement for the model associated with the region self.model_runner.endpoint_statistics_table.increment_region_count.assert_not_called() self.model_runner.endpoint_statistics_table.decrement_region_count.assert_not_called() @@ -563,21 +411,20 @@ def test_process_region_request_invalid( @patch("aws.osml.model_runner.tile_worker.tile_worker_utils.TileWorker", autospec=True) @patch("aws.osml.model_runner.tile_worker.tile_worker_utils.Queue", autospec=True) def test_process_region_request_throttled( - self, - mock_queue, - mock_tile_worker, - mock_feature_table, - mock_feature_detector, - ): + self, mock_queue: Mock, mock_tile_worker: Mock, mock_feature_table: Mock, mock_feature_detector: Mock + ) -> None: + """ + Test handling of a self-throttled region request due to resource limits, ensuring + the request is not processed and appropriate counts are not incremented. + """ from aws.osml.gdal.gdal_utils import load_gdal_dataset from aws.osml.model_runner.database.endpoint_statistics_table import EndpointStatisticsTable from aws.osml.model_runner.database.job_table import JobTable from aws.osml.model_runner.database.region_request_table import RegionRequestItem, RegionRequestTable from aws.osml.model_runner.exceptions import SelfThrottledRegionException - region_request_item = RegionRequestItem(image_id=TEST_IMAGE_ID, region_id="test-region-id") + region_request_item = RegionRequestItem(image_id=TEST_CONFIG["IMAGE_ID"], region_id="test-region-id") - # Load up our test image raster_dataset, sensor_model = load_gdal_dataset(self.region_request.image_url) self.model_runner.job_table = Mock(JobTable, autospec=True) @@ -598,10 +445,13 @@ def test_process_region_request_throttled( # Check to make sure a queue was created and populated with appropriate region requests mock_queue.assert_not_called() - @patch.dict("os.environ", values={"ELEVATION_DATA_LOCATION": TEST_ELEVATION_DATA_LOCATION}) - def test_create_elevation_model(self): - # These imports/reloads are necessary to force the ServiceConfig instance used by model runner - # to have the patched environment variables + @patch.dict("os.environ", values={"ELEVATION_DATA_LOCATION": TEST_CONFIG["ELEVATION_DATA_LOCATION"]}) + def test_create_elevation_model(self) -> None: + """ + Test that the ModelRunner correctly creates an elevation model based on the SRTM DEM tile set. + The import and reload statements are necessary to force the ServiceConfig to update with the + patched environment variables. + """ import aws.osml.model_runner.app_config reload(aws.osml.model_runner.app_config) @@ -612,7 +462,7 @@ def test_create_elevation_model(self): from aws.osml.photogrammetry.digital_elevation_model import DigitalElevationModel from aws.osml.photogrammetry.srtm_dem_tile_set import SRTMTileSet - assert ServiceConfig.elevation_data_location == TEST_ELEVATION_DATA_LOCATION + assert ServiceConfig.elevation_data_location == TEST_CONFIG["ELEVATION_DATA_LOCATION"] elevation_model = ModelRunner.create_elevation_model() assert elevation_model @@ -624,11 +474,14 @@ def test_create_elevation_model(self): assert elevation_model.tile_set.prefix == "" assert elevation_model.tile_set.version == "1arc_v3" - assert elevation_model.tile_factory.tile_directory == TEST_ELEVATION_DATA_LOCATION + assert elevation_model.tile_factory.tile_directory == TEST_CONFIG["ELEVATION_DATA_LOCATION"] - def test_create_elevation_model_disabled(self): - # These imports/reloads are necessary to force the ServiceConfig instance used by model runner - # to have the patched environment variables + def test_create_elevation_model_disabled(self) -> None: + """ + Test that no elevation model is created when ELEVATION_DATA_LOCATION is not set in the environment. + The import and reload statements are necessary to force the ServiceConfig to update with the + patched environment variables. + """ import aws.osml.model_runner.app_config reload(aws.osml.model_runner.app_config) @@ -636,19 +489,41 @@ def test_create_elevation_model_disabled(self): from aws.osml.model_runner.app import ModelRunner from aws.osml.model_runner.app_config import ServiceConfig - # Check to make sure that excluding the ELEVATION_DATA_LOCATION env variable results in no elevation model assert ServiceConfig.elevation_data_location is None elevation_model = ModelRunner.create_elevation_model() assert not elevation_model @staticmethod - def get_dataset_and_camera(): + def get_dataset_and_camera() -> tuple: + """ + Retrieve a GDAL dataset and sensor model for testing. + + :return: A tuple of GDAL dataset and sensor model. + """ from aws.osml.gdal.gdal_utils import load_gdal_dataset ds, sensor_model = load_gdal_dataset("./test/data/GeogToWGS84GeoKey5.tif") return ds, sensor_model + @staticmethod + def get_stubbed_sm_client() -> boto3.client: + """ + Get a stubbed SageMaker client for use in testing. + + :return: A stubbed SageMaker Runtime client. + """ + sm_client = boto3.client("sagemaker-runtime") + sm_runtime_stub = Stubber(sm_client) + sm_runtime_stub.add_response( + "invoke_endpoint", + expected_params={"EndpointName": TEST_CONFIG["MODEL_ENDPOINT"], "Body": ANY}, + service_response=MOCK_MODEL_RESPONSE, + ) + sm_runtime_stub.activate() + + return sm_client + if __name__ == "__main__": main()