diff --git a/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java index 4d90e9bd47..c97e03db42 100644 --- a/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java +++ b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java @@ -73,7 +73,7 @@ public AiResponse generate(Prompt prompt) { List messages = prompt.getMessages(); List azureMessages = new ArrayList<>(); for (Message message : messages) { - String messageType = message.getMessageType().getValue(); + String messageType = message.getMessageTypeValue(); ChatRole chatRole = ChatRole.fromString(messageType); azureMessages.add(new ChatMessage(chatRole, message.getContent())); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java index 085e721124..c971d80dfb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java @@ -87,4 +87,9 @@ public MessageType getMessageType() { return this.messageType; } + @Override + public String getMessageTypeValue() { + return this.messageType.getValue(); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java index a405eb10b8..190cd96427 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java @@ -26,4 +26,6 @@ public interface Message { MessageType getMessageType(); + String getMessageTypeValue(); + } diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java index b503dc5041..1db7246286 100644 --- a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java @@ -27,6 +27,7 @@ import org.springframework.ai.client.Generation; import org.springframework.ai.prompt.Prompt; import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.prompt.messages.MessageType; import org.springframework.util.Assert; import java.util.ArrayList; @@ -79,7 +80,7 @@ public AiResponse generate(Prompt prompt) { List messages = prompt.getMessages(); List theoMessages = new ArrayList<>(); for (Message message : messages) { - String messageType = message.getMessageType().getValue(); + String messageType = message.getMessageTypeValue(); theoMessages.add(new ChatMessage(messageType, message.getContent())); } ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() @@ -148,14 +149,14 @@ private List convertToChatMessages(List messages) { for (Message promptMessage : messages) { switch (promptMessage.getMessageType()) { case USER: - chatMessages.add(new ChatMessage("user", promptMessage.getContent())); + chatMessages.add(new ChatMessage(MessageType.USER.getValue(), promptMessage.getContent())); break; case ASSISTANT: // TODO - valid? - chatMessages.add(new ChatMessage("assistant", promptMessage.getContent())); + chatMessages.add(new ChatMessage(MessageType.ASSISTANT.getValue(), promptMessage.getContent())); break; case SYSTEM: - chatMessages.add(new ChatMessage("system", promptMessage.getContent())); + chatMessages.add(new ChatMessage(MessageType.SYSTEM.getValue(), promptMessage.getContent())); break; case FUNCTION: logger.error(