diff --git a/pom.xml b/pom.xml index 934339d..6f3e19b 100644 --- a/pom.xml +++ b/pom.xml @@ -19,6 +19,10 @@ 2.0.3 + + org.springframework.boot + spring-boot-starter-webflux + org.springframework.boot spring-boot-starter-web @@ -67,6 +71,11 @@ spring-boot-starter-test test + + io.netty + netty-resolver-dns-native-macos + osx-aarch_64 + diff --git a/src/main/java/com/ai/springdemo/aispringdemo/controller/ChatController.java b/src/main/java/com/ai/springdemo/aispringdemo/controller/ChatController.java new file mode 100644 index 0000000..178f380 --- /dev/null +++ b/src/main/java/com/ai/springdemo/aispringdemo/controller/ChatController.java @@ -0,0 +1,44 @@ +package com.ai.springdemo.aispringdemo.controller; + +import com.ai.springdemo.aispringdemo.exception.InvalidMessageException; +import com.ai.springdemo.aispringdemo.service.ChatService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; +import reactor.core.publisher.Flux; + +@RestController +public class ChatController { + + private static final Logger logger = LoggerFactory.getLogger(ChatController.class); + + private final ChatService questionService; + + @Autowired + public ChatController(ChatService questionService){ + this.questionService = questionService; + } + + @PostMapping("/chat") + public Flux getQuestion(@RequestBody String question){ + + logger.info("Received question: {}", question); + + if (question == null || question.trim().isEmpty()) { + logger.error("Invalid question received: empty or null message"); + return Flux.error(new InvalidMessageException("Question cannot be null or empty")); + } + + return questionService.chat(question) + .doOnError(error -> logger.error("Error processing question: {}", error.getMessage())); + } + + @ExceptionHandler(InvalidMessageException.class) + public ResponseEntity handleInvalidMessageException(InvalidMessageException e) { + logger.error("Invalid message error: {}", e.getMessage()); + return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST); + } +} diff --git a/src/main/java/com/ai/springdemo/aispringdemo/controller/DataController.java b/src/main/java/com/ai/springdemo/aispringdemo/controller/DataController.java index a865d49..0086c7b 100644 --- a/src/main/java/com/ai/springdemo/aispringdemo/controller/DataController.java +++ b/src/main/java/com/ai/springdemo/aispringdemo/controller/DataController.java @@ -8,7 +8,6 @@ import org.springframework.web.bind.annotation.*; @RestController -@RequestMapping("/data") public class DataController { private final DataLoadingService dataLoadingService; @@ -18,7 +17,7 @@ public DataController(DataLoadingService dataLoadingService, JdbcTemplate jdbcTe this.dataLoadingService = dataLoadingService; } - @PostMapping("/load") + @GetMapping("/data-load") public ResponseEntity load() { try { this.dataLoadingService.load(); diff --git a/src/main/java/com/ai/springdemo/aispringdemo/controller/QuestionController.java b/src/main/java/com/ai/springdemo/aispringdemo/controller/QuestionController.java deleted file mode 100644 index 8208fd9..0000000 --- a/src/main/java/com/ai/springdemo/aispringdemo/controller/QuestionController.java +++ /dev/null @@ -1,34 +0,0 @@ -package com.ai.springdemo.aispringdemo.controller; - -import com.ai.springdemo.aispringdemo.service.QuestionService; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.*; - -import java.util.LinkedHashMap; -import java.util.Map; - -@RestController -public class QuestionController { - private final QuestionService questionService; - - @Autowired - public QuestionController(QuestionService questionService){ - this.questionService = questionService; - } - - @PostMapping("/simple-question") - public ResponseEntity getQuestion(@RequestBody String question){ - return ResponseEntity.ok().body(questionService.simpleQuestion(question)); - } - - @GetMapping("/question") - public Map completion(@RequestParam(value = "question", defaultValue = "List three tools to trial") String question, - @RequestParam(value = "prompstuff", defaultValue = "true") boolean prompstuff) { - String answer = this.questionService.question(question, prompstuff); - Map map = new LinkedHashMap(); - map.put("question", question); - map.put("answer", answer); - return map; - } -} diff --git a/src/main/java/com/ai/springdemo/aispringdemo/exception/InvalidMessageException.java b/src/main/java/com/ai/springdemo/aispringdemo/exception/InvalidMessageException.java new file mode 100644 index 0000000..f04330f --- /dev/null +++ b/src/main/java/com/ai/springdemo/aispringdemo/exception/InvalidMessageException.java @@ -0,0 +1,8 @@ +package com.ai.springdemo.aispringdemo.exception; + +public class InvalidMessageException extends RuntimeException { + + public InvalidMessageException(String message) { + super(message); + } +} diff --git a/src/main/java/com/ai/springdemo/aispringdemo/service/ChatService.java b/src/main/java/com/ai/springdemo/aispringdemo/service/ChatService.java new file mode 100644 index 0000000..9f109fb --- /dev/null +++ b/src/main/java/com/ai/springdemo/aispringdemo/service/ChatService.java @@ -0,0 +1,88 @@ +package com.ai.springdemo.aispringdemo.service; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.Resource; +import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +@Service +public class ChatService { + + private static final Logger logger = LoggerFactory.getLogger(ChatService.class); + + @Value("classpath:/prompts/system-qa.st") + private Resource qaSystemPromptResource; + + @Value("classpath:/prompts/system-chatbot.st") + private Resource chatbotSystemPromptResource; + + private final ChatClient chatClient; + + + private final VectorStore vectorStore; + + @Autowired + public ChatService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) { + this.chatClient = chatClientBuilder.build(); + this.vectorStore = vectorStore; + } + + public Flux chat(String userMessageText) { + Message systemMessage = createSystemMessage(userMessageText); + UserMessage userMessage = new UserMessage(userMessageText); + Prompt prompt = createPrompt(systemMessage, userMessage); + + logger.info("Sending prompt to AI model."); + + return chatClient.prompt(prompt) + .stream() + .chatResponse() + .flatMap(this::processChatResponse) + .doOnComplete(() -> logger.info("AI response processing completed.")) + .doOnError(error -> logger.error("Error during AI response processing", error)); + } + + private Message createSystemMessage(String userMessageText) { + logger.info("Fetching relevant documents for message: {}", userMessageText); + List relevantDocuments = vectorStore.similaritySearch(userMessageText); + logger.info("Found {} relevant documents.", relevantDocuments.size()); + String documentsContent = relevantDocuments.stream() + .map(Document::getContent) + .collect(Collectors.joining("\n")); + return new SystemPromptTemplate(qaSystemPromptResource) + .createMessage(Map.of("documents", documentsContent)); + } + + private Prompt createPrompt(Message systemMessage, UserMessage userMessage) { + return new Prompt(List.of(systemMessage, userMessage)); + } + + private Flux processChatResponse(ChatResponse chatResponse) { + return Optional.ofNullable(chatResponse) + .filter(response -> response.getResults() != null && !response.getResults().isEmpty()) + .map(response -> response.getResults().getFirst()) + .map(result -> result.getOutput().getContent()) + .filter(content -> content != null) + .map(Flux::just) + .orElseGet(() -> { + logger.warn("Received an empty or null response."); + return Flux.empty(); + }); + } +} diff --git a/src/main/java/com/ai/springdemo/aispringdemo/service/QuestionService.java b/src/main/java/com/ai/springdemo/aispringdemo/service/QuestionService.java deleted file mode 100644 index 9190b47..0000000 --- a/src/main/java/com/ai/springdemo/aispringdemo/service/QuestionService.java +++ /dev/null @@ -1,76 +0,0 @@ -package com.ai.springdemo.aispringdemo.service; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.document.Document; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.core.io.Resource; -import org.springframework.stereotype.Service; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -@Service -public class QuestionService { - - private static final Logger logger = LoggerFactory.getLogger(QuestionService.class); - - @Value("classpath:/prompts/system-qa.st") - private Resource qaSystemPromptResource; - - @Value("classpath:/prompts/system-chatbot.st") - private Resource chatbotSystemPromptResource; - - private final ChatClient chatClient; - - - private final VectorStore vectorStore; - - @Autowired - public QuestionService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) { - this.chatClient = chatClientBuilder.build(); - this.vectorStore = vectorStore; - } - - public String question(String message, boolean prompstuff) { - Message systemMessage = getSystemMessage(message, prompstuff); - UserMessage userMessage = new UserMessage(message); - Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - - logger.info("Asking AI model to reply to question."); - ChatResponse chatResponse = chatClient.prompt(prompt).call().chatResponse(); - logger.info("AI responded."); - return chatResponse.getResults().getFirst().getOutput().getContent(); - } - - private Message getSystemMessage(String message, boolean prompstuff) { - if (prompstuff) { - logger.info("Retrieving relevant documents"); - List similarDocuments = vectorStore.similaritySearch(message); - logger.info(String.format("Found %s relevant documents.", similarDocuments.size())); - String documents = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n")); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.qaSystemPromptResource); - return systemPromptTemplate.createMessage(Map.of("documents", documents)); - } else { - logger.info("Not stuffing the prompt, using generic prompt"); - return new SystemPromptTemplate(this.chatbotSystemPromptResource).createMessage(); - } - } - - - public String simpleQuestion(String question) { - Message systemMessage = getSystemMessage(question, true); - UserMessage userMessage = new UserMessage(question); - Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); - return chatClient.prompt(prompt).call().chatResponse().getResults().getFirst().getOutput().getContent(); - } -} diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 8885781..c8fcdd2 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -7,6 +7,9 @@ spring.datasource.url=jdbc:postgresql://localhost:5433/vector_db spring.datasource.username=postgres spring.datasource.password=admin +spring.ai.vectorstore.pgvector.remove-existing-vector-store-table=false +spring.ai.vectorstore.pgvector.initialize-schema=true + spring.threads.virtual.enabled=true spring.ai.openai.api-key={OPEN_AI_API_KEY}