From 62dfaa603a0216cecbdd5019a630bcce2aa74922 Mon Sep 17 00:00:00 2001 From: "Dr. Hans-Peter Stoerr" Date: Mon, 25 Mar 2024 13:18:45 +0100 Subject: [PATCH] configurable AI endpoint; add possibility to call Anthropic Claude or local models --- .../sling/nodes/ai/impl/AIServiceImpl.java | 131 ++++++++++++------ .../ai/impl/AIServiceImplClaudeCall.java | 51 +++++++ ...Call.java => AIServiceImplOpenAICall.java} | 6 +- .../nodes/ai/impl/AIServiceImplTest.java | 55 ++++++++ 4 files changed, 200 insertions(+), 43 deletions(-) create mode 100644 console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplClaudeCall.java rename console/src/test/java/com/composum/sling/nodes/ai/impl/{AIServiceImplCall.java => AIServiceImplOpenAICall.java} (89%) create mode 100644 console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplTest.java diff --git a/console/src/main/java/com/composum/sling/nodes/ai/impl/AIServiceImpl.java b/console/src/main/java/com/composum/sling/nodes/ai/impl/AIServiceImpl.java index ab3609bd8..f9fee714e 100644 --- a/console/src/main/java/com/composum/sling/nodes/ai/impl/AIServiceImpl.java +++ b/console/src/main/java/com/composum/sling/nodes/ai/impl/AIServiceImpl.java @@ -14,7 +14,6 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; -import org.apache.commons.collections4.ListUtils; import org.apache.commons.collections4.MapUtils; import org.apache.commons.io.Charsets; import org.apache.commons.io.FileUtils; @@ -27,6 +26,7 @@ import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.util.EntityUtils; +import org.jetbrains.annotations.NotNull; import org.osgi.framework.Constants; import org.osgi.service.component.annotations.Activate; import org.osgi.service.component.annotations.Component; @@ -53,22 +53,31 @@ public class AIServiceImpl implements AIService { public static final String SERVICE_NAME = "Composum Nodes AI Service"; - private static final Logger LOG = LoggerFactory.getLogger(AIServiceImpl.class); + protected static final Logger LOG = LoggerFactory.getLogger(AIServiceImpl.class); - /** The default model - probably a GPT-4 is needed for complicated stuff like JCR queries. */ + /** + * The default model - probably a GPT-4 is needed for complicated stuff like JCR queries. + */ public static final String DEFAULT_MODEL = "gpt-4-turbo-preview"; protected static final String CHAT_COMPLETION_URL = "https://api.openai.com/v1/chat/completions"; - private Configuration config; - private String openAiApiKey; - private RateLimiter rateLimiter; - private Gson gson = new Gson(); - private CloseableHttpClient httpClient; - private String organizationId; + protected Configuration config; + protected String apiKey; + protected RateLimiter rateLimiter; + protected Gson gson = new Gson(); + protected CloseableHttpClient httpClient; + protected String chatURL; + protected String apiKeyHeader; + protected String additionalHeader; + protected String additionalHeaderValue; @Override public boolean isAvailable() { - return config != null && !config.disabled() && StringUtils.isNotBlank(openAiApiKey); + return config != null && !config.disabled() && StringUtils.isNotBlank(apiKey); + } + + protected boolean isAnthropicClaude() { + return chatURL != null && chatURL.contains("api.anthropic.com"); } @Nonnull @@ -83,7 +92,10 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla Map request = new HashMap<>(); request.put("model", StringUtils.defaultIfBlank(config.defaultModel(), DEFAULT_MODEL)); - if (systemmsg != null) { + if (systemmsg != null && isAnthropicClaude()) { + request.put("system", systemmsg); + request.put("messages", asList(userMessage)); + } else if (systemmsg != null) { Map systemMessage = new LinkedHashMap<>(); systemMessage.put("role", "system"); systemMessage.put("content", systemmsg); @@ -92,7 +104,10 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla request.put("messages", asList(userMessage)); } request.put("temperature", 0); - if (responseFormat == ResponseFormat.JSON) { + if (config.maxTokens() > 0) { + request.put("max_tokens", config.maxTokens()); + } + if (responseFormat == ResponseFormat.JSON && !isAnthropicClaude()) { Map responseFormatMap = new LinkedHashMap<>(); responseFormatMap.put("type", "json_object"); request.put("response_format", responseFormatMap); @@ -102,15 +117,20 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla rateLimiter.waitForLimit(); // retrieve response from OpenAI using httpClient 4 - HttpPost postRequest = new HttpPost(CHAT_COMPLETION_URL); + HttpPost postRequest = new HttpPost(this.chatURL); EntityBuilder entityBuilder = EntityBuilder.create(); entityBuilder.setContentType(ContentType.APPLICATION_JSON); entityBuilder.setContentEncoding("UTF-8"); entityBuilder.setText(requestJson); postRequest.setEntity(entityBuilder.build()); - postRequest.setHeader("Authorization", "Bearer " + openAiApiKey); - if (organizationId != null) { - postRequest.setHeader("OpenAI-Organization", organizationId); + if (this.apiKeyHeader != null) { + String[] headerSplitted = this.apiKeyHeader.split("\\s"); + String headername = headerSplitted[0]; + String prefix = headerSplitted.length > 1 ? headerSplitted[1].trim() + " " : ""; + postRequest.setHeader(headername, prefix + apiKey); + } + if (additionalHeader != null && additionalHeaderValue != null) { + postRequest.setHeader(additionalHeader, additionalHeaderValue); } String id = "#" + System.nanoTime(); LOG.debug("Request {} to OpenAI: {}", id, requestJson); @@ -123,7 +143,7 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla } @Nonnull - private String retrieveMessage(String id, CloseableHttpResponse response) throws AIServiceException, IOException { + protected String retrieveMessage(String id, CloseableHttpResponse response) throws AIServiceException, IOException { int statusCode = response.getStatusLine().getStatusCode(); HttpEntity responseEntity = response.getEntity(); if (statusCode != 200) { @@ -133,27 +153,38 @@ private String retrieveMessage(String id, CloseableHttpResponse response) throws if (bytes.size() > 0) { errorbody = errorbody + "\n" + new String(bytes.toByteArray(), Charsets.UTF_8); } - throw new AIServiceException("Error from OpenAI: " + statusCode + " " + errorbody); + throw new AIServiceException("Error from AI backend: " + response.getStatusLine() + " " + errorbody); } String responseJson = EntityUtils.toString(responseEntity); LOG.debug("Response {} from OpenAI: {}", id, responseJson); - Map responseMap = gson.fromJson(responseJson, Map.class); + return extractText(responseJson); + } - List> choices = ListUtils.emptyIfNull((List>) responseMap.get("choices")); - if (choices.isEmpty()) { - throw new AIServiceException("No choices in response: " + responseJson); - } - Map choice = choices.get(0); - String finish_reason = (String) choice.get("finish_reason"); - if (!"stop".equals(finish_reason)) { - throw new AIServiceException("Finish reason is not stop: " + finish_reason + " in response: " + responseJson); + @NotNull + protected String extractText(String responseJson) throws AIServiceException { + Map responseMap = gson.fromJson(responseJson, Map.class); + String text; + if (responseMap.containsKey("choices")) { // OpenAI format + List> choices = (List>) responseMap.get("choices"); + Map choice = choices.get(0); + String finish_reason = (String) choice.get("finish_reason"); + if (!"stop".equals(finish_reason)) { + throw new AIServiceException("Finish reason is not stop: " + finish_reason + " in response: " + responseJson); + } + Map message = MapUtils.emptyIfNull((Map) choice.get("message")); + text = (String) message.get("content"); + } else if (responseMap.containsKey("content")) { // Anthropic Claude format + List> content = (List>) responseMap.get("content"); + Map message = content.get(0); + text = (String) message.get("text"); + } else { + LOG.error("Response format not recognized: {}", responseJson); + throw new AIServiceException("Response format not recognized: " + responseJson); } - Map message = MapUtils.emptyIfNull((Map) choice.get("message")); - String text = (String) message.get("content"); if (text == null) { + LOG.error("No message in response: {}", responseJson); throw new AIServiceException("No message in response: " + responseJson); } - return text; } @@ -173,18 +204,22 @@ protected void deactivate() { @Modified protected void activate(Configuration configuration) { this.config = configuration; - this.openAiApiKey = StringUtils.defaultIfBlank(config.openAiApiKey(), System.getenv("OPENAI_API_KEY")); - this.openAiApiKey = StringUtils.defaultIfBlank(this.openAiApiKey, System.getProperty("openai.api.key")); + chatURL = StringUtils.defaultIfBlank(config.chatCompletionUrl(), CHAT_COMPLETION_URL).trim(); + String envVariable = isAnthropicClaude() ? "ANTHROPIC_API_KEY" : "OPENAI_API_KEY"; + this.apiKey = StringUtils.defaultIfBlank(config.openAiApiKey(), System.getenv(envVariable)); + this.apiKey = StringUtils.defaultIfBlank(this.apiKey, System.getProperty("openai.api.key")); if (config.openAiApiKeyFile() != null && !config.openAiApiKeyFile().isEmpty()) { try { - this.openAiApiKey = StringUtils.defaultIfBlank(this.openAiApiKey, + this.apiKey = StringUtils.defaultIfBlank(this.apiKey, FileUtils.readFileToString(new File(config.openAiApiKeyFile()), Charsets.UTF_8)); } catch (IOException e) { LOG.error("Could not read OpenAI key from {}", config.openAiApiKeyFile(), e); } } - openAiApiKey = StringUtils.trimToNull(openAiApiKey); - organizationId = StringUtils.trimToNull(config.organizationId()); + apiKey = StringUtils.trimToNull(apiKey); + additionalHeader = StringUtils.trimToNull(config.additionalHeader()); + additionalHeaderValue = StringUtils.trimToNull(config.additionalHeaderValue()); + apiKeyHeader = StringUtils.defaultIfBlank(config.apiKeyHeader(), "Authorization Bearer"); rateLimiter = null; if (isAvailable()) { int perMinuteLimit = config.requestsPerMinuteLimit() > 0 ? config.requestsPerMinuteLimit() : 20; @@ -203,27 +238,39 @@ protected void activate(Configuration configuration) { @AttributeDefinition(name = "Disable the GPT Chat Completion Service", description = "Disable the GPT Chat Completion Service", defaultValue = "false") boolean disabled() default false; // we want it to work by just deploying it. Admittedly this is a bit doubtful. - @AttributeDefinition(name = "OpenAI API Key from https://platform.openai.com/. If not given, we check the key file, the environment Variable OPENAI_API_KEY, and the system property openai.api.key .") + @AttributeDefinition(name = "Chat Completion URL", description = "The URL for chat completions. If not given, the default for OpenAI is used: " + CHAT_COMPLETION_URL + + " . In the case of Anthropic Claude it is https://api.anthropic.com/v1/messages.") + String chatCompletionUrl(); + + @AttributeDefinition(name = "API key for requests to AI backend, in the case of OpenAI from https://platform.openai.com/api-keys. If not given, we check the key file, the environment Variable OPENAI_API_KEY (or in the case of Anthropic Claude ANTHROPIC_API_KEY), and the system property openai.api.key .") String openAiApiKey(); // alternatively, a key file - @AttributeDefinition(name = "OpenAI API Key File containing the API key, as an alternative to Open AKI Key configuration and the variants described there.") + @AttributeDefinition(name = "OpenAI API key file containing the API key, as an alternative to API key configuration and the variants described there.") String openAiApiKeyFile(); - @AttributeDefinition(name = "Organization ID", description = "The organization ID from OpenAI, if you have one.") - String organizationId(); + @AttributeDefinition(name = "Header for API key", description = "The header in the request that is used for sending the API key. If it's two words like 'Authorization Bearer' then that second word is used as prefix for the value. Default: 'Authorization Bearer', as used by OpenAI.") + String apiKeyHeader(); + + @AttributeDefinition(name = "Additional Header", description = "Optionally, an additional header. In the cause of OpenAI that could be 'OpenAI-Organization'. In the case of Anthropic Claude it could be 'anthropic-version'.") + String additionalHeader(); + + @AttributeDefinition(name = "Additional Header Value", description = "The value for the additional header. In the case of OpenAI that could be the organization id. In the case of Anthropic Claude it could be the version of the API, e.g. '2023-06-01'.") + String additionalHeaderValue(); @AttributeDefinition(name = "Default model to use for the chat completion. If not configured we take a default, in this version " + DEFAULT_MODEL + ". Please consider the varying prices https://openai.com/pricing . For programming related questions a GPT-4 seems necessary, though. Do not configure if not necessary, to follow further changes.") String defaultModel(); - @AttributeDefinition(name = "Requests per minute", description = "The number of requests per minute - after half of that we do slow down. >0, the default is 100.", defaultValue = "20") + @AttributeDefinition(name = "Requests per minute", description = "The number of requests per minute - after half of that we do slow down. >0, the default is 100.") int requestsPerMinuteLimit() default 20; - @AttributeDefinition(name = "Requests per hour", description = "The number of requests per hour - after half of that we do slow down. >0, the default is 1000.", defaultValue = "60") + @AttributeDefinition(name = "Requests per hour", description = "The number of requests per hour - after half of that we do slow down. >0, the default is 1000.") int requestsPerHourLimit() default 60; - @AttributeDefinition(name = "Requests per day", description = "The number of requests per day - after half of that we do slow down. >0, the default is 12000.", defaultValue = "120") + @AttributeDefinition(name = "Requests per day", description = "The number of requests per day - after half of that we do slow down. >0, the default is 12000.") int requestsPerDayLimit() default 120; + @AttributeDefinition(name = "Maximum number of tokens", description = "Maximum number of tokens in the response. >0, the default is 2048.") + int maxTokens() default 2048; } } diff --git a/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplClaudeCall.java b/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplClaudeCall.java new file mode 100644 index 000000000..f797d6c15 --- /dev/null +++ b/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplClaudeCall.java @@ -0,0 +1,51 @@ +package com.composum.sling.nodes.ai.impl; + +import static org.mockito.Mockito.when; + +import java.io.IOException; + +import org.apache.commons.lang3.StringUtils; +import org.mockito.Mockito; + +/** + * Calls Anthropic Claude create messsage service to verify the implementation works. + * Not a test since that needs network and OpenAI key. + * + * @see "https://docs.anthropic.com/claude/reference/messages_post" + */ +public class AIServiceImplClaudeCall { + + private static AIServiceImpl aiService = new AIServiceImpl(); + + public static void main(String[] args) throws Exception { + setUp(); + testPrompt(); + } + + public static void setUp() throws IOException { + AIServiceImpl.Configuration config = Mockito.mock(AIServiceImpl.Configuration.class); + when(config.disabled()).thenReturn(false); + when(config.chatCompletionUrl()).thenReturn("https://api.anthropic.com/v1/messages"); + when(config.apiKeyHeader()).thenReturn("x-api-key"); + when(config.defaultModel()).thenReturn("claude-3-opus-20240229"); + when(config.additionalHeader()).thenReturn("anthropic-version"); + String version = StringUtils.defaultIfBlank(System.getenv("ANTHROPIC_API_VERSION"), "2023-06-01"); + when (config.additionalHeaderValue()).thenReturn(version); + when(config.openAiApiKey()).thenReturn(null); // rely on env var ANTHROPIC_API_KEY + when(config.requestsPerMinuteLimit()).thenReturn(20); + when(config.requestsPerHourLimit()).thenReturn(60); + when(config.requestsPerDayLimit()).thenReturn(120); + when(config.maxTokens()).thenReturn(2048); + aiService.activate(config); + if (!aiService.isAvailable()) { + throw new IllegalStateException("AI service not available, probably because of missing ANTHROPIC_API_KEY"); + } + } + + public static void testPrompt() throws Exception { + System.out.println(aiService.prompt(null, "Hi!", null)); + System.out.println(aiService.prompt("Reverse the users message.", "Hi!", null)); + System.out.println(aiService.prompt(null, "Make a haiku in JSON format about Composum.", AIServiceImpl.ResponseFormat.JSON)); + System.out.println(aiService.prompt("Always respond in JSON", "Make a haiku about Composum.", AIServiceImpl.ResponseFormat.JSON)); + } +} diff --git a/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplCall.java b/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplOpenAICall.java similarity index 89% rename from console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplCall.java rename to console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplOpenAICall.java index 3b39646f9..e11ecfab3 100644 --- a/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplCall.java +++ b/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplOpenAICall.java @@ -6,7 +6,11 @@ import org.mockito.Mockito; -public class AIServiceImplCall { +/** + * Calls OpenAI chat completion service to verify the implementation works. + * Not a test since that needs network and OpenAI key. + */ +public class AIServiceImplOpenAICall { private static AIServiceImpl aiService = new AIServiceImpl(); diff --git a/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplTest.java b/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplTest.java new file mode 100644 index 000000000..2854caa00 --- /dev/null +++ b/console/src/test/java/com/composum/sling/nodes/ai/impl/AIServiceImplTest.java @@ -0,0 +1,55 @@ +package com.composum.sling.nodes.ai.impl; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +import com.composum.sling.nodes.ai.AIService; + +public class AIServiceImplTest { + + @Test + public void extractOpenAIResponse() throws AIService.AIServiceException { + String response = "{\n" + + " \"id\": \"chatcmpl-123\",\n" + + " \"model\": \"gpt-3.5-turbo-0125\",\n" + + " \"choices\": [\n" + + " {\n" + + " \"index\": 0,\n" + + " \"message\": {\n" + + " \"role\": \"assistant\",\n" + + " \"content\": \"Hello there, how may I assist you today?\"\n" + + " },\n" + + " \"finish_reason\": \"stop\"\n" + + " }\n" + + " ]\n" + + "}\n"; + String result = new AIServiceImpl().extractText(response); + assertEquals("Hello there, how may I assist you today?", result); + } + + @Test + public void extractClaudeResponse() throws AIService.AIServiceException { + String response = "{\n" + + " \"content\": [\n" + + " {\n" + + " \"text\": \"Hi! My name is Claude.\",\n" + + " \"type\": \"text\"\n" + + " }\n" + + " ],\n" + + " \"id\": \"msg_013Zva2CMHLNnXjNJJKqJ2EF\",\n" + + " \"model\": \"claude-3-opus-20240229\",\n" + + " \"role\": \"assistant\",\n" + + " \"stop_reason\": \"end_turn\",\n" + + " \"stop_sequence\": null,\n" + + " \"type\": \"message\",\n" + + " \"usage\": {\n" + + " \"input_tokens\": 10,\n" + + " \"output_tokens\": 25\n" + + " }\n" + + "}"; + String result = new AIServiceImpl().extractText(response); + assertEquals("Hi! My name is Claude.", result); + } + +}