From f4e9a81c9695d2694e42c933cd145f9452a87bf7 Mon Sep 17 00:00:00 2001 From: Marcos Date: Wed, 30 Oct 2024 09:40:54 -0300 Subject: [PATCH] feat: Adding chat messages to the DB --- learning_assistant/models.py | 7 ++++++- learning_assistant/serializers.py | 4 ++-- learning_assistant/views.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/learning_assistant/models.py b/learning_assistant/models.py index c890087..cb2d68c 100644 --- a/learning_assistant/models.py +++ b/learning_assistant/models.py @@ -35,7 +35,12 @@ class LearningAssistantMessage(TimeStampedModel): .. pii_retirement: third_party """ + USER_ROLE = 'user' + ASSISTANT_ROLE = 'assistant' + + Roles = (USER_ROLE, ASSISTANT_ROLE) + course_id = CourseKeyField(max_length=255, db_index=True) user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE) - role = models.CharField(max_length=64) + role = models.CharField(choices=Roles, max_length=64) content = models.TextField() diff --git a/learning_assistant/serializers.py b/learning_assistant/serializers.py index a212654..6a4de75 100644 --- a/learning_assistant/serializers.py +++ b/learning_assistant/serializers.py @@ -2,7 +2,7 @@ Serializers for the learning-assistant API. """ from rest_framework import serializers - +from learning_assistant.models import LearningAssistantMessage class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-method """ @@ -16,7 +16,7 @@ def validate_role(self, value): """ Validate that role is one of two acceptable values. """ - valid_roles = ['user', 'assistant'] + valid_roles = [LearningAssistantMessage.USER_ROLE, LearningAssistantMessage.ASSISTANT_ROLE] if value not in valid_roles: raise serializers.ValidationError('Must be valid role.') return value diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 7739240..56825ae 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -4,6 +4,7 @@ import logging from django.conf import settings +from django.contrib.auth import get_user_model from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey @@ -12,7 +13,6 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView - try: from common.djangoapps.course_modes.models import CourseMode from common.djangoapps.student.models import CourseEnrollment @@ -21,11 +21,13 @@ pass from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template +from learning_assistant.models import LearningAssistantMessage from learning_assistant.serializers import MessageSerializer from learning_assistant.utils import get_chat_response, user_role_is_staff log = logging.getLogger(__name__) +User = get_user_model() class CourseChatView(APIView): """ @@ -76,6 +78,15 @@ def post(self, request, course_run_id): unit_id = request.query_params.get('unit_id') message_list = request.data + + # Check that the last message in the list corresponds to a user + new_user_message = message_list[-1] + if (new_user_message.role != 'user'): + return Response( + status=http_status.HTTP_400_BAD_REQUEST, + data={'detail': "Expects user role on last message."} + ) + serializer = MessageSerializer(data=message_list, many=True) # serializer will not be valid in the case that the message list contains any roles other than @@ -103,6 +114,22 @@ def post(self, request, course_run_id): ) status_code, message = get_chat_response(prompt_template, message_list) + user = User.objects.get(id=request.user.id) # Based on the previous code, user exists. + + # Save the user message to the database. + LearningAssistantMessage.objects.create( + user=user, + role=LearningAssistantMessage.USER_ROLE, + content=new_user_message.content, + ) + + # Save the assistant response to the database. + LearningAssistantMessage.objects.create( + user=user, + role=LearningAssistantMessage.ASSISTANT_ROLE, + content=message, + ) + return Response(status=status_code, data=message)