diff --git a/pom.xml b/pom.xml
index 934339d..6f3e19b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -19,6 +19,10 @@
2.0.3
+
+ org.springframework.boot
+ spring-boot-starter-webflux
+
org.springframework.boot
spring-boot-starter-web
@@ -67,6 +71,11 @@
spring-boot-starter-test
test
+
+ io.netty
+ netty-resolver-dns-native-macos
+ osx-aarch_64
+
diff --git a/src/main/java/com/ai/springdemo/aispringdemo/controller/ChatController.java b/src/main/java/com/ai/springdemo/aispringdemo/controller/ChatController.java
new file mode 100644
index 0000000..178f380
--- /dev/null
+++ b/src/main/java/com/ai/springdemo/aispringdemo/controller/ChatController.java
@@ -0,0 +1,44 @@
+package com.ai.springdemo.aispringdemo.controller;
+
+import com.ai.springdemo.aispringdemo.exception.InvalidMessageException;
+import com.ai.springdemo.aispringdemo.service.ChatService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.ResponseEntity;
+import org.springframework.web.bind.annotation.*;
+import reactor.core.publisher.Flux;
+
+@RestController
+public class ChatController {
+
+ private static final Logger logger = LoggerFactory.getLogger(ChatController.class);
+
+ private final ChatService questionService;
+
+ @Autowired
+ public ChatController(ChatService questionService){
+ this.questionService = questionService;
+ }
+
+ @PostMapping("/chat")
+ public Flux getQuestion(@RequestBody String question){
+
+ logger.info("Received question: {}", question);
+
+ if (question == null || question.trim().isEmpty()) {
+ logger.error("Invalid question received: empty or null message");
+ return Flux.error(new InvalidMessageException("Question cannot be null or empty"));
+ }
+
+ return questionService.chat(question)
+ .doOnError(error -> logger.error("Error processing question: {}", error.getMessage()));
+ }
+
+ @ExceptionHandler(InvalidMessageException.class)
+ public ResponseEntity handleInvalidMessageException(InvalidMessageException e) {
+ logger.error("Invalid message error: {}", e.getMessage());
+ return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
+ }
+}
diff --git a/src/main/java/com/ai/springdemo/aispringdemo/controller/DataController.java b/src/main/java/com/ai/springdemo/aispringdemo/controller/DataController.java
index a865d49..0086c7b 100644
--- a/src/main/java/com/ai/springdemo/aispringdemo/controller/DataController.java
+++ b/src/main/java/com/ai/springdemo/aispringdemo/controller/DataController.java
@@ -8,7 +8,6 @@
import org.springframework.web.bind.annotation.*;
@RestController
-@RequestMapping("/data")
public class DataController {
private final DataLoadingService dataLoadingService;
@@ -18,7 +17,7 @@ public DataController(DataLoadingService dataLoadingService, JdbcTemplate jdbcTe
this.dataLoadingService = dataLoadingService;
}
- @PostMapping("/load")
+ @GetMapping("/data-load")
public ResponseEntity load() {
try {
this.dataLoadingService.load();
diff --git a/src/main/java/com/ai/springdemo/aispringdemo/controller/QuestionController.java b/src/main/java/com/ai/springdemo/aispringdemo/controller/QuestionController.java
deleted file mode 100644
index 8208fd9..0000000
--- a/src/main/java/com/ai/springdemo/aispringdemo/controller/QuestionController.java
+++ /dev/null
@@ -1,34 +0,0 @@
-package com.ai.springdemo.aispringdemo.controller;
-
-import com.ai.springdemo.aispringdemo.service.QuestionService;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.http.ResponseEntity;
-import org.springframework.web.bind.annotation.*;
-
-import java.util.LinkedHashMap;
-import java.util.Map;
-
-@RestController
-public class QuestionController {
- private final QuestionService questionService;
-
- @Autowired
- public QuestionController(QuestionService questionService){
- this.questionService = questionService;
- }
-
- @PostMapping("/simple-question")
- public ResponseEntity getQuestion(@RequestBody String question){
- return ResponseEntity.ok().body(questionService.simpleQuestion(question));
- }
-
- @GetMapping("/question")
- public Map completion(@RequestParam(value = "question", defaultValue = "List three tools to trial") String question,
- @RequestParam(value = "prompstuff", defaultValue = "true") boolean prompstuff) {
- String answer = this.questionService.question(question, prompstuff);
- Map map = new LinkedHashMap();
- map.put("question", question);
- map.put("answer", answer);
- return map;
- }
-}
diff --git a/src/main/java/com/ai/springdemo/aispringdemo/exception/InvalidMessageException.java b/src/main/java/com/ai/springdemo/aispringdemo/exception/InvalidMessageException.java
new file mode 100644
index 0000000..f04330f
--- /dev/null
+++ b/src/main/java/com/ai/springdemo/aispringdemo/exception/InvalidMessageException.java
@@ -0,0 +1,8 @@
+package com.ai.springdemo.aispringdemo.exception;
+
+public class InvalidMessageException extends RuntimeException {
+
+ public InvalidMessageException(String message) {
+ super(message);
+ }
+}
diff --git a/src/main/java/com/ai/springdemo/aispringdemo/service/ChatService.java b/src/main/java/com/ai/springdemo/aispringdemo/service/ChatService.java
new file mode 100644
index 0000000..9f109fb
--- /dev/null
+++ b/src/main/java/com/ai/springdemo/aispringdemo/service/ChatService.java
@@ -0,0 +1,88 @@
+package com.ai.springdemo.aispringdemo.service;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.UserMessage;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.chat.prompt.SystemPromptTemplate;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.core.io.Resource;
+import org.springframework.stereotype.Service;
+import reactor.core.publisher.Flux;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+@Service
+public class ChatService {
+
+ private static final Logger logger = LoggerFactory.getLogger(ChatService.class);
+
+ @Value("classpath:/prompts/system-qa.st")
+ private Resource qaSystemPromptResource;
+
+ @Value("classpath:/prompts/system-chatbot.st")
+ private Resource chatbotSystemPromptResource;
+
+ private final ChatClient chatClient;
+
+
+ private final VectorStore vectorStore;
+
+ @Autowired
+ public ChatService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
+ this.chatClient = chatClientBuilder.build();
+ this.vectorStore = vectorStore;
+ }
+
+ public Flux chat(String userMessageText) {
+ Message systemMessage = createSystemMessage(userMessageText);
+ UserMessage userMessage = new UserMessage(userMessageText);
+ Prompt prompt = createPrompt(systemMessage, userMessage);
+
+ logger.info("Sending prompt to AI model.");
+
+ return chatClient.prompt(prompt)
+ .stream()
+ .chatResponse()
+ .flatMap(this::processChatResponse)
+ .doOnComplete(() -> logger.info("AI response processing completed."))
+ .doOnError(error -> logger.error("Error during AI response processing", error));
+ }
+
+ private Message createSystemMessage(String userMessageText) {
+ logger.info("Fetching relevant documents for message: {}", userMessageText);
+ List relevantDocuments = vectorStore.similaritySearch(userMessageText);
+ logger.info("Found {} relevant documents.", relevantDocuments.size());
+ String documentsContent = relevantDocuments.stream()
+ .map(Document::getContent)
+ .collect(Collectors.joining("\n"));
+ return new SystemPromptTemplate(qaSystemPromptResource)
+ .createMessage(Map.of("documents", documentsContent));
+ }
+
+ private Prompt createPrompt(Message systemMessage, UserMessage userMessage) {
+ return new Prompt(List.of(systemMessage, userMessage));
+ }
+
+ private Flux processChatResponse(ChatResponse chatResponse) {
+ return Optional.ofNullable(chatResponse)
+ .filter(response -> response.getResults() != null && !response.getResults().isEmpty())
+ .map(response -> response.getResults().getFirst())
+ .map(result -> result.getOutput().getContent())
+ .filter(content -> content != null)
+ .map(Flux::just)
+ .orElseGet(() -> {
+ logger.warn("Received an empty or null response.");
+ return Flux.empty();
+ });
+ }
+}
diff --git a/src/main/java/com/ai/springdemo/aispringdemo/service/QuestionService.java b/src/main/java/com/ai/springdemo/aispringdemo/service/QuestionService.java
deleted file mode 100644
index 9190b47..0000000
--- a/src/main/java/com/ai/springdemo/aispringdemo/service/QuestionService.java
+++ /dev/null
@@ -1,76 +0,0 @@
-package com.ai.springdemo.aispringdemo.service;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.client.ChatClient;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.model.ChatResponse;
-import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.chat.prompt.SystemPromptTemplate;
-import org.springframework.ai.document.Document;
-import org.springframework.ai.vectorstore.VectorStore;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.beans.factory.annotation.Value;
-import org.springframework.core.io.Resource;
-import org.springframework.stereotype.Service;
-
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-
-@Service
-public class QuestionService {
-
- private static final Logger logger = LoggerFactory.getLogger(QuestionService.class);
-
- @Value("classpath:/prompts/system-qa.st")
- private Resource qaSystemPromptResource;
-
- @Value("classpath:/prompts/system-chatbot.st")
- private Resource chatbotSystemPromptResource;
-
- private final ChatClient chatClient;
-
-
- private final VectorStore vectorStore;
-
- @Autowired
- public QuestionService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
- this.chatClient = chatClientBuilder.build();
- this.vectorStore = vectorStore;
- }
-
- public String question(String message, boolean prompstuff) {
- Message systemMessage = getSystemMessage(message, prompstuff);
- UserMessage userMessage = new UserMessage(message);
- Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
-
- logger.info("Asking AI model to reply to question.");
- ChatResponse chatResponse = chatClient.prompt(prompt).call().chatResponse();
- logger.info("AI responded.");
- return chatResponse.getResults().getFirst().getOutput().getContent();
- }
-
- private Message getSystemMessage(String message, boolean prompstuff) {
- if (prompstuff) {
- logger.info("Retrieving relevant documents");
- List similarDocuments = vectorStore.similaritySearch(message);
- logger.info(String.format("Found %s relevant documents.", similarDocuments.size()));
- String documents = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n"));
- SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.qaSystemPromptResource);
- return systemPromptTemplate.createMessage(Map.of("documents", documents));
- } else {
- logger.info("Not stuffing the prompt, using generic prompt");
- return new SystemPromptTemplate(this.chatbotSystemPromptResource).createMessage();
- }
- }
-
-
- public String simpleQuestion(String question) {
- Message systemMessage = getSystemMessage(question, true);
- UserMessage userMessage = new UserMessage(question);
- Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
- return chatClient.prompt(prompt).call().chatResponse().getResults().getFirst().getOutput().getContent();
- }
-}
diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties
index 8885781..c8fcdd2 100644
--- a/src/main/resources/application.properties
+++ b/src/main/resources/application.properties
@@ -7,6 +7,9 @@ spring.datasource.url=jdbc:postgresql://localhost:5433/vector_db
spring.datasource.username=postgres
spring.datasource.password=admin
+spring.ai.vectorstore.pgvector.remove-existing-vector-store-table=false
+spring.ai.vectorstore.pgvector.initialize-schema=true
+
spring.threads.virtual.enabled=true
spring.ai.openai.api-key={OPEN_AI_API_KEY}