diff --git a/exceptional_situations/api/serializers.py b/exceptional_situations/api/serializers.py index 983cde86a..58c5b0c73 100644 --- a/exceptional_situations/api/serializers.py +++ b/exceptional_situations/api/serializers.py @@ -16,6 +16,7 @@ class Meta: class SituationAnnouncementSerializer(serializers.ModelSerializer): location = SituationLocationSerializer() + municipalities = serializers.SerializerMethodField() class Meta: model = SituationAnnouncement @@ -27,8 +28,12 @@ class Meta: "end_time", "additional_info", "location", + "municipalities", ] + def get_municipalities(self, obj): + return [m.id for m in obj.municipalities.all()] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/exceptional_situations/api/views.py b/exceptional_situations/api/views.py index f14488ba1..6213a7906 100644 --- a/exceptional_situations/api/views.py +++ b/exceptional_situations/api/views.py @@ -1,4 +1,5 @@ import django_filters +from django.db.models import Q from django_filters.rest_framework import DjangoFilterBackend from rest_framework import viewsets @@ -23,6 +24,7 @@ class SituationFilter(django_filters.FilterSet): start_time__lt = django_filters.DateTimeFilter(method="filter_start_time__lt") end_time__gt = django_filters.DateTimeFilter(method="filter_end_time__gt") end_time__lt = django_filters.DateTimeFilter(method="filter_end_time__lt") + municipalities = django_filters.CharFilter(method="filter_municipalities") class Meta: model = Situation @@ -50,14 +52,21 @@ def filter_start_time__lt(self, queryset, fields, start_time): ids = [obj.id for obj in queryset if obj.start_time < start_time] return queryset.filter(id__in=ids) - def filter_end_time__gt(self, queryset, fields, start_time): - ids = [obj.id for obj in queryset if obj.start_time > start_time] + def filter_end_time__gt(self, queryset, fields, end_time): + ids = [obj.id for obj in queryset if obj.end_time > end_time] return queryset.filter(id__in=ids) - def filter_end_time__lt(self, queryset, fields, start_time): - ids = [obj.id for obj in queryset if obj.start_time < start_time] + def filter_end_time__lt(self, queryset, fields, end_time): + ids = [obj.id for obj in queryset if obj.end_time < end_time] return queryset.filter(id__in=ids) + def filter_municipalities(self, queryset, fields, municipalities): + municipalities = municipalities.split(",") + query = Q() + for municiaplity in municipalities: + query |= Q(announcements__municipalities__id__iexact=municiaplity.strip()) + return queryset.filter(query).distinct() + class SituationViewSet(viewsets.ReadOnlyModelViewSet): queryset = Situation.objects.all() diff --git a/exceptional_situations/management/commands/import_traffic_situations.py b/exceptional_situations/management/commands/import_traffic_situations.py index d7123a514..a936a1f9a 100644 --- a/exceptional_situations/management/commands/import_traffic_situations.py +++ b/exceptional_situations/management/commands/import_traffic_situations.py @@ -11,6 +11,7 @@ from django.contrib.gis.geos import GEOSGeometry, Polygon from django.core.management import BaseCommand from django.utils import timezone +from munigeo.models import Municipality from exceptional_situations.models import ( PROJECTION_SRID, @@ -65,6 +66,19 @@ def create_location(self, geometry, announcement_data): } return get_or_create(SituationLocation, filter) + def get_municipality_lower_names(self, location_details): + names = [] + road_address_location = location_details.get("roadAddressLocation", None) + if road_address_location: + primary_point = road_address_location.get("primaryPoint", None) + if primary_point: + names.append(primary_point["municipality"].lower()) + secondary_point = road_address_location.get("secondaryPoint", None) + if secondary_point: + names.append(secondary_point["municipality"].lower()) + + return names + def create_announcement(self, announcement_data, location): title = announcement_data.get("title", "") description = announcement_data["location"].get("description", "") @@ -97,7 +111,19 @@ def create_announcement(self, announcement_data, location): "end_time": end_time, "location": location, } - return get_or_create(SituationAnnouncement, filter) + + announcement = get_or_create(SituationAnnouncement, filter) + location_details = announcement_data.get("locationDetails", None) + if location_details: + announcement.municipalities.clear() + municipality_names = self.get_municipality_lower_names(location_details) + for name in municipality_names: + try: + municipality = Municipality.objects.get(id=name) + announcement.municipalities.add(municipality) + except Municipality.DoesNotExist: + logger.warning(f"Municipality {name} does not exists") + return announcement def save_features(self, features): num_imported = 0 diff --git a/exceptional_situations/migrations/0003_situationannouncement_municipalities.py b/exceptional_situations/migrations/0003_situationannouncement_municipalities.py new file mode 100644 index 000000000..68bc70dfc --- /dev/null +++ b/exceptional_situations/migrations/0003_situationannouncement_municipalities.py @@ -0,0 +1,19 @@ +# Generated by Django 4.1.13 on 2024-06-24 08:07 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("munigeo", "0016_address_modified_at_remove_auto_now"), + ("exceptional_situations", "0002_alter_situationannouncement_location"), + ] + + operations = [ + migrations.AddField( + model_name="situationannouncement", + name="municipalities", + field=models.ManyToManyField(to="munigeo.municipality"), + ), + ] diff --git a/exceptional_situations/models.py b/exceptional_situations/models.py index 1ea08a353..a4c10d616 100644 --- a/exceptional_situations/models.py +++ b/exceptional_situations/models.py @@ -2,6 +2,7 @@ from django.contrib.gis.db import models from django.utils import timezone +from munigeo.models import Municipality PROJECTION_SRID = 4326 @@ -39,6 +40,7 @@ class SituationAnnouncement(models.Model): blank=True, related_name="announcements", ) + municipalities = models.ManyToManyField(Municipality) class Meta: ordering = ["start_time"] diff --git a/exceptional_situations/tests/conftest.py b/exceptional_situations/tests/conftest.py index a91ed5434..18ebb85d0 100644 --- a/exceptional_situations/tests/conftest.py +++ b/exceptional_situations/tests/conftest.py @@ -3,6 +3,7 @@ import pytest from django.contrib.gis.geos import GEOSGeometry from django.utils import timezone +from munigeo.models import Municipality from rest_framework.test import APIClient from exceptional_situations.models import ( @@ -20,6 +21,15 @@ def api_client(): return APIClient() +@pytest.mark.django_db +@pytest.fixture +def municipalities(): + Municipality.objects.create(id="turku", name="Turku") + Municipality.objects.create(id="lieto", name="Lieto") + Municipality.objects.create(id="raisio", name="Raisio") + return Municipality.objects.all() + + @pytest.mark.django_db @pytest.fixture def situation_types(): @@ -48,9 +58,9 @@ def locations(): @pytest.mark.django_db @pytest.fixture -def announcements(locations): +def announcements(locations, municipalities): json_data = {"test_key": "test_value"} - SituationAnnouncement.objects.create( + sa = SituationAnnouncement.objects.create( title="two hours", description="two hours long situation", additional_info=json_data, @@ -58,7 +68,9 @@ def announcements(locations): start_time=NOW - timedelta(hours=1), end_time=NOW + timedelta(hours=1), ) - SituationAnnouncement.objects.create( + sa.municipalities.add(municipalities.filter(id="turku").first()) + sa.municipalities.add(municipalities.filter(id="lieto").first()) + sa = SituationAnnouncement.objects.create( title="two days", description="two days long situation", additional_info=json_data, @@ -66,7 +78,7 @@ def announcements(locations): start_time=NOW - timedelta(days=1), end_time=NOW + timedelta(days=1), ) - + sa.municipalities.add(municipalities.filter(id="raisio").first()) return SituationAnnouncement.objects.all() diff --git a/exceptional_situations/tests/test_api.py b/exceptional_situations/tests/test_api.py index 12c78b227..cc32fc3f9 100644 --- a/exceptional_situations/tests/test_api.py +++ b/exceptional_situations/tests/test_api.py @@ -38,6 +38,7 @@ def test_situations_list(api_client, situations, inactive_situations): "end_time", "additional_info", "location", + "municipalities", } location = announcement["location"] assert location.keys() == {"id", "location", "geometry", "details"} @@ -102,12 +103,12 @@ def test_situation_filter_by_end_time(api_client, situations): SITUATION_LIST_URL + f"?end_time__gt={datetime.strftime(end_time, DATETIME_FORMAT)}" ) - assert response.json()["count"] == 1 + assert response.json()["count"] == 2 response = api_client.get( SITUATION_LIST_URL + f"?end_time__lt={datetime.strftime(end_time, DATETIME_FORMAT)}" ) - assert response.json()["count"] == 1 + assert response.json()["count"] == 0 end_time = timezone.now() - timedelta(days=2) response = api_client.get( @@ -122,6 +123,14 @@ def test_situation_filter_by_end_time(api_client, situations): assert response.json()["count"] == 0 +@pytest.mark.django_db +def test_filter_by_municipalities(api_client, situations): + response = api_client.get(SITUATION_LIST_URL + "?municipalities=raisio,lieto") + assert response.json()["count"] == 2 + response = api_client.get(SITUATION_LIST_URL + "?municipalities=turku") + assert response.json()["count"] == 1 + + @pytest.mark.django_db def test_situation_types_list(api_client, situation_types): response = api_client.get(reverse("exceptional_situations:situation_type-list")) @@ -163,6 +172,7 @@ def test_announcement_list(api_client, announcements): "end_time", "additional_info", "location", + "municipalities", } location = result_data["location"] assert location.keys() == {"id", "location", "geometry", "details"} @@ -186,6 +196,7 @@ def test_announcement_retrieve(api_client, announcements): "end_time", "additional_info", "location", + "municipalities", } assert json_data["id"] == announcements[0].pk diff --git a/exceptional_situations/tests/test_import_traffic_situations.py b/exceptional_situations/tests/test_import_traffic_situations.py index 4bcd9389d..8453a5092 100644 --- a/exceptional_situations/tests/test_import_traffic_situations.py +++ b/exceptional_situations/tests/test_import_traffic_situations.py @@ -336,7 +336,7 @@ def import_command(*args, **kwargs): @pytest.mark.django_db @freeze_time("2024-06-11 12:00:00", tz_offset=2) -def test_import_traffic_situation(): +def test_import_traffic_situation(municipalities): import_command(test_importer=data) assert SituationType.objects.count() == 1 assert SituationType.objects.first().type_name == "ROAD_WORK" @@ -352,6 +352,9 @@ def test_import_traffic_situation(): assert location.details["primaryPoint"]["roadName"] == "Turun kehätie" announcement = SituationAnnouncement.objects.first() assert announcement in situation.announcements.all() + assert announcement.municipalities.filter( + id=municipalities.get(id="turku").id + ).exists() assert announcement.title == "Tie 40, eli Turun kehätie, Turku. Tietyö. " assert "aikutusalue 1,1 km, suuntaan Kärsämäen " in announcement.description assert (