diff --git a/src/NATS.Extensions.Microsoft.DependencyInjection/NatsBuilder.cs b/src/NATS.Extensions.Microsoft.DependencyInjection/NatsBuilder.cs index cc9bc7877..a90289697 100644 --- a/src/NATS.Extensions.Microsoft.DependencyInjection/NatsBuilder.cs +++ b/src/NATS.Extensions.Microsoft.DependencyInjection/NatsBuilder.cs @@ -9,9 +9,10 @@ namespace NATS.Extensions.Microsoft.DependencyInjection; public class NatsBuilder { private readonly IServiceCollection _services; + private int _poolSize = 1; - private Func? _configureOpts; - private Action? _configureConnection; + private Func? _configureOpts; + private Action? _configureConnection; private object? _diKey = null; public NatsBuilder(IServiceCollection services) @@ -20,36 +21,51 @@ public NatsBuilder(IServiceCollection services) public NatsBuilder WithPoolSize(int size) { _poolSize = Math.Max(size, 1); + return this; } - public NatsBuilder ConfigureOptions(Func optsFactory) + public NatsBuilder ConfigureOptions(Func optsFactory) => + ConfigureOptions((_, opts) => optsFactory(opts)); + + public NatsBuilder ConfigureOptions(Func optsFactory) { - var previousFactory = _configureOpts; - _configureOpts = opts => + var configure = _configureOpts; + _configureOpts = (serviceProvider, opts) => { - // Apply the previous configurator if it exists. - if (previousFactory != null) - { - opts = previousFactory(opts); - } + opts = configure?.Invoke(serviceProvider, opts) ?? opts; - // Then apply the new configurator. - return optsFactory(opts); + return optsFactory(serviceProvider, opts); }; + return this; } - public NatsBuilder ConfigureConnection(Action connectionOpts) + public NatsBuilder ConfigureConnection(Action configureConnection) => + ConfigureConnection((_, con) => configureConnection(con)); + + public NatsBuilder ConfigureConnection(Action configureConnection) { - _configureConnection = connectionOpts; + var configure = _configureConnection; + _configureConnection = (serviceProvider, connection) => + { + configure?.Invoke(serviceProvider, connection); + + configureConnection(serviceProvider, connection); + }; + return this; } - public NatsBuilder AddJsonSerialization(JsonSerializerContext context) - => ConfigureOptions(opts => + public NatsBuilder AddJsonSerialization(JsonSerializerContext context) => + AddJsonSerialization(_ => context); + + public NatsBuilder AddJsonSerialization(Func contextFactory) + => ConfigureOptions((serviceProvider, opts) => { - var jsonRegistry = new NatsJsonContextSerializerRegistry(context); + var context = contextFactory(serviceProvider); + NatsJsonContextSerializerRegistry jsonRegistry = new(context); + return opts with { SerializerRegistry = jsonRegistry }; }); @@ -57,6 +73,7 @@ public NatsBuilder AddJsonSerialization(JsonSerializerContext context) public NatsBuilder WithKey(object key) { _diKey = key; + return this; } #endif @@ -117,18 +134,18 @@ private static NatsConnection PooledConnectionFactory(IServiceProvider provider, private NatsConnectionPool PoolFactory(IServiceProvider provider, object? diKey = null) { var options = NatsOpts.Default with { LoggerFactory = provider.GetRequiredService() }; - options = _configureOpts?.Invoke(options) ?? options; + options = _configureOpts?.Invoke(provider, options) ?? options; - return new NatsConnectionPool(_poolSize, options, _configureConnection ?? (_ => { })); + return new NatsConnectionPool(_poolSize, options, con => _configureConnection?.Invoke(provider, con)); } private NatsConnection SingleConnectionFactory(IServiceProvider provider, object? diKey = null) { var options = NatsOpts.Default with { LoggerFactory = provider.GetRequiredService() }; - options = _configureOpts?.Invoke(options) ?? options; + options = _configureOpts?.Invoke(provider, options) ?? options; var conn = new NatsConnection(options); - _configureConnection?.Invoke(conn); + _configureConnection?.Invoke(provider, conn); return conn; } diff --git a/tests/NATS.Extensions.Microsoft.DependencyInjection.Tests/MyResolvedService.cs b/tests/NATS.Extensions.Microsoft.DependencyInjection.Tests/MyResolvedService.cs new file mode 100644 index 000000000..ca0ab9373 --- /dev/null +++ b/tests/NATS.Extensions.Microsoft.DependencyInjection.Tests/MyResolvedService.cs @@ -0,0 +1,11 @@ +namespace NATS.Extensions.Microsoft.DependencyInjection.Tests; + +internal interface IMyResolvedService +{ + string GetValue(); +} + +internal class MyResolvedService(string value) : IMyResolvedService +{ + public string GetValue() => value; +} diff --git a/tests/NATS.Extensions.Microsoft.DependencyInjection.Tests/NatsHostingExtensionsTests.cs b/tests/NATS.Extensions.Microsoft.DependencyInjection.Tests/NatsHostingExtensionsTests.cs index 9e1ec9026..a2a62a5cc 100644 --- a/tests/NATS.Extensions.Microsoft.DependencyInjection.Tests/NatsHostingExtensionsTests.cs +++ b/tests/NATS.Extensions.Microsoft.DependencyInjection.Tests/NatsHostingExtensionsTests.cs @@ -80,6 +80,57 @@ public Task AddNatsClient_ConfigureOptionsSetsUrl() return Task.CompletedTask; } + [Fact] + public Task AddNatsClient_ConfigureOptionsSetsUrlResolvesServices() + { + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddSingleton(new MyResolvedService("url-set")); + services.AddNatsClient(nats => nats + .ConfigureOptions((_, opts) => opts) // Add multiple to test chaining + .ConfigureOptions((serviceProvider, opts) => + { + opts = opts with + { + Url = serviceProvider.GetRequiredService().GetValue(), + }; + + return opts; + })); + + var provider = services.BuildServiceProvider(); + var nats = provider.GetRequiredService(); + + Assert.Equal("url-set", nats.Opts.Url); + + return Task.CompletedTask; + } + + [Fact] + public async Task AddNatsClient_ConfigureConnectionResolvesServices() + { + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddSingleton(new MyResolvedService("url-set")); + services.AddNatsClient(nats => nats + .ConfigureConnection((_, _) => { }) // Add multiple to test chaining + .ConfigureConnection((serviceProvider, conn) => + { + conn.OnConnectingAsync = async instance => + { + var resolved = serviceProvider.GetRequiredService().GetValue(); + + return (resolved, instance.Port); + }; + })); + + var provider = services.BuildServiceProvider(); + var nats = provider.GetRequiredService(); + + (var host, var _) = await nats.OnConnectingAsync!((Host: "host", Port: 123)); + Assert.Equal("url-set", host); + } + #if NET8_0_OR_GREATER [Fact] public void AddNats_RegistersKeyedNatsConnection_WhenKeyIsProvided()