Skip to content

Commit

Permalink
[.Net] refactor over streaming version api (microsoft#2461)
Browse files Browse the repository at this point in the history
* update

* update

* fix comment
  • Loading branch information
LittleLittleCloud authored May 5, 2024
1 parent 4711d7b commit e878be5
Show file tree
Hide file tree
Showing 39 changed files with 255 additions and 394 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public async Task ChatWithAnAgent(IStreamingAgent agent)

#region ChatWithAnAgent_GenerateStreamingReplyAsync
var textMessage = new TextMessage(Role.User, "Hello");
await foreach (var streamingReply in await agent.GenerateStreamingReplyAsync([message]))
await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([message]))
{
if (streamingReply is TextMessageUpdate update)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public async Task StreamingCallCodeSnippetAsync()
IStreamingAgent agent = default;
#region StreamingCallCodeSnippet
var helloTextMessage = new TextMessage(Role.User, "Hello");
var reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]);
var reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
var finalTextMessage = new TextMessage(Role.Assistant, string.Empty, from: agent.Name);
await foreach (var message in reply)
{
Expand All @@ -24,7 +24,7 @@ public async Task StreamingCallCodeSnippetAsync()
#endregion StreamingCallCodeSnippet

#region StreamingCallWithFinalMessage
reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]);
reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
TextMessage finalMessage = null;
await foreach (var message in reply)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public async Task CreateMistralAIClientAsync()
#endregion create_mistral_agent

#region streaming_chat
var reply = await agent.GenerateStreamingReplyAsync(
var reply = agent.GenerateStreamingReplyAsync(
messages: [new TextMessage(Role.User, "Hello, how are you?")]
);

Expand Down Expand Up @@ -75,7 +75,7 @@ public async Task MistralAIChatAgentGetWeatherToolUsageAsync()
#endregion create_get_weather_function_call_middleware

#region register_function_call_middleware
agent = agent.RegisterMiddleware(functionCallMiddleware);
agent = agent.RegisterStreamingMiddleware(functionCallMiddleware);
#endregion register_function_call_middleware

#region send_message_with_function_call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public async Task CreateOpenAIChatAgentAsync()
#endregion create_openai_chat_agent

#region create_openai_chat_agent_streaming
var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });

await foreach (var streamingMessage in streamingReply)
{
Expand Down Expand Up @@ -123,7 +123,7 @@ public async Task OpenAIChatAgentGetWeatherFunctionCallAsync()
{ functions.GetWeatherFunctionContract.Name, functions.GetWeatherWrapper } // GetWeatherWrapper is a wrapper function for GetWeather, which is also auto-generated
});

openAIChatAgent = openAIChatAgent.RegisterMiddleware(functionCallMiddleware);
openAIChatAgent = openAIChatAgent.RegisterStreamingMiddleware(functionCallMiddleware);
#endregion create_function_call_middleware

#region chat_agent_send_function_call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public async Task CreateSemanticKernelAgentAsync()
#endregion create_semantic_kernel_agent

#region create_semantic_kernel_agent_streaming
var streamingReply = await semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
var streamingReply = semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });

