From fd2151217e9ccb760cfa7bf0cc66ce24e0dbf54e Mon Sep 17 00:00:00 2001 From: Zach Mullen Date: Wed, 16 Oct 2024 18:16:07 -0400 Subject: [PATCH 1/4] Low-hanging fruit optimizations for model run listing * Remove unnecessary aliases * Remove ground truth aggregate computation and rely on ModelRun's ground_truth field instead * Short-circuit matching Region lookup (not in hot path) * Remove caching infrastructure since this branch should obviate it --- rdwatch/core/views/model_run.py | 101 +++++++++----------------------- 1 file changed, 28 insertions(+), 73 deletions(-) diff --git a/rdwatch/core/views/model_run.py b/rdwatch/core/views/model_run.py index 358f0048..4abc6d95 100644 --- a/rdwatch/core/views/model_run.py +++ b/rdwatch/core/views/model_run.py @@ -14,7 +14,6 @@ from django.contrib.gis.db.models.functions import Transform from django.contrib.postgres.aggregates import JSONBAgg from django.core import signing -from django.core.cache import cache from django.core.files.storage import default_storage from django.db import transaction from django.db.models import ( @@ -205,31 +204,14 @@ def get_queryset(): return ( ModelRun.objects.select_related('performer') - # Get minimum score and performer so that we can tell which runs - # are ground truth - .annotate( - min_score=Min('evaluations__score'), - ) - # Label ground truths as such. A ground truth is defined as a model run - # with a min_score of 1 and a performer of "TE" - .alias( - groundtruth=Case( - When(min_score=1, performer__short_code__iexact='TE', then=True), - default=False, - ) - ) # Order queryset so that ground truths are first - .order_by('groundtruth', '-created') - .alias( - evaluation_configuration=F('evaluations__configuration'), - proposal_val=F('proposal'), - ) + .order_by('ground_truth', '-created') .annotate( region_name=F('region__name'), downloading=Coalesce( Subquery( SatelliteFetching.objects.filter( - site__configuration_id=OuterRef('evaluation_configuration'), + site__configuration_id=OuterRef('pk'), status=SatelliteFetching.Status.RUNNING, ) .annotate(count=Func(F('id'), function='Count')) @@ -271,7 +253,7 @@ def get_queryset(): ), adjudicated=Case( When( - ~Q(proposal_val=None), # When proposal has a value + ~Q(proposal=None), # When proposal has a value then=JSONObject( proposed=Coalesce(Subquery(proposed_count_subquery), 0), other=Coalesce(Subquery(other_count_subquery), 0), @@ -306,45 +288,31 @@ def paginate_queryset( filters: ModelRunFilterSchema, **params, ) -> dict[str, Any]: - # TODO: remove caching after model runs endpoint is - # refactored to be more efficient. - model_runs = None - cache_key = _get_model_runs_cache_key(filters.dict() | pagination.dict()) - - # If we have a cache miss, execute the query and save the results to cache - # before returning. - if model_runs is None: - qs = super().paginate_queryset(queryset, pagination, **params) - aggregate_kwargs = { - 'timerange': JSONObject( - min=ExtractEpoch(Min('evaluations__start_date')), - max=ExtractEpoch(Max('evaluations__end_date')), - ), - } - if filters.region: - aggregate_kwargs['bbox'] = Coalesce( - BoundingBoxGeoJSON('region__geom'), - BoundingBoxGeoJSON('evaluations__geom'), - ) - - model_runs = qs | queryset.aggregate(**aggregate_kwargs) - if filters.region: - if qs['count'] == 0: # No model runs we set bbox to Region bbox - region_filter = filters.dict()['region'] - regions = [ - obj - for obj in Region.objects.all() - if obj.value == region_filter - ] - if len(regions) > 0: - bbox = regions[0].geojson - model_runs['bbox'] = bbox - cache.set( - key=cache_key, - value=model_runs, - timeout=timedelta(days=30).total_seconds(), + qs = super().paginate_queryset(queryset, pagination, **params) + aggregate_kwargs = { + 'timerange': JSONObject( + min=ExtractEpoch(Min('evaluations__start_date')), + max=ExtractEpoch(Max('evaluations__end_date')), + ), + } + if filters.region: + aggregate_kwargs['bbox'] = Coalesce( + BoundingBoxGeoJSON('region__geom'), + BoundingBoxGeoJSON('evaluations__geom'), ) + model_runs = qs | queryset.aggregate(**aggregate_kwargs) + if filters.region: + if qs['count'] == 0: # No model runs; we set bbox to Region bbox + region_filter = filters.dict()['region'] + # We do a full table scan because this isn't a common case and isn't a bottleneck; + # the equivalent database query is really annoying due to the logic of the + # `Region.value` property. + for region in Region.objects.all(): + if region.value == region_filter: + model_runs['bbox'] = region.geojson + break + return model_runs @@ -388,27 +356,14 @@ def create_model_run( } -def _get_model_runs_cache_key(params: dict) -> str: - """Generate a cache key for the model runs listing endpoint.""" - most_recent_evaluation = ModelRun.objects.aggregate( - latest_eval=Max('evaluations__timestamp') - )['latest_eval'] - - return '|'.join( - [ - ModelRun.__name__, - str(most_recent_evaluation), - *[f'{key}={value}' for key, value in params.items()], - ] - ).replace(' ', '_') - - @router.get('/', response={200: list[ModelRunListSchema]}) @paginate(ModelRunPagination, page_size=MODEL_RUN_PAGE_SIZE) def list_model_runs( request: HttpRequest, filters: ModelRunFilterSchema = Query(...), # noqa: B008 ): + # TODO maybe the aggregate stats should have a separate endpoint since they only need to + # be queried when the filters change, not every time a new page is loaded return filters.filter(get_queryset()) From 8d72baf9ebd2caf91fdf4dc36abb67f9d9de6333 Mon Sep 17 00:00:00 2001 From: Zach Mullen Date: Thu, 17 Oct 2024 12:39:37 -0400 Subject: [PATCH 2/4] Cache expensive summary statistics on each model run --- dev/airflow_sample_dag.py | 4 +- rdwatch/core/db/functions.py | 6 ++ ..._bbox_modelrun_cached_numsites_and_more.py | 79 ++++++++++++++++ rdwatch/core/models/model_run.py | 94 +++++++++++++++++++ rdwatch/core/views/model_run.py | 29 +++--- rdwatch/scoring/views/site_image.py | 34 ++++--- 6 files changed, 215 insertions(+), 31 deletions(-) create mode 100644 rdwatch/core/migrations/0038_modelrun_cached_bbox_modelrun_cached_numsites_and_more.py diff --git a/dev/airflow_sample_dag.py b/dev/airflow_sample_dag.py index 45bee811..18997612 100644 --- a/dev/airflow_sample_dag.py +++ b/dev/airflow_sample_dag.py @@ -8,9 +8,7 @@ dag_id="RD-WATCH-AIRFLOW-DEMO-DAG", description="Test DAG", params={ - "region_id": Param( - default="BR_R002", type="string", pattern=r"^.{1,255}$" - ), + "region_id": Param(default="BR_R002", type="string", pattern=r"^.{1,255}$"), "model_run_title": Param(default="test_run", type="string"), }, start_date=datetime(2022, 3, 1), diff --git a/rdwatch/core/db/functions.py b/rdwatch/core/db/functions.py index 88bd4df6..dfef8b22 100644 --- a/rdwatch/core/db/functions.py +++ b/rdwatch/core/db/functions.py @@ -78,6 +78,12 @@ def __init__(self, field): return super().__init__(json_str, JSONField()) # noqa: B037 +class AsGeoJSONDeserialized(Cast): + def __init__(self, field): + json_str = AsGeoJSON(field) + return super().__init__(json_str, JSONField()) # noqa: B037 + + class TimeRangeJSON(NullIf): """Represents the min/max time of a field as JSON""" diff --git a/rdwatch/core/migrations/0038_modelrun_cached_bbox_modelrun_cached_numsites_and_more.py b/rdwatch/core/migrations/0038_modelrun_cached_bbox_modelrun_cached_numsites_and_more.py new file mode 100644 index 00000000..2871e211 --- /dev/null +++ b/rdwatch/core/migrations/0038_modelrun_cached_bbox_modelrun_cached_numsites_and_more.py @@ -0,0 +1,79 @@ +# Generated by Django 5.0.3 on 2024-10-17 08:19 + +import django.contrib.gis.db.models.fields +from django.db import migrations, models + + +def compute_model_run_aggregate_stats(apps, schema_editor): + from rdwatch.core.models import ModelRun + + ModelRun.compute_all_aggregate_stats(recompute_all=True) + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0037_alter_animationmodelrunexport_user_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='modelrun', + name='cached_bbox', + field=django.contrib.gis.db.models.fields.PolygonField( + blank=True, + help_text='The bounding box of all evaluations for this model run', + null=True, + spatial_index=False, + srid=4326, + ), + ), + migrations.AddField( + model_name='modelrun', + name='cached_numsites', + field=models.PositiveIntegerField( + blank=True, + help_text='The number of distinct evaluations for this model run', + null=True, + ), + ), + migrations.AddField( + model_name='modelrun', + name='cached_score', + field=models.FloatField( + blank=True, + help_text='The average score of all evaluations for this model run', + null=True, + ), + ), + migrations.AddField( + model_name='modelrun', + name='cached_timerange_max', + field=models.DateTimeField( + blank=True, + help_text='The latest timestamp of any evaluation for this model run', + null=True, + ), + ), + migrations.AddField( + model_name='modelrun', + name='cached_timerange_min', + field=models.DateTimeField( + blank=True, + help_text='The earliest timestamp of any evaluation for this model run', + null=True, + ), + ), + migrations.AddField( + model_name='modelrun', + name='cached_timestamp', + field=models.DateTimeField( + blank=True, + help_text='The timestamp of the most recent evaluation for this model run', + null=True, + ), + ), + migrations.RunPython( + code=compute_model_run_aggregate_stats, + reverse_code=migrations.RunPython.noop, + ), + ] diff --git a/rdwatch/core/models/model_run.py b/rdwatch/core/models/model_run.py index 7949e372..663115d4 100644 --- a/rdwatch/core/models/model_run.py +++ b/rdwatch/core/models/model_run.py @@ -3,8 +3,11 @@ from django_extensions.db.models import CreationDateTimeField from django.contrib.auth.models import User +from django.contrib.gis.db.models import PolygonField from django.db import models +from rdwatch.core.db.functions import BoundingBoxPolygon + class ModelRun(models.Model): id = models.UUIDField(primary_key=True, default=uuid4, editable=False) @@ -54,5 +57,96 @@ class ProposalStatus(models.TextChoices): owner = models.ForeignKey(User, on_delete=models.CASCADE, null=True, blank=True) public = models.BooleanField(default=True) + # The below fields are the denormalized aggregate statistics for this model run. + cached_numsites = models.PositiveIntegerField( + blank=True, + null=True, + help_text='The number of distinct evaluations for this model run', + ) + cached_score = models.FloatField( + blank=True, + null=True, + help_text='The average score of all evaluations for this model run', + ) + cached_timestamp = models.DateTimeField( + blank=True, + null=True, + help_text='The timestamp of the most recent evaluation for this model run', + ) + cached_timerange_min = models.DateTimeField( + blank=True, + null=True, + help_text='The earliest timestamp of any evaluation for this model run', + ) + cached_timerange_max = models.DateTimeField( + blank=True, + null=True, + help_text='The latest timestamp of any evaluation for this model run', + ) + cached_bbox = PolygonField( + srid=4326, + blank=True, + null=True, + help_text='The bounding box of all evaluations for this model run', + spatial_index=False, + ) + def __str__(self) -> str: return str(self.pk) + + def compute_aggregate_stats(self): + """ + Compute denormalized aggregate stats and persist them on this model run. + + Many of the fields we report for each model run are expensive to compute, and need to be + returned to the client often. Calling this will perform the expensive computations and + persist the results on the model run for fast retrieval. + + This should be called whenever a model run is fully ingested, and again whenever any of the + underlying data corresponding to the model changes such that these summary statistics might + be affected. + """ + stats = ModelRun.objects.filter(pk=self.pk).aggregate( + cached_bbox=BoundingBoxPolygon('evaluations__geom'), + cached_numsites=models.Count('evaluations__pk', distinct=True), + cached_score=models.Avg('evaluations__score'), + cached_timerange_min=models.Min('evaluations__start_date'), + cached_timerange_max=models.Max('evaluations__end_date'), + cached_timestamp=models.Max('evaluations__timestamp'), + ) + self.cached_bbox = stats['cached_bbox'] + self.cached_numsites = stats['cached_numsites'] + self.cached_score = stats['cached_score'] + self.cached_timerange_min = stats['cached_timerange_min'] + self.cached_timerange_max = stats['cached_timerange_max'] + self.cached_timestamp = stats['cached_timestamp'] + + self.save( + update_fields=[ + 'cached_bbox', + 'cached_numsites', + 'cached_score', + 'cached_timerange_min', + 'cached_timerange_max', + 'cached_timestamp', + ] + ) + + @classmethod + def compute_all_aggregate_stats(cls, recompute_all: bool = False): + """ + Computes aggregate stats on all model runs that are missing them. + + If `recompute_all` is True, all model runs will have their aggregate stats recomputed + regardless of whether they are already populated. + """ + if recompute_all: + qs = cls.objects.all() + else: + qs = cls.objects.filter(cached_numsites=None) + + # Unfortunately, Django ORM doesn't support joins in update queries, so doing this in a + # single join query would require raw SQL that would be a pain to maintain. + # Since this should only rarely run, we don't mind looping here. + for model_run in qs: + model_run.compute_aggregate_stats() diff --git a/rdwatch/core/views/model_run.py b/rdwatch/core/views/model_run.py index 4abc6d95..f5ecce5e 100644 --- a/rdwatch/core/views/model_run.py +++ b/rdwatch/core/views/model_run.py @@ -17,7 +17,6 @@ from django.core.files.storage import default_storage from django.db import transaction from django.db.models import ( - Avg, Case, Count, Exists, @@ -36,7 +35,12 @@ from django.shortcuts import get_object_or_404 from django.views.decorators.csrf import csrf_exempt -from rdwatch.core.db.functions import BoundingBox, BoundingBoxGeoJSON, ExtractEpoch +from rdwatch.core.db.functions import ( + AsGeoJSONDeserialized, + BoundingBox, + BoundingBoxGeoJSON, + ExtractEpoch, +) from rdwatch.core.models import ( AnnotationExport, ModelRun, @@ -134,9 +138,9 @@ class ModelRunDetailSchema(Schema): region: str = Field(alias='region_name') performer: PerformerSchema parameters: dict - numsites: int | None = 0 + numsites: int | None = Field(0, alias='cached_score') downloading: int | None = None - score: float | None = None + score: float | None = Field(None, alias='cached_score') timestamp: int | None = None timerange: TimeRangeSchema | None = None bbox: dict | None @@ -155,9 +159,9 @@ class ModelRunListSchema(Schema): region: str = Field(..., alias='region_name') performer: PerformerSchema parameters: dict - numsites: int | None = 0 + numsites: int | None = Field(0, alias='cached_numsites') downloading: int | None = None - score: float | None = None + score: float | None = Field(None, alias='cached_score') timestamp: int | None = None timerange: TimeRangeSchema | None = None bbox: dict | None @@ -205,8 +209,7 @@ def get_queryset(): return ( ModelRun.objects.select_related('performer') # Order queryset so that ground truths are first - .order_by('ground_truth', '-created') - .annotate( + .order_by('ground_truth', '-created').annotate( region_name=F('region__name'), downloading=Coalesce( Subquery( @@ -219,14 +222,12 @@ def get_queryset(): ), 0, # Default value when evaluations are None ), - numsites=Count('evaluations__pk', distinct=True), - score=Avg('evaluations__score'), - timestamp=ExtractEpoch(Max('evaluations__timestamp')), + timestamp=ExtractEpoch('cached_timestamp'), timerange=JSONObject( - min=ExtractEpoch(Min('evaluations__start_date')), - max=ExtractEpoch(Max('evaluations__end_date')), + min=ExtractEpoch('cached_timerange_min'), + max=ExtractEpoch('cached_timerange_max'), ), - bbox=BoundingBoxGeoJSON('evaluations__geom'), + bbox=AsGeoJSONDeserialized('cached_bbox'), groundTruthLink=Case( When( Q(ground_truth=False), diff --git a/rdwatch/scoring/views/site_image.py b/rdwatch/scoring/views/site_image.py index 173a4eaa..f791326f 100644 --- a/rdwatch/scoring/views/site_image.py +++ b/rdwatch/scoring/views/site_image.py @@ -99,14 +99,16 @@ def site_images(request: HttpRequest, id: UUID4): geom_queryset = observations.aggregate( results=JSONBAgg( JSONObject( - label=Func( - F(observation_db_model_cols['phase']), - Value(', '), - function='array_to_string', - output=CharField(), - ) - if proposal - else observation_db_model_cols['phase'], + label=( + Func( + F(observation_db_model_cols['phase']), + Value(', '), + function='array_to_string', + output=CharField(), + ) + if proposal + else observation_db_model_cols['phase'] + ), timestamp=ExtractEpoch(observation_db_model_cols['date']), geoJSON=Func( F('geometry'), @@ -132,9 +134,11 @@ def site_images(request: HttpRequest, id: UUID4): .annotate( json=JSONObject( label=F(site_db_model_cols['status']), - status=F(site_db_model_cols['proposal_status']) - if site_db_model_cols['proposal_status'] - else Value(''), + status=( + F(site_db_model_cols['proposal_status']) + if site_db_model_cols['proposal_status'] + else Value('') + ), evaluationGeoJSON=Func( F(site_db_model_cols['geometry']), 4326, @@ -149,9 +153,11 @@ def site_images(request: HttpRequest, id: UUID4): output_field=GeometryField(), ) ), - notes=F(site_db_model_cols['notes']) - if site_db_model_cols['notes'] - else Value(''), # TODO + notes=( + F(site_db_model_cols['notes']) + if site_db_model_cols['notes'] + else Value('') + ), # TODO ) )[0] ) From 2e345e9abffac2226d0aeb37f916df6106c57ee8 Mon Sep 17 00:00:00 2001 From: Zach Mullen Date: Mon, 21 Oct 2024 15:50:52 -0400 Subject: [PATCH 3/4] Compute aggregate stats on newly-ingested model runs --- rdwatch/core/tasks/__init__.py | 2 ++ rdwatch/core/views/model_run.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/rdwatch/core/tasks/__init__.py b/rdwatch/core/tasks/__init__.py index 7bfc9bfe..8d4fd8b2 100644 --- a/rdwatch/core/tasks/__init__.py +++ b/rdwatch/core/tasks/__init__.py @@ -888,6 +888,8 @@ def process_model_run_upload(model_run_upload: ModelRunUpload): if region_model: SiteEvaluation.bulk_create_from_region_model(region_model, model_run) + model_run.compute_aggregate_stats() + @shared_task(bind=True) def process_model_run_upload_task(task, upload_id: UUID): diff --git a/rdwatch/core/views/model_run.py b/rdwatch/core/views/model_run.py index f5ecce5e..b1f65d17 100644 --- a/rdwatch/core/views/model_run.py +++ b/rdwatch/core/views/model_run.py @@ -368,6 +368,19 @@ def list_model_runs( return filters.filter(get_queryset()) +@router.post( + '/{id}/finalization/', + response={200: ModelRunDetailSchema}, + auth=[ModelRunAuth()], +) +# this is safe because we're using a nonstandard header w/ API Key for auth +@csrf_exempt +def finalize_model_run(request: HttpRequest, id: UUID4): + model_run = get_object_or_404(ModelRun, pk=id) + model_run.compute_aggregate_stats() + return model_run + + @router.get('/{id}/', response={200: ModelRunDetailSchema}) def get_model_run(request: HttpRequest, id: UUID4): return get_object_or_404(get_queryset(), id=id) From c8cef93ba430c5e924e01de8a0ba3150742eb13d Mon Sep 17 00:00:00 2001 From: Zach Mullen Date: Mon, 21 Oct 2024 15:59:26 -0400 Subject: [PATCH 4/4] Merge migrations --- rdwatch/core/migrations/0041_merge_20241021_1458.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 rdwatch/core/migrations/0041_merge_20241021_1458.py diff --git a/rdwatch/core/migrations/0041_merge_20241021_1458.py b/rdwatch/core/migrations/0041_merge_20241021_1458.py new file mode 100644 index 00000000..884e51b7 --- /dev/null +++ b/rdwatch/core/migrations/0041_merge_20241021_1458.py @@ -0,0 +1,12 @@ +# Generated by Django 5.0.3 on 2024-10-21 14:58 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0038_modelrun_cached_bbox_modelrun_cached_numsites_and_more'), + ('core', '0040_remove_siteimage_aws_location'), + ] + + operations = []