diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs index 345e6d34c826..01ad856a2d49 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs @@ -15,27 +15,40 @@ namespace Microsoft.AutoGen.Agents; public abstract class AgentBase : IAgentBase, IHandle { public static readonly ActivitySource s_source = new("AutoGen.Agent"); - public AgentId AgentId => _context.AgentId; + public AgentId AgentId => _runtime.AgentId; private readonly object _lock = new(); private readonly Dictionary> _pendingRequests = []; private readonly Channel _mailbox = Channel.CreateUnbounded(); - private readonly IAgentRuntime _context; + private readonly IAgentRuntime _runtime; public string Route { get; set; } = "base"; protected internal ILogger _logger; - public IAgentRuntime Context => _context; + public IAgentRuntime Context => _runtime; protected readonly EventTypes EventTypes; protected AgentBase( - IAgentRuntime context, + IAgentRuntime runtime, EventTypes eventTypes, ILogger? logger = null) { - _context = context; - context.AgentInstance = this; + _runtime = runtime; + runtime.AgentInstance = this; this.EventTypes = eventTypes; _logger = logger ?? LoggerFactory.Create(builder => { }).CreateLogger(); + var subscriptionRequest = new AddSubscriptionRequest + { + RequestId = Guid.NewGuid().ToString(), + Subscription = new Subscription + { + TypeSubscription = new TypeSubscription + { + AgentType = this.AgentId.Type, + TopicType = this.AgentId.Type + "/" + this.AgentId.Key + } + } + }; + _runtime.SendMessageAsync(new Message { AddSubscriptionRequest = subscriptionRequest }).AsTask().Wait(); Completion = Start(); } internal Task Completion { get; } @@ -131,19 +144,19 @@ public List Subscribe(string topic) } } }; - _context.SendMessageAsync(message).AsTask().Wait(); + _runtime.SendMessageAsync(message).AsTask().Wait(); return new List { topic }; } public async Task StoreAsync(AgentState state, CancellationToken cancellationToken = default) { - await _context.StoreAsync(state, cancellationToken).ConfigureAwait(false); + await _runtime.StoreAsync(state, cancellationToken).ConfigureAwait(false); return; } public async Task ReadAsync(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new() { - var agentState = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false); - return agentState.FromAgentState(); + var agentstate = await _runtime.ReadAsync(agentId, cancellationToken).ConfigureAwait(false); + return agentstate.FromAgentState(); } private void OnResponseCore(RpcResponse response) { @@ -171,7 +184,7 @@ private async Task OnRequestCoreAsync(RpcRequest request, CancellationToken canc { response = new RpcResponse { Error = ex.Message }; } - await _context.SendResponseAsync(request, response, cancellationToken).ConfigureAwait(false); + await _runtime.SendResponseAsync(request, response, cancellationToken).ConfigureAwait(false); } protected async Task RequestAsync(AgentId target, string method, Dictionary parameters) @@ -195,7 +208,7 @@ protected async Task RequestAsync(AgentId target, string method, Di activity?.SetTag("peer.service", target.ToString()); var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _context.Update(request, activity); + _runtime.Update(request, activity); await this.InvokeWithActivityAsync( static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource) state, CancellationToken ct) => { @@ -206,7 +219,7 @@ static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource { - await state.Agent._context.PublishEventAsync(state.Event, ct).ConfigureAwait(false); + await state.Agent._runtime.PublishEventAsync(state.Event).ConfigureAwait(false); }, (this, item), activity, diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/AgentWorker.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/AgentWorker.cs index a69da96fb3d4..f9a5050534c8 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/AgentWorker.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/AgentWorker.cs @@ -24,6 +24,8 @@ public class AgentWorker : private readonly CancellationTokenSource _shutdownCts; private readonly IServiceProvider _serviceProvider; private readonly IEnumerable> _configuredAgentTypes; + private readonly ConcurrentDictionary _subscriptionsByAgentType = new(); + private readonly ConcurrentDictionary> _subscriptionsByTopic = new(); private readonly DistributedContextPropagator _distributedContextPropagator; private readonly CancellationTokenSource _shutdownCancellationToken = new(); private Task? _mailboxTask; @@ -96,11 +98,7 @@ public async Task RunMessagePump() if (message == null) { continue; } switch (message) { - case Message.MessageOneofCase.AddSubscriptionResponse: - break; - case Message.MessageOneofCase.RegisterAgentTypeResponse: - break; - case Message msg: + case Message msg when msg.CloudEvent != null: var item = msg.CloudEvent; @@ -110,6 +108,13 @@ public async Task RunMessagePump() agentToInvoke.ReceiveMessage(msg); } break; + case Message msg when msg.AddSubscriptionRequest != null: + await AddSubscriptionRequestAsync(msg.AddSubscriptionRequest).ConfigureAwait(true); + break; + case Message msg when msg.AddSubscriptionResponse != null: + break; + case Message msg when msg.RegisterAgentTypeResponse != null: + break; default: throw new InvalidOperationException($"Unexpected message '{message}'."); } @@ -123,6 +128,23 @@ public async Task RunMessagePump() } } } + private async ValueTask AddSubscriptionRequestAsync(AddSubscriptionRequest subscription) + { + var topic = subscription.Subscription.TypeSubscription.TopicType; + var agentType = subscription.Subscription.TypeSubscription.AgentType; + _subscriptionsByAgentType[agentType] = subscription.Subscription; + _subscriptionsByTopic.GetOrAdd(topic, _ => []).Add(agentType); + Message response = new() + { + AddSubscriptionResponse = new() + { + RequestId = subscription.RequestId, + Error = "", + Success = true + } + }; + await _mailbox.Writer.WriteAsync(response).ConfigureAwait(false); + } public async Task StartAsync(CancellationToken cancellationToken) { diff --git a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs index 45477c8eb5a6..ab24a0e15fe5 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs @@ -153,6 +153,22 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, Success = true } }; + // add a default subscription for the agent type + //TODO: we should consider having constraints on the namespace or at least migrate all our examples to use well typed namesspaces like com.microsoft.autogen/hello/HelloAgents etc + var subscriptionRequest = new AddSubscriptionRequest + { + RequestId = Guid.NewGuid().ToString(), + Subscription = new Subscription + { + TypeSubscription = new TypeSubscription + { + AgentType = msg.Type, + TopicType = msg.Type + } + } + }; + await AddSubscriptionAsync(connection, subscriptionRequest).ConfigureAwait(true); + await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false); } private async ValueTask DispatchEventAsync(CloudEvent evt) diff --git a/dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs b/dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs index e58fdb00f0a0..7e272ce6bed9 100644 --- a/dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs @@ -23,6 +23,10 @@ public class AgentBaseTests(InMemoryAgentRuntimeFixture fixture) public async Task ItInvokeRightHandlerTestAsync() { var mockContext = new Mock(); + mockContext.SetupGet(x => x.AgentId).Returns(new AgentId("test", "test")); + // mock SendMessageAsync + mockContext.Setup(x => x.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(new ValueTask()); var agent = new TestAgent(mockContext.Object, new EventTypes(TypeRegistry.Empty, [], []), new Logger(new LoggerFactory())); await agent.HandleObject("hello world");