Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Fix for DrainMode #247

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,9 @@

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ.Samples;

public class SamplesTypeLocator : ITypeLocator
public class SamplesTypeLocator(params Type[] types) : ITypeLocator
{
private readonly Type[] types;

public SamplesTypeLocator(params Type[] types)
{
this.types = types;
}
private readonly Type[] types = types;

public IReadOnlyList<Type> GetTypes()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,10 @@

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

internal class RabbitMQClientBuilder : IConverter<RabbitMQAttribute, IModel>
internal class RabbitMQClientBuilder(RabbitMQExtensionConfigProvider configProvider, IOptions<RabbitMQOptions> options) : IConverter<RabbitMQAttribute, IModel>
{
private readonly RabbitMQExtensionConfigProvider configProvider;
private readonly IOptions<RabbitMQOptions> options;

public RabbitMQClientBuilder(RabbitMQExtensionConfigProvider configProvider, IOptions<RabbitMQOptions> options)
{
this.configProvider = configProvider;
this.options = options;
}
private readonly RabbitMQExtensionConfigProvider configProvider = configProvider;
private readonly IOptions<RabbitMQOptions> options = options;

public IModel Convert(RabbitMQAttribute attribute)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,14 @@
namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

[Extension("RabbitMQ")]
internal class RabbitMQExtensionConfigProvider : IExtensionConfigProvider
internal class RabbitMQExtensionConfigProvider(IOptions<RabbitMQOptions> options, INameResolver nameResolver, IRabbitMQServiceFactory rabbitMQServiceFactory, ILoggerFactory loggerFactory, IConfiguration configuration) : IExtensionConfigProvider
{
private readonly IOptions<RabbitMQOptions> options;
private readonly INameResolver nameResolver;
private readonly IRabbitMQServiceFactory rabbitMQServiceFactory;
private readonly ILogger logger;
private readonly IConfiguration configuration;
private readonly ConcurrentDictionary<string, IRabbitMQService> connectionParametersToService;

public RabbitMQExtensionConfigProvider(IOptions<RabbitMQOptions> options, INameResolver nameResolver, IRabbitMQServiceFactory rabbitMQServiceFactory, ILoggerFactory loggerFactory, IConfiguration configuration)
{
this.options = options;
this.nameResolver = nameResolver;
this.rabbitMQServiceFactory = rabbitMQServiceFactory;
this.logger = loggerFactory?.CreateLogger(LogCategories.CreateTriggerCategory("RabbitMQ"));
this.configuration = configuration;
this.connectionParametersToService = new ConcurrentDictionary<string, IRabbitMQService>();
}
private readonly IOptions<RabbitMQOptions> options = options;
private readonly INameResolver nameResolver = nameResolver;
private readonly IRabbitMQServiceFactory rabbitMQServiceFactory = rabbitMQServiceFactory;
private readonly ILogger logger = loggerFactory?.CreateLogger(LogCategories.CreateTriggerCategory("RabbitMQ"));
private readonly IConfiguration configuration = configuration;
private readonly ConcurrentDictionary<string, IRabbitMQService> connectionParametersToService = new();

public void Initialize(ExtensionConfigContext context)
{
Expand Down Expand Up @@ -88,15 +78,17 @@ internal RabbitMQContext CreateContext(RabbitMQAttribute attribute)

internal IRabbitMQService GetService(string connectionString, string queueName, bool disableCertificateValidation)
{
string[] keyArray = { connectionString, queueName, disableCertificateValidation.ToString() };
string[] keyArray =
[connectionString, queueName, disableCertificateValidation.ToString()];
string key = string.Join(",", keyArray);
return this.connectionParametersToService.GetOrAdd(key, _ => this.rabbitMQServiceFactory.CreateService(connectionString, queueName, disableCertificateValidation));
}

// Overloaded method used only for getting the RabbitMQ client
internal IRabbitMQService GetService(string connectionString, bool disableCertificateValidation)
{
string[] keyArray = { connectionString, disableCertificateValidation.ToString() };
string[] keyArray =
[connectionString, disableCertificateValidation.ToString()];
string key = string.Join(",", keyArray);
return this.connectionParametersToService.GetOrAdd(key, _ => this.rabbitMQServiceFactory.CreateService(connectionString, disableCertificateValidation));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,11 @@

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

public class BasicDeliverEventArgsValueProvider : IValueProvider
public class BasicDeliverEventArgsValueProvider(BasicDeliverEventArgs input, Type destinationType) : IValueProvider
{
private readonly BasicDeliverEventArgs input;
private readonly BasicDeliverEventArgs input = input;

public BasicDeliverEventArgsValueProvider(BasicDeliverEventArgs input, Type destinationType)
{
this.input = input;
this.Type = destinationType;
}

public Type Type { get; }
public Type Type { get; } = destinationType;

public Task<object> GetValueAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.WebJobs.Host;
using Microsoft.Azure.WebJobs.Host.Executors;
using Microsoft.Azure.WebJobs.Host.Listeners;
using Microsoft.Azure.WebJobs.Host.Scale;
Expand All @@ -33,7 +34,9 @@ internal sealed class RabbitMQListener : IListener, IScaleMonitor<RabbitMQTrigge
private readonly string queueName;
private readonly ushort prefetchCount;
private readonly string logDetails;
private readonly IDrainModeManager drainModeManager;

private readonly CancellationTokenSource listenerCancellationTokenSource;
private int listenerState = ListenerNotStarted;
private string consumerTag;

Expand All @@ -43,13 +46,16 @@ public RabbitMQListener(
ILogger logger,
string functionId,
string queueName,
ushort prefetchCount)
ushort prefetchCount,
IDrainModeManager drainModeManager)
{
this.channel = channel ?? throw new ArgumentNullException(nameof(channel));
this.executor = executor ?? throw new ArgumentNullException(nameof(executor));
this.logger = logger ?? throw new ArgumentNullException(nameof(logger));
this.queueName = !string.IsNullOrWhiteSpace(queueName) ? queueName : throw new ArgumentNullException(nameof(queueName));
this.prefetchCount = prefetchCount;
this.drainModeManager = drainModeManager;
this.listenerCancellationTokenSource = new CancellationTokenSource();

_ = !string.IsNullOrWhiteSpace(functionId) ? true : throw new ArgumentNullException(nameof(functionId));

Expand Down Expand Up @@ -103,7 +109,8 @@ async Task ReceivedHandler(object model, BasicDeliverEventArgs args)
using Activity activity = RabbitMQActivitySource.StartActivity(args.BasicProperties);

var input = new TriggeredFunctionData() { TriggerValue = args };
FunctionResult result = await this.executor.TryExecuteAsync(input, cancellationToken).ConfigureAwait(false);

FunctionResult result = await this.executor.TryExecuteAsync(input, this.listenerCancellationTokenSource.Token).ConfigureAwait(false);

if (!result.Succeeded)
{
Expand Down Expand Up @@ -143,6 +150,11 @@ public Task StopAsync(CancellationToken cancellationToken)
this.channel.BasicCancel(this.consumerTag);
this.channel.Close();

if (!this.drainModeManager.IsDrainModeEnabled)
{
this.listenerCancellationTokenSource.Cancel();
}

this.listenerState = ListenerStopped;
this.logger.LogDebug($"Stopped RabbitMQ trigger listener for {this.logDetails}.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,14 @@ namespace Microsoft.Azure.WebJobs;
/// <summary>
/// Attribute used to bind a parameter to RabbitMQ trigger message.
/// </summary>
/// <remarks>
/// Initializes a new instance of the <see cref="RabbitMQTriggerAttribute"/> class.
/// </remarks>
/// <param name="queueName">RabbitMQ queue name.</param>
[AttributeUsage(AttributeTargets.Parameter)]
[Binding]
public sealed class RabbitMQTriggerAttribute : Attribute
public sealed class RabbitMQTriggerAttribute(string queueName) : Attribute
{
/// <summary>
/// Initializes a new instance of the <see cref="RabbitMQTriggerAttribute"/> class.
/// </summary>
/// <param name="queueName">RabbitMQ queue name.</param>
public RabbitMQTriggerAttribute(string queueName)
{
this.QueueName = queueName;
}

/// <summary>
/// Gets or sets the setting name for RabbitMQ connection URI.
/// </summary>
Expand All @@ -31,7 +26,7 @@ public RabbitMQTriggerAttribute(string queueName)
/// <summary>
/// Gets the RabbitMQ queue name.
/// </summary>
public string QueueName { get; private set; }
public string QueueName { get; private set; } = queueName;

/// <summary>
/// Gets or sets a value indicating whether certificate validation should be disabled. Not recommended for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,20 @@

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

internal class RabbitMQTriggerAttributeBindingProvider : ITriggerBindingProvider
internal class RabbitMQTriggerAttributeBindingProvider(
INameResolver nameResolver,
RabbitMQExtensionConfigProvider provider,
ILogger logger,
IOptions<RabbitMQOptions> options,
IConfiguration configuration,
IDrainModeManager drainModeManager) : ITriggerBindingProvider
{
private readonly INameResolver nameResolver;
private readonly RabbitMQExtensionConfigProvider provider;
private readonly ILogger logger;
private readonly IOptions<RabbitMQOptions> options;
private readonly IConfiguration configuration;

public RabbitMQTriggerAttributeBindingProvider(
INameResolver nameResolver,
RabbitMQExtensionConfigProvider provider,
ILogger logger,
IOptions<RabbitMQOptions> options,
IConfiguration configuration)
{
this.nameResolver = nameResolver ?? throw new ArgumentNullException(nameof(nameResolver));
this.provider = provider ?? throw new ArgumentNullException(nameof(provider));
this.logger = logger ?? throw new ArgumentNullException(nameof(logger));
this.options = options;
this.configuration = configuration;
}
private readonly INameResolver nameResolver = nameResolver ?? throw new ArgumentNullException(nameof(nameResolver));
private readonly RabbitMQExtensionConfigProvider provider = provider ?? throw new ArgumentNullException(nameof(provider));
private readonly ILogger logger = logger ?? throw new ArgumentNullException(nameof(logger));
private readonly IOptions<RabbitMQOptions> options = options;
private readonly IConfiguration configuration = configuration;
private readonly IDrainModeManager drainModeManager = drainModeManager;

public Task<ITriggerBinding> TryCreateAsync(TriggerBindingProviderContext context)
{
Expand All @@ -52,7 +45,7 @@ public Task<ITriggerBinding> TryCreateAsync(TriggerBindingProviderContext contex

IRabbitMQService service = this.provider.GetService(connectionString, queueName, disableCertificateValidation);

return Task.FromResult<ITriggerBinding>(new RabbitMQTriggerBinding(service, queueName, this.logger, parameter.ParameterType, this.options.Value.PrefetchCount));
return Task.FromResult<ITriggerBinding>(new RabbitMQTriggerBinding(service, queueName, this.logger, parameter.ParameterType, this.options.Value.PrefetchCount, this.drainModeManager));
}

private string Resolve(string name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.Azure.WebJobs.Host;
using Microsoft.Azure.WebJobs.Host.Bindings;
using Microsoft.Azure.WebJobs.Host.Listeners;
using Microsoft.Azure.WebJobs.Host.Protocols;
Expand All @@ -14,27 +15,18 @@

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ;

internal class RabbitMQTriggerBinding : ITriggerBinding
internal class RabbitMQTriggerBinding(IRabbitMQService service, string queueName, ILogger logger, Type parameterType, ushort prefetchCount, IDrainModeManager drainModeManager) : ITriggerBinding
{
private readonly IRabbitMQService service;
private readonly ILogger logger;
private readonly Type parameterType;
private readonly string queueName;
private readonly ushort prefetchCount;

public RabbitMQTriggerBinding(IRabbitMQService service, string queueName, ILogger logger, Type parameterType, ushort prefetchCount)
{
this.service = service;
this.queueName = queueName;
this.logger = logger;
this.parameterType = parameterType;
this.prefetchCount = prefetchCount;
this.BindingDataContract = CreateBindingDataContract();
}
private readonly IRabbitMQService service = service;
private readonly ILogger logger = logger;
private readonly Type parameterType = parameterType;
private readonly string queueName = queueName;
private readonly ushort prefetchCount = prefetchCount;
private readonly IDrainModeManager drainModeManager = drainModeManager;

public Type TriggerValueType => typeof(BasicDeliverEventArgs);

public IReadOnlyDictionary<string, Type> BindingDataContract { get; } = new Dictionary<string, Type>();
public IReadOnlyDictionary<string, Type> BindingDataContract { get; } = CreateBindingDataContract();

public Task<ITriggerData> BindAsync(object value, ValueBindingContext context)
{
Expand All @@ -54,7 +46,8 @@ public Task<IListener> CreateListenerAsync(ListenerFactoryContext context)
this.logger,
context.Descriptor.Id,
this.queueName,
this.prefetchCount));
this.prefetchCount,
this.drainModeManager));
}

public ParameterDescriptor ToParameterDescriptor()
Expand Down