Skip to content

Commit

Permalink
Update library Spring Ai 1.0.0-Snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnscr committed Sep 2, 2024
1 parent d4c9622 commit 2ac9a5a
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 112 deletions.
9 changes: 9 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
<spring-retry.version>2.0.3</spring-retry.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
Expand Down Expand Up @@ -67,6 +71,11 @@
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver-dns-native-macos</artifactId>
<classifier>osx-aarch_64</classifier>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.ai.springdemo.aispringdemo.controller;

import com.ai.springdemo.aispringdemo.exception.InvalidMessageException;
import com.ai.springdemo.aispringdemo.service.ChatService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;

@RestController
public class ChatController {

private static final Logger logger = LoggerFactory.getLogger(ChatController.class);

private final ChatService questionService;

@Autowired
public ChatController(ChatService questionService){
this.questionService = questionService;
}

@PostMapping("/chat")
public Flux<String> getQuestion(@RequestBody String question){

logger.info("Received question: {}", question);

if (question == null || question.trim().isEmpty()) {
logger.error("Invalid question received: empty or null message");
return Flux.error(new InvalidMessageException("Question cannot be null or empty"));
}

return questionService.chat(question)
.doOnError(error -> logger.error("Error processing question: {}", error.getMessage()));
}

@ExceptionHandler(InvalidMessageException.class)
public ResponseEntity<String> handleInvalidMessageException(InvalidMessageException e) {
logger.error("Invalid message error: {}", e.getMessage());
return new ResponseEntity<>(e.getMessage(), HttpStatus.BAD_REQUEST);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.springframework.web.bind.annotation.*;

@RestController
@RequestMapping("/data")
public class DataController {

private final DataLoadingService dataLoadingService;
Expand All @@ -18,7 +17,7 @@ public DataController(DataLoadingService dataLoadingService, JdbcTemplate jdbcTe
this.dataLoadingService = dataLoadingService;
}

@PostMapping("/load")
@GetMapping("/data-load")
public ResponseEntity<String> load() {
try {
this.dataLoadingService.load();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.ai.springdemo.aispringdemo.exception;

public class InvalidMessageException extends RuntimeException {

public InvalidMessageException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package com.ai.springdemo.aispringdemo.service;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

@Service
public class ChatService {

private static final Logger logger = LoggerFactory.getLogger(ChatService.class);

@Value("classpath:/prompts/system-qa.st")
private Resource qaSystemPromptResource;

@Value("classpath:/prompts/system-chatbot.st")
private Resource chatbotSystemPromptResource;

private final ChatClient chatClient;


private final VectorStore vectorStore;

@Autowired
public ChatService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
this.chatClient = chatClientBuilder.build();
this.vectorStore = vectorStore;
}

public Flux<String> chat(String userMessageText) {
Message systemMessage = createSystemMessage(userMessageText);
UserMessage userMessage = new UserMessage(userMessageText);
Prompt prompt = createPrompt(systemMessage, userMessage);

logger.info("Sending prompt to AI model.");

return chatClient.prompt(prompt)
.stream()
.chatResponse()
.flatMap(this::processChatResponse)
.doOnComplete(() -> logger.info("AI response processing completed."))
.doOnError(error -> logger.error("Error during AI response processing", error));
}

private Message createSystemMessage(String userMessageText) {
logger.info("Fetching relevant documents for message: {}", userMessageText);
List<Document> relevantDocuments = vectorStore.similaritySearch(userMessageText);
logger.info("Found {} relevant documents.", relevantDocuments.size());
String documentsContent = relevantDocuments.stream()
.map(Document::getContent)
.collect(Collectors.joining("\n"));
return new SystemPromptTemplate(qaSystemPromptResource)
.createMessage(Map.of("documents", documentsContent));
}

private Prompt createPrompt(Message systemMessage, UserMessage userMessage) {
return new Prompt(List.of(systemMessage, userMessage));
}

private Flux<String> processChatResponse(ChatResponse chatResponse) {
return Optional.ofNullable(chatResponse)
.filter(response -> response.getResults() != null && !response.getResults().isEmpty())
.map(response -> response.getResults().getFirst())
.map(result -> result.getOutput().getContent())
.filter(content -> content != null)
.map(Flux::just)
.orElseGet(() -> {
logger.warn("Received an empty or null response.");
return Flux.empty();
});
}
}

This file was deleted.

3 changes: 3 additions & 0 deletions src/main/resources/application.properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ spring.datasource.url=jdbc:postgresql://localhost:5433/vector_db
spring.datasource.username=postgres
spring.datasource.password=admin

spring.ai.vectorstore.pgvector.remove-existing-vector-store-table=false
spring.ai.vectorstore.pgvector.initialize-schema=true

spring.threads.virtual.enabled=true

spring.ai.openai.api-key={OPEN_AI_API_KEY}
Expand Down

0 comments on commit 2ac9a5a

Please sign in to comment.