Skip to content

Commit

Permalink
feat: Adding chat messages to the DB
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuma committed Oct 30, 2024
1 parent 5f5c618 commit f4e9a81
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
7 changes: 6 additions & 1 deletion learning_assistant/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ class LearningAssistantMessage(TimeStampedModel):
.. pii_retirement: third_party
"""

USER_ROLE = 'user'
ASSISTANT_ROLE = 'assistant'

Roles = (USER_ROLE, ASSISTANT_ROLE)

course_id = CourseKeyField(max_length=255, db_index=True)
user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE)
role = models.CharField(max_length=64)
role = models.CharField(choices=Roles, max_length=64)
content = models.TextField()
4 changes: 2 additions & 2 deletions learning_assistant/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Serializers for the learning-assistant API.
"""
from rest_framework import serializers

from learning_assistant.models import LearningAssistantMessage

class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""
Expand All @@ -16,7 +16,7 @@ def validate_role(self, value):
"""
Validate that role is one of two acceptable values.
"""
valid_roles = ['user', 'assistant']
valid_roles = [LearningAssistantMessage.USER_ROLE, LearningAssistantMessage.ASSISTANT_ROLE]
if value not in valid_roles:
raise serializers.ValidationError('Must be valid role.')
return value
29 changes: 28 additions & 1 deletion learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from django.conf import settings
from django.contrib.auth import get_user_model
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
Expand All @@ -12,7 +13,6 @@
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView

try:
from common.djangoapps.course_modes.models import CourseMode
from common.djangoapps.student.models import CourseEnrollment
Expand All @@ -21,11 +21,13 @@
pass

from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template
from learning_assistant.models import LearningAssistantMessage
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response, user_role_is_staff

log = logging.getLogger(__name__)

User = get_user_model()

class CourseChatView(APIView):
"""
Expand Down Expand Up @@ -76,6 +78,15 @@ def post(self, request, course_run_id):
unit_id = request.query_params.get('unit_id')

message_list = request.data

# Check that the last message in the list corresponds to a user
new_user_message = message_list[-1]
if (new_user_message.role != 'user'):
return Response(
status=http_status.HTTP_400_BAD_REQUEST,
data={'detail': "Expects user role on last message."}
)

serializer = MessageSerializer(data=message_list, many=True)

# serializer will not be valid in the case that the message list contains any roles other than
Expand Down Expand Up @@ -103,6 +114,22 @@ def post(self, request, course_run_id):
)
status_code, message = get_chat_response(prompt_template, message_list)

user = User.objects.get(id=request.user.id) # Based on the previous code, user exists.

# Save the user message to the database.
LearningAssistantMessage.objects.create(
user=user,
role=LearningAssistantMessage.USER_ROLE,
content=new_user_message.content,
)

# Save the assistant response to the database.
LearningAssistantMessage.objects.create(
user=user,
role=LearningAssistantMessage.ASSISTANT_ROLE,
content=message,
)

return Response(status=status_code, data=message)


Expand Down

0 comments on commit f4e9a81

Please sign in to comment.