diff --git a/nucliadb/tests/nucliadb/integration/test_ask.py b/nucliadb/tests/nucliadb/integration/test_ask.py index 3b3522aa54..9952379aee 100644 --- a/nucliadb/tests/nucliadb/integration/test_ask.py +++ b/nucliadb/tests/nucliadb/integration/test_ask.py @@ -51,7 +51,15 @@ async def test_ask( context = [{"author": "USER", "text": "query"}] resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/ask", json={"query": "query", "context": context} + f"/kb/{knowledgebox}/ask", + json={ + "query": "query", + "context": context, + "answer_json_schema": { + "type": "object", + "properties": {"answer": {"type": "string"}, "confidence": {"type": "number"}}, + }, + }, ) assert resp.status_code == 200 diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index 5209413594..4e30f3fadc 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # -import json from dataclasses import dataclass from datetime import datetime from enum import Enum @@ -41,36 +40,34 @@ _T = TypeVar("_T") -ANSWER_JSON_SCHEMA_EXAMPLE = json.dumps( - { - "name": "structred_response", - "description": "Structured response with custom fields", - "parameters": { - "type": "object", - "properties": { - "answer": { +ANSWER_JSON_SCHEMA_EXAMPLE = { + "name": "structred_response", + "description": "Structured response with custom fields", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "Text responding to the user's query with the given context.", + }, + "confidence": { + "type": "integer", + "description": "The confidence level of the response, on a scale from 0 to 5.", + "minimum": 0, + "maximum": 5, + }, + "machinery_mentioned": { + "type": "array", + "items": { "type": "string", - "description": "Text responding to the user's query with the given context.", - }, - "confidence": { - "type": "integer", - "description": "The confidence level of the response, on a scale from 0 to 5.", - "minimum": 0, - "maximum": 5, - }, - "machinery_mentioned": { - "type": "array", - "items": { - "type": "string", - "description": "A list of machinery mentioned in the response, if any. Use machine IDs if possible.", - }, - "description": "Optional field listing any machinery mentioned in the response.", + "description": "A list of machinery mentioned in the response, if any. Use machine IDs if possible.", }, + "description": "Optional field listing any machinery mentioned in the response.", }, - "required": ["answer", "confidence"], }, - } -) + "required": ["answer", "confidence"], + }, +} class ModelParamDefaults: @@ -862,7 +859,7 @@ class ChatModel(BaseModel): default=False, description="If set to true, the response will be in markdown format", ) - json_schema: Optional[str] = Field( + json_schema: Optional[Dict[str, Any]] = Field( default=None, description="The JSON schema to use for the generative model answers", ) @@ -1360,13 +1357,13 @@ def validate_facets(facets): class AskRequest(ChatRequest): - answer_json_schema: Optional[str] = Field( + answer_json_schema: Optional[Dict[str, Any]] = Field( default=None, title="Answer JSON schema", description="""Desired JSON schema of the desired LLM answer. This schema is passed to the LLM so that it answers in a scructured format following the schema. If not provided, textual response is returned. Note that when using this parameter, the answer in the generative response will not be returned in chunks, the whole response text will be returned instead. -Using this feature also disables the `citations` parameter. +Using this feature also disables the `citations` parameter. For maximal accuracy, please include a `description` for each field of the schema. """, examples=[ANSWER_JSON_SCHEMA_EXAMPLE], ) diff --git a/nucliadb_sdk/tests/test_ask.py b/nucliadb_sdk/tests/test_ask.py index 617693034e..e2c7c6b426 100644 --- a/nucliadb_sdk/tests/test_ask.py +++ b/nucliadb_sdk/tests/test_ask.py @@ -56,6 +56,10 @@ def test_ask_on_kb(docs_dataset, sdk: nucliadb_sdk.NucliaDB): ], # Control the number of AI tokens used for every request max_tokens=MaxTokens(context=100, answer=50), + answer_json_schema={ + "type": "object", + "properties": {"answer": {"type": "string"}, "confidence": {"type": "number"}}, + }, ) assert result.learning_id == "00" assert result.answer == "valid answer to"