Skip to content

Commit

Permalink
feat: Adding chat messages to the DB
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuma committed Oct 30, 2024
1 parent 5f5c618 commit 4f2529d
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 3 deletions.
10 changes: 9 additions & 1 deletion learning_assistant/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@ class LearningAssistantMessage(TimeStampedModel):
.. pii_retirement: third_party
"""

USER_ROLE = 'user'
ASSISTANT_ROLE = 'assistant'

Roles = (
(USER_ROLE, USER_ROLE),
(ASSISTANT_ROLE, ASSISTANT_ROLE),
)

course_id = CourseKeyField(max_length=255, db_index=True)
user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE)
role = models.CharField(max_length=64)
role = models.CharField(choices=Roles, max_length=64)
content = models.TextField()
4 changes: 3 additions & 1 deletion learning_assistant/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
from rest_framework import serializers

from learning_assistant.models import LearningAssistantMessage


class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""
Expand All @@ -16,7 +18,7 @@ def validate_role(self, value):
"""
Validate that role is one of two acceptable values.
"""
valid_roles = ['user', 'assistant']
valid_roles = [LearningAssistantMessage.USER_ROLE, LearningAssistantMessage.ASSISTANT_ROLE]
if value not in valid_roles:
raise serializers.ValidationError('Must be valid role.')
return value
28 changes: 28 additions & 0 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from django.conf import settings
from django.contrib.auth import get_user_model
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
Expand All @@ -21,10 +22,12 @@
pass

from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template
from learning_assistant.models import LearningAssistantMessage
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response, user_role_is_staff

log = logging.getLogger(__name__)
User = get_user_model()


class CourseChatView(APIView):
Expand Down Expand Up @@ -76,6 +79,15 @@ def post(self, request, course_run_id):
unit_id = request.query_params.get('unit_id')

message_list = request.data

# Check that the last message in the list corresponds to a user
new_user_message = message_list[-1]
if new_user_message['role'] != LearningAssistantMessage.USER_ROLE:
return Response(
status=http_status.HTTP_400_BAD_REQUEST,
data={'detail': "Expects user role on last message."}
)

serializer = MessageSerializer(data=message_list, many=True)

# serializer will not be valid in the case that the message list contains any roles other than
Expand Down Expand Up @@ -103,6 +115,22 @@ def post(self, request, course_run_id):
)
status_code, message = get_chat_response(prompt_template, message_list)

user = User.objects.get(id=request.user.id) # Based on the previous code, user exists.

# Save the user message to the database.
LearningAssistantMessage.objects.create(
user=user,
role=LearningAssistantMessage.USER_ROLE,
content=new_user_message['content'],
)

# Save the assistant response to the database.
LearningAssistantMessage.objects.create(
user=user,
role=LearningAssistantMessage.ASSISTANT_ROLE,
content=message['content'],
)

return Response(status=status_code, data=message)


Expand Down
Empty file added tests/__init__.py
Empty file.
17 changes: 16 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from django.test.client import Client
from django.urls import reverse

from learning_assistant.models import LearningAssistantMessage

User = get_user_model()


Expand Down Expand Up @@ -170,16 +172,29 @@ def test_chat_response_default(

test_data = [
{'role': 'user', 'content': 'What is 2+2?'},
{'role': 'assistant', 'content': 'It is 4'}
{'role': 'assistant', 'content': 'It is 4'},
{'role': 'user', 'content': 'And what else?'},
]

response = self.client.post(
reverse('chat', kwargs={'course_run_id': self.course_id})+f'?unit_id={test_unit_id}',
data=json.dumps(test_data),
content_type='application/json'
)

self.assertEqual(response.status_code, 200)

last_rows = LearningAssistantMessage.objects.all().order_by('-created').values()[:2][::-1]

user_msg = last_rows[0]
assistant_msg = last_rows[1]

self.assertEqual(user_msg['role'], LearningAssistantMessage.USER_ROLE)
self.assertEqual(user_msg['content'], test_data[2]['content'])

self.assertEqual(assistant_msg['role'], LearningAssistantMessage.ASSISTANT_ROLE)
self.assertEqual(assistant_msg['content'], 'Something else')

render_args = mock_render.call_args.args
self.assertIn(test_unit_id, render_args)
self.assertIn('This is the default template', render_args)
Expand Down

0 comments on commit 4f2529d

Please sign in to comment.