From d784c17876cf7e1abf2e4d18d7ac6932bd17ff94 Mon Sep 17 00:00:00 2001
From: Jack Gerrits <jack@jackgerrits.com>
Date: Tue, 28 Jan 2025 18:23:46 -0500
Subject: [PATCH] First go at dotnet worker

---
 .../Core.Grpc/GrpcAgentRuntime.cs             | 597 ++++++++++++++++++
 .../Core.Grpc/IAgentMessageSerializer.cs      |  23 +
 .../Core.Grpc/IAgentRuntimeExtensions.cs      | 101 +++
 .../Core.Grpc/IProtoMessageSerializer.cs      |  10 +
 .../Core.Grpc/ISerializationRegistry.cs       |  27 +
 .../Core.Grpc/ITypeNameResolver.cs            |   9 +
 .../Core.Grpc/ProtoTypeNameResolver.cs        |  21 +
 .../Core.Grpc/ProtobufConversionExtensions.cs |  61 ++
 .../Core.Grpc/ProtobufMessageSerializer.cs    |  46 ++
 .../Core.Grpc/SerializationRegistry.cs        |  30 +
 10 files changed, 925 insertions(+)
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs
 create mode 100644 dotnet/src/Microsoft.AutoGen/Core.Grpc/SerializationRegistry.cs

diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
new file mode 100644
index 000000000000..5deba58ae62b
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs
@@ -0,0 +1,597 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// GrpcAgentRuntime.cs
+
+using System.Collections.Concurrent;
+using System.Threading.Channels;
+using Google.Protobuf;
+using Grpc.Core;
+using Microsoft.AutoGen.Contracts;
+using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+using Microsoft.AutoGen.Protobuf;
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+public sealed class GrpcAgentRuntime(
+    AgentRpc.AgentRpcClient client,
+    IHostApplicationLifetime hostApplicationLifetime,
+    IServiceProvider serviceProvider,
+    ILogger<GrpcAgentRuntime> logger
+    ) : IAgentRuntime, IDisposable
+{
+    private readonly object _channelLock = new();
+
+    // Request ID ->
+    private readonly ConcurrentDictionary<string, ResultSink<object?>> _pendingRequests = new();
+    private Dictionary<AgentType, Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>>> agentFactories = new();
+    private Dictionary<Contracts.AgentId, IHostableAgent> agentInstances = new();
+
+    private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel = Channel.CreateBounded<(Message, TaskCompletionSource)>(new BoundedChannelOptions(1024)
+    {
+        AllowSynchronousContinuations = true,
+        SingleReader = true,
+        SingleWriter = false,
+        FullMode = BoundedChannelFullMode.Wait
+    });
+
+    private readonly AgentRpc.AgentRpcClient _client = client;
+    public readonly IServiceProvider ServiceProvider = serviceProvider;
+
+    private readonly ILogger<GrpcAgentRuntime> _logger = logger;
+    private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping);
+    private AsyncDuplexStreamingCall<Message, Message>? _channel;
+    private Task? _readTask;
+    private Task? _writeTask;
+
+    private string _clientId = Guid.NewGuid().ToString();
+    private CallOptions CallOptions
+    {
+        get
+        {
+            var metadata = new Metadata
+            {
+                { "client-id", this._clientId }
+            };
+            return new CallOptions(headers: metadata);
+        }
+    }
+
+    public IProtoSerializationRegistry SerializationRegistry { get; } = new ProtoSerializationRegistry();
+
+    public void Dispose()
+    {
+        _outboundMessagesChannel.Writer.TryComplete();
+        _channel?.Dispose();
+    }
+
+    private async Task RunReadPump()
+    {
+        var channel = GetChannel();
+        while (!_shutdownCts.Token.IsCancellationRequested)
+        {
+            try
+            {
+                await foreach (var message in channel.ResponseStream.ReadAllAsync(_shutdownCts.Token))
+                {
+                    // next if message is null
+                    if (message == null)
+                    {
+                        continue;
+                    }
+                    switch (message.MessageCase)
+                    {
+                        case Message.MessageOneofCase.Request:
+                            var request = message.Request ?? throw new InvalidOperationException("Request is null.");
+                            await HandleRequest(request);
+                            break;
+                        case Message.MessageOneofCase.Response:
+                            var response = message.Response ?? throw new InvalidOperationException("Response is null.");
+                            await HandleResponse(response);
+                            break;
+                        case Message.MessageOneofCase.CloudEvent:
+                            var cloudEvent = message.CloudEvent ?? throw new InvalidOperationException("CloudEvent is null.");
+                            await HandlePublish(cloudEvent);
+                            break;
+                        default:
+                            throw new InvalidOperationException($"Unexpected message '{message}'.");
+                    }
+                }
+            }
+            catch (OperationCanceledException)
+            {
+                // Time to shut down.
+                break;
+            }
+            catch (Exception ex) when (!_shutdownCts.IsCancellationRequested)
+            {
+                _logger.LogError(ex, "Error reading from channel.");
+                channel = RecreateChannel(channel);
+            }
+            catch
+            {
+                // Shutdown requested.
+                break;
+            }
+        }
+    }
+
+    private async ValueTask HandleRequest(RpcRequest request, CancellationToken cancellationToken = default)
+    {
+        if (request is null)
+        {
+            throw new InvalidOperationException("Request is null.");
+        }
+        if (request.Payload is null)
+        {
+            throw new InvalidOperationException("Payload is null.");
+        }
+        if (request.Target is null)
+        {
+            throw new InvalidOperationException("Target is null.");
+        }
+        if (request.Source is null)
+        {
+            throw new InvalidOperationException("Source is null.");
+        }
+
+        var agentId = request.Target;
+        var agent = await EnsureAgentAsync(agentId.FromProtobuf());
+
+        // Convert payload back to object
+        var payload = request.Payload;
+        var message = PayloadToObject(payload);
+
+        var messageContext = new MessageContext(request.RequestId, cancellationToken)
+        {
+            Sender = request.Source.FromProtobuf(),
+            Topic = null,
+            IsRpc = true
+        };
+
+        var result = await agent.OnMessageAsync(message, messageContext);
+
+        if (result is not null)
+        {
+            var response = new RpcResponse
+            {
+                RequestId = request.RequestId,
+                Payload = ObjectToPayload(result)
+            };
+
+            var responseMessage = new Message
+            {
+                Response = response
+            };
+
+            await WriteChannelAsync(responseMessage, cancellationToken);
+        }
+    }
+
+    private async ValueTask HandleResponse(RpcResponse request, CancellationToken _ = default)
+    {
+        if (request is null)
+        {
+            throw new InvalidOperationException("Request is null.");
+        }
+        if (request.Payload is null)
+        {
+            throw new InvalidOperationException("Payload is null.");
+        }
+        if (request.RequestId is null)
+        {
+            throw new InvalidOperationException("RequestId is null.");
+        }
+
+        if (_pendingRequests.TryRemove(request.RequestId, out var resultSink))
+        {
+            var payload = request.Payload;
+            var message = PayloadToObject(payload);
+            resultSink.SetResult(message);
+        }
+    }
+
+    private async ValueTask HandlePublish(CloudEvent evt, CancellationToken cancellationToken = default)
+    {
+        if (evt is null)
+        {
+            throw new InvalidOperationException("CloudEvent is null.");
+        }
+        if (evt.ProtoData is null)
+        {
+            throw new InvalidOperationException("ProtoData is null.");
+        }
+        if (evt.Attributes is null)
+        {
+            throw new InvalidOperationException("Attributes is null.");
+        }
+
+        var topic = new TopicId(evt.Type, evt.Source);
+        var sender = new Contracts.AgentId
+        {
+            Type = evt.Attributes["agagentsendertype"].CeString,
+            Key = evt.Attributes["agagentsenderkey"].CeString
+        };
+
+        var messageId = evt.Id;
+        var typeName = evt.Attributes["dataschema"].CeString;
+        var serializer = SerializationRegistry.GetSerializer(typeName) ?? throw new Exception();
+        var message = serializer.Deserialize(evt.ProtoData);
+
+        var messageContext = new MessageContext(messageId, cancellationToken)
+        {
+            Sender = sender,
+            Topic = topic,
+            IsRpc = false
+        };
+        var agent = await EnsureAgentAsync(sender);
+        await agent.OnMessageAsync(message, messageContext);
+    }
+
+    private async Task RunWritePump()
+    {
+        var channel = GetChannel();
+        var outboundMessages = _outboundMessagesChannel.Reader;
+        while (!_shutdownCts.IsCancellationRequested)
+        {
+            (Message Message, TaskCompletionSource WriteCompletionSource) item = default;
+            try
+            {
+                await outboundMessages.WaitToReadAsync().ConfigureAwait(false);
+
+                // Read the next message if we don't already have an unsent message
+                // waiting to be sent.
+                if (!outboundMessages.TryRead(out item))
+                {
+                    break;
+                }
+
+                while (!_shutdownCts.IsCancellationRequested)
+                {
+                    await channel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false);
+                    item.WriteCompletionSource.TrySetResult();
+                    break;
+                }
+            }
+            catch (OperationCanceledException)
+            {
+                // Time to shut down.
+                item.WriteCompletionSource?.TrySetCanceled();
+                break;
+            }
+            catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable)
+            {
+                // we could not connect to the endpoint - most likely we have the wrong port or failed ssl
+                // we need to let the user know what port we tried to connect to and then do backoff and retry
+                _logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST"));
+                break;
+            }
+            catch (RpcException ex) when (ex.StatusCode == StatusCode.OK)
+            {
+                _logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", channel.ToString());
+                break;
+            }
+            catch (Exception ex) when (!_shutdownCts.IsCancellationRequested)
+            {
+                item.WriteCompletionSource?.TrySetException(ex);
+                _logger.LogError(ex, $"Error writing to channel.{ex}");
+                channel = RecreateChannel(channel);
+                continue;
+            }
+            catch
+            {
+                // Shutdown requested.
+                item.WriteCompletionSource?.TrySetCanceled();
+                break;
+            }
+        }
+
+        while (outboundMessages.TryRead(out var item))
+        {
+            item.WriteCompletionSource.TrySetCanceled();
+        }
+    }
+
+    // private override async ValueTask<RpcResponse> SendMessageAsync(Payload message, AgentId agentId, AgentId? agent = null, CancellationToken? cancellationToken = default)
+    // {
+    //     var request = new RpcRequest
+    //     {
+    //         RequestId = Guid.NewGuid().ToString(),
+    //         Source = agent,
+    //         Target = agentId,
+    //         Payload = message,
+    //     };
+
+    //     // Actually send it and wait for the response
+    //     throw new NotImplementedException();
+    // }
+
+    // new is intentional
+
+    // public new async ValueTask RuntimeSendRequestAsync(IAgent agent, RpcRequest request, CancellationToken cancellationToken = default)
+    // {
+    //     var requestId = Guid.NewGuid().ToString();
+    //     _pendingRequests[requestId] = ((Agent)agent, request.RequestId);
+    //     request.RequestId = requestId;
+    //     await WriteChannelAsync(new Message { Request = request }, cancellationToken).ConfigureAwait(false);
+    // }
+
+    private async Task WriteChannelAsync(Message message, CancellationToken cancellationToken = default)
+    {
+        var tcs = new TaskCompletionSource();
+        await _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellationToken).ConfigureAwait(false);
+    }
+    private AsyncDuplexStreamingCall<Message, Message> GetChannel()
+    {
+        if (_channel is { } channel)
+        {
+            return channel;
+        }
+
+        lock (_channelLock)
+        {
+            if (_channel is not null)
+            {
+                return _channel;
+            }
+
+            return RecreateChannel(null);
+        }
+    }
+
+    private AsyncDuplexStreamingCall<Message, Message> RecreateChannel(AsyncDuplexStreamingCall<Message, Message>? channel)
+    {
+        if (_channel is null || _channel == channel)
+        {
+            lock (_channelLock)
+            {
+                if (_channel is null || _channel == channel)
+                {
+                    _channel?.Dispose();
+                    _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token);
+                }
+            }
+        }
+
+        return _channel;
+    }
+    public async Task StartAsync(CancellationToken cancellationToken)
+    {
+        _channel = GetChannel();
+        _logger.LogInformation("Starting " + GetType().Name + ",connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST"));
+        var didSuppress = false;
+        if (!ExecutionContext.IsFlowSuppressed())
+        {
+            didSuppress = true;
+            ExecutionContext.SuppressFlow();
+        }
+
+        try
+        {
+            _readTask = Task.Run(RunReadPump, cancellationToken);
+            _writeTask = Task.Run(RunWritePump, cancellationToken);
+        }
+        finally
+        {
+            if (didSuppress)
+            {
+                ExecutionContext.RestoreFlow();
+            }
+        }
+    }
+
+    public async Task StopAsync(CancellationToken cancellationToken)
+    {
+        _shutdownCts.Cancel();
+
+        _outboundMessagesChannel.Writer.TryComplete();
+
+        if (_readTask is { } readTask)
+        {
+            await readTask.ConfigureAwait(false);
+        }
+
+        if (_writeTask is { } writeTask)
+        {
+            await writeTask.ConfigureAwait(false);
+        }
+        lock (_channelLock)
+        {
+            _channel?.Dispose();
+        }
+    }
+
+    private async ValueTask<IHostableAgent> EnsureAgentAsync(Contracts.AgentId agentId)
+    {
+        if (!this.agentInstances.TryGetValue(agentId, out IHostableAgent? agent))
+        {
+            if (!this.agentFactories.TryGetValue(agentId.Type, out Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>>? factoryFunc))
+            {
+                throw new Exception($"Agent with name {agentId.Type} not found.");
+            }
+
+            agent = await factoryFunc(agentId, this);
+            this.agentInstances.Add(agentId, agent);
+        }
+
+        return this.agentInstances[agentId];
+    }
+
+    private Payload ObjectToPayload(object message) {
+        if (!SerializationRegistry.Exists(message.GetType()))
+        {
+            SerializationRegistry.RegisterSerializer(message.GetType());
+        }
+        var rpcMessage = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);
+
+        var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message);
+        const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf";
+
+        // Protobuf any to byte array
+        Payload payload = new()
+        {
+            DataType = typeName,
+            DataContentType = PAYLOAD_DATA_CONTENT_TYPE,
+            Data = rpcMessage.ToByteString()
+        };
+
+        return payload;
+    }
+
+    private object PayloadToObject(Payload payload) {
+        var typeName = payload.DataType;
+        var data = payload.Data;
+        var type = SerializationRegistry.TypeNameResolver.ResolveTypeName(typeName);
+        var serializer = SerializationRegistry.GetSerializer(type) ?? throw new Exception();
+        var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data);
+        return serializer.Deserialize(any);
+    }
+
+    public async ValueTask<object?> SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default)
+    {
+        if (!SerializationRegistry.Exists(message.GetType()))
+        {
+            SerializationRegistry.RegisterSerializer(message.GetType());
+        }
+
+        var payload = ObjectToPayload(message);
+        var request = new RpcRequest
+        {
+            RequestId = Guid.NewGuid().ToString(),
+            Source = (sender ?? new Contracts.AgentId() ).ToProtobuf(),
+            Target = recepient.ToProtobuf(),
+            Payload = payload,
+        };
+
+        Message msg = new()
+        {
+            Request = request
+        };
+        // Create a future that will be completed when the response is received
+        var resultSink = new ResultSink<object?>();
+        this._pendingRequests.TryAdd(request.RequestId, resultSink);
+        await WriteChannelAsync(msg, cancellationToken);
+
+        return await resultSink.Future;
+    }
+
+    private CloudEvent CreateCloudEvent(Google.Protobuf.WellKnownTypes.Any payload, TopicId topic, string dataType, Contracts.AgentId sender, string messageId)
+    {
+        const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf";
+        return new CloudEvent
+        {
+            ProtoData = payload,
+            Type = topic.Type,
+            Source = topic.Source,
+            Id = messageId,
+            Attributes = {
+                {
+                    "datacontenttype", new CloudEvent.Types.CloudEventAttributeValue { CeString = PAYLOAD_DATA_CONTENT_TYPE }
+                },
+                {
+                    "dataschema", new CloudEvent.Types.CloudEventAttributeValue { CeString = dataType }
+                },
+                {
+                    "agagentsendertype", new CloudEvent.Types.CloudEventAttributeValue { CeString = sender.Type }
+                },
+                {
+                    "agagentsenderkey", new CloudEvent.Types.CloudEventAttributeValue { CeString = sender.Key }
+                },
+                {
+                    "agmsgkind", new CloudEvent.Types.CloudEventAttributeValue { CeString = "publish" }
+                }
+            }
+        };
+    }
+
+    public async ValueTask PublishMessageAsync(object message, TopicId topic, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default)
+    {
+        if (!SerializationRegistry.Exists(message.GetType()))
+        {
+            SerializationRegistry.RegisterSerializer(message.GetType());
+        }
+        var protoAny = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message);
+        var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message);
+
+        var cloudEvent = CreateCloudEvent(protoAny, topic, typeName, sender ?? new Contracts.AgentId(), messageId ?? Guid.NewGuid().ToString());
+
+        Message msg = new()
+        {
+            CloudEvent = cloudEvent
+        };
+        await WriteChannelAsync(msg, cancellationToken);
+    }
+
+    public ValueTask<Contracts.AgentId> GetAgentAsync(Contracts.AgentId agentId, bool lazy = true)
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask<Contracts.AgentId> GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true)
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask<Contracts.AgentId> GetAgentAsync(string agent, string key = "default", bool lazy = true)
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask<IDictionary<string, object>> SaveAgentStateAsync(Contracts.AgentId agentId)
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary<string, object> state)
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask<AgentMetadata> GetAgentMetadataAsync(Contracts.AgentId agentId)
+    {
+        throw new NotImplementedException();
+    }
+
+    public async ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription)
+    {
+        var _ = await this._client.AddSubscriptionAsync(new AddSubscriptionRequest{
+            Subscription = subscription.ToProtobuf()
+        },this.CallOptions);
+    }
+
+    public ValueTask RemoveSubscriptionAsync(string subscriptionId)
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask<AgentType> RegisterAgentFactoryAsync(AgentType type, Func<Contracts.AgentId, IAgentRuntime, ValueTask<IHostableAgent>> factoryFunc)
+    {
+        if (this.agentFactories.ContainsKey(type))
+        {
+            throw new Exception($"Agent with type {type} already exists.");
+        }
+        this.agentFactories.Add(type, async (agentId, runtime) => await factoryFunc(agentId, runtime));
+
+        this._client.RegisterAgentAsync(new RegisterAgentTypeRequest
+        {
+            Type = type.Name,
+
+        }, this.CallOptions);
+        return ValueTask.FromResult(type);
+    }
+
+    public ValueTask<AgentProxy> TryGetAgentProxyAsync(Contracts.AgentId agentId)
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask<IDictionary<string, object>> SaveStateAsync()
+    {
+        throw new NotImplementedException();
+    }
+
+    public ValueTask LoadStateAsync(IDictionary<string, object> state)
+    {
+        throw new NotImplementedException();
+    }
+}
+
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs
new file mode 100644
index 000000000000..0cc422d54d85
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// IAgentMessageSerializer.cs
+
+namespace Microsoft.AutoGen.Core.Grpc;
+/// <summary>
+/// Interface for serializing and deserializing agent messages.
+/// </summary>
+public interface IAgentMessageSerializer
+{
+    /// <summary>
+    /// Serialize an agent message.
+    /// </summary>
+    /// <param name="message">The message to serialize.</param>
+    /// <returns>The serialized message.</returns>
+    Google.Protobuf.WellKnownTypes.Any Serialize(object message);
+
+    /// <summary>
+    /// Deserialize an agent message.
+    /// </summary>
+    /// <param name="message">The message to deserialize.</param>
+    /// <returns>The deserialized message.</returns>
+    object Deserialize(Google.Protobuf.WellKnownTypes.Any message);
+}
\ No newline at end of file
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs
new file mode 100644
index 000000000000..c820baa527c7
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs
@@ -0,0 +1,101 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// IAgentRuntimeExtensions.cs
+
+using System.Diagnostics;
+using Google.Protobuf.Collections;
+using Microsoft.AutoGen.Contracts;
+using Microsoft.Extensions.DependencyInjection;
+using static Microsoft.AutoGen.Contracts.CloudEvent.Types;
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+public static class IAgentRuntimeExtensions
+{
+    public static (string?, string?) GetTraceIdAndState(IAgentRuntime runtime, IDictionary<string, string> metadata)
+    {
+        var dcp = runtime.RuntimeServiceProvider.GetRequiredService<DistributedContextPropagator>();
+        dcp.ExtractTraceIdAndState(metadata,
+            static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
+            {
+                var metadata = (IDictionary<string, string>)carrier!;
+                fieldValues = null;
+                metadata.TryGetValue(fieldName, out fieldValue);
+            },
+            out var traceParent,
+            out var traceState);
+        return (traceParent, traceState);
+    }
+    public static (string?, string?) GetTraceIdAndState(IAgentRuntime worker, MapField<string, CloudEventAttributeValue> metadata)
+    {
+        var dcp = worker.RuntimeServiceProvider.GetRequiredService<DistributedContextPropagator>();
+        dcp.ExtractTraceIdAndState(metadata,
+            static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
+            {
+                var metadata = (MapField<string, CloudEventAttributeValue>)carrier!;
+                fieldValues = null;
+                metadata.TryGetValue(fieldName, out var ceValue);
+                fieldValue = ceValue?.CeString;
+            },
+            out var traceParent,
+            out var traceState);
+        return (traceParent, traceState);
+    }
+    public static void Update(IAgentRuntime worker, RpcRequest request, Activity? activity = null)
+    {
+        var dcp = worker.RuntimeServiceProvider.GetRequiredService<DistributedContextPropagator>();
+        dcp.Inject(activity, request.Metadata, static (carrier, key, value) =>
+        {
+            var metadata = (IDictionary<string, string>)carrier!;
+            if (metadata.TryGetValue(key, out _))
+            {
+                metadata[key] = value;
+            }
+            else
+            {
+                metadata.Add(key, value);
+            }
+        });
+    }
+    public static void Update(IAgentRuntime worker, CloudEvent cloudEvent, Activity? activity = null)
+    {
+        var dcp = worker.RuntimeServiceProvider.GetRequiredService<DistributedContextPropagator>();
+        dcp.Inject(activity, cloudEvent.Attributes, static (carrier, key, value) =>
+        {
+            var mapField = (MapField<string, CloudEventAttributeValue>)carrier!;
+            if (mapField.TryGetValue(key, out var ceValue))
+            {
+                mapField[key] = new CloudEventAttributeValue { CeString = value };
+            }
+            else
+            {
+                mapField.Add(key, new CloudEventAttributeValue { CeString = value });
+            }
+        });
+    }
+
+    public static IDictionary<string, string> ExtractMetadata(IAgentRuntime worker, IDictionary<string, string> metadata)
+    {
+        var dcp = worker.RuntimeServiceProvider.GetRequiredService<DistributedContextPropagator>();
+        var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
+        {
+            var metadata = (IDictionary<string, string>)carrier!;
+            fieldValues = null;
+            metadata.TryGetValue(fieldName, out fieldValue);
+        });
+
+        return baggage as IDictionary<string, string> ?? new Dictionary<string, string>();
+    }
+    public static IDictionary<string, string> ExtractMetadata(IAgentRuntime worker, MapField<string, CloudEventAttributeValue> metadata)
+    {
+        var dcp = worker.RuntimeServiceProvider.GetRequiredService<DistributedContextPropagator>();
+        var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
+        {
+            var metadata = (MapField<string, CloudEventAttributeValue>)carrier!;
+            fieldValues = null;
+            metadata.TryGetValue(fieldName, out var ceValue);
+            fieldValue = ceValue?.CeString;
+        });
+
+        return baggage as IDictionary<string, string> ?? new Dictionary<string, string>();
+    }
+}
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs
new file mode 100644
index 000000000000..ca690e508d2b
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtoMessageSerializer.cs
@@ -0,0 +1,10 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// IProtoMessageSerializer.cs
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+public interface IProtoMessageSerializer
+{
+    Google.Protobuf.WellKnownTypes.Any Serialize(object input);
+    object Deserialize(Google.Protobuf.WellKnownTypes.Any input);
+}
\ No newline at end of file
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs
new file mode 100644
index 000000000000..190ed3ec239d
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ISerializationRegistry.cs
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+public interface IProtoSerializationRegistry
+{
+    /// <summary>
+    /// Registers a serializer for the specified type.
+    /// </summary>
+    /// <param name="type">The type to register.</param>
+    void RegisterSerializer(System.Type type) => RegisterSerializer(type, new ProtobufMessageSerializer(type));
+
+    void RegisterSerializer(System.Type type, IProtoMessageSerializer serializer);
+
+    /// <summary>
+    /// Gets the serializer for the specified type.
+    /// </summary>
+    /// <param name="type">The type to get the serializer for.</param>
+    /// <returns>The serializer for the specified type.</returns>
+    IProtoMessageSerializer? GetSerializer(System.Type type) => GetSerializer(TypeNameResolver.ResolveTypeName(type));
+    IProtoMessageSerializer? GetSerializer(string typeName);
+
+    ITypeNameResolver TypeNameResolver { get; }
+
+    bool Exists(System.Type type);
+}
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs
new file mode 100644
index 000000000000..24de4cb8b449
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs
@@ -0,0 +1,9 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ITypeNameResolver.cs
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+public interface ITypeNameResolver
+{
+    string ResolveTypeName(object input);
+}
\ No newline at end of file
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs
new file mode 100644
index 000000000000..808116139ba6
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtoTypeNameResolver.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ITypeNameResolver.cs
+
+using Google.Protobuf;
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+public class ProtoTypeNameResolver : ITypeNameResolver
+{
+    public string ResolveTypeName(object input)
+    {
+        if (input is IMessage protoMessage)
+        {
+            return protoMessage.Descriptor.FullName;
+        }
+        else
+        {
+            throw new ArgumentException("Input must be a protobuf message.");
+        }
+    }
+}
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs
new file mode 100644
index 000000000000..4850b7825afe
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs
@@ -0,0 +1,61 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ProtobufConversionExtensions.cs
+
+using Microsoft.AutoGen.Contracts;
+using Microsoft.AutoGen.Protobuf;
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+public static class ProtobufConversionExtensions
+{
+    // Convert an ISubscrptionDefinition to a Protobuf Subscription
+    public static Subscription? ToProtobuf(this ISubscriptionDefinition subscriptionDefinition)
+    {
+        // Check if is a TypeSubscription
+        if (subscriptionDefinition is Contracts.TypeSubscription typeSubscription)
+        {
+            return new Subscription
+            {
+                Id = typeSubscription.Id,
+                TypeSubscription = new Protobuf.TypeSubscription
+                {
+                    TopicType = typeSubscription.TopicType,
+                    AgentType = typeSubscription.AgentType
+                }
+            };
+        }
+
+        // Check if is a TypePrefixSubscription
+        if (subscriptionDefinition is Contracts.TypePrefixSubscription typePrefixSubscription)
+        {
+            return new Subscription
+            {
+                Id = typePrefixSubscription.Id,
+                TypePrefixSubscription = new Protobuf.TypePrefixSubscription
+                {
+                    TopicTypePrefix = typePrefixSubscription.TopicTypePrefix,
+                    AgentType = typePrefixSubscription.AgentType
+                }
+            };
+        }
+
+        return null;
+    }
+
+    // Convert AgentId from Protobuf to AgentId
+    public static Contracts.AgentId FromProtobuf(this Protobuf.AgentId agentId)
+    {
+        return new Contracts.AgentId(agentId.Type, agentId.Key);
+    }
+
+    // Convert AgentId from AgentId to Protobuf
+    public static Protobuf.AgentId ToProtobuf(this Contracts.AgentId agentId)
+    {
+        return new Protobuf.AgentId
+        {
+            Type = agentId.Type,
+            Key = agentId.Key
+        };
+    }
+
+}
\ No newline at end of file
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs
new file mode 100644
index 000000000000..55c1aebfa47d
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs
@@ -0,0 +1,46 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// ProtobufMessageSerializer.cs
+
+using Google.Protobuf;
+using Google.Protobuf.WellKnownTypes;
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+/// <summary>
+/// Interface for serializing and deserializing agent messages.
+/// </summary>
+public class ProtobufMessageSerializer : IProtoMessageSerializer
+{
+    private System.Type _concreteType;
+
+    public ProtobufMessageSerializer(System.Type concreteType)
+    {
+        _concreteType = concreteType;
+    }
+
+    public object Deserialize(Any message)
+    {
+        // Check if the concrete type is a proto IMessage
+        if (typeof(IMessage).IsAssignableFrom(_concreteType))
+        {
+            var nameOfMethod = nameof(Any.Unpack);
+            var result = message.GetType().GetMethods().Where(m => m.Name == nameOfMethod && m.IsGenericMethod).First().MakeGenericMethod(_concreteType).Invoke(message, null);
+            return result as IMessage ?? throw new ArgumentException("Failed to deserialize", nameof(message));
+        }
+
+        // Raise an exception if the concrete type is not a proto IMessage
+        throw new ArgumentException("Concrete type must be a proto IMessage", nameof(_concreteType));
+    }
+
+    public Any Serialize(object message)
+    {
+        // Check if message is a proto IMessage
+        if (message is IMessage protoMessage)
+        {
+            return Any.Pack(protoMessage);
+        }
+
+        // Raise an exception if the message is not a proto IMessage
+        throw new ArgumentException("Message must be a proto IMessage", nameof(message));
+    }
+}
\ No newline at end of file
diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/SerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/SerializationRegistry.cs
new file mode 100644
index 000000000000..d7bf3a37325c
--- /dev/null
+++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/SerializationRegistry.cs
@@ -0,0 +1,30 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// SerializationRegistry.cs
+
+namespace Microsoft.AutoGen.Core.Grpc;
+
+public class ProtoSerializationRegistry : IProtoSerializationRegistry
+{
+    private readonly Dictionary<Type, IProtoMessageSerializer> _serializers
+        = new Dictionary<Type, IProtoMessageSerializer>();
+
+    public bool Exists(Type type)
+    {
+        return _serializers.ContainsKey(type);
+    }
+
+    public IProtoMessageSerializer? GetSerializer(Type type)
+    {
+        _serializers.TryGetValue(type, out var serializer);
+        return serializer;
+    }
+
+    public void RegisterSerializer(Type type, IProtoMessageSerializer serializer)
+    {
+        if (_serializers.ContainsKey(type))
+        {
+            throw new InvalidOperationException($"Serializer already registered for {type.FullName}");
+        }
+        _serializers[type] = serializer;
+    }
+}