await foreach (var streamingMessage in streamingReply)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ public static async Task RunAsync()
var teacher = new AssistantAgent(
name: "teacher",
systemMessage: @"You are a teacher that create pre-school math question for student and check answer.
If the answer is correct, you terminate conversation by saying [TERMINATE].
If the answer is correct, you stop the conversation by saying [COMPLETE].
If the answer is wrong, you ask student to fix it.",
llmConfig: new ConversableAgentConfig
{
Temperature = 0,
ConfigList = [gpt35],
})
.RegisterPostProcess(async (_, reply, _) =>
.RegisterMiddleware(async (msgs, option, agent, _) =>
{
if (reply.GetContent()?.ToLower().Contains("terminate") is true)
var reply = await agent.GenerateReplyAsync(msgs, option);
if (reply.GetContent()?.ToLower().Contains("complete") is true)
{
return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: reply.From);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,16 @@ public static async Task<IAgent> CreateRunnerAgentAsync(InteractiveService servi
systemMessage: "You run dotnet code",
defaultReply: "No code available.")
.RegisterDotnetCodeBlockExectionHook(interactiveService: service)
.RegisterReply(async (msgs, _) =>
.RegisterMiddleware(async (msgs, option, agent, _) =>
{
if (msgs.Count() == 0)
if (msgs.Count() == 0 || msgs.All(msg => msg.From != "coder"))
{
return new TextMessage(Role.Assistant, "No code available. Coder please write code");
}

return null;
})
.RegisterPreProcess(async (msgs, _) =>
{
// retrieve the most recent message from coder
var coderMsg = msgs.LastOrDefault(msg => msg.From == "coder");
if (coderMsg is null)
{
return Enumerable.Empty<IMessage>();
}
else
{
return new[] { coderMsg };
var coderMsg = msgs.Last(msg => msg.From == "coder");
return await agent.GenerateReplyAsync([coderMsg], option);
}
})
.RegisterPrintMessage();
Expand All @@ -122,8 +112,9 @@ public static async Task<IAgent> CreateAdminAsync()
systemMessage: "You are group admin, terminate the group chat once task is completed by saying [TERMINATE] plus the final answer",
temperature: 0,
config: gpt3Config)
.RegisterPostProcess(async (_, reply, _) =>
.RegisterMiddleware(async (msgs, option, agent, _) =>
{
var reply = await agent.GenerateReplyAsync(msgs, option);
if (reply is TextMessage textMessage && textMessage.Content.Contains("TERMINATE") is true)
{
var content = $"{textMessage.Content}\n\n {GroupChatExtension.TERMINATE}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public static async Task RunAsync()
Console.WriteLine((reply as IMessage<ChatMessageContent>).Content.Items[0].As<TextContent>().Text);

var skAgentWithMiddleware = skAgent
.RegisterMessageConnector()
.RegisterMessageConnector() // Register the message connector to support more AutoGen built-in message types
.RegisterPrintMessage();

// Now the skAgentWithMiddleware supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public static async Task RunAsync()
systemMessage: "You are a helpful assistant designed to output JSON.",
seed: 0, // explicitly set a seed to enable deterministic output
responseFormat: ChatCompletionsResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode
.RegisterMessageConnector();
.RegisterMessageConnector()
.RegisterPrintMessage();
#endregion create_agent

#region chat_with_agent
Expand Down
3 changes: 1 addition & 2 deletions dotnet/sample/AutoGen.BasicSamples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Program.cs

using AutoGen.BasicSample;
await Example14_MistralClientAgent_TokenCount.RunAsync();
await Example02_TwoAgent_MathChat.RunAsync();
8 changes: 6 additions & 2 deletions dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public interface IMiddlewareAgent : IAgent
void Use(IMiddleware middleware);
}

public interface IMiddlewareStreamAgent : IMiddlewareAgent, IStreamingAgent
public interface IMiddlewareStreamAgent : IStreamingAgent
{
/// <summary>
/// Get the inner agent.
Expand All @@ -44,7 +44,11 @@ public interface IMiddlewareAgent<out T> : IMiddlewareAgent
T TAgent { get; }
}

public interface IMiddlewareStreamAgent<out T> : IMiddlewareStreamAgent, IMiddlewareAgent<T>
public interface IMiddlewareStreamAgent<out T> : IMiddlewareStreamAgent
where T : IStreamingAgent
{
/// <summary>
/// Get the typed inner agent.
/// </summary>
T TStreamingAgent { get; }
}
3 changes: 1 addition & 2 deletions dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace AutoGen.Core;

Expand All @@ -12,7 +11,7 @@ namespace AutoGen.Core;
/// </summary>
public interface IStreamingAgent : IAgent
{
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default);
Expand Down
26 changes: 15 additions & 11 deletions dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,25 @@ namespace AutoGen.Core;
/// </summary>
public class MiddlewareAgent : IMiddlewareAgent
{
private readonly IAgent _agent;
private IAgent _agent;
private readonly List<IMiddleware> middlewares = new();

/// <summary>
/// Create a new instance of <see cref="MiddlewareAgent"/>
/// </summary>
/// <param name="innerAgent">the inner agent where middleware will be added.</param>
/// <param name="name">the name of the agent if provided. Otherwise, the name of <paramref name="innerAgent"/> will be used.</param>
public MiddlewareAgent(IAgent innerAgent, string? name = null)
public MiddlewareAgent(IAgent innerAgent, string? name = null, IEnumerable<IMiddleware>? middlewares = null)
{
this.Name = name ?? innerAgent.Name;
this._agent = innerAgent;
if (middlewares != null && middlewares.Any())
{
foreach (var middleware in middlewares)
{
this.Use(middleware);
}
}
}

/// <summary>
Expand Down Expand Up @@ -55,13 +62,7 @@ public Task<IMessage> GenerateReplyAsync(
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
IAgent agent = this._agent;
foreach (var middleware in this.middlewares)
{
agent = new DelegateAgent(middleware, agent);
}

return agent.GenerateReplyAsync(messages, options, cancellationToken);
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
}

/// <summary>
Expand All @@ -71,15 +72,18 @@ public Task<IMessage> GenerateReplyAsync(
/// </summary>
public void Use(Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func, string? middlewareName = null)
{
this.middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
{
return await func(context.Messages, context.Options, agent, cancellationToken);
}));
});

this.Use(middleware);
}

public void Use(IMiddleware middleware)
{
this.middlewares.Add(middleware);
_agent = new DelegateAgent(middleware, _agent);
}

public override string ToString()
Expand Down
Loading

0 comments on commit e878be5

Please sign in to comment.