Skip to content

Commit

Permalink
Merge branch 'u/MEAI' of https://github.com/microsoft/autogen into u/…
Browse files Browse the repository at this point in the history
…MEAI
  • Loading branch information
LittleLittleCloud committed Nov 1, 2024
2 parents cd6eaf4 + c01868d commit 0d53364
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 95 deletions.
10 changes: 5 additions & 5 deletions dotnet/src/Microsoft.AutoGen/Abstractions/IAgentContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ public interface IAgentContext
IAgentBase? AgentInstance { get; set; }
DistributedContextPropagator DistributedContextPropagator { get; } // TODO: Remove this. An abstraction should not have a dependency on DistributedContextPropagator.
ILogger Logger { get; } // TODO: Remove this. An abstraction should not have a dependency on ILogger.
ValueTask Store(AgentState value);
ValueTask<AgentState> Read(AgentId agentId);
ValueTask SendResponseAsync(RpcRequest request, RpcResponse response);
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request);
ValueTask PublishEventAsync(CloudEvent @event);
ValueTask Store(AgentState value, CancellationToken cancellationToken = default);
ValueTask<AgentState> Read(AgentId agentId, CancellationToken cancellationToken = default);
ValueTask SendResponseAsync(RpcRequest request, RpcResponse response, CancellationToken cancellationToken = default);
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default);
ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken = default);
}
10 changes: 5 additions & 5 deletions dotnet/src/Microsoft.AutoGen/Abstractions/IAgentWorkerRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ namespace Microsoft.AutoGen.Abstractions;

public interface IAgentWorkerRuntime
{
ValueTask PublishEvent(CloudEvent evt);
ValueTask SendRequest(IAgentBase agent, RpcRequest request);
ValueTask SendResponse(RpcResponse response);
ValueTask Store(AgentState value);
ValueTask<AgentState> Read(AgentId agentId);
ValueTask PublishEvent(CloudEvent evt, CancellationToken cancellationToken);
ValueTask SendRequest(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken);
ValueTask SendResponse(RpcResponse response, CancellationToken cancellationToken);
ValueTask Store(AgentState value, CancellationToken cancellationToken);
ValueTask<AgentState> Read(AgentId agentId, CancellationToken cancellationToken);
}
5 changes: 2 additions & 3 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,14 @@ public async ValueTask PublishEvent(CloudEvent item)
var activity = s_source.StartActivity($"PublishEvent '{item.Type}'", ActivityKind.Client, Activity.Current?.Context ?? default);
activity?.SetTag("peer.service", $"{item.Type}/{item.Source}");

