diff --git a/lunch/tests.py b/lunch/tests.py index 1777741..3a5e3b2 100644 --- a/lunch/tests.py +++ b/lunch/tests.py @@ -1,10 +1,10 @@ -# from django.contrib.auth.models import User from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase +from rest_framework_simplejwt.tokens import RefreshToken +from common.models import Allergy from lunch.models import Lunch, LunchMenu -from menus.models import Menu from menus.serializers import MenuWithDetailSerializer from users.models import User from utils.test_helper import create_menu @@ -12,8 +12,17 @@ class LunchAPITestCase(APITestCase): def setUp(self): - user = User.objects.create_user(username="testuser", email="test@naver.com", password="1234") - self.client.login(username="testuser", password="1234") + self.test_user = User.objects.create_user( + username="testuser", email="test@example.com", password="testpassword", status=2 + ) + + refresh = RefreshToken.for_user(self.test_user) + self.token = str(refresh.access_token) + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {self.token}") + allergy_names = ["밀", "돼지고기"] + allergies = Allergy.objects.filter(name__in=allergy_names) + + self.test_user.allergies.set(allergies) menu1 = create_menu(name="test_menu1") menu2 = create_menu(name="test_menu2") @@ -101,7 +110,7 @@ def setUp(self): for i in range(1, 11): lunch = Lunch.objects.create( - store=user, + store=self.test_user, name=f"test Lunch{i}", description="test lunch set", image_url="http://example.com/image.jpg", @@ -178,5 +187,17 @@ def test_lunch_delete(self): def test_random_lunch_get(self): url = reverse("lunch-random") - res = self.client.get(url) + + with self.assertNumQueries(4): + res = self.client.get(url) + + self.assertEqual(res.status_code, status.HTTP_200_OK) + + def test_random_lunch_with_allergy_get(self): + base_url = reverse("lunch-random") + url = f"{base_url}?allergy=true" + + with self.assertNumQueries(4): + res = self.client.get(url) + self.assertEqual(res.status_code, status.HTTP_200_OK) diff --git a/lunch/views.py b/lunch/views.py index c40a9ff..e86a4b8 100644 --- a/lunch/views.py +++ b/lunch/views.py @@ -1,7 +1,7 @@ import random +from django.db.models import Prefetch, Subquery from rest_framework import status -from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView @@ -80,8 +80,32 @@ def delete(self, request: Request, pk: int) -> Response: class LunchRandomList(APIView): def get(self, request: Request) -> Response: - random_lunch = Lunch.objects.order_by("?")[0:10] - serializer = LunchSerializer(random_lunch, many=True) + allergy = request.GET.get("allergy", "").lower() + lunch_queryset = Lunch.objects.all() + + if allergy == "true": + if not request.user.is_authenticated: + return Response({"message": "로그인 된 유저가 아닙니다"}, status=status.HTTP_403_FORBIDDEN) + + user_allergies = request.user.allergies.all() + + allergen_menus = Menu.objects.filter(menu_details__allergy__in=user_allergies).values("id") + + lunch_queryset = lunch_queryset.exclude(menus__in=Subquery(allergen_menus)) + + total_count = lunch_queryset.count() + if total_count <= 10: + lunch_queryset = lunch_queryset.all() + else: + random_ids = lunch_queryset.values_list("id", flat=True) + random_ids_list = list(random_ids) + random.shuffle(random_ids_list) + selected_ids = random_ids_list[:10] + lunch_queryset = Lunch.objects.filter(id__in=selected_ids) + + lunch_prefetch = Prefetch("lunch_menu", queryset=LunchMenu.objects.select_related("menu")) + lunch_queryset = lunch_queryset.prefetch_related(lunch_prefetch) + serializer = LunchSerializer(lunch_queryset, many=True) return Response( serializer.data, diff --git a/menus/tests.py b/menus/tests.py index ae1c2d5..ffd4f40 100644 --- a/menus/tests.py +++ b/menus/tests.py @@ -18,6 +18,10 @@ def setUp(self) -> None: refresh = RefreshToken.for_user(self.test_user) self.token = str(refresh.access_token) self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {self.token}") + allergy_names = ["밀", "돼지고기"] + allergies = Allergy.objects.filter(name__in=allergy_names) + + self.test_user.allergies.set(allergies) self.menu = create_menu(name="test_menu1") create_menu(name="test_menu2") @@ -199,3 +203,12 @@ def test_create_menu(self) -> None: self.assertEqual(menu_details.count(), 2) self.assertEqual(menu_details.first().allergy.name, "메밀") self.assertIsNone(menu_details.last().allergy) + + def test_get_allergy_menu_list(self): + base_url = reverse("menu-list") + url = f"{base_url}?category=chan&allergy=true" + + with self.assertNumQueries(4): + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/menus/views.py b/menus/views.py index 2bff681..2782aab 100644 --- a/menus/views.py +++ b/menus/views.py @@ -11,15 +11,21 @@ class MenuList(APIView): def get(self, request: Request) -> Response: - page = int(request.GET.get("page", "1")) size = int(request.GET.get("size", "10")) category = request.GET.get("category", "bob").lower() - allergies = request.GET.get("allergy", "").lower().split(",") - allergies = [allergy.strip() for allergy in allergies if allergy.strip()] + allergy = request.GET.get("allergy", "").lower() search = request.GET.get("search", "").lower() offset = (page - 1) * size + allergies: list[int] = [] + + if allergy == "true": + if not request.user.is_authenticated: + return Response({"message": "로그인 된 유저가 아닙니다"}, status=status.HTTP_403_FORBIDDEN) + + allergies = [aller.id for aller in request.user.allergies.all()] + if page < 1: return Response("page input error", status=status.HTTP_400_BAD_REQUEST) @@ -30,7 +36,7 @@ def get(self, request: Request) -> Response: if allergies: menus = menus.annotate( - allergy_count=Count("menu_details__allergy", filter=Q(menu_details__allergy__name__in=allergies)) + allergy_count=Count("menu_details__allergy", filter=Q(menu_details__allergy__id__in=allergies)) ).filter(allergy_count=0) if search: