Skip to content

Commit

Permalink
wip: Implementing GrpcGateway
Browse files Browse the repository at this point in the history
* Simplify Gateway RPC protocol (remove RPCs in favour of the message channel)
* Generify Payload deserialization
* Implement RegisterAgentType
  • Loading branch information
lokitoth committed Jan 30, 2025
1 parent 07bdb55 commit ff2b423
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 81 deletions.
69 changes: 35 additions & 34 deletions dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// GrpcAgentRuntime.cs

using System.Collections.Concurrent;
using Google.Protobuf;
using Grpc.Core;
using Microsoft.AutoGen.Contracts;
using Microsoft.Extensions.Hosting;
Expand Down Expand Up @@ -155,7 +154,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc

// Convert payload back to object
var payload = request.Payload;
var message = PayloadToObject(payload);
var message = this.SerializationRegistry.PayloadToObject(payload);

var messageContext = new MessageContext(request.RequestId, cancellationToken)
{
Expand All @@ -171,7 +170,7 @@ private async ValueTask HandleRequest(RpcRequest request, CancellationToken canc
var response = new RpcResponse
{
RequestId = request.RequestId,
Payload = ObjectToPayload(result)
Payload = this.SerializationRegistry.ObjectToPayload(result)
};

var responseMessage = new Message
Expand Down Expand Up @@ -201,7 +200,7 @@ private async ValueTask HandleResponse(RpcResponse request, CancellationToken _
if (_pendingRequests.TryRemove(request.RequestId, out var resultSink))
{
var payload = request.Payload;
var message = PayloadToObject(payload);
var message = this.SerializationRegistry.PayloadToObject(payload);
resultSink.SetResult(message);
}
}
Expand Down Expand Up @@ -255,35 +254,35 @@ public Task StopAsync(CancellationToken cancellationToken)
return this._messageRouter.StopAsync();
}

private Payload ObjectToPayload(object message) {
if (!SerializationRegistry.Exists(message.GetType()))
{
SerializationRegistry.RegisterSerializer(message.GetType());
}
var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);

var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message);
const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf";

// Protobuf any to byte array
Payload payload = new()
{
DataType = typeName,
DataContentType = PAYLOAD_DATA_CONTENT_TYPE,
Data = rpcMessage.ToByteString()
};

return payload;
}

private object PayloadToObject(Payload payload) {
var typeName = payload.DataType;
var data = payload.Data;
var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName);
var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception();
var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data);
return serializer.Deserialize(any);
}
//private Payload ObjectToPayload(object message) {
// if (!SerializationRegistry.Exists(message.GetType()))
// {
// SerializationRegistry.RegisterSerializer(message.GetType());
// }
// var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);

// var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message);
// const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf";

// // Protobuf any to byte array
// Payload payload = new()
// {
// DataType = typeName,
// DataContentType = PAYLOAD_DATA_CONTENT_TYPE,
// Data = rpcMessage.ToByteString()
// };

// return payload;
//}

//private object PayloadToObject(Payload payload) {
// var typeName = payload.DataType;
// var data = payload.Data;
// var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName);
// var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception();
// var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data);
// return serializer.Deserialize(any);
//}

public async ValueTask<object?> SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default)
{
Expand All @@ -292,7 +291,7 @@ private object PayloadToObject(Payload payload) {
SerializationRegistry.RegisterSerializer(message.GetType());
}

var payload = ObjectToPayload(message);
var payload = this.SerializationRegistry.ObjectToPayload(message);
var request = new RpcRequest
{
RequestId = Guid.NewGuid().ToString(),
Expand Down Expand Up @@ -388,6 +387,8 @@ public ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription)

// Because we have an extensible definition of ISubscriptionDefinition, we cannot project it to the Gateway.
// What this means is that we will have a much chattier interface between the Gateway and the Runtime.
// TODO: We will be able to make this better by treating unknown subscription types as an "everything"
// subscription. This will allow us to have a single subscription for all unknown types.

//await this._client.AddSubscriptionAsync(new AddSubscriptionRequest
//{
Expand Down
38 changes: 38 additions & 0 deletions dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ISerializationRegistry.cs

using Google.Protobuf;
using Microsoft.AutoGen.Protobuf;

namespace Microsoft.AutoGen.Core.Grpc;

public interface IProtoSerializationRegistry
Expand All @@ -25,3 +28,38 @@ public interface IProtoSerializationRegistry

bool Exists(System.Type type);
}

