diff --git a/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/AgentStateGrain.cs b/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/AgentStateGrain.cs index c6b2db205a37..a3fac39e21d9 100644 --- a/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/AgentStateGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/AgentStateGrain.cs @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentStateGrain.cs +using System.Data; using Microsoft.AutoGen.Abstractions; namespace Microsoft.AutoGen.Runtime.Grpc; -internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState state) : Grain +internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState state) : Grain, IAgentGrain { /// public async ValueTask WriteStateAsync(AgentState newState, string eTag/*, CancellationToken cancellationToken = default*/) @@ -22,7 +23,7 @@ public async ValueTask WriteStateAsync(AgentState newState, string eTag/ else { //TODO - this is probably not the correct behavior to just throw - I presume we want to somehow let the caller know that the state has changed and they need to re-read it - throw new ArgumentException( + throw new DBConcurrencyException( "The provided ETag does not match the current ETag. The state has been modified by another request."); } return state.Etag; @@ -34,3 +35,9 @@ public ValueTask ReadStateAsync(/*CancellationToken cancellationToke return ValueTask.FromResult(state.State); } } + +internal interface IAgentGrain : IGrainWithStringKey +{ + ValueTask ReadStateAsync(); + ValueTask WriteStateAsync(AgentState state, string eTag); +} diff --git a/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcGateway.cs b/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcGateway.cs index 980be774198f..613637295623 100644 --- a/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcGateway.cs +++ b/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcGateway.cs @@ -14,7 +14,6 @@ public sealed class GrpcGateway : BackgroundService, IGateway private static readonly TimeSpan s_agentResponseTimeout = TimeSpan.FromSeconds(30); private readonly ILogger _logger; private readonly IClusterClient _clusterClient; - private readonly ConcurrentDictionary _agentState = new(); private readonly IRegistryGrain _gatewayRegistry; private readonly IGateway _reference; @@ -60,12 +59,18 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) } } - internal Task ConnectToWorkerProcess(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + internal async Task ConnectToWorkerProcess(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) { _logger.LogInformation("Received new connection from {Peer}.", context.Peer); var workerProcess = new GrpcWorkerConnection(this, requestStream, responseStream, context); _workers[workerProcess] = workerProcess; - return workerProcess.Completion; + var completion = new TaskCompletionSource(); + var _ = Task.Run(() => + { + completion.SetResult(workerProcess.Connect()); + }); + + await completion.Task; } public async ValueTask BroadcastEvent(CloudEvent evt) @@ -172,18 +177,16 @@ private static async Task InvokeRequestDelegate(GrpcWorkerConnection connection, public async ValueTask StoreAsync(AgentState value) { - var agentId = value.AgentId ?? throw new ArgumentNullException(nameof(value.AgentId)); - _agentState[agentId.Key] = value; + var agentState = _clusterClient.GetGrain($"{value.AgentId.Type}:{value.AgentId.Key}"); + await agentState.WriteStateAsync(value, value.ETag); } public async ValueTask ReadAsync(AgentId agentId) { - if (_agentState.TryGetValue(agentId.Key, out var state)) - { - return state; - } - return new AgentState { AgentId = agentId }; + var agentState = _clusterClient.GetGrain($"{agentId.Type}:{agentId.Key}"); + return await agentState.ReadStateAsync(); } + internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess) { _workers.TryRemove(workerProcess, out _); @@ -208,7 +211,7 @@ internal void OnRemoveWorkerProcess(GrpcWorkerConnection workerProcess) public async ValueTask InvokeRequest(RpcRequest request, CancellationToken cancellationToken = default) { var agentId = (request.Target.Type, request.Target.Key); - if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion.IsCompleted) + if (!_agentDirectory.TryGetValue(agentId, out var connection) || connection.Completion?.IsCompleted == true) { // Activate the agent on a compatible worker process. if (_supportedAgentTypes.TryGetValue(request.Target.Type, out var workers)) diff --git a/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcWorkerConnection.cs b/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcWorkerConnection.cs index e9c77d1ede92..22e7bf7cba7e 100644 --- a/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcWorkerConnection.cs +++ b/dotnet/src/Microsoft.AutoGen/Microsoft.AutoGen.Runtime.Grpc/GrpcWorkerConnection.cs @@ -10,13 +10,14 @@ namespace Microsoft.AutoGen.Runtime.Grpc; internal sealed class GrpcWorkerConnection : IAsyncDisposable { private static long s_nextConnectionId; - private readonly Task _readTask; - private readonly Task _writeTask; + private Task? _readTask; + private Task? _writeTask; private readonly string _connectionId = Interlocked.Increment(ref s_nextConnectionId).ToString(); private readonly object _lock = new(); private readonly HashSet _supportedTypes = []; private readonly GrpcGateway _gateway; private readonly CancellationTokenSource _shutdownCancellationToken = new(); + public Task? Completion { get; private set; } public GrpcWorkerConnection(GrpcGateway agentWorker, IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) { @@ -25,7 +26,10 @@ public GrpcWorkerConnection(GrpcGateway agentWorker, IAsyncStreamReader ResponseStream = responseStream; ServerCallContext = context; _outboundMessages = Channel.CreateUnbounded(new UnboundedChannelOptions { AllowSynchronousContinuations = true, SingleReader = true, SingleWriter = false }); + } + public Task Connect() + { var didSuppress = false; if (!ExecutionContext.IsFlowSuppressed()) { @@ -46,7 +50,7 @@ public GrpcWorkerConnection(GrpcGateway agentWorker, IAsyncStreamReader } } - Completion = Task.WhenAll(_readTask, _writeTask); + return Completion = Task.WhenAll(_readTask, _writeTask); } public IAsyncStreamReader RequestStream { get; } @@ -76,8 +80,6 @@ public async Task SendMessage(Message message) await _outboundMessages.Writer.WriteAsync(message).ConfigureAwait(false); } - public Task Completion { get; } - public async Task RunReadPump() { await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); @@ -122,7 +124,10 @@ public async Task RunWritePump() public async ValueTask DisposeAsync() { _shutdownCancellationToken.Cancel(); - await Completion.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + if (Completion is not null) + { + await Completion.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + } } public override string ToString() => $"Connection-{_connectionId}"; diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs index 6b496d58c409..f9991996c885 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs @@ -29,9 +29,20 @@ public async Task Test_OpenChannel() var requestStream = new TestAsyncStreamReader(callContext); var responseStream = new TestServerStreamWriter(callContext); + await service.RegisterAgent(new RegisterAgentTypeRequest { Type = nameof(PBAgent), RequestId = $"{Guid.NewGuid()}", Events = { "", "" } }, callContext); + await service.RegisterAgent(new RegisterAgentTypeRequest { Type = nameof(GMAgent), RequestId = $"{Guid.NewGuid()}", Events = { "", "" } }, callContext); + await service.OpenChannel(requestStream, responseStream, callContext); - requestStream.AddMessage(new Message { }); + var bgEvent = new CloudEvent + { + Id = "1", + Source = "gh/repo/1", + Type = "test", + + }; + + requestStream.AddMessage(new Message { CloudEvent = bgEvent }); requestStream.Complete(); @@ -62,7 +73,7 @@ public async Task Test_GetState() var service = new GrpcGatewayService(gateway); var callContext = TestServerCallContext.Create(); - var response = await service.GetState(new AgentId { }, callContext); + var response = await service.GetState(new AgentId { Key = "", Type = "" }, callContext); response.Should().NotBeNull(); } diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayTests.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayTests.cs index 2cb1a094c939..e92fdcde6483 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayTests.cs @@ -29,7 +29,8 @@ public async Task TestBroadcastEvent() // 1. Register Agent // 2. Broadcast Event - // 3. + // 3. + await gateway.BroadcastEvent(evt); //var registry = fixture.Registry; //var subscriptions = fixture.Subscriptions; diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs index 84d404e1f6bc..fb345777a82b 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/SiloBuilderConfigurator.cs @@ -12,7 +12,10 @@ public void Configure(ISiloBuilder siloBuilder) { siloBuilder.ConfigureServices(services => { - services.AddSerializer(a=> a.AddProtobufSerializer()); + services.AddSerializer(a => a.AddProtobufSerializer()); }); + siloBuilder.AddMemoryStreams("StreamProvider") + .AddMemoryGrainStorage("PubSubStore") + .AddMemoryGrainStorage("AgentStateStore"); } } diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/Surrogates/AgentStateSurrogate.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/Surrogates/AgentStateSurrogate.cs index 31aa7cc5670a..56d20be28887 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/Surrogates/AgentStateSurrogate.cs +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Helpers/Orleans/Surrogates/AgentStateSurrogate.cs @@ -2,7 +2,6 @@ // AgentStateSurrogate.cs using Google.Protobuf; -using Google.Protobuf.WellKnownTypes; using Microsoft.AutoGen.Abstractions; namespace Microsoft.AutoGen.Runtime.Grpc.Tests.Helpers.Orleans.Surrogates; @@ -21,7 +20,7 @@ public struct AgentStateSurrogate [Id(4)] public string Etag; [Id(5)] - public Any ProtoData; + public ByteString ProtoData; } [RegisterConverter] @@ -35,7 +34,7 @@ public AgentState ConvertFromSurrogate( TextData = surrogate.TextData, BinaryData = surrogate.BinaryData, AgentId = surrogate.AgentId, - ProtoData = surrogate.ProtoData, + // ProtoData = surrogate.ProtoData, ETag = surrogate.Etag }; @@ -47,7 +46,7 @@ public AgentStateSurrogate ConvertToSurrogate( BinaryData = value.BinaryData, TextData = value.TextData, Etag = value.ETag, - ProtoData = value.ProtoData + ProtoData = value.ProtoData.Value }; } diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj index 0a0df26b158a..0a852fbf7e61 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj @@ -13,6 +13,7 @@ + diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/TestAgent.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/TestAgent.cs new file mode 100644 index 000000000000..36dc24d20935 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/TestAgent.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TestAgent.cs + +using System.Collections.Concurrent; +using Microsoft.AutoGen.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Tests.Events; + +namespace Microsoft.AutoGen.Runtime.Grpc.Tests; + +public class PBAgent([FromKeyedServices("EventTypes")] EventTypes eventTypes, ILogger? logger = null) + : Agent(eventTypes, logger) + , IHandle + , IHandle +{ + public async Task Handle(NewMessageReceived item, CancellationToken cancellationToken = default) + { + ReceivedMessages[AgentId.Key] = item.Message; + var hello = new Hello { Message = item.Message }; + await PublishEventAsync(hello); + } + public Task Handle(GoodBye item, CancellationToken cancellationToken) + { + _logger.LogInformation($"Received GoodBye message {item.Message}"); + return Task.CompletedTask; + } + + public static ConcurrentDictionary ReceivedMessages { get; private set; } = new(); +} + +public class GMAgent([FromKeyedServices("EventTypes")] EventTypes eventTypes, ILogger? logger = null) + : Agent(eventTypes, logger) + , IHandle +{ + public async Task Handle(Hello item, CancellationToken cancellationToken) + { + _logger.LogInformation($"Received Hello message {item.Message}"); + ReceivedMessages[AgentId.Key] = item.Message; + await PublishEventAsync(new GoodBye { Message = "" }); + } + + public static ConcurrentDictionary ReceivedMessages { get; private set; } = new(); +} diff --git a/dotnet/test/Microsoft.Autogen.Tests.Shared/Protos/messages.proto b/dotnet/test/Microsoft.Autogen.Tests.Shared/Protos/messages.proto index 95f9beddab8e..cb68d45e7550 100644 --- a/dotnet/test/Microsoft.Autogen.Tests.Shared/Protos/messages.proto +++ b/dotnet/test/Microsoft.Autogen.Tests.Shared/Protos/messages.proto @@ -7,7 +7,7 @@ message TextMessage { string message = 1; string source = 2; } -message Input { +message Hello { string message = 1; } message InputProcessed {