var completion = new TaskCompletionSource<CloudEvent>(TaskCreationOptions.RunContinuationsAsynchronously);
// TODO: fix activity
Context.DistributedContextPropagator.Inject(activity, item.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
await this.InvokeWithActivityAsync(
static async ((AgentBase Agent, CloudEvent Event, TaskCompletionSource<CloudEvent>) state) =>
static async ((AgentBase Agent, CloudEvent Event) state) =>
{
await state.Agent._context.PublishEventAsync(state.Event).ConfigureAwait(false);
},
(this, item, completion),
(this, item),
activity,
item.Type).ConfigureAwait(false);
}
Expand Down
20 changes: 10 additions & 10 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@ internal sealed class AgentContext(AgentId agentId, IAgentWorkerRuntime runtime,
public ILogger Logger { get; } = logger;
public IAgentBase? AgentInstance { get; set; }
public DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse response)
public async ValueTask SendResponseAsync(RpcRequest request, RpcResponse response, CancellationToken cancellationToken)
{
response.RequestId = request.RequestId;
await _runtime.SendResponse(response);
await _runtime.SendResponse(response, cancellationToken).ConfigureAwait(false);
}
public async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request)
public async ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken)
{
await _runtime.SendRequest(agent, request).ConfigureAwait(false);
await _runtime.SendRequest(agent, request, cancellationToken).ConfigureAwait(false);
}
public async ValueTask PublishEventAsync(CloudEvent @event)
public async ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken)
{
await _runtime.PublishEvent(@event).ConfigureAwait(false);
await _runtime.PublishEvent(@event, cancellationToken).ConfigureAwait(false);
}
public async ValueTask Store(AgentState value)
public async ValueTask Store(AgentState value, CancellationToken cancellationToken)
{
await _runtime.Store(value).ConfigureAwait(false);
await _runtime.Store(value, cancellationToken).ConfigureAwait(false);
}
public ValueTask<AgentState> Read(AgentId agentId)
public ValueTask<AgentState> Read(AgentId agentId, CancellationToken cancellationToken)
{
return _runtime.Read(agentId);
return _runtime.Read(agentId, cancellationToken);
}
}
64 changes: 47 additions & 17 deletions dotnet/src/Microsoft.AutoGen/Agents/GrpcAgentWorkerRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public sealed class GrpcAgentWorkerRuntime : IHostedService, IDisposable, IAgent
private readonly ConcurrentDictionary<string, Type> _agentTypes = new();
private readonly ConcurrentDictionary<(string Type, string Key), IAgentBase> _agents = new();
private readonly ConcurrentDictionary<string, (IAgentBase Agent, string OriginalRequestId)> _pendingRequests = new();
private readonly Channel<Message> _outboundMessagesChannel = Channel.CreateBounded<Message>(new BoundedChannelOptions(1024)
private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024)
{
AllowSynchronousContinuations = true,
SingleReader = true,
Expand Down Expand Up @@ -138,40 +138,50 @@ private async Task RunWritePump()
var outboundMessages = _outboundMessagesChannel.Reader;
while (!_shutdownCts.IsCancellationRequested)
{
(Message Message, TaskCompletionSource WriteCompletionSource) item = default;
try
{
await outboundMessages.WaitToReadAsync().ConfigureAwait(false);

// Read the next message if we don't already have an unsent message
// waiting to be sent.
if (!outboundMessages.TryRead(out var message))
if (!outboundMessages.TryRead(out item))
{
break;
}

while (!_shutdownCts.IsCancellationRequested)
{
await channel.RequestStream.WriteAsync(message, _shutdownCts.Token).ConfigureAwait(false);
await channel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false);
item.WriteCompletionSource.TrySetResult();
break;
}
}
catch (OperationCanceledException)
{
// Time to shut down.
item.WriteCompletionSource?.TrySetCanceled();
break;
}
catch (Exception ex) when (!_shutdownCts.IsCancellationRequested)
{
item.WriteCompletionSource?.TrySetException(ex);
_logger.LogError(ex, "Error writing to channel.");
channel = RecreateChannel(channel);
continue;
}
catch
{
// Shutdown requested.
item.WriteCompletionSource?.TrySetCanceled();
break;
}
}

while (outboundMessages.TryRead(out var item))
{
item.WriteCompletionSource.TrySetCanceled();
}
}

private IAgentBase GetOrActivateAgent(AgentId agentId)
Expand Down Expand Up @@ -213,33 +223,53 @@ await WriteChannelAsync(new Message
//StateType = state?.Name,
//Events = { events }
}
}).ConfigureAwait(false);
},
_shutdownCts.Token).ConfigureAwait(false);
}
}

public async ValueTask SendResponse(RpcResponse response)
public async ValueTask SendResponse(RpcResponse response, CancellationToken cancellationToken)
{
_logger.LogInformation("Sending response '{Response}'.", response);
await WriteChannelAsync(new Message { Response = response }).ConfigureAwait(false);
await WriteChannelAsync(new Message { Response = response }, cancellationToken).ConfigureAwait(false);
}

public async ValueTask SendRequest(IAgentBase agent, RpcRequest request)
public async ValueTask SendRequest(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken)
{
_logger.LogInformation("[{AgentId}] Sending request '{Request}'.", agent.AgentId, request);
var requestId = Guid.NewGuid().ToString();
_pendingRequests[requestId] = (agent, request.RequestId);
request.RequestId = requestId;
await WriteChannelAsync(new Message { Request = request }).ConfigureAwait(false);
try
{
await WriteChannelAsync(new Message { Request = request }, cancellationToken).ConfigureAwait(false);
}
catch (Exception exception)
{
if (_pendingRequests.TryRemove(requestId, out _))
{
agent.ReceiveMessage(new Message { Response = new RpcResponse { RequestId = request.RequestId, Error = exception.Message } });
}
}
}

public async ValueTask PublishEvent(CloudEvent @event)
public async ValueTask PublishEvent(CloudEvent @event, CancellationToken cancellationToken)
{
await WriteChannelAsync(new Message { CloudEvent = @event }).ConfigureAwait(false);
try
{
await WriteChannelAsync(new Message { CloudEvent = @event }, cancellationToken).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.LogWarning(exception, "Failed to publish event '{Event}'.", @event);
}
}