public static class SerializerRegistryExtensions
{
public static Payload ObjectToPayload(this IProtoSerializationRegistry this_, object message)
{
if (!this_.Exists(message.GetType()))
{
this_.RegisterSerializer(message.GetType());
}
var rpcMessage = (this_.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);

var typeName = this_.TypeNameResolver.ResolveTypeName(message);
const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf";

// Protobuf any to byte array
Payload payload = new()
{
DataType = typeName,
DataContentType = PAYLOAD_DATA_CONTENT_TYPE,
Data = rpcMessage.ToByteString()
};

return payload;
}

public static object PayloadToObject(this IProtoSerializationRegistry this_, Payload payload)
{
var typeName = payload.DataType;
var data = payload.Data;
var type = this_.TypeNameResolver.ResolveTypeName(typeName);
var serializer = this_.GetSerializer(type) ?? throw new Exception();
var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data);
return serializer.Deserialize(any);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// GrpcGateway.cs

using System.Collections.Concurrent;
using System.Diagnostics;
using Grpc.Core;
using Microsoft.AutoGen.Contracts;
using Microsoft.AutoGen.Core.Grpc;
using Microsoft.AutoGen.Protobuf;
using Microsoft.AutoGen.Runtime.Grpc.Abstractions;
using Microsoft.Extensions.Hosting;
Expand All @@ -27,6 +29,8 @@ public sealed class GrpcGateway : BackgroundService, IGateway
private readonly ConcurrentDictionary<string, List<string>> _subscriptionsByTopic = new();
private readonly ISubscriptionsGrain _subscriptions;

private IProtoSerializationRegistry SerializationRegistry { get; } = new ProtoSerializationRegistry();

// The mapping from agent id to worker process.
private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new();
// RPC
Expand Down Expand Up @@ -86,20 +90,20 @@ public async ValueTask<RpcResponse> InvokeRequestAsync(RpcRequest request, Cance
return response;
}

public async ValueTask StoreAsync(AgentState value, CancellationToken cancellationToken = default)
private async ValueTask StoreAsync(AgentState value, CancellationToken __ = default)
{
_ = value.AgentId ?? throw new ArgumentNullException(nameof(value.AgentId));
var agentState = _clusterClient.GetGrain<IAgentGrain>($"{value.AgentId.Type}:{value.AgentId.Key}");
await agentState.WriteStateAsync(value, value.ETag);
}

public async ValueTask<AgentState> ReadAsync(Protobuf.AgentId agentId, CancellationToken cancellationToken = default)
private async ValueTask<AgentState> ReadAsync(Protobuf.AgentId agentId, CancellationToken _ = default)
{
var agentState = _clusterClient.GetGrain<IAgentGrain>($"{agentId.Type}:{agentId.Key}");
return await agentState.ReadStateAsync();
}

public async ValueTask<RegisterAgentTypeResponse> RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken _ = default)
private async ValueTask<RegisterAgentTypeResponse> RegisterAgentTypeAsync(RegisterAgentTypeRequest request, CancellationToken _ = default)
{
string requestId = string.Empty;
try
Expand Down Expand Up @@ -248,7 +252,7 @@ private async ValueTask DispatchEventAsync(CloudEvent evt, CancellationToken can
{
var registry = _clusterClient.GetGrain<IRegistryGrain>(0);
//intentionally blocking
var targetAgentTypes = await registry.GetSubscribedAndHandlingAgentsAsync(evt.Source, evt.Type).ConfigureAwait(true);
var targetAgentTypes = await registry.GetSubscribedAndHandlingAgentsAsync(topic: evt.Source, eventType: evt.Type).ConfigureAwait(true);
if (targetAgentTypes is not null && targetAgentTypes.Count > 0)
{
targetAgentTypes = targetAgentTypes.Distinct().ToList();
Expand Down Expand Up @@ -279,12 +283,25 @@ private async ValueTask DispatchRequestAsync(GrpcWorkerConnection connection, Rp
if (request.Target is null)
{
// If the gateway knows how to service this request, treat the target as the "Gateway"
if (request.Method == "RegisterAgent")
if (request.Method == "RegisterAgent" &&
request.Payload is not null &&
request.Payload.DataType == nameof(RegisterAgentTypeRequest) &&
request.Payload.Data is not null)
{
//RegisterAgentTypeRequest request =
object? payloadData = this.SerializationRegistry.PayloadToObject(request.Payload);
if (payloadData is RegisterAgentTypeRequest registerAgentTypeRequest)
{
await RegisterAgentTypeAsync(requestId, connection, registerAgentTypeRequest).ConfigureAwait(false);
return;
}
else
{
await RespondBadRequestAsync(connection, $"Invalid payload type for \"RegisterAgent\" Expected: {nameof(RegisterAgentTypeRequest)}; got: {payloadData?.GetType().Name ?? "null"}.").ConfigureAwait(false);
return;
}

//await RegisterAgentTypeAsync(requestId, connection, request.Payload).ConfigureAwait(false);
return;
//return;
}

throw new InvalidOperationException($"Request message is missing a target. Message: '{request}'.");
Expand Down Expand Up @@ -413,7 +430,7 @@ public async Task SendMessageAsync(IConnection connection, CloudEvent cloudEvent
await queue.ResponseStream.WriteAsync(new Message { CloudEvent = cloudEvent }, cancellationToken).ConfigureAwait(false);
}

public async ValueTask<RemoveSubscriptionResponse> UnsubscribeAsync(RemoveSubscriptionRequest request, CancellationToken cancellationToken = default)
private async ValueTask<RemoveSubscriptionResponse> UnsubscribeAsync(RemoveSubscriptionRequest request, CancellationToken _ = default)
{
try
{
Expand All @@ -433,7 +450,7 @@ public async ValueTask<RemoveSubscriptionResponse> UnsubscribeAsync(RemoveSubscr
}
}

public ValueTask<List<Subscription>> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken cancellationToken = default)
private ValueTask<List<Subscription>> GetSubscriptionsAsync(GetSubscriptionsRequest request, CancellationToken _ = default)
{
return _gatewayRegistry.GetSubscriptionsAsync(request);
}
Expand All @@ -450,31 +467,37 @@ async ValueTask IGateway.BroadcastEventAsync(CloudEvent evt)

ValueTask IGateway.StoreAsync(AgentState value)
{
Debug.Fail("This method should not be called.");
return StoreAsync(value, default);
}

ValueTask<AgentState> IGateway.ReadAsync(Protobuf.AgentId agentId)
{
Debug.Fail("This method should not be called.");
return ReadAsync(agentId, default);
}

ValueTask<RegisterAgentTypeResponse> IGateway.RegisterAgentTypeAsync(string requestId, RegisterAgentTypeRequest request)
{
Debug.Fail("This method should not be called.");
return RegisterAgentTypeAsync(request, default);
}

ValueTask<AddSubscriptionResponse> IGateway.SubscribeAsync(AddSubscriptionRequest request)
{
Debug.Fail("This method should not be called.");
return SubscribeAsync(request, default);
}

ValueTask<RemoveSubscriptionResponse> IGateway.UnsubscribeAsync(RemoveSubscriptionRequest request)
{
Debug.Fail("This method should not be called.");
return UnsubscribeAsync(request, default);
}

ValueTask<List<Subscription>> IGateway.GetSubscriptionsAsync(GetSubscriptionsRequest request)
{
Debug.Fail("This method should not be called.");
return GetSubscriptionsAsync(request);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,41 +28,41 @@ public override async Task OpenChannel(IAsyncStreamReader<Message> requestStream
}
}

public override async Task<GetStateResponse> GetState(AgentId request, ServerCallContext context)
{
var state = await Gateway.ReadAsync(request);
return new GetStateResponse { AgentState = state };
}
//public override async Task<GetStateResponse> GetState(AgentId request, ServerCallContext context)
//{
// var state = await Gateway.ReadAsync(request);
// return new GetStateResponse { AgentState = state };
//}

public override async Task<SaveStateResponse> SaveState(AgentState request, ServerCallContext context)
{
await Gateway.StoreAsync(request);
return new SaveStateResponse
{
//Success = true // TODO: Implement error handling
};
}
//public override async Task<SaveStateResponse> SaveState(AgentState request, ServerCallContext context)
//{
// await Gateway.StoreAsync(request);
// return new SaveStateResponse
// {
// //Success = true // TODO: Implement error handling
// };
//}

public override async Task<AddSubscriptionResponse> AddSubscription(AddSubscriptionRequest request, ServerCallContext context)
{
//request.RequestId = context.Peer;
return await Gateway.SubscribeAsync(request).ConfigureAwait(true);
}
//public override async Task<AddSubscriptionResponse> AddSubscription(AddSubscriptionRequest request, ServerCallContext context)
//{
// //request.RequestId = context.Peer;
// return await Gateway.SubscribeAsync(request).ConfigureAwait(true);
//}

public override async Task<RemoveSubscriptionResponse> RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context)
{
return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true);
}
//public override async Task<RemoveSubscriptionResponse> RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context)
//{
// return await Gateway.UnsubscribeAsync(request).ConfigureAwait(true);
//}

public override async Task<GetSubscriptionsResponse> GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context)
{
var subscriptions = await Gateway.GetSubscriptionsAsync(request);
return new GetSubscriptionsResponse { Subscriptions = { subscriptions } };
}
//public override async Task<GetSubscriptionsResponse> GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context)
//{
// var subscriptions = await Gateway.GetSubscriptionsAsync(request);
// return new GetSubscriptionsResponse { Subscriptions = { subscriptions } };
//}

public override async Task<RegisterAgentTypeResponse> RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context)
{
//request.RequestId = context.Peer;
return await Gateway.RegisterAgentTypeAsync(request).ConfigureAwait(true);
}
//public override async Task<RegisterAgentTypeResponse> RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context)
//{
// //request.RequestId = context.Peer;
// return await Gateway.RegisterAgentTypeAsync(request).ConfigureAwait(true);
//}
}
14 changes: 8 additions & 6 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ message Message {

service AgentRpc {
rpc OpenChannel (stream Message) returns (stream Message);
rpc GetState(AgentId) returns (GetStateResponse);
rpc SaveState(AgentState) returns (SaveStateResponse);
rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse);
rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse);
rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse);
rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse);

// These should be coming through the RpcRequest/Response pathway through the Message Channel
//rpc GetState(AgentId) returns (GetStateResponse);
//rpc SaveState(AgentState) returns (SaveStateResponse);
//rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse);
//rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse);
//rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse);
//rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse);
}

0 comments on commit ff2b423

Please sign in to comment.