Skip to content

Commit

Permalink
chore: Fixed coverage for CourseChatView
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuma committed Oct 30, 2024
1 parent c91ea21 commit e28918a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 53 deletions.
8 changes: 4 additions & 4 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,16 @@ def get_course_id(course_run_id):
course_key = course_data['course']
return course_key


def save_chat_message(user_id, chat_role, message):
"""
Saves the chat message to the database.
Save the chat message to the database.
"""

user = None
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist:
raise Exception("User does not exists.")
except User.DoesNotExist as exc:
raise Exception("User does not exists.") from exc

Check failure on line 202 in learning_assistant/api.py

View workflow job for this annotation

GitHub Actions / tests (ubuntu-20.04, 3.11, django42)

Missing coverage

Missing coverage on lines 201-202

# Save the user message to the database.
LearningAssistantMessage.objects.create(
Expand Down
41 changes: 10 additions & 31 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging

from django.conf import settings
from django.contrib.auth import get_user_model
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from opaque_keys import InvalidKeyError
from opaque_keys.edx.keys import CourseKey
Expand All @@ -21,13 +20,13 @@
except ImportError:
pass

from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template
from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template, save_chat_message
from learning_assistant.models import LearningAssistantMessage
from learning_assistant.serializers import MessageSerializer
from learning_assistant.toggles import chat_history_enabled
from learning_assistant.utils import get_chat_response, user_role_is_staff

log = logging.getLogger(__name__)
User = get_user_model()


class CourseChatView(APIView):
Expand All @@ -38,27 +37,6 @@ class CourseChatView(APIView):
authentication_classes = (SessionAuthentication, JwtAuthentication,)
permission_classes = (IsAuthenticated,)

def __save_user_interaction(self, user_id, user_message, assistant_message):
"""
Saves the last question/response to the database.
"""
user = User.objects.get(id=user_id)

# Save the user message to the database.
LearningAssistantMessage.objects.create(
user=user,
role=LearningAssistantMessage.USER_ROLE,
content=user_message,
)

# Save the assistant response to the database.
LearningAssistantMessage.objects.create(
user=user,
role=LearningAssistantMessage.ASSISTANT_ROLE,
content=assistant_message,
)


def post(self, request, course_run_id):
"""
Given a course run ID, retrieve a chat response for that course.
Expand Down Expand Up @@ -109,6 +87,12 @@ def post(self, request, course_run_id):
data={'detail': "Expects user role on last message."}
)

course_id = get_course_id(course_run_id)
user_id = request.user.id

if chat_history_enabled(course_id):
save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content'])

serializer = MessageSerializer(data=message_list, many=True)

# serializer will not be valid in the case that the message list contains any roles other than
Expand All @@ -127,20 +111,15 @@ def post(self, request, course_run_id):
}
)

course_id = get_course_id(course_run_id)

template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')

prompt_template = render_prompt_template(
request, request.user.id, course_run_id, unit_id, course_id, template_string
)
status_code, message = get_chat_response(prompt_template, message_list)

self.__save_user_interaction(
user_id=request.user.id,
user_message=new_user_message['content'],
assistant_message=message['content']
)
if chat_history_enabled(course_id):
save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content'])

return Response(status=status_code, data=message)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
fake_transcript = 'This is the text version from the transcript'
User = get_user_model()


class FakeChild:
"""Fake child block for testing"""
transcript_download_format = 'txt'
Expand Down Expand Up @@ -232,6 +233,7 @@ def test_render_prompt_template_invalid_unit_key(self, mock_get_content):

self.assertNotIn('The following text is useful.', prompt_text)


@ddt.ddt
class TestLearningAssistantCourseEnabledApi(TestCase):
"""
Expand All @@ -256,7 +258,6 @@ def test_save_chat_message(self, chat_role, message):
self.assertEqual(row.content, message)



@ddt.ddt
class LearningAssistantCourseEnabledApiTests(TestCase):
"""
Expand Down
50 changes: 33 additions & 17 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import sys
from importlib import import_module
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch

import ddt
from django.conf import settings
Expand All @@ -19,7 +19,7 @@
User = get_user_model()


class TestClient(Client):
class FakeClient(Client):
"""
Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint.
"""
Expand Down Expand Up @@ -65,14 +65,14 @@ def setUp(self):
Setup for tests.
"""
super().setUp()
self.client = TestClient()
self.client = FakeClient()
self.user = User(username='tester', email='tester@test.com')
self.user.save()
self.client.login_user(self.user)


@ddt.ddt
class CourseChatViewTests(LoggedInTestCase):
class TestCourseChatView(LoggedInTestCase):
"""
Test for the CourseChatView
"""
Expand Down Expand Up @@ -152,15 +152,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render):
)
self.assertEqual(response.status_code, 400)

@ddt.data(True, False) # TODO: Fix this - See below.
@patch('learning_assistant.views.render_prompt_template')
@patch('learning_assistant.views.get_chat_response')
@patch('learning_assistant.views.learning_assistant_enabled')
@patch('learning_assistant.views.get_user_role')
@patch('learning_assistant.views.CourseEnrollment.get_enrollment')
@patch('learning_assistant.views.CourseMode')
@patch('learning_assistant.api.save_chat_message')
@patch('learning_assistant.toggles.chat_history_enabled')
@override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template')
def test_chat_response_default(
self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render
self,
enabled_flag,
mock_chat_history_enabled,
mock_save_chat_message,
mock_mode,
mock_enrollment,
mock_role,
mock_waffle,
mock_chat_response,
mock_render,
):
mock_waffle.return_value = True
mock_role.return_value = 'student'
Expand All @@ -170,6 +182,14 @@ def test_chat_response_default(
mock_render.return_value = 'Rendered template mock'
test_unit_id = 'test-unit-id'

# TODO: Fix this...
# For some reason this only works the first time. The 2nd time (enabled_flag = False)
# Doesn't actually work since the mocked chat_history_enabled() will return False no matter what.
# Swap the order of the @ddt.data() above by: @ddt.data(False, True) and watch it fail.
# The value for enabled_flag is corrct on this scope, but the mocked method doesn't update.
# It even happens if we split the test cases into two different methods.
mock_chat_history_enabled.return_value = enabled_flag

test_data = [
{'role': 'user', 'content': 'What is 2+2?'},
{'role': 'assistant', 'content': 'It is 4'},
Expand All @@ -181,20 +201,8 @@ def test_chat_response_default(
data=json.dumps(test_data),
content_type='application/json'
)

self.assertEqual(response.status_code, 200)

last_rows = LearningAssistantMessage.objects.all().order_by('-created').values()[:2][::-1]

user_msg = last_rows[0]
assistant_msg = last_rows[1]

self.assertEqual(user_msg['role'], LearningAssistantMessage.USER_ROLE)
self.assertEqual(user_msg['content'], test_data[2]['content'])

self.assertEqual(assistant_msg['role'], LearningAssistantMessage.ASSISTANT_ROLE)
self.assertEqual(assistant_msg['content'], 'Something else')

render_args = mock_render.call_args.args
self.assertIn(test_unit_id, render_args)
self.assertIn('This is the default template', render_args)
Expand All @@ -204,6 +212,14 @@ def test_chat_response_default(
test_data,
)

if enabled_flag:
mock_save_chat_message.assert_has_calls([
call(self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']),
call(self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else')
])
else:
mock_save_chat_message.assert_not_called()


@ddt.ddt
class LearningAssistantEnabledViewTests(LoggedInTestCase):
Expand Down

0 comments on commit e28918a

Please sign in to comment.