diff --git a/src/AWS/Orleans.Streaming.SQS/Storage/SQSStorage.cs b/src/AWS/Orleans.Streaming.SQS/Storage/SQSStorage.cs index e4e73d1265..e9f6f4b967 100644 --- a/src/AWS/Orleans.Streaming.SQS/Storage/SQSStorage.cs +++ b/src/AWS/Orleans.Streaming.SQS/Storage/SQSStorage.cs @@ -3,12 +3,15 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Net; +using System.Text; using System.Threading.Tasks; using Amazon.SQS.Model; using Microsoft.Extensions.Logging; using Orleans.Streaming.SQS; using SQSMessage = Amazon.SQS.Model.Message; using Orleans; +using Orleans.Configuration; namespace OrleansAWSUtils.Storage { @@ -23,14 +26,21 @@ internal class SQSStorage public const int MAX_NUMBER_OF_MESSAGE_TO_PEEK = 10; private const string AccessKeyPropertyName = "AccessKey"; private const string SecretKeyPropertyName = "SecretKey"; + private const string SessionTokenPropertyName = "SessionToken"; private const string ServicePropertyName = "Service"; + private readonly SqsOptions sqsOptions; private readonly ILogger Logger; private string accessKey; private string secretKey; + private string sessionToken; private string service; private string queueUrl; private AmazonSQSClient sqsClient; + private List receiveAttributes; + private List receiveMessageAttributes; + + /// /// The queue Name /// @@ -41,19 +51,26 @@ internal class SQSStorage /// /// logger factory to use /// The name of the queue - /// The connection string + /// The options for the SQS connection /// The service ID - public SQSStorage(ILoggerFactory loggerFactory, string queueName, string connectionString, string serviceId = "") + public SQSStorage(ILoggerFactory loggerFactory, string queueName, SqsOptions sqsOptions, string serviceId = "") { - QueueName = string.IsNullOrWhiteSpace(serviceId) ? queueName : $"{serviceId}-{queueName}"; - ParseDataConnectionString(connectionString); + if (sqsOptions is null) throw new ArgumentNullException(nameof(sqsOptions)); + this.sqsOptions = sqsOptions; + QueueName = ConstructQueueName(queueName, sqsOptions, serviceId); + ParseDataConnectionString(sqsOptions.ConnectionString); Logger = loggerFactory.CreateLogger(); CreateClient(); + + receiveAttributes = [..sqsOptions.ReceiveAttributes]; + receiveMessageAttributes = [.. sqsOptions.ReceiveMessageAttributes]; } private void ParseDataConnectionString(string dataConnectionString) { - var parameters = dataConnectionString.Split(';', StringSplitOptions.RemoveEmptyEntries); + if(string.IsNullOrEmpty(dataConnectionString)) throw new ArgumentNullException(nameof(dataConnectionString)); + + var parameters = dataConnectionString.Split(new[] { ';' }, StringSplitOptions.RemoveEmptyEntries); var serviceConfig = parameters.FirstOrDefault(p => p.Contains(ServicePropertyName)); if (!string.IsNullOrWhiteSpace(serviceConfig)) @@ -78,6 +95,14 @@ private void ParseDataConnectionString(string dataConnectionString) if (value.Length == 2 && !string.IsNullOrWhiteSpace(value[1])) accessKey = value[1]; } + + var sessionTokenConfig = parameters.Where(p => p.Contains(SessionTokenPropertyName)).FirstOrDefault(); + if (!string.IsNullOrWhiteSpace(sessionTokenConfig)) + { + var value = sessionTokenConfig.Split(new[] { '=' }, StringSplitOptions.RemoveEmptyEntries); + if (value.Length == 2 && !string.IsNullOrWhiteSpace(value[1])) + sessionToken = value[1]; + } } private void CreateClient() @@ -89,6 +114,12 @@ private void CreateClient() var credentials = new BasicAWSCredentials("dummy", "dummyKey"); sqsClient = new AmazonSQSClient(credentials, new AmazonSQSConfig { ServiceURL = service }); } + else if (!string.IsNullOrEmpty(accessKey) && !string.IsNullOrEmpty(secretKey) && !string.IsNullOrEmpty(sessionToken)) + { + // AWS SQS instance (auth via explicit credentials) + var credentials = new SessionAWSCredentials(accessKey, secretKey, sessionToken); + sqsClient = new AmazonSQSClient(credentials, new AmazonSQSConfig { RegionEndpoint = AWSUtils.GetRegionEndpoint(service) }); + } else if (!string.IsNullOrEmpty(accessKey) && !string.IsNullOrEmpty(secretKey)) { // AWS SQS instance (auth via explicit credentials) @@ -128,7 +159,44 @@ public async Task InitQueueAsync() { if (string.IsNullOrWhiteSpace(await GetQueueUrl())) { - var response = await sqsClient.CreateQueueAsync(QueueName); + var createQueueRequest = new CreateQueueRequest(QueueName); + + if (sqsOptions.FifoQueue) + { + // The stream must have these attributes to be a valid FIFO queue. + createQueueRequest.Attributes = new() + { + { QueueAttributeName.FifoQueue, "true" }, + { QueueAttributeName.FifoThroughputLimit, "perMessageGroupId" }, + { QueueAttributeName.DeduplicationScope, "messageGroup" }, + { QueueAttributeName.ContentBasedDeduplication, "true" }, + }; + + // We require to bring down the AWS set SequenceNumber when on a FIFO queue + // in order to populate the SQSFIFOSequenceToken from it. + + if (!receiveMessageAttributes.Contains(MessageSystemAttributeName.SequenceNumber)) + receiveMessageAttributes.Add(MessageSystemAttributeName.SequenceNumber); + if (!receiveMessageAttributes.Contains(MessageSystemAttributeName.MessageGroupId)) + receiveMessageAttributes.Add(MessageSystemAttributeName.MessageGroupId); + + // FIFO Queue does not support Long Polling + sqsOptions.ReceiveWaitTimeSeconds = null; + } + + if (sqsOptions.ReceiveWaitTimeSeconds.HasValue) + { + createQueueRequest.Attributes.Add(QueueAttributeName.ReceiveMessageWaitTimeSeconds, + sqsOptions.ReceiveWaitTimeSeconds.Value.ToString()); + } + + if (sqsOptions.VisibilityTimeoutSeconds.HasValue) + { + createQueueRequest.Attributes.Add(QueueAttributeName.VisibilityTimeout, + sqsOptions.VisibilityTimeoutSeconds.Value.ToString()); + } + + var response = await sqsClient.CreateQueueAsync(createQueueRequest); queueUrl = response.QueueUrl; } } @@ -169,7 +237,11 @@ public async Task AddMessage(SendMessageRequest message) throw new InvalidOperationException("Queue not initialized"); message.QueueUrl = queueUrl; - await sqsClient.SendMessageAsync(message); + var response = await sqsClient.SendMessageAsync(message); + if (response.HttpStatusCode != HttpStatusCode.OK) + { + throw new Exception("Failed to send message into SQS. "); + } } catch (Exception exc) { @@ -192,7 +264,18 @@ public async Task> GetMessages(int count = 1) if (count < 1) throw new ArgumentOutOfRangeException(nameof(count)); - var request = new ReceiveMessageRequest { QueueUrl = queueUrl, MaxNumberOfMessages = count <= MAX_NUMBER_OF_MESSAGE_TO_PEEK ? count : MAX_NUMBER_OF_MESSAGE_TO_PEEK }; + + var request = new ReceiveMessageRequest + { + QueueUrl = queueUrl, + MaxNumberOfMessages = count <= MAX_NUMBER_OF_MESSAGE_TO_PEEK ? count : MAX_NUMBER_OF_MESSAGE_TO_PEEK, + AttributeNames = receiveAttributes, + MessageAttributeNames = receiveMessageAttributes, + }; + + if (sqsOptions.ReceiveWaitTimeSeconds.HasValue) + request.WaitTimeSeconds = sqsOptions.ReceiveWaitTimeSeconds.Value; + var response = await sqsClient.ReceiveMessageAsync(request); return response.Messages; } @@ -221,7 +304,7 @@ public async Task DeleteMessage(SQSMessage message) if (string.IsNullOrWhiteSpace(queueUrl)) throw new InvalidOperationException("Queue not initialized"); - await sqsClient.DeleteMessageAsync( + var result = await sqsClient.DeleteMessageAsync( new DeleteMessageRequest { QueueUrl = queueUrl, ReceiptHandle = message.ReceiptHandle }); } catch (Exception exc) @@ -230,10 +313,71 @@ await sqsClient.DeleteMessageAsync( } } + public async Task DeleteMessages(IEnumerable messages) + { + try + { + foreach (var message in messages) + { + ValidateMessageForDeletion(message); + } + + foreach (var batch in messages.Chunk(MAX_NUMBER_OF_MESSAGE_TO_PEEK)) + { + var deleteRequest = new DeleteMessageBatchRequest + { + QueueUrl = queueUrl, + Entries = batch + .Select((m, i) => + new DeleteMessageBatchRequestEntry(i.ToString(), m.ReceiptHandle)) + .ToList() + }; + + var result = await sqsClient.DeleteMessageBatchAsync(deleteRequest); + foreach (var failed in result.Failed) + { + Logger.LogWarning("Failed to delete message {MessageId} from SQS queue {QueueName}. Error code: {ErrorCode}. Error message: {ErrorMessage}", + failed.Id, QueueName, failed.Code, failed.Message); + } + } + } + catch (Exception exc) + { + ReportErrorAndRethrow(exc, "GetMessages", ErrorCode.StreamProviderManagerBase); + } + } + + private void ValidateMessageForDeletion(SQSMessage message) + { + if (message == null) + throw new ArgumentNullException(nameof(message)); + + if (string.IsNullOrWhiteSpace(message.ReceiptHandle)) + throw new ArgumentNullException(nameof(message.ReceiptHandle)); + + if (string.IsNullOrWhiteSpace(queueUrl)) + throw new InvalidOperationException("Queue not initialized"); + } + private void ReportErrorAndRethrow(Exception exc, string operation, ErrorCode errorCode) { Logger.LogError((int)errorCode, exc, "Error doing {Operation} for SQS queue {QueueName}", operation, QueueName); throw new AggregateException($"Error doing {operation} for SQS queue {QueueName}", exc); } + + private static string ConstructQueueName(string queueName, SqsOptions sqsOptions, string serviceId) + { + var queueNameBuilder = new StringBuilder(); + if (!string.IsNullOrEmpty(serviceId)) + { + queueNameBuilder.Append(serviceId); + queueNameBuilder.Append("-"); + } + + queueNameBuilder.Append(queueName); + if (sqsOptions.FifoQueue) + queueNameBuilder.Append(".fifo"); + return queueNameBuilder.ToString(); + } } } diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/ISQSDataAdapter.cs b/src/AWS/Orleans.Streaming.SQS/Streams/ISQSDataAdapter.cs new file mode 100644 index 0000000000..097917a5cc --- /dev/null +++ b/src/AWS/Orleans.Streaming.SQS/Streams/ISQSDataAdapter.cs @@ -0,0 +1,8 @@ +using Orleans.Streams; +using SQSMessage = Amazon.SQS.Model.Message; + +namespace Orleans.Streaming.SQS.Streams; +public interface ISQSDataAdapter : IQueueDataAdapter +{ + IBatchContainer GetBatchContainer(SQSMessage sqsMessage, ref long sequenceNumber); +} diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapter.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapter.cs index a73af88957..6a964aca8d 100644 --- a/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapter.cs +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapter.cs @@ -4,17 +4,20 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading.Tasks; +using Amazon.SQS.Model; using Microsoft.Extensions.Logging; +using Orleans.Configuration; using Orleans.Runtime; -using Orleans.Serialization; +using Orleans.Streaming.SQS.Streams; +using System.Linq; namespace OrleansAWSUtils.Streams { internal class SQSAdapter : IQueueAdapter { protected readonly string ServiceId; - private readonly Serializer serializer; - protected readonly string DataConnectionString; + private readonly ISQSDataAdapter dataAdapter; + protected SqsOptions sqsOptions; private readonly IConsistentRingStreamQueueMapper streamQueueMapper; protected readonly ConcurrentDictionary Queues = new ConcurrentDictionary(); private readonly ILoggerFactory loggerFactory; @@ -23,13 +26,13 @@ internal class SQSAdapter : IQueueAdapter public StreamProviderDirection Direction { get { return StreamProviderDirection.ReadWrite; } } - public SQSAdapter(Serializer serializer, IConsistentRingStreamQueueMapper streamQueueMapper, ILoggerFactory loggerFactory, string dataConnectionString, string serviceId, string providerName) + public SQSAdapter(ISQSDataAdapter dataAdapter, IConsistentRingStreamQueueMapper streamQueueMapper, ILoggerFactory loggerFactory, SqsOptions sqsOptions, string serviceId, string providerName) { - if (string.IsNullOrEmpty(dataConnectionString)) throw new ArgumentNullException(nameof(dataConnectionString)); + if (sqsOptions is null) throw new ArgumentNullException(nameof(sqsOptions)); if (string.IsNullOrEmpty(serviceId)) throw new ArgumentNullException(nameof(serviceId)); this.loggerFactory = loggerFactory; - this.serializer = serializer; - DataConnectionString = dataConnectionString; + this.sqsOptions = sqsOptions; + this.dataAdapter = dataAdapter; this.ServiceId = serviceId; Name = providerName; this.streamQueueMapper = streamQueueMapper; @@ -37,7 +40,7 @@ public SQSAdapter(Serializer serializer, IConsistentRingStrea public IQueueAdapterReceiver CreateReceiver(QueueId queueId) { - return SQSAdapterReceiver.Create(this.serializer, this.loggerFactory, queueId, DataConnectionString, this.ServiceId); + return SQSAdapterReceiver.Create(this.dataAdapter, this.loggerFactory, queueId, sqsOptions, this.ServiceId); } public async Task QueueMessageBatchAsync(StreamId streamId, IEnumerable events, StreamSequenceToken token, Dictionary requestContext) @@ -50,12 +53,29 @@ public async Task QueueMessageBatchAsync(StreamId streamId, IEnumerable ev SQSStorage queue; if (!Queues.TryGetValue(queueId, out queue)) { - var tmpQueue = new SQSStorage(this.loggerFactory, queueId.ToString(), DataConnectionString, this.ServiceId); + var tmpQueue = new SQSStorage(this.loggerFactory, queueId.ToString(), sqsOptions, this.ServiceId); await tmpQueue.InitQueueAsync(); queue = Queues.GetOrAdd(queueId, tmpQueue); } - var msg = SQSBatchContainer.ToSQSMessage(this.serializer, streamId, events, requestContext); - await queue.AddMessage(msg); + + var sqsMessage = dataAdapter.ToQueueMessage(streamId, events, token, requestContext); + var sqsRequest = new SendMessageRequest(string.Empty, sqsMessage.Body); + + if (this.sqsOptions.FifoQueue) + { + // Ensure the SQS Queue ensures FIFO order of messages over this QueueId. + sqsRequest.MessageGroupId = streamId.ToString(); + } + + foreach (var attr in sqsMessage.Attributes) + { + sqsRequest.MessageAttributes.Add(attr.Key, new MessageAttributeValue + { + DataType = "String", + StringValue = attr.Value + }); + } + await queue.AddMessage(sqsRequest); } } } diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterFactory.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterFactory.cs index 7f3ce62882..b86e772db0 100644 --- a/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterFactory.cs +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterFactory.cs @@ -8,7 +8,7 @@ using Orleans.Configuration; using Orleans; using Orleans.Configuration.Overrides; -using Orleans.Serialization; +using Orleans.Streaming.SQS.Streams; namespace OrleansAWSUtils.Streams { @@ -18,7 +18,7 @@ public class SQSAdapterFactory : IQueueAdapterFactory private readonly string providerName; private readonly SqsOptions sqsOptions; private readonly ClusterOptions clusterOptions; - private readonly Serializer serializer; + private readonly ISQSDataAdapter dataAdapter; private readonly ILoggerFactory loggerFactory; private readonly HashRingBasedStreamQueueMapper streamQueueMapper; private readonly IQueueAdapterCache adapterCache; @@ -34,18 +34,21 @@ public SQSAdapterFactory( HashRingStreamQueueMapperOptions queueMapperOptions, SimpleQueueCacheOptions cacheOptions, IOptions clusterOptions, - Orleans.Serialization.Serializer serializer, + ISQSDataAdapter dataAdapter, ILoggerFactory loggerFactory) { this.providerName = name; this.sqsOptions = sqsOptions; this.clusterOptions = clusterOptions.Value; - this.serializer = serializer.GetSerializer(); + this.dataAdapter = dataAdapter; this.loggerFactory = loggerFactory; streamQueueMapper = new HashRingBasedStreamQueueMapper(queueMapperOptions, this.providerName); - adapterCache = new SimpleQueueAdapterCache(cacheOptions, this.providerName, this.loggerFactory); - } + if (sqsOptions.FifoQueue) + adapterCache = new StreamIdPartitionedQueueAdapterCache(cacheOptions, this.providerName, this.loggerFactory); + else + adapterCache = new SimpleQueueAdapterCache(cacheOptions, this.providerName, this.loggerFactory); + } /// Init the factory. public virtual void Init() @@ -60,7 +63,7 @@ public virtual void Init() /// Creates the Azure Queue based adapter. public virtual Task CreateAdapter() { - var adapter = new SQSAdapter(this.serializer, this.streamQueueMapper, this.loggerFactory, this.sqsOptions.ConnectionString, this.clusterOptions.ServiceId, this.providerName); + var adapter = new SQSAdapter(this.dataAdapter, this.streamQueueMapper, this.loggerFactory, this.sqsOptions, this.clusterOptions.ServiceId, this.providerName); return Task.FromResult(adapter); } @@ -92,7 +95,10 @@ public static SQSAdapterFactory Create(IServiceProvider services, string name) var cacheOptions = services.GetOptionsByName(name); var queueMapperOptions = services.GetOptionsByName(name); IOptions clusterOptions = services.GetProviderClusterOptions(name); - var factory = ActivatorUtilities.CreateInstance(services, name, sqsOptions, cacheOptions, queueMapperOptions, clusterOptions); + var dataAdapter = services.GetKeyedService(name) + ?? services.GetService() + ?? ActivatorUtilities.CreateInstance(services); + var factory = ActivatorUtilities.CreateInstance(services, name, sqsOptions, cacheOptions, queueMapperOptions, clusterOptions, dataAdapter); factory.Init(); return factory; } diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterReceiver.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterReceiver.cs index 649ea1e064..7fb82f3d61 100644 --- a/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterReceiver.cs +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterReceiver.cs @@ -6,7 +6,8 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.Logging; -using Orleans.Serialization; +using Orleans.Configuration; +using Orleans.Streaming.SQS.Streams; using SQSMessage = Amazon.SQS.Model.Message; namespace OrleansAWSUtils.Streams @@ -17,25 +18,24 @@ namespace OrleansAWSUtils.Streams internal class SQSAdapterReceiver : IQueueAdapterReceiver { private SQSStorage queue; - private long lastReadMessage; private Task outstandingTask; + private long lastReadMessage = -1; private readonly ILogger logger; - private readonly Serializer serializer; - + private readonly ISQSDataAdapter dataAdapter; public QueueId Id { get; private set; } - public static IQueueAdapterReceiver Create(Serializer serializer, ILoggerFactory loggerFactory, QueueId queueId, string dataConnectionString, string serviceId) + public static IQueueAdapterReceiver Create(ISQSDataAdapter dataAdapter, ILoggerFactory loggerFactory, QueueId queueId, SqsOptions sqsOptions, string serviceId) { if (queueId.IsDefault) throw new ArgumentNullException(nameof(queueId)); - if (string.IsNullOrEmpty(dataConnectionString)) throw new ArgumentNullException(nameof(dataConnectionString)); + if (sqsOptions is null) throw new ArgumentNullException(nameof(sqsOptions)); if (string.IsNullOrEmpty(serviceId)) throw new ArgumentNullException(nameof(serviceId)); - var queue = new SQSStorage(loggerFactory, queueId.ToString(), dataConnectionString, serviceId); - return new SQSAdapterReceiver(serializer, loggerFactory, queueId, queue); + var queue = new SQSStorage(loggerFactory, queueId.ToString(), sqsOptions, serviceId); + return new SQSAdapterReceiver(dataAdapter, loggerFactory, queueId, queue); } - private SQSAdapterReceiver(Serializer serializer, ILoggerFactory loggerFactory, QueueId queueId, SQSStorage queue) + private SQSAdapterReceiver(ISQSDataAdapter dataAdapter, ILoggerFactory loggerFactory, QueueId queueId, SQSStorage queue) { if (queueId.IsDefault) throw new ArgumentNullException(nameof(queueId)); if (queue == null) throw new ArgumentNullException(nameof(queue)); @@ -43,7 +43,7 @@ private SQSAdapterReceiver(Serializer serializer, ILoggerFact Id = queueId; this.queue = queue; logger = loggerFactory.CreateLogger(); - this.serializer = serializer; + this.dataAdapter = dataAdapter; } public Task Initialize(TimeSpan timeout) @@ -84,9 +84,7 @@ public async Task> GetQueueMessagesAsync(int maxCount) outstandingTask = task; IEnumerable messages = await task; - List messageBatch = messages - .Select(msg => (IBatchContainer)SQSBatchContainer.FromSQSMessage(this.serializer, msg, lastReadMessage++)).ToList(); - + List messageBatch = messages.Select(x => dataAdapter.GetBatchContainer(x, ref lastReadMessage)).ToList(); return messageBatch; } finally @@ -102,7 +100,7 @@ public async Task MessagesDeliveredAsync(IList messages) var queueRef = queue; // store direct ref, in case we are somehow asked to shutdown while we are receiving. if (messages.Count == 0 || queueRef == null) return; List cloudQueueMessages = messages.Cast().Select(b => b.Message).ToList(); - outstandingTask = Task.WhenAll(cloudQueueMessages.Select(queueRef.DeleteMessage)); + outstandingTask = queue.DeleteMessages(cloudQueueMessages); try { await outstandingTask; diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SQSBatchContainer.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SQSBatchContainer.cs index dc56e1f6e3..f0bb0db923 100644 --- a/src/AWS/Orleans.Streaming.SQS/Streams/SQSBatchContainer.cs +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SQSBatchContainer.cs @@ -1,4 +1,3 @@ -using Amazon.SQS.Model; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Orleans.Providers.Streams.Common; @@ -8,17 +7,20 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text; +using Amazon.SQS; +using Amazon.SQS.Model; using SQSMessage = Amazon.SQS.Model.Message; namespace OrleansAWSUtils.Streams { [Serializable] [Orleans.GenerateSerializer] - internal class SQSBatchContainer : IBatchContainer + public class SQSBatchContainer : IBatchContainer { [JsonProperty] [Orleans.Id(0)] - private EventSequenceTokenV2 sequenceToken; + private StreamSequenceToken sequenceToken; [JsonProperty] [Orleans.Id(1)] @@ -31,7 +33,7 @@ internal class SQSBatchContainer : IBatchContainer [NonSerialized] // Need to store reference to the original SQS Message to be able to delete it later on. // Don't need to serialize it, since we are never interested in sending it to stream consumers. - internal SQSMessage Message; + public SQSMessage Message; [Orleans.Id(3)] public StreamId StreamId { get; private set; } @@ -42,17 +44,17 @@ public StreamSequenceToken SequenceToken } [JsonConstructor] - private SQSBatchContainer( + public SQSBatchContainer( StreamId streamId, List events, Dictionary requestContext, - EventSequenceTokenV2 sequenceToken) + StreamSequenceToken sequenceToken) : this(streamId, events, requestContext) { this.sequenceToken = sequenceToken; } - private SQSBatchContainer(StreamId streamId, List events, Dictionary requestContext) + public SQSBatchContainer(StreamId streamId, List events, Dictionary requestContext) { if (events == null) throw new ArgumentNullException(nameof(events), "Message contains no events"); @@ -63,10 +65,20 @@ private SQSBatchContainer(StreamId streamId, List events, Dictionary> GetEvents() { - return events.OfType().Select((e, i) => Tuple.Create(e, sequenceToken.CreateSequenceTokenForEvent(i))); + static StreamSequenceToken CreateStreamSequenceToken(StreamId streamId, StreamSequenceToken tok, int eventIndex) + { + return tok switch + { + EventSequenceTokenV2 v2Tok => v2Tok.CreateSequenceTokenForEvent(eventIndex), + SQSFIFOSequenceToken fifoTok => fifoTok.CreateSequenceTokenForEvent(streamId, eventIndex), + _ => throw new NotSupportedException("Unknown SequenceToken provided.") + }; + } + + return events.OfType().Select((e, i) => Tuple.Create(e, CreateStreamSequenceToken(StreamId, sequenceToken, i))); } - internal static SendMessageRequest ToSQSMessage( + internal static SQSMessage ToSQSMessage( Serializer serializer, StreamId streamId, IEnumerable events, @@ -78,18 +90,29 @@ internal static SendMessageRequest ToSQSMessage( { { "payload", JToken.FromObject(rawBytes) } }; - return new SendMessageRequest + return new SQSMessage { - MessageBody = payload.ToString() + Body = payload.ToString(), }; } - internal static SQSBatchContainer FromSQSMessage(Serializer serializer, SQSMessage msg, long sequenceId) + internal static SQSBatchContainer FromSQSMessage(Serializer serializer, SQSMessage msg, long sequenceNumber) { var json = JObject.Parse(msg.Body); var sqsBatch = serializer.Deserialize(json["payload"].ToObject()); sqsBatch.Message = msg; - sqsBatch.sequenceToken = new EventSequenceTokenV2(sequenceId); + + if (msg.Attributes.TryGetValue(MessageSystemAttributeName.SequenceNumber, out var fifoSeqNum)) + { + if(!msg.Attributes.TryGetValue(MessageSystemAttributeName.MessageGroupId, out var messageGroupId)) + throw new ArgumentException("FIFO SQS message does not have MessageGroupId attribute", nameof(msg)); + + var streamId = StreamId.Parse(Encoding.UTF8.GetBytes(messageGroupId)); + sqsBatch.sequenceToken = new SQSFIFOSequenceToken(streamId, UInt128.Parse(fifoSeqNum)); + } + else + sqsBatch.sequenceToken = new EventSequenceTokenV2(sequenceNumber); + return sqsBatch; } diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SQSDataAdapter.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SQSDataAdapter.cs new file mode 100644 index 0000000000..5f517c4e0b --- /dev/null +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SQSDataAdapter.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using Orleans.Runtime; +using Orleans.Streams; +using OrleansAWSUtils.Streams; +using SQSMessage = Amazon.SQS.Model.Message; + +namespace Orleans.Streaming.SQS.Streams; + +/// +/// Default SQS Stream data adapter. Users may subclass to override event data to stream mapping. +/// + +public class SQSDataAdapter : ISQSDataAdapter +{ + private readonly Serialization.Serializer serializer; + + public SQSDataAdapter(Serialization.Serializer serializer) + { + this.serializer = serializer; + } + + /// + /// Convert an SQS Message to a batch container + /// + /// + /// + public virtual IBatchContainer GetBatchContainer(SQSMessage sqsMessage, ref long sequenceNumber) + { + return SQSBatchContainer.FromSQSMessage( + serializer.GetSerializer(), + sqsMessage, + Interlocked.Increment(ref sequenceNumber)); + } + + public virtual SQSMessage ToQueueMessage(StreamId streamId, IEnumerable events, StreamSequenceToken token, Dictionary requestContext) + { + if (token != null) throw new ArgumentException("SQS streams currently does not support non-null StreamSequenceToken.", nameof(token)); + return SQSBatchContainer.ToSQSMessage( + serializer.GetSerializer(), + streamId, + events, + requestContext); + } +} diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SQSFIFOSequenceToken.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SQSFIFOSequenceToken.cs new file mode 100644 index 0000000000..5809a1d7e4 --- /dev/null +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SQSFIFOSequenceToken.cs @@ -0,0 +1,142 @@ +using System; +using System.Globalization; +using System.IO; +using Newtonsoft.Json; +using Orleans; +using Orleans.Runtime; +using Orleans.Streams; + +namespace OrleansAWSUtils.Streams +{ + /// + /// Stream sequence token that tracks sequence number and event index + /// + [Serializable] + [GenerateSerializer] + public class SQSFIFOSequenceToken : StreamSequenceToken + { + /// + /// Gets the StreamId this SQS FIFO sequence token is related to, + /// as the SqsSequenceNumber is partitioned by StreamId (SQS MessageGroupId) + /// + [Id(2)][JsonProperty] public StreamId StreamId { get; set; } + + /// + /// Gets the number of event batches in stream prior to this event batch + /// + [Id(0)] + [JsonProperty] + public UInt128 SqsSequenceNumber { get; set; } + + + /// + /// Gets the number of event batches in stream prior to this event batch + /// + public override long SequenceNumber + { + get => throw new NotSupportedException(); + protected set => throw new NotSupportedException(); + } + + /// + /// Gets the number of events in batch prior to this event + /// + [Id(1)] + [JsonProperty] + public override int EventIndex { get; protected set; } + + /// + /// Initializes a new instance of the class. + /// + /// The sequence number. + public SQSFIFOSequenceToken(StreamId streamId, UInt128 seqNumber) + { + StreamId = streamId; + SqsSequenceNumber = seqNumber; + EventIndex = 0; + } + + /// + /// Initializes a new instance of the class. + /// + /// The stream id for which this token relates. + /// The sequence number. + /// The event index, for events which are part of a batch of events. + public SQSFIFOSequenceToken(StreamId streamId, UInt128 seqNumber, int eventInd) + { + StreamId = streamId; + SqsSequenceNumber = seqNumber; + EventIndex = eventInd; + } + + /// + /// Initializes a new instance of the class. + /// + /// + /// This constructor is for serializer use only. + /// + public SQSFIFOSequenceToken() + { + } + + /// + /// Creates a sequence token for a specific event in the current batch + /// + /// The event index. + /// A new sequence token. + public SQSFIFOSequenceToken CreateSequenceTokenForEvent(StreamId streamId, int eventInd) + { + return new SQSFIFOSequenceToken(streamId, SqsSequenceNumber, eventInd); + } + + /// + public override bool Equals(object obj) + { + return Equals(obj as SQSFIFOSequenceToken); + } + + /// + public override bool Equals(StreamSequenceToken other) + { + var token = other as SQSFIFOSequenceToken; + if (token == null) + return false; + + if (token.StreamId != StreamId) + return false; + + return (token.SqsSequenceNumber == SqsSequenceNumber && + token.EventIndex == EventIndex); + } + + /// + public override int CompareTo(StreamSequenceToken other) + { + if (other == null) + return 1; + + var token = other as SQSFIFOSequenceToken; + if (token == null) + throw new ArgumentOutOfRangeException(nameof(other)); + + if(token.StreamId != StreamId) + throw new ArgumentOutOfRangeException(nameof(other), "Cannot compare between tokens of different StreamIds"); + + int difference = SqsSequenceNumber.CompareTo(token.SqsSequenceNumber); + return difference != 0 ? difference : EventIndex.CompareTo(token.EventIndex); + } + + /// + public override int GetHashCode() + { + // why 397? + return (EventIndex * 397) ^ SqsSequenceNumber.GetHashCode(); + } + + /// + public override string ToString() + { + return string.Format(CultureInfo.InvariantCulture, "[SQSFIFOSequenceToken: SeqNum={0}, EventIndex={1}, StreamId={2}]", SqsSequenceNumber, EventIndex, StreamId); + } + } +} diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SQSStreamBuilder.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SQSStreamBuilder.cs index a84d757c66..e92055f7f1 100644 --- a/src/AWS/Orleans.Streaming.SQS/Streams/SQSStreamBuilder.cs +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SQSStreamBuilder.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; using Orleans.Configuration; +using Orleans.Streaming.SQS.Streams; using OrleansAWSUtils.Streams; namespace Orleans.Hosting @@ -36,6 +37,12 @@ public SiloSqsStreamConfigurator ConfigurePartitioning(int numOfparitions = Hash this.Configure(ob => ob.Configure(options => options.TotalQueueCount = numOfparitions)); return this; } + + public SiloSqsStreamConfigurator UseDataAdapter(Func factory) + { + this.ConfigureComponent(factory); + return this; + } } public class ClusterClientSqsStreamConfigurator : ClusterClientPersistentStreamConfigurator @@ -63,5 +70,11 @@ public ClusterClientSqsStreamConfigurator ConfigurePartitioning(int numOfparitio this.Configure(ob => ob.Configure(options => options.TotalQueueCount = numOfparitions)); return this; } + + public ClusterClientSqsStreamConfigurator UseDataAdapter(Func factory) + { + this.ConfigureComponent(factory); + return this; + } } } diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SQSStreamProviderUtils.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SQSStreamProviderUtils.cs index fe0418e779..141c79b689 100644 --- a/src/AWS/Orleans.Streaming.SQS/Streams/SQSStreamProviderUtils.cs +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SQSStreamProviderUtils.cs @@ -1,4 +1,4 @@ -using System.Collections.Generic; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -24,10 +24,12 @@ public static async Task DeleteAllUsedQueues(string providerName, string cluster var queueMapper = new HashRingBasedStreamQueueMapper(new HashRingStreamQueueMapperOptions(), providerName); List allQueues = queueMapper.GetAllQueues().ToList(); + var sqsOptions = new SqsOptions { ConnectionString = storageConnectionString }; + var deleteTasks = new List(); foreach (var queueId in allQueues) { - var manager = new SQSStorage(loggerFactory, queueId.ToString(), storageConnectionString, clusterId); + var manager = new SQSStorage(loggerFactory, queueId.ToString(), sqsOptions, clusterId); manager.InitQueueAsync().Wait(); deleteTasks.Add(manager.DeleteQueue()); } diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/SqsStreamOptions.cs b/src/AWS/Orleans.Streaming.SQS/Streams/SqsStreamOptions.cs index 3cce70482f..afce5b51d4 100644 --- a/src/AWS/Orleans.Streaming.SQS/Streams/SqsStreamOptions.cs +++ b/src/AWS/Orleans.Streaming.SQS/Streams/SqsStreamOptions.cs @@ -1,9 +1,46 @@ - + +using System.Collections.Generic; + namespace Orleans.Configuration { public class SqsOptions { + /// + /// Specifies the connection string to use for connecting to SQS. + /// + /// + /// Example for AWS: Service=eu-west-1;AccessKey=XXXXXX;SecretKey=XXXXXX;SessionToken=XXXXXX; + /// + /// + /// Example for LocalStack: Service=http://localhost:4566 + /// [Redact] public string ConnectionString { get; set; } + + /// + /// Specifies which SQS Attributes should be retrieved about the SQS message from the Queue. + /// + public List ReceiveAttributes { get; set; } = new(); + + /// + /// Specifies which Message Attributes should be retrieved with the SQS messages. + /// + public List ReceiveMessageAttributes { get; set; } = new(); + + /// + /// The optional duration to long-poll for new SQS messages. + /// + public int? ReceiveWaitTimeSeconds { get; set; } + + /// + /// The visibility timeout begins when Amazon SQS returns a message. + /// During this time, the consumer processes and deletes the message. + /// However, if the consumer fails before deleting the message and your system doesn't call the DeleteMessage action for that message before the visibility timeout expires, + /// the message becomes visible to other consumers and the message is received again. + /// If a message must be received only once, your consumer should delete it within the duration of the visibility timeout. + /// + public int? VisibilityTimeoutSeconds { get; set; } + + public bool FifoQueue { get; set; } } } diff --git a/src/AWS/Orleans.Streaming.SQS/Streams/StreamIdPartionedQueueCache.cs b/src/AWS/Orleans.Streaming.SQS/Streams/StreamIdPartionedQueueCache.cs new file mode 100644 index 0000000000..347f8714be --- /dev/null +++ b/src/AWS/Orleans.Streaming.SQS/Streams/StreamIdPartionedQueueCache.cs @@ -0,0 +1,102 @@ +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.Logging; +using Orleans.Configuration; +using Orleans.Providers.Streams.Common; +using Orleans.Runtime; +using Orleans.Streams; + +namespace Orleans.Streaming.SQS.Streams; + +/// +/// Adapter for simple queue caches. +/// +public class StreamIdPartitionedQueueAdapterCache : IQueueAdapterCache +{ + /// + /// Cache size property name for configuration + /// + public const string CacheSizePropertyName = "CacheSize"; + + private readonly int cacheSize; + private readonly string providerName; + private readonly ILoggerFactory loggerFactory; + + /// + /// Adapter for simple queue caches. + /// + /// The options. + /// The stream provider name. + /// The logger factory. + public StreamIdPartitionedQueueAdapterCache(SimpleQueueCacheOptions options, string providerName, ILoggerFactory loggerFactory) + { + this.cacheSize = options.CacheSize; + this.loggerFactory = loggerFactory; + this.providerName = providerName; + } + + /// + public IQueueCache CreateQueueCache(QueueId queueId) + { + return new StreamIdPartitionedQueueCache(cacheSize, this.loggerFactory.CreateLogger($"{typeof(SimpleQueueCache).FullName}.{providerName}.{queueId}")); + } +} + +public class StreamIdPartitionedQueueCache : IQueueCache +{ + private Dictionary _partitionedCaches = new(); + + private ILogger logger; + private int maxCacheSize; + private readonly int CACHE_HISTOGRAM_MAX_BUCKET_SIZE = 10; + + public StreamIdPartitionedQueueCache(int cacheSize, ILogger logger) + { + maxCacheSize = cacheSize; + this.logger = logger; + } + + public int GetMaxAddCount() => CACHE_HISTOGRAM_MAX_BUCKET_SIZE; + + public void AddToCache(IList messages) + { + foreach (var messagesByStream in messages.GroupBy(x => x.StreamId)) + { + GetPartitionedCache(messagesByStream.Key) + .AddToCache(messagesByStream.ToList()); + } + } + + public bool TryPurgeFromCache(out IList purgedItems) + { + bool purged = false; + var collector = new List(); + foreach (var cache in _partitionedCaches.Values) + { + if (cache.TryPurgeFromCache(out var partitionedPurgedItems)) + { + purged = true; + collector.AddRange(partitionedPurgedItems); + } + } + + purgedItems = collector; + return purged; + } + + public IQueueCacheCursor GetCacheCursor(StreamId streamId, StreamSequenceToken token) + => GetPartitionedCache(streamId).GetCacheCursor(streamId, token); + + public bool IsUnderPressure() => + _partitionedCaches.Values.Any(cache => cache.IsUnderPressure()); + + private IQueueCache GetPartitionedCache(StreamId streamId) + { + if (!_partitionedCaches.TryGetValue(streamId, out var cache)) + { + cache = new SimpleQueueCache(maxCacheSize, logger); + _partitionedCaches.Add(streamId, cache); + } + return cache; + } +} diff --git a/test/Extensions/AWSUtils.Tests/Streaming/SQSAdapterTests.cs b/test/Extensions/AWSUtils.Tests/Streaming/SQSAdapterTests.cs index 62dfe83e2e..259485a1e0 100644 --- a/test/Extensions/AWSUtils.Tests/Streaming/SQSAdapterTests.cs +++ b/test/Extensions/AWSUtils.Tests/Streaming/SQSAdapterTests.cs @@ -12,6 +12,10 @@ using Xunit.Abstractions; using OrleansAWSUtils.Storage; using Orleans.Configuration; +using Orleans.Serialization; +using Orleans.Serialization.Session; +using Orleans.Streaming.SQS.Streams; +using Microsoft.Extensions.DependencyInjection; namespace AWSUtils.Tests.Streaming { @@ -25,6 +29,7 @@ public class SQSAdapterTests : IAsyncLifetime private const int NumMessagesPerBatch = 20; private readonly string clusterId; public static readonly string SQS_STREAM_PROVIDER_NAME = "SQSAdapterTests"; + private readonly TimeSpan QueuePollRate = TimeSpan.FromSeconds(1); public SQSAdapterTests(ITestOutputHelper output, TestEnvironmentFixture fixture) { @@ -59,7 +64,9 @@ public async Task SendAndReceiveFromSQS() { ConnectionString = AWSTestConstants.SqsConnectionString, }; - var adapterFactory = new SQSAdapterFactory(SQS_STREAM_PROVIDER_NAME, options, new HashRingStreamQueueMapperOptions(), new SimpleQueueCacheOptions(), Options.Create(new ClusterOptions()), null, null); + var clusterOptions = new ClusterOptions { ServiceId = this.clusterId }; + var dataAdapter = new SQSDataAdapter(fixture.Serializer); + var adapterFactory = new SQSAdapterFactory(SQS_STREAM_PROVIDER_NAME, options, new HashRingStreamQueueMapperOptions(), new SimpleQueueCacheOptions(), Options.Create(clusterOptions), dataAdapter, NullLoggerFactory.Instance); adapterFactory.Init(); await SendAndReceiveFromQueueAdapter(adapterFactory); } @@ -90,13 +97,14 @@ private async Task SendAndReceiveFromQueueAdapter(IQueueAdapterFactory adapterFa QueueId queueId = receiverKvp.Key; var receiver = receiverKvp.Value; var qCache = caches[queueId]; - Task task = Task.Factory.StartNew(() => + Task task = Task.Factory.StartNew(async () => { while (receivedBatches < NumBatches) { - var messages = receiver.GetQueueMessagesAsync(SQSStorage.MAX_NUMBER_OF_MESSAGE_TO_PEEK).Result.ToArray(); + var messages = (await receiver.GetQueueMessagesAsync(SQSStorage.MAX_NUMBER_OF_MESSAGE_TO_PEEK)).ToArray(); if (!messages.Any()) { + await Task.Delay(QueuePollRate); continue; } foreach (var message in messages.Cast()) @@ -130,6 +138,9 @@ private async Task SendAndReceiveFromQueueAdapter(IQueueAdapterFactory adapterFa events.Take(NumMessagesPerBatch).ToArray(), null, RequestContextExtensions.Export(this.fixture.DeepCopier)).Wait()))); await Task.WhenAll(work); + // Wait for everything to be consumed. + await Task.Delay(QueuePollRate * 2); + // Make sure we got back everything we sent Assert.Equal(NumBatches, receivedBatches); diff --git a/test/Extensions/AWSUtils.Tests/Streaming/SQSDataAdapterTests.cs b/test/Extensions/AWSUtils.Tests/Streaming/SQSDataAdapterTests.cs new file mode 100644 index 0000000000..327e69e6d8 --- /dev/null +++ b/test/Extensions/AWSUtils.Tests/Streaming/SQSDataAdapterTests.cs @@ -0,0 +1,257 @@ +using System.Collections.Concurrent; +using System.Data; +using System.Globalization; +using System.Text; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Orleans.Providers.Streams.Common; +using Orleans.Runtime; +using Orleans.Streams; +using OrleansAWSUtils.Streams; +using AWSUtils.Tests.StorageTests; +using TestExtensions; +using Xunit; +using Xunit.Abstractions; +using OrleansAWSUtils.Storage; +using Orleans.Configuration; +using Orleans.Serialization; +using Orleans.Streaming.SQS.Streams; +using Message = Amazon.SQS.Model.Message; + +namespace AWSUtils.Tests.Streaming +{ + [TestCategory("AWS"), TestCategory("SQS")] + [Collection(TestEnvironmentFixture.DefaultCollection)] + public class SQSDataAdapterTests : IAsyncLifetime + { + private readonly ITestOutputHelper output; + private readonly TestEnvironmentFixture fixture; + private const int NumBatches = 20; + private const int NumMessagesPerBatch = 20; + private readonly string clusterId; + public static readonly string SQS_STREAM_PROVIDER_NAME = "SQSAdapterTests"; + private readonly TimeSpan QueuePollRate = TimeSpan.FromSeconds(1); + + public SQSDataAdapterTests(ITestOutputHelper output, TestEnvironmentFixture fixture) + { + if (!AWSTestConstants.IsSqsAvailable) + { + throw new SkipException("Empty connection string"); + } + + this.output = output; + this.fixture = fixture; + this.clusterId = MakeClusterId(); + } + + public Task InitializeAsync() => Task.CompletedTask; + + public async Task DisposeAsync() + { + if (!string.IsNullOrWhiteSpace(AWSTestConstants.SqsConnectionString)) + { + await SQSStreamProviderUtils.DeleteAllUsedQueues( + SQS_STREAM_PROVIDER_NAME, + this.clusterId, + AWSTestConstants.SqsConnectionString, + NullLoggerFactory.Instance); + } + } + + [SkippableFact] + public async Task SendAndReceiveFromSQS() + { + var options = new SqsOptions + { + ConnectionString = AWSTestConstants.SqsConnectionString, + ReceiveMessageAttributes = new[] { "StreamId" }.ToList() + }; + var clusterOptions = new ClusterOptions { ServiceId = this.clusterId }; + var dataAdapter = new StringOrIntSqlDataAdapter(fixture.Serializer); + var adapterFactory = new SQSAdapterFactory(SQS_STREAM_PROVIDER_NAME, options, new HashRingStreamQueueMapperOptions(), new SimpleQueueCacheOptions(), Options.Create(clusterOptions), dataAdapter, NullLoggerFactory.Instance); + adapterFactory.Init(); + await SendAndReceiveFromQueueAdapter(adapterFactory); + } + + private async Task SendAndReceiveFromQueueAdapter(IQueueAdapterFactory adapterFactory) + { + IQueueAdapter adapter = await adapterFactory.CreateAdapter(); + IQueueAdapterCache cache = adapterFactory.GetQueueAdapterCache(); + + // Create receiver per queue + IStreamQueueMapper mapper = adapterFactory.GetStreamQueueMapper(); + Dictionary receivers = mapper.GetAllQueues().ToDictionary(queueId => queueId, adapter.CreateReceiver); + Dictionary caches = mapper.GetAllQueues().ToDictionary(queueId => queueId, cache.CreateQueueCache); + + await Task.WhenAll(receivers.Values.Select(receiver => receiver.Initialize(TimeSpan.FromSeconds(5)))); + + // test using 2 streams + Guid streamId1 = Guid.NewGuid(); + Guid streamId2 = Guid.NewGuid(); + + int receivedBatches = 0; + var streamsPerQueue = new ConcurrentDictionary>(); + + // reader threads (at most 2 active queues because only two streams) + var work = new List(); + foreach (KeyValuePair receiverKvp in receivers) + { + QueueId queueId = receiverKvp.Key; + var receiver = receiverKvp.Value; + var qCache = caches[queueId]; + Task task = Task.Factory.StartNew(async () => + { + while (receivedBatches < NumBatches) + { + var messages = (await receiver.GetQueueMessagesAsync(SQSStorage.MAX_NUMBER_OF_MESSAGE_TO_PEEK)).ToArray(); + if (!messages.Any()) + { + await Task.Delay(QueuePollRate); + continue; + } + foreach (var message in messages.Cast()) + { + streamsPerQueue.AddOrUpdate(queueId, + id => new HashSet { message.StreamId }, + (id, set) => + { + set.Add(message.StreamId); + return set; + }); + output.WriteLine("Queue {0} received message on stream {1}", queueId, + message.StreamId); + Assert.Equal(NumMessagesPerBatch / 2, message.GetEvents().Count()); // "Half the events were ints" + Assert.Equal(NumMessagesPerBatch / 2, message.GetEvents().Count()); // "Half the events were strings" + } + Interlocked.Add(ref receivedBatches, messages.Length); + qCache.AddToCache(messages); + } + }); + work.Add(task); + } + + // send events + List events = CreateEvents(NumMessagesPerBatch); + work.Add(Task.Factory.StartNew(() => Enumerable.Range(0, NumBatches) + .Select(i => i % 2 == 0 ? streamId1 : streamId2) + .ToList() + .ForEach(streamId => + adapter.QueueMessageBatchAsync(StreamId.Create(streamId.ToString(), streamId), + events.Take(NumMessagesPerBatch).ToArray(), null, RequestContextExtensions.Export(this.fixture.DeepCopier)).Wait()))); + await Task.WhenAll(work); + + // Wait for everything to be consumed. + await Task.Delay(QueuePollRate * 2); + + // Make sure we got back everything we sent + Assert.Equal(NumBatches, receivedBatches); + + // check to see if all the events are in the cache and we can enumerate through them + StreamSequenceToken firstInCache = new EventSequenceTokenV2(0); + foreach (KeyValuePair> kvp in streamsPerQueue) + { + var receiver = receivers[kvp.Key]; + var qCache = caches[kvp.Key]; + + foreach (StreamId streamGuid in kvp.Value) + { + // read all messages in cache for stream + IQueueCacheCursor cursor = qCache.GetCacheCursor(streamGuid, firstInCache); + int messageCount = 0; + StreamSequenceToken tenthInCache = null; + StreamSequenceToken lastToken = firstInCache; + while (cursor.MoveNext()) + { + Exception ex; + messageCount++; + IBatchContainer batch = cursor.GetCurrent(out ex); + output.WriteLine("Token: {0}", batch.SequenceToken); + Assert.True(batch.SequenceToken.CompareTo(lastToken) >= 0, $"order check for event {messageCount}"); + lastToken = batch.SequenceToken; + if (messageCount == 10) + { + tenthInCache = batch.SequenceToken; + } + } + output.WriteLine("On Queue {0} we received a total of {1} message on stream {2}", kvp.Key, messageCount, streamGuid); + Assert.Equal(NumBatches / 2, messageCount); + Assert.NotNull(tenthInCache); + + // read all messages from the 10th + cursor = qCache.GetCacheCursor(streamGuid, tenthInCache); + messageCount = 0; + while (cursor.MoveNext()) + { + messageCount++; + } + output.WriteLine("On Queue {0} we received a total of {1} message on stream {2}", kvp.Key, messageCount, streamGuid); + const int expected = NumBatches / 2 - 10 + 1; // all except the first 10, including the 10th (10 + 1) + Assert.Equal(expected, messageCount); + } + } + } + + private List CreateEvents(int count) + { + return Enumerable.Range(0, count).Select(i => + { + if (i % 2 == 0) + { + return Random.Shared.Next(int.MaxValue) as object; + } + return Random.Shared.Next(int.MaxValue).ToString(CultureInfo.InvariantCulture); + }).ToList(); + } + + internal static string MakeClusterId() + { + const string DeploymentIdFormat = "cluster-{0}"; + string now = DateTime.UtcNow.ToString("yyyy-MM-dd-hh-mm-ss-ffff"); + return string.Format(DeploymentIdFormat, now); + } + + private class StringOrIntSqlDataAdapter : SQSDataAdapter + { + public StringOrIntSqlDataAdapter(Serializer serializer) : base(serializer) + { + } + + public override IBatchContainer GetBatchContainer(Message sqsMessage, ref long sequenceNumber) + { + // Example extracts the StreamId as an attribute instead of it being serialized in the body. + if (!sqsMessage.MessageAttributes.TryGetValue("StreamId", out var streamIdStr)) + throw new DataException("SQS Message did not contain a StreamId attribute."); + var streamId = StreamId.Parse(Encoding.UTF8.GetBytes(streamIdStr.StringValue)); + + // Contrived example sends strings as quoted, and longs as unquoted. + var events = sqsMessage.Body.Split(Environment.NewLine) + .Select(x => (object)(x.StartsWith('"') ? x.Trim('"') : int.Parse(x))) + .ToList(); + + return new SQSBatchContainer( + streamId, + events, + new Dictionary(), + new EventSequenceTokenV2(Interlocked.Increment(ref sequenceNumber)) + ); + } + + public override Message ToQueueMessage(StreamId streamId, IEnumerable events, StreamSequenceToken token, Dictionary requestContext) + { + // Contrived example sends strings as quoted, and longs as unquoted. + var serializedData = string.Join(Environment.NewLine, + events.Select(x => x is string ? $"\"{x}\"" : x.ToString())); + + // Example includes the StreamId as an attribute. + return new Message + { + Attributes = new() + { + { "StreamId", streamId.ToString() } + }, + Body = serializedData + }; + } + } + } +} diff --git a/test/Extensions/AWSUtils.Tests/Streaming/SQSFIFOStreamTests.cs b/test/Extensions/AWSUtils.Tests/Streaming/SQSFIFOStreamTests.cs new file mode 100644 index 0000000000..67dfaa9907 --- /dev/null +++ b/test/Extensions/AWSUtils.Tests/Streaming/SQSFIFOStreamTests.cs @@ -0,0 +1,244 @@ +using AWSUtils.Tests.StorageTests; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Configuration; +using Orleans.TestingHost; +using UnitTests.StreamingTests; +using Xunit; +using TestExtensions; +using UnitTests.Streaming; +using OrleansAWSUtils.Streams; + +namespace AWSUtils.Tests.Streaming +{ + [TestCategory("AWS"), TestCategory("SQS")] + public class SQSFIFOStreamTests : TestClusterPerTest + { + public static readonly string SQS_STREAM_PROVIDER_NAME = "SQSProvider"; + + private SingleStreamTestRunner runner; + + protected override void ConfigureTestCluster(TestClusterBuilder builder) + { + if (!AWSTestConstants.IsSqsAvailable) + { + throw new SkipException("Empty connection string"); + } + builder.AddSiloBuilderConfigurator(); + builder.AddClientBuilderConfigurator(); + } + + private class MySiloBuilderConfigurator : ISiloConfigurator + { + public void Configure(ISiloBuilder hostBuilder) + { + hostBuilder + .AddSqsStreams("SQSProvider", options => + { + options.ConfigurePullingAgent(agent => + { + agent.Configure(x => + { + x.MaxEventDeliveryTime = TimeSpan.FromSeconds(55); + x.GetQueueMsgsTimerPeriod = TimeSpan.FromMilliseconds(25); + }); + }); + options.ConfigureSqs(opt => + opt.Configure(sqs => + { + sqs.FifoQueue = true; + sqs.ConnectionString = AWSTestConstants.SqsConnectionString; + })); + }) + .AddSqsStreams("SQSProvider2", options => + { + options.ConfigurePullingAgent(agent => + { + agent.Configure(x => + { + x.MaxEventDeliveryTime = TimeSpan.FromSeconds(55); + x.GetQueueMsgsTimerPeriod = TimeSpan.FromMilliseconds(25); + }); + }); + options.ConfigureSqs(opt => + opt.Configure(sqs => + { + sqs.FifoQueue = true; + sqs.ConnectionString = AWSTestConstants.SqsConnectionString; + })); + }) + .AddMemoryGrainStorage("MemoryStore", op=>op.NumStorageGrains = 1); + + if (!string.IsNullOrEmpty(AWSTestConstants.DynamoDbService)) + { + hostBuilder + .AddDynamoDBGrainStorage("DynamoDBStore", options => + { + options.Service = AWSTestConstants.DynamoDbService; + options.SecretKey = AWSTestConstants.DynamoDbSecretKey; + options.AccessKey = AWSTestConstants.DynamoDbAccessKey; + options.DeleteStateOnClear = true; + }) + .AddDynamoDBGrainStorage("PubSubStore", options => + { + options.Service = AWSTestConstants.DynamoDbService; + options.SecretKey = AWSTestConstants.DynamoDbSecretKey; + options.AccessKey = AWSTestConstants.DynamoDbAccessKey; + }); + } + else + { + hostBuilder + .AddMemoryGrainStorage("DynamoDBStore") + .AddMemoryGrainStorage("PubSubStore"); + } + } + } + + private class MyClientBuilderConfigurator : IClientBuilderConfigurator + { + public void Configure(IConfiguration configuration, IClientBuilder clientBuilder) + { + clientBuilder + .AddSqsStreams("SQSProvider", (options => + { + options.ConfigureSqs(opt => + opt.Configure(sqs => + { + sqs.FifoQueue = true; + sqs.ConnectionString = AWSTestConstants.SqsConnectionString; + })); + })); + } + } + + public override async Task InitializeAsync() + { + await base.InitializeAsync(); + runner = new SingleStreamTestRunner(this.InternalClient, SQS_STREAM_PROVIDER_NAME); + } + + public override async Task DisposeAsync() + { + var clusterId = HostedCluster.Options.ClusterId; + await base.DisposeAsync(); + if (!string.IsNullOrWhiteSpace(AWSTestConstants.SqsConnectionString)) + { + SQSStreamProviderUtils.DeleteAllUsedQueues(SQS_STREAM_PROVIDER_NAME, clusterId, AWSTestConstants.SqsConnectionString, NullLoggerFactory.Instance).Wait(); + } + } + + ////------------------------ One to One ----------------------// + + [SkippableFact] + public async Task SQS_01_OneProducerGrainOneConsumerGrain() + { + await runner.StreamTest_01_OneProducerGrainOneConsumerGrain(); + } + + [SkippableFact] + public async Task SQS_02_OneProducerGrainOneConsumerClient() + { + await runner.StreamTest_02_OneProducerGrainOneConsumerClient(); + } + + [SkippableFact] + public async Task SQS_03_OneProducerClientOneConsumerGrain() + { + await runner.StreamTest_03_OneProducerClientOneConsumerGrain(); + } + + [SkippableFact] + public async Task SQS_04_OneProducerClientOneConsumerClient() + { + await runner.StreamTest_04_OneProducerClientOneConsumerClient(); + } + + //------------------------ MANY to Many different grains ----------------------// + + [SkippableFact] + public async Task SQS_05_ManyDifferent_ManyProducerGrainsManyConsumerGrains() + { + await runner.StreamTest_05_ManyDifferent_ManyProducerGrainsManyConsumerGrains(); + } + + [SkippableFact] + public async Task SQS_06_ManyDifferent_ManyProducerGrainManyConsumerClients() + { + await runner.StreamTest_06_ManyDifferent_ManyProducerGrainManyConsumerClients(); + } + + [SkippableFact] + public async Task SQS_07_ManyDifferent_ManyProducerClientsManyConsumerGrains() + { + await runner.StreamTest_07_ManyDifferent_ManyProducerClientsManyConsumerGrains(); + } + + [SkippableFact] + public async Task SQS_08_ManyDifferent_ManyProducerClientsManyConsumerClients() + { + await runner.StreamTest_08_ManyDifferent_ManyProducerClientsManyConsumerClients(); + } + + //------------------------ MANY to Many Same grains ----------------------// + [SkippableFact] + public async Task SQS_09_ManySame_ManyProducerGrainsManyConsumerGrains() + { + await runner.StreamTest_09_ManySame_ManyProducerGrainsManyConsumerGrains(); + } + + [SkippableFact] + public async Task SQS_10_ManySame_ManyConsumerGrainsManyProducerGrains() + { + await runner.StreamTest_10_ManySame_ManyConsumerGrainsManyProducerGrains(); + } + + [SkippableFact] + public async Task SQS_11_ManySame_ManyProducerGrainsManyConsumerClients() + { + await runner.StreamTest_11_ManySame_ManyProducerGrainsManyConsumerClients(); + } + + [SkippableFact] + public async Task SQS_12_ManySame_ManyProducerClientsManyConsumerGrains() + { + await runner.StreamTest_12_ManySame_ManyProducerClientsManyConsumerGrains(); + } + + //------------------------ MANY to Many producer consumer same grain ----------------------// + + [SkippableFact] + public async Task SQS_13_SameGrain_ConsumerFirstProducerLater() + { + await runner.StreamTest_13_SameGrain_ConsumerFirstProducerLater(false); + } + + [SkippableFact] + public async Task SQS_14_SameGrain_ProducerFirstConsumerLater() + { + await runner.StreamTest_14_SameGrain_ProducerFirstConsumerLater(false); + } + + //----------------------------------------------// + + [SkippableFact] + public async Task SQS_15_ConsumeAtProducersRequest() + { + await runner.StreamTest_15_ConsumeAtProducersRequest(); + } + + [SkippableFact] + public async Task SQS_16_MultipleStreams_ManyDifferent_ManyProducerGrainsManyConsumerGrains() + { + var multiRunner = new MultipleStreamsTestRunner(this.InternalClient, SQS_STREAM_PROVIDER_NAME, 16, false); + await multiRunner.StreamTest_MultipleStreams_ManyDifferent_ManyProducerGrainsManyConsumerGrains(); + } + + [SkippableFact] + public async Task SQS_17_MultipleStreams_1J_ManyProducerGrainsManyConsumerGrains() + { + var multiRunner = new MultipleStreamsTestRunner(this.InternalClient, SQS_STREAM_PROVIDER_NAME, 17, false); + await multiRunner.StreamTest_MultipleStreams_ManyDifferent_ManyProducerGrainsManyConsumerGrains( + this.HostedCluster.StartAdditionalSilo); + } + } +} diff --git a/test/Extensions/AWSUtils.Tests/Streaming/SQSStreamTests.cs b/test/Extensions/AWSUtils.Tests/Streaming/SQSStreamTests.cs index d97930abfe..29d036d5f7 100644 --- a/test/Extensions/AWSUtils.Tests/Streaming/SQSStreamTests.cs +++ b/test/Extensions/AWSUtils.Tests/Streaming/SQSStreamTests.cs @@ -40,20 +40,31 @@ public void Configure(ISiloBuilder hostBuilder) { options.ConnectionString = AWSTestConstants.SqsConnectionString; }) - .AddDynamoDBGrainStorage("DynamoDBStore", options => - { - options.Service = AWSTestConstants.DynamoDbService; - options.SecretKey = AWSTestConstants.DynamoDbSecretKey; - options.AccessKey = AWSTestConstants.DynamoDbAccessKey; - options.DeleteStateOnClear = true; - }) - .AddDynamoDBGrainStorage("PubSubStore", options => - { - options.Service = AWSTestConstants.DynamoDbService; - options.SecretKey = AWSTestConstants.DynamoDbSecretKey; - options.AccessKey = AWSTestConstants.DynamoDbAccessKey; - }) .AddMemoryGrainStorage("MemoryStore", op=>op.NumStorageGrains = 1); + + if (!string.IsNullOrEmpty(AWSTestConstants.DynamoDbService)) + { + hostBuilder + .AddDynamoDBGrainStorage("DynamoDBStore", options => + { + options.Service = AWSTestConstants.DynamoDbService; + options.SecretKey = AWSTestConstants.DynamoDbSecretKey; + options.AccessKey = AWSTestConstants.DynamoDbAccessKey; + options.DeleteStateOnClear = true; + }) + .AddDynamoDBGrainStorage("PubSubStore", options => + { + options.Service = AWSTestConstants.DynamoDbService; + options.SecretKey = AWSTestConstants.DynamoDbSecretKey; + options.AccessKey = AWSTestConstants.DynamoDbAccessKey; + }); + } + else + { + hostBuilder + .AddMemoryGrainStorage("DynamoDBStore") + .AddMemoryGrainStorage("PubSubStore"); + } } } diff --git a/test/Grains/TestGrainInterfaces/IStreamingGrain.cs b/test/Grains/TestGrainInterfaces/IStreamingGrain.cs index f13cb5ba20..e6c62842d3 100644 --- a/test/Grains/TestGrainInterfaces/IStreamingGrain.cs +++ b/test/Grains/TestGrainInterfaces/IStreamingGrain.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using Orleans.Streams; namespace UnitTests.GrainInterfaces diff --git a/test/Grains/TestInternalGrains/StreamingGrain.cs b/test/Grains/TestInternalGrains/StreamingGrain.cs index 1d0d71ed47..2a6713b600 100644 --- a/test/Grains/TestInternalGrains/StreamingGrain.cs +++ b/test/Grains/TestInternalGrains/StreamingGrain.cs @@ -229,6 +229,8 @@ private async Task ProduceItem(string data) if (_cleanedUpFlag.IsSet) return false; + data += $", random#{Guid.NewGuid()}"; + StreamItem item = new StreamItem(data, _streamId); await _observer.OnNextAsync(item); _itemsProduced++;