diff --git a/src/System.CommandLine.Tests/Invocation/CommandHandlerTests.cs b/src/System.CommandLine.Tests/Invocation/CommandHandlerTests.cs index f799649ddc..d6289d5c5d 100644 --- a/src/System.CommandLine.Tests/Invocation/CommandHandlerTests.cs +++ b/src/System.CommandLine.Tests/Invocation/CommandHandlerTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.CommandLine.Binding; +using System.CommandLine.Builder; using System.CommandLine.Invocation; using System.CommandLine.IO; using System.CommandLine.Parsing; @@ -322,7 +323,6 @@ public async Task Method_parameters_of_type_InvocationContext_receive_the_curren boundContext.ParseResult.ValueForOption(option).Should().Be(123); } - private class ExecuteTestClass { public string boundName = default; @@ -507,5 +507,111 @@ public class OverridenVirtualTestCommandHandler : VirtualTestCommandHandler public override Task InvokeAsync(InvocationContext context) => Task.FromResult(41); } + + [Fact] + public static void FromBindingContext_forwards_invocation_to_bound_handler_type() + { + var command = new RootCommand + { + Handler = CommandHandler.FromBindingContext() + }; + var parser = new CommandLineBuilder(command) + .ConfigureBindingContext(context => context.AddService()) + .Build(); + + var console = new TestConsole(); + parser.Invoke(Array.Empty(), console); + console.Out.ToString().Should().Be(typeof(BindingContextResolvedCommandHandler).FullName); + } + + [Fact] + public static void FromBindingContext_returns_a_wrapper_type_instance() + { + ICommandHandler handler = CommandHandler.FromBindingContext(); + handler.Should().NotBeOfType(); + } + + [Fact] + public static void Subsequent_call_to_configure_overrides_service_registration() + { + ICommandHandler invokedHandler = null; + BindingContextCommandHandlerAction action = (handler, Console) => + { + invokedHandler = handler; + }; + var parser = new CommandLineBuilder(new RootCommand + { + Handler = CommandHandler.FromBindingContext() + }) + .ConfigureBindingContext(context => context.AddService(_ => action)) + .ConfigureBindingContext(context => context.AddService()) + .ConfigureBindingContext(context => context.AddService()) + .Build(); + parser.Invoke(Array.Empty(), new TestConsole()); + + invokedHandler.Should().NotBeNull(); + invokedHandler.Should().BeOfType(); + } + + public class BindingContextResolvedCommandHandler : ICommandHandler + { + public BindingContextResolvedCommandHandler(IConsole console) + { + Console = console; + } + + public IConsole Console { get; } + + public Task InvokeAsync(InvocationContext context) + { + Console.Out.Write(GetType().FullName); + return Task.FromResult(0); + } + } + + public interface IBindingContextCommandHandlerInterface : ICommandHandler + { + } + + public class BindingContextCommandHandler1 : IBindingContextCommandHandlerInterface + { + private readonly BindingContextCommandHandlerAction invokeAction; + + public BindingContextCommandHandler1(IConsole console, + BindingContextCommandHandlerAction invokeAction) + { + Console = console; + this.invokeAction = invokeAction; + } + + public IConsole Console { get; } + public Task InvokeAsync(InvocationContext context) + { + invokeAction(this, Console); + return Task.FromResult(0); + } + } + + public class BindingContextCommandHandler2 : IBindingContextCommandHandlerInterface + { + private readonly BindingContextCommandHandlerAction invokeAction; + + public BindingContextCommandHandler2(IConsole console, + BindingContextCommandHandlerAction invokeAction) + { + Console = console; + this.invokeAction = invokeAction; + } + + public IConsole Console { get; } + + public Task InvokeAsync(InvocationContext context) + { + invokeAction(this, Console); + return Task.FromResult(0); + } + } + + public delegate void BindingContextCommandHandlerAction(ICommandHandler handler, IConsole console); } } diff --git a/src/System.CommandLine/Binding/BindingContext.cs b/src/System.CommandLine/Binding/BindingContext.cs index a80a417287..1bd2b31685 100644 --- a/src/System.CommandLine/Binding/BindingContext.cs +++ b/src/System.CommandLine/Binding/BindingContext.cs @@ -51,7 +51,7 @@ public IConsole Console internal ServiceProvider ServiceProvider { get; } - public void AddModelBinder(ModelBinder binder) => + public void AddModelBinder(ModelBinder binder) => _modelBindersByValueDescriptor.Add(binder.ValueDescriptor.ValueType, binder); public ModelBinder GetModelBinder(IValueDescriptor valueDescriptor) @@ -65,19 +65,38 @@ public ModelBinder GetModelBinder(IValueDescriptor valueDescriptor) public void AddService(Type serviceType, Func factory) { + _ = serviceType ?? throw new ArgumentNullException(nameof(serviceType)); + _ = factory ?? throw new ArgumentNullException(nameof(factory)); ServiceProvider.AddService(serviceType, factory); } - + public void AddService(Func factory) { - if (factory is null) + _ = factory ?? throw new ArgumentNullException(nameof(factory)); + ServiceProvider.AddService(typeof(T), s => factory(s)); + } + + public void AddService(Type serviceType, Type? implementationType = null) + { + _ = serviceType ?? throw new ArgumentNullException(nameof(serviceType)); + implementationType ??= serviceType; + object factory(IServiceProvider serviceProvider) { - throw new ArgumentNullException(nameof(factory)); + var bindingContext = + serviceProvider.GetService(typeof(BindingContext)) as BindingContext + ?? this; + var valueDescriptor = new ModelBinder.AnonymousValueDescriptor(implementationType); + var modelBinder = bindingContext.GetModelBinder(valueDescriptor); + return modelBinder.CreateInstance(bindingContext)!; } - - ServiceProvider.AddService(typeof(T), s => factory(s)); + AddService(serviceType, factory); } + public void AddService() => + AddService(typeof(TService), typeof(TImplementation)); + + public void AddService() => AddService(); + internal bool TryGetValueSource( IValueDescriptor valueDescriptor, [MaybeNullWhen(false)] out IValueSource valueSource) @@ -108,8 +127,8 @@ internal bool TryBindToScalarValue( else { var parsed = ArgumentConverter.ConvertObject( - valueDescriptor as IArgument ?? new Argument(valueDescriptor.ValueName), - valueDescriptor.ValueType, + valueDescriptor as IArgument ?? new Argument(valueDescriptor.ValueName), + valueDescriptor.ValueType, value, resources); diff --git a/src/System.CommandLine/Builder/CommandLineBuilderExtensions.cs b/src/System.CommandLine/Builder/CommandLineBuilderExtensions.cs index ee966c9611..c4b2393d7c 100644 --- a/src/System.CommandLine/Builder/CommandLineBuilderExtensions.cs +++ b/src/System.CommandLine/Builder/CommandLineBuilderExtensions.cs @@ -143,6 +143,18 @@ public static CommandLineBuilder ConfigureConsole( return builder; } + public static CommandLineBuilder ConfigureBindingContext( + this CommandLineBuilder builder, + Action configureBindingContext) + { + builder.AddMiddleware((context, next) => + { + configureBindingContext?.Invoke(context.BindingContext); + return next(context); + }, default(MiddlewareOrder)); + return builder; + } + public static CommandLineBuilder EnableDirectives( this CommandLineBuilder builder, bool value = true) diff --git a/src/System.CommandLine/Invocation/CommandHandler.cs b/src/System.CommandLine/Invocation/CommandHandler.cs index 6b6464852e..ab1e16b910 100644 --- a/src/System.CommandLine/Invocation/CommandHandler.cs +++ b/src/System.CommandLine/Invocation/CommandHandler.cs @@ -282,6 +282,10 @@ public static ICommandHandler Create> action) => HandlerDescriptor.FromDelegate(action).GetCommandHandler(); + public static ICommandHandler FromBindingContext() + where THandler : ICommandHandler => + Create((InvocationContext context, THandler handler) => handler.InvokeAsync(context)); + internal static async Task GetExitCodeAsync(object value, InvocationContext context) { switch (value)