Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add default subscriptions for the agent type - Implicitly created sub… #4324

Merged
43 changes: 28 additions & 15 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,40 @@ namespace Microsoft.AutoGen.Agents;
public abstract class AgentBase : IAgentBase, IHandle
{
public static readonly ActivitySource s_source = new("AutoGen.Agent");
public AgentId AgentId => _context.AgentId;
public AgentId AgentId => _runtime.AgentId;
private readonly object _lock = new();
private readonly Dictionary<string, TaskCompletionSource<RpcResponse>> _pendingRequests = [];

private readonly Channel<object> _mailbox = Channel.CreateUnbounded<object>();
private readonly IAgentRuntime _context;
private readonly IAgentRuntime _runtime;
public string Route { get; set; } = "base";

protected internal ILogger<AgentBase> _logger;
public IAgentRuntime Context => _context;
public IAgentRuntime Context => _runtime;
protected readonly EventTypes EventTypes;

protected AgentBase(
IAgentRuntime context,
IAgentRuntime runtime,
EventTypes eventTypes,
ILogger<AgentBase>? logger = null)
{
_context = context;
context.AgentInstance = this;
_runtime = runtime;
runtime.AgentInstance = this;
this.EventTypes = eventTypes;
_logger = logger ?? LoggerFactory.Create(builder => { }).CreateLogger<AgentBase>();
var subscriptionRequest = new AddSubscriptionRequest
{
RequestId = Guid.NewGuid().ToString(),
Subscription = new Subscription
{
TypeSubscription = new TypeSubscription
{
AgentType = this.AgentId.Type,
TopicType = this.AgentId.Type + "/" + this.AgentId.Key
}
}
};
_runtime.SendMessageAsync(new Message { AddSubscriptionRequest = subscriptionRequest }).AsTask().Wait();
Completion = Start();
}
internal Task Completion { get; }
Expand Down Expand Up @@ -131,19 +144,19 @@ public List<string> Subscribe(string topic)
}
}
};
_context.SendMessageAsync(message).AsTask().Wait();
_runtime.SendMessageAsync(message).AsTask().Wait();

