From e56876adeb2ab5e2c509fcc4476aaa5b5f6fc368 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 28 Jan 2025 18:23:46 -0500 Subject: [PATCH 1/8] feat: Set up Scaffolding for Core.Grpc * Define skeleton for GrpcAgentRuntime * Implement CloudEvent and RPC Payload serialization/marshaling --- .../Core.Grpc/GrpcAgentRuntime.cs | 597 ++++++++++++++++++ .../GrpcAgentWorkerHostBuilderExtension.cs | 70 ++ .../Core.Grpc/IAgentMessageSerializer.cs | 23 + .../Core.Grpc/IAgentRuntimeExtensions.cs | 102 +++ .../Core.Grpc/IProtoMessageSerializer.cs | 10 + .../Core.Grpc/ISerializationRegistry.cs | 27 + .../Core.Grpc/ITypeNameResolver.cs | 9 + .../Core.Grpc/ProtoSerializationRegistry.cs | 37 ++ .../Core.Grpc/ProtoTypeNameResolver.cs | 21 + .../Core.Grpc/ProtobufConversionExtensions.cs | 61 ++ .../Core.Grpc/ProtobufMessageSerializer.cs | 46 ++ .../src/Microsoft.AutoGen/Core/AgentsApp.cs | 2 + 12 files changed, 1005 insertions(+) create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentWorkerHostBuilderExtension.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoSerializationRegistry.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs new file mode 100644 index 00000000000..5deba58ae62 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -0,0 +1,597 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcAgentRuntime.cs + +using System.Collections.Concurrent; +using System.Threading.Channels; +using Google.Protobuf; +using Grpc.Core; +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +public sealed class GrpcAgentRuntime( + AgentRpc.AgentRpcClient client, + IHostApplicationLifetime hostApplicationLifetime, + IServiceProvider serviceProvider, + ILogger logger + ) : IAgentRuntime, IDisposable +{ + private readonly object _channelLock = new(); + + // Request ID -> + private readonly ConcurrentDictionary> _pendingRequests = new(); + private Dictionary>> agentFactories = new(); + private Dictionary agentInstances = new(); + + private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024) + { + AllowSynchronousContinuations = true, + SingleReader = true, + SingleWriter = false, + FullMode = BoundedChannelFullMode.Wait + }); + + private readonly AgentRpc.AgentRpcClient _client = client; + public readonly IServiceProvider ServiceProvider = serviceProvider; + + private readonly ILogger _logger = logger; + private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping); + private AsyncDuplexStreamingCall? _channel; + private Task? _readTask; + private Task? _writeTask; + + private string _clientId = Guid.NewGuid().ToString(); + private CallOptions CallOptions + { + get + { + var metadata = new Metadata + { + { "client-id", this._clientId } + }; + return new CallOptions(headers: metadata); + } + } + + public IProtoSerializationRegistry SerializationRegistry { get; } = new ProtoSerializationRegistry(); + + public void Dispose() + { + _outboundMessagesChannel.Writer.TryComplete(); + _channel?.Dispose(); + } + + private async Task RunReadPump() + { + var channel = GetChannel(); + while (!_shutdownCts.Token.IsCancellationRequested) + { + try + { + await foreach (var message in channel.ResponseStream.ReadAllAsync(_shutdownCts.Token)) + { + // next if message is null + if (message == null) + { + continue; + } + switch (message.MessageCase) + { + case Message.MessageOneofCase.Request: + var request = message.Request ?? throw new InvalidOperationException("Request is null."); + await HandleRequest(request); + break; + case Message.MessageOneofCase.Response: + var response = message.Response ?? throw new InvalidOperationException("Response is null."); + await HandleResponse(response); + break; + case Message.MessageOneofCase.CloudEvent: + var cloudEvent = message.CloudEvent ?? throw new InvalidOperationException("CloudEvent is null."); + await HandlePublish(cloudEvent); + break; + default: + throw new InvalidOperationException($"Unexpected message '{message}'."); + } + } + } + catch (OperationCanceledException) + { + // Time to shut down. + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + _logger.LogError(ex, "Error reading from channel."); + channel = RecreateChannel(channel); + } + catch + { + // Shutdown requested. + break; + } + } + } + + private async ValueTask HandleRequest(RpcRequest request, CancellationToken cancellationToken = default) + { + if (request is null) + { + throw new InvalidOperationException("Request is null."); + } + if (request.Payload is null) + { + throw new InvalidOperationException("Payload is null."); + } + if (request.Target is null) + { + throw new InvalidOperationException("Target is null."); + } + if (request.Source is null) + { + throw new InvalidOperationException("Source is null."); + } + + var agentId = request.Target; + var agent = await EnsureAgentAsync(agentId.FromProtobuf()); + + // Convert payload back to object + var payload = request.Payload; + var message = PayloadToObject(payload); + + var messageContext = new MessageContext(request.RequestId, cancellationToken) + { + Sender = request.Source.FromProtobuf(), + Topic = null, + IsRpc = true + }; + + var result = await agent.OnMessageAsync(message, messageContext); + + if (result is not null) + { + var response = new RpcResponse + { + RequestId = request.RequestId, + Payload = ObjectToPayload(result) + }; + + var responseMessage = new Message + { + Response = response + }; + + await WriteChannelAsync(responseMessage, cancellationToken); + } + } + + private async ValueTask HandleResponse(RpcResponse request, CancellationToken _ = default) + { + if (request is null) + { + throw new InvalidOperationException("Request is null."); + } + if (request.Payload is null) + { + throw new InvalidOperationException("Payload is null."); + } + if (request.RequestId is null) + { + throw new InvalidOperationException("RequestId is null."); + } + + if (_pendingRequests.TryRemove(request.RequestId, out var resultSink)) + { + var payload = request.Payload; + var message = PayloadToObject(payload); + resultSink.SetResult(message); + } + } + + private async ValueTask HandlePublish(CloudEvent evt, CancellationToken cancellationToken = default) + { + if (evt is null) + { + throw new InvalidOperationException("CloudEvent is null."); + } + if (evt.ProtoData is null) + { + throw new InvalidOperationException("ProtoData is null."); + } + if (evt.Attributes is null) + { + throw new InvalidOperationException("Attributes is null."); + } + + var topic = new TopicId(evt.Type, evt.Source); + var sender = new Contracts.AgentId + { + Type = evt.Attributes["agagentsendertype"].CeString, + Key = evt.Attributes["agagentsenderkey"].CeString + }; + + var messageId = evt.Id; + var typeName = evt.Attributes["dataschema"].CeString; + var serializer = SerializationRegistry.GetSerializer(typeName) ?? throw new Exception(); + var message = serializer.Deserialize(evt.ProtoData); + + var messageContext = new MessageContext(messageId, cancellationToken) + { + Sender = sender, + Topic = topic, + IsRpc = false + }; + var agent = await EnsureAgentAsync(sender); + await agent.OnMessageAsync(message, messageContext); + } + + private async Task RunWritePump() + { + var channel = GetChannel(); + var outboundMessages = _outboundMessagesChannel.Reader; + while (!_shutdownCts.IsCancellationRequested) + { + (Message Message, TaskCompletionSource WriteCompletionSource) item = default; + try + { + await outboundMessages.WaitToReadAsync().ConfigureAwait(false); + + // Read the next message if we don't already have an unsent message + // waiting to be sent. + if (!outboundMessages.TryRead(out item)) + { + break; + } + + while (!_shutdownCts.IsCancellationRequested) + { + await channel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false); + item.WriteCompletionSource.TrySetResult(); + break; + } + } + catch (OperationCanceledException) + { + // Time to shut down. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) + { + // we could not connect to the endpoint - most likely we have the wrong port or failed ssl + // we need to let the user know what port we tried to connect to and then do backoff and retry + _logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST")); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.OK) + { + _logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", channel.ToString()); + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + item.WriteCompletionSource?.TrySetException(ex); + _logger.LogError(ex, $"Error writing to channel.{ex}"); + channel = RecreateChannel(channel); + continue; + } + catch + { + // Shutdown requested. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + } + + while (outboundMessages.TryRead(out var item)) + { + item.WriteCompletionSource.TrySetCanceled(); + } + } + + // private override async ValueTask SendMessageAsync(Payload message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default) + // { + // var request = new RpcRequest + // { + // RequestId = Guid.NewGuid().ToString(), + // Source = agent, + // Target = agentId, + // Payload = message, + // }; + + // // Actually send it and wait for the response + // throw new NotImplementedException(); + // } + + // new is intentional + + // public new async ValueTask RuntimeSendRequestAsync(IAgent agent, RpcRequest request, CancellationToken cancellationToken = default) + // { + // var requestId = Guid.NewGuid().ToString(); + // _pendingRequests[requestId] = ((Agent)agent, request.RequestId); + // request.RequestId = requestId; + // await WriteChannelAsync(new Message { Request = request }, cancellationToken).ConfigureAwait(false); + // } + + private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken = default) + { + var tcs = new TaskCompletionSource(); + await _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellationToken).ConfigureAwait(false); + } + private AsyncDuplexStreamingCall GetChannel() + { + if (_channel is { } channel) + { + return channel; + } + + lock (_channelLock) + { + if (_channel is not null) + { + return _channel; + } + + return RecreateChannel(null); + } + } + + private AsyncDuplexStreamingCall RecreateChannel(AsyncDuplexStreamingCall? channel) + { + if (_channel is null || _channel == channel) + { + lock (_channelLock) + { + if (_channel is null || _channel == channel) + { + _channel?.Dispose(); + _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token); + } + } + } + + return _channel; + } + public async Task StartAsync(CancellationToken cancellationToken) + { + _channel = GetChannel(); + _logger.LogInformation("Starting " + GetType().Name + ",connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST")); + var didSuppress = false; + if (!ExecutionContext.IsFlowSuppressed()) + { + didSuppress = true; + ExecutionContext.SuppressFlow(); + } + + try + { + _readTask = Task.Run(RunReadPump, cancellationToken); + _writeTask = Task.Run(RunWritePump, cancellationToken); + } + finally + { + if (didSuppress) + { + ExecutionContext.RestoreFlow(); + } + } + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + _shutdownCts.Cancel(); + + _outboundMessagesChannel.Writer.TryComplete(); + + if (_readTask is { } readTask) + { + await readTask.ConfigureAwait(false); + } + + if (_writeTask is { } writeTask) + { + await writeTask.ConfigureAwait(false); + } + lock (_channelLock) + { + _channel?.Dispose(); + } + } + + private async ValueTask EnsureAgentAsync(Contracts.AgentId agentId) + { + if (!this.agentInstances.TryGetValue(agentId, out IHostableAgent? agent)) + { + if (!this.agentFactories.TryGetValue(agentId.Type, out Func>? factoryFunc)) + { + throw new Exception($"Agent with name {agentId.Type} not found."); + } + + agent = await factoryFunc(agentId, this); + this.agentInstances.Add(agentId, agent); + } + + return this.agentInstances[agentId]; + } + + private Payload ObjectToPayload(object message) { + if (!SerializationRegistry.Exists(message.GetType())) + { + SerializationRegistry.RegisterSerializer(message.GetType()); + } + var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + + var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message); + const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + + // Protobuf any to byte array + Payload payload = new() + { + DataType = typeName, + DataContentType = PAYLOAD_DATA_CONTENT_TYPE, + Data = rpcMessage.ToByteString() + }; + + return payload; + } + + private object PayloadToObject(Payload payload) { + var typeName = payload.DataType; + var data = payload.Data; + var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName); + var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception(); + var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); + return serializer.Deserialize(any); + } + + public async ValueTask SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default) + { + if (!SerializationRegistry.Exists(message.GetType())) + { + SerializationRegistry.RegisterSerializer(message.GetType()); + } + + var payload = ObjectToPayload(message); + var request = new RpcRequest + { + RequestId = Guid.NewGuid().ToString(), + Source = (sender ?? new Contracts.AgentId() ).ToProtobuf(), + Target = recepient.ToProtobuf(), + Payload = payload, + }; + + Message msg = new() + { + Request = request + }; + // Create a future that will be completed when the response is received + var resultSink = new ResultSink(); + this._pendingRequests.TryAdd(request.RequestId, resultSink); + await WriteChannelAsync(msg, cancellationToken); + + return await resultSink.Future; + } + + private CloudEvent CreateCloudEvent(Google.Protobuf.WellKnownTypes.Any payload, TopicId topic, string dataType, Contracts.AgentId sender, string messageId) + { + const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + return new CloudEvent + { + ProtoData = payload, + Type = topic.Type, + Source = topic.Source, + Id = messageId, + Attributes = { + { + "datacontenttype", new CloudEvent.Types.CloudEventAttributeValue { CeString = PAYLOAD_DATA_CONTENT_TYPE } + }, + { + "dataschema", new CloudEvent.Types.CloudEventAttributeValue { CeString = dataType } + }, + { + "agagentsendertype", new CloudEvent.Types.CloudEventAttributeValue { CeString = sender.Type } + }, + { + "agagentsenderkey", new CloudEvent.Types.CloudEventAttributeValue { CeString = sender.Key } + }, + { + "agmsgkind", new CloudEvent.Types.CloudEventAttributeValue { CeString = "publish" } + } + } + }; + } + + public async ValueTask PublishMessageAsync(object message, TopicId topic, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default) + { + if (!SerializationRegistry.Exists(message.GetType())) + { + SerializationRegistry.RegisterSerializer(message.GetType()); + } + var protoAny = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message); + + var cloudEvent = CreateCloudEvent(protoAny, topic, typeName, sender ?? new Contracts.AgentId(), messageId ?? Guid.NewGuid().ToString()); + + Message msg = new() + { + CloudEvent = cloudEvent + }; + await WriteChannelAsync(msg, cancellationToken); + } + + public ValueTask GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) + { + throw new NotImplementedException(); + } + + public ValueTask GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true) + { + throw new NotImplementedException(); + } + + public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true) + { + throw new NotImplementedException(); + } + + public ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) + { + throw new NotImplementedException(); + } + + public ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) + { + throw new NotImplementedException(); + } + + public ValueTask GetAgentMetadataAsync(Contracts.AgentId agentId) + { + throw new NotImplementedException(); + } + + public async ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription) + { + var _ = await this._client.AddSubscriptionAsync(new AddSubscriptionRequest{ + Subscription = subscription.ToProtobuf() + },this.CallOptions); + } + + public ValueTask RemoveSubscriptionAsync(string subscriptionId) + { + throw new NotImplementedException(); + } + + public ValueTask RegisterAgentFactoryAsync(AgentType type, Func> factoryFunc) + { + if (this.agentFactories.ContainsKey(type)) + { + throw new Exception($"Agent with type {type} already exists."); + } + this.agentFactories.Add(type, async (agentId, runtime) => await factoryFunc(agentId, runtime)); + + this._client.RegisterAgentAsync(new RegisterAgentTypeRequest + { + Type = type.Name, + + }, this.CallOptions); + return ValueTask.FromResult(type); + } + + public ValueTask TryGetAgentProxyAsync(Contracts.AgentId agentId) + { + throw new NotImplementedException(); + } + + public ValueTask> SaveStateAsync() + { + throw new NotImplementedException(); + } + + public ValueTask LoadStateAsync(IDictionary state) + { + throw new NotImplementedException(); + } +} + diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentWorkerHostBuilderExtension.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentWorkerHostBuilderExtension.cs new file mode 100644 index 00000000000..7f43b9620f5 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentWorkerHostBuilderExtension.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcAgentWorkerHostBuilderExtension.cs +using System.Diagnostics; +using Grpc.Core; +using Grpc.Net.Client.Configuration; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +namespace Microsoft.AutoGen.Core.Grpc; + +public static class GrpcAgentWorkerHostBuilderExtensions +{ + private const string _defaultAgentServiceAddress = "https://localhost:53071"; + + // TODO: How do we ensure AddGrpcAgentWorker and UseInProcessRuntime are mutually exclusive? + public static AgentsAppBuilder AddGrpcAgentWorker(this AgentsAppBuilder builder, string? agentServiceAddress = null) + { + builder.Services.AddGrpcClient(options => + { + options.Address = new Uri(agentServiceAddress ?? builder.Configuration["AGENT_HOST"] ?? _defaultAgentServiceAddress); + options.ChannelOptionsActions.Add(channelOptions => + { + var loggerFactory = new LoggerFactory(); + if (Debugger.IsAttached) + { + channelOptions.HttpHandler = new SocketsHttpHandler + { + EnableMultipleHttp2Connections = false, + KeepAlivePingDelay = TimeSpan.FromSeconds(200), + KeepAlivePingTimeout = TimeSpan.FromSeconds(100), + KeepAlivePingPolicy = HttpKeepAlivePingPolicy.Always + }; + } + else + { + channelOptions.HttpHandler = new SocketsHttpHandler + { + EnableMultipleHttp2Connections = true, + KeepAlivePingDelay = TimeSpan.FromSeconds(20), + KeepAlivePingTimeout = TimeSpan.FromSeconds(10), + KeepAlivePingPolicy = HttpKeepAlivePingPolicy.WithActiveRequests + }; + } + + var methodConfig = new MethodConfig + { + Names = { MethodName.Default }, + RetryPolicy = new RetryPolicy + { + MaxAttempts = 5, + InitialBackoff = TimeSpan.FromSeconds(1), + MaxBackoff = TimeSpan.FromSeconds(5), + BackoffMultiplier = 1.5, + RetryableStatusCodes = { StatusCode.Unavailable } + } + }; + + channelOptions.ServiceConfig = new() { MethodConfigs = { methodConfig } }; + channelOptions.ThrowOperationCanceledOnCancellation = true; + }); + }); + builder.Services.TryAddSingleton(DistributedContextPropagator.Current); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(sp => (IHostedService)sp.GetRequiredService()); + return builder; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs new file mode 100644 index 00000000000..0cc422d54d8 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IAgentMessageSerializer.cs + +namespace Microsoft.AutoGen.Core.Grpc; +/// +/// Interface for serializing and deserializing agent messages. +/// +public interface IAgentMessageSerializer +{ + /// + /// Serialize an agent message. + /// + /// The message to serialize. + /// The serialized message. + Google.Protobuf.WellKnownTypes.Any Serialize(object message); + + /// + /// Deserialize an agent message. + /// + /// The message to deserialize. + /// The deserialized message. + object Deserialize(Google.Protobuf.WellKnownTypes.Any message); +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs new file mode 100644 index 00000000000..8179ff4b494 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IAgentRuntimeExtensions.cs + +using System.Diagnostics; +using Google.Protobuf.Collections; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; +using Microsoft.Extensions.DependencyInjection; +using static Microsoft.AutoGen.Contracts.CloudEvent.Types; + +namespace Microsoft.AutoGen.Core.Grpc; + +public static class GrpcAgentRuntimeExtensions +{ + public static (string?, string?) GetTraceIdAndState(GrpcAgentRuntime runtime, IDictionary metadata) + { + var dcp = runtime.ServiceProvider.GetRequiredService(); + dcp.ExtractTraceIdAndState(metadata, + static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (IDictionary)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out fieldValue); + }, + out var traceParent, + out var traceState); + return (traceParent, traceState); + } + public static (string?, string?) GetTraceIdAndState(GrpcAgentRuntime worker, MapField metadata) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + dcp.ExtractTraceIdAndState(metadata, + static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (MapField)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out var ceValue); + fieldValue = ceValue?.CeString; + }, + out var traceParent, + out var traceState); + return (traceParent, traceState); + } + public static void Update(GrpcAgentRuntime worker, RpcRequest request, Activity? activity = null) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + dcp.Inject(activity, request.Metadata, static (carrier, key, value) => + { + var metadata = (IDictionary)carrier!; + if (metadata.TryGetValue(key, out _)) + { + metadata[key] = value; + } + else + { + metadata.Add(key, value); + } + }); + } + public static void Update(GrpcAgentRuntime worker, CloudEvent cloudEvent, Activity? activity = null) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + dcp.Inject(activity, cloudEvent.Attributes, static (carrier, key, value) => + { + var mapField = (MapField)carrier!; + if (mapField.TryGetValue(key, out var ceValue)) + { + mapField[key] = new CloudEventAttributeValue { CeString = value }; + } + else + { + mapField.Add(key, new CloudEventAttributeValue { CeString = value }); + } + }); + } + + public static IDictionary ExtractMetadata(GrpcAgentRuntime worker, IDictionary metadata) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (IDictionary)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out fieldValue); + }); + + return baggage as IDictionary ?? new Dictionary(); + } + public static IDictionary ExtractMetadata(GrpcAgentRuntime worker, MapField metadata) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (MapField)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out var ceValue); + fieldValue = ceValue?.CeString; + }); + + return baggage as IDictionary ?? new Dictionary(); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs new file mode 100644 index 00000000000..ca690e508d2 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IProtoMessageSerializer.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface IProtoMessageSerializer +{ + Google.Protobuf.WellKnownTypes.Any Serialize(object input); + object Deserialize(Google.Protobuf.WellKnownTypes.Any input); +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs new file mode 100644 index 00000000000..190ed3ec239 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ISerializationRegistry.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface IProtoSerializationRegistry +{ + /// + /// Registers a serializer for the specified type. + /// + /// The type to register. + void RegisterSerializer(System.Type type) => RegisterSerializer(type, new ProtobufMessageSerializer(type)); + + void RegisterSerializer(System.Type type, IProtoMessageSerializer serializer); + + /// + /// Gets the serializer for the specified type. + /// + /// The type to get the serializer for. + /// The serializer for the specified type. + IProtoMessageSerializer? GetSerializer(System.Type type) => GetSerializer(TypeNameResolver.ResolveTypeName(type)); + IProtoMessageSerializer? GetSerializer(string typeName); + + ITypeNameResolver TypeNameResolver { get; } + + bool Exists(System.Type type); +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs new file mode 100644 index 00000000000..24de4cb8b44 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ITypeNameResolver.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface ITypeNameResolver +{ + string ResolveTypeName(object input); +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoSerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoSerializationRegistry.cs new file mode 100644 index 00000000000..e744bcb0eee --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoSerializationRegistry.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtoSerializationRegistry.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public class ProtoSerializationRegistry : IProtoSerializationRegistry +{ + private readonly Dictionary _serializers + = new Dictionary(); + + public ITypeNameResolver TypeNameResolver => new ProtoTypeNameResolver(); + + public bool Exists(Type type) + { + return _serializers.ContainsKey(TypeNameResolver.ResolveTypeName(type)); + } + + public IProtoMessageSerializer? GetSerializer(Type type) + { + return GetSerializer(TypeNameResolver.ResolveTypeName(type)); + } + + public IProtoMessageSerializer? GetSerializer(string typeName) + { + _serializers.TryGetValue(typeName, out var serializer); + return serializer; + } + + public void RegisterSerializer(Type type, IProtoMessageSerializer serializer) + { + if (_serializers.ContainsKey(TypeNameResolver.ResolveTypeName(type))) + { + throw new InvalidOperationException($"Serializer already registered for {type.FullName}"); + } + _serializers[TypeNameResolver.ResolveTypeName(type)] = serializer; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs new file mode 100644 index 00000000000..a769b0f31c8 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtoTypeNameResolver.cs + +using Google.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +public class ProtoTypeNameResolver : ITypeNameResolver +{ + public string ResolveTypeName(object input) + { + if (input is IMessage protoMessage) + { + return protoMessage.Descriptor.FullName; + } + else + { + throw new ArgumentException("Input must be a protobuf message."); + } + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs new file mode 100644 index 00000000000..4850b7825af --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtobufConversionExtensions.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +public static class ProtobufConversionExtensions +{ + // Convert an ISubscrptionDefinition to a Protobuf Subscription + public static Subscription? ToProtobuf(this ISubscriptionDefinition subscriptionDefinition) + { + // Check if is a TypeSubscription + if (subscriptionDefinition is Contracts.TypeSubscription typeSubscription) + { + return new Subscription + { + Id = typeSubscription.Id, + TypeSubscription = new Protobuf.TypeSubscription + { + TopicType = typeSubscription.TopicType, + AgentType = typeSubscription.AgentType + } + }; + } + + // Check if is a TypePrefixSubscription + if (subscriptionDefinition is Contracts.TypePrefixSubscription typePrefixSubscription) + { + return new Subscription + { + Id = typePrefixSubscription.Id, + TypePrefixSubscription = new Protobuf.TypePrefixSubscription + { + TopicTypePrefix = typePrefixSubscription.TopicTypePrefix, + AgentType = typePrefixSubscription.AgentType + } + }; + } + + return null; + } + + // Convert AgentId from Protobuf to AgentId + public static Contracts.AgentId FromProtobuf(this Protobuf.AgentId agentId) + { + return new Contracts.AgentId(agentId.Type, agentId.Key); + } + + // Convert AgentId from AgentId to Protobuf + public static Protobuf.AgentId ToProtobuf(this Contracts.AgentId agentId) + { + return new Protobuf.AgentId + { + Type = agentId.Type, + Key = agentId.Key + }; + } + +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs new file mode 100644 index 00000000000..55c1aebfa47 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtobufMessageSerializer.cs + +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; + +namespace Microsoft.AutoGen.Core.Grpc; + +/// +/// Interface for serializing and deserializing agent messages. +/// +public class ProtobufMessageSerializer : IProtoMessageSerializer +{ + private System.Type _concreteType; + + public ProtobufMessageSerializer(System.Type concreteType) + { + _concreteType = concreteType; + } + + public object Deserialize(Any message) + { + // Check if the concrete type is a proto IMessage + if (typeof(IMessage).IsAssignableFrom(_concreteType)) + { + var nameOfMethod = nameof(Any.Unpack); + var result = message.GetType().GetMethods().Where(m => m.Name == nameOfMethod && m.IsGenericMethod).First().MakeGenericMethod(_concreteType).Invoke(message, null); + return result as IMessage ?? throw new ArgumentException("Failed to deserialize", nameof(message)); + } + + // Raise an exception if the concrete type is not a proto IMessage + throw new ArgumentException("Concrete type must be a proto IMessage", nameof(_concreteType)); + } + + public Any Serialize(object message) + { + // Check if message is a proto IMessage + if (message is IMessage protoMessage) + { + return Any.Pack(protoMessage); + } + + // Raise an exception if the message is not a proto IMessage + throw new ArgumentException("Message must be a proto IMessage", nameof(message)); + } +} \ No newline at end of file diff --git a/dotnet/src/Microsoft.AutoGen/Core/AgentsApp.cs b/dotnet/src/Microsoft.AutoGen/Core/AgentsApp.cs index bae09a9f191..cecd8d9ec48 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/AgentsApp.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/AgentsApp.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Reflection; using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -21,6 +22,7 @@ public AgentsAppBuilder(HostApplicationBuilder? baseBuilder = null) } public IServiceCollection Services => this.builder.Services; + public IConfiguration Configuration => this.builder.Configuration; public void AddAgentsFromAssemblies() { From 593273491f557bfe33f41602561a69711e6c8bfc Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Wed, 29 Jan 2025 23:03:26 -0500 Subject: [PATCH 2/8] refactor: Extract Message channel logic to MessageRouter --- .../Core.Grpc/GrpcAgentRuntime.cs | 509 +++++++++++------- 1 file changed, 312 insertions(+), 197 deletions(-) diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index 5deba58ae62..1797182a970 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -12,89 +12,130 @@ namespace Microsoft.AutoGen.Core.Grpc; -public sealed class GrpcAgentRuntime( - AgentRpc.AgentRpcClient client, - IHostApplicationLifetime hostApplicationLifetime, - IServiceProvider serviceProvider, - ILogger logger - ) : IAgentRuntime, IDisposable +// TODO: Consider whether we want to just reuse IHandle +internal interface IMessageSink { - private readonly object _channelLock = new(); + public ValueTask OnMessageAsync(TMessage message, CancellationToken cancellation = default); +} - // Request ID -> - private readonly ConcurrentDictionary> _pendingRequests = new(); - private Dictionary>> agentFactories = new(); - private Dictionary agentInstances = new(); +internal sealed class AutoRestartChannel : IDisposable +{ + private readonly object _channelLock = new(); + private readonly AgentRpc.AgentRpcClient _client; + private readonly ILogger _logger; + private readonly CancellationTokenSource _shutdownCts; + private AsyncDuplexStreamingCall? _channel; - private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024) + public AutoRestartChannel(AgentRpc.AgentRpcClient client, + ILogger logger, + CancellationToken shutdownCancellation = default) { - AllowSynchronousContinuations = true, - SingleReader = true, - SingleWriter = false, - FullMode = BoundedChannelFullMode.Wait - }); + _client = client; + _logger = logger; + _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation); + } - private readonly AgentRpc.AgentRpcClient _client = client; - public readonly IServiceProvider ServiceProvider = serviceProvider; + public void EnsureConnected() + { + _logger.LogInformation("Connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST")); - private readonly ILogger _logger = logger; - private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping); - private AsyncDuplexStreamingCall? _channel; - private Task? _readTask; - private Task? _writeTask; + if (this.RecreateChannel(null) == null) + { + throw new Exception("Failed to connect to gRPC endpoint."); + }; + } - private string _clientId = Guid.NewGuid().ToString(); - private CallOptions CallOptions + public AsyncDuplexStreamingCall StreamingCall { get { - var metadata = new Metadata + if (_channel is { } channel) { - { "client-id", this._clientId } - }; - return new CallOptions(headers: metadata); + return channel; + } + + lock (_channelLock) + { + if (_channel is not null) + { + return _channel; + } + + return RecreateChannel(null); + } } } - public IProtoSerializationRegistry SerializationRegistry { get; } = new ProtoSerializationRegistry(); + public AsyncDuplexStreamingCall RecreateChannel() => RecreateChannel(this._channel); + + private AsyncDuplexStreamingCall RecreateChannel(AsyncDuplexStreamingCall? ownedChannel) + { + // Make sure we are only re-creating the channel if it does not exit or we are the owner. + if (_channel is null || _channel == ownedChannel) + { + lock (_channelLock) + { + if (_channel is null || _channel == ownedChannel) + { + _channel?.Dispose(); + _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token); + } + } + } + + return _channel; + } public void Dispose() { - _outboundMessagesChannel.Writer.TryComplete(); - _channel?.Dispose(); + IDisposable? channelDisposable = Interlocked.Exchange(ref this._channel, null); + channelDisposable?.Dispose(); } +} + +internal sealed class MessageRouter(AgentRpc.AgentRpcClient client, + IMessageSink incomingMessageSink, + ILogger logger, + CancellationToken shutdownCancellation = default) : IDisposable +{ + private static readonly BoundedChannelOptions DefaultChannelOptions = new BoundedChannelOptions(1024) + { + AllowSynchronousContinuations = true, + SingleReader = true, + SingleWriter = false, + FullMode = BoundedChannelFullMode.Wait + }; + + private readonly ILogger _logger = logger; + + private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation); + + private readonly IMessageSink _incomingMessageSink = incomingMessageSink; + private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel + // TODO: Enable a way to configure the channel options + = Channel.CreateBounded<(Message, TaskCompletionSource)>(DefaultChannelOptions); + + private readonly AutoRestartChannel _incomingMessageChannel = new AutoRestartChannel(client, logger, shutdownCancellation); + + private Task? _readTask; + private Task? _writeTask; private async Task RunReadPump() { - var channel = GetChannel(); + var cachedChannel = _incomingMessageChannel.StreamingCall; while (!_shutdownCts.Token.IsCancellationRequested) { try { - await foreach (var message in channel.ResponseStream.ReadAllAsync(_shutdownCts.Token)) + await foreach (var message in cachedChannel.ResponseStream.ReadAllAsync(_shutdownCts.Token)) { // next if message is null if (message == null) { continue; } - switch (message.MessageCase) - { - case Message.MessageOneofCase.Request: - var request = message.Request ?? throw new InvalidOperationException("Request is null."); - await HandleRequest(request); - break; - case Message.MessageOneofCase.Response: - var response = message.Response ?? throw new InvalidOperationException("Response is null."); - await HandleResponse(response); - break; - case Message.MessageOneofCase.CloudEvent: - var cloudEvent = message.CloudEvent ?? throw new InvalidOperationException("CloudEvent is null."); - await HandlePublish(cloudEvent); - break; - default: - throw new InvalidOperationException($"Unexpected message '{message}'."); - } + + await _incomingMessageSink.OnMessageAsync(message, _shutdownCts.Token); } } catch (OperationCanceledException) @@ -105,14 +146,199 @@ private async Task RunReadPump() catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) { _logger.LogError(ex, "Error reading from channel."); - channel = RecreateChannel(channel); + cachedChannel = this._incomingMessageChannel.RecreateChannel(); + } + catch + { + // Shutdown requested. + break; + } + } + } + + private async Task RunWritePump() + { + var cachedChannel = this._incomingMessageChannel.StreamingCall; + var outboundMessages = _outboundMessagesChannel.Reader; + while (!_shutdownCts.IsCancellationRequested) + { + (Message Message, TaskCompletionSource WriteCompletionSource) item = default; + try + { + await outboundMessages.WaitToReadAsync().ConfigureAwait(false); + + // Read the next message if we don't already have an unsent message + // waiting to be sent. + if (!outboundMessages.TryRead(out item)) + { + break; + } + + while (!_shutdownCts.IsCancellationRequested) + { + await cachedChannel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false); + item.WriteCompletionSource.TrySetResult(); + break; + } + } + catch (OperationCanceledException) + { + // Time to shut down. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) + { + // we could not connect to the endpoint - most likely we have the wrong port or failed ssl + // we need to let the user know what port we tried to connect to and then do backoff and retry + _logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST")); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.OK) + { + _logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", cachedChannel.ToString()); + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + item.WriteCompletionSource?.TrySetException(ex); + _logger.LogError(ex, $"Error writing to channel.{ex}"); + cachedChannel = this._incomingMessageChannel.RecreateChannel(); + continue; } catch { // Shutdown requested. + item.WriteCompletionSource?.TrySetCanceled(); break; } } + + while (outboundMessages.TryRead(out var item)) + { + item.WriteCompletionSource.TrySetCanceled(); + } + } + + public ValueTask RouteMessageAsync(Message message, CancellationToken cancellation = default) + { + var tcs = new TaskCompletionSource(); + return _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellation); + } + + public ValueTask StartAsync(CancellationToken cancellation) + { + // TODO: Should we error out on a noncancellable token? + + this._incomingMessageChannel.EnsureConnected(); + var didSuppress = false; + + // Make sure we do not mistakenly flow the ExecutionContext into the background pumping tasks. + if (!ExecutionContext.IsFlowSuppressed()) + { + didSuppress = true; + ExecutionContext.SuppressFlow(); + } + + try + { + _readTask = Task.Run(RunReadPump, cancellation); + _writeTask = Task.Run(RunWritePump, cancellation); + + return ValueTask.CompletedTask; + } + catch (Exception ex) + { + return ValueTask.FromException(ex); + } + finally + { + if (didSuppress) + { + ExecutionContext.RestoreFlow(); + } + } + } + + // No point in returning a ValueTask here, since we are awaiting the two tasks + public async Task StopAsync() + { + _shutdownCts.Cancel(); + + _outboundMessagesChannel.Writer.TryComplete(); + + List pendingTasks = new(); + if (_readTask is { } readTask) + { + pendingTasks.Add(readTask); + } + + if (_writeTask is { } writeTask) + { + pendingTasks.Add(writeTask); + } + + await Task.WhenAll(pendingTasks).ConfigureAwait(false); + + this._incomingMessageChannel.Dispose(); + } + + public void Dispose() + { + _outboundMessagesChannel.Writer.TryComplete(); + this._incomingMessageChannel.Dispose(); + } +} + +public sealed class GrpcAgentRuntime: IHostedService, IAgentRuntime, IMessageSink, IDisposable +{ + public GrpcAgentRuntime(AgentRpc.AgentRpcClient client, + IHostApplicationLifetime hostApplicationLifetime, + IServiceProvider serviceProvider, + ILogger logger) + { + this._client = client; + this._logger = logger; + this._shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping); + + this._messageRouter = new MessageRouter(client, this, logger, this._shutdownCts.Token); + + this.ServiceProvider = serviceProvider; + } + + // Request ID -> + private readonly ConcurrentDictionary> _pendingRequests = new(); + + private Dictionary>> agentFactories = new(); + private Dictionary agentInstances = new(); + + private readonly AgentRpc.AgentRpcClient _client; + private readonly MessageRouter _messageRouter; + + private readonly ILogger _logger; + private readonly CancellationTokenSource _shutdownCts; + + public IServiceProvider ServiceProvider { get; } + + private string _clientId = Guid.NewGuid().ToString(); + private CallOptions CallOptions + { + get + { + var metadata = new Metadata + { + { "client-id", this._clientId } + }; + return new CallOptions(headers: metadata); + } + } + + public IProtoSerializationRegistry SerializationRegistry { get; } = new ProtoSerializationRegistry(); + + public void Dispose() + { + this._shutdownCts.Cancel(); + this._messageRouter.Dispose(); } private async ValueTask HandleRequest(RpcRequest request, CancellationToken cancellationToken = default) @@ -163,7 +389,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc Response = response }; - await WriteChannelAsync(responseMessage, cancellationToken); + await this._messageRouter.RouteMessageAsync(responseMessage, cancellationToken); } } @@ -227,69 +453,7 @@ private async ValueTask HandlePublish(CloudEvent evt, CancellationToken cancella await agent.OnMessageAsync(message, messageContext); } - private async Task RunWritePump() - { - var channel = GetChannel(); - var outboundMessages = _outboundMessagesChannel.Reader; - while (!_shutdownCts.IsCancellationRequested) - { - (Message Message, TaskCompletionSource WriteCompletionSource) item = default; - try - { - await outboundMessages.WaitToReadAsync().ConfigureAwait(false); - - // Read the next message if we don't already have an unsent message - // waiting to be sent. - if (!outboundMessages.TryRead(out item)) - { - break; - } - - while (!_shutdownCts.IsCancellationRequested) - { - await channel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false); - item.WriteCompletionSource.TrySetResult(); - break; - } - } - catch (OperationCanceledException) - { - // Time to shut down. - item.WriteCompletionSource?.TrySetCanceled(); - break; - } - catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) - { - // we could not connect to the endpoint - most likely we have the wrong port or failed ssl - // we need to let the user know what port we tried to connect to and then do backoff and retry - _logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST")); - break; - } - catch (RpcException ex) when (ex.StatusCode == StatusCode.OK) - { - _logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", channel.ToString()); - break; - } - catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) - { - item.WriteCompletionSource?.TrySetException(ex); - _logger.LogError(ex, $"Error writing to channel.{ex}"); - channel = RecreateChannel(channel); - continue; - } - catch - { - // Shutdown requested. - item.WriteCompletionSource?.TrySetCanceled(); - break; - } - } - - while (outboundMessages.TryRead(out var item)) - { - item.WriteCompletionSource.TrySetCanceled(); - } - } + // private override async ValueTask SendMessageAsync(Payload message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default) // { @@ -315,89 +479,17 @@ private async Task RunWritePump() // await WriteChannelAsync(new Message { Request = request }, cancellationToken).ConfigureAwait(false); // } - private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken = default) + + public ValueTask StartAsync(CancellationToken cancellationToken) { - var tcs = new TaskCompletionSource(); - await _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellationToken).ConfigureAwait(false); + return this._messageRouter.StartAsync(cancellationToken); } - private AsyncDuplexStreamingCall GetChannel() - { - if (_channel is { } channel) - { - return channel; - } - lock (_channelLock) - { - if (_channel is not null) - { - return _channel; - } + Task IHostedService.StartAsync(CancellationToken cancellationToken) => this._messageRouter.StartAsync(cancellationToken).AsTask(); - return RecreateChannel(null); - } - } - - private AsyncDuplexStreamingCall RecreateChannel(AsyncDuplexStreamingCall? channel) + public Task StopAsync(CancellationToken cancellationToken) { - if (_channel is null || _channel == channel) - { - lock (_channelLock) - { - if (_channel is null || _channel == channel) - { - _channel?.Dispose(); - _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token); - } - } - } - - return _channel; - } - public async Task StartAsync(CancellationToken cancellationToken) - { - _channel = GetChannel(); - _logger.LogInformation("Starting " + GetType().Name + ",connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST")); - var didSuppress = false; - if (!ExecutionContext.IsFlowSuppressed()) - { - didSuppress = true; - ExecutionContext.SuppressFlow(); - } - - try - { - _readTask = Task.Run(RunReadPump, cancellationToken); - _writeTask = Task.Run(RunWritePump, cancellationToken); - } - finally - { - if (didSuppress) - { - ExecutionContext.RestoreFlow(); - } - } - } - - public async Task StopAsync(CancellationToken cancellationToken) - { - _shutdownCts.Cancel(); - - _outboundMessagesChannel.Writer.TryComplete(); - - if (_readTask is { } readTask) - { - await readTask.ConfigureAwait(false); - } - - if (_writeTask is { } writeTask) - { - await writeTask.ConfigureAwait(false); - } - lock (_channelLock) - { - _channel?.Dispose(); - } + return this._messageRouter.StopAsync(); } private async ValueTask EnsureAgentAsync(Contracts.AgentId agentId) @@ -466,10 +558,11 @@ private object PayloadToObject(Payload payload) { { Request = request }; + // Create a future that will be completed when the response is received var resultSink = new ResultSink(); this._pendingRequests.TryAdd(request.RequestId, resultSink); - await WriteChannelAsync(msg, cancellationToken); + await this._messageRouter.RouteMessageAsync(msg, cancellationToken); return await resultSink.Future; } @@ -518,7 +611,8 @@ public async ValueTask PublishMessageAsync(object message, TopicId topic, Contra { CloudEvent = cloudEvent }; - await WriteChannelAsync(msg, cancellationToken); + + await this._messageRouter.RouteMessageAsync(msg, cancellationToken); } public ValueTask GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) @@ -593,5 +687,26 @@ public ValueTask LoadStateAsync(IDictionary state) { throw new NotImplementedException(); } + + public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default) + { + switch (message.MessageCase) + { + case Message.MessageOneofCase.Request: + var request = message.Request ?? throw new InvalidOperationException("Request is null."); + await HandleRequest(request); + break; + case Message.MessageOneofCase.Response: + var response = message.Response ?? throw new InvalidOperationException("Response is null."); + await HandleResponse(response); + break; + case Message.MessageOneofCase.CloudEvent: + var cloudEvent = message.CloudEvent ?? throw new InvalidOperationException("CloudEvent is null."); + await HandlePublish(cloudEvent); + break; + default: + throw new InvalidOperationException($"Unexpected message '{message}'."); + } + } } From 8410bfac1e34fd4914ddb47f552c159168b97e28 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Wed, 29 Jan 2025 23:05:20 -0500 Subject: [PATCH 3/8] refactor: Move GrpcMessageRouter to own file --- .../Core.Grpc/GrpcAgentRuntime.cs | 281 +---------------- .../Core.Grpc/GrpcMessageRouter.cs | 288 ++++++++++++++++++ 2 files changed, 290 insertions(+), 279 deletions(-) create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index 1797182a970..b1f5d43669e 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -2,7 +2,6 @@ // GrpcAgentRuntime.cs using System.Collections.Concurrent; -using System.Threading.Channels; using Google.Protobuf; using Grpc.Core; using Microsoft.AutoGen.Contracts; @@ -12,283 +11,7 @@ namespace Microsoft.AutoGen.Core.Grpc; -// TODO: Consider whether we want to just reuse IHandle -internal interface IMessageSink -{ - public ValueTask OnMessageAsync(TMessage message, CancellationToken cancellation = default); -} - -internal sealed class AutoRestartChannel : IDisposable -{ - private readonly object _channelLock = new(); - private readonly AgentRpc.AgentRpcClient _client; - private readonly ILogger _logger; - private readonly CancellationTokenSource _shutdownCts; - private AsyncDuplexStreamingCall? _channel; - - public AutoRestartChannel(AgentRpc.AgentRpcClient client, - ILogger logger, - CancellationToken shutdownCancellation = default) - { - _client = client; - _logger = logger; - _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation); - } - - public void EnsureConnected() - { - _logger.LogInformation("Connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST")); - - if (this.RecreateChannel(null) == null) - { - throw new Exception("Failed to connect to gRPC endpoint."); - }; - } - - public AsyncDuplexStreamingCall StreamingCall - { - get - { - if (_channel is { } channel) - { - return channel; - } - - lock (_channelLock) - { - if (_channel is not null) - { - return _channel; - } - - return RecreateChannel(null); - } - } - } - - public AsyncDuplexStreamingCall RecreateChannel() => RecreateChannel(this._channel); - - private AsyncDuplexStreamingCall RecreateChannel(AsyncDuplexStreamingCall? ownedChannel) - { - // Make sure we are only re-creating the channel if it does not exit or we are the owner. - if (_channel is null || _channel == ownedChannel) - { - lock (_channelLock) - { - if (_channel is null || _channel == ownedChannel) - { - _channel?.Dispose(); - _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token); - } - } - } - - return _channel; - } - - public void Dispose() - { - IDisposable? channelDisposable = Interlocked.Exchange(ref this._channel, null); - channelDisposable?.Dispose(); - } -} - -internal sealed class MessageRouter(AgentRpc.AgentRpcClient client, - IMessageSink incomingMessageSink, - ILogger logger, - CancellationToken shutdownCancellation = default) : IDisposable -{ - private static readonly BoundedChannelOptions DefaultChannelOptions = new BoundedChannelOptions(1024) - { - AllowSynchronousContinuations = true, - SingleReader = true, - SingleWriter = false, - FullMode = BoundedChannelFullMode.Wait - }; - - private readonly ILogger _logger = logger; - - private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation); - - private readonly IMessageSink _incomingMessageSink = incomingMessageSink; - private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel - // TODO: Enable a way to configure the channel options - = Channel.CreateBounded<(Message, TaskCompletionSource)>(DefaultChannelOptions); - - private readonly AutoRestartChannel _incomingMessageChannel = new AutoRestartChannel(client, logger, shutdownCancellation); - - private Task? _readTask; - private Task? _writeTask; - - private async Task RunReadPump() - { - var cachedChannel = _incomingMessageChannel.StreamingCall; - while (!_shutdownCts.Token.IsCancellationRequested) - { - try - { - await foreach (var message in cachedChannel.ResponseStream.ReadAllAsync(_shutdownCts.Token)) - { - // next if message is null - if (message == null) - { - continue; - } - - await _incomingMessageSink.OnMessageAsync(message, _shutdownCts.Token); - } - } - catch (OperationCanceledException) - { - // Time to shut down. - break; - } - catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) - { - _logger.LogError(ex, "Error reading from channel."); - cachedChannel = this._incomingMessageChannel.RecreateChannel(); - } - catch - { - // Shutdown requested. - break; - } - } - } - - private async Task RunWritePump() - { - var cachedChannel = this._incomingMessageChannel.StreamingCall; - var outboundMessages = _outboundMessagesChannel.Reader; - while (!_shutdownCts.IsCancellationRequested) - { - (Message Message, TaskCompletionSource WriteCompletionSource) item = default; - try - { - await outboundMessages.WaitToReadAsync().ConfigureAwait(false); - - // Read the next message if we don't already have an unsent message - // waiting to be sent. - if (!outboundMessages.TryRead(out item)) - { - break; - } - while (!_shutdownCts.IsCancellationRequested) - { - await cachedChannel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false); - item.WriteCompletionSource.TrySetResult(); - break; - } - } - catch (OperationCanceledException) - { - // Time to shut down. - item.WriteCompletionSource?.TrySetCanceled(); - break; - } - catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) - { - // we could not connect to the endpoint - most likely we have the wrong port or failed ssl - // we need to let the user know what port we tried to connect to and then do backoff and retry - _logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST")); - break; - } - catch (RpcException ex) when (ex.StatusCode == StatusCode.OK) - { - _logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", cachedChannel.ToString()); - break; - } - catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) - { - item.WriteCompletionSource?.TrySetException(ex); - _logger.LogError(ex, $"Error writing to channel.{ex}"); - cachedChannel = this._incomingMessageChannel.RecreateChannel(); - continue; - } - catch - { - // Shutdown requested. - item.WriteCompletionSource?.TrySetCanceled(); - break; - } - } - - while (outboundMessages.TryRead(out var item)) - { - item.WriteCompletionSource.TrySetCanceled(); - } - } - - public ValueTask RouteMessageAsync(Message message, CancellationToken cancellation = default) - { - var tcs = new TaskCompletionSource(); - return _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellation); - } - - public ValueTask StartAsync(CancellationToken cancellation) - { - // TODO: Should we error out on a noncancellable token? - - this._incomingMessageChannel.EnsureConnected(); - var didSuppress = false; - - // Make sure we do not mistakenly flow the ExecutionContext into the background pumping tasks. - if (!ExecutionContext.IsFlowSuppressed()) - { - didSuppress = true; - ExecutionContext.SuppressFlow(); - } - - try - { - _readTask = Task.Run(RunReadPump, cancellation); - _writeTask = Task.Run(RunWritePump, cancellation); - - return ValueTask.CompletedTask; - } - catch (Exception ex) - { - return ValueTask.FromException(ex); - } - finally - { - if (didSuppress) - { - ExecutionContext.RestoreFlow(); - } - } - } - - // No point in returning a ValueTask here, since we are awaiting the two tasks - public async Task StopAsync() - { - _shutdownCts.Cancel(); - - _outboundMessagesChannel.Writer.TryComplete(); - - List pendingTasks = new(); - if (_readTask is { } readTask) - { - pendingTasks.Add(readTask); - } - - if (_writeTask is { } writeTask) - { - pendingTasks.Add(writeTask); - } - - await Task.WhenAll(pendingTasks).ConfigureAwait(false); - - this._incomingMessageChannel.Dispose(); - } - - public void Dispose() - { - _outboundMessagesChannel.Writer.TryComplete(); - this._incomingMessageChannel.Dispose(); - } -} public sealed class GrpcAgentRuntime: IHostedService, IAgentRuntime, IMessageSink, IDisposable { @@ -301,7 +24,7 @@ public GrpcAgentRuntime(AgentRpc.AgentRpcClient client, this._logger = logger; this._shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping); - this._messageRouter = new MessageRouter(client, this, logger, this._shutdownCts.Token); + this._messageRouter = new GrpcMessageRouter(client, this, logger, this._shutdownCts.Token); this.ServiceProvider = serviceProvider; } @@ -313,7 +36,7 @@ public GrpcAgentRuntime(AgentRpc.AgentRpcClient client, private Dictionary agentInstances = new(); private readonly AgentRpc.AgentRpcClient _client; - private readonly MessageRouter _messageRouter; + private readonly GrpcMessageRouter _messageRouter; private readonly ILogger _logger; private readonly CancellationTokenSource _shutdownCts; diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs new file mode 100644 index 00000000000..0ecf8ee1413 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcMessageRouter.cs + +using System.Threading.Channels; +using Grpc.Core; +using Microsoft.Extensions.Logging; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +// TODO: Consider whether we want to just reuse IHandle +internal interface IMessageSink +{ + public ValueTask OnMessageAsync(TMessage message, CancellationToken cancellation = default); +} + +internal sealed class AutoRestartChannel : IDisposable +{ + private readonly object _channelLock = new(); + private readonly AgentRpc.AgentRpcClient _client; + private readonly ILogger _logger; + private readonly CancellationTokenSource _shutdownCts; + private AsyncDuplexStreamingCall? _channel; + + public AutoRestartChannel(AgentRpc.AgentRpcClient client, + ILogger logger, + CancellationToken shutdownCancellation = default) + { + _client = client; + _logger = logger; + _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation); + } + + public void EnsureConnected() + { + _logger.LogInformation("Connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST")); + + if (this.RecreateChannel(null) == null) + { + throw new Exception("Failed to connect to gRPC endpoint."); + }; + } + + public AsyncDuplexStreamingCall StreamingCall + { + get + { + if (_channel is { } channel) + { + return channel; + } + + lock (_channelLock) + { + if (_channel is not null) + { + return _channel; + } + + return RecreateChannel(null); + } + } + } + + public AsyncDuplexStreamingCall RecreateChannel() => RecreateChannel(this._channel); + + private AsyncDuplexStreamingCall RecreateChannel(AsyncDuplexStreamingCall? ownedChannel) + { + // Make sure we are only re-creating the channel if it does not exit or we are the owner. + if (_channel is null || _channel == ownedChannel) + { + lock (_channelLock) + { + if (_channel is null || _channel == ownedChannel) + { + _channel?.Dispose(); + _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token); + } + } + } + + return _channel; + } + + public void Dispose() + { + IDisposable? channelDisposable = Interlocked.Exchange(ref this._channel, null); + channelDisposable?.Dispose(); + } +} + +internal sealed class GrpcMessageRouter(AgentRpc.AgentRpcClient client, + IMessageSink incomingMessageSink, + ILogger logger, + CancellationToken shutdownCancellation = default) : IDisposable +{ + private static readonly BoundedChannelOptions DefaultChannelOptions = new BoundedChannelOptions(1024) + { + AllowSynchronousContinuations = true, + SingleReader = true, + SingleWriter = false, + FullMode = BoundedChannelFullMode.Wait + }; + + private readonly ILogger _logger = logger; + + private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation); + + private readonly IMessageSink _incomingMessageSink = incomingMessageSink; + private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel + // TODO: Enable a way to configure the channel options + = Channel.CreateBounded<(Message, TaskCompletionSource)>(DefaultChannelOptions); + + private readonly AutoRestartChannel _incomingMessageChannel = new AutoRestartChannel(client, logger, shutdownCancellation); + + private Task? _readTask; + private Task? _writeTask; + + private async Task RunReadPump() + { + var cachedChannel = _incomingMessageChannel.StreamingCall; + while (!_shutdownCts.Token.IsCancellationRequested) + { + try + { + await foreach (var message in cachedChannel.ResponseStream.ReadAllAsync(_shutdownCts.Token)) + { + // next if message is null + if (message == null) + { + continue; + } + + await _incomingMessageSink.OnMessageAsync(message, _shutdownCts.Token); + } + } + catch (OperationCanceledException) + { + // Time to shut down. + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + _logger.LogError(ex, "Error reading from channel."); + cachedChannel = this._incomingMessageChannel.RecreateChannel(); + } + catch + { + // Shutdown requested. + break; + } + } + } + + private async Task RunWritePump() + { + var cachedChannel = this._incomingMessageChannel.StreamingCall; + var outboundMessages = _outboundMessagesChannel.Reader; + while (!_shutdownCts.IsCancellationRequested) + { + (Message Message, TaskCompletionSource WriteCompletionSource) item = default; + try + { + await outboundMessages.WaitToReadAsync().ConfigureAwait(false); + + // Read the next message if we don't already have an unsent message + // waiting to be sent. + if (!outboundMessages.TryRead(out item)) + { + break; + } + + while (!_shutdownCts.IsCancellationRequested) + { + await cachedChannel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false); + item.WriteCompletionSource.TrySetResult(); + break; + } + } + catch (OperationCanceledException) + { + // Time to shut down. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) + { + // we could not connect to the endpoint - most likely we have the wrong port or failed ssl + // we need to let the user know what port we tried to connect to and then do backoff and retry + _logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST")); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.OK) + { + _logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", cachedChannel.ToString()); + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + item.WriteCompletionSource?.TrySetException(ex); + _logger.LogError(ex, $"Error writing to channel.{ex}"); + cachedChannel = this._incomingMessageChannel.RecreateChannel(); + continue; + } + catch + { + // Shutdown requested. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + } + + while (outboundMessages.TryRead(out var item)) + { + item.WriteCompletionSource.TrySetCanceled(); + } + } + + public ValueTask RouteMessageAsync(Message message, CancellationToken cancellation = default) + { + var tcs = new TaskCompletionSource(); + return _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellation); + } + + public ValueTask StartAsync(CancellationToken cancellation) + { + // TODO: Should we error out on a noncancellable token? + + this._incomingMessageChannel.EnsureConnected(); + var didSuppress = false; + + // Make sure we do not mistakenly flow the ExecutionContext into the background pumping tasks. + if (!ExecutionContext.IsFlowSuppressed()) + { + didSuppress = true; + ExecutionContext.SuppressFlow(); + } + + try + { + _readTask = Task.Run(RunReadPump, cancellation); + _writeTask = Task.Run(RunWritePump, cancellation); + + return ValueTask.CompletedTask; + } + catch (Exception ex) + { + return ValueTask.FromException(ex); + } + finally + { + if (didSuppress) + { + ExecutionContext.RestoreFlow(); + } + } + } + + // No point in returning a ValueTask here, since we are awaiting the two tasks + public async Task StopAsync() + { + _shutdownCts.Cancel(); + + _outboundMessagesChannel.Writer.TryComplete(); + + List pendingTasks = new(); + if (_readTask is { } readTask) + { + pendingTasks.Add(readTask); + } + + if (_writeTask is { } writeTask) + { + pendingTasks.Add(writeTask); + } + + await Task.WhenAll(pendingTasks).ConfigureAwait(false); + + this._incomingMessageChannel.Dispose(); + } + + public void Dispose() + { + _outboundMessagesChannel.Writer.TryComplete(); + this._incomingMessageChannel.Dispose(); + } +} + From f3e93e3a3947f5b395064cc4e253253852fe9b5c Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Wed, 29 Jan 2025 23:51:03 -0500 Subject: [PATCH 4/8] refactor: Add default IAgentRuntime.GetAgent implementations --- dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs | 6 ++++-- dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs | 6 ------ dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs index 0d84fbe72d3..bb360617dc0 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs @@ -53,7 +53,8 @@ public interface IAgentRuntime : ISaveState /// An optional key to specify variations of the agent. Defaults to "default". /// If true, the agent is fetched lazily. /// A task representing the asynchronous operation, returning the agent's ID. - public ValueTask GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true/*, CancellationToken? = default*/); + public ValueTask GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true/*, CancellationToken? = default*/) + => this.GetAgentAsync(new AgentId(agentType, key), lazy); /// /// Retrieves an agent by its string representation. @@ -62,7 +63,8 @@ public interface IAgentRuntime : ISaveState /// An optional key to specify variations of the agent. Defaults to "default". /// If true, the agent is fetched lazily. /// A task representing the asynchronous operation, returning the agent's ID. - public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true/*, CancellationToken? = default*/); + public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true/*, CancellationToken? = default*/) + => this.GetAgentAsync(new AgentId(agent, key), lazy); /// /// Saves the state of an agent. diff --git a/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs index 69b2d314e55..791376ae56e 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs @@ -140,12 +140,6 @@ public async ValueTask GetAgentAsync(AgentId agentId, bool lazy = true) return agentId; } - public ValueTask GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true) - => this.GetAgentAsync(new AgentId(agentType, key), lazy); - - public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true) - => this.GetAgentAsync(new AgentId(agent, key), lazy); - public async ValueTask GetAgentMetadataAsync(AgentId agentId) { IHostableAgent agent = await this.EnsureAgentAsync(agentId); diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs index 11cd42362db..0de95399210 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs @@ -135,7 +135,7 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => }); Assert.Null(agent); - await runtime.GetAgentAsync("MyAgent", lazy: false); + await runtime.GetAgentAsync(AgentId.FromStr("MyAgent"), lazy: false); Assert.NotNull(agent); Assert.True(agent.ReceivedItems.Count == 0); From 9ff345e5ea3f638f2936e7d664291b1256d16b69 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Wed, 29 Jan 2025 23:54:33 -0500 Subject: [PATCH 5/8] feat: Implement remaining methods in GrpcAgentRuntime * Factor out AgentContainer for managing factory registration, instantiation, and subscription management --- .../Core.Grpc/GrpcAgentRuntime.cs | 227 +++++++++++------- 1 file changed, 136 insertions(+), 91 deletions(-) diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index b1f5d43669e..f0be2376a15 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -11,7 +11,74 @@ namespace Microsoft.AutoGen.Core.Grpc; +internal sealed class AgentsContainer(IAgentRuntime hostingRuntime) +{ + private readonly IAgentRuntime hostingRuntime = hostingRuntime; + + private Dictionary agentInstances = new(); + private Dictionary subscriptions = new(); + private Dictionary>> agentFactories = new(); + + public async ValueTask EnsureAgentAsync(Contracts.AgentId agentId) + { + if (!this.agentInstances.TryGetValue(agentId, out IHostableAgent? agent)) + { + if (!this.agentFactories.TryGetValue(agentId.Type, out Func>? factoryFunc)) + { + throw new Exception($"Agent with name {agentId.Type} not found."); + } + + agent = await factoryFunc(agentId, this.hostingRuntime); + this.agentInstances.Add(agentId, agent); + } + + return this.agentInstances[agentId]; + } + + public async ValueTask GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) + { + if (!lazy) + { + await this.EnsureAgentAsync(agentId); + } + + return agentId; + } + + public AgentType RegisterAgentFactory(AgentType type, Func> factoryFunc) + { + if (this.agentFactories.ContainsKey(type)) + { + throw new Exception($"Agent factory with type {type} already exists."); + } + this.agentFactories.Add(type, factoryFunc); + return type; + } + + public void AddSubscription(ISubscriptionDefinition subscription) + { + if (this.subscriptions.ContainsKey(subscription.Id)) + { + throw new Exception($"Subscription with id {subscription.Id} already exists."); + } + + this.subscriptions.Add(subscription.Id, subscription); + } + + public bool RemoveSubscriptionAsync(string subscriptionId) + { + if (!this.subscriptions.ContainsKey(subscriptionId)) + { + throw new Exception($"Subscription with id {subscriptionId} does not exist."); + } + + return this.subscriptions.Remove(subscriptionId); + } + + public HashSet RegisteredAgentTypes => this.agentFactories.Keys.ToHashSet(); + public IEnumerable LiveAgents => this.agentInstances.Values; +} public sealed class GrpcAgentRuntime: IHostedService, IAgentRuntime, IMessageSink, IDisposable { @@ -25,21 +92,21 @@ public GrpcAgentRuntime(AgentRpc.AgentRpcClient client, this._shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping); this._messageRouter = new GrpcMessageRouter(client, this, logger, this._shutdownCts.Token); + this._agentsContainer = new AgentsContainer(this); this.ServiceProvider = serviceProvider; } - // Request ID -> + // Request ID -> ResultSink<...> private readonly ConcurrentDictionary> _pendingRequests = new(); - private Dictionary>> agentFactories = new(); - private Dictionary agentInstances = new(); - private readonly AgentRpc.AgentRpcClient _client; private readonly GrpcMessageRouter _messageRouter; private readonly ILogger _logger; private readonly CancellationTokenSource _shutdownCts; + + private readonly AgentsContainer _agentsContainer; public IServiceProvider ServiceProvider { get; } @@ -84,7 +151,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc } var agentId = request.Target; - var agent = await EnsureAgentAsync(agentId.FromProtobuf()); + var agent = await this._agentsContainer.EnsureAgentAsync(agentId.FromProtobuf()); // Convert payload back to object var payload = request.Payload; @@ -172,36 +239,9 @@ private async ValueTask HandlePublish(CloudEvent evt, CancellationToken cancella Topic = topic, IsRpc = false }; - var agent = await EnsureAgentAsync(sender); + var agent = await this._agentsContainer.EnsureAgentAsync(sender); await agent.OnMessageAsync(message, messageContext); } - - - - // private override async ValueTask SendMessageAsync(Payload message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default) - // { - // var request = new RpcRequest - // { - // RequestId = Guid.NewGuid().ToString(), - // Source = agent, - // Target = agentId, - // Payload = message, - // }; - - // // Actually send it and wait for the response - // throw new NotImplementedException(); - // } - - // new is intentional - - // public new async ValueTask RuntimeSendRequestAsync(IAgent agent, RpcRequest request, CancellationToken cancellationToken = default) - // { - // var requestId = Guid.NewGuid().ToString(); - // _pendingRequests[requestId] = ((Agent)agent, request.RequestId); - // request.RequestId = requestId; - // await WriteChannelAsync(new Message { Request = request }, cancellationToken).ConfigureAwait(false); - // } - public ValueTask StartAsync(CancellationToken cancellationToken) { @@ -215,22 +255,6 @@ public Task StopAsync(CancellationToken cancellationToken) return this._messageRouter.StopAsync(); } - private async ValueTask EnsureAgentAsync(Contracts.AgentId agentId) - { - if (!this.agentInstances.TryGetValue(agentId, out IHostableAgent? agent)) - { - if (!this.agentFactories.TryGetValue(agentId.Type, out Func>? factoryFunc)) - { - throw new Exception($"Agent with name {agentId.Type} not found."); - } - - agent = await factoryFunc(agentId, this); - this.agentInstances.Add(agentId, agent); - } - - return this.agentInstances[agentId]; - } - private Payload ObjectToPayload(object message) { if (!SerializationRegistry.Exists(message.GetType())) { @@ -338,77 +362,98 @@ public async ValueTask PublishMessageAsync(object message, TopicId topic, Contra await this._messageRouter.RouteMessageAsync(msg, cancellationToken); } - public ValueTask GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) - { - throw new NotImplementedException(); - } + public ValueTask GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) => this._agentsContainer.GetAgentAsync(agentId, lazy); - public ValueTask GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true) + public async ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) { - throw new NotImplementedException(); + IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); + return await agent.SaveStateAsync(); } - public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true) + public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) { - throw new NotImplementedException(); + IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); + await agent.LoadStateAsync(state); } - public ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) + public async ValueTask GetAgentMetadataAsync(Contracts.AgentId agentId) { - throw new NotImplementedException(); + IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); + return agent.Metadata; } - public ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) + public ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription) { - throw new NotImplementedException(); + this._agentsContainer.AddSubscription(subscription); + + // Because we have an extensible definition of ISubscriptionDefinition, we cannot project it to the Gateway. + // What this means is that we will have a much chattier interface between the Gateway and the Runtime. + + //await this._client.AddSubscriptionAsync(new AddSubscriptionRequest + //{ + // Subscription = new Subscription + // { + // Id = subscription.Id, + // TopicType = subscription.TopicType, + // AgentType = subscription.AgentType.Name + // } + //}, this.CallOptions); + + return ValueTask.CompletedTask; } - public ValueTask GetAgentMetadataAsync(Contracts.AgentId agentId) + public ValueTask RemoveSubscriptionAsync(string subscriptionId) { - throw new NotImplementedException(); - } + this._agentsContainer.RemoveSubscriptionAsync(subscriptionId); - public async ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription) - { - var _ = await this._client.AddSubscriptionAsync(new AddSubscriptionRequest{ - Subscription = subscription.ToProtobuf() - },this.CallOptions); + // See above (AddSubscriptionAsync) for why this is commented out. + + //await this._client.RemoveSubscriptionAsync(new RemoveSubscriptionRequest + //{ + // Id = subscriptionId + //}, this.CallOptions); + + return ValueTask.CompletedTask; } - public ValueTask RemoveSubscriptionAsync(string subscriptionId) + public ValueTask RegisterAgentFactoryAsync(AgentType type, Func> factoryFunc) + => ValueTask.FromResult(this._agentsContainer.RegisterAgentFactory(type, factoryFunc)); + + public ValueTask TryGetAgentProxyAsync(Contracts.AgentId agentId) { - throw new NotImplementedException(); + // TODO: Do we want to support getting remote agent proxies? + return ValueTask.FromResult(new AgentProxy(agentId, this)); } - public ValueTask RegisterAgentFactoryAsync(AgentType type, Func> factoryFunc) + public async ValueTask> SaveStateAsync() { - if (this.agentFactories.ContainsKey(type)) + Dictionary state = new(); + foreach (var agent in this._agentsContainer.LiveAgents) { - throw new Exception($"Agent with type {type} already exists."); + state[agent.Id.ToString()] = await agent.SaveStateAsync(); } - this.agentFactories.Add(type, async (agentId, runtime) => await factoryFunc(agentId, runtime)); - - this._client.RegisterAgentAsync(new RegisterAgentTypeRequest - { - Type = type.Name, - }, this.CallOptions); - return ValueTask.FromResult(type); + return state; } - public ValueTask TryGetAgentProxyAsync(Contracts.AgentId agentId) + public async ValueTask LoadStateAsync(IDictionary state) { - throw new NotImplementedException(); - } + HashSet registeredTypes = this._agentsContainer.RegisteredAgentTypes; - public ValueTask> SaveStateAsync() - { - throw new NotImplementedException(); - } + foreach (var agentIdStr in state.Keys) + { + Contracts.AgentId agentId = Contracts.AgentId.FromStr(agentIdStr); + if (state[agentIdStr] is not IDictionary agentStateDict) + { + throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary)}: {state[agentIdStr].GetType()}"); + } - public ValueTask LoadStateAsync(IDictionary state) - { - throw new NotImplementedException(); + if (registeredTypes.Contains(agentId.Type)) + { + IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); + await agent.LoadStateAsync(agentStateDict); + } + } } public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default) From 6f1f04f3e816fefc204549a2326d918734d2d461 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 30 Jan 2025 00:19:29 -0500 Subject: [PATCH 6/8] fix: Get Core.Grpc test project building --- .../Core.Grpc/AgentsAppBuilderExtensions.cs | 21 ++ .../AgentGrpcTests.cs | 346 ++++++++++-------- .../Microsoft.AutoGen.Core.Grpc.Tests.csproj | 2 +- 3 files changed, 212 insertions(+), 157 deletions(-) create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/AgentsAppBuilderExtensions.cs diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/AgentsAppBuilderExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/AgentsAppBuilderExtensions.cs new file mode 100644 index 00000000000..e19cc2f343d --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/AgentsAppBuilderExtensions.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AgentsAppBuilderExtensions.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AutoGen.Core.Grpc; + +public static class AgentsAppBuilderExtensions +{ + public static AgentsAppBuilder UseGrpcRuntime(this AgentsAppBuilder this_, bool deliverToSelf = false) + { + this_.Services.AddSingleton(); + this_.Services.AddHostedService(services => + { + return (services.GetRequiredService() as GrpcAgentRuntime)!; + }); + + return this_; + } +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/AgentGrpcTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/AgentGrpcTests.cs index f9f1341e8e0..bcce6a0c28e 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/AgentGrpcTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/AgentGrpcTests.cs @@ -1,183 +1,213 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentGrpcTests.cs -using System.Collections.Concurrent; -using System.Text.Json; -using FluentAssertions; -using Google.Protobuf.Reflection; +//using System.Collections.Concurrent; +//using System.Text.Json; +//using FluentAssertions; +//using Google.Protobuf.Reflection; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Xunit; -using static Microsoft.AutoGen.Core.Grpc.Tests.AgentGrpcTests; namespace Microsoft.AutoGen.Core.Grpc.Tests; public class AgentGrpcTests { - /// - /// Verify that if the agent is not initialized via AgentWorker, it should throw the correct exception. - /// - /// void [Fact] - public async Task Agent_ShouldThrowException_WhenNotInitialized() - { - using var runtime = new GrpcRuntime(); - var (_, agent) = runtime.Start(false); // Do not initialize - - // Expect an exception when calling AddSubscriptionAsync because the agent is uninitialized - await Assert.ThrowsAsync( - async () => await agent.AddSubscriptionAsync("TestEvent") - ); - } - - /// - /// validate that the agent is initialized correctly with implicit subs - /// - /// void - [Fact] - public async Task Agent_ShouldInitializeCorrectly() + public void Agent_ShouldInitializeCorrectly() { using var runtime = new GrpcRuntime(); var (worker, agent) = runtime.Start(); Assert.Equal(nameof(GrpcAgentRuntime), worker.GetType().Name); - await Task.Delay(5000); - var subscriptions = await agent.GetSubscriptionsAsync(); - Assert.Equal(2, subscriptions.Count); - } - /// - /// Test AddSubscriptionAsync method - /// - /// void - [Fact] - public async Task SubscribeAsync_UnsubscribeAsync_and_GetSubscriptionsTest() - { - using var runtime = new GrpcRuntime(); - var (_, agent) = runtime.Start(); - await agent.AddSubscriptionAsync("TestEvent"); - await Task.Delay(100); - var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); - var found = false; - foreach (var subscription in subscriptions) - { - if (subscription.TypeSubscription.TopicType == "TestEvent") - { - found = true; - } - } - Assert.True(found); - await agent.RemoveSubscriptionAsync("TestEvent").ConfigureAwait(true); - await Task.Delay(1000); - subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); - found = false; - foreach (var subscription in subscriptions) - { - if (subscription.TypeSubscription.TopicType == "TestEvent") - { - found = true; - } - } - Assert.False(found); } - /// - /// Test StoreAsync and ReadAsync methods - /// - /// void - [Fact] - public async Task StoreAsync_and_ReadAsyncTest() - { - using var runtime = new GrpcRuntime(); - var (_, agent) = runtime.Start(); - Dictionary state = new() - { - { "testdata", "Active" } - }; - await agent.StoreAsync(new AgentState - { - AgentId = agent.AgentId, - TextData = JsonSerializer.Serialize(state) - }).ConfigureAwait(true); - var readState = await agent.ReadAsync(agent.AgentId).ConfigureAwait(true); - var read = JsonSerializer.Deserialize>(readState.TextData) ?? new Dictionary { { "data", "No state data found" } }; - read.TryGetValue("testdata", out var value); - Assert.Equal("Active", value); - } + ///// + ///// Verify that if the agent is not initialized via AgentWorker, it should throw the correct exception. + ///// + ///// void + //[Fact] + //public async Task Agent_ShouldThrowException_WhenNotInitialized() + //{ + // using var runtime = new GrpcRuntime(); + // var (_, agent) = runtime.Start(false); // Do not initialize - /// - /// Test PublishMessageAsync method and ReceiveMessage method - /// - /// void - [Fact] - public async Task PublishMessageAsync_and_ReceiveMessageTest() + // // Expect an exception when calling AddSubscriptionAsync because the agent is uninitialized + // await Assert.ThrowsAsync( + // async () => await agent.AddSubscriptionAsync("TestEvent") + // ); + //} + + ///// + ///// validate that the agent is initialized correctly with implicit subs + ///// + ///// void + //[Fact] + //public async Task Agent_ShouldInitializeCorrectly() + //{ + // using var runtime = new GrpcRuntime(); + // var (worker, agent) = runtime.Start(); + // Assert.Equal(nameof(GrpcAgentRuntime), worker.GetType().Name); + // await Task.Delay(5000); + // var subscriptions = await agent.GetSubscriptionsAsync(); + // Assert.Equal(2, subscriptions.Count); + //} + ///// + ///// Test AddSubscriptionAsync method + ///// + ///// void + //[Fact] + //public async Task SubscribeAsync_UnsubscribeAsync_and_GetSubscriptionsTest() + //{ + // using var runtime = new GrpcRuntime(); + // var (_, agent) = runtime.Start(); + // await agent.AddSubscriptionAsync("TestEvent"); + // await Task.Delay(100); + // var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); + // var found = false; + // foreach (var subscription in subscriptions) + // { + // if (subscription.TypeSubscription.TopicType == "TestEvent") + // { + // found = true; + // } + // } + // Assert.True(found); + // await agent.RemoveSubscriptionAsync("TestEvent").ConfigureAwait(true); + // await Task.Delay(1000); + // subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); + // found = false; + // foreach (var subscription in subscriptions) + // { + // if (subscription.TypeSubscription.TopicType == "TestEvent") + // { + // found = true; + // } + // } + // Assert.False(found); + //} + + ///// + ///// Test StoreAsync and ReadAsync methods + ///// + ///// void + //[Fact] + //public async Task StoreAsync_and_ReadAsyncTest() + //{ + // using var runtime = new GrpcRuntime(); + // var (_, agent) = runtime.Start(); + // Dictionary state = new() + // { + // { "testdata", "Active" } + // }; + // await agent.StoreAsync(new AgentState + // { + // AgentId = agent.AgentId, + // TextData = JsonSerializer.Serialize(state) + // }).ConfigureAwait(true); + // var readState = await agent.ReadAsync(agent.AgentId).ConfigureAwait(true); + // var read = JsonSerializer.Deserialize>(readState.TextData) ?? new Dictionary { { "data", "No state data found" } }; + // read.TryGetValue("testdata", out var value); + // Assert.Equal("Active", value); + //} + + ///// + ///// Test PublishMessageAsync method and ReceiveMessage method + ///// + ///// void + //[Fact] + //public async Task PublishMessageAsync_and_ReceiveMessageTest() + //{ + // using var runtime = new GrpcRuntime(); + // var (_, agent) = runtime.Start(); + // var topicType = "TestTopic"; + // await agent.AddSubscriptionAsync(topicType).ConfigureAwait(true); + // var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); + // var found = false; + // foreach (var subscription in subscriptions) + // { + // if (subscription.TypeSubscription.TopicType == topicType) + // { + // found = true; + // } + // } + // Assert.True(found); + // await agent.PublishMessageAsync(new TextMessage() + // { + // Source = topicType, + // TextMessage_ = "buffer" + // }, topicType).ConfigureAwait(true); + // await Task.Delay(100); + // Assert.True(TestAgent.ReceivedMessages.ContainsKey(topicType)); + // runtime.Stop(); + //} + + //[Fact] + //public async Task InvokeCorrectHandler() + //{ + // var agent = new TestAgent(new AgentsMetadata(TypeRegistry.Empty, new Dictionary(), new Dictionary>(), new Dictionary>()), new Logger(new LoggerFactory())); + + // await agent.HandleObjectAsync("hello world"); + // await agent.HandleObjectAsync(42); + + // agent.ReceivedItems.Should().HaveCount(2); + // agent.ReceivedItems[0].Should().Be("hello world"); + // agent.ReceivedItems[1].Should().Be(42); + //} +} + +/// +/// The test agent is a simple agent that is used for testing purposes. +/// +public class TestAgent(AgentId id, + IAgentRuntime runtime, + Logger? logger = null) : BaseAgent(id, runtime, "Test Agent", logger), + //IHandle, + //IHandle, + IHandle + +{ + //public ValueTask HandleAsync(TextMessage item, MessageContext messageContext) + //{ + // ReceivedMessages[item.Source] = item.Content; + // return ValueTask.CompletedTask; + //} + + public ValueTask HandleAsync(string item, MessageContext messageContext) { - using var runtime = new GrpcRuntime(); - var (_, agent) = runtime.Start(); - var topicType = "TestTopic"; - await agent.AddSubscriptionAsync(topicType).ConfigureAwait(true); - var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); - var found = false; - foreach (var subscription in subscriptions) - { - if (subscription.TypeSubscription.TopicType == topicType) - { - found = true; - } - } - Assert.True(found); - await agent.PublishMessageAsync(new TextMessage() - { - Source = topicType, - TextMessage_ = "buffer" - }, topicType).ConfigureAwait(true); - await Task.Delay(100); - Assert.True(TestAgent.ReceivedMessages.ContainsKey(topicType)); - runtime.Stop(); + ReceivedItems.Add(item); + return ValueTask.CompletedTask; } - [Fact] - public async Task InvokeCorrectHandler() + public ValueTask HandleAsync(int item, MessageContext messageContext) { - var agent = new TestAgent(new AgentsMetadata(TypeRegistry.Empty, new Dictionary(), new Dictionary>(), new Dictionary>()), new Logger(new LoggerFactory())); + ReceivedItems.Add(item); + return ValueTask.CompletedTask; + } - await agent.HandleObjectAsync("hello world"); - await agent.HandleObjectAsync(42); + //public ValueTask HandleAsync(RpcTextMessage item, MessageContext messageContext) + //{ + // ReceivedMessages[item.Source] = item.Content; + // return ValueTask.FromResult(item.Content); + //} - agent.ReceivedItems.Should().HaveCount(2); - agent.ReceivedItems[0].Should().Be("hello world"); - agent.ReceivedItems[1].Should().Be(42); - } + public List ReceivedItems { get; private set; } = []; /// - /// The test agent is a simple agent that is used for testing purposes. + /// Key: source + /// Value: message /// - public class TestAgent( - [FromKeyedServices("AgentsMetadata")] AgentsMetadata eventTypes, - Logger? logger = null) : Agent(eventTypes, logger), IHandle - { - public Task Handle(TextMessage item, CancellationToken cancellationToken = default) - { - ReceivedMessages[item.Source] = item.TextMessage_; - return Task.CompletedTask; - } - public Task Handle(string item) - { - ReceivedItems.Add(item); - return Task.CompletedTask; - } - public Task Handle(int item) - { - ReceivedItems.Add(item); - return Task.CompletedTask; - } - public List ReceivedItems { get; private set; } = []; + public static Dictionary ReceivedMessages { get; private set; } = new(); +} - /// - /// Key: source - /// Value: message - /// - public static ConcurrentDictionary ReceivedMessages { get; private set; } = new(); +[TypeSubscription("TestTopic")] +public class SubscribedAgent : TestAgent +{ + public SubscribedAgent(AgentId id, + IAgentRuntime runtime, + Logger? logger = null) : base(id, runtime, logger) + { } } @@ -211,14 +241,18 @@ private static int GetAvailablePort() private static async Task StartClientAsync() { - return await AgentsApp.StartAsync().ConfigureAwait(false); - } - private static async Task StartAppHostAsync() - { - return await Microsoft.AutoGen.Runtime.Grpc.Host.StartAsync(local: false, useGrpc: true).ConfigureAwait(false); + AgentsApp agentsApp = await new AgentsAppBuilder().UseGrpcRuntime().AddAgent("TestAgent").BuildAsync(); + await agentsApp.StartAsync(); + + return agentsApp.Host; } + //private static async Task StartAppHostAsync() + //{ + // return await Microsoft.AutoGen.Runtime.Grpc.Host.StartAsync(local: false, useGrpc: true).ConfigureAwait(false); + //} + /// /// Start - gets a new port and starts fresh instances /// @@ -230,14 +264,14 @@ private static async Task StartAppHostAsync() Environment.SetEnvironmentVariable("ASPNETCORE_HTTPS_PORTS", port.ToString()); Environment.SetEnvironmentVariable("AGENT_HOST", $"https://localhost:{port}"); - AppHost = StartAppHostAsync().GetAwaiter().GetResult(); + //AppHost = StartAppHostAsync().GetAwaiter().GetResult(); Client = StartClientAsync().GetAwaiter().GetResult(); var agent = ActivatorUtilities.CreateInstance(Client.Services); var worker = Client.Services.GetRequiredService(); if (initialize) { - Agent.Initialize(worker, agent); + //Agent.Initialize(worker, agent); } return (worker, agent); diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj index f14497e75fb..a2dad2212a7 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj @@ -10,7 +10,7 @@ - + From 07bdb554977a34a2e27c6f8ed50bb6dbcc19303c Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 30 Jan 2025 01:45:32 -0500 Subject: [PATCH 7/8] feat: Restore Runtime.Grpc project and get it building --- dotnet/AutoGen.sln | 19 ++- .../Runtime.Grpc/Abstractions/IAgentGrain.cs | 6 +- .../Runtime.Grpc/Abstractions/IGateway.cs | 18 +- .../Abstractions/IGatewayRegistry.cs | 55 +++++- .../Abstractions/IRegistryGrain.cs | 15 ++ .../Microsoft.AutoGen.Runtime.Grpc.csproj | 1 + .../Services/AgentWorkerHostingExtensions.cs | 10 +- .../Runtime.Grpc/Services/Grpc/GrpcGateway.cs | 156 ++++++++++++------ .../Services/Grpc/GrpcGatewayService.cs | 15 +- .../Services/Grpc/GrpcWorkerConnection.cs | 2 + .../Services/Orleans/AgentStateGrain.cs | 23 +++ .../Services/Orleans/RegistryGrain.cs | 4 +- .../AddSubscriptionRequestSurrogate.cs | 6 +- .../AddSubscriptionResponseSurrogate.cs | 14 +- .../Orleans/Surrogates/AgentIdSurrogate.cs | 2 +- .../Orleans/Surrogates/AgentStateSurrogate.cs | 2 +- .../Orleans/Surrogates/CloudEventSurrogate.cs | 1 + .../Surrogates/GetSubscriptionsRequest.cs | 2 + .../RegisterAgentTypeRequestSurrogate.cs | 5 +- .../RegisterAgentTypeResponseSurrogate.cs | 14 +- .../Surrogates/RemoveSubscriptionRequest.cs | 2 + .../Surrogates/RemoveSubscriptionResponse.cs | 10 +- .../Orleans/Surrogates/RpcRequestSurrogate.cs | 2 +- .../Surrogates/RpcResponseSurrogate.cs | 1 + .../Surrogates/SubscriptionSurrogate.cs | 93 +++++------ .../TypePrefixSubscriptionSurrogate.cs | 56 +++---- .../Surrogates/TypeSubscriptionSurrogate.cs | 56 +++---- 27 files changed, 393 insertions(+), 197 deletions(-) diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index ab7a07464c5..61fdd8bf4ae 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -118,6 +118,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Hello", "Hello", "{F42F9C8E EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc", "src\Microsoft.AutoGen\Core.Grpc\Microsoft.AutoGen.Core.Grpc.csproj", "{3D83C6DB-ACEA-48F3-959F-145CCD2EE135}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Runtime.Grpc", "src\Microsoft.AutoGen\Runtime.Grpc\Microsoft.AutoGen.Runtime.Grpc.csproj", "{BEC2FDB8-5FC4-4B88-9D69-69759F63F4DC}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -296,16 +298,18 @@ Global {70A8D4B5-D0A6-4098-A6F3-6ED274B65E7D}.Debug|Any CPU.Build.0 = Debug|Any CPU {70A8D4B5-D0A6-4098-A6F3-6ED274B65E7D}.Release|Any CPU.ActiveCfg = Release|Any CPU {70A8D4B5-D0A6-4098-A6F3-6ED274B65E7D}.Release|Any CPU.Build.0 = Release|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.CoreOnly|Any CPU.ActiveCfg = Debug|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.CoreOnly|Any CPU.Build.0 = Debug|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Debug|Any CPU.Build.0 = Debug|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Release|Any CPU.ActiveCfg = Release|Any CPU - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Release|Any CPU.Build.0 = Release|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Debug|Any CPU.Build.0 = Debug|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.ActiveCfg = Release|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.Build.0 = Release|Any CPU + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135}.Release|Any CPU.Build.0 = Release|Any CPU + {BEC2FDB8-5FC4-4B88-9D69-69759F63F4DC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BEC2FDB8-5FC4-4B88-9D69-69759F63F4DC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BEC2FDB8-5FC4-4B88-9D69-69759F63F4DC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BEC2FDB8-5FC4-4B88-9D69-69759F63F4DC}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -356,9 +360,10 @@ Global {EAFFE339-26CB-4019-991D-BCCE8E7D33A1} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {58AD8E1D-83BD-4950-A324-1A20677D78D9} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} {70A8D4B5-D0A6-4098-A6F3-6ED274B65E7D} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6} - {3D83C6DB-ACEA-48F3-959F-145CCD2EE135} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {AAD593FE-A49B-425E-A9FE-A0022CD25E3D} = {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6} + {3D83C6DB-ACEA-48F3-959F-145CCD2EE135} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {BEC2FDB8-5FC4-4B88-9D69-69759F63F4DC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IAgentGrain.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IAgentGrain.cs index 947b6b0cbc0..bc6c098a8d2 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IAgentGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IAgentGrain.cs @@ -1,10 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IAgentGrain.cs +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; internal interface IAgentGrain : IGrainWithStringKey { - ValueTask ReadStateAsync(); - ValueTask WriteStateAsync(Contracts.AgentState state, string eTag); + ValueTask ReadStateAsync(); + ValueTask WriteStateAsync(AgentState state, string eTag); } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGateway.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGateway.cs index 33bb94f7c49..6b6ff3bc7b1 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGateway.cs @@ -1,18 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IGateway.cs using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; +public interface IConnection +{ +} + public interface IGateway : IGrainObserver { - ValueTask InvokeRequestAsync(RpcRequest request); ValueTask BroadcastEventAsync(CloudEvent evt); - ValueTask StoreAsync(Contracts.AgentState value); - ValueTask ReadAsync(AgentId agentId); - ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request); + + ValueTask InvokeRequestAsync(RpcRequest request); + + ValueTask StoreAsync(Protobuf.AgentState value); + ValueTask ReadAsync(Protobuf.AgentId agentId); + + ValueTask RegisterAgentTypeAsync(string requestId, RegisterAgentTypeRequest request); + ValueTask SubscribeAsync(AddSubscriptionRequest request); ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request); ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request); + Task SendMessageAsync(IConnection connection, CloudEvent cloudEvent); } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGatewayRegistry.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGatewayRegistry.cs index cb377841804..20f06e8b363 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGatewayRegistry.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IGatewayRegistry.cs @@ -2,9 +2,62 @@ // IGatewayRegistry.cs using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; +public interface IRegistry +{ + //AgentsRegistryState State { get; set; } + /// + /// Registers a new agent type with the specified worker. + /// + /// The request containing agent type details. + /// The worker to register the agent type with. + /// A task representing the asynchronous operation. + /// removing CancellationToken from here as it is not compatible with Orleans Serialization + ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, IAgentRuntime worker); + + /// + /// Unregisters an agent type from the specified worker. + /// + /// The type of the agent to unregister. + /// The worker to unregister the agent type from. + /// A task representing the asynchronous operation. + /// removing CancellationToken from here as it is not compatible with Orleans Serialization + ValueTask UnregisterAgentTypeAsync(string type, IAgentRuntime worker); + + /// + /// Gets a list of agents subscribed to and handling the specified topic and event type. + /// + /// The topic to check subscriptions for. + /// The event type to check subscriptions for. + /// A task representing the asynchronous operation, with the list of agent IDs as the result. + ValueTask> GetSubscribedAndHandlingAgentsAsync(string topic, string eventType); + + /// + /// Subscribes an agent to a topic. + /// + /// The subscription request. + /// A task representing the asynchronous operation. + /// removing CancellationToken from here as it is not compatible with Orleans Serialization + ValueTask SubscribeAsync(AddSubscriptionRequest request); + + /// + /// Unsubscribes an agent from a topic. + /// + /// The unsubscription request. + /// A task representing the asynchronous operation. + /// removing CancellationToken from here as it is not compatible with Orleans Serialization + ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request); // TODO: This should have its own request type. + + /// + /// Gets the subscriptions for a specified agent type. + /// + /// A task representing the asynchronous operation, with the subscriptions as the result. + ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request); +} + /// /// Interface for managing agent registration, placement, and subscriptions. /// @@ -15,7 +68,7 @@ public interface IGatewayRegistry : IRegistry /// /// The ID of the agent. /// A tuple containing the worker and a boolean indicating if it's a new placement. - ValueTask<(IGateway? Worker, bool NewPlacement)> GetOrPlaceAgent(AgentId agentId); + ValueTask<(IGateway? Worker, bool NewPlacement)> GetOrPlaceAgent(Protobuf.AgentId agentId); /// /// Removes a worker from the registry. diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IRegistryGrain.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IRegistryGrain.cs index 81b59858619..ea7d99ffb8a 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IRegistryGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Abstractions/IRegistryGrain.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IRegistryGrain.cs +using Microsoft.AutoGen.Protobuf; +using System.Collections.Concurrent; + namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; /// @@ -9,3 +12,15 @@ namespace Microsoft.AutoGen.Runtime.Grpc.Abstractions; [Alias("Microsoft.AutoGen.Runtime.Grpc.Abstractions.IRegistryGrain")] public interface IRegistryGrain : IGatewayRegistry, IGrainWithIntegerKey { } + +public class AgentsRegistryState +{ + public ConcurrentDictionary> AgentsToEventsMap { get; set; } = new ConcurrentDictionary>(); + public ConcurrentDictionary> AgentsToTopicsMap { get; set; } = []; + public ConcurrentDictionary> TopicToAgentTypesMap { get; set; } = []; + public ConcurrentDictionary> EventsToAgentTypesMap { get; set; } = []; + public ConcurrentDictionary> GuidSubscriptionsMap { get; set; } = []; + public ConcurrentDictionary AgentTypes { get; set; } = []; + public string Etag { get; set; } = Guid.NewGuid().ToString(); +} + diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Microsoft.AutoGen.Runtime.Grpc.csproj b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Microsoft.AutoGen.Runtime.Grpc.csproj index b874a657d8f..caf2f64c55f 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Microsoft.AutoGen.Runtime.Grpc.csproj +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Microsoft.AutoGen.Runtime.Grpc.csproj @@ -6,6 +6,7 @@ + diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/AgentWorkerHostingExtensions.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/AgentWorkerHostingExtensions.cs index 3b130ca4bed..5fbcca9fede 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/AgentWorkerHostingExtensions.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/AgentWorkerHostingExtensions.cs @@ -3,7 +3,7 @@ using System.Diagnostics; using Microsoft.AspNetCore.Builder; -using Microsoft.AutoGen.Core; +//using Microsoft.AutoGen.Core; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Hosting; @@ -18,10 +18,10 @@ public static WebApplicationBuilder AddAgentService(this WebApplicationBuilder b builder.Services.TryAddSingleton(DistributedContextPropagator.Current); builder.Services.AddGrpc(); - builder.Services.AddKeyedSingleton("AgentsMetadata", (sp, key) => - { - return ReflectionHelper.GetAgentsMetadata(AppDomain.CurrentDomain.GetAssemblies()); - }); + //builder.Services.AddKeyedSingleton("AgentsMetadata", (sp, key) => + //{ + // return ReflectionHelper.GetAgentsMetadata(AppDomain.CurrentDomain.GetAssemblies()); + //}); builder.Services.AddSingleton(); builder.Services.AddSingleton(sp => (IHostedService)sp.GetRequiredService()); diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs index 26c99c89424..afa4c3603b5 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs @@ -4,6 +4,7 @@ using System.Collections.Concurrent; using Grpc.Core; using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; using Microsoft.AutoGen.Runtime.Grpc.Abstractions; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; @@ -30,6 +31,7 @@ public sealed class GrpcGateway : BackgroundService, IGateway private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new(); // RPC private readonly ConcurrentDictionary<(GrpcWorkerConnection, string), TaskCompletionSource> _pendingRequests = new(); + public GrpcGateway(IClusterClient clusterClient, ILogger logger) { _logger = logger; @@ -38,8 +40,26 @@ public GrpcGateway(IClusterClient clusterClient, ILogger logger) _gatewayRegistry = clusterClient.GetGrain(0); _subscriptions = clusterClient.GetGrain(0); } + public async ValueTask InvokeRequestAsync(RpcRequest request, CancellationToken cancellationToken = default) { + //if (string.IsNullOrWhiteSpace(request.Target.Type) && string.IsNullOrWhiteSpace(request.Target.Key)) + //{ + // // Check if this is a request to the gateway itself. + // switch (request.Method) + // { + // case "RegisterAgentType": + // { + // if (!request.Payload.DataType.Equals(nameof(RegisterAgentTypeRequest))) + // { + // return new(new RpcResponse { Error = "Invalid payload type." }); + // } + + // //return await TryServiceRegisterAgentType(request.RequestId, request.Payload, cancellationToken).ConfigureAwait(false); + // } + // } + //} + var agentId = (request.Target.Type, request.Target.Key); if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion.IsCompleted == true) { @@ -65,63 +85,89 @@ public async ValueTask InvokeRequestAsync(RpcRequest request, Cance response.RequestId = originalRequestId; return response; } + public async ValueTask StoreAsync(AgentState value, CancellationToken cancellationToken = default) { _ = value.AgentId ?? throw new ArgumentNullException(nameof(value.AgentId)); var agentState = _clusterClient.GetGrain($"{value.AgentId.Type}:{value.AgentId.Key}"); await agentState.WriteStateAsync(value, value.ETag); } - public async ValueTask ReadAsync(AgentId agentId, CancellationToken cancellationToken = default) + + public async ValueTask ReadAsync(Protobuf.AgentId agentId, CancellationToken cancellationToken = default) { var agentState = _clusterClient.GetGrain($"{agentId.Type}:{agentId.Key}"); return await agentState.ReadStateAsync(); } - public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken cancellationToken = default) + + public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken _ = default) { + string requestId = string.Empty; try { - var connection = _workersByConnection[request.RequestId]; + var connection = _workersByConnection[requestId]; connection.AddSupportedType(request.Type); _supportedAgentTypes.GetOrAdd(request.Type, _ => []).Add(connection); await _gatewayRegistry.RegisterAgentTypeAsync(request, _reference).ConfigureAwait(true); return new RegisterAgentTypeResponse { - Success = true, - RequestId = request.RequestId + //Success = true, + //RequestId = request.RequestId }; } - catch (Exception ex) + catch (Exception) { return new RegisterAgentTypeResponse { - Success = false, - RequestId = request.RequestId, - Error = ex.Message + //Success = false, + //RequestId = request.RequestId, + //Error = ex.Message }; } } - public async ValueTask SubscribeAsync(AddSubscriptionRequest request, CancellationToken cancellationToken = default) + + private async ValueTask RegisterAgentTypeAsync(string requestId, GrpcWorkerConnection connection, RegisterAgentTypeRequest msg) + { + connection.AddSupportedType(msg.Type); + _supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection); + + await _gatewayRegistry.RegisterAgentTypeAsync(msg, _reference).ConfigureAwait(true); + + Message response = new() + { + Response = new RpcResponse + { + RequestId = requestId, + Error = "", + //Success = true + } + }; + + await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false); + } + + public async ValueTask SubscribeAsync(AddSubscriptionRequest request, CancellationToken _ = default) { try { await _gatewayRegistry.SubscribeAsync(request).ConfigureAwait(true); return new AddSubscriptionResponse { - Success = true, - RequestId = request.RequestId + //Success = true, + //RequestId = request.RequestId }; } - catch (Exception ex) + catch (Exception) { return new AddSubscriptionResponse { - Success = false, - RequestId = request.RequestId, - Error = ex.Message + //Success = false, + //RequestId = request.RequestId, + //Error = ex.Message }; } } + protected override async Task ExecuteAsync(CancellationToken stoppingToken) { while (!stoppingToken.IsCancellationRequested) @@ -145,6 +191,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) _logger.LogWarning(exception, "Error removing worker from registry."); } } + internal async Task ConnectToWorkerProcess(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) { _logger.LogInformation("Received new connection from {Peer}.", context.Peer); @@ -153,10 +200,12 @@ internal async Task ConnectToWorkerProcess(IAsyncStreamReader requestSt _workersByConnection.GetOrAdd(context.Peer, workerProcess); await workerProcess.Connect().ConfigureAwait(false); } + internal async Task SendMessageAsync(GrpcWorkerConnection connection, CloudEvent cloudEvent, CancellationToken cancellationToken = default) { await connection.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false); } + internal async Task OnReceivedMessageAsync(GrpcWorkerConnection connection, Message message, CancellationToken cancellationToken = default) { _logger.LogInformation("Received message {Message} from connection {Connection}.", message, connection); @@ -171,18 +220,19 @@ internal async Task OnReceivedMessageAsync(GrpcWorkerConnection connection, Mess case Message.MessageOneofCase.CloudEvent: await DispatchEventAsync(message.CloudEvent, cancellationToken); break; - case Message.MessageOneofCase.RegisterAgentTypeRequest: - await RegisterAgentTypeAsync(connection, message.RegisterAgentTypeRequest); - break; - case Message.MessageOneofCase.AddSubscriptionRequest: - await AddSubscriptionAsync(connection, message.AddSubscriptionRequest); - break; + //case Message.MessageOneofCase.RegisterAgentTypeRequest: + // await RegisterAgentTypeAsync(connection, message.RegisterAgentTypeRequest); + // break; + //case Message.MessageOneofCase.AddSubscriptionRequest: + // await AddSubscriptionAsync(connection, message.AddSubscriptionRequest); + // break; default: // if it wasn't recognized return bad request await RespondBadRequestAsync(connection, $"Unknown message type for message '{message}'."); break; }; } + private void DispatchResponse(GrpcWorkerConnection connection, RpcResponse response) { if (!_pendingRequests.TryRemove((connection, response.RequestId), out var completion)) @@ -193,23 +243,7 @@ private void DispatchResponse(GrpcWorkerConnection connection, RpcResponse respo // Complete the request. completion.SetResult(response); } - private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection, RegisterAgentTypeRequest msg) - { - connection.AddSupportedType(msg.Type); - _supportedAgentTypes.GetOrAdd(msg.Type, _ => []).Add(connection); - await _gatewayRegistry.RegisterAgentTypeAsync(msg, _reference).ConfigureAwait(true); - Message response = new() - { - RegisterAgentTypeResponse = new() - { - RequestId = msg.RequestId, - Error = "", - Success = true - } - }; - await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false); - } private async ValueTask DispatchEventAsync(CloudEvent evt, CancellationToken cancellationToken = default) { var registry = _clusterClient.GetGrain(0); @@ -238,11 +272,21 @@ private async ValueTask DispatchEventAsync(CloudEvent evt, CancellationToken can _logger.LogWarning("No agent types found for event type {EventType}.", evt.Type); } } + private async ValueTask DispatchRequestAsync(GrpcWorkerConnection connection, RpcRequest request) { var requestId = request.RequestId; if (request.Target is null) { + // If the gateway knows how to service this request, treat the target as the "Gateway" + if (request.Method == "RegisterAgent") + { + //RegisterAgentTypeRequest request = + + //await RegisterAgentTypeAsync(requestId, connection, request.Payload).ConfigureAwait(false); + return; + } + throw new InvalidOperationException($"Request message is missing a target. Message: '{request}'."); } await InvokeRequestDelegate(connection, request, async request => @@ -260,6 +304,7 @@ await InvokeRequestDelegate(connection, request, async request => return await gateway.InvokeRequestAsync(request).ConfigureAwait(true); }).ConfigureAwait(false); } + private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection, RpcRequest request, Func> func) { try @@ -273,6 +318,7 @@ private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection, await connection.ResponseStream.WriteAsync(new Message { Response = new RpcResponse { RequestId = request.RequestId, Error = ex.Message } }).ConfigureAwait(false); } } + internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess) { _workers.TryRemove(workerProcess, out _); @@ -293,10 +339,12 @@ internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess) } } } + private static async ValueTask RespondBadRequestAsync(GrpcWorkerConnection connection, string error) { throw new RpcException(new Status(StatusCode.InvalidArgument, error)); } + private async ValueTask AddSubscriptionAsync(GrpcWorkerConnection connection, AddSubscriptionRequest request) { var topic = ""; @@ -317,15 +365,16 @@ private async ValueTask AddSubscriptionAsync(GrpcWorkerConnection connection, Ad //var response = new SubscriptionResponse { RequestId = request.RequestId, Error = "", Success = true }; Message response = new() { - AddSubscriptionResponse = new() + Response = new() { - RequestId = request.RequestId, + //RequestId = request.RequestId, Error = "", - Success = true + //Success = true } }; await connection.ResponseStream.WriteAsync(response).ConfigureAwait(false); } + private async ValueTask DispatchEventToAgentsAsync(IEnumerable agentTypes, CloudEvent evt) { var tasks = new List(agentTypes.Count()); @@ -341,6 +390,7 @@ private async ValueTask DispatchEventToAgentsAsync(IEnumerable agentType } await Task.WhenAll(tasks).ConfigureAwait(false); } + public async ValueTask BroadcastEventAsync(CloudEvent evt, CancellationToken cancellationToken = default) { var tasks = new List(_workers.Count); @@ -351,10 +401,12 @@ public async ValueTask BroadcastEventAsync(CloudEvent evt, CancellationToken can } await Task.WhenAll(tasks).ConfigureAwait(false); } + Task IGateway.SendMessageAsync(IConnection connection, CloudEvent cloudEvent) { return this.SendMessageAsync(connection, cloudEvent, default); } + public async Task SendMessageAsync(IConnection connection, CloudEvent cloudEvent, CancellationToken cancellationToken = default) { var queue = (GrpcWorkerConnection)connection; @@ -367,52 +419,60 @@ public async ValueTask UnsubscribeAsync(RemoveSubscr { await _gatewayRegistry.UnsubscribeAsync(request).ConfigureAwait(true); return new RemoveSubscriptionResponse - { - Success = true, + //Success = true, }; } - catch (Exception ex) + catch (Exception) { return new RemoveSubscriptionResponse { - Success = false, - Error = ex.Message + //Success = false, + //Error = ex.Message }; } } + public ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken cancellationToken = default) { return _gatewayRegistry.GetSubscriptionsAsync(request); } + async ValueTask IGateway.InvokeRequestAsync(RpcRequest request) { return await InvokeRequestAsync(request, default).ConfigureAwait(false); } + async ValueTask IGateway.BroadcastEventAsync(CloudEvent evt) { await BroadcastEventAsync(evt, default).ConfigureAwait(false); } + ValueTask IGateway.StoreAsync(AgentState value) { return StoreAsync(value, default); } - ValueTask IGateway.ReadAsync(AgentId agentId) + + ValueTask IGateway.ReadAsync(Protobuf.AgentId agentId) { return ReadAsync(agentId, default); } - ValueTask IGateway.RegisterAgentTypeAsync(RegisterAgentTypeRequest request) + + ValueTask IGateway.RegisterAgentTypeAsync(string requestId, RegisterAgentTypeRequest request) { return RegisterAgentTypeAsync(request, default); } + ValueTask IGateway.SubscribeAsync(AddSubscriptionRequest request) { return SubscribeAsync(request, default); } + ValueTask IGateway.UnsubscribeAsync(RemoveSubscriptionRequest request) { return UnsubscribeAsync(request, default); } + ValueTask> IGateway.GetSubscriptionsAsync(GetSubscriptionsRequest request) { return GetSubscriptionsAsync(request); diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs index 9481922943c..4be883854f4 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs @@ -2,7 +2,8 @@ // GrpcGatewayService.cs using Grpc.Core; -using Microsoft.AutoGen.Contracts; +//using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc; @@ -26,36 +27,42 @@ public override async Task OpenChannel(IAsyncStreamReader requestStream throw; } } + public override async Task GetState(AgentId request, ServerCallContext context) { var state = await Gateway.ReadAsync(request); return new GetStateResponse { AgentState = state }; } + public override async Task SaveState(AgentState request, ServerCallContext context) { await Gateway.StoreAsync(request); return new SaveStateResponse { - Success = true // TODO: Implement error handling + //Success = true // TODO: Implement error handling }; } + public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) { - request.RequestId = context.Peer; + //request.RequestId = context.Peer; return await Gateway.SubscribeAsync(request).ConfigureAwait(true); } + public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) { return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true); } + public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) { var subscriptions = await Gateway.GetSubscriptionsAsync(request); return new GetSubscriptionsResponse { Subscriptions = { subscriptions } }; } + public override async Task RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) { - request.RequestId = context.Peer; + //request.RequestId = context.Peer; return await Gateway.RegisterAgentTypeAsync(request).ConfigureAwait(true); } } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcWorkerConnection.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcWorkerConnection.cs index cba0f8c4772..6b2c544a8e5 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcWorkerConnection.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcWorkerConnection.cs @@ -3,6 +3,8 @@ using System.Threading.Channels; using Grpc.Core; +using Microsoft.AutoGen.Protobuf; +using Microsoft.AutoGen.Runtime.Grpc.Abstractions; namespace Microsoft.AutoGen.Runtime.Grpc; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/AgentStateGrain.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/AgentStateGrain.cs index 97869cd91fd..d745184f04e 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/AgentStateGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/AgentStateGrain.cs @@ -1,10 +1,33 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentStateGrain.cs +using Microsoft.AutoGen.Protobuf; using Microsoft.AutoGen.Runtime.Grpc.Abstractions; namespace Microsoft.AutoGen.Runtime.Grpc; +/// +/// Interface for managing the state of an agent. +/// +public interface IAgentState +{ + /// + /// Reads the current state of the agent asynchronously. + /// + /// A token to cancel the operation. + /// A task that represents the asynchronous read operation. The task result contains the current state of the agent. + ValueTask ReadStateAsync(CancellationToken cancellationToken = default); + + /// + /// Writes the specified state of the agent asynchronously. + /// + /// The state to write. + /// The ETag for concurrency control. + /// A token to cancel the operation. + /// A task that represents the asynchronous write operation. The task result contains the ETag of the written state. + ValueTask WriteStateAsync(AgentState state, string eTag, CancellationToken cancellationToken = default); +} + internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState state) : Grain, IAgentState, IAgentGrain { /// diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/RegistryGrain.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/RegistryGrain.cs index 9de7065fdb6..1faa29f39af 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/RegistryGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/RegistryGrain.cs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RegistryGrain.cs using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; using Microsoft.AutoGen.Runtime.Grpc.Abstractions; namespace Microsoft.AutoGen.Runtime.Grpc; + internal sealed class RegistryGrain([PersistentState("state", "AgentRegistryStore")] IPersistentState state) : Grain, IRegistryGrain { private readonly Dictionary _workerStates = new(); @@ -54,7 +56,7 @@ public ValueTask> GetSubscribedAndHandlingAgentsAsync(string topic, return new ValueTask>(agents); } - public ValueTask<(IGateway? Worker, bool NewPlacement)> GetOrPlaceAgent(AgentId agentId) + public ValueTask<(IGateway? Worker, bool NewPlacement)> GetOrPlaceAgent(Protobuf.AgentId agentId) { // TODO: Clarify the logic bool isNewPlacement; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs index 37e3af1b9d1..6bb6b90da4d 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionRequestSurrogate.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AddSubscriptionRequestSurrogate.cs +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; [GenerateSerializer] @@ -21,7 +23,7 @@ public AddSubscriptionRequest ConvertFromSurrogate( { var request = new AddSubscriptionRequest() { - RequestId = surrogate.RequestId, + //RequestId = surrogate.RequestId, Subscription = surrogate.Subscription }; return request; @@ -31,7 +33,7 @@ public AddSubscriptionRequestSurrogate ConvertToSurrogate( in AddSubscriptionRequest value) => new AddSubscriptionRequestSurrogate { - RequestId = value.RequestId, + //RequestId = value.RequestId, Subscription = value.Subscription }; } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs index 4c15784e0fc..6313d4f8da7 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AddSubscriptionResponseSurrogate.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AddSubscriptionResponseSurrogate.cs +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; [GenerateSerializer] @@ -22,18 +24,18 @@ public AddSubscriptionResponse ConvertFromSurrogate( in AddSubscriptionResponseSurrogate surrogate) => new AddSubscriptionResponse { - RequestId = surrogate.RequestId, - Success = surrogate.Success, - Error = surrogate.Error + //RequestId = surrogate.RequestId, + //Success = surrogate.Success, + //Error = surrogate.Error }; public AddSubscriptionResponseSurrogate ConvertToSurrogate( in AddSubscriptionResponse value) => new AddSubscriptionResponseSurrogate { - RequestId = value.RequestId, - Success = value.Success, - Error = value.Error + //RequestId = value.RequestId, + //Success = value.Success, + //Error = value.Error }; } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs index ddef9e99757..d19c3b05802 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentIdSurrogate.cs @@ -3,7 +3,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentIdSurrogate.cs -using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentStateSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentStateSurrogate.cs index a5291f94215..67b35ef1e8b 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentStateSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/AgentStateSurrogate.cs @@ -2,7 +2,7 @@ // AgentStateSurrogate.cs using Google.Protobuf; -using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs index 22359a08981..7572ec3c31a 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/CloudEventSurrogate.cs @@ -3,6 +3,7 @@ using Google.Protobuf; using Google.Protobuf.WellKnownTypes; +using Microsoft.AutoGen.Contracts; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs index ab4722ff8c7..c9910ca19c1 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/GetSubscriptionsRequest.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // GetSubscriptionsRequest.cs +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; [GenerateSerializer] diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs index fa50e597fab..b2abf686a12 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeRequestSurrogate.cs @@ -2,6 +2,7 @@ // RegisterAgentTypeRequestSurrogate.cs using Google.Protobuf.Collections; +using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; @@ -27,7 +28,7 @@ public RegisterAgentTypeRequest ConvertFromSurrogate( { var request = new RegisterAgentTypeRequest() { - RequestId = surrogate.RequestId, + //RequestId = surrogate.RequestId, Type = surrogate.Type }; /* future @@ -40,7 +41,7 @@ public RegisterAgentTypeRequestSurrogate ConvertToSurrogate( in RegisterAgentTypeRequest value) => new RegisterAgentTypeRequestSurrogate { - RequestId = value.RequestId, + //RequestId = value.RequestId, Type = value.Type, /* future Events = value.Events, diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs index 2c7d6788a76..c6bf562bf8b 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RegisterAgentTypeResponseSurrogate.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RegisterAgentTypeResponseSurrogate.cs +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; [GenerateSerializer] @@ -22,18 +24,18 @@ public RegisterAgentTypeResponse ConvertFromSurrogate( in RegisterAgentTypeResponseSurrogate surrogate) => new RegisterAgentTypeResponse { - RequestId = surrogate.RequestId, - Success = surrogate.Success, - Error = surrogate.Error + //RequestId = surrogate.RequestId, + //Success = surrogate.Success, + //Error = surrogate.Error }; public RegisterAgentTypeResponseSurrogate ConvertToSurrogate( in RegisterAgentTypeResponse value) => new RegisterAgentTypeResponseSurrogate { - RequestId = value.RequestId, - Success = value.Success, - Error = value.Error + //RequestId = value.RequestId, + //Success = value.Success, + //Error = value.Error }; } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs index 27299728baa..96edcc10171 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionRequest.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RemoveSubscriptionRequest.cs +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; [GenerateSerializer] diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs index 88253c99b91..27fcf5edb48 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RemoveSubscriptionResponse.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // RemoveSubscriptionResponse.cs +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; [GenerateSerializer] @@ -22,16 +24,16 @@ public RemoveSubscriptionResponse ConvertFromSurrogate( in RemoveSubscriptionResponseSurrogate surrogate) => new RemoveSubscriptionResponse { - Success = surrogate.Success, - Error = surrogate.Error + //Success = surrogate.Success, + //Error = surrogate.Error }; public RemoveSubscriptionResponseSurrogate ConvertToSurrogate( in RemoveSubscriptionResponse value) => new RemoveSubscriptionResponseSurrogate { - Success = value.Success, - Error = value.Error + //Success = value.Success, + //Error = value.Error }; } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs index 9791a68d795..a8cf07672a9 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcRequestSurrogate.cs @@ -2,7 +2,7 @@ // RpcRequestSurrogate.cs using Google.Protobuf.Collections; -using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs index 5c9fac246f8..fee1f79f522 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/RpcResponseSurrogate.cs @@ -2,6 +2,7 @@ // RpcResponseSurrogate.cs using Google.Protobuf.Collections; +using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs index 1fd56c17627..dc060023c16 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/SubscriptionSurrogate.cs @@ -1,54 +1,55 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SubscriptionSurrogate.cs -using Microsoft.AutoGen.Contracts; +//using Microsoft.AutoGen.Contracts; +//using Microsoft.AutoGen.Protobuf; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; -[GenerateSerializer] -public struct SubscriptionSurrogate -{ - [Id(0)] - public TypeSubscription? TypeSubscription; - [Id(1)] - public TypePrefixSubscription? TypePrefixSubscription; - [Id(2)] - public string Id; -} +//[GenerateSerializer] +//public struct SubscriptionSurrogate +//{ +// [Id(0)] +// public TypeSubscription? TypeSubscription; +// [Id(1)] +// public TypePrefixSubscription? TypePrefixSubscription; +// [Id(2)] +// public string Id; +//} -[RegisterConverter] -public sealed class SubscriptionSurrogateConverter : - IConverter -{ - public Subscription ConvertFromSurrogate( - in SubscriptionSurrogate surrogate) - { - if (surrogate.TypeSubscription is not null) - { - return new Subscription - { - Id = surrogate.Id, - TypeSubscription = surrogate.TypeSubscription - }; - } - else - { - return new Subscription - { - Id = surrogate.Id, - TypePrefixSubscription = surrogate.TypePrefixSubscription - }; - } - } +//[RegisterConverter] +//public sealed class SubscriptionSurrogateConverter : +// IConverter +//{ +// public Subscription ConvertFromSurrogate( +// in SubscriptionSurrogate surrogate) +// { +// if (surrogate.TypeSubscription is not null) +// { +// return new Subscription +// { +// Id = surrogate.Id, +// TypeSubscription = surrogate.TypeSubscription +// }; +// } +// else +// { +// return new Subscription +// { +// Id = surrogate.Id, +// TypePrefixSubscription = surrogate.TypePrefixSubscription +// }; +// } +// } - public SubscriptionSurrogate ConvertToSurrogate( - in Subscription value) - { - return new SubscriptionSurrogate - { - Id = value.Id, - TypeSubscription = value.TypeSubscription, - TypePrefixSubscription = value.TypePrefixSubscription - }; - } -} +// public SubscriptionSurrogate ConvertToSurrogate( +// in Subscription value) +// { +// return new SubscriptionSurrogate +// { +// Id = value.Id, +// TypeSubscription = value.TypeSubscription, +// TypePrefixSubscription = value.TypePrefixSubscription +// }; +// } +//} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs index ca4d721315e..ff2d684c6ba 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypePrefixSubscriptionSurrogate.cs @@ -1,36 +1,36 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TypePrefixSubscriptionSurrogate.cs -using Microsoft.AutoGen.Contracts; +//using Microsoft.AutoGen.Contracts; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; -[GenerateSerializer] -public struct TypePrefixSubscriptionSurrogate -{ - [Id(0)] - public string TopicTypePrefix; - [Id(1)] - public string AgentType; -} +//[GenerateSerializer] +//public struct TypePrefixSubscriptionSurrogate +//{ +// [Id(0)] +// public string TopicTypePrefix; +// [Id(1)] +// public string AgentType; +//} -[RegisterConverter] -public sealed class TypePrefixSubscriptionConverter : - IConverter -{ - public TypePrefixSubscription ConvertFromSurrogate( - in TypePrefixSubscriptionSurrogate surrogate) => - new TypePrefixSubscription - { - TopicTypePrefix = surrogate.TopicTypePrefix, - AgentType = surrogate.AgentType - }; +//[RegisterConverter] +//public sealed class TypePrefixSubscriptionConverter : +// IConverter +//{ +// public TypePrefixSubscription ConvertFromSurrogate( +// in TypePrefixSubscriptionSurrogate surrogate) => +// new TypePrefixSubscription +// { +// TopicTypePrefix = surrogate.TopicTypePrefix, +// AgentType = surrogate.AgentType +// }; - public TypePrefixSubscriptionSurrogate ConvertToSurrogate( - in TypePrefixSubscription value) => - new TypePrefixSubscriptionSurrogate - { - TopicTypePrefix = value.TopicTypePrefix, - AgentType = value.AgentType - }; -} +// public TypePrefixSubscriptionSurrogate ConvertToSurrogate( +// in TypePrefixSubscription value) => +// new TypePrefixSubscriptionSurrogate +// { +// TopicTypePrefix = value.TopicTypePrefix, +// AgentType = value.AgentType +// }; +//} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs index 57fa202ebfc..ff28bfcac61 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Orleans/Surrogates/TypeSubscriptionSurrogate.cs @@ -1,36 +1,36 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TypeSubscriptionSurrogate.cs -using Microsoft.AutoGen.Contracts; +//using Microsoft.AutoGen.Contracts; namespace Microsoft.AutoGen.Runtime.Grpc.Orleans.Surrogates; -[GenerateSerializer] -public struct TypeSubscriptionSurrogate -{ - [Id(0)] - public string TopicType; - [Id(1)] - public string AgentType; -} +//[GenerateSerializer] +//public struct TypeSubscriptionSurrogate +//{ +// [Id(0)] +// public string TopicType; +// [Id(1)] +// public string AgentType; +//} -[RegisterConverter] -public sealed class TypeSubscriptionSurrogateConverter : - IConverter -{ - public TypeSubscription ConvertFromSurrogate( - in TypeSubscriptionSurrogate surrogate) => - new TypeSubscription - { - TopicType = surrogate.TopicType, - AgentType = surrogate.AgentType - }; +//[RegisterConverter] +//public sealed class TypeSubscriptionSurrogateConverter : +// IConverter +//{ +// public TypeSubscription ConvertFromSurrogate( +// in TypeSubscriptionSurrogate surrogate) => +// new TypeSubscription +// { +// TopicType = surrogate.TopicType, +// AgentType = surrogate.AgentType +// }; - public TypeSubscriptionSurrogate ConvertToSurrogate( - in TypeSubscription value) => - new TypeSubscriptionSurrogate - { - TopicType = value.TopicType, - AgentType = value.AgentType - }; -} +// public TypeSubscriptionSurrogate ConvertToSurrogate( +// in TypeSubscription value) => +// new TypeSubscriptionSurrogate +// { +// TopicType = value.TopicType, +// AgentType = value.AgentType +// }; +//} From ff2b42328198ab01a61ab44af589ec30eb59ac1c Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 30 Jan 2025 02:02:32 -0500 Subject: [PATCH 8/8] wip: Implementing GrpcGateway * Simplify Gateway RPC protocol (remove RPCs in favour of the message channel) * Generify Payload deserialization * Implement RegisterAgentType --- .../Core.Grpc/GrpcAgentRuntime.cs | 69 ++++++++++--------- .../Core.Grpc/ISerializationRegistry.cs | 38 ++++++++++ .../Runtime.Grpc/Services/Grpc/GrpcGateway.cs | 41 ++++++++--- .../Services/Grpc/GrpcGatewayService.cs | 64 ++++++++--------- protos/agent_worker.proto | 14 ++-- 5 files changed, 145 insertions(+), 81 deletions(-) diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index f0be2376a15..1139dd35805 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -2,7 +2,6 @@ // GrpcAgentRuntime.cs using System.Collections.Concurrent; -using Google.Protobuf; using Grpc.Core; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Hosting; @@ -155,7 +154,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc // Convert payload back to object var payload = request.Payload; - var message = PayloadToObject(payload); + var message = this.SerializationRegistry.PayloadToObject(payload); var messageContext = new MessageContext(request.RequestId, cancellationToken) { @@ -171,7 +170,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc var response = new RpcResponse { RequestId = request.RequestId, - Payload = ObjectToPayload(result) + Payload = this.SerializationRegistry.ObjectToPayload(result) }; var responseMessage = new Message @@ -201,7 +200,7 @@ private async ValueTask HandleResponse(RpcResponse request, CancellationToken _ if (_pendingRequests.TryRemove(request.RequestId, out var resultSink)) { var payload = request.Payload; - var message = PayloadToObject(payload); + var message = this.SerializationRegistry.PayloadToObject(payload); resultSink.SetResult(message); } } @@ -255,35 +254,35 @@ public Task StopAsync(CancellationToken cancellationToken) return this._messageRouter.StopAsync(); } - private Payload ObjectToPayload(object message) { - if (!SerializationRegistry.Exists(message.GetType())) - { - SerializationRegistry.RegisterSerializer(message.GetType()); - } - var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); - - var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message); - const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; - - // Protobuf any to byte array - Payload payload = new() - { - DataType = typeName, - DataContentType = PAYLOAD_DATA_CONTENT_TYPE, - Data = rpcMessage.ToByteString() - }; - - return payload; - } - - private object PayloadToObject(Payload payload) { - var typeName = payload.DataType; - var data = payload.Data; - var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName); - var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception(); - var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); - return serializer.Deserialize(any); - } + //private Payload ObjectToPayload(object message) { + // if (!SerializationRegistry.Exists(message.GetType())) + // { + // SerializationRegistry.RegisterSerializer(message.GetType()); + // } + // var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + + // var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message); + // const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + + // // Protobuf any to byte array + // Payload payload = new() + // { + // DataType = typeName, + // DataContentType = PAYLOAD_DATA_CONTENT_TYPE, + // Data = rpcMessage.ToByteString() + // }; + + // return payload; + //} + + //private object PayloadToObject(Payload payload) { + // var typeName = payload.DataType; + // var data = payload.Data; + // var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName); + // var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception(); + // var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); + // return serializer.Deserialize(any); + //} public async ValueTask SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default) { @@ -292,7 +291,7 @@ private object PayloadToObject(Payload payload) { SerializationRegistry.RegisterSerializer(message.GetType()); } - var payload = ObjectToPayload(message); + var payload = this.SerializationRegistry.ObjectToPayload(message); var request = new RpcRequest { RequestId = Guid.NewGuid().ToString(), @@ -388,6 +387,8 @@ public ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription) // Because we have an extensible definition of ISubscriptionDefinition, we cannot project it to the Gateway. // What this means is that we will have a much chattier interface between the Gateway and the Runtime. + // TODO: We will be able to make this better by treating unknown subscription types as an "everything" + // subscription. This will allow us to have a single subscription for all unknown types. //await this._client.AddSubscriptionAsync(new AddSubscriptionRequest //{ diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs index 190ed3ec239..67380459502 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ISerializationRegistry.cs +using Google.Protobuf; +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Core.Grpc; public interface IProtoSerializationRegistry @@ -25,3 +28,38 @@ public interface IProtoSerializationRegistry bool Exists(System.Type type); } + +public static class SerializerRegistryExtensions +{ + public static Payload ObjectToPayload(this IProtoSerializationRegistry this_, object message) + { + if (!this_.Exists(message.GetType())) + { + this_.RegisterSerializer(message.GetType()); + } + var rpcMessage = (this_.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + + var typeName = this_.TypeNameResolver.ResolveTypeName(message); + const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + + // Protobuf any to byte array + Payload payload = new() + { + DataType = typeName, + DataContentType = PAYLOAD_DATA_CONTENT_TYPE, + Data = rpcMessage.ToByteString() + }; + + return payload; + } + + public static object PayloadToObject(this IProtoSerializationRegistry this_, Payload payload) + { + var typeName = payload.DataType; + var data = payload.Data; + var type = this_.TypeNameResolver.ResolveTypeName(typeName); + var serializer = this_.GetSerializer(type) ?? throw new Exception(); + var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); + return serializer.Deserialize(any); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs index afa4c3603b5..1672bb40c9a 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs @@ -2,8 +2,10 @@ // GrpcGateway.cs using System.Collections.Concurrent; +using System.Diagnostics; using Grpc.Core; using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core.Grpc; using Microsoft.AutoGen.Protobuf; using Microsoft.AutoGen.Runtime.Grpc.Abstractions; using Microsoft.Extensions.Hosting; @@ -27,6 +29,8 @@ public sealed class GrpcGateway : BackgroundService, IGateway private readonly ConcurrentDictionary> _subscriptionsByTopic = new(); private readonly ISubscriptionsGrain _subscriptions; + private IProtoSerializationRegistry SerializationRegistry { get; } = new ProtoSerializationRegistry(); + // The mapping from agent id to worker process. private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new(); // RPC @@ -86,20 +90,20 @@ public async ValueTask InvokeRequestAsync(RpcRequest request, Cance return response; } - public async ValueTask StoreAsync(AgentState value, CancellationToken cancellationToken = default) + private async ValueTask StoreAsync(AgentState value, CancellationToken __ = default) { _ = value.AgentId ?? throw new ArgumentNullException(nameof(value.AgentId)); var agentState = _clusterClient.GetGrain($"{value.AgentId.Type}:{value.AgentId.Key}"); await agentState.WriteStateAsync(value, value.ETag); } - public async ValueTask ReadAsync(Protobuf.AgentId agentId, CancellationToken cancellationToken = default) + private async ValueTask ReadAsync(Protobuf.AgentId agentId, CancellationToken _ = default) { var agentState = _clusterClient.GetGrain($"{agentId.Type}:{agentId.Key}"); return await agentState.ReadStateAsync(); } - public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken _ = default) + private async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken _ = default) { string requestId = string.Empty; try @@ -248,7 +252,7 @@ private async ValueTask DispatchEventAsync(CloudEvent evt, CancellationToken can { var registry = _clusterClient.GetGrain(0); //intentionally blocking - var targetAgentTypes = await registry.GetSubscribedAndHandlingAgentsAsync(evt.Source, evt.Type).ConfigureAwait(true); + var targetAgentTypes = await registry.GetSubscribedAndHandlingAgentsAsync(topic: evt.Source, eventType: evt.Type).ConfigureAwait(true); if (targetAgentTypes is not null && targetAgentTypes.Count > 0) { targetAgentTypes = targetAgentTypes.Distinct().ToList(); @@ -279,12 +283,25 @@ private async ValueTask DispatchRequestAsync(GrpcWorkerConnection connection, Rp if (request.Target is null) { // If the gateway knows how to service this request, treat the target as the "Gateway" - if (request.Method == "RegisterAgent") + if (request.Method == "RegisterAgent" && + request.Payload is not null && + request.Payload.DataType == nameof(RegisterAgentTypeRequest) && + request.Payload.Data is not null) { - //RegisterAgentTypeRequest request = + object? payloadData = this.SerializationRegistry.PayloadToObject(request.Payload); + if (payloadData is RegisterAgentTypeRequest registerAgentTypeRequest) + { + await RegisterAgentTypeAsync(requestId, connection, registerAgentTypeRequest).ConfigureAwait(false); + return; + } + else + { + await RespondBadRequestAsync(connection, $"Invalid payload type for \"RegisterAgent\" Expected: {nameof(RegisterAgentTypeRequest)}; got: {payloadData?.GetType().Name ?? "null"}.").ConfigureAwait(false); + return; + } //await RegisterAgentTypeAsync(requestId, connection, request.Payload).ConfigureAwait(false); - return; + //return; } throw new InvalidOperationException($"Request message is missing a target. Message: '{request}'."); @@ -413,7 +430,7 @@ public async Task SendMessageAsync(IConnection connection, CloudEvent cloudEvent await queue.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false); } - public async ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default) + private async ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request, CancellationToken _ = default) { try { @@ -433,7 +450,7 @@ public async ValueTask UnsubscribeAsync(RemoveSubscr } } - public ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken cancellationToken = default) + private ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken _ = default) { return _gatewayRegistry.GetSubscriptionsAsync(request); } @@ -450,31 +467,37 @@ async ValueTask IGateway.BroadcastEventAsync(CloudEvent evt) ValueTask IGateway.StoreAsync(AgentState value) { + Debug.Fail("This method should not be called."); return StoreAsync(value, default); } ValueTask IGateway.ReadAsync(Protobuf.AgentId agentId) { + Debug.Fail("This method should not be called."); return ReadAsync(agentId, default); } ValueTask IGateway.RegisterAgentTypeAsync(string requestId, RegisterAgentTypeRequest request) { + Debug.Fail("This method should not be called."); return RegisterAgentTypeAsync(request, default); } ValueTask IGateway.SubscribeAsync(AddSubscriptionRequest request) { + Debug.Fail("This method should not be called."); return SubscribeAsync(request, default); } ValueTask IGateway.UnsubscribeAsync(RemoveSubscriptionRequest request) { + Debug.Fail("This method should not be called."); return UnsubscribeAsync(request, default); } ValueTask> IGateway.GetSubscriptionsAsync(GetSubscriptionsRequest request) { + Debug.Fail("This method should not be called."); return GetSubscriptionsAsync(request); } } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs index 4be883854f4..dee1db5138d 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs @@ -28,41 +28,41 @@ public override async Task OpenChannel(IAsyncStreamReader requestStream } } - public override async Task GetState(AgentId request, ServerCallContext context) - { - var state = await Gateway.ReadAsync(request); - return new GetStateResponse { AgentState = state }; - } + //public override async Task GetState(AgentId request, ServerCallContext context) + //{ + // var state = await Gateway.ReadAsync(request); + // return new GetStateResponse { AgentState = state }; + //} - public override async Task SaveState(AgentState request, ServerCallContext context) - { - await Gateway.StoreAsync(request); - return new SaveStateResponse - { - //Success = true // TODO: Implement error handling - }; - } + //public override async Task SaveState(AgentState request, ServerCallContext context) + //{ + // await Gateway.StoreAsync(request); + // return new SaveStateResponse + // { + // //Success = true // TODO: Implement error handling + // }; + //} - public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) - { - //request.RequestId = context.Peer; - return await Gateway.SubscribeAsync(request).ConfigureAwait(true); - } + //public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) + //{ + // //request.RequestId = context.Peer; + // return await Gateway.SubscribeAsync(request).ConfigureAwait(true); + //} - public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) - { - return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true); - } + //public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) + //{ + // return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true); + //} - public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) - { - var subscriptions = await Gateway.GetSubscriptionsAsync(request); - return new GetSubscriptionsResponse { Subscriptions = { subscriptions } }; - } + //public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) + //{ + // var subscriptions = await Gateway.GetSubscriptionsAsync(request); + // return new GetSubscriptionsResponse { Subscriptions = { subscriptions } }; + //} - public override async Task RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) - { - //request.RequestId = context.Peer; - return await Gateway.RegisterAgentTypeAsync(request).ConfigureAwait(true); - } + //public override async Task RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) + //{ + // //request.RequestId = context.Peer; + // return await Gateway.RegisterAgentTypeAsync(request).ConfigureAwait(true); + //} } diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 07375fe9725..ebaae2c1f89 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -106,10 +106,12 @@ message Message { service AgentRpc { rpc OpenChannel (stream Message) returns (stream Message); - rpc GetState(AgentId) returns (GetStateResponse); - rpc SaveState(AgentState) returns (SaveStateResponse); - rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse); - rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse); - rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse); - rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse); + + // These should be coming through the RpcRequest/Response pathway through the Message Channel + //rpc GetState(AgentId) returns (GetStateResponse); + //rpc SaveState(AgentState) returns (SaveStateResponse); + //rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse); + //rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse); + //rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse); + //rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse); }