private async Task WriteChannelAsync(Message message)
private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken)
{
await _outboundMessagesChannel.Writer.WriteAsync(message).ConfigureAwait(false);
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
await _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellationToken).ConfigureAwait(false);
await tcs.Task.WaitAsync(cancellationToken);
}

private AsyncDuplexStreamingCall<Message, Message> GetChannel()
Expand Down Expand Up @@ -269,7 +299,7 @@ private AsyncDuplexStreamingCall<Message, Message> RecreateChannel(AsyncDuplexSt
if (_channel is null || _channel == channel)
{
_channel?.Dispose();
_channel = _client.OpenChannel();
_channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token);
}
}
}
Expand Down Expand Up @@ -334,19 +364,19 @@ public async Task StopAsync(CancellationToken cancellationToken)
_channel?.Dispose();
}
}
public ValueTask Store(AgentState value)
public ValueTask Store(AgentState value, CancellationToken cancellationToken)
{
var agentId = value.AgentId ?? throw new InvalidOperationException("AgentId is required when saving AgentState.");
var response = _client.SaveState(value);
var response = _client.SaveState(value, cancellationToken: cancellationToken);
if (!response.Success)
{
throw new InvalidOperationException($"Error saving AgentState for AgentId {agentId}.");
}
return ValueTask.CompletedTask;
}
public async ValueTask<AgentState> Read(AgentId agentId)
public async ValueTask<AgentState> Read(AgentId agentId, CancellationToken cancellationToken)
{
var response = await _client.GetStateAsync(agentId);
var response = await _client.GetStateAsync(agentId, cancellationToken: cancellationToken);
// if (response.Success && response.AgentState.AgentId is not null) - why is success always false?
if (response.AgentState.AgentId is not null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
HandoffMessage,
InnerMessage,
ResetMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
Expand Down Expand Up @@ -232,8 +231,8 @@ def __init__(
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
if self._handoffs:
return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage]
return [TextMessage, HandoffMessage]
return [TextMessage]

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
Expand Down Expand Up @@ -267,8 +266,8 @@ async def on_messages_stream(
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
yield ToolCallMessage(content=result.content, source=self.name)
inner_messages.append(ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage))
yield ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage)

# Execute the tool calls.
results = await asyncio.gather(
Expand Down Expand Up @@ -303,16 +302,10 @@ async def on_messages_stream(
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

assert isinstance(result.content, str)
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
yield Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
else:
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, model_usage=result.usage),
inner_messages=inner_messages,
)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel


Expand All @@ -11,6 +11,9 @@ class BaseMessage(BaseModel):
source: str
"""The name of the agent that sent this message."""

model_usage: RequestUsage | None = None
"""The model client usage incurred when producing this message."""


class TextMessage(BaseMessage):
"""A text message."""
Expand Down
24 changes: 20 additions & 4 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
Expand All @@ -88,7 +88,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
Expand All @@ -100,7 +100,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
Expand All @@ -113,9 +113,17 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
result = await tool_use_agent.run("task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].model_usage is None
assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[1].model_usage is not None
assert result.messages[1].model_usage.completion_tokens == 5
assert result.messages[1].model_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].model_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].model_usage is not None
assert result.messages[3].model_usage.completion_tokens == 5
assert result.messages[3].model_usage.prompt_tokens == 10

# Test streaming.
mock._curr_index = 0 # pyright: ignore
Expand Down Expand Up @@ -158,7 +166,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85),
),
]
mock = _MockChatCompletion(chat_completions)
Expand All @@ -173,9 +181,17 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
result = await tool_use_agent.run("task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].model_usage is None
assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[1].model_usage is not None
assert result.messages[1].model_usage.completion_tokens == 43
assert result.messages[1].model_usage.prompt_tokens == 42
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].model_usage is None
assert isinstance(result.messages[3], HandoffMessage)
assert result.messages[3].content == handoff.message
assert result.messages[3].target == handoff.target
assert result.messages[3].model_usage is None

# Test streaming.
mock._curr_index = 0 # pyright: ignore
Expand Down
Loading

0 comments on commit 0d53364

Please sign in to comment.