Skip to content

Commit

Permalink
feat: remove GPT model field in post request (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
alangsto authored Sep 10, 2024
1 parent bb7c1cb commit 3c60659
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 82 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Change Log
Unreleased
**********

4.3.1 - 2024-09-10
******************
* Remove GPT model field as part of POST request to Xpert backend

4.3.0 - 2024-07-01
******************
* Adds optional parameter to use updated prompt and model for the chat response.
Expand Down
2 changes: 1 addition & 1 deletion learning_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Plugin for a learning assistant backend, intended for use within edx-platform.
"""

__version__ = '4.3.0'
__version__ = '4.3.1'

default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name
10 changes: 0 additions & 10 deletions learning_assistant/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,3 @@
"html": "TEXT",
"video": "VIDEO",
}


class GptModels:
GPT_3_5_TURBO = 'gpt-3.5-turbo'
GPT_3_5_TURBO_0125 = 'gpt-3.5-turbo-0125'
GPT_4o = 'gpt-4o'


class ResponseVariations:
GPT4_UPDATED_PROMPT = 'updated_prompt'
7 changes: 3 additions & 4 deletions learning_assistant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,18 @@ def get_reduced_message_list(prompt_template, message_list):
return [system_message] + new_message_list


def create_request_body(prompt_template, message_list, gpt_model):
def create_request_body(prompt_template, message_list):
"""
Form request body to be passed to the chat endpoint.
"""
response_body = {
'message_list': get_reduced_message_list(prompt_template, message_list),
'model': gpt_model,
}

return response_body


def get_chat_response(prompt_template, message_list, gpt_model):
def get_chat_response(prompt_template, message_list):
"""
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting.
"""
Expand All @@ -75,7 +74,7 @@ def get_chat_response(prompt_template, message_list, gpt_model):
connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1)
read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15)

body = create_request_body(prompt_template, message_list, gpt_model)
body = create_request_body(prompt_template, message_list)

try:
response = requests.post(
Expand Down
11 changes: 2 additions & 9 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
pass

from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template
from learning_assistant.constants import GptModels, ResponseVariations
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response, user_role_is_staff

Expand Down Expand Up @@ -75,7 +74,6 @@ def post(self, request, course_run_id):
)

unit_id = request.query_params.get('unit_id')
response_variation = request.query_params.get('response_variation')

message_list = request.data
serializer = MessageSerializer(data=message_list, many=True)
Expand All @@ -98,17 +96,12 @@ def post(self, request, course_run_id):

course_id = get_course_id(course_run_id)

if response_variation == ResponseVariations.GPT4_UPDATED_PROMPT:
gpt_model = GptModels.GPT_4o
template_string = getattr(settings, 'LEARNING_ASSISTANT_EXPERIMENTAL_PROMPT_TEMPLATE', '')
else:
gpt_model = GptModels.GPT_3_5_TURBO_0125
template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')
template_string = getattr(settings, 'LEARNING_ASSISTANT_EXPERIMENTAL_PROMPT_TEMPLATE', '')

prompt_template = render_prompt_template(
request, request.user.id, course_run_id, unit_id, course_id, template_string
)
status_code, message = get_chat_response(prompt_template, message_list, gpt_model)
status_code, message = get_chat_response(prompt_template, message_list)

return Response(status=status_code, data=message)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):
self.course_id = 'edx+test'

def get_response(self):
return get_chat_response(self.prompt_template, self.message_list, 'gpt-version-test')
return get_chat_response(self.prompt_template, self.message_list)

@override_settings(CHAT_COMPLETION_API=None)
def test_no_endpoint_setting(self):
Expand Down Expand Up @@ -90,7 +90,6 @@ def test_post_request_structure(self, mock_requests):

response_body = {
'message_list': [{'role': 'system', 'content': self.prompt_template}] + self.message_list,
'model': 'gpt-version-test',
}

self.get_response()
Expand Down
57 changes: 1 addition & 56 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from django.test.client import Client
from django.urls import reverse

from learning_assistant.constants import GptModels, ResponseVariations

User = get_user_model()


Expand Down Expand Up @@ -158,7 +156,7 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render):
@patch('learning_assistant.views.get_user_role')
@patch('learning_assistant.views.CourseEnrollment.get_enrollment')
@patch('learning_assistant.views.CourseMode')
@override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template')
@override_settings(LEARNING_ASSISTANT_EXPERIMENTAL_PROMPT_TEMPLATE='This is the default template')
def test_chat_response_default(
self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render
):
Expand Down Expand Up @@ -189,59 +187,6 @@ def test_chat_response_default(
mock_chat_response.assert_called_with(
'Rendered template mock',
test_data,
GptModels.GPT_3_5_TURBO_0125
)

@ddt.data(ResponseVariations.GPT4_UPDATED_PROMPT, 'invalid-variation')
@patch('learning_assistant.views.render_prompt_template')
@patch('learning_assistant.views.get_chat_response')
@patch('learning_assistant.views.learning_assistant_enabled')
@patch('learning_assistant.views.get_user_role')
@patch('learning_assistant.views.CourseEnrollment.get_enrollment')
@patch('learning_assistant.views.CourseMode')
@override_settings(LEARNING_ASSISTANT_EXPERIMENTAL_PROMPT_TEMPLATE='This is a template for GPT-4o variation')
@override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template')
def test_chat_response_variation(
self, variation, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render
):
mock_waffle.return_value = True
mock_role.return_value = 'student'
mock_mode.VERIFIED_MODES = ['verified']
mock_enrollment.return_value = MagicMock(mode='verified')
mock_chat_response.return_value = (200, {'role': 'assistant', 'content': 'Something else'})
mock_render.return_value = 'Rendered template mock'
test_unit_id = 'test-unit-id'

test_data = [
{'role': 'user', 'content': 'What is 2+2?'},
{'role': 'assistant', 'content': 'It is 4'}
]

response = self.client.post(
reverse(
'chat',
kwargs={'course_run_id': self.course_id}
)+f'?unit_id={test_unit_id}&response_variation={variation}',
data=json.dumps(test_data),
content_type='application/json',
)
self.assertEqual(response.status_code, 200)

if variation == ResponseVariations.GPT4_UPDATED_PROMPT:
expected_template = 'This is a template for GPT-4o variation'
expected_model = GptModels.GPT_4o
else:
expected_template = 'This is the default template'
expected_model = GptModels.GPT_3_5_TURBO_0125

render_args = mock_render.call_args.args
self.assertIn(test_unit_id, render_args)
self.assertIn(expected_template, render_args)

mock_chat_response.assert_called_with(
'Rendered template mock',
test_data,
expected_model,
)


Expand Down

0 comments on commit 3c60659

Please sign in to comment.