Skip to content

Commit

Permalink
feat(ollama): Add streaming support for function calls and improve Ol…
Browse files Browse the repository at this point in the history
…lamaApi

- Implement streaming tool call support in OllamaApi and OllamaChatModel
- Add OllamaApiHelper to manage merging of streaming chat response chunks
- Remove @disabled annotations for streaming function call tests
- Update documentation to reflect new streaming function call capabilities
- Add a new default constructor for ChatResponse
- Update Ollama chat documentation to clarify streaming support requirements
- Deprecated withContent(), withImages(), and withToolCalls() methods
- Replaced with content(), images(), and toolCalls() methods

Add token and duration aggregation for Ollama chat responses

- Modify OllamaChatModel to support accumulating tokens and durations across multiple responses
- Update ChatResponse metadata generation to aggregate usage and duration metrics
- Add tests to verify metadata aggregation behavior

Refactor Ollama duration fields and tests

- Replace Duration fields in OllamaApi.ChatResponse with Long to represent durations in nanoseconds, ensuring precision and compatibility.
- Update methods to convert Long nanoseconds to Duration objects (getTotalDuration, getLoadDuration, getEvalDuration, getPromptEvalDuration).
- Adjust merge logic in OllamaApiHelper to sum Long values for duration fields.
- Modify test cases in OllamaChatModelTests to align with Long duration representation and Duration.ofNanos conversions.
- Add new test class OllamaDurationFieldsTests to validate JSON deserialization and Duration conversion for duration fields.

Resolves #1847
Related to #1800
Resolves #1796
Related to #1307
  • Loading branch information
tzolov committed Nov 29, 2024
1 parent cdffc72 commit e335beb
Show file tree
Hide file tree
Showing 12 changed files with 465 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.ollama;

import java.time.Duration;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
Expand All @@ -33,6 +34,7 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -110,24 +112,62 @@ public static Builder builder() {
return new Builder();
}

public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse perviousChatResponse) {
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");

OllamaChatUsage newUsage = OllamaChatUsage.from(response);
Long promptTokens = newUsage.getPromptTokens();
Long generationTokens = newUsage.getGenerationTokens();
Long totalTokens = newUsage.getTotalTokens();

Duration evalDuration = response.getEvalDuration();
Duration promptEvalDuration = response.getPromptEvalDuration();
Duration loadDuration = response.getLoadDuration();
Duration totalDuration = response.getTotalDuration();

if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null) {
if (perviousChatResponse.getMetadata().get("eval-duration") != null) {
evalDuration = evalDuration.plus(perviousChatResponse.getMetadata().get("eval-duration"));
}
if (perviousChatResponse.getMetadata().get("prompt-eval-duration") != null) {
promptEvalDuration = promptEvalDuration
.plus(perviousChatResponse.getMetadata().get("prompt-eval-duration"));
}
if (perviousChatResponse.getMetadata().get("load-duration") != null) {
loadDuration = loadDuration.plus(perviousChatResponse.getMetadata().get("load-duration"));
}
if (perviousChatResponse.getMetadata().get("total-duration") != null) {
totalDuration = totalDuration.plus(perviousChatResponse.getMetadata().get("total-duration"));
}
if (perviousChatResponse.getMetadata().getUsage() != null) {
promptTokens += perviousChatResponse.getMetadata().getUsage().getPromptTokens();
generationTokens += perviousChatResponse.getMetadata().getUsage().getGenerationTokens();
totalTokens += perviousChatResponse.getMetadata().getUsage().getTotalTokens();
}
}

DefaultUsage aggregatedUsage = new DefaultUsage(promptTokens, generationTokens, totalTokens);

return ChatResponseMetadata.builder()
.withUsage(OllamaChatUsage.from(response))
.withUsage(aggregatedUsage)
.withModel(response.model())
.withKeyValue("created-at", response.createdAt())
.withKeyValue("eval-duration", response.evalDuration())
.withKeyValue("eval-count", response.evalCount())
.withKeyValue("load-duration", response.loadDuration())
.withKeyValue("prompt-eval-duration", response.promptEvalDuration())
.withKeyValue("prompt-eval-count", response.promptEvalCount())
.withKeyValue("total-duration", response.totalDuration())
.withKeyValue("eval-duration", evalDuration)
.withKeyValue("eval-count", aggregatedUsage.getGenerationTokens().intValue())
.withKeyValue("load-duration", loadDuration)
.withKeyValue("prompt-eval-duration", promptEvalDuration)
.withKeyValue("prompt-eval-count", aggregatedUsage.getPromptTokens().intValue())
.withKeyValue("total-duration", totalDuration)
.withKeyValue("done", response.done())
.build();
}

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false);

Expand Down Expand Up @@ -162,7 +202,8 @@ public ChatResponse call(Prompt prompt) {
}

var generator = new Generation(assistantMessage, generationMetadata);
ChatResponse chatResponse = new ChatResponse(List.of(generator), from(ollamaResponse));
ChatResponse chatResponse = new ChatResponse(List.of(generator),
from(ollamaResponse, previousChatResponse));

observationContext.setResponse(chatResponse);

Expand All @@ -175,14 +216,18 @@ && isToolCall(response, Set.of("stop"))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}

return response;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true);

Expand Down Expand Up @@ -223,7 +268,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
}

var generator = new Generation(assistantMessage, generationMetadata);
return new ChatResponse(List.of(generator), from(chunk));
return new ChatResponse(List.of(generator), from(chunk, previousChatResponse));
});

