Skip to content

Commit

Permalink
Support SQS FIFO Queue SequenceNumber as SequenceToken, add ParitionedQC
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescarter-le committed Nov 5, 2024
1 parent be99547 commit 483cda0
Show file tree
Hide file tree
Showing 9 changed files with 416 additions and 21 deletions.
21 changes: 18 additions & 3 deletions src/AWS/Orleans.Streaming.SQS/Storage/SQSStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ internal class SQSStorage
private string queueUrl;
private AmazonSQSClient sqsClient;

private List<string> receiveAttributes;
private List<string> receiveMessageAttributes;


/// <summary>
/// The queue Name
/// </summary>
Expand All @@ -57,6 +61,9 @@ public SQSStorage(ILoggerFactory loggerFactory, string queueName, SqsOptions sqs
ParseDataConnectionString(sqsOptions.ConnectionString);
Logger = loggerFactory.CreateLogger<SQSStorage>();
CreateClient();

receiveAttributes = [..sqsOptions.ReceiveAttributes];
receiveMessageAttributes = [.. sqsOptions.ReceiveMessageAttributes];
}

private void ParseDataConnectionString(string dataConnectionString)
Expand Down Expand Up @@ -156,6 +163,7 @@ public async Task InitQueueAsync()

if (sqsOptions.FifoQueue)
{
// The stream must have these attributes to be a valid FIFO queue.
createQueueRequest.Attributes = new()
{
{ QueueAttributeName.FifoQueue, "true" },
Expand All @@ -166,7 +174,14 @@ public async Task InitQueueAsync()

// We require to bring down the AWS set SequenceNumber when on a FIFO queue
// in order to populate the SQSFIFOSequenceToken from it.
sqsOptions.ReceiveAttributes.Add("SequenceNumber");

if (!receiveMessageAttributes.Contains(MessageSystemAttributeName.SequenceNumber))
receiveMessageAttributes.Add(MessageSystemAttributeName.SequenceNumber);
if (!receiveMessageAttributes.Contains(MessageSystemAttributeName.MessageGroupId))
receiveMessageAttributes.Add(MessageSystemAttributeName.MessageGroupId);

// FIFO Queue does not support Long Polling
sqsOptions.ReceiveWaitTimeSeconds = null;
}

if (sqsOptions.ReceiveWaitTimeSeconds.HasValue)
Expand Down Expand Up @@ -254,8 +269,8 @@ public async Task<IEnumerable<SQSMessage>> GetMessages(int count = 1)
{
QueueUrl = queueUrl,
MaxNumberOfMessages = count <= MAX_NUMBER_OF_MESSAGE_TO_PEEK ? count : MAX_NUMBER_OF_MESSAGE_TO_PEEK,
AttributeNames = sqsOptions.ReceiveAttributes,
MessageAttributeNames = sqsOptions.ReceiveMessageAttributes,
AttributeNames = receiveAttributes,
MessageAttributeNames = receiveMessageAttributes,
};

if (sqsOptions.ReceiveWaitTimeSeconds.HasValue)
Expand Down
2 changes: 1 addition & 1 deletion src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public async Task QueueMessageBatchAsync<T>(StreamId streamId, IEnumerable<T> ev
if (this.sqsOptions.FifoQueue)
{
// Ensure the SQS Queue ensures FIFO order of messages over this QueueId.
sqsRequest.MessageGroupId = queueId.ToString();
sqsRequest.MessageGroupId = streamId.ToString();
}

foreach (var attr in sqsMessage.Attributes)
Expand Down
8 changes: 5 additions & 3 deletions src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using Orleans;
using Orleans.Configuration.Overrides;
using Orleans.Streaming.SQS.Streams;
using Orleans.Runtime;

namespace OrleansAWSUtils.Streams
{
Expand Down Expand Up @@ -44,9 +43,12 @@ public SQSAdapterFactory(
this.dataAdapter = dataAdapter;
this.loggerFactory = loggerFactory;
streamQueueMapper = new HashRingBasedStreamQueueMapper(queueMapperOptions, this.providerName);
adapterCache = new SimpleQueueAdapterCache(cacheOptions, this.providerName, this.loggerFactory);
}

if (sqsOptions.FifoQueue)
adapterCache = new StreamIdPartitionedQueueAdapterCache(cacheOptions, this.providerName, this.loggerFactory);
else
adapterCache = new SimpleQueueAdapterCache(cacheOptions, this.providerName, this.loggerFactory);
}

/// <summary> Init the factory.</summary>
public virtual void Init()
Expand Down
21 changes: 15 additions & 6 deletions src/AWS/Orleans.Streaming.SQS/Streams/SQSBatchContainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Amazon.SQS;
using Amazon.SQS.Model;
using SQSMessage = Amazon.SQS.Model.Message;

namespace OrleansAWSUtils.Streams
Expand Down Expand Up @@ -62,17 +65,17 @@ public SQSBatchContainer(StreamId streamId, List<object> events, Dictionary<stri

public IEnumerable<Tuple<T, StreamSequenceToken>> GetEvents<T>()
{
static StreamSequenceToken CreateStreamSequenceToken(StreamSequenceToken tok, int eventIndex)
static StreamSequenceToken CreateStreamSequenceToken(StreamId streamId, StreamSequenceToken tok, int eventIndex)
{
return tok switch
{
EventSequenceTokenV2 v2Tok => v2Tok.CreateSequenceTokenForEvent(eventIndex),
SQSFIFOSequenceToken fifoTok => fifoTok.CreateSequenceTokenForEvent(eventIndex),
SQSFIFOSequenceToken fifoTok => fifoTok.CreateSequenceTokenForEvent(streamId, eventIndex),
_ => throw new NotSupportedException("Unknown SequenceToken provided.")
};
}

return events.OfType<T>().Select((e, i) => Tuple.Create<T, StreamSequenceToken>(e, CreateStreamSequenceToken(sequenceToken, i)));
return events.OfType<T>().Select((e, i) => Tuple.Create<T, StreamSequenceToken>(e, CreateStreamSequenceToken(StreamId, sequenceToken, i)));
}

internal static SQSMessage ToSQSMessage<T>(
Expand All @@ -89,7 +92,7 @@ internal static SQSMessage ToSQSMessage<T>(
};
return new SQSMessage
{
Body = payload.ToString()
Body = payload.ToString(),
};
}

Expand All @@ -98,9 +101,15 @@ internal static SQSBatchContainer FromSQSMessage(Serializer<SQSBatchContainer> s
var json = JObject.Parse(msg.Body);
var sqsBatch = serializer.Deserialize(json["payload"].ToObject<byte[]>());
sqsBatch.Message = msg;

if (msg.Attributes.TryGetValue(MessageSystemAttributeName.SequenceNumber, out var fifoSeqNum))
{
if(!msg.Attributes.TryGetValue(MessageSystemAttributeName.MessageGroupId, out var messageGroupId))
throw new ArgumentException("FIFO SQS message does not have MessageGroupId attribute", nameof(msg));

if(msg.Attributes.TryGetValue("SequenceNumber", out var fifoSeqNum))
sqsBatch.sequenceToken = new SQSFIFOSequenceToken(UInt128.Parse(fifoSeqNum));
var streamId = StreamId.Parse(Encoding.UTF8.GetBytes(messageGroupId));
sqsBatch.sequenceToken = new SQSFIFOSequenceToken(streamId, UInt128.Parse(fifoSeqNum));
}
else
sqsBatch.sequenceToken = new EventSequenceTokenV2(sequenceNumber);

Expand Down
35 changes: 28 additions & 7 deletions src/AWS/Orleans.Streaming.SQS/Streams/SQSFIFOSequenceToken.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System;
using System.Globalization;
using System.IO;
using Newtonsoft.Json;
using Orleans;
using Orleans.Runtime;
using Orleans.Streams;

namespace OrleansAWSUtils.Streams
Expand All @@ -13,13 +15,20 @@ namespace OrleansAWSUtils.Streams
[GenerateSerializer]
public class SQSFIFOSequenceToken : StreamSequenceToken
{
/// <summary>
/// Gets the StreamId this SQS FIFO sequence token is related to,
/// as the SqsSequenceNumber is partitioned by StreamId (SQS MessageGroupId)
/// </summary>
[Id(2)][JsonProperty] public StreamId StreamId { get; set; }

/// <summary>
/// Gets the number of event batches in stream prior to this event batch
/// </summary>
[Id(0)]
[JsonProperty]
public UInt128 SqsSequenceNumber { get; set; }


/// <summary>
/// Gets the number of event batches in stream prior to this event batch
/// </summary>
Expand All @@ -40,19 +49,22 @@ public override long SequenceNumber
/// Initializes a new instance of the <see cref="SQSFIFOSequenceToken"/> class.
/// </summary>
/// <param name="seqNumber">The sequence number.</param>
public SQSFIFOSequenceToken(UInt128 seqNumber)
public SQSFIFOSequenceToken(StreamId streamId, UInt128 seqNumber)
{
StreamId = streamId;
SqsSequenceNumber = seqNumber;
EventIndex = 0;
}

/// <summary>
/// Initializes a new instance of the <see cref="SQSFIFOSequenceToken"/> class.
/// </summary>
/// <param name="streamId">The stream id for which this token relates.</param>
/// <param name="seqNumber">The sequence number.</param>
/// <param name="eventInd">The event index, for events which are part of a batch of events.</param>
public SQSFIFOSequenceToken(UInt128 seqNumber, int eventInd)
public SQSFIFOSequenceToken(StreamId streamId, UInt128 seqNumber, int eventInd)
{
StreamId = streamId;
SqsSequenceNumber = seqNumber;
EventIndex = eventInd;
}
Expand All @@ -72,9 +84,9 @@ public SQSFIFOSequenceToken()
/// </summary>
/// <param name="eventInd">The event index.</param>
/// <returns>A new sequence token.</returns>
public SQSFIFOSequenceToken CreateSequenceTokenForEvent(int eventInd)
public SQSFIFOSequenceToken CreateSequenceTokenForEvent(StreamId streamId, int eventInd)
{
return new SQSFIFOSequenceToken(SqsSequenceNumber, eventInd);
return new SQSFIFOSequenceToken(streamId, SqsSequenceNumber, eventInd);
}

/// <inheritdoc/>
Expand All @@ -87,8 +99,14 @@ public override bool Equals(object obj)
public override bool Equals(StreamSequenceToken other)
{
var token = other as SQSFIFOSequenceToken;
return token != null && (token.SqsSequenceNumber == SqsSequenceNumber &&
token.EventIndex == EventIndex);
if (token == null)
return false;

if (token.StreamId != StreamId)
return false;

return (token.SqsSequenceNumber == SqsSequenceNumber &&
token.EventIndex == EventIndex);
}

/// <inheritdoc/>
Expand All @@ -101,6 +119,9 @@ public override int CompareTo(StreamSequenceToken other)
if (token == null)
throw new ArgumentOutOfRangeException(nameof(other));

if(token.StreamId != StreamId)
throw new ArgumentOutOfRangeException(nameof(other), "Cannot compare between tokens of different StreamIds");

int difference = SqsSequenceNumber.CompareTo(token.SqsSequenceNumber);
return difference != 0 ? difference : EventIndex.CompareTo(token.EventIndex);
}
Expand All @@ -115,7 +136,7 @@ public override int GetHashCode()
/// <inheritdoc/>
public override string ToString()
{
return string.Format(CultureInfo.InvariantCulture, "[SQSFIFOSequenceToken: SeqNum={0}, EventIndex={1}]", SqsSequenceNumber, EventIndex);
return string.Format(CultureInfo.InvariantCulture, "[SQSFIFOSequenceToken: SeqNum={0}, EventIndex={1}, StreamId={2}]", SqsSequenceNumber, EventIndex, StreamId);
}
}
}
102 changes: 102 additions & 0 deletions src/AWS/Orleans.Streaming.SQS/Streams/StreamIdPartionedQueueCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using System.Collections.Generic;
using System.Linq;
using Microsoft.Extensions.Logging;
using Orleans.Configuration;
using Orleans.Providers.Streams.Common;
using Orleans.Runtime;
using Orleans.Streams;

namespace Orleans.Streaming.SQS.Streams;

/// <summary>
/// Adapter for simple queue caches.
/// </summary>
public class StreamIdPartitionedQueueAdapterCache : IQueueAdapterCache
{
/// <summary>
/// Cache size property name for configuration
/// </summary>
public const string CacheSizePropertyName = "CacheSize";

private readonly int cacheSize;
private readonly string providerName;
private readonly ILoggerFactory loggerFactory;

/// <summary>
/// Adapter for simple queue caches.
/// </summary>
/// <param name="options">The options.</param>
/// <param name="providerName">The stream provider name.</param>
/// <param name="loggerFactory">The logger factory.</param>
public StreamIdPartitionedQueueAdapterCache(SimpleQueueCacheOptions options, string providerName, ILoggerFactory loggerFactory)
{
this.cacheSize = options.CacheSize;
this.loggerFactory = loggerFactory;
this.providerName = providerName;
}

/// <inheritdoc />
public IQueueCache CreateQueueCache(QueueId queueId)
{
return new StreamIdPartitionedQueueCache(cacheSize, this.loggerFactory.CreateLogger($"{typeof(SimpleQueueCache).FullName}.{providerName}.{queueId}"));
}
}

public class StreamIdPartitionedQueueCache : IQueueCache
{
private Dictionary<StreamId, IQueueCache> _partitionedCaches = new();

private ILogger logger;
private int maxCacheSize;
private readonly int CACHE_HISTOGRAM_MAX_BUCKET_SIZE = 10;

public StreamIdPartitionedQueueCache(int cacheSize, ILogger logger)
{
maxCacheSize = cacheSize;
this.logger = logger;
}

public int GetMaxAddCount() => CACHE_HISTOGRAM_MAX_BUCKET_SIZE;

public void AddToCache(IList<IBatchContainer> messages)
{
foreach (var messagesByStream in messages.GroupBy(x => x.StreamId))
{
GetPartitionedCache(messagesByStream.Key)
.AddToCache(messagesByStream.ToList());
}
}

public bool TryPurgeFromCache(out IList<IBatchContainer> purgedItems)
{
bool purged = false;
var collector = new List<IBatchContainer>();
foreach (var cache in _partitionedCaches.Values)
{
if (cache.TryPurgeFromCache(out var partitionedPurgedItems))
{
purged = true;
collector.AddRange(partitionedPurgedItems);
}
}

purgedItems = collector;
return purged;
}

public IQueueCacheCursor GetCacheCursor(StreamId streamId, StreamSequenceToken token)
=> GetPartitionedCache(streamId).GetCacheCursor(streamId, token);

public bool IsUnderPressure() =>
_partitionedCaches.Values.Any(cache => cache.IsUnderPressure());

private IQueueCache GetPartitionedCache(StreamId streamId)
{
if (!_partitionedCaches.TryGetValue(streamId, out var cache))
{
cache = new SimpleQueueCache(maxCacheSize, logger);
_partitionedCaches.Add(streamId, cache);
}
return cache;
}
}
Loading

0 comments on commit 483cda0

Please sign in to comment.