From c20a9207892e89dfaa20afd8df901f425f1af263 Mon Sep 17 00:00:00 2001 From: Marcos Date: Mon, 4 Nov 2024 15:23:31 -0300 Subject: [PATCH] fix: Added course run key to save_chat_message() --- learning_assistant/api.py | 4 +++- learning_assistant/views.py | 4 ++-- tests/test_api.py | 4 +++- tests/test_views.py | 6 ++++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index e6fa870..eaf4b10 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -191,7 +191,7 @@ def get_course_id(course_run_id): return course_key -def save_chat_message(user_id, chat_role, message): +def save_chat_message(courserun_key, user_id, chat_role, message): """ Save the chat message to the database. """ @@ -203,9 +203,11 @@ def save_chat_message(user_id, chat_role, message): # Save the user message to the database. LearningAssistantMessage.objects.create( + course_id=courserun_key, user=user, role=chat_role, content=message, + ) diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 3e136fe..583aed1 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -96,7 +96,7 @@ def post(self, request, course_run_id): user_id = request.user.id if chat_history_enabled(courserun_key): - save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) + save_chat_message(courserun_key, user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) serializer = MessageSerializer(data=message_list, many=True) @@ -126,7 +126,7 @@ def post(self, request, course_run_id): status_code, message = get_chat_response(prompt_template, message_list) if chat_history_enabled(courserun_key): - save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) + save_chat_message(courserun_key, user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) return Response(status=status_code, data=message) diff --git a/tests/test_api.py b/tests/test_api.py index 6969af8..0470e32 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -245,6 +245,7 @@ def setUp(self): super().setUp() self.test_user = User.objects.create(username='username', password='password') + self.course_run_key = CourseKey.from_string('course-v1:edx+test+23') @ddt.data( (LearningAssistantMessage.USER_ROLE, 'What is the meaning of life, the universe and everything?'), @@ -252,10 +253,11 @@ def setUp(self): ) @ddt.unpack def test_save_chat_message(self, chat_role, message): - save_chat_message(self.test_user.id, chat_role, message) + save_chat_message(self.course_run_key, self.test_user.id, chat_role, message) row = LearningAssistantMessage.objects.all().last() + self.assertEqual(row.course_id, self.course_run_key) self.assertEqual(row.role, chat_role) self.assertEqual(row.content, message) diff --git a/tests/test_views.py b/tests/test_views.py index 9e07e55..58dab31 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -14,6 +14,7 @@ from django.test.client import Client from django.urls import reverse +from opaque_keys.edx.keys import CourseKey from learning_assistant.models import LearningAssistantMessage User = get_user_model() @@ -85,6 +86,7 @@ class TestCourseChatView(LoggedInTestCase): def setUp(self): super().setUp() self.course_id = 'course-v1:edx+test+23' + self.course_run_key = CourseKey.from_string(self.course_id) self.patcher = patch( 'learning_assistant.api.get_cache_course_run_data', @@ -209,8 +211,8 @@ def test_chat_response_default( if enabled_flag: mock_save_chat_message.assert_has_calls([ - call(self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), - call(self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') + call(self.course_run_key, self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), + call(self.course_run_key, self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') ]) else: mock_save_chat_message.assert_not_called()