diff --git a/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/MessageRegistryGrain.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/MessageRegistryGrain.cs index 534114920d7d..1f524e552f77 100644 --- a/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/MessageRegistryGrain.cs +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/MessageRegistryGrain.cs @@ -4,152 +4,79 @@ using Microsoft.AutoGen.Contracts; using Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; using Microsoft.Extensions.Logging; +using Orleans.Concurrency; namespace Microsoft.AutoGen.RuntimeGateway.Grpc; - -internal sealed class MessageRegistryGrain( - [PersistentState("state", "PubSubStore")] IPersistentState state, - ILogger logger -) : Grain, IMessageRegistryGrain +[Reentrant] +internal sealed class MessageRegistryGrain : Grain, IMessageRegistryGrain { - // - // Helper class for managing state writes. - // - private readonly StateManager _stateManager = new(state); + public enum QueueType + { + DeadLetterQueue, + EventBuffer + } - // - // The number of times to retry writing the state before giving up. - // - private const int _retries = 5; /// /// The time to wait before removing a message from the event buffer. /// in milliseconds /// private const int _bufferTime = 5000; - private readonly ILogger _logger = logger; - - // - public async Task AddMessageToDeadLetterQueueAsync(string topic, CloudEvent message) - { - await TryWriteMessageAsync("dlq", topic, message).ConfigureAwait(true); - } - - /// - public async Task AddMessageToEventBufferAsync(string topic, CloudEvent message) - { - await TryWriteMessageAsync("eb", topic, message).ConfigureAwait(true); - // Schedule the removal task to run in the background after bufferTime - RemoveMessageAfterDelay(topic, message).Ignore(); - } /// - /// remove a specific message from the buffer for a given topic + /// maximum size of a message we will write to the state store in bytes /// - /// - /// - /// ValueTask - private async ValueTask RemoveMessage(string topic, CloudEvent message) - { - if (state.State.EventBuffer != null && state.State.EventBuffer.TryGetValue(topic, out List? events)) - { - if (events != null && events.Remove(message)) - { - state.State.EventBuffer.AddOrUpdate(topic, events, (_, _) => events); - await _stateManager.WriteStateAsync().ConfigureAwait(true); - return true; - } - } - return false; - } + /// set this to HALF your intended limit as protobuf strings are UTF8 but .NET UTF16 + private const int _maxMessageSize = 1024 * 1024 * 10; // 10MB /// - /// remove a specific message from the buffer for a given topic after a delay + /// maximum size of a each queue /// - /// - /// - private async Task RemoveMessageAfterDelay(string topic, CloudEvent message) + /// set this to HALF your intended limit as protobuf strings are UTF8 but .NET UTF16 + private const int _maxQueueSize = 1024 * 1024 * 10; // 10MB + + private readonly MessageRegistryQueue _dlqQueue; + private readonly MessageRegistryQueue _ebQueue; + + public MessageRegistryGrain( + [PersistentState("state", "PubSubStore")] IPersistentState state, + ILogger logger) { - await Task.Delay(_bufferTime); - await RemoveMessage(topic, message); + var stateManager = new StateManager(state); + _dlqQueue = new MessageRegistryQueue( + QueueType.DeadLetterQueue, + state, + stateManager, + logger, + _maxMessageSize, + _maxQueueSize); + + _ebQueue = new MessageRegistryQueue( + QueueType.EventBuffer, + state, + stateManager, + logger, + _maxMessageSize, + _maxQueueSize); } - /// - /// Tries to write a message to the given queue in Orleans state. - /// Allows for retries using etag for optimistic concurrency. - /// - /// - /// - /// - /// - /// - private async ValueTask TryWriteMessageAsync(string whichQueue, string topic, CloudEvent message) + // + public async Task AddMessageToDeadLetterQueueAsync(string topic, CloudEvent message) { - var retries = _retries; - while (!await WriteMessageAsync(whichQueue, topic, message, state.Etag).ConfigureAwait(false)) - { - if (retries-- <= 0) - { - throw new InvalidOperationException($"Failed to write MessageRegistryState after {_retries} retries."); - } - _logger.LogWarning("Failed to write MessageRegistryState. Retrying..."); - retries--; - } - if (retries == 0) { return false; } else { return true; } + await _dlqQueue.AddMessageAsync(topic, message); } - /// - /// Writes a message to the given queue in Orleans state. - /// - /// - /// - /// - /// - /// ValueTask - /// - private async ValueTask WriteMessageAsync(string whichQueue, string topic, CloudEvent message, string etag) + + /// + public async Task AddMessageToEventBufferAsync(string topic, CloudEvent message) { - if (state.Etag != null && state.Etag != etag) - { - return false; - } - switch (whichQueue) - { - case "dlq": - var dlqQueue = state.State.DeadLetterQueue.GetOrAdd(topic, _ => new()); - dlqQueue.Add(message); - state.State.DeadLetterQueue.AddOrUpdate(topic, dlqQueue, (_, _) => dlqQueue); - break; - case "eb": - var ebQueue = state.State.EventBuffer.GetOrAdd(topic, _ => new()); - ebQueue.Add(message); - state.State.EventBuffer.AddOrUpdate(topic, ebQueue, (_, _) => ebQueue); - break; - default: - throw new ArgumentException($"Invalid queue name: {whichQueue}"); - } - await _stateManager.WriteStateAsync().ConfigureAwait(true); - return true; + await _ebQueue.AddMessageAsync(topic, message); + _ebQueue.RemoveMessageAfterDelayAsync(topic, message, _bufferTime).Ignore(); } // public async Task> RemoveMessagesAsync(string topic) { - var messages = new List(); - if (state.State.DeadLetterQueue != null && state.State.DeadLetterQueue.Remove(topic, out List? letters)) - { - await _stateManager.WriteStateAsync().ConfigureAwait(true); - if (letters != null) - { - messages.AddRange(letters); - } - } - if (state.State.EventBuffer != null && state.State.EventBuffer.Remove(topic, out List? events)) - { - await _stateManager.WriteStateAsync().ConfigureAwait(true); - if (events != null) - { - messages.AddRange(events); - } - } - return messages; + var removedDeadLetter = await _dlqQueue.RemoveMessagesAsync(topic); + var removedBuffer = await _ebQueue.RemoveMessagesAsync(topic); + return removedDeadLetter.Concat(removedBuffer).ToList(); } } diff --git a/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/MessageRegistryQueue.cs b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/MessageRegistryQueue.cs new file mode 100644 index 000000000000..6d8d55b6e0e5 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/RuntimeGateway.Grpc/Services/Orleans/MessageRegistryQueue.cs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// MessageRegistryQueue.cs + +using System.Collections.Concurrent; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.RuntimeGateway.Grpc; + +public sealed class MessageRegistryQueue +{ + private ConcurrentDictionary> _queue = new(); + private readonly int _maxMessageSize; + private readonly int _maxQueueSize; + private readonly Dictionary _timestamps = new(); + private int _currentSize; + private readonly IPersistentState _state; + private readonly ILogger _logger; + private readonly StateManager _stateManager; + private readonly MessageRegistryGrain.QueueType _queueType; + + internal MessageRegistryQueue(MessageRegistryGrain.QueueType queueType, + IPersistentState state, + StateManager stateManager, + ILogger logger, + int maxMessageSize, + int maxQueueSize) + { + if (state.State == null) + { + state.State = new MessageRegistryState(); + } + _queueType = queueType; + _state = state; + // use the queueType to get the correct queue from state.State. + _queue = GetQueue(); + _stateManager = stateManager; + _logger = logger; + _maxMessageSize = maxMessageSize; + _maxQueueSize = maxQueueSize; + } + + public async Task AddMessageAsync(string topic, CloudEvent message) + { + var size = message.CalculateSize(); + if (size > _maxMessageSize) + { + _logger.LogWarning("Message size {Size} for topic {Topic} in queue {Name} exceeds the maximum message size {Max}.", + size, topic, _queueType.ToString(), _maxMessageSize); + return; + } + if (_currentSize + size > _maxQueueSize) + { + while (_currentSize + size > _maxQueueSize && _timestamps.Count > 0) + { + var oldest = _timestamps.OrderBy(x => x.Key).First(); + if (await RemoveOldestMessage(oldest.Value)) + { + _timestamps.Remove(oldest.Key); + } + } + } + await AddOrUpdate(topic, message); + _currentSize += size; + } + + public async Task> RemoveMessagesAsync(string topic) + { + var removed = new List(); + var queue = GetQueue(); + if (queue.Remove(topic, out var events)) + { + removed.AddRange(events); + var total = 0; + foreach (var e in events) { total += e.CalculateSize(); } + _currentSize -= total; + } + // Remove timestamps that refer to this topic + var toRemove = _timestamps.Where(x => x.Value == topic).Select(x => x.Key).ToList(); + foreach (var t in toRemove) { _timestamps.Remove(t); } + await _stateManager.WriteStateAsync().ConfigureAwait(true); + return removed; + } + + public async Task RemoveMessageAsync(string topic, CloudEvent message) + { + var queue = GetQueue(); + if (queue.TryGetValue(topic, out var events) && events.Remove(message)) + { + _currentSize -= message.CalculateSize(); + await _stateManager.WriteStateAsync().ConfigureAwait(true); + return true; + } + return false; + } + + private async Task RemoveOldestMessage(string topic) + { + var queue = GetQueue(); + if (queue.TryGetValue(topic, out var events) && events != null && events.Count > 0) + { + var oldestEvent = events[0]; + events.RemoveAt(0); + _currentSize -= oldestEvent.CalculateSize(); + _timestamps.Remove(_timestamps.OrderBy(x => x.Key).First().Key); + queue[topic] = events; + await _stateManager.WriteStateAsync().ConfigureAwait(true); + return true; + } + return false; + } + + private async Task AddOrUpdate(string topic, CloudEvent message) + { + var queue = GetQueue(); + var list = queue.GetOrAdd(topic, _ => new()); + list.Add(message); + queue.AddOrUpdate(topic, list, (_, _) => list); + await _stateManager.WriteStateAsync().ConfigureAwait(true); + _timestamps.Add(DateTime.UtcNow, topic); + } + + private ConcurrentDictionary> GetQueue() + { + return _queueType switch + { + MessageRegistryGrain.QueueType.DeadLetterQueue => _state.State.DeadLetterQueue, + MessageRegistryGrain.QueueType.EventBuffer => _state.State.EventBuffer, + _ => throw new ArgumentException($"Invalid queue type: {_queueType}.") + }; + } + + public async Task RemoveMessageAfterDelayAsync(string topic, CloudEvent message, int delay) + { + await Task.Delay(delay); + await RemoveMessageAsync(topic, message); + _currentSize -= message.CalculateSize(); + } +} diff --git a/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/MessageRegistryTests.cs b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/MessageRegistryTests.cs index fee4cf73d77b..698cf122a690 100644 --- a/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/MessageRegistryTests.cs +++ b/dotnet/test/Microsoft.AutoGen.RuntimeGateway.Grpc.Tests/MessageRegistryTests.cs @@ -4,23 +4,19 @@ using Microsoft.AutoGen.Contracts; using Microsoft.AutoGen.RuntimeGateway.Grpc.Abstractions; using Microsoft.AutoGen.RuntimeGateway.Grpc.Tests.Helpers.Orleans; -using Orleans.TestingHost; namespace Microsoft.AutoGen.RuntimeGateway.Grpc.Tests; -public class MessageRegistryTests : IClassFixture +public class MessageRegistryTests { - private readonly TestCluster _cluster; - - public MessageRegistryTests(ClusterFixture fixture) - { - _cluster = fixture.Cluster; - } + public MessageRegistryTests() { } [Fact] public async Task Write_and_Remove_Messages() { // Arrange - var grain = _cluster.GrainFactory.GetGrain(0); + var fixture = new ClusterFixture(); + var cluster = fixture.Cluster; + var grain = cluster.GrainFactory.GetGrain(0); var topic = Guid.NewGuid().ToString(); // Random topic var message = new CloudEvent { Id = Guid.NewGuid().ToString(), Source = "test-source", Type = "test-type" }; @@ -36,6 +32,7 @@ public async Task Write_and_Remove_Messages() // ensure the queue is empty removedMessages = await grain.RemoveMessagesAsync(topic); Assert.Empty(removedMessages); + cluster.StopAllSilos(); } /// /// Test that messages are removed from the event buffer after the buffer time @@ -44,7 +41,9 @@ public async Task Write_and_Remove_Messages() public async Task Write_and_Remove_Messages_BufferTime() { // Arrange - var grain = _cluster.GrainFactory.GetGrain(0); + var fixture = new ClusterFixture(); + var cluster = fixture.Cluster; + var grain = cluster.GrainFactory.GetGrain(0); var topic = Guid.NewGuid().ToString(); // Random topic var message = new CloudEvent { Id = Guid.NewGuid().ToString(), Source = "test-source", Type = "test-type" }; @@ -55,6 +54,7 @@ public async Task Write_and_Remove_Messages_BufferTime() // attempt to remove the topic from the queue var removedMessages = await grain.RemoveMessagesAsync(topic); Assert.Empty(removedMessages); + cluster.StopAllSilos(); } /// @@ -64,7 +64,9 @@ public async Task Write_and_Remove_Messages_BufferTime() public async Task Write_and_Remove_Messages_BufferTime_StillInBuffer() { // Arrange - var grain = _cluster.GrainFactory.GetGrain(0); + var fixture = new ClusterFixture(); + var cluster = fixture.Cluster; + var grain = cluster.GrainFactory.GetGrain(0); var topic = Guid.NewGuid().ToString(); // Random topic var message = new CloudEvent { Id = Guid.NewGuid().ToString(), Source = "test-source", Type = "test-type" }; @@ -75,5 +77,57 @@ public async Task Write_and_Remove_Messages_BufferTime_StillInBuffer() // attempt to remove the topic from the queue var removedMessages = await grain.RemoveMessagesAsync(topic); Assert.Single(removedMessages); + cluster.StopAllSilos(); + } + + /// + /// Test that messages which exceed the mas message size are not written to the event buffer + /// + [Fact] + public async Task Do_No_Buffer_If_Messages_Exceed_MaxMessageSize() + { + // Arrange + var fixture = new ClusterFixture(); + var cluster = fixture.Cluster; + var grain = cluster.GrainFactory.GetGrain(0); + var topic = Guid.NewGuid().ToString(); // Random topic + var maxMessageSize = 1024 * 1024 * 10; // 10MB + var message = new CloudEvent { Id = Guid.NewGuid().ToString(), Source = "test-source", Type = "test-type" }; + + // Act + await grain.AddMessageToDeadLetterQueueAsync(topic, message); // small message + message.BinaryData = Google.Protobuf.ByteString.CopyFrom(new byte[maxMessageSize + 1]); + await grain.AddMessageToDeadLetterQueueAsync(topic, message); // over the limit + // attempt to remove the topic from the queue + var removedMessages = await grain.RemoveMessagesAsync(topic); + Assert.Single(removedMessages); // only the small message should be in the buffer + cluster.StopAllSilos(); + } + + /// + /// Test that the queue cannot grow past the max queue size + /// + [Fact] + public async Task Do_No_Buffer_If_Queue_Exceeds_MaxQueueSize() + { + // Arrange + var fixture = new ClusterFixture(); + var cluster = fixture.Cluster; + var grain = cluster.GrainFactory.GetGrain(0); + var topic = Guid.NewGuid().ToString(); // Random topic + var bigMessage = 1024 * 1024 * 1; // 1MB + var message = new CloudEvent { Id = Guid.NewGuid().ToString(), Source = "test-source", Type = "test-type" }; + + // Act + for (int i = 0; i < 11; i++) + { + message.BinaryData = Google.Protobuf.ByteString.CopyFrom(new byte[bigMessage]); + message.Source = i.ToString(); + await grain.AddMessageToDeadLetterQueueAsync(topic, message); + } + // attempt to remove the topic from the queue + var removedMessages = await grain.RemoveMessagesAsync(topic); + Assert.Equal(9, removedMessages.Count); // only 3 messages should be in the buffer + cluster.StopAllSilos(); } }