return new List<string> { topic };
}
public async Task StoreAsync(AgentState state, CancellationToken cancellationToken = default)
{
await _context.StoreAsync(state, cancellationToken).ConfigureAwait(false);
await _runtime.StoreAsync(state, cancellationToken).ConfigureAwait(false);
return;
}
public async Task<T> ReadAsync<T>(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new()
{
var agentState = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
return agentState.FromAgentState<T>();
var agentstate = await _runtime.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
return agentstate.FromAgentState<T>();
}
private void OnResponseCore(RpcResponse response)
{
Expand Down Expand Up @@ -171,7 +184,7 @@ private async Task OnRequestCoreAsync(RpcRequest request, CancellationToken canc
{
response = new RpcResponse { Error = ex.Message };
}
await _context.SendResponseAsync(request, response, cancellationToken).ConfigureAwait(false);
await _runtime.SendResponseAsync(request, response, cancellationToken).ConfigureAwait(false);
}

protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Dictionary<string, string> parameters)
Expand All @@ -195,7 +208,7 @@ protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Di
activity?.SetTag("peer.service", target.ToString());

var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
_context.Update(request, activity);
_runtime.Update(request, activity);
await this.InvokeWithActivityAsync(
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state, CancellationToken ct) =>
{
Expand All @@ -206,7 +219,7 @@ static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResp
self._pendingRequests[request.RequestId] = completion;
}

await state.Agent._context.SendRequestAsync(state.Agent, state.Request, ct).ConfigureAwait(false);
await state.Agent._runtime.SendRequestAsync(state.Agent, state.Request).ConfigureAwait(false);

await completion.Task.ConfigureAwait(false);
},
Expand All @@ -231,11 +244,11 @@ public async ValueTask PublishEventAsync(CloudEvent item, CancellationToken canc
activity?.SetTag("peer.service", $"{item.Type}/{item.Source}");

// TODO: fix activity
_context.Update(item, activity);
_runtime.Update(item, activity);
await this.InvokeWithActivityAsync(
static async ((AgentBase Agent, CloudEvent Event) state, CancellationToken ct) =>
{
await state.Agent._context.PublishEventAsync(state.Event, ct).ConfigureAwait(false);
await state.Agent._runtime.PublishEventAsync(state.Event).ConfigureAwait(false);
},
(this, item),
activity,
Expand Down
32 changes: 27 additions & 5 deletions dotnet/src/Microsoft.AutoGen/Agents/Services/AgentWorker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class AgentWorker :
private readonly CancellationTokenSource _shutdownCts;
private readonly IServiceProvider _serviceProvider;
private readonly IEnumerable<Tuple<string, Type>> _configuredAgentTypes;
private readonly ConcurrentDictionary<string, Subscription> _subscriptionsByAgentType = new();
private readonly ConcurrentDictionary<string, List<string>> _subscriptionsByTopic = new();
private readonly DistributedContextPropagator _distributedContextPropagator;
private readonly CancellationTokenSource _shutdownCancellationToken = new();
private Task? _mailboxTask;
Expand Down Expand Up @@ -96,11 +98,7 @@ public async Task RunMessagePump()
if (message == null) { continue; }
switch (message)
{
case Message.MessageOneofCase.AddSubscriptionResponse:
break;
case Message.MessageOneofCase.RegisterAgentTypeResponse:
break;
case Message msg:
case Message msg when msg.CloudEvent != null:

var item = msg.CloudEvent;

Expand All @@ -110,6 +108,13 @@ public async Task RunMessagePump()
agentToInvoke.ReceiveMessage(msg);
}
break;
case Message msg when msg.AddSubscriptionRequest != null:
await AddSubscriptionRequestAsync(msg.AddSubscriptionRequest).ConfigureAwait(true);
break;
case Message msg when msg.AddSubscriptionResponse != null:
break;
case Message msg when msg.RegisterAgentTypeResponse != null:
break;
default:
throw new InvalidOperationException($"Unexpected message '{message}'.");
}
Expand All @@ -123,6 +128,23 @@ public async Task RunMessagePump()
}
}
}
private async ValueTask AddSubscriptionRequestAsync(AddSubscriptionRequest subscription)
{
var topic = subscription.Subscription.TypeSubscription.TopicType;
var agentType = subscription.Subscription.TypeSubscription.AgentType;
_subscriptionsByAgentType[agentType] = subscription.Subscription;
_subscriptionsByTopic.GetOrAdd(topic, _ => []).Add(agentType);
Message response = new()
{
AddSubscriptionResponse = new()
{
RequestId = subscription.RequestId,
Error = "",
Success = true
}
};
await _mailbox.Writer.WriteAsync(response).ConfigureAwait(false);
}

public async Task StartAsync(CancellationToken cancellationToken)
{
Expand Down
16 changes: 16 additions & 0 deletions dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection,
Success = true
}
};
// add a default subscription for the agent type
//TODO: we should consider having constraints on the namespace or at least migrate all our examples to use well typed namesspaces like com.microsoft.autogen/hello/HelloAgents etc
var subscriptionRequest = new AddSubscriptionRequest
{
RequestId = Guid.NewGuid().ToString(),
Subscription = new Subscription
{
TypeSubscription = new TypeSubscription
{
AgentType = msg.Type,
TopicType = msg.Type
}
}
};
await AddSubscriptionAsync(connection, subscriptionRequest).ConfigureAwait(true);

await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false);
}
private async ValueTask DispatchEventAsync(CloudEvent evt)
Expand Down
4 changes: 4 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public class AgentBaseTests(InMemoryAgentRuntimeFixture fixture)
public async Task ItInvokeRightHandlerTestAsync()
{
var mockContext = new Mock<IAgentRuntime>();
mockContext.SetupGet(x => x.AgentId).Returns(new AgentId("test", "test"));
// mock SendMessageAsync
mockContext.Setup(x => x.SendMessageAsync(It.IsAny<Message>(), It.IsAny<CancellationToken>()))
.Returns(new ValueTask());
var agent = new TestAgent(mockContext.Object, new EventTypes(TypeRegistry.Empty, [], []), new Logger<AgentBase>(new LoggerFactory()));

await agent.HandleObject("hello world");
Expand Down
Loading