Skip to content

Commit

Permalink
add MEAI tool support
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud committed Nov 1, 2024
1 parent 173acc6 commit 77b737f
Show file tree
Hide file tree
Showing 13 changed files with 356 additions and 17 deletions.
9 changes: 8 additions & 1 deletion dotnet/AutoGen.sln
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{243E768F-EA7D-4AF1-B625-0398440BB1AB}"
ProjectSection(SolutionItems) = preProject
.editorconfig = .editorconfig
.gitattributes = .gitattributes
.gitignore = .gitignore
Directory.Build.props = Directory.Build.props
Directory.Build.targets = Directory.Build.targets
Directory.Packages.props = Directory.Packages.props
global.json = global.json
NuGet.config = NuGet.config
spelling.dic = spelling.dic
EndProjectSection
EndProject
Expand Down Expand Up @@ -123,7 +130,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "HelloAgent", "samples\Hello
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AIModelClientHostingExtensions", "src\Microsoft.AutoGen\Extensions\AIModelClientHostingExtensions\AIModelClientHostingExtensions.csproj", "{97550E87-48C6-4EBF-85E1-413ABAE9DBFD}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Agents.Tests", "Microsoft.AutoGen.Agents.Tests\Microsoft.AutoGen.Agents.Tests.csproj", "{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AutoGen.Agents.Tests", "Microsoft.AutoGen.Agents.Tests\Microsoft.AutoGen.Agents.Tests.csproj", "{CF4C92BD-28AE-4B8F-B173-601004AEC9BF}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{686480D7-8FEC-4ED3-9C5D-CEBE1057A7ED}"
EndProject
Expand Down
4 changes: 0 additions & 4 deletions dotnet/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@
<NoWarn>$(NoWarn);CA1829</NoWarn>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="System.Text.Json" />
</ItemGroup>

<ItemGroup Condition="'$(IsTestProject)' == 'true'">
<PackageReference Include="ApprovalTests" />
<PackageReference Include="FluentAssertions" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
<PackageReference Include="FluentAssertions" />
<PackageReference Include="Microsoft.SemanticKernel.Plugins.Web" />
<PackageReference Include="Microsoft.Extensions.AI" />
<PackageReference Include="System.Text.Json" VersionOverride="9.0.0-rc.2.24473.5" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using AutoGen.OpenAI;
using AutoGen.OpenAI.Extension;
using FluentAssertions;
using Microsoft.Extensions.AI;

/// <summary>
/// This example shows how to add type-safe function call to an agent.
Expand Down Expand Up @@ -37,13 +38,20 @@ public async Task<string> ConcatString(string[] strings)
/// </summary>
/// <param name="price">price, should be an integer</param>
/// <param name="taxRate">tax rate, should be in range (0, 1)</param>
[FunctionAttribute]
[Function]
public async Task<string> CalculateTax(int price, float taxRate)
{
return $"tax is {price * taxRate}";
}

public static async Task RunAsync()
/// <summary>
/// This example shows how to add type-safe function call using AutoGen.SourceGenerator.
/// The SourceGenerator will automatically generate FunctionDefinition and FunctionCallWrapper during compiling time.
///
/// For adding type-safe function call from M.E.A.I tools, please refer to <see cref="ToolCallWithMEAITools"/>.
/// </summary>
/// <returns></returns>
public static async Task ToolCallWithSourceGenerator()
{
var instance = new Example03_Agent_FunctionCall();
var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
Expand Down Expand Up @@ -101,4 +109,60 @@ public static async Task RunAsync()
// send aggregate message back to llm to get the final result
var finalResult = await agent.SendAsync(calculateTaxes);
}

/// <summary>
/// This example shows how to add type-safe function call from M.E.A.I tools.
///
/// For adding type-safe function call from source generator, please refer to <see cref="ToolCallWithSourceGenerator"/>.
/// </summary>
public static async Task ToolCallWithMEAITools()
{
var gpt4o = LLMConfiguration.GetOpenAIGPT4o_mini();
var instance = new Example03_Agent_FunctionCall();

AIFunction[] tools = [
AIFunctionFactory.Create(instance.UpperCase),
AIFunctionFactory.Create(instance.ConcatString),
AIFunctionFactory.Create(instance.CalculateTax),
];

var toolCallMiddleware = new FunctionCallMiddleware(tools);

var agent = new OpenAIChatAgent(
chatClient: gpt4o,
name: "agent",
systemMessage: "You are a helpful AI assistant")
.RegisterMessageConnector()
.RegisterStreamingMiddleware(toolCallMiddleware)
.RegisterPrintMessage();

// talk to the assistant agent
var upperCase = await agent.SendAsync("convert to upper case: hello world");
upperCase.GetContent()?.Should().Be("HELLO WORLD");
upperCase.Should().BeOfType<ToolCallAggregateMessage>();
upperCase.GetToolCalls().Should().HaveCount(1);
upperCase.GetToolCalls().First().FunctionName.Should().Be(nameof(UpperCase));

var concatString = await agent.SendAsync("concatenate strings: a, b, c, d, e");
concatString.GetContent()?.Should().Be("a b c d e");
concatString.Should().BeOfType<ToolCallAggregateMessage>();
concatString.GetToolCalls().Should().HaveCount(1);
concatString.GetToolCalls().First().FunctionName.Should().Be(nameof(ConcatString));

var calculateTax = await agent.SendAsync("calculate tax: 100, 0.1");
calculateTax.GetContent().Should().Be("tax is 10");
calculateTax.Should().BeOfType<ToolCallAggregateMessage>();
calculateTax.GetToolCalls().Should().HaveCount(1);
calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));

