Skip to content

Commit

Permalink
fix runtime broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
kostapetan committed Nov 28, 2024
1 parent f985f7d commit d2497d1
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 38 deletions.
22 changes: 9 additions & 13 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,6 @@ protected AgentBase(
runtime.AgentInstance = this;
this.EventTypes = eventTypes;
_logger = logger ?? LoggerFactory.Create(builder => { }).CreateLogger<AgentBase>();
var subscriptionRequest = new AddSubscriptionRequest
{
RequestId = Guid.NewGuid().ToString(),
Subscription = new Subscription
{
TypeSubscription = new TypeSubscription
{
AgentType = this.AgentId.Type,
TopicType = this.AgentId.Type + "/" + this.AgentId.Key
}
}
};
_runtime.SendMessageAsync(new Message { AddSubscriptionRequest = subscriptionRequest }).AsTask().Wait();
Completion = Start();
}
internal Task Completion { get; }
Expand Down Expand Up @@ -125,6 +112,15 @@ await this.InvokeWithActivityAsync(
case Message.MessageOneofCase.Response:
OnResponseCore(msg.Response);
break;
case Message.MessageOneofCase.RegisterAgentTypeResponse:
_logger.LogInformation($"Got {msg.MessageCase} with payload {msg.RegisterAgentTypeResponse}");
break;
case Message.MessageOneofCase.AddSubscriptionResponse:
_logger.LogInformation($"Got {msg.MessageCase} with payload {msg.AddSubscriptionResponse}");
break;
default:
_logger.LogInformation($"Got {msg.MessageCase}");
break;
}
}
public List<string> Subscribe(string topic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using System.Diagnostics;
using System.Reflection;
using System.Threading.Channels;
using Google.Protobuf;
using Google.Protobuf.Reflection;
using Grpc.Core;
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -212,7 +214,7 @@ private async ValueTask RegisterAgentTypeAsync(string type, Type agentType, Canc
{
var events = agentType.GetInterfaces()
.Where(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IHandle<>))
.Select(i => i.GetGenericArguments().First().Name);
.Select(i => GetDescriptorName(i.GetGenericArguments().First()));
//var state = agentType.BaseType?.GetGenericArguments().First();
var topicTypes = agentType.GetCustomAttributes<TopicSubscriptionAttribute>().Select(t => t.Topic);

Expand All @@ -224,12 +226,26 @@ await WriteChannelAsync(new Message
RequestId = Guid.NewGuid().ToString(),
//TopicTypes = { topicTypes },
//StateType = state?.Name,
//Events = { events }
Events = { events }
}
}, cancellationToken).ConfigureAwait(false);
}
}

public static string GetDescriptorName(Type messageType)
{
if (typeof(IMessage).IsAssignableFrom(messageType))
{
var descriptorProperty = messageType.GetProperty("Descriptor", BindingFlags.Public | BindingFlags.Static);
if (descriptorProperty != null)
{
var descriptor = descriptorProperty.GetValue(null) as MessageDescriptor;
return descriptor?.FullName??messageType.Name;
}
}
return messageType.Name;
}

// new is intentional
public new async ValueTask SendResponseAsync(RpcResponse response, CancellationToken cancellationToken = default)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Microsoft.AutoGen.Agents;

public static class GrpcAgentWorkerHostBuilderExtensions
{
private const string _defaultAgentServiceAddress = "https://localhost:53071";
private const string _defaultAgentServiceAddress = "https://localhost:5001";
public static IHostApplicationBuilder AddGrpcAgentWorker(this IHostApplicationBuilder builder, string? agentServiceAddress = null)
{
builder.Services.AddGrpcClient<AgentRpc.AgentRpcClient>(options =>
Expand Down
27 changes: 8 additions & 19 deletions dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public sealed class GrpcGateway : BackgroundService, IGateway
private readonly ConcurrentDictionary<string, Subscription> _subscriptionsByAgentType = new();
private readonly ConcurrentDictionary<string, List<string>> _subscriptionsByTopic = new();

private readonly ConcurrentDictionary<string, HashSet<string>> _agentsToEventsMap = new();
// The mapping from agent id to worker process.
private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new();
// RPC
Expand All @@ -42,10 +43,12 @@ public async ValueTask BroadcastEvent(CloudEvent evt)
{
// TODO: filter the workers that receive the event
var tasks = new List<Task>(_workers.Count);
foreach (var (_, connection) in _supportedAgentTypes)
foreach (var (key, connection) in _supportedAgentTypes)
{

tasks.Add(this.SendMessageAsync((IConnection)connection[0], evt, default));
if (_agentsToEventsMap.TryGetValue(key, out var events) && events.Contains(evt.Type))
{
tasks.Add(SendMessageAsync(connection[0], evt, default));
}
}
await Task.WhenAll(tasks).ConfigureAwait(false);
}
Expand Down Expand Up @@ -142,6 +145,7 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection,
{
connection.AddSupportedType(msg.Type);
_supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection);
_agentsToEventsMap.TryAdd(msg.Type, new HashSet<string>(msg.Events));

await _gatewayRegistry.RegisterAgentType(msg.Type, _reference).ConfigureAwait(true);
Message response = new()
Expand All @@ -153,22 +157,7 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection,
Success = true
}
};
// add a default subscription for the agent type
//TODO: we should consider having constraints on the namespace or at least migrate all our examples to use well typed namesspaces like com.microsoft.autogen/hello/HelloAgents etc
var subscriptionRequest = new AddSubscriptionRequest
{
RequestId = Guid.NewGuid().ToString(),
Subscription = new Subscription
{
TypeSubscription = new TypeSubscription
{
AgentType = msg.Type,
TopicType = msg.Type
}
}
};
await AddSubscriptionAsync(connection, subscriptionRequest).ConfigureAwait(true);

// TODO: add Topics from the registration message
await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false);
}
private async ValueTask DispatchEventAsync(CloudEvent evt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ namespace Microsoft.AutoGen.Agents;

public static class OrleansRuntimeHostingExtenions
{
public static WebApplicationBuilder AddOrleans(this WebApplicationBuilder builder, bool local = false)
public static IHostApplicationBuilder AddOrleans(this WebApplicationBuilder builder, bool local = false)
{
return builder.AddOrleans(local);
return builder.AddOrleansImpl(local);
}

public static IHostApplicationBuilder AddOrleans(this IHostApplicationBuilder builder, bool local = false)
private static IHostApplicationBuilder AddOrleansImpl(this IHostApplicationBuilder builder, bool local = false)
{
builder.Services.AddSerializer(serializer => serializer.AddProtobufSerializer());
builder.Services.AddSingleton<IRegistryGrain, RegistryGrain>();
Expand Down
1 change: 1 addition & 0 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ message Event {
message RegisterAgentTypeRequest {
string request_id = 1;
string type = 2;
repeated string events = 3;
}

message RegisterAgentTypeResponse {
Expand Down

0 comments on commit d2497d1

Please sign in to comment.