From e28918a6a682719c72533f0435dcd76488ff0e1b Mon Sep 17 00:00:00 2001 From: Marcos Date: Wed, 30 Oct 2024 17:53:09 -0300 Subject: [PATCH] chore: Fixed coverage for CourseChatView --- learning_assistant/api.py | 8 +++--- learning_assistant/views.py | 41 ++++++++---------------------- tests/test_api.py | 3 ++- tests/test_views.py | 50 ++++++++++++++++++++++++------------- 4 files changed, 49 insertions(+), 53 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 8ba2906..5a92d93 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -190,16 +190,16 @@ def get_course_id(course_run_id): course_key = course_data['course'] return course_key + def save_chat_message(user_id, chat_role, message): """ - Saves the chat message to the database. + Save the chat message to the database. """ - user = None try: user = User.objects.get(id=user_id) - except User.DoesNotExist: - raise Exception("User does not exists.") + except User.DoesNotExist as exc: + raise Exception("User does not exists.") from exc # Save the user message to the database. LearningAssistantMessage.objects.create( diff --git a/learning_assistant/views.py b/learning_assistant/views.py index f40e517..3088c78 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -4,7 +4,6 @@ import logging from django.conf import settings -from django.contrib.auth import get_user_model from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey @@ -21,13 +20,13 @@ except ImportError: pass -from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template +from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template, save_chat_message from learning_assistant.models import LearningAssistantMessage from learning_assistant.serializers import MessageSerializer +from learning_assistant.toggles import chat_history_enabled from learning_assistant.utils import get_chat_response, user_role_is_staff log = logging.getLogger(__name__) -User = get_user_model() class CourseChatView(APIView): @@ -38,27 +37,6 @@ class CourseChatView(APIView): authentication_classes = (SessionAuthentication, JwtAuthentication,) permission_classes = (IsAuthenticated,) - def __save_user_interaction(self, user_id, user_message, assistant_message): - """ - Saves the last question/response to the database. - """ - user = User.objects.get(id=user_id) - - # Save the user message to the database. - LearningAssistantMessage.objects.create( - user=user, - role=LearningAssistantMessage.USER_ROLE, - content=user_message, - ) - - # Save the assistant response to the database. - LearningAssistantMessage.objects.create( - user=user, - role=LearningAssistantMessage.ASSISTANT_ROLE, - content=assistant_message, - ) - - def post(self, request, course_run_id): """ Given a course run ID, retrieve a chat response for that course. @@ -109,6 +87,12 @@ def post(self, request, course_run_id): data={'detail': "Expects user role on last message."} ) + course_id = get_course_id(course_run_id) + user_id = request.user.id + + if chat_history_enabled(course_id): + save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) + serializer = MessageSerializer(data=message_list, many=True) # serializer will not be valid in the case that the message list contains any roles other than @@ -127,8 +111,6 @@ def post(self, request, course_run_id): } ) - course_id = get_course_id(course_run_id) - template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') prompt_template = render_prompt_template( @@ -136,11 +118,8 @@ def post(self, request, course_run_id): ) status_code, message = get_chat_response(prompt_template, message_list) - self.__save_user_interaction( - user_id=request.user.id, - user_message=new_user_message['content'], - assistant_message=message['content'] - ) + if chat_history_enabled(course_id): + save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) return Response(status=status_code, data=message) diff --git a/tests/test_api.py b/tests/test_api.py index 85d0a01..60bc1f9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -29,6 +29,7 @@ fake_transcript = 'This is the text version from the transcript' User = get_user_model() + class FakeChild: """Fake child block for testing""" transcript_download_format = 'txt' @@ -232,6 +233,7 @@ def test_render_prompt_template_invalid_unit_key(self, mock_get_content): self.assertNotIn('The following text is useful.', prompt_text) + @ddt.ddt class TestLearningAssistantCourseEnabledApi(TestCase): """ @@ -256,7 +258,6 @@ def test_save_chat_message(self, chat_role, message): self.assertEqual(row.content, message) - @ddt.ddt class LearningAssistantCourseEnabledApiTests(TestCase): """ diff --git a/tests/test_views.py b/tests/test_views.py index 74728c3..afa98c4 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -4,7 +4,7 @@ import json import sys from importlib import import_module -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import ddt from django.conf import settings @@ -19,7 +19,7 @@ User = get_user_model() -class TestClient(Client): +class FakeClient(Client): """ Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint. """ @@ -65,14 +65,14 @@ def setUp(self): Setup for tests. """ super().setUp() - self.client = TestClient() + self.client = FakeClient() self.user = User(username='tester', email='tester@test.com') self.user.save() self.client.login_user(self.user) @ddt.ddt -class CourseChatViewTests(LoggedInTestCase): +class TestCourseChatView(LoggedInTestCase): """ Test for the CourseChatView """ @@ -152,15 +152,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): ) self.assertEqual(response.status_code, 400) + @ddt.data(True, False) # TODO: Fix this - See below. @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.get_chat_response') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') + @patch('learning_assistant.api.save_chat_message') + @patch('learning_assistant.toggles.chat_history_enabled') @override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template') def test_chat_response_default( - self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render + self, + enabled_flag, + mock_chat_history_enabled, + mock_save_chat_message, + mock_mode, + mock_enrollment, + mock_role, + mock_waffle, + mock_chat_response, + mock_render, ): mock_waffle.return_value = True mock_role.return_value = 'student' @@ -170,6 +182,14 @@ def test_chat_response_default( mock_render.return_value = 'Rendered template mock' test_unit_id = 'test-unit-id' + # TODO: Fix this... + # For some reason this only works the first time. The 2nd time (enabled_flag = False) + # Doesn't actually work since the mocked chat_history_enabled() will return False no matter what. + # Swap the order of the @ddt.data() above by: @ddt.data(False, True) and watch it fail. + # The value for enabled_flag is corrct on this scope, but the mocked method doesn't update. + # It even happens if we split the test cases into two different methods. + mock_chat_history_enabled.return_value = enabled_flag + test_data = [ {'role': 'user', 'content': 'What is 2+2?'}, {'role': 'assistant', 'content': 'It is 4'}, @@ -181,20 +201,8 @@ def test_chat_response_default( data=json.dumps(test_data), content_type='application/json' ) - self.assertEqual(response.status_code, 200) - last_rows = LearningAssistantMessage.objects.all().order_by('-created').values()[:2][::-1] - - user_msg = last_rows[0] - assistant_msg = last_rows[1] - - self.assertEqual(user_msg['role'], LearningAssistantMessage.USER_ROLE) - self.assertEqual(user_msg['content'], test_data[2]['content']) - - self.assertEqual(assistant_msg['role'], LearningAssistantMessage.ASSISTANT_ROLE) - self.assertEqual(assistant_msg['content'], 'Something else') - render_args = mock_render.call_args.args self.assertIn(test_unit_id, render_args) self.assertIn('This is the default template', render_args) @@ -204,6 +212,14 @@ def test_chat_response_default( test_data, ) + if enabled_flag: + mock_save_chat_message.assert_has_calls([ + call(self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), + call(self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') + ]) + else: + mock_save_chat_message.assert_not_called() + @ddt.ddt class LearningAssistantEnabledViewTests(LoggedInTestCase):