diff --git a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentContext.cs b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentContext.cs index d93b6246765..ab5972730fb 100644 --- a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentContext.cs +++ b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentContext.cs @@ -12,9 +12,9 @@ public interface IAgentContext IAgentBase? AgentInstance { get; set; } DistributedContextPropagator DistributedContextPropagator { get; } // TODO: Remove this. An abstraction should not have a dependency on DistributedContextPropagator. ILogger Logger { get; } // TODO: Remove this. An abstraction should not have a dependency on ILogger. - ValueTask Store(AgentState value); - ValueTask Read(AgentId agentId); - ValueTask SendResponseAsync(RpcRequest request, RpcResponse response); - ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request); - ValueTask PublishEventAsync(CloudEvent @event); + ValueTask Store(AgentState value, CancellationToken cancellationToken = default); + ValueTask Read(AgentId agentId, CancellationToken cancellationToken = default); + ValueTask SendResponseAsync(RpcRequest request, RpcResponse response, CancellationToken cancellationToken = default); + ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default); + ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentWorkerRuntime.cs b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentWorkerRuntime.cs index 1a255e13234..c03259f722f 100644 --- a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentWorkerRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentWorkerRuntime.cs @@ -5,9 +5,9 @@ namespace Microsoft.AutoGen.Abstractions; public interface IAgentWorkerRuntime { - ValueTask PublishEvent(CloudEvent evt); - ValueTask SendRequest(IAgentBase agent, RpcRequest request); - ValueTask SendResponse(RpcResponse response); - ValueTask Store(AgentState value); - ValueTask Read(AgentId agentId); + ValueTask PublishEvent(CloudEvent evt, CancellationToken cancellationToken); + ValueTask SendRequest(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken); + ValueTask SendResponse(RpcResponse response, CancellationToken cancellationToken); + ValueTask Store(AgentState value, CancellationToken cancellationToken); + ValueTask Read(AgentId agentId, CancellationToken cancellationToken); } diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs index af06c84e9ba..baa7ee201ed 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs @@ -202,15 +202,14 @@ public async ValueTask PublishEvent(CloudEvent item) var activity = s_source.StartActivity($"PublishEvent '{item.Type}'", ActivityKind.Client, Activity.Current?.Context ?? default); activity?.SetTag("peer.service", $"{item.Type}/{item.Source}"); - var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); // TODO: fix activity Context.DistributedContextPropagator.Inject(activity, item.Metadata, static (carrier, key, value) => ((IDictionary)carrier!)[key] = value); await this.InvokeWithActivityAsync( - static async ((AgentBase Agent, CloudEvent Event, TaskCompletionSource) state) => + static async ((AgentBase Agent, CloudEvent Event) state) => { await state.Agent._context.PublishEventAsync(state.Event).ConfigureAwait(false); }, - (this, item, completion), + (this, item), activity, item.Type).ConfigureAwait(false); } diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AgentContext.cs b/dotnet/src/Microsoft.AutoGen/Agents/AgentContext.cs index 325bc33a11d..7de1e6565d3 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/AgentContext.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/AgentContext.cs @@ -15,25 +15,25 @@ internal sealed class AgentContext(AgentId agentId, IAgentWorkerRuntime runtime, public ILogger Logger { get; } = logger; public IAgentBase? AgentInstance { get; set; } public DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator; - public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse response) + public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse response, CancellationToken cancellationToken) { response.RequestId = request.RequestId; - await _runtime.SendResponse(response); + await _runtime.SendResponse(response, cancellationToken).ConfigureAwait(false); } - public async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request) + public async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken) { - await _runtime.SendRequest(agent, request).ConfigureAwait(false); + await _runtime.SendRequest(agent, request, cancellationToken).ConfigureAwait(false); } - public async ValueTask PublishEventAsync(CloudEvent @event) + public async ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken) { - await _runtime.PublishEvent(@event).ConfigureAwait(false); + await _runtime.PublishEvent(@event, cancellationToken).ConfigureAwait(false); } - public async ValueTask Store(AgentState value) + public async ValueTask Store(AgentState value, CancellationToken cancellationToken) { - await _runtime.Store(value).ConfigureAwait(false); + await _runtime.Store(value, cancellationToken).ConfigureAwait(false); } - public ValueTask Read(AgentId agentId) + public ValueTask Read(AgentId agentId, CancellationToken cancellationToken) { - return _runtime.Read(agentId); + return _runtime.Read(agentId, cancellationToken); } } diff --git a/dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs b/dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs index 193f9dd2b63..b0550c1fb71 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs @@ -228,13 +228,13 @@ await WriteChannelAsync(new Message } } - public async ValueTask SendResponse(RpcResponse response) + public async ValueTask SendResponse(RpcResponse response, CancellationToken cancellationToken) { _logger.LogInformation("Sending response '{Response}'.", response); - await WriteChannelAsync(new Message { Response = response }).ConfigureAwait(false); + await WriteChannelAsync(new Message { Response = response }, cancellationToken).ConfigureAwait(false); } - public async ValueTask SendRequest(IAgentBase agent, RpcRequest request) + public async ValueTask SendRequest(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken) { _logger.LogInformation("[{AgentId}] Sending request '{Request}'.", agent.AgentId, request); var requestId = Guid.NewGuid().ToString(); @@ -242,7 +242,7 @@ public async ValueTask SendRequest(IAgentBase agent, RpcRequest request) request.RequestId = requestId; try { - await WriteChannelAsync(new Message { Request = request }).ConfigureAwait(false); + await WriteChannelAsync(new Message { Request = request }, cancellationToken).ConfigureAwait(false); } catch (Exception exception) { @@ -253,11 +253,11 @@ public async ValueTask SendRequest(IAgentBase agent, RpcRequest request) } } - public async ValueTask PublishEvent(CloudEvent @event) + public async ValueTask PublishEvent(CloudEvent @event, CancellationToken cancellationToken) { try { - await WriteChannelAsync(new Message { CloudEvent = @event }).ConfigureAwait(false); + await WriteChannelAsync(new Message { CloudEvent = @event }, cancellationToken).ConfigureAwait(false); } catch (Exception exception) { @@ -265,7 +265,7 @@ public async ValueTask PublishEvent(CloudEvent @event) } } - private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken = default) + private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken) { var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); await _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellationToken).ConfigureAwait(false); @@ -364,19 +364,19 @@ public async Task StopAsync(CancellationToken cancellationToken) _channel?.Dispose(); } } - public ValueTask Store(AgentState value) + public ValueTask Store(AgentState value, CancellationToken cancellationToken) { var agentId = value.AgentId ?? throw new InvalidOperationException("AgentId is required when saving AgentState."); - var response = _client.SaveState(value); + var response = _client.SaveState(value, cancellationToken: cancellationToken); if (!response.Success) { throw new InvalidOperationException($"Error saving AgentState for AgentId {agentId}."); } return ValueTask.CompletedTask; } - public async ValueTask Read(AgentId agentId) + public async ValueTask Read(AgentId agentId, CancellationToken cancellationToken) { - var response = await _client.GetStateAsync(agentId); + var response = await _client.GetStateAsync(agentId, cancellationToken: cancellationToken); // if (response.Success && response.AgentState.AgentId is not null) - why is success always false? if (response.AgentState.AgentId is not null) { diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 37d4646c685..8ef47806ac4 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -266,8 +266,8 @@ async def on_messages_stream( while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name)) # Add the tool call message to the output. - inner_messages.append(ToolCallMessage(content=result.content, source=self.name)) - yield ToolCallMessage(content=result.content, source=self.name) + inner_messages.append(ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage)) + yield ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage) # Execute the tool calls. results = await asyncio.gather( @@ -303,7 +303,8 @@ async def on_messages_stream( assert isinstance(result.content, str) yield Response( - chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages + chat_message=TextMessage(content=result.content, source=self.name, model_usage=result.usage), + inner_messages=inner_messages, ) async def _execute_tool_call( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 51dbcca333d..c8037671e13 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -1,7 +1,7 @@ from typing import List from autogen_core.components import FunctionCall, Image -from autogen_core.components.models import FunctionExecutionResult +from autogen_core.components.models import FunctionExecutionResult, RequestUsage from pydantic import BaseModel @@ -11,6 +11,9 @@ class BaseMessage(BaseModel): source: str """The name of the agent that sent this message.""" + model_usage: RequestUsage | None = None + """The model client usage incurred when producing this message.""" + class TextMessage(BaseMessage): """A text message.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 4589f86860d..20556ad783c 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -78,7 +78,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -88,7 +88,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -100,7 +100,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] mock = _MockChatCompletion(chat_completions) @@ -113,9 +113,17 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: result = await tool_use_agent.run("task") assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].model_usage is None assert isinstance(result.messages[1], ToolCallMessage) + assert result.messages[1].model_usage is not None + assert result.messages[1].model_usage.completion_tokens == 5 + assert result.messages[1].model_usage.prompt_tokens == 10 assert isinstance(result.messages[2], ToolCallResultMessage) + assert result.messages[2].model_usage is None assert isinstance(result.messages[3], TextMessage) + assert result.messages[3].model_usage is not None + assert result.messages[3].model_usage.completion_tokens == 5 + assert result.messages[3].model_usage.prompt_tokens == 10 # Test streaming. mock._curr_index = 0 # pyright: ignore @@ -158,7 +166,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85), ), ] mock = _MockChatCompletion(chat_completions) @@ -173,9 +181,17 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: result = await tool_use_agent.run("task") assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].model_usage is None assert isinstance(result.messages[1], ToolCallMessage) + assert result.messages[1].model_usage is not None + assert result.messages[1].model_usage.completion_tokens == 43 + assert result.messages[1].model_usage.prompt_tokens == 42 assert isinstance(result.messages[2], ToolCallResultMessage) + assert result.messages[2].model_usage is None assert isinstance(result.messages[3], HandoffMessage) + assert result.messages[3].content == handoff.message + assert result.messages[3].target == handoff.target + assert result.messages[3].model_usage is None # Test streaming. mock._curr_index = 0 # pyright: ignore