From ff2b42328198ab01a61ab44af589ec30eb59ac1c Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 30 Jan 2025 02:02:32 -0500 Subject: [PATCH] wip: Implementing GrpcGateway * Simplify Gateway RPC protocol (remove RPCs in favour of the message channel) * Generify Payload deserialization * Implement RegisterAgentType --- .../Core.Grpc/GrpcAgentRuntime.cs | 69 ++++++++++--------- .../Core.Grpc/ISerializationRegistry.cs | 38 ++++++++++ .../Runtime.Grpc/Services/Grpc/GrpcGateway.cs | 41 ++++++++--- .../Services/Grpc/GrpcGatewayService.cs | 64 ++++++++--------- protos/agent_worker.proto | 14 ++-- 5 files changed, 145 insertions(+), 81 deletions(-) diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index f0be2376a15..1139dd35805 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -2,7 +2,6 @@ // GrpcAgentRuntime.cs using System.Collections.Concurrent; -using Google.Protobuf; using Grpc.Core; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Hosting; @@ -155,7 +154,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc // Convert payload back to object var payload = request.Payload; - var message = PayloadToObject(payload); + var message = this.SerializationRegistry.PayloadToObject(payload); var messageContext = new MessageContext(request.RequestId, cancellationToken) { @@ -171,7 +170,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc var response = new RpcResponse { RequestId = request.RequestId, - Payload = ObjectToPayload(result) + Payload = this.SerializationRegistry.ObjectToPayload(result) }; var responseMessage = new Message @@ -201,7 +200,7 @@ private async ValueTask HandleResponse(RpcResponse request, CancellationToken _ if (_pendingRequests.TryRemove(request.RequestId, out var resultSink)) { var payload = request.Payload; - var message = PayloadToObject(payload); + var message = this.SerializationRegistry.PayloadToObject(payload); resultSink.SetResult(message); } } @@ -255,35 +254,35 @@ public Task StopAsync(CancellationToken cancellationToken) return this._messageRouter.StopAsync(); } - private Payload ObjectToPayload(object message) { - if (!SerializationRegistry.Exists(message.GetType())) - { - SerializationRegistry.RegisterSerializer(message.GetType()); - } - var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); - - var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message); - const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; - - // Protobuf any to byte array - Payload payload = new() - { - DataType = typeName, - DataContentType = PAYLOAD_DATA_CONTENT_TYPE, - Data = rpcMessage.ToByteString() - }; - - return payload; - } - - private object PayloadToObject(Payload payload) { - var typeName = payload.DataType; - var data = payload.Data; - var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName); - var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception(); - var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); - return serializer.Deserialize(any); - } + //private Payload ObjectToPayload(object message) { + // if (!SerializationRegistry.Exists(message.GetType())) + // { + // SerializationRegistry.RegisterSerializer(message.GetType()); + // } + // var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + + // var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message); + // const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + + // // Protobuf any to byte array + // Payload payload = new() + // { + // DataType = typeName, + // DataContentType = PAYLOAD_DATA_CONTENT_TYPE, + // Data = rpcMessage.ToByteString() + // }; + + // return payload; + //} + + //private object PayloadToObject(Payload payload) { + // var typeName = payload.DataType; + // var data = payload.Data; + // var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName); + // var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception(); + // var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); + // return serializer.Deserialize(any); + //} public async ValueTask SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default) { @@ -292,7 +291,7 @@ private object PayloadToObject(Payload payload) { SerializationRegistry.RegisterSerializer(message.GetType()); } - var payload = ObjectToPayload(message); + var payload = this.SerializationRegistry.ObjectToPayload(message); var request = new RpcRequest { RequestId = Guid.NewGuid().ToString(), @@ -388,6 +387,8 @@ public ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription) // Because we have an extensible definition of ISubscriptionDefinition, we cannot project it to the Gateway. // What this means is that we will have a much chattier interface between the Gateway and the Runtime. + // TODO: We will be able to make this better by treating unknown subscription types as an "everything" + // subscription. This will allow us to have a single subscription for all unknown types. //await this._client.AddSubscriptionAsync(new AddSubscriptionRequest //{ diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs index 190ed3ec239..67380459502 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ISerializationRegistry.cs +using Google.Protobuf; +using Microsoft.AutoGen.Protobuf; + namespace Microsoft.AutoGen.Core.Grpc; public interface IProtoSerializationRegistry @@ -25,3 +28,38 @@ public interface IProtoSerializationRegistry bool Exists(System.Type type); } + +public static class SerializerRegistryExtensions +{ + public static Payload ObjectToPayload(this IProtoSerializationRegistry this_, object message) + { + if (!this_.Exists(message.GetType())) + { + this_.RegisterSerializer(message.GetType()); + } + var rpcMessage = (this_.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + + var typeName = this_.TypeNameResolver.ResolveTypeName(message); + const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + + // Protobuf any to byte array + Payload payload = new() + { + DataType = typeName, + DataContentType = PAYLOAD_DATA_CONTENT_TYPE, + Data = rpcMessage.ToByteString() + }; + + return payload; + } + + public static object PayloadToObject(this IProtoSerializationRegistry this_, Payload payload) + { + var typeName = payload.DataType; + var data = payload.Data; + var type = this_.TypeNameResolver.ResolveTypeName(typeName); + var serializer = this_.GetSerializer(type) ?? throw new Exception(); + var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); + return serializer.Deserialize(any); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs index afa4c3603b5..1672bb40c9a 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGateway.cs @@ -2,8 +2,10 @@ // GrpcGateway.cs using System.Collections.Concurrent; +using System.Diagnostics; using Grpc.Core; using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core.Grpc; using Microsoft.AutoGen.Protobuf; using Microsoft.AutoGen.Runtime.Grpc.Abstractions; using Microsoft.Extensions.Hosting; @@ -27,6 +29,8 @@ public sealed class GrpcGateway : BackgroundService, IGateway private readonly ConcurrentDictionary> _subscriptionsByTopic = new(); private readonly ISubscriptionsGrain _subscriptions; + private IProtoSerializationRegistry SerializationRegistry { get; } = new ProtoSerializationRegistry(); + // The mapping from agent id to worker process. private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new(); // RPC @@ -86,20 +90,20 @@ public async ValueTask InvokeRequestAsync(RpcRequest request, Cance return response; } - public async ValueTask StoreAsync(AgentState value, CancellationToken cancellationToken = default) + private async ValueTask StoreAsync(AgentState value, CancellationToken __ = default) { _ = value.AgentId ?? throw new ArgumentNullException(nameof(value.AgentId)); var agentState = _clusterClient.GetGrain($"{value.AgentId.Type}:{value.AgentId.Key}"); await agentState.WriteStateAsync(value, value.ETag); } - public async ValueTask ReadAsync(Protobuf.AgentId agentId, CancellationToken cancellationToken = default) + private async ValueTask ReadAsync(Protobuf.AgentId agentId, CancellationToken _ = default) { var agentState = _clusterClient.GetGrain($"{agentId.Type}:{agentId.Key}"); return await agentState.ReadStateAsync(); } - public async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken _ = default) + private async ValueTask RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken _ = default) { string requestId = string.Empty; try @@ -248,7 +252,7 @@ private async ValueTask DispatchEventAsync(CloudEvent evt, CancellationToken can { var registry = _clusterClient.GetGrain(0); //intentionally blocking - var targetAgentTypes = await registry.GetSubscribedAndHandlingAgentsAsync(evt.Source, evt.Type).ConfigureAwait(true); + var targetAgentTypes = await registry.GetSubscribedAndHandlingAgentsAsync(topic: evt.Source, eventType: evt.Type).ConfigureAwait(true); if (targetAgentTypes is not null && targetAgentTypes.Count > 0) { targetAgentTypes = targetAgentTypes.Distinct().ToList(); @@ -279,12 +283,25 @@ private async ValueTask DispatchRequestAsync(GrpcWorkerConnection connection, Rp if (request.Target is null) { // If the gateway knows how to service this request, treat the target as the "Gateway" - if (request.Method == "RegisterAgent") + if (request.Method == "RegisterAgent" && + request.Payload is not null && + request.Payload.DataType == nameof(RegisterAgentTypeRequest) && + request.Payload.Data is not null) { - //RegisterAgentTypeRequest request = + object? payloadData = this.SerializationRegistry.PayloadToObject(request.Payload); + if (payloadData is RegisterAgentTypeRequest registerAgentTypeRequest) + { + await RegisterAgentTypeAsync(requestId, connection, registerAgentTypeRequest).ConfigureAwait(false); + return; + } + else + { + await RespondBadRequestAsync(connection, $"Invalid payload type for \"RegisterAgent\" Expected: {nameof(RegisterAgentTypeRequest)}; got: {payloadData?.GetType().Name ?? "null"}.").ConfigureAwait(false); + return; + } //await RegisterAgentTypeAsync(requestId, connection, request.Payload).ConfigureAwait(false); - return; + //return; } throw new InvalidOperationException($"Request message is missing a target. Message: '{request}'."); @@ -413,7 +430,7 @@ public async Task SendMessageAsync(IConnection connection, CloudEvent cloudEvent await queue.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false); } - public async ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default) + private async ValueTask UnsubscribeAsync(RemoveSubscriptionRequest request, CancellationToken _ = default) { try { @@ -433,7 +450,7 @@ public async ValueTask UnsubscribeAsync(RemoveSubscr } } - public ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken cancellationToken = default) + private ValueTask> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken _ = default) { return _gatewayRegistry.GetSubscriptionsAsync(request); } @@ -450,31 +467,37 @@ async ValueTask IGateway.BroadcastEventAsync(CloudEvent evt) ValueTask IGateway.StoreAsync(AgentState value) { + Debug.Fail("This method should not be called."); return StoreAsync(value, default); } ValueTask IGateway.ReadAsync(Protobuf.AgentId agentId) { + Debug.Fail("This method should not be called."); return ReadAsync(agentId, default); } ValueTask IGateway.RegisterAgentTypeAsync(string requestId, RegisterAgentTypeRequest request) { + Debug.Fail("This method should not be called."); return RegisterAgentTypeAsync(request, default); } ValueTask IGateway.SubscribeAsync(AddSubscriptionRequest request) { + Debug.Fail("This method should not be called."); return SubscribeAsync(request, default); } ValueTask IGateway.UnsubscribeAsync(RemoveSubscriptionRequest request) { + Debug.Fail("This method should not be called."); return UnsubscribeAsync(request, default); } ValueTask> IGateway.GetSubscriptionsAsync(GetSubscriptionsRequest request) { + Debug.Fail("This method should not be called."); return GetSubscriptionsAsync(request); } } diff --git a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs index 4be883854f4..dee1db5138d 100644 --- a/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs +++ b/dotnet/src/Microsoft.AutoGen/Runtime.Grpc/Services/Grpc/GrpcGatewayService.cs @@ -28,41 +28,41 @@ public override async Task OpenChannel(IAsyncStreamReader requestStream } } - public override async Task GetState(AgentId request, ServerCallContext context) - { - var state = await Gateway.ReadAsync(request); - return new GetStateResponse { AgentState = state }; - } + //public override async Task GetState(AgentId request, ServerCallContext context) + //{ + // var state = await Gateway.ReadAsync(request); + // return new GetStateResponse { AgentState = state }; + //} - public override async Task SaveState(AgentState request, ServerCallContext context) - { - await Gateway.StoreAsync(request); - return new SaveStateResponse - { - //Success = true // TODO: Implement error handling - }; - } + //public override async Task SaveState(AgentState request, ServerCallContext context) + //{ + // await Gateway.StoreAsync(request); + // return new SaveStateResponse + // { + // //Success = true // TODO: Implement error handling + // }; + //} - public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) - { - //request.RequestId = context.Peer; - return await Gateway.SubscribeAsync(request).ConfigureAwait(true); - } + //public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) + //{ + // //request.RequestId = context.Peer; + // return await Gateway.SubscribeAsync(request).ConfigureAwait(true); + //} - public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) - { - return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true); - } + //public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) + //{ + // return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true); + //} - public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) - { - var subscriptions = await Gateway.GetSubscriptionsAsync(request); - return new GetSubscriptionsResponse { Subscriptions = { subscriptions } }; - } + //public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) + //{ + // var subscriptions = await Gateway.GetSubscriptionsAsync(request); + // return new GetSubscriptionsResponse { Subscriptions = { subscriptions } }; + //} - public override async Task RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) - { - //request.RequestId = context.Peer; - return await Gateway.RegisterAgentTypeAsync(request).ConfigureAwait(true); - } + //public override async Task RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) + //{ + // //request.RequestId = context.Peer; + // return await Gateway.RegisterAgentTypeAsync(request).ConfigureAwait(true); + //} } diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 07375fe9725..ebaae2c1f89 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -106,10 +106,12 @@ message Message { service AgentRpc { rpc OpenChannel (stream Message) returns (stream Message); - rpc GetState(AgentId) returns (GetStateResponse); - rpc SaveState(AgentState) returns (SaveStateResponse); - rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse); - rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse); - rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse); - rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse); + + // These should be coming through the RpcRequest/Response pathway through the Message Channel + //rpc GetState(AgentId) returns (GetStateResponse); + //rpc SaveState(AgentState) returns (SaveStateResponse); + //rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse); + //rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse); + //rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse); + //rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse); }