Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize model run listing #520

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions dev/airflow_sample_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
dag_id="RD-WATCH-AIRFLOW-DEMO-DAG",
description="Test DAG",
params={
"region_id": Param(
default="BR_R002", type="string", pattern=r"^.{1,255}$"
),
"region_id": Param(default="BR_R002", type="string", pattern=r"^.{1,255}$"),
"model_run_title": Param(default="test_run", type="string"),
},
start_date=datetime(2022, 3, 1),
Expand Down
6 changes: 6 additions & 0 deletions rdwatch/core/db/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def __init__(self, field):
return super().__init__(json_str, JSONField()) # noqa: B037


class AsGeoJSONDeserialized(Cast):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's an easier way of getting AsGeoJSON's result into a deserialized python object, I'd love to know it.

def __init__(self, field):
json_str = AsGeoJSON(field)
return super().__init__(json_str, JSONField()) # noqa: B037


class TimeRangeJSON(NullIf):
"""Represents the min/max time of a field as JSON"""

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Generated by Django 5.0.3 on 2024-10-17 08:19

import django.contrib.gis.db.models.fields
from django.db import migrations, models


def compute_model_run_aggregate_stats(apps, schema_editor):
from rdwatch.core.models import ModelRun

ModelRun.compute_all_aggregate_stats(recompute_all=True)


class Migration(migrations.Migration):
dependencies = [
('core', '0037_alter_animationmodelrunexport_user_and_more'),
]

operations = [
migrations.AddField(
model_name='modelrun',
name='cached_bbox',
field=django.contrib.gis.db.models.fields.PolygonField(
blank=True,
help_text='The bounding box of all evaluations for this model run',
null=True,
spatial_index=False,
srid=4326,
),
),
migrations.AddField(
model_name='modelrun',
name='cached_numsites',
field=models.PositiveIntegerField(
blank=True,
help_text='The number of distinct evaluations for this model run',
null=True,
),
),
migrations.AddField(
model_name='modelrun',
name='cached_score',
field=models.FloatField(
blank=True,
help_text='The average score of all evaluations for this model run',
null=True,
),
),
migrations.AddField(
model_name='modelrun',
name='cached_timerange_max',
field=models.DateTimeField(
blank=True,
help_text='The latest timestamp of any evaluation for this model run',
null=True,
),
),
migrations.AddField(
model_name='modelrun',
name='cached_timerange_min',
field=models.DateTimeField(
blank=True,
help_text='The earliest timestamp of any evaluation for this model run',
null=True,
),
),
migrations.AddField(
model_name='modelrun',
name='cached_timestamp',
field=models.DateTimeField(
blank=True,
help_text='The timestamp of the most recent evaluation for this model run',
null=True,
),
),
migrations.RunPython(
code=compute_model_run_aggregate_stats,
reverse_code=migrations.RunPython.noop,
),
]
12 changes: 12 additions & 0 deletions rdwatch/core/migrations/0041_merge_20241021_1458.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Generated by Django 5.0.3 on 2024-10-21 14:58

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
('core', '0038_modelrun_cached_bbox_modelrun_cached_numsites_and_more'),
('core', '0040_remove_siteimage_aws_location'),
]

operations = []
94 changes: 94 additions & 0 deletions rdwatch/core/models/model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from django_extensions.db.models import CreationDateTimeField

from django.contrib.auth.models import User
from django.contrib.gis.db.models import PolygonField
from django.db import models

from rdwatch.core.db.functions import BoundingBoxPolygon


class ModelRun(models.Model):
id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
Expand Down Expand Up @@ -54,5 +57,96 @@ class ProposalStatus(models.TextChoices):
owner = models.ForeignKey(User, on_delete=models.CASCADE, null=True, blank=True)
public = models.BooleanField(default=True)

# The below fields are the denormalized aggregate statistics for this model run.
cached_numsites = models.PositiveIntegerField(
blank=True,
null=True,
help_text='The number of distinct evaluations for this model run',
)
cached_score = models.FloatField(
blank=True,
null=True,
help_text='The average score of all evaluations for this model run',
)
cached_timestamp = models.DateTimeField(
blank=True,
null=True,
help_text='The timestamp of the most recent evaluation for this model run',
)
cached_timerange_min = models.DateTimeField(
blank=True,
null=True,
help_text='The earliest timestamp of any evaluation for this model run',
)
cached_timerange_max = models.DateTimeField(
blank=True,
null=True,
help_text='The latest timestamp of any evaluation for this model run',
)
cached_bbox = PolygonField(
srid=4326,
blank=True,
null=True,
help_text='The bounding box of all evaluations for this model run',
spatial_index=False,
)

def __str__(self) -> str:
return str(self.pk)

def compute_aggregate_stats(self):
"""
Compute denormalized aggregate stats and persist them on this model run.

Many of the fields we report for each model run are expensive to compute, and need to be
returned to the client often. Calling this will perform the expensive computations and
persist the results on the model run for fast retrieval.

This should be called whenever a model run is fully ingested, and again whenever any of the
underlying data corresponding to the model changes such that these summary statistics might
be affected.
"""
stats = ModelRun.objects.filter(pk=self.pk).aggregate(
cached_bbox=BoundingBoxPolygon('evaluations__geom'),
cached_numsites=models.Count('evaluations__pk', distinct=True),
cached_score=models.Avg('evaluations__score'),
cached_timerange_min=models.Min('evaluations__start_date'),
cached_timerange_max=models.Max('evaluations__end_date'),
cached_timestamp=models.Max('evaluations__timestamp'),
)
self.cached_bbox = stats['cached_bbox']
self.cached_numsites = stats['cached_numsites']
self.cached_score = stats['cached_score']
self.cached_timerange_min = stats['cached_timerange_min']
self.cached_timerange_max = stats['cached_timerange_max']
self.cached_timestamp = stats['cached_timestamp']

self.save(
update_fields=[
'cached_bbox',
'cached_numsites',
'cached_score',
'cached_timerange_min',
'cached_timerange_max',
'cached_timestamp',
]
)

@classmethod
def compute_all_aggregate_stats(cls, recompute_all: bool = False):
"""
Computes aggregate stats on all model runs that are missing them.

If `recompute_all` is True, all model runs will have their aggregate stats recomputed
regardless of whether they are already populated.
"""
if recompute_all:
qs = cls.objects.all()
else:
qs = cls.objects.filter(cached_numsites=None)

# Unfortunately, Django ORM doesn't support joins in update queries, so doing this in a
# single join query would require raw SQL that would be a pain to maintain.
# Since this should only rarely run, we don't mind looping here.
for model_run in qs:
model_run.compute_aggregate_stats()
2 changes: 2 additions & 0 deletions rdwatch/core/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,8 @@ def process_model_run_upload(model_run_upload: ModelRunUpload):
if region_model:
SiteEvaluation.bulk_create_from_region_model(region_model, model_run)

model_run.compute_aggregate_stats()


@shared_task(bind=True)
def process_model_run_upload_task(task, upload_id: UUID):
Expand Down
Loading
Loading