diff --git a/Npgmq.Example/Program.cs b/Npgmq.Example/Program.cs index 72ce1dc..d26d62d 100644 --- a/Npgmq.Example/Program.cs +++ b/Npgmq.Example/Program.cs @@ -1,29 +1,61 @@ using System.Reflection; using Microsoft.Extensions.Configuration; using Npgmq; +using Npgsql; var configuration = new ConfigurationBuilder() .AddEnvironmentVariables() .AddUserSecrets(Assembly.GetExecutingAssembly()) .Build(); -var npgmq = new NpgmqClient(configuration.GetConnectionString("ExampleDB")!); +var connectionString = configuration.GetConnectionString("ExampleDB")!; -await npgmq.InitAsync(); -await npgmq.CreateQueueAsync("example_queue"); - -var msgId = await npgmq.SendAsync("example_queue", new MyMessageType +// Test Npgmq with connection string { - Foo = "Test", - Bar = 123 -}); -Console.WriteLine($"Sent message with id {msgId}"); + var npgmq = new NpgmqClient(connectionString); + + await npgmq.InitAsync(); + await npgmq.CreateQueueAsync("example_queue"); + + var msgId = await npgmq.SendAsync("example_queue", new MyMessageType + { + Foo = "Connection string test", + Bar = 1 + }); + Console.WriteLine($"Sent message with id {msgId}"); + + var msg = await npgmq.ReadAsync("example_queue"); + if (msg != null) + { + Console.WriteLine($"Read message with id {msg.MsgId}: Foo = {msg.Message?.Foo}, Bar = {msg.Message?.Bar}"); + await npgmq.ArchiveAsync("example_queue", msg.MsgId); + } +} -var msg = await npgmq.ReadAsync("example_queue"); -if (msg != null) +// Test Npgmq with connection object and a transaction { - Console.WriteLine($"Read message with id {msg.MsgId}: Foo = {msg.Message?.Foo}, Bar = {msg.Message?.Bar}"); - await npgmq.ArchiveAsync("example_queue", msg.MsgId); + await using var connection = new NpgsqlConnection(connectionString); + await connection.OpenAsync(); + var npgmq = new NpgmqClient(connection); + + await using (var tx = connection.BeginTransaction()) + { + var msgId = await npgmq.SendAsync("example_queue", new MyMessageType + { + Foo = "Connection object test", + Bar = 2 + }); + Console.WriteLine($"Sent message with id {msgId}"); + + await tx.CommitAsync(); + } + + var msg = await npgmq.ReadAsync("example_queue"); + if (msg != null) + { + Console.WriteLine($"Read message with id {msg.MsgId}: Foo = {msg.Message?.Foo}, Bar = {msg.Message?.Bar}"); + await npgmq.ArchiveAsync("example_queue", msg.MsgId); + } } internal class MyMessageType diff --git a/Npgmq.Test/NpgmqClientTest.cs b/Npgmq.Test/NpgmqClientTest.cs index 22ab577..66b9e77 100644 --- a/Npgmq.Test/NpgmqClientTest.cs +++ b/Npgmq.Test/NpgmqClientTest.cs @@ -29,9 +29,8 @@ public NpgmqClientTest() .Build(); _connectionString = configuration.GetConnectionString("Test")!; - _connection = new NpgsqlConnection(_connectionString); - _sut = new NpgmqClient(_connectionString); + _sut = new NpgmqClient(_connection); } public void Dispose() @@ -545,6 +544,47 @@ public async Task ReadBatchAsync_should_return_list_of_messages() }); } + [Fact] + public async Task Client_s() + { + // Arrange + await ResetTestQueueAsync(); + + // Act + var msgId = await _sut.SendAsync(TestQueueName, new TestMessage + { + Foo = 123, + Bar = "Test", + Baz = DateTimeOffset.Parse("2023-09-01T01:23:45-04:00") + }); + + // Assert + Assert.Equal(1, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + Assert.Equal(1, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName} WHERE vt <= CURRENT_TIMESTAMP;")); + Assert.Equal(msgId, await _connection.ExecuteScalarAsync($"SELECT msg_id FROM pgmq.q_{TestQueueName} LIMIT 1;")); + } + + [Fact] + public async Task ConnectionString_should_be_used_to_connect() + { + // Arrange + await ResetTestQueueAsync(); + var sut2 = new NpgmqClient(_connectionString); + + // Act + var msgId = await sut2.SendAsync(TestQueueName, new TestMessage + { + Foo = 123, + Bar = "Test", + Baz = DateTimeOffset.Parse("2023-09-01T01:23:45-04:00") + }); + + // Assert + Assert.Equal(1, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + Assert.Equal(1, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName} WHERE vt <= CURRENT_TIMESTAMP;")); + Assert.Equal(msgId, await _connection.ExecuteScalarAsync($"SELECT msg_id FROM pgmq.q_{TestQueueName} LIMIT 1;")); + } + [Fact] public async Task SendAsync_should_add_message() { @@ -565,6 +605,69 @@ public async Task SendAsync_should_add_message() Assert.Equal(msgId, await _connection.ExecuteScalarAsync($"SELECT msg_id FROM pgmq.q_{TestQueueName} LIMIT 1;")); } + [Fact] + public async Task SendAsync_should_commit_with_database_transaction() + { + // Arrange + await ResetTestQueueAsync(); + await using var connection2 = new NpgsqlConnection(_connectionString); + await connection2.OpenAsync(); + + // Act + await using var transaction = await _connection.BeginTransactionAsync(); + var msgId = await _sut.SendAsync(TestQueueName, new TestMessage + { + Foo = 123, + Bar = "Test", + Baz = DateTimeOffset.Parse("2023-09-01T01:23:45-04:00") + }); + + // Assert + Assert.Equal(1, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + Assert.Equal(0, await connection2.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + Assert.Equal(0, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName} WHERE vt <= CURRENT_TIMESTAMP;")); + Assert.Equal(msgId, await _connection.ExecuteScalarAsync($"SELECT msg_id FROM pgmq.q_{TestQueueName} LIMIT 1;")); + + // Act + await transaction.CommitAsync(); + + // Assert + Assert.Equal(1, await connection2.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + Assert.Equal(1, await connection2.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName} WHERE vt <= CURRENT_TIMESTAMP;")); + Assert.Equal(msgId, await connection2.ExecuteScalarAsync($"SELECT msg_id FROM pgmq.q_{TestQueueName} LIMIT 1;")); + } + + [Fact] + public async Task SendAsync_should_rollback_with_database_transaction() + { + // Arrange + await ResetTestQueueAsync(); + await using var connection2 = new NpgsqlConnection(_connectionString); + await connection2.OpenAsync(); + + // Act + await using var transaction = await _connection.BeginTransactionAsync(); + var msgId = await _sut.SendAsync(TestQueueName, new TestMessage + { + Foo = 123, + Bar = "Test", + Baz = DateTimeOffset.Parse("2023-09-01T01:23:45-04:00") + }); + + // Assert + Assert.Equal(1, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + Assert.Equal(0, await connection2.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + Assert.Equal(0, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName} WHERE vt <= CURRENT_TIMESTAMP;")); + Assert.Equal(msgId, await _connection.ExecuteScalarAsync($"SELECT msg_id FROM pgmq.q_{TestQueueName} LIMIT 1;")); + + // Act + await transaction.RollbackAsync(); + + // Assert + Assert.Equal(0, await _connection.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + Assert.Equal(0, await connection2.ExecuteScalarAsync($"SELECT count(*) FROM pgmq.q_{TestQueueName};")); + } + [Fact] public async Task SendAsync_should_add_string_message() { diff --git a/Npgmq.sln b/Npgmq.sln index 21e8272..5f0d93a 100644 --- a/Npgmq.sln +++ b/Npgmq.sln @@ -8,6 +8,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution ProjectSection(SolutionItems) = preProject LICENSE = LICENSE README.md = README.md + global.json = global.json EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "workflows", "workflows", "{8C37002D-05C6-4B1F-B4FC-C2F45C5E5328}" diff --git a/Npgmq/NpgmqClient.cs b/Npgmq/NpgmqClient.cs index 76da5e6..2ecbd46 100644 --- a/Npgmq/NpgmqClient.cs +++ b/Npgmq/NpgmqClient.cs @@ -1,4 +1,5 @@ -using System.Data.Common; +using System.Data; +using System.Data.Common; using System.Text.Json; using Npgsql; using NpgsqlTypes; @@ -13,7 +14,7 @@ public class NpgmqClient : INpgmqClient public const int DefaultPollTimeoutSeconds = 5; public const int DefaultPollIntervalMilliseconds = 250; - private readonly string _connectionString; + private readonly NpgmqCommandFactory _commandFactory; /// /// Create a new PGMQ client. @@ -21,179 +22,107 @@ public class NpgmqClient : INpgmqClient /// The connection string. public NpgmqClient(string connectionString) { - _connectionString = connectionString; + _commandFactory = new NpgmqCommandFactory(connectionString); } + /// + /// Create a new PGMQ client using a pre-existing connection. + /// + /// The connection to use. + public NpgmqClient(NpgsqlConnection connection) + { + _commandFactory = new NpgmqCommandFactory(connection); + } + public async Task ArchiveAsync(string queueName, long msgId) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.archive(@queue_name, @msg_id);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("@queue_name", queueName); - cmd.Parameters.AddWithValue("@msg_id", msgId); - var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); - return (bool)result!; - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.archive(@queue_name, @msg_id);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("@queue_name", queueName); + cmd.Parameters.AddWithValue("@msg_id", msgId); + var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); + return (bool)result!; } public async Task> ArchiveBatchAsync(string queueName, IEnumerable msgIds) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.archive(@queue_name, @msg_ids);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("@queue_name", queueName); + cmd.Parameters.AddWithValue("@msg_ids", msgIds); + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var result = new List(); + while (await reader.ReadAsync().ConfigureAwait(false)) { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.archive(@queue_name, @msg_ids);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("@queue_name", queueName); - cmd.Parameters.AddWithValue("@msg_ids", msgIds); - var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); - await using (reader.ConfigureAwait(false)) - { - var result = new List(); - while (await reader.ReadAsync().ConfigureAwait(false)) - { - result.Add(reader.GetInt64(0)); - } - return result; - } - } + result.Add(reader.GetInt64(0)); } + return result; } public async Task CreateQueueAsync(string queueName) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.create(@queue_name);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.create(@queue_name);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); } public async Task CreateUnloggedQueueAsync(string queueName) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.create_unlogged(@queue_name);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.create_unlogged(@queue_name);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); } public async Task DeleteAsync(string queueName, long msgId) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.delete(@queue_name, @msg_id);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("@queue_name", queueName); - cmd.Parameters.AddWithValue("@msg_id", msgId); - var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); - return (bool)result!; - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.delete(@queue_name, @msg_id);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("@queue_name", queueName); + cmd.Parameters.AddWithValue("@msg_id", msgId); + var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); + return (bool)result!; } public async Task> DeleteBatchAsync(string queueName, IEnumerable msgIds) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.delete(@queue_name, @msg_ids);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("@queue_name", queueName); + cmd.Parameters.AddWithValue("@msg_ids", msgIds.ToArray()); + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var result = new List(); + while (await reader.ReadAsync().ConfigureAwait(false)) { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.delete(@queue_name, @msg_ids);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("@queue_name", queueName); - cmd.Parameters.AddWithValue("@msg_ids", msgIds.ToArray()); - var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); - await using (reader.ConfigureAwait(false)) - { - var result = new List(); - while (await reader.ReadAsync().ConfigureAwait(false)) - { - result.Add(reader.GetInt64(0)); - } - return result; - } - } + result.Add(reader.GetInt64(0)); } + return result; } public async Task DropQueueAsync(string queueName) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.drop_queue(@queue_name);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.drop_queue(@queue_name);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); } public async Task InitAsync() { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("CREATE EXTENSION IF NOT EXISTS pgmq CASCADE;", connection); - await using (cmd.ConfigureAwait(false)) - { - await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); - } - } + await using var cmd = await _commandFactory.CreateAsync("CREATE EXTENSION IF NOT EXISTS pgmq CASCADE;").ConfigureAwait(false); + await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); } public async Task> ListQueuesAsync() { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) + await using var cmd = await _commandFactory.CreateAsync("SELECT queue_name, created_at, is_partitioned, is_unlogged FROM pgmq.list_queues();").ConfigureAwait(false); + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var result = new List(); + while (await reader.ReadAsync().ConfigureAwait(false)) { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT queue_name, created_at, is_partitioned, is_unlogged FROM pgmq.list_queues();", connection); - await using (cmd.ConfigureAwait(false)) + result.Add(new NpgmqQueue { - var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); - await using (reader.ConfigureAwait(false)) - { - var result = new List(); - while (await reader.ReadAsync().ConfigureAwait(false)) - { - result.Add(new NpgmqQueue - { - QueueName = reader.GetString(0), - CreatedAt = reader.GetDateTime(1), - IsPartitioned = reader.GetBoolean(2), - IsUnlogged = reader.GetBoolean(3) - }); - } - return result; - } - } + QueueName = reader.GetString(0), + CreatedAt = reader.GetDateTime(1), + IsPartitioned = reader.GetBoolean(2), + IsUnlogged = reader.GetBoolean(3) + }); } + return result; } public async Task?> PollAsync(string queueName, int vt = DefaultVt, int pollTimeoutSeconds = DefaultPollTimeoutSeconds, int pollIntervalMilliseconds = DefaultPollIntervalMilliseconds) where T : class @@ -204,77 +133,39 @@ public async Task> ListQueuesAsync() public async Task>> PollBatchAsync(string queueName, int vt = DefaultVt, int limit = DefaultReadBatchLimit, int pollTimeoutSeconds = DefaultPollTimeoutSeconds, int pollIntervalMilliseconds = DefaultPollIntervalMilliseconds) where T : class { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq.read_with_poll(@queue_name, @vt, @limit, @poll_timeout_s, @poll_interval_ms);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - cmd.Parameters.AddWithValue("vt", vt); - cmd.Parameters.AddWithValue("limit", limit); - cmd.Parameters.AddWithValue("poll_timeout_s", pollTimeoutSeconds); - cmd.Parameters.AddWithValue("poll_interval_ms", pollIntervalMilliseconds); - var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); - await using (reader.ConfigureAwait(false)) - { - return await ReadMessagesAsync(reader).ConfigureAwait(false); - } - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq.read_with_poll(@queue_name, @vt, @limit, @poll_timeout_s, @poll_interval_ms);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + cmd.Parameters.AddWithValue("vt", vt); + cmd.Parameters.AddWithValue("limit", limit); + cmd.Parameters.AddWithValue("poll_timeout_s", pollTimeoutSeconds); + cmd.Parameters.AddWithValue("poll_interval_ms", pollIntervalMilliseconds); + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + return await ReadMessagesAsync(reader).ConfigureAwait(false); } public async Task?> PopAsync(string queueName) where T : class { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq.pop(@queue_name);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); - await using (reader.ConfigureAwait(false)) - { - var result = await ReadMessagesAsync(reader).ConfigureAwait(false); - return result.SingleOrDefault(); - } - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq.pop(@queue_name);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var result = await ReadMessagesAsync(reader).ConfigureAwait(false); + return result.SingleOrDefault(); } public async Task PurgeQueueAsync(string queueName) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.purge_queue(@queue_name);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); - return (long)result!; - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.purge_queue(@queue_name);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); + return (long)result!; } public async Task QueueExistsAsync(string queueName) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT 1 WHERE EXISTS (SELECT * FROM pgmq.list_queues() WHERE queue_name = @queue_name);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); - return (int)(result ?? 0) == 1; - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT 1 WHERE EXISTS (SELECT * FROM pgmq.list_queues() WHERE queue_name = @queue_name);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); + return (int)(result ?? 0) == 1; } public async Task?> ReadAsync(string queueName, int vt = DefaultVt) where T : class @@ -285,23 +176,12 @@ public async Task QueueExistsAsync(string queueName) public async Task>> ReadBatchAsync(string queueName, int vt = DefaultVt, int limit = DefaultReadBatchLimit) where T : class { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq.read(@queue_name, @vt, @limit);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - cmd.Parameters.AddWithValue("vt", vt); - cmd.Parameters.AddWithValue("limit", limit); - var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); - await using (reader.ConfigureAwait(false)) - { - return await ReadMessagesAsync(reader).ConfigureAwait(false); - } - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT msg_id, read_ct, enqueued_at, vt, message FROM pgmq.read(@queue_name, @vt, @limit);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + cmd.Parameters.AddWithValue("vt", vt); + cmd.Parameters.AddWithValue("limit", limit); + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + return await ReadMessagesAsync(reader).ConfigureAwait(false); } public async Task SendAsync(string queueName, T message) where T : class @@ -311,63 +191,36 @@ public async Task SendAsync(string queueName, T message) where T : clas public async Task SendDelayAsync(string queueName, T message, int delay) where T : class { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT * FROM pgmq.send(@queue_name, @message, @delay);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - cmd.Parameters.AddWithValue("message", NpgsqlDbType.Jsonb, SerializeMessage(message)); - cmd.Parameters.AddWithValue("delay", delay); - var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); - return (long)result!; - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT * FROM pgmq.send(@queue_name, @message, @delay);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + cmd.Parameters.AddWithValue("message", NpgsqlDbType.Jsonb, SerializeMessage(message)); + cmd.Parameters.AddWithValue("delay", delay); + var result = await cmd.ExecuteScalarAsync().ConfigureAwait(false); + return (long)result!; } public async Task> SendBatchAsync(string queueName, IEnumerable messages) where T : class { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) + await using var cmd = await _commandFactory.CreateAsync("SELECT * FROM pgmq.send_batch(@queue_name, @messages);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + cmd.Parameters.AddWithValue("messages", NpgsqlDbType.Array | NpgsqlDbType.Jsonb, + messages.Select(SerializeMessage).ToArray()); + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var result = new List(); + while (await reader.ReadAsync().ConfigureAwait(false)) { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT * FROM pgmq.send_batch(@queue_name, @messages);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - cmd.Parameters.AddWithValue("messages", NpgsqlDbType.Array | NpgsqlDbType.Jsonb, - messages.Select(SerializeMessage).ToArray()); - var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); - await using (reader.ConfigureAwait(false)) - { - var result = new List(); - while (await reader.ReadAsync().ConfigureAwait(false)) - { - result.Add(reader.GetInt64(0)); - } - return result; - } - } + result.Add(reader.GetInt64(0)); } + return result; } public async Task SetVtAsync(string queueName, long msgId, int vtOffset) { - var connection = new NpgsqlConnection(_connectionString); - await using (connection.ConfigureAwait(false)) - { - await connection.OpenAsync(); - var cmd = new NpgsqlCommand("SELECT pgmq.set_vt(@queue_name, @msg_id, @vt_offset);", connection); - await using (cmd.ConfigureAwait(false)) - { - cmd.Parameters.AddWithValue("queue_name", queueName); - cmd.Parameters.AddWithValue("msg_id", msgId); - cmd.Parameters.AddWithValue("vt_offset", vtOffset); - await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); - } - } + await using var cmd = await _commandFactory.CreateAsync("SELECT pgmq.set_vt(@queue_name, @msg_id, @vt_offset);").ConfigureAwait(false); + cmd.Parameters.AddWithValue("queue_name", queueName); + cmd.Parameters.AddWithValue("msg_id", msgId); + cmd.Parameters.AddWithValue("vt_offset", vtOffset); + await cmd.ExecuteNonQueryAsync().ConfigureAwait(false); } private static async Task>> ReadMessagesAsync(DbDataReader reader) where T : class diff --git a/Npgmq/NpgmqCommand.cs b/Npgmq/NpgmqCommand.cs new file mode 100644 index 0000000..2779d7c --- /dev/null +++ b/Npgmq/NpgmqCommand.cs @@ -0,0 +1,20 @@ +using System.Data; +using Npgsql; + +namespace Npgmq; + +internal class NpgmqCommand(string commandText, NpgsqlConnection connection, bool disposeConnection) + : NpgsqlCommand(commandText, connection) +{ + public override async ValueTask DisposeAsync() + { + if (disposeConnection && Connection != null) + { + if (Connection.State == ConnectionState.Open) + { + await Connection.CloseAsync().ConfigureAwait(false); + } + await Connection.DisposeAsync().ConfigureAwait(false); + } + } +} \ No newline at end of file diff --git a/Npgmq/NpgmqCommandFactory.cs b/Npgmq/NpgmqCommandFactory.cs new file mode 100644 index 0000000..4c9f0b2 --- /dev/null +++ b/Npgmq/NpgmqCommandFactory.cs @@ -0,0 +1,31 @@ +using System.Data; +using Npgsql; + +namespace Npgmq; + +internal class NpgmqCommandFactory +{ + private readonly string? _connectionString; + private readonly NpgsqlConnection? _connection; + + public NpgmqCommandFactory(NpgsqlConnection connection) + { + _connection = connection; + } + + public NpgmqCommandFactory(string connectionString) + { + _connectionString = connectionString; + } + + public async Task CreateAsync(string commandText) + { + var connection = _connection ?? new NpgsqlConnection(_connectionString ?? throw new InvalidOperationException("No connection or connection string provided.")); + if (connection.State != ConnectionState.Open) + { + await connection.OpenAsync().ConfigureAwait(false); + } + + return new NpgmqCommand(commandText, connection, _connection == null); + } +} diff --git a/README.md b/README.md index 8dd95d2..f530e55 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ public class MyMessageType } ``` -You can also send and read messages as JSON strings: +You can send and read messages as JSON strings, like this: ```csharp var msgId = await npgmq.SendAsync("my_queue", "{\"foo\":\"Test\",\"bar\":123}"); @@ -65,6 +65,13 @@ if (msg != null) } ``` +You can pass your own `NpgsqlConnection` to the `NpgmqClient` constructor, like this: + +```csharp +using var myConnection = new NpgsqlConnection(""); +var npgmq = new NpgmqClient(myConnection); +``` + ## Database Connection Npgmq uses Npgsql internally to connect to the database. diff --git a/global.json b/global.json new file mode 100644 index 0000000..b5b37b6 --- /dev/null +++ b/global.json @@ -0,0 +1,7 @@ +{ + "sdk": { + "version": "8.0.0", + "rollForward": "latestMajor", + "allowPrerelease": false + } +} \ No newline at end of file