From e2d8e12adafba5b94e049857b300a970cec00b56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gon=C3=A9ri=20Le=20Bouder?= Date: Sun, 19 May 2024 18:13:02 -0400 Subject: [PATCH] WCA: add playbook explanation and generation end-points - Add the new endpoints - Add a `request` parameter to `generate_playbook()` and `explain_playbook()`. We need it to get the user `organization_id`. --- ansible_wisdom/ai/api/model_client/base.py | 4 +- .../ai/api/model_client/dummy_client.py | 4 +- .../ai/api/model_client/langchain.py | 4 +- .../model_client/tests/test_dummy_client.py | 6 +-- .../api/model_client/tests/test_langchain.py | 5 ++- .../model_client/tests/test_model_client.py | 4 +- .../api/model_client/tests/test_wca_client.py | 29 ++++++++++++++ .../ai/api/model_client/wca_client.py | 39 +++++++++++++++++++ ansible_wisdom/ai/api/views.py | 4 +- 9 files changed, 84 insertions(+), 15 deletions(-) diff --git a/ansible_wisdom/ai/api/model_client/base.py b/ansible_wisdom/ai/api/model_client/base.py index 91406d74e..4c3625c05 100644 --- a/ansible_wisdom/ai/api/model_client/base.py +++ b/ansible_wisdom/ai/api/model_client/base.py @@ -54,11 +54,11 @@ def get_chat_model(self, model_id): raise NotImplementedError def generate_playbook( - self, text: str = "", create_outline: bool = False, outline: str = "" + self, request, text: str = "", create_outline: bool = False, outline: str = "" ) -> tuple[str, str]: raise NotImplementedError - def explain_playbook(self, content) -> str: + def explain_playbook(self, request, content) -> str: raise NotImplementedError def self_test(self) -> HealthCheckSummary: diff --git a/ansible_wisdom/ai/api/model_client/dummy_client.py b/ansible_wisdom/ai/api/model_client/dummy_client.py index f7f0f78f8..bf9a47a06 100644 --- a/ansible_wisdom/ai/api/model_client/dummy_client.py +++ b/ansible_wisdom/ai/api/model_client/dummy_client.py @@ -87,11 +87,11 @@ def infer(self, model_input, model_id="", suggestion_id=None) -> Dict[str, Any]: return response_body def generate_playbook( - self, text: str = "", create_outline: bool = False, outline: str = "" + self, request, text: str = "", create_outline: bool = False, outline: str = "" ) -> tuple[str, str]: if create_outline: return PLAYBOOK, OUTLINE return PLAYBOOK, "" - def explain_playbook(self, content) -> str: + def explain_playbook(self, request, content) -> str: return EXPLANATION diff --git a/ansible_wisdom/ai/api/model_client/langchain.py b/ansible_wisdom/ai/api/model_client/langchain.py index 929d6c432..695fb7bd9 100644 --- a/ansible_wisdom/ai/api/model_client/langchain.py +++ b/ansible_wisdom/ai/api/model_client/langchain.py @@ -121,7 +121,7 @@ def infer(self, model_input, model_id="", suggestion_id=None) -> Dict[str, Any]: raise ModelTimeoutError def generate_playbook( - self, text: str = "", create_outline: bool = False, outline: str = "" + self, request, text: str = "", create_outline: bool = False, outline: str = "" ) -> tuple[str, str]: SYSTEM_MESSAGE_TEMPLATE = """ You are an Ansible expert. @@ -178,7 +178,7 @@ def generate_playbook( return playbook, outline - def explain_playbook(self, content) -> str: + def explain_playbook(self, request, content) -> str: SYSTEM_MESSAGE_TEMPLATE = """ You're an Ansible expert. You format your output with Markdown. diff --git a/ansible_wisdom/ai/api/model_client/tests/test_dummy_client.py b/ansible_wisdom/ai/api/model_client/tests/test_dummy_client.py index 92730853b..b1a61214a 100644 --- a/ansible_wisdom/ai/api/model_client/tests/test_dummy_client.py +++ b/ansible_wisdom/ai/api/model_client/tests/test_dummy_client.py @@ -58,20 +58,20 @@ def test_infer_without_jitter(self, loads, sleep): def test_generate_playbook(self): client = DummyClient(inference_url="https://ibm.com") - playbook, outline = client.generate_playbook(text="foo", create_outline=False) + playbook, outline = client.generate_playbook(None, text="foo", create_outline=False) self.assertTrue(isinstance(playbook, str)) self.assertTrue(isinstance(outline, str)) self.assertEqual(outline, "") def test_generate_playbook_with_outline(self): client = DummyClient(inference_url="https://ibm.com") - playbook, outline = client.generate_playbook(text="foo", create_outline=True) + playbook, outline = client.generate_playbook(None, text="foo", create_outline=True) self.assertTrue(isinstance(playbook, str)) self.assertTrue(isinstance(outline, str)) self.assertTrue(outline) def test_explain_playbook(self): client = DummyClient(inference_url="https://ibm.com") - explanation = client.explain_playbook("ëoo") + explanation = client.explain_playbook(None, "ëoo") self.assertTrue(isinstance(explanation, str)) self.assertTrue(explanation) diff --git a/ansible_wisdom/ai/api/model_client/tests/test_langchain.py b/ansible_wisdom/ai/api/model_client/tests/test_langchain.py index 8f80eb298..6f9d4cfd8 100644 --- a/ansible_wisdom/ai/api/model_client/tests/test_langchain.py +++ b/ansible_wisdom/ai/api/model_client/tests/test_langchain.py @@ -117,16 +117,17 @@ def fake_get_chat_mode(self, model_id=None): def test_generate_playbook(self): playbook, outline = self.my_client.generate_playbook( + None, text="foo", ) self.assertEqual(playbook, "my_playbook") self.assertEqual(outline, "") def test_generate_playbook_with_outline(self): - playbook, outline = self.my_client.generate_playbook(text="foo", create_outline=True) + playbook, outline = self.my_client.generate_playbook(None, text="foo", create_outline=True) self.assertEqual(playbook, "my_playbook") self.assertEqual(outline, "my outline") def test_explain_playbook(self): - explanation = self.my_client.explain_playbook(content="foo") + explanation = self.my_client.explain_playbook(None, content="foo") self.assertTrue(explanation) diff --git a/ansible_wisdom/ai/api/model_client/tests/test_model_client.py b/ansible_wisdom/ai/api/model_client/tests/test_model_client.py index e320a61c6..69eae7f5c 100644 --- a/ansible_wisdom/ai/api/model_client/tests/test_model_client.py +++ b/ansible_wisdom/ai/api/model_client/tests/test_model_client.py @@ -42,6 +42,6 @@ def test_not_implemented(self): with self.assertRaises(NotImplementedError): c.get_chat_model("a") with self.assertRaises(NotImplementedError): - c.generate_playbook("a") + c.generate_playbook(None, "a") with self.assertRaises(NotImplementedError): - c.explain_playbook("a") + c.explain_playbook(None, "a") diff --git a/ansible_wisdom/ai/api/model_client/tests/test_wca_client.py b/ansible_wisdom/ai/api/model_client/tests/test_wca_client.py index 62879f471..4d66b9a7e 100644 --- a/ansible_wisdom/ai/api/model_client/tests/test_wca_client.py +++ b/ansible_wisdom/ai/api/model_client/tests/test_wca_client.py @@ -240,6 +240,35 @@ def test_fatal_exception(self): b = WCAClient.fatal_exception(exc) self.assertTrue(b) + def test_playbook_gen(self): + wca_client = WCAClient(inference_url='http://example.com/') + wca_client.get_api_key = Mock(return_value="some-key") + wca_client.get_token = Mock(return_value={"access_token": "a-token"}) + wca_client.get_model_id = Mock(return_value="a-random-model") + wca_client.session = Mock() + response = Mock + response.text = '{"playbook": "Oh!", "outline": "Ahh!"}' + wca_client.session.post.return_value = response + request = Mock() + playbook, outline = wca_client.generate_playbook( + request, text="Install Wordpress", create_outline=True + ) + self.assertEqual(playbook, "Oh!") + self.assertEqual(outline, "Ahh!") + + def test_playbook_exp(self): + wca_client = WCAClient(inference_url='http://example.com/') + wca_client.get_api_key = Mock(return_value="some-key") + wca_client.get_token = Mock(return_value={"access_token": "a-token"}) + wca_client.get_model_id = Mock(return_value="a-random-model") + wca_client.session = Mock() + response = Mock + response.text = '{"explanation": "!Óh¡"}' + wca_client.session.post.return_value = response + request = Mock() + explanation = wca_client.explain_playbook(request, content="Some playbook") + self.assertEqual(explanation, "!Óh¡") + @override_settings(ANSIBLE_WCA_RETRY_COUNT=1) @override_settings(WCA_SECRET_BACKEND_TYPE="dummy") diff --git a/ansible_wisdom/ai/api/model_client/wca_client.py b/ansible_wisdom/ai/api/model_client/wca_client.py index 662d7b032..abb798ac5 100644 --- a/ansible_wisdom/ai/api/model_client/wca_client.py +++ b/ansible_wisdom/ai/api/model_client/wca_client.py @@ -470,6 +470,45 @@ def self_test(self) -> HealthCheckSummary: return summary + def generate_playbook( + self, request, text: str = "", create_outline: bool = False, outline: str = "" + ) -> tuple[str, str]: + api_key = self.get_api_key(request.user.organization.id) + model_id = self.get_model_id(request.user.organization.id) + + headers = self._get_base_headers(api_key) + data = { + "model_id": model_id, + "text": text, + "create_outline": create_outline, + } + if outline: + data["outline"] = outline + result = self.session.post( + f"{self._inference_url}/v1/wca/codegen/ansible/playbook", + headers=headers, + json=data, + ) + response = json.loads(result.text) + return response["playbook"], response["outline"] + + def explain_playbook(self, request, content: str) -> str: + api_key = self.get_api_key(request.user.organization.id) + model_id = self.get_model_id(request.user.organization.id) + + headers = self._get_base_headers(api_key) + data = { + "model_id": model_id, + "playbook": content, + } + result = self.session.post( + f"{self._inference_url}/v1/wca/explain/ansible/playbook", + headers=headers, + json=data, + ) + response = json.loads(result.text) + return response["explanation"] + class WCAOnPremClient(BaseWCAClient): def __init__(self, inference_url): diff --git a/ansible_wisdom/ai/api/views.py b/ansible_wisdom/ai/api/views.py index 825843ed5..8411d757e 100644 --- a/ansible_wisdom/ai/api/views.py +++ b/ansible_wisdom/ai/api/views.py @@ -738,7 +738,7 @@ def post(self, request) -> Response: playbook = request_serializer.validated_data.get("content") llm = apps.get_app_config("ai").model_mesh_client - explanation = llm.explain_playbook(playbook) + explanation = llm.explain_playbook(request, playbook) answer = {"content": explanation, "format": "markdown", "explanationId": explanation_id} @@ -789,7 +789,7 @@ def post(self, request) -> Response: text = request_serializer.validated_data["text"] llm = apps.get_app_config("ai").model_mesh_client - playbook, outline = llm.generate_playbook(text, create_outline, outline) + playbook, outline = llm.generate_playbook(request, text, create_outline, outline) answer = { "playbook": playbook,