Skip to content

Commit

Permalink
configurable AI endpoint; add possibility to call Anthropic Claude o…
Browse files Browse the repository at this point in the history
…r local models
  • Loading branch information
stoerr committed Mar 25, 2024
1 parent d3a6479 commit 62dfaa6
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.collections4.ListUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.io.Charsets;
import org.apache.commons.io.FileUtils;
Expand All @@ -27,6 +26,7 @@
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
import org.jetbrains.annotations.NotNull;
import org.osgi.framework.Constants;
import org.osgi.service.component.annotations.Activate;
import org.osgi.service.component.annotations.Component;
Expand All @@ -53,22 +53,31 @@ public class AIServiceImpl implements AIService {

public static final String SERVICE_NAME = "Composum Nodes AI Service";

private static final Logger LOG = LoggerFactory.getLogger(AIServiceImpl.class);
protected static final Logger LOG = LoggerFactory.getLogger(AIServiceImpl.class);

/** The default model - probably a GPT-4 is needed for complicated stuff like JCR queries. */
/**
* The default model - probably a GPT-4 is needed for complicated stuff like JCR queries.
*/
public static final String DEFAULT_MODEL = "gpt-4-turbo-preview";
protected static final String CHAT_COMPLETION_URL = "https://api.openai.com/v1/chat/completions";

private Configuration config;
private String openAiApiKey;
private RateLimiter rateLimiter;
private Gson gson = new Gson();
private CloseableHttpClient httpClient;
private String organizationId;
protected Configuration config;
protected String apiKey;
protected RateLimiter rateLimiter;
protected Gson gson = new Gson();
protected CloseableHttpClient httpClient;
protected String chatURL;
protected String apiKeyHeader;
protected String additionalHeader;
protected String additionalHeaderValue;

@Override
public boolean isAvailable() {
return config != null && !config.disabled() && StringUtils.isNotBlank(openAiApiKey);
return config != null && !config.disabled() && StringUtils.isNotBlank(apiKey);
}

protected boolean isAnthropicClaude() {
return chatURL != null && chatURL.contains("api.anthropic.com");
}

@Nonnull
Expand All @@ -83,7 +92,10 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla

Map<String, Object> request = new HashMap<>();
request.put("model", StringUtils.defaultIfBlank(config.defaultModel(), DEFAULT_MODEL));
if (systemmsg != null) {
if (systemmsg != null && isAnthropicClaude()) {
request.put("system", systemmsg);
request.put("messages", asList(userMessage));
} else if (systemmsg != null) {
Map<String, Object> systemMessage = new LinkedHashMap<>();
systemMessage.put("role", "system");
systemMessage.put("content", systemmsg);
Expand All @@ -92,7 +104,10 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla
request.put("messages", asList(userMessage));
}
request.put("temperature", 0);
if (responseFormat == ResponseFormat.JSON) {
if (config.maxTokens() > 0) {
request.put("max_tokens", config.maxTokens());
}
if (responseFormat == ResponseFormat.JSON && !isAnthropicClaude()) {
Map<String, String> responseFormatMap = new LinkedHashMap<>();
responseFormatMap.put("type", "json_object");
request.put("response_format", responseFormatMap);
Expand All @@ -102,15 +117,20 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla

rateLimiter.waitForLimit();
// retrieve response from OpenAI using httpClient 4
HttpPost postRequest = new HttpPost(CHAT_COMPLETION_URL);
HttpPost postRequest = new HttpPost(this.chatURL);
EntityBuilder entityBuilder = EntityBuilder.create();
entityBuilder.setContentType(ContentType.APPLICATION_JSON);
entityBuilder.setContentEncoding("UTF-8");
entityBuilder.setText(requestJson);
postRequest.setEntity(entityBuilder.build());
postRequest.setHeader("Authorization", "Bearer " + openAiApiKey);
if (organizationId != null) {
postRequest.setHeader("OpenAI-Organization", organizationId);
if (this.apiKeyHeader != null) {
String[] headerSplitted = this.apiKeyHeader.split("\\s");
String headername = headerSplitted[0];
String prefix = headerSplitted.length > 1 ? headerSplitted[1].trim() + " " : "";
postRequest.setHeader(headername, prefix + apiKey);
}
if (additionalHeader != null && additionalHeaderValue != null) {
postRequest.setHeader(additionalHeader, additionalHeaderValue);
}
String id = "#" + System.nanoTime();
LOG.debug("Request {} to OpenAI: {}", id, requestJson);
Expand All @@ -123,7 +143,7 @@ public String prompt(@Nullable String systemmsg, @Nonnull String usermsg, @Nulla
}

@Nonnull
private String retrieveMessage(String id, CloseableHttpResponse response) throws AIServiceException, IOException {
protected String retrieveMessage(String id, CloseableHttpResponse response) throws AIServiceException, IOException {
int statusCode = response.getStatusLine().getStatusCode();
HttpEntity responseEntity = response.getEntity();
if (statusCode != 200) {
Expand All @@ -133,27 +153,38 @@ private String retrieveMessage(String id, CloseableHttpResponse response) throws
if (bytes.size() > 0) {
errorbody = errorbody + "\n" + new String(bytes.toByteArray(), Charsets.UTF_8);
}
throw new AIServiceException("Error from OpenAI: " + statusCode + " " + errorbody);
throw new AIServiceException("Error from AI backend: " + response.getStatusLine() + " " + errorbody);
}
String responseJson = EntityUtils.toString(responseEntity);
LOG.debug("Response {} from OpenAI: {}", id, responseJson);
Map<String, Object> responseMap = gson.fromJson(responseJson, Map.class);
return extractText(responseJson);
}

List<Map<String, Object>> choices = ListUtils.emptyIfNull((List<Map<String, Object>>) responseMap.get("choices"));
if (choices.isEmpty()) {
throw new AIServiceException("No choices in response: " + responseJson);
}
Map<String, Object> choice = choices.get(0);
String finish_reason = (String) choice.get("finish_reason");
if (!"stop".equals(finish_reason)) {
throw new AIServiceException("Finish reason is not stop: " + finish_reason + " in response: " + responseJson);
@NotNull
protected String extractText(String responseJson) throws AIServiceException {
Map<String, Object> responseMap = gson.fromJson(responseJson, Map.class);
String text;
if (responseMap.containsKey("choices")) { // OpenAI format
List<Map<String, Object>> choices = (List<Map<String, Object>>) responseMap.get("choices");
Map<String, Object> choice = choices.get(0);
String finish_reason = (String) choice.get("finish_reason");
if (!"stop".equals(finish_reason)) {
throw new AIServiceException("Finish reason is not stop: " + finish_reason + " in response: " + responseJson);
}
Map<String, Object> message = MapUtils.emptyIfNull((Map<String, Object>) choice.get("message"));
text = (String) message.get("content");
} else if (responseMap.containsKey("content")) { // Anthropic Claude format
List<Map<String, Object>> content = (List<Map<String, Object>>) responseMap.get("content");
Map<String, Object> message = content.get(0);
text = (String) message.get("text");
} else {
LOG.error("Response format not recognized: {}", responseJson);
throw new AIServiceException("Response format not recognized: " + responseJson);
}
Map<String, Object> message = MapUtils.emptyIfNull((Map<String, Object>) choice.get("message"));
String text = (String) message.get("content");
if (text == null) {
LOG.error("No message in response: {}", responseJson);
throw new AIServiceException("No message in response: " + responseJson);
}

return text;
}

Expand All @@ -173,18 +204,22 @@ protected void deactivate() {
@Modified
protected void activate(Configuration configuration) {
this.config = configuration;
this.openAiApiKey = StringUtils.defaultIfBlank(config.openAiApiKey(), System.getenv("OPENAI_API_KEY"));
this.openAiApiKey = StringUtils.defaultIfBlank(this.openAiApiKey, System.getProperty("openai.api.key"));
chatURL = StringUtils.defaultIfBlank(config.chatCompletionUrl(), CHAT_COMPLETION_URL).trim();
String envVariable = isAnthropicClaude() ? "ANTHROPIC_API_KEY" : "OPENAI_API_KEY";
this.apiKey = StringUtils.defaultIfBlank(config.openAiApiKey(), System.getenv(envVariable));
this.apiKey = StringUtils.defaultIfBlank(this.apiKey, System.getProperty("openai.api.key"));
if (config.openAiApiKeyFile() != null && !config.openAiApiKeyFile().isEmpty()) {
try {
this.openAiApiKey = StringUtils.defaultIfBlank(this.openAiApiKey,
this.apiKey = StringUtils.defaultIfBlank(this.apiKey,
FileUtils.readFileToString(new File(config.openAiApiKeyFile()), Charsets.UTF_8));
} catch (IOException e) {
LOG.error("Could not read OpenAI key from {}", config.openAiApiKeyFile(), e);
}
}
openAiApiKey = StringUtils.trimToNull(openAiApiKey);
organizationId = StringUtils.trimToNull(config.organizationId());
apiKey = StringUtils.trimToNull(apiKey);
additionalHeader = StringUtils.trimToNull(config.additionalHeader());
additionalHeaderValue = StringUtils.trimToNull(config.additionalHeaderValue());
apiKeyHeader = StringUtils.defaultIfBlank(config.apiKeyHeader(), "Authorization Bearer");
rateLimiter = null;
if (isAvailable()) {
int perMinuteLimit = config.requestsPerMinuteLimit() > 0 ? config.requestsPerMinuteLimit() : 20;
Expand All @@ -203,27 +238,39 @@ protected void activate(Configuration configuration) {
@AttributeDefinition(name = "Disable the GPT Chat Completion Service", description = "Disable the GPT Chat Completion Service", defaultValue = "false")
boolean disabled() default false; // we want it to work by just deploying it. Admittedly this is a bit doubtful.

@AttributeDefinition(name = "OpenAI API Key from https://platform.openai.com/. If not given, we check the key file, the environment Variable OPENAI_API_KEY, and the system property openai.api.key .")
@AttributeDefinition(name = "Chat Completion URL", description = "The URL for chat completions. If not given, the default for OpenAI is used: " + CHAT_COMPLETION_URL +
" . In the case of Anthropic Claude it is https://api.anthropic.com/v1/messages.")
String chatCompletionUrl();

@AttributeDefinition(name = "API key for requests to AI backend, in the case of OpenAI from https://platform.openai.com/api-keys. If not given, we check the key file, the environment Variable OPENAI_API_KEY (or in the case of Anthropic Claude ANTHROPIC_API_KEY), and the system property openai.api.key .")
String openAiApiKey();

// alternatively, a key file
@AttributeDefinition(name = "OpenAI API Key File containing the API key, as an alternative to Open AKI Key configuration and the variants described there.")
@AttributeDefinition(name = "OpenAI API key file containing the API key, as an alternative to API key configuration and the variants described there.")
String openAiApiKeyFile();

@AttributeDefinition(name = "Organization ID", description = "The organization ID from OpenAI, if you have one.")
String organizationId();
@AttributeDefinition(name = "Header for API key", description = "The header in the request that is used for sending the API key. If it's two words like 'Authorization Bearer' then that second word is used as prefix for the value. Default: 'Authorization Bearer', as used by OpenAI.")
String apiKeyHeader();

@AttributeDefinition(name = "Additional Header", description = "Optionally, an additional header. In the cause of OpenAI that could be 'OpenAI-Organization'. In the case of Anthropic Claude it could be 'anthropic-version'.")
String additionalHeader();

@AttributeDefinition(name = "Additional Header Value", description = "The value for the additional header. In the case of OpenAI that could be the organization id. In the case of Anthropic Claude it could be the version of the API, e.g. '2023-06-01'.")
String additionalHeaderValue();

@AttributeDefinition(name = "Default model to use for the chat completion. If not configured we take a default, in this version " + DEFAULT_MODEL + ". Please consider the varying prices https://openai.com/pricing . For programming related questions a GPT-4 seems necessary, though. Do not configure if not necessary, to follow further changes.")
String defaultModel();

@AttributeDefinition(name = "Requests per minute", description = "The number of requests per minute - after half of that we do slow down. >0, the default is 100.", defaultValue = "20")
@AttributeDefinition(name = "Requests per minute", description = "The number of requests per minute - after half of that we do slow down. >0, the default is 100.")
int requestsPerMinuteLimit() default 20;

@AttributeDefinition(name = "Requests per hour", description = "The number of requests per hour - after half of that we do slow down. >0, the default is 1000.", defaultValue = "60")
@AttributeDefinition(name = "Requests per hour", description = "The number of requests per hour - after half of that we do slow down. >0, the default is 1000.")
int requestsPerHourLimit() default 60;

@AttributeDefinition(name = "Requests per day", description = "The number of requests per day - after half of that we do slow down. >0, the default is 12000.", defaultValue = "120")
@AttributeDefinition(name = "Requests per day", description = "The number of requests per day - after half of that we do slow down. >0, the default is 12000.")
int requestsPerDayLimit() default 120;

@AttributeDefinition(name = "Maximum number of tokens", description = "Maximum number of tokens in the response. >0, the default is 2048.")
int maxTokens() default 2048;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.composum.sling.nodes.ai.impl;

import static org.mockito.Mockito.when;

import java.io.IOException;

import org.apache.commons.lang3.StringUtils;
import org.mockito.Mockito;

/**
* Calls Anthropic Claude create messsage service to verify the implementation works.
* Not a test since that needs network and OpenAI key.
*
* @see "https://docs.anthropic.com/claude/reference/messages_post"
*/
public class AIServiceImplClaudeCall {

private static AIServiceImpl aiService = new AIServiceImpl();

public static void main(String[] args) throws Exception {
setUp();
testPrompt();
}

public static void setUp() throws IOException {
AIServiceImpl.Configuration config = Mockito.mock(AIServiceImpl.Configuration.class);
when(config.disabled()).thenReturn(false);
when(config.chatCompletionUrl()).thenReturn("https://api.anthropic.com/v1/messages");
when(config.apiKeyHeader()).thenReturn("x-api-key");
when(config.defaultModel()).thenReturn("claude-3-opus-20240229");
when(config.additionalHeader()).thenReturn("anthropic-version");
String version = StringUtils.defaultIfBlank(System.getenv("ANTHROPIC_API_VERSION"), "2023-06-01");
when (config.additionalHeaderValue()).thenReturn(version);
when(config.openAiApiKey()).thenReturn(null); // rely on env var ANTHROPIC_API_KEY
when(config.requestsPerMinuteLimit()).thenReturn(20);
when(config.requestsPerHourLimit()).thenReturn(60);
when(config.requestsPerDayLimit()).thenReturn(120);
when(config.maxTokens()).thenReturn(2048);
aiService.activate(config);
if (!aiService.isAvailable()) {
throw new IllegalStateException("AI service not available, probably because of missing ANTHROPIC_API_KEY");
}
}

public static void testPrompt() throws Exception {
System.out.println(aiService.prompt(null, "Hi!", null));
System.out.println(aiService.prompt("Reverse the users message.", "Hi!", null));
System.out.println(aiService.prompt(null, "Make a haiku in JSON format about Composum.", AIServiceImpl.ResponseFormat.JSON));
System.out.println(aiService.prompt("Always respond in JSON", "Make a haiku about Composum.", AIServiceImpl.ResponseFormat.JSON));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import org.mockito.Mockito;

public class AIServiceImplCall {
/**
* Calls OpenAI chat completion service to verify the implementation works.
* Not a test since that needs network and OpenAI key.
*/
public class AIServiceImplOpenAICall {

private static AIServiceImpl aiService = new AIServiceImpl();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.composum.sling.nodes.ai.impl;

import static org.junit.Assert.assertEquals;

import org.junit.Test;

import com.composum.sling.nodes.ai.AIService;

public class AIServiceImplTest {

@Test
public void extractOpenAIResponse() throws AIService.AIServiceException {
String response = "{\n" +
" \"id\": \"chatcmpl-123\",\n" +
" \"model\": \"gpt-3.5-turbo-0125\",\n" +
" \"choices\": [\n" +
" {\n" +
" \"index\": 0,\n" +
" \"message\": {\n" +
" \"role\": \"assistant\",\n" +
" \"content\": \"Hello there, how may I assist you today?\"\n" +
" },\n" +
" \"finish_reason\": \"stop\"\n" +
" }\n" +
" ]\n" +
"}\n";
String result = new AIServiceImpl().extractText(response);
assertEquals("Hello there, how may I assist you today?", result);
}

@Test
public void extractClaudeResponse() throws AIService.AIServiceException {
String response = "{\n" +
" \"content\": [\n" +
" {\n" +
" \"text\": \"Hi! My name is Claude.\",\n" +
" \"type\": \"text\"\n" +
" }\n" +
" ],\n" +
" \"id\": \"msg_013Zva2CMHLNnXjNJJKqJ2EF\",\n" +
" \"model\": \"claude-3-opus-20240229\",\n" +
" \"role\": \"assistant\",\n" +
" \"stop_reason\": \"end_turn\",\n" +
" \"stop_sequence\": null,\n" +
" \"type\": \"message\",\n" +
" \"usage\": {\n" +
" \"input_tokens\": 10,\n" +
" \"output_tokens\": 25\n" +
" }\n" +
"}";
String result = new AIServiceImpl().extractText(response);
assertEquals("Hi! My name is Claude.", result);
}

}

0 comments on commit 62dfaa6

Please sign in to comment.