// @formatter:off
Expand All @@ -232,7 +277,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
else {
return Flux.just(response);
Expand All @@ -256,15 +301,15 @@ OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {

List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
if (message instanceof UserMessage userMessage) {
var messageBuilder = OllamaApi.Message.builder(Role.USER).withContent(message.getContent());
var messageBuilder = OllamaApi.Message.builder(Role.USER).content(message.getContent());
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
messageBuilder.withImages(
messageBuilder.images(
userMessage.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList());
}
return List.of(messageBuilder.build());
}
else if (message instanceof SystemMessage systemMessage) {
return List.of(OllamaApi.Message.builder(Role.SYSTEM).withContent(systemMessage.getContent()).build());
return List.of(OllamaApi.Message.builder(Role.SYSTEM).content(systemMessage.getContent()).build());
}
else if (message instanceof AssistantMessage assistantMessage) {
List<ToolCall> toolCalls = null;
Expand All @@ -276,14 +321,14 @@ else if (message instanceof AssistantMessage assistantMessage) {
}).toList();
}
return List.of(OllamaApi.Message.builder(Role.ASSISTANT)
.withContent(assistantMessage.getContent())
.withToolCalls(toolCalls)
.content(assistantMessage.getContent())
.toolCalls(toolCalls)
.build());
}
else if (message instanceof ToolResponseMessage toolMessage) {
return toolMessage.getResponses()
.stream()
.map(tr -> OllamaApi.Message.builder(Role.TOOL).withContent(tr.responseData()).build())
.map(tr -> OllamaApi.Message.builder(Role.TOOL).content(tr.responseData()).build())
.toList();
}
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonFormat.Feature;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -133,11 +136,40 @@ public Flux<ChatResponse> streamingChat(ChatRequest chatRequest) {
Assert.notNull(chatRequest, REQUEST_BODY_NULL_ERROR);
Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true.");

AtomicBoolean isInsideTool = new AtomicBoolean(false);

return this.webClient.post()
.uri("/api/chat")
.body(Mono.just(chatRequest), ChatRequest.class)
.retrieve()
.bodyToFlux(ChatResponse.class)
.map(chunk -> {
if (OllamaApiHelper.isStreamingToolCall(chunk)) {
isInsideTool.set(true);
}
return chunk;
})
// Group all chunks belonging to the same function call.
// Flux<ChatChatResponse> -> Flux<Flux<ChatChatResponse>>
.windowUntil(chunk -> {
if (isInsideTool.get() && OllamaApiHelper.isStreamingDone(chunk)) {
isInsideTool.set(false);
return true;
}
return !isInsideTool.get();
})
// Merging the window chunks into a single chunk.
// Reduce the inner Flux<ChatChatResponse> window into a single
// Mono<ChatChatResponse>,
// Flux<Flux<ChatChatResponse>> -> Flux<Mono<ChatChatResponse>>
.concatMapIterable(window -> {
Mono<ChatResponse> monoChunk = window.reduce(
new ChatResponse(),
(previous, current) -> OllamaApiHelper.merge(previous, current));
return List.of(monoChunk);
})
// Flux<Mono<ChatChatResponse>> -> Flux<ChatChatResponse>
.flatMap(mono -> mono)
.handle((data, sink) -> {
if (logger.isTraceEnabled()) {
logger.trace(data);
Expand Down Expand Up @@ -332,25 +364,51 @@ public Builder(Role role) {
this.role = role;
}

/**
* @deprecated Use {@link #content(String)} instead.
*/
@Deprecated
public Builder withContent(String content) {
this.content = content;
return this;
}

public Builder content(String content) {
this.content = content;
return this;
}

/**
* @deprecated Use {@link #images(List)} instead.
*/
@Deprecated
public Builder withImages(List<String> images) {
this.images = images;
return this;
}

public Builder images(List<String> images) {
this.images = images;
return this;
}

/**
* @deprecated Use {@link #toolCalls(List)} instead.
*/
@Deprecated
public Builder withToolCalls(List<ToolCall> toolCalls) {
this.toolCalls = toolCalls;
return this;
}

public Builder toolCalls(List<ToolCall> toolCalls) {
this.toolCalls = toolCalls;
return this;
}

public Message build() {
return new Message(this.role, this.content, this.images, this.toolCalls);
}

}
}

Expand Down Expand Up @@ -538,13 +596,36 @@ public record ChatResponse(
@JsonProperty("message") Message message,
@JsonProperty("done_reason") String doneReason,
@JsonProperty("done") Boolean done,
@JsonProperty("total_duration") Duration totalDuration,
@JsonProperty("load_duration") Duration loadDuration,
@JsonProperty("total_duration") Long totalDuration,
@JsonProperty("load_duration") Long loadDuration,
@JsonProperty("prompt_eval_count") Integer promptEvalCount,
@JsonProperty("prompt_eval_duration") Duration promptEvalDuration,
@JsonProperty("prompt_eval_duration") Long promptEvalDuration,
@JsonProperty("eval_count") Integer evalCount,
@JsonProperty("eval_duration") Duration evalDuration
@JsonProperty("eval_duration") Long evalDuration
) {
ChatResponse() {
this(null, null, null, null, null, null, null, null, null, null, null);
}

public Duration getTotalDuration() {
return (this.totalDuration() != null)? Duration.ofNanos(this.totalDuration()) : null;
}

public Duration getLoadDuration() {
return (this.loadDuration() != null)? Duration.ofNanos(this.loadDuration()) : null;
}

public Duration getPromptEvalDuration() {
return (this.promptEvalDuration() != null)? Duration.ofNanos(this.promptEvalDuration()) : null;
}

public Duration getEvalDuration() {
if (this.evalDuration() == null) {
return null;
}
return Duration.ofNanos(this.evalDuration());
// return (this.evalDuration() != null)? Duration.ofNanos(this.evalDuration()) : null;
}
}

/**
Expand Down
Loading

0 comments on commit e335beb

Please sign in to comment.