From d2497d146ced62dcc895e3d576afafe79653cf19 Mon Sep 17 00:00:00 2001 From: Kosta Petan Date: Thu, 28 Nov 2024 15:36:00 +0100 Subject: [PATCH] fix runtime broadcasting --- .../src/Microsoft.AutoGen/Agents/AgentBase.cs | 22 +++++++-------- .../Agents/Services/Grpc/GrpcAgentWorker.cs | 20 ++++++++++++-- .../GrpcAgentWorkerHostBuilderExtension.cs | 2 +- .../Agents/Services/Grpc/GrpcGateway.cs | 27 ++++++------------- .../Orleans/OrleansRuntimeHostingExtenions.cs | 6 ++--- protos/agent_worker.proto | 1 + 6 files changed, 40 insertions(+), 38 deletions(-) diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs index 545734b50575..5a3900a65e69 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs @@ -36,19 +36,6 @@ protected AgentBase( runtime.AgentInstance = this; this.EventTypes = eventTypes; _logger = logger ?? LoggerFactory.Create(builder => { }).CreateLogger(); - var subscriptionRequest = new AddSubscriptionRequest - { - RequestId = Guid.NewGuid().ToString(), - Subscription = new Subscription - { - TypeSubscription = new TypeSubscription - { - AgentType = this.AgentId.Type, - TopicType = this.AgentId.Type + "/" + this.AgentId.Key - } - } - }; - _runtime.SendMessageAsync(new Message { AddSubscriptionRequest = subscriptionRequest }).AsTask().Wait(); Completion = Start(); } internal Task Completion { get; } @@ -125,6 +112,15 @@ await this.InvokeWithActivityAsync( case Message.MessageOneofCase.Response: OnResponseCore(msg.Response); break; + case Message.MessageOneofCase.RegisterAgentTypeResponse: + _logger.LogInformation($"Got {msg.MessageCase} with payload {msg.RegisterAgentTypeResponse}"); + break; + case Message.MessageOneofCase.AddSubscriptionResponse: + _logger.LogInformation($"Got {msg.MessageCase} with payload {msg.AddSubscriptionResponse}"); + break; + default: + _logger.LogInformation($"Got {msg.MessageCase}"); + break; } } public List Subscribe(string topic) diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcAgentWorker.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcAgentWorker.cs index 48f07573430d..cb928e77e29d 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcAgentWorker.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcAgentWorker.cs @@ -5,6 +5,8 @@ using System.Diagnostics; using System.Reflection; using System.Threading.Channels; +using Google.Protobuf; +using Google.Protobuf.Reflection; using Grpc.Core; using Microsoft.AutoGen.Abstractions; using Microsoft.Extensions.DependencyInjection; @@ -212,7 +214,7 @@ private async ValueTask RegisterAgentTypeAsync(string type, Type agentType, Canc { var events = agentType.GetInterfaces() .Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>)) - .Select(i => i.GetGenericArguments().First().Name); + .Select(i => GetDescriptorName(i.GetGenericArguments().First())); //var state = agentType.BaseType?.GetGenericArguments().First(); var topicTypes = agentType.GetCustomAttributes().Select(t => t.Topic); @@ -224,12 +226,26 @@ await WriteChannelAsync(new Message RequestId = Guid.NewGuid().ToString(), //TopicTypes = { topicTypes }, //StateType = state?.Name, - //Events = { events } + Events = { events } } }, cancellationToken).ConfigureAwait(false); } } + public static string GetDescriptorName(Type messageType) + { + if (typeof(IMessage).IsAssignableFrom(messageType)) + { + var descriptorProperty = messageType.GetProperty("Descriptor", BindingFlags.Public | BindingFlags.Static); + if (descriptorProperty != null) + { + var descriptor = descriptorProperty.GetValue(null) as MessageDescriptor; + return descriptor?.FullName??messageType.Name; + } + } + return messageType.Name; + } + // new is intentional public new async ValueTask SendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcAgentWorkerHostBuilderExtension.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcAgentWorkerHostBuilderExtension.cs index 4f214caa8203..ef7975c3864f 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcAgentWorkerHostBuilderExtension.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcAgentWorkerHostBuilderExtension.cs @@ -11,7 +11,7 @@ namespace Microsoft.AutoGen.Agents; public static class GrpcAgentWorkerHostBuilderExtensions { - private const string _defaultAgentServiceAddress = "https://localhost:53071"; + private const string _defaultAgentServiceAddress = "https://localhost:5001"; public static IHostApplicationBuilder AddGrpcAgentWorker(this IHostApplicationBuilder builder, string? agentServiceAddress = null) { builder.Services.AddGrpcClient(options => diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs index ab24a0e15fe5..f5d92e6cbacb 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs @@ -24,6 +24,7 @@ public sealed class GrpcGateway : BackgroundService, IGateway private readonly ConcurrentDictionary _subscriptionsByAgentType = new(); private readonly ConcurrentDictionary> _subscriptionsByTopic = new(); + private readonly ConcurrentDictionary> _agentsToEventsMap = new(); // The mapping from agent id to worker process. private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new(); // RPC @@ -42,10 +43,12 @@ public async ValueTask BroadcastEvent(CloudEvent evt) { // TODO: filter the workers that receive the event var tasks = new List(_workers.Count); - foreach (var (_, connection) in _supportedAgentTypes) + foreach (var (key, connection) in _supportedAgentTypes) { - - tasks.Add(this.SendMessageAsync((IConnection)connection[0], evt, default)); + if (_agentsToEventsMap.TryGetValue(key, out var events) && events.Contains(evt.Type)) + { + tasks.Add(SendMessageAsync(connection[0], evt, default)); + } } await Task.WhenAll(tasks).ConfigureAwait(false); } @@ -142,6 +145,7 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, { connection.AddSupportedType(msg.Type); _supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection); + _agentsToEventsMap.TryAdd(msg.Type, new HashSet(msg.Events)); await _gatewayRegistry.RegisterAgentType(msg.Type, _reference).ConfigureAwait(true); Message response = new() @@ -153,22 +157,7 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, Success = true } }; - // add a default subscription for the agent type - //TODO: we should consider having constraints on the namespace or at least migrate all our examples to use well typed namesspaces like com.microsoft.autogen/hello/HelloAgents etc - var subscriptionRequest = new AddSubscriptionRequest - { - RequestId = Guid.NewGuid().ToString(), - Subscription = new Subscription - { - TypeSubscription = new TypeSubscription - { - AgentType = msg.Type, - TopicType = msg.Type - } - } - }; - await AddSubscriptionAsync(connection, subscriptionRequest).ConfigureAwait(true); - + // TODO: add Topics from the registration message await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false); } private async ValueTask DispatchEventAsync(CloudEvent evt) diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/OrleansRuntimeHostingExtenions.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/OrleansRuntimeHostingExtenions.cs index 374e49f7a500..a5a8e126cb2b 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/OrleansRuntimeHostingExtenions.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/OrleansRuntimeHostingExtenions.cs @@ -13,12 +13,12 @@ namespace Microsoft.AutoGen.Agents; public static class OrleansRuntimeHostingExtenions { - public static WebApplicationBuilder AddOrleans(this WebApplicationBuilder builder, bool local = false) + public static IHostApplicationBuilder AddOrleans(this WebApplicationBuilder builder, bool local = false) { - return builder.AddOrleans(local); + return builder.AddOrleansImpl(local); } - public static IHostApplicationBuilder AddOrleans(this IHostApplicationBuilder builder, bool local = false) + private static IHostApplicationBuilder AddOrleansImpl(this IHostApplicationBuilder builder, bool local = false) { builder.Services.AddSerializer(serializer => serializer.AddProtobufSerializer()); builder.Services.AddSingleton(); diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 4d346dfecd63..853df79fff7d 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -50,6 +50,7 @@ message Event { message RegisterAgentTypeRequest { string request_id = 1; string type = 2; + repeated string events = 3; } message RegisterAgentTypeResponse {