Skip to content

Commit

Permalink
Refactor NatsBuilder service registration logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mtmk committed Dec 5, 2024
1 parent 752a9e8 commit 5175afe
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 12 deletions.
38 changes: 28 additions & 10 deletions src/NATS.Extensions.Microsoft.DependencyInjection/NatsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using NATS.Client.Core;
using NATS.Net;

Expand Down Expand Up @@ -82,12 +83,33 @@ public NatsBuilder WithKey(object key)
}
#endif

/// <summary>
/// Override the default <see cref="BoundedChannelFullMode"/> for the pending messages channel.
/// </summary>
/// <param name="pending">Full mode for the subscription channel.</param>
/// <returns>Builder to allow method chaining.</returns>
/// <remarks>
/// This will be applied to options overriding values set for <c>SubPendingChannelFullMode</c> in options.
/// By default, the pending messages channel will wait for space to be available when full.
/// Note that this is not the same as <c>NatsOpts</c> default <c>SubPendingChannelFullMode</c> which is <c>DropNewest</c>.
/// </remarks>
public NatsBuilder WithSubPendingChannelFullMode(BoundedChannelFullMode pending)
{
_pending = pending;
return this;
}

/// <summary>
/// Override the default <see cref="INatsSerializerRegistry"/> for the options.
/// </summary>
/// <param name="registry">Serializer registry to use.</param>
/// <returns>Builder to allow method chaining.</returns>
/// <remarks>
/// This will be applied to options overriding values set for <c>SerializerRegistry</c> in options.
/// By default, NatsClient registry will be used which allows ad-hoc JSON serialization.
/// Note that this is not the same as <c>NatsOpts</c> default <c>SerializerRegistry</c> which
/// doesn't do ad-hoc JSON serialization.
/// </remarks>
public NatsBuilder WithSerializerRegistry(INatsSerializerRegistry registry)
{
_serializerRegistry = registry;
Expand All @@ -104,18 +126,16 @@ internal IServiceCollection Build()
_services.TryAddSingleton<INatsConnectionPool>(static provider => provider.GetRequiredService<NatsConnectionPool>());
_services.TryAddTransient<NatsConnection>(static provider => PooledConnectionFactory(provider, null));
_services.TryAddTransient<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
_services.TryAddTransient<NatsClient>(static provider => new NatsClient(provider.GetRequiredService<NatsConnection>()));
_services.TryAddTransient<INatsClient>(static provider => provider.GetRequiredService<NatsClient>());
_services.TryAddTransient<INatsClient>(static provider => provider.GetRequiredService<NatsConnection>());
}
else
{
#if NET8_0_OR_GREATER
_services.TryAddKeyedSingleton<NatsConnectionPool>(_diKey, PoolFactory);
_services.TryAddKeyedSingleton<INatsConnectionPool>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnectionPool>(key));
_services.TryAddKeyedTransient(_diKey, PooledConnectionFactory);
_services.TryAddKeyedTransient<NatsConnection>(_diKey, PooledConnectionFactory);
_services.TryAddKeyedTransient<INatsConnection>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
_services.TryAddKeyedTransient<NatsClient>(_diKey, static (provider, key) => new NatsClient(provider.GetRequiredKeyedService<NatsConnection>(key)));
_services.TryAddKeyedTransient<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsClient>(key));
_services.TryAddKeyedTransient<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
#endif
}
}
Expand All @@ -125,16 +145,14 @@ internal IServiceCollection Build()
{
_services.TryAddSingleton<NatsConnection>(provider => SingleConnectionFactory(provider));
_services.TryAddSingleton<INatsConnection>(static provider => provider.GetRequiredService<NatsConnection>());
_services.TryAddSingleton<NatsClient>(static provider => new NatsClient(provider.GetRequiredService<NatsConnection>()));
_services.TryAddSingleton<INatsClient>(static provider => provider.GetRequiredService<NatsClient>());
_services.TryAddSingleton<INatsClient>(static provider => provider.GetRequiredService<NatsConnection>());
}
else
{
#if NET8_0_OR_GREATER
_services.TryAddKeyedSingleton(_diKey, SingleConnectionFactory);
_services.TryAddKeyedSingleton<INatsConnection>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
_services.TryAddKeyedSingleton<NatsClient>(_diKey, static (provider, key) => new NatsClient(provider.GetRequiredKeyedService<NatsConnection>(key)));
_services.TryAddKeyedSingleton<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsClient>(key));
_services.TryAddKeyedSingleton<INatsClient>(_diKey, static (provider, key) => provider.GetRequiredKeyedService<NatsConnection>(key));
#endif
}
}
Expand Down Expand Up @@ -174,7 +192,7 @@ private NatsConnection SingleConnectionFactory(IServiceProvider provider, object

private NatsOpts GetNatsOpts(IServiceProvider provider)
{
var options = NatsOpts.Default with { LoggerFactory = provider.GetRequiredService<ILoggerFactory>() };
var options = NatsOpts.Default with { LoggerFactory = provider.GetService<ILoggerFactory>() ?? NullLoggerFactory.Instance };
options = _configureOpts?.Invoke(provider, options) ?? options;

if (_serializerRegistry != null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Threading.Channels;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
Expand Down Expand Up @@ -25,11 +26,10 @@ public void AddNatsClient_RegistersNatsConnectionAsSingleton_WhenPoolSizeIsOne()

var natsClient1 = provider.GetRequiredService<INatsClient>();
var natsClient2 = provider.GetRequiredService<INatsClient>();
var natsClient3 = provider.GetRequiredService<NatsClient>();

Assert.NotNull(natsClient1);
Assert.Same(natsClient1, natsClient2);
Assert.Same(natsClient1, natsClient3);
Assert.Same(natsClient1, natsConnection1); // Same Connection implements INatsClient
Assert.Same(natsClient1.Connection, natsConnection1);
}

Expand All @@ -43,9 +43,138 @@ public void AddNatsClient_RegistersNatsConnectionAsTransient_WhenPoolSizeIsGreat
var provider = services.BuildServiceProvider();
var natsConnection1 = provider.GetRequiredService<INatsConnection>();
var natsConnection2 = provider.GetRequiredService<INatsConnection>();
var natsConnection3 = provider.GetRequiredService<INatsConnection>();
var natsConnection4 = provider.GetRequiredService<INatsConnection>();

Assert.NotNull(natsConnection1);
Assert.NotSame(natsConnection1, natsConnection2); // Transient should return different instances
Assert.NotSame(natsConnection3, natsConnection4);
Assert.Same(natsConnection1, natsConnection3); // The pool is round-robin
Assert.Same(natsConnection2, natsConnection4);

var natsClient1 = provider.GetRequiredService<INatsClient>();
var natsClient2 = provider.GetRequiredService<INatsClient>();
var natsClient3 = provider.GetRequiredService<INatsClient>();
var natsClient4 = provider.GetRequiredService<INatsClient>();

Assert.NotNull(natsClient1);
Assert.NotSame(natsClient1, natsClient2);
Assert.NotSame(natsClient3, natsClient4);
Assert.Same(natsClient1, natsClient3);
Assert.Same(natsClient2, natsClient4);
Assert.Same(natsClient1, natsConnection1);
Assert.Same(natsClient1.Connection, natsConnection1);
}

[Fact]
public Task AddNatsClient_OptionsWithDefaults()
{
var services = new ServiceCollection();
services.AddNatsClient();

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

Assert.Same(NullLoggerFactory.Instance, nats.Opts.LoggerFactory);

// These defaults are different from NatsOptions defaults but same as NatsClient defaults
// for ease of use for new users
Assert.Same(NatsClientDefaultSerializerRegistry.Default, nats.Opts.SerializerRegistry);
Assert.Equal(BoundedChannelFullMode.Wait, nats.Opts.SubPendingChannelFullMode);

return Task.CompletedTask;
}

[Fact]
public Task AddNatsClient_WithDefaultSerializerExplicitlySet()
{
var services = new ServiceCollection();
services.AddNatsClient(nats =>
{
// These two settings make the options same as NatsOptions defaults
nats.WithSerializerRegistry(NatsDefaultSerializerRegistry.Default)
.WithSubPendingChannelFullMode(BoundedChannelFullMode.DropNewest);
});

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

Assert.Same(NatsDefaultSerializerRegistry.Default, nats.Opts.SerializerRegistry);
Assert.Equal(BoundedChannelFullMode.DropNewest, nats.Opts.SubPendingChannelFullMode);

return Task.CompletedTask;
}

[Fact]
public Task AddNatsClient_WithSerializerExplicitlySet()
{
var mySerializerRegistry = new NatsJsonContextSerializerRegistry(MyJsonContext.Default);

var services = new ServiceCollection();
services.AddNatsClient(nats =>
{
nats.ConfigureOptions(opts => opts with { SerializerRegistry = mySerializerRegistry });
});

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

Assert.Same(mySerializerRegistry, nats.Opts.SerializerRegistry);

// You can only override this using .WithSubPendingChannelFullMode() on builder above
Assert.Equal(BoundedChannelFullMode.Wait, nats.Opts.SubPendingChannelFullMode);

return Task.CompletedTask;
}

[Fact]
public async Task AddNatsClient_WithDefaultSerializer()
{
await using var server = NatsServer.Start();
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
var cancellationToken = cts.Token;

// Default JSON serialization
{
var services = new ServiceCollection();
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();
services.AddNatsClient(nats =>
{
nats.ConfigureOptions(opts => server.ClientOpts(opts));
});

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

// Ad-hoc JSON serialization
await using var sub = await nats.SubscribeCoreAsync<MyAdHocData>("foo", cancellationToken: cancellationToken);
await nats.PingAsync(cancellationToken);
await nats.PublishAsync("foo", new MyAdHocData(1, "bar"), cancellationToken: cancellationToken);

var msg = await sub.Msgs.ReadAsync(cancellationToken);
Assert.Equal(1, msg.Data?.Id);
Assert.Equal("bar", msg.Data?.Name);
}

// Default raw serialization
{
var services = new ServiceCollection();
services.AddSingleton<ILoggerFactory, NullLoggerFactory>();
services.AddNatsClient(nats =>
{
nats.ConfigureOptions(opts => server.ClientOpts(opts));
nats.WithSerializerRegistry(NatsDefaultSerializerRegistry.Default);
});

var provider = services.BuildServiceProvider();
var nats = provider.GetRequiredService<INatsConnection>();

var exception = await Assert.ThrowsAsync<NatsException>(async () =>
{
await nats.PublishAsync("foo", new MyAdHocData(1, "bar"), cancellationToken: cancellationToken);
});
Assert.Matches("Can't serialize.*MyAdHocData", exception.Message);
}
}

[Fact]
Expand Down Expand Up @@ -218,3 +347,5 @@ public void AddNats_RegistersKeyedNatsConnection_WhenKeyIsProvided_pooled()
}
#endif
}

public record MyAdHocData(int Id, string Name);

0 comments on commit 5175afe

Please sign in to comment.