// parallel function calls
var calculateTaxes = await agent.SendAsync("calculate tax: 100, 0.1; calculate tax: 200, 0.2");
calculateTaxes.GetContent().Should().Be("tax is 10\ntax is 40"); // "tax is 10\n tax is 40
calculateTaxes.Should().BeOfType<ToolCallAggregateMessage>();
calculateTaxes.GetToolCalls().Should().HaveCount(2);
calculateTaxes.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax));

// send aggregate message back to llm to get the final result
var finalResult = await agent.SendAsync(calculateTaxes);
}
}
3 changes: 2 additions & 1 deletion dotnet/samples/AutoGen.BasicSamples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
// When a new sample is created please add them to the allSamples collection
("Assistant Agent", Example01_AssistantAgent.RunAsync),
("Two-agent Math Chat", Example02_TwoAgent_MathChat.RunAsync),
("Agent Function Call", Example03_Agent_FunctionCall.RunAsync),
("Agent Function Call With Source Generator", Example03_Agent_FunctionCall.ToolCallWithSourceGenerator),
("Agent Function Call With M.E.A.I AI Functions", Example03_Agent_FunctionCall.ToolCallWithMEAITools),
("Dynamic Group Chat Coding Task", Example04_Dynamic_GroupChat_Coding_Task.RunAsync),
("DALL-E and GPT4v", Example05_Dalle_And_GPT4V.RunAsync),
("User Proxy Agent", Example06_UserProxyAgent.RunAsync),
Expand Down
1 change: 1 addition & 0 deletions dotnet/src/AutoGen.Core/AutoGen.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<ItemGroup>
<PackageReference Include="JsonSchema.Net.Generation" />
<PackageReference Include="System.Memory.Data" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
Expand Down
67 changes: 67 additions & 0 deletions dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
using Microsoft.Extensions.AI;

namespace AutoGen.Core;

