diff --git a/environment_data/api/utils.py b/environment_data/api/utils.py index 4c78d2206..2b69b104b 100644 --- a/environment_data/api/utils.py +++ b/environment_data/api/utils.py @@ -1,55 +1,127 @@ -from datetime import datetime - -from rest_framework.exceptions import ParseError - -from .constants import DATA_TYPES, DATETIME_FORMATS, DAY, HOUR, MONTH, WEEK, YEAR - - -def validate_timestamp(timestamp_str, data_type): - time_format = DATETIME_FORMATS[data_type] - try: - datetime.strptime(timestamp_str, time_format) - except ValueError: - return f"{timestamp_str} invalid format date format, valid format for type {data_type} is {time_format}" - return None - - -def get_start_and_end_and_year(filters, data_type): - start = filters.get("start", None) - end = filters.get("end", None) - year = filters.get("year", None) - - if not start or not end: - raise ParseError("Supply both 'start' and 'end' parameters") - - if YEAR not in data_type and not year: - raise ParseError("Supply 'year' parameter") - - res1 = None - res2 = None - match data_type: - case DATA_TYPES.DAY: - res1 = validate_timestamp(start, DAY) - res2 = validate_timestamp(end, DAY) - case DATA_TYPES.HOUR: - res1 = validate_timestamp(start, HOUR) - res2 = validate_timestamp(end, HOUR) - case DATA_TYPES.WEEK: - res1 = validate_timestamp(start, WEEK) - res2 = validate_timestamp(end, WEEK) - case DATA_TYPES.MONTH: - res1 = validate_timestamp(start, MONTH) - res2 = validate_timestamp(end, MONTH) - case DATA_TYPES.YEAR: - res1 = validate_timestamp(start, YEAR) - res2 = validate_timestamp(end, YEAR) - - if res1: - raise ParseError(res1) - if res2: - raise ParseError(res2) - - if HOUR in data_type or DAY in data_type: - start = f"{year}-{start}" - end = f"{year}-{end}" - return start, end, year +import django_filters + +from environment_data.models import ( + DayData, + HourData, + MonthData, + Station, + WeekData, + YearData, +) + + +class StationFilterSet(django_filters.FilterSet): + geo_id = django_filters.NumberFilter(field_name="geo_id", lookup_expr="exact") + name = django_filters.CharFilter(lookup_expr="icontains") + + class Meta: + model = Station + fields = {"data_type": ["exact"]} + + +class BaseFilterSet(django_filters.FilterSet): + + station_id = django_filters.NumberFilter(field_name="station") + + class Meta: + fields = {"station": ["exact"]} + + def get_date(self, year_number, month_and_day): + return f"{year_number}-{month_and_day}" + + +class YearDataFilterSet(django_filters.FilterSet): + station_id = django_filters.NumberFilter(field_name="station") + start = django_filters.NumberFilter( + field_name="year__year_number", lookup_expr="gte" + ) + end = django_filters.NumberFilter(field_name="year__year_number", lookup_expr="lte") + + class Meta: + model = YearData + fields = {"station": ["exact"]} + + +class MonthDataFilterSet(BaseFilterSet): + def filter_year(self, queryset, field, year): + return queryset.filter(month__year__year_number=year) + + year = django_filters.NumberFilter(method="filter_year") + start = django_filters.NumberFilter( + field_name="month__month_number", lookup_expr="gte" + ) + end = django_filters.NumberFilter( + field_name="month__month_number", lookup_expr="lte" + ) + + class Meta: + model = MonthData + fields = BaseFilterSet.Meta.fields + + +class WeekDataFilterSet(BaseFilterSet): + def filter_year(self, queryset, field, year): + return queryset.filter(week__years__year_number=year) + + year = django_filters.NumberFilter(method="filter_year") + start = django_filters.NumberFilter( + field_name="week__week_number", lookup_expr="gte" + ) + end = django_filters.NumberFilter(field_name="week__week_number", lookup_expr="lte") + + class Meta: + model = WeekData + fields = BaseFilterSet.Meta.fields + + +class DateDataFilterSet(BaseFilterSet): + DATE_MODEL_NAME = None + YEAR_LOOKUP = None + + def filter_year(self, queryset, field, year): + return queryset.filter(**{f"{self.DATE_MODEL_NAME}__year__year_number": year}) + + def filter_start(self, queryset, field, start): + first = queryset.first() + if first: + lookup = first + if self.YEAR_LOOKUP: + lookup = getattr(first, self.YEAR_LOOKUP) + date = self.get_date(lookup.day.year.year_number, start) + return queryset.filter(**{f"{self.DATE_MODEL_NAME}__date__gte": date}) + else: + return queryset.none() + + def filter_end(self, queryset, field, end): + first = queryset.first() + if first: + lookup = first + if self.YEAR_LOOKUP: + lookup = getattr(first, self.YEAR_LOOKUP) + date = self.get_date(lookup.day.year.year_number, end) + return queryset.filter(**{f"{self.DATE_MODEL_NAME}__date__lte": date}) + else: + return queryset.none() + + year = django_filters.NumberFilter(method="filter_year") + start = django_filters.CharFilter(method="filter_start") + end = django_filters.CharFilter(method="filter_end") + + +class DayDataFilterSet(DateDataFilterSet): + + DATE_MODEL_NAME = "day" + + class Meta: + model = DayData + fields = BaseFilterSet.Meta.fields + + +class HourDataFilterSet(DateDataFilterSet): + + DATE_MODEL_NAME = "hour__day" + YEAR_LOOKUP = "hour" + + class Meta: + model = HourData + fields = BaseFilterSet.Meta.fields diff --git a/environment_data/api/views.py b/environment_data/api/views.py index c52a89ce6..da4dc8766 100644 --- a/environment_data/api/views.py +++ b/environment_data/api/views.py @@ -1,12 +1,12 @@ from django.utils.decorators import method_decorator from django.views.decorators.cache import cache_page +from django_filters.rest_framework import DjangoFilterBackend from drf_spectacular.utils import extend_schema, extend_schema_view -from rest_framework import status, viewsets -from rest_framework.response import Response +from rest_framework import viewsets +from rest_framework.exceptions import ValidationError from environment_data.api.constants import ( DATA_TYPES, - DATETIME_FORMATS, ENVIRONMENT_DATA_PARAMS, ENVIRONMENT_STATION_PARAMS, ) @@ -19,7 +19,7 @@ WeekDataSerializer, YearDataSerializer, ) -from environment_data.constants import DATA_TYPES_LIST, VALID_DATA_TYPE_CHOICES +from environment_data.constants import DATA_TYPES_LIST from environment_data.models import ( DayData, HourData, @@ -30,7 +30,14 @@ YearData, ) -from .utils import get_start_and_end_and_year +from .utils import ( + DayDataFilterSet, + HourDataFilterSet, + MonthDataFilterSet, + StationFilterSet, + WeekDataFilterSet, + YearDataFilterSet, +) @extend_schema_view( @@ -42,25 +49,12 @@ class StationViewSet(viewsets.ReadOnlyModelViewSet): queryset = Station.objects.all() serializer_class = StationSerializer + filter_backends = [DjangoFilterBackend] + filterset_class = StationFilterSet @method_decorator(cache_page(60 * 60)) def list(self, request, *args, **kwargs): - queryset = self.queryset - filters = self.request.query_params - data_type = filters.get("data_type", None) - if data_type: - data_type = str(data_type).upper() - if data_type not in DATA_TYPES_LIST: - return Response( - f"Invalid data type, valid types are: {VALID_DATA_TYPE_CHOICES}", - status=status.HTTP_400_BAD_REQUEST, - ) - - queryset = queryset.filter(data_type=data_type) - - page = self.paginate_queryset(queryset) - serializer = self.serializer_class(page, many=True) - return self.get_paginated_response(serializer.data) + return super().list(request, *args, **kwargs) @extend_schema_view( @@ -82,78 +76,64 @@ class ParameterViewSet(viewsets.ReadOnlyModelViewSet): ) ) class DataViewSet(viewsets.GenericViewSet): - queryset = YearData.objects.all() - def list(self, request, *args, **kwargs): - filters = self.request.query_params - station_id = filters.get("station_id", None) - if not station_id: - return Response( - "Supply 'station_id' parameter.", status=status.HTTP_400_BAD_REQUEST - ) - else: - try: - station = Station.objects.get(id=station_id) - except Station.DoesNotExist: - return Response( - f"Station with id {station_id} not found.", - status=status.HTTP_400_BAD_REQUEST, - ) + queryset = [] + serializer_class = None - data_type = filters.get("type", None) - if not data_type: - return Response( - "Supply 'type' parameter", status=status.HTTP_400_BAD_REQUEST - ) - else: - data_type = data_type.lower() + def get_serializer_class(self): + data_type = self.request.query_params.get("type", "").lower() + match data_type: + case DATA_TYPES.HOUR: + return HourDataSerializer + case DATA_TYPES.DAY: + return DayDataSerializer + case DATA_TYPES.WEEK: + return WeekDataSerializer + case DATA_TYPES.MONTH: + return MonthDataSerializer + case DATA_TYPES.YEAR: + return YearDataSerializer + case _: + raise ValidationError( + f"Provide a valid 'type' parameter. Valid types are: {', '.join([f for f in DATA_TYPES_LIST])}", + ) - start, end, year = get_start_and_end_and_year(filters, data_type) + def get_queryset(self): + params = self.request.query_params + data_type = params.get("type", "").lower() + queryset = YearData.objects.all() match data_type: case DATA_TYPES.HOUR: - queryset = HourData.objects.filter( - station=station, - hour__day__year__year_number=year, - hour__day__date__gte=start, - hour__day__date__lte=end, + filter_set = HourDataFilterSet( + data=params, queryset=HourData.objects.all() ) - serializer_class = HourDataSerializer case DATA_TYPES.DAY: - queryset = DayData.objects.filter( - station=station, - day__date__gte=start, - day__date__lte=end, - day__year__year_number=year, + filter_set = DayDataFilterSet( + data=params, queryset=DayData.objects.all() ) - serializer_class = DayDataSerializer case DATA_TYPES.WEEK: - serializer_class = WeekDataSerializer - queryset = WeekData.objects.filter( - week__years__year_number=year, - station=station, - week__week_number__gte=start, - week__week_number__lte=end, + filter_set = WeekDataFilterSet( + data=params, queryset=WeekData.objects.all() ) case DATA_TYPES.MONTH: - serializer_class = MonthDataSerializer - queryset = MonthData.objects.filter( - month__year__year_number=year, - station=station, - month__month_number__gte=start, - month__month_number__lte=end, + filter_set = MonthDataFilterSet( + data=params, queryset=MonthData.objects.all() ) case DATA_TYPES.YEAR: - serializer_class = YearDataSerializer - queryset = YearData.objects.filter( - station=station, - year__year_number__gte=start, - year__year_number__lte=end, + filter_set = YearDataFilterSet( + data=params, queryset=YearData.objects.all() ) case _: - return Response( - f"Provide a valid 'type' parameters. Valid types are: {', '.join([f for f in DATETIME_FORMATS])}", - status=status.HTTP_400_BAD_REQUEST, + raise ValidationError( + f"Provide a valid 'type' parameter. Valid types are: {', '.join([f for f in DATA_TYPES_LIST])}", ) + if filter_set and filter_set.is_valid(): + return filter_set.qs + else: + return queryset.none() + + def list(self, request, *args, **kwargs): + queryset = self.get_queryset() page = self.paginate_queryset(queryset) - serializer = serializer_class(page, many=True) + serializer = self.get_serializer_class()(page, many=True) return self.get_paginated_response(serializer.data) diff --git a/environment_data/tests/conftest.py b/environment_data/tests/conftest.py index 6dad1c53b..fe65f6c48 100644 --- a/environment_data/tests/conftest.py +++ b/environment_data/tests/conftest.py @@ -49,6 +49,8 @@ def stations(parameters): @pytest.fixture def measurements(parameters): Measurement.objects.create(id=1, parameter=Parameter.objects.get(id=1), value=1.5) + Measurement.objects.create(id=2, parameter=Parameter.objects.get(id=2), value=2) + return Measurement.objects.all() @@ -58,7 +60,6 @@ def parameters(): Parameter.objects.create(id=1, name="AQINDEX_PT1H_avg") Parameter.objects.create(id=2, name="NO2_PT1H_avg") Parameter.objects.create(id=3, name="WS_PT1H_avg") - return Parameter.objects.all() @@ -66,6 +67,7 @@ def parameters(): @pytest.fixture def years(): Year.objects.create(id=1, year_number=2023) + Year.objects.create(id=2, year_number=2022) return Year.objects.all() @@ -73,6 +75,7 @@ def years(): @pytest.fixture def months(years): Month.objects.create(month_number=1, year=years[0]) + Month.objects.create(month_number=1, year=years[1]) return Month.objects.all() @@ -81,6 +84,8 @@ def months(years): def weeks(years): week = Week.objects.create(week_number=1) week.years.add(years[0]) + week = Week.objects.create(week_number=1) + week.years.add(years[1]) return Week.objects.all() @@ -93,6 +98,12 @@ def days(years, months, weeks): month=months[0], week=weeks[0], ) + Day.objects.create( + date=parser.parse("2022-01-01 00:00:00"), + year=years[1], + month=months[1], + week=weeks[1], + ) return Day.objects.all() @@ -100,6 +111,7 @@ def days(years, months, weeks): @pytest.fixture def hours(days): Hour.objects.create(day=days[0], hour_number=0) + Hour.objects.create(day=days[1], hour_number=0) return Hour.objects.all() @@ -108,6 +120,8 @@ def hours(days): def year_datas(stations, years, measurements): year_data = YearData.objects.create(station=stations[0], year=years[0]) year_data.measurements.add(measurements[0]) + year_data = YearData.objects.create(station=stations[0], year=years[1]) + year_data.measurements.add(measurements[1]) return YearData.objects.all() @@ -116,6 +130,8 @@ def year_datas(stations, years, measurements): def month_datas(stations, months, measurements): month_data = MonthData.objects.create(station=stations[0], month=months[0]) month_data.measurements.add(measurements[0]) + month_data = MonthData.objects.create(station=stations[0], month=months[1]) + month_data.measurements.add(measurements[1]) return MonthData.objects.all() @@ -124,14 +140,17 @@ def month_datas(stations, months, measurements): def week_datas(stations, weeks, measurements): week_data = WeekData.objects.create(station=stations[0], week=weeks[0]) week_data.measurements.add(measurements[0]) + week_data = WeekData.objects.create(station=stations[0], week=weeks[1]) + week_data.measurements.add(measurements[1]) return WeekData.objects.all() -@pytest.mark.django_db @pytest.fixture def day_datas(stations, days, measurements): day_data = DayData.objects.create(station=stations[0], day=days[0]) day_data.measurements.add(measurements[0]) + day_data = DayData.objects.create(station=stations[0], day=days[1]) + day_data.measurements.add(measurements[1]) return DayData.objects.all() @@ -140,4 +159,6 @@ def day_datas(stations, days, measurements): def hour_datas(stations, hours, measurements): hour_data = HourData.objects.create(station=stations[0], hour=hours[0]) hour_data.measurements.add(measurements[0]) + hour_data = HourData.objects.create(station=stations[0], hour=hours[1]) + hour_data.measurements.add(measurements[1]) return HourData.objects.all() diff --git a/environment_data/tests/test_api.py b/environment_data/tests/test_api.py index 6635fc5c0..aecb46962 100644 --- a/environment_data/tests/test_api.py +++ b/environment_data/tests/test_api.py @@ -48,6 +48,7 @@ def test_day_data(api_client, day_datas, parameters): ) response = api_client.get(url) assert response.status_code == 200 + assert len(response.json()["results"]) == 1 json_data = response.json()["results"][0] assert len(json_data["measurements"]) == 1 assert json_data["measurements"][0]["value"] == 1.5 @@ -55,6 +56,17 @@ def test_day_data(api_client, day_datas, parameters): assert json_data["date"] == "2023-01-01" +@pytest.mark.django_db +def test_day_data_non_existing_year(api_client, day_datas, parameters): + url = ( + reverse("environment_data:data-list") + + "?year=2020&start=01-01&end=02-01&station_id=1&type=day" + ) + response = api_client.get(url) + assert response.status_code == 200 + assert len(response.json()["results"]) == 0 + + @pytest.mark.django_db def test_week_data(api_client, week_datas, parameters): url = ( @@ -63,6 +75,7 @@ def test_week_data(api_client, week_datas, parameters): ) response = api_client.get(url) assert response.status_code == 200 + assert len(response.json()["results"]) == 1 json_data = response.json()["results"][0] assert len(json_data["measurements"]) == 1 assert json_data["measurements"][0]["value"] == 1.5 @@ -78,11 +91,38 @@ def test_month_data(api_client, month_datas, parameters): ) response = api_client.get(url) assert response.status_code == 200 + assert len(response.json()["results"]) == 1 json_data = response.json()["results"][0] assert len(json_data["measurements"]) == 1 assert json_data["measurements"][0]["value"] == 1.5 assert json_data["measurements"][0]["parameter"] == parameters[0].name assert json_data["month_number"] == 1 + url = ( + reverse("environment_data:data-list") + + "?year=2023&start=1&end=1&station_id=411&type=month" + ) + response = api_client.get(url) + assert len(response.json()["results"]) == 0 + + +@pytest.mark.django_db +def test_month_data_non_existing_year(api_client, month_datas, parameters): + url = ( + reverse("environment_data:data-list") + + "?year=2020&start=1&end=1&station_id=411&type=month" + ) + response = api_client.get(url) + assert len(response.json()["results"]) == 0 + + +@pytest.mark.django_db +def test_month_data_chars_in_arguments(api_client, month_datas, parameters): + url = ( + reverse("environment_data:data-list") + + "?year=foo&start=abc&end=dce&station_id=foobar&type=month" + ) + response = api_client.get(url) + assert len(response.json()["results"]) == 0 @pytest.mark.django_db @@ -94,6 +134,7 @@ def test_year_data(api_client, year_datas, parameters): response = api_client.get(url) assert response.status_code == 200 json_data = response.json()["results"][0] + assert len(response.json()["results"]) == 1 assert len(json_data["measurements"]) == 1 assert json_data["measurements"][0]["value"] == 1.5 assert json_data["measurements"][0]["parameter"] == parameters[0].name