Expand All @@ -22,6 +25,10 @@ public FunctionAttribute(string? functionName = null, string? description = null

public class FunctionContract
{
private const string NamespaceKey = nameof(Namespace);

private const string ClassNameKey = nameof(ClassName);

/// <summary>
/// The namespace of the function.
/// </summary>
Expand Down Expand Up @@ -52,6 +59,7 @@ public class FunctionContract
/// <summary>
/// The return type of the function.
/// </summary>
[JsonIgnore]
public Type? ReturnType { get; set; }

/// <summary>
Expand All @@ -60,6 +68,39 @@ public class FunctionContract
/// Otherwise, the description will be null.
/// </summary>
public string? ReturnDescription { get; set; }

public static implicit operator FunctionContract(AIFunctionMetadata metadata)
{
return new FunctionContract
{
Namespace = metadata.AdditionalProperties.ContainsKey(NamespaceKey) ? metadata.AdditionalProperties[NamespaceKey] as string : null,
ClassName = metadata.AdditionalProperties.ContainsKey(ClassNameKey) ? metadata.AdditionalProperties[ClassNameKey] as string : null,
Name = metadata.Name,
Description = metadata.Description,
Parameters = metadata.Parameters?.Select(p => (FunctionParameterContract)p).ToList(),
ReturnType = metadata.ReturnParameter.ParameterType,
ReturnDescription = metadata.ReturnParameter.Description,
};
}

public static implicit operator AIFunctionMetadata(FunctionContract contract)
{
return new AIFunctionMetadata(contract.Name)
{
Description = contract.Description,
ReturnParameter = new AIFunctionReturnParameterMetadata()
{
Description = contract.ReturnDescription,
ParameterType = contract.ReturnType,
},
AdditionalProperties = new Dictionary<string, object?>
{
[NamespaceKey] = contract.Namespace,
[ClassNameKey] = contract.ClassName,
},
Parameters = [.. contract.Parameters?.Select(p => (AIFunctionParameterMetadata)p)],
};
}
}

public class FunctionParameterContract
Expand All @@ -79,6 +120,7 @@ public class FunctionParameterContract
/// <summary>
/// The type of the parameter.
/// </summary>
[JsonIgnore]
public Type? ParameterType { get; set; }

/// <summary>
Expand All @@ -90,4 +132,29 @@ public class FunctionParameterContract
/// The default value of the parameter.
/// </summary>
public object? DefaultValue { get; set; }

// convert to/from FunctionParameterMetadata
public static implicit operator FunctionParameterContract(AIFunctionParameterMetadata metadata)
{
return new FunctionParameterContract
{
Name = metadata.Name,
Description = metadata.Description,
ParameterType = metadata.ParameterType,
IsRequired = metadata.IsRequired,
DefaultValue = metadata.DefaultValue,
};
}

public static implicit operator AIFunctionParameterMetadata(FunctionParameterContract contract)
{
return new AIFunctionParameterMetadata(contract.Name!)
{
DefaultValue = contract.DefaultValue,
Description = contract.Description,
IsRequired = contract.IsRequired,
ParameterType = contract.ParameterType,
HasDefaultValue = contract.DefaultValue != null,
};
}
}
31 changes: 31 additions & 0 deletions dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;

namespace AutoGen.Core;

Expand Down Expand Up @@ -43,6 +45,19 @@ public FunctionCallMiddleware(
this.functionMap = functionMap;
}

/// <summary>
/// Create a new instance of <see cref="FunctionCallMiddleware"/> with a list of <see cref="AIFunction"/>.
/// </summary>
/// <param name="functions">function list</param>
/// <param name="name">optional middleware name. If not provided, the class name <see cref="FunctionCallMiddleware"/> will be used.</param>
public FunctionCallMiddleware(IEnumerable<AIFunction> functions, string? name = null)
{
this.Name = name ?? nameof(FunctionCallMiddleware);
this.functions = functions.Select(f => (FunctionContract)f.Metadata).ToArray();

this.functionMap = functions.Select(f => (f.Metadata.Name, this.AIToolInvokeWrapper(f.InvokeAsync))).ToDictionary(f => f.Name, f => f.Item2);
}

public string? Name { get; }

public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -173,4 +188,20 @@ private async Task<IMessage> InvokeToolCallMessagesAfterInvokingAgentAsync(ToolC
return toolCallMsg;
}
}

private Func<string, Task<string>> AIToolInvokeWrapper(Func<IEnumerable<KeyValuePair<string, object?>>?, CancellationToken, Task<object?>> lambda)
{
return async (string args) =>
{
var arguments = JsonSerializer.Deserialize<Dictionary<string, object?>>(args);
var result = await lambda(arguments, CancellationToken.None);
return result switch
{
string s => s,
JsonElement e => e.ToString(),
_ => JsonSerializer.Serialize(result),
};
};
}
}
2 changes: 2 additions & 0 deletions dotnet/test/AutoGen.Tests/AutoGen.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
<ProjectReference Include="..\..\src\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="..\..\src\AutoGen\AutoGen.csproj" />
<ProjectReference Include="..\AutoGen.Test.Share\AutoGen.Tests.Share.csproj" />
<PackageReference Include="Microsoft.Extensions.AI" />
<PackageReference Include="System.Text.Json" VersionOverride="9.0.0-rc.2.24473.5" />
</ItemGroup>

<ItemGroup>
Expand Down
3 changes: 2 additions & 1 deletion dotnet/test/AutoGen.Tests/BasicSampleTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public async Task TwoAgentMathClassTestAsync()
[ApiKeyFact("OPENAI_API_KEY")]
public async Task AgentFunctionCallTestAsync()
{
await Example03_Agent_FunctionCall.RunAsync();
await Example03_Agent_FunctionCall.ToolCallWithSourceGenerator();
await Example03_Agent_FunctionCall.ToolCallWithMEAITools();
}

[ApiKeyFact("MISTRAL_API_KEY")]
Expand Down
Loading

0 comments on commit 77b737f

Please sign in to comment.