diff --git a/src/main/java/com/mongodb/kafka/connect/sink/MongoSinkTask.java b/src/main/java/com/mongodb/kafka/connect/sink/MongoSinkTask.java index c29d5be4..8fd58092 100644 --- a/src/main/java/com/mongodb/kafka/connect/sink/MongoSinkTask.java +++ b/src/main/java/com/mongodb/kafka/connect/sink/MongoSinkTask.java @@ -60,17 +60,24 @@ public String version() { * * @param props initial configuration */ + @SuppressWarnings("try") @Override public void start(final Map props) { LOGGER.info("Starting MongoDB sink task"); - MongoSinkConfig sinkConfig; + MongoClient client = null; try { - sinkConfig = new MongoSinkConfig(props); - } catch (Exception e) { - throw new ConnectException("Failed to start new task", e); + MongoSinkConfig sinkConfig = new MongoSinkConfig(props); + client = createMongoClient(sinkConfig); + startedTask = new StartedMongoSinkTask(sinkConfig, client, createErrorReporter()); + } catch (RuntimeException taskStartingException) { + //noinspection EmptyTryBlock + try (MongoClient autoCloseableClient = client) { + // just using try-with-resources to ensure they all get closed, even in the case of exceptions + } catch (RuntimeException resourceReleasingException) { + taskStartingException.addSuppressed(resourceReleasingException); + } + throw new ConnectException("Failed to start MongoDB sink task", taskStartingException); } - startedTask = - new StartedMongoSinkTask(sinkConfig, createMongoClient(sinkConfig), createErrorReporter()); LOGGER.debug("Started MongoDB sink task"); } @@ -115,7 +122,7 @@ public void flush(final Map currentOffsets) { public void stop() { LOGGER.info("Stopping MongoDB sink task"); if (startedTask != null) { - startedTask.stop(); + startedTask.close(); } } diff --git a/src/main/java/com/mongodb/kafka/connect/sink/StartedMongoSinkTask.java b/src/main/java/com/mongodb/kafka/connect/sink/StartedMongoSinkTask.java index 5b1a3573..decf5109 100644 --- a/src/main/java/com/mongodb/kafka/connect/sink/StartedMongoSinkTask.java +++ b/src/main/java/com/mongodb/kafka/connect/sink/StartedMongoSinkTask.java @@ -50,7 +50,7 @@ import com.mongodb.kafka.connect.util.time.InnerOuterTimer.InnerTimer; import com.mongodb.kafka.connect.util.time.Timer; -final class StartedMongoSinkTask { +final class StartedMongoSinkTask implements AutoCloseable { private final MongoSinkConfig sinkConfig; private final MongoClient mongoClient; private final ErrorReporter errorReporter; @@ -93,7 +93,8 @@ private String getMBeanName() { /** @see MongoSinkTask#stop() */ @SuppressWarnings("try") - void stop() { + @Override + public void close() { try (MongoClient autoCloseable = mongoClient) { statistics.unregister(); } diff --git a/src/main/java/com/mongodb/kafka/connect/source/MongoSourceTask.java b/src/main/java/com/mongodb/kafka/connect/source/MongoSourceTask.java index 8940c228..229c0a04 100644 --- a/src/main/java/com/mongodb/kafka/connect/source/MongoSourceTask.java +++ b/src/main/java/com/mongodb/kafka/connect/source/MongoSourceTask.java @@ -103,30 +103,29 @@ public String version() { return Versions.VERSION; } + @SuppressWarnings("try") @Override public void start(final Map props) { LOGGER.info("Starting MongoDB source task"); - MongoSourceConfig sourceConfig; - try { - sourceConfig = new MongoSourceConfig(props); - } catch (Exception e) { - throw new ConnectException("Failed to start new task", e); - } - - boolean shouldCopyData = shouldCopyData(context, sourceConfig); - String connectorName = JmxStatisticsManager.getConnectorName(props); - StatisticsManager statisticsManager = new JmxStatisticsManager(shouldCopyData, connectorName); + StatisticsManager statisticsManager = null; + MongoClient mongoClient = null; + MongoCopyDataManager copyDataManager = null; try { + MongoSourceConfig sourceConfig = new MongoSourceConfig(props); + boolean shouldCopyData = shouldCopyData(context, sourceConfig); + String connectorName = JmxStatisticsManager.getConnectorName(props); + statisticsManager = new JmxStatisticsManager(shouldCopyData, connectorName); + StatisticsManager statsManager = statisticsManager; CommandListener statisticsCommandListener = new CommandListener() { @Override public void commandSucceeded(final CommandSucceededEvent event) { - mongoCommandSucceeded(event, statisticsManager.currentStatistics()); + mongoCommandSucceeded(event, statsManager.currentStatistics()); } @Override public void commandFailed(final CommandFailedEvent event) { - mongoCommandFailed(event, statisticsManager.currentStatistics()); + mongoCommandFailed(event, statsManager.currentStatistics()); } }; @@ -137,10 +136,11 @@ public void commandFailed(final CommandFailedEvent event) { .applyToSslSettings(sslBuilder -> setupSsl(sslBuilder, sourceConfig)); setServerApi(builder, sourceConfig); - MongoClient mongoClient = + mongoClient = MongoClients.create( builder.build(), getMongoDriverInformation(CONNECTOR_TYPE, sourceConfig.getString(PROVIDER_CONFIG))); + copyDataManager = shouldCopyData ? new MongoCopyDataManager(sourceConfig, mongoClient) : null; startedTask = new StartedMongoSourceTask( @@ -148,15 +148,19 @@ public void commandFailed(final CommandFailedEvent event) { // in case it changes, because there is no // documentation stating that it cannot be changed. () -> context, - sourceConfig, - mongoClient, - shouldCopyData ? new MongoCopyDataManager(sourceConfig, mongoClient) : null, - statisticsManager); - LOGGER.info("Started MongoDB source task"); - } catch (RuntimeException e) { - statisticsManager.close(); - throw e; + sourceConfig, mongoClient, copyDataManager, statisticsManager); + } catch (RuntimeException taskStartingException) { + //noinspection EmptyTryBlock + try (StatisticsManager autoCloseableStatisticsManager = statisticsManager; + MongoClient autoCloseableMongoClient = mongoClient; + MongoCopyDataManager autoCloseableCopyDataManager = copyDataManager) { + // just using try-with-resources to ensure they all get closed, even in the case of exceptions + } catch (RuntimeException resourceReleasingException) { + taskStartingException.addSuppressed(resourceReleasingException); + } + throw new ConnectException("Failed to start MongoDB source task", taskStartingException); } + LOGGER.info("Started MongoDB source task"); } @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) diff --git a/src/test/java/com/mongodb/kafka/connect/sink/StartedMongoSinkTaskTest.java b/src/test/java/com/mongodb/kafka/connect/sink/StartedMongoSinkTaskTest.java index 9f15e4ba..fc798a53 100644 --- a/src/test/java/com/mongodb/kafka/connect/sink/StartedMongoSinkTaskTest.java +++ b/src/test/java/com/mongodb/kafka/connect/sink/StartedMongoSinkTaskTest.java @@ -57,6 +57,7 @@ import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.errors.DataException; import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -95,6 +96,7 @@ final class StartedMongoSinkTaskTest { private Map properties; private BulkWritesCapturingClient client; private InMemoryErrorReporter errorReporter; + private StartedMongoSinkTask task; @BeforeEach void setUp() { @@ -106,12 +108,18 @@ void setUp() { errorReporter = new InMemoryErrorReporter(); } + @AfterEach + void tearDown() { + if (task != null) { + task.close(); + } + } + @Test void put() { MongoSinkConfig config = new MongoSinkConfig(properties); client.configureCapturing(DEFAULT_NAMESPACE); - StartedMongoSinkTask task = - new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); + task = new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); RecordsAndExpectations recordsAndExpectations = new RecordsAndExpectations( asList( @@ -123,7 +131,6 @@ void put() { task.put(recordsAndExpectations.records()); recordsAndExpectations.assertExpectations( client.capturedBulkWrites().get(DEFAULT_NAMESPACE), errorReporter.reported()); - task.stop(); } @Test @@ -131,8 +138,7 @@ void putTolerateAllPostProcessingError() { properties.put(MongoSinkTopicConfig.ERRORS_TOLERANCE_CONFIG, ErrorTolerance.ALL.value()); MongoSinkConfig config = new MongoSinkConfig(properties); client.configureCapturing(DEFAULT_NAMESPACE); - StartedMongoSinkTask task = - new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); + task = new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); RecordsAndExpectations recordsAndExpectations = new RecordsAndExpectations( asList( @@ -145,7 +151,6 @@ void putTolerateAllPostProcessingError() { task.put(recordsAndExpectations.records()); recordsAndExpectations.assertExpectations( client.capturedBulkWrites().get(DEFAULT_NAMESPACE), errorReporter.reported()); - task.stop(); } /** @@ -162,8 +167,7 @@ void putTolerateAllAnyError() { collection -> when(collection.bulkWrite(anyList(), any(BulkWriteOptions.class))) .thenThrow(new MongoCommandException(new BsonDocument(), new ServerAddress()))); - StartedMongoSinkTask task = - new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); + task = new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); RecordsAndExpectations recordsAndExpectations = new RecordsAndExpectations( asList( @@ -182,15 +186,13 @@ void putTolerateAllAnyError() { task.put(recordsAndExpectations.records()); recordsAndExpectations.assertExpectations( client.capturedBulkWrites().get(DEFAULT_NAMESPACE), errorReporter.reported()); - task.stop(); } @Test void putTolerateNonePostProcessingError() { MongoSinkConfig config = new MongoSinkConfig(properties); client.configureCapturing(DEFAULT_NAMESPACE); - StartedMongoSinkTask task = - new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); + task = new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); RecordsAndExpectations recordsAndExpectations = new RecordsAndExpectations( asList(Records.simpleValid(TEST_TOPIC, 0), Records.simpleInvalid(TEST_TOPIC, 1)), @@ -199,7 +201,6 @@ void putTolerateNonePostProcessingError() { assertThrows(RuntimeException.class, () -> task.put(recordsAndExpectations.records())); recordsAndExpectations.assertExpectations( client.capturedBulkWrites().get(DEFAULT_NAMESPACE), errorReporter.reported()); - task.stop(); } @Test @@ -213,8 +214,7 @@ void putTolerateNoneWriteError() { .thenThrow(bulkWriteException(emptyList(), true)) // batch2 .thenReturn(BulkWriteResult.unacknowledged())); - StartedMongoSinkTask task = - new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); + task = new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); RecordsAndExpectations recordsAndExpectations = new RecordsAndExpectations( asList( @@ -227,7 +227,6 @@ void putTolerateNoneWriteError() { assertThrows(DataException.class, () -> task.put(recordsAndExpectations.records())); recordsAndExpectations.assertExpectations( client.capturedBulkWrites().get(DEFAULT_NAMESPACE), errorReporter.reported()); - task.stop(); } @Test @@ -246,8 +245,7 @@ void putTolerateAllOrderedWriteError() { .thenThrow(bulkWriteException(emptyList(), true)) // batch4 .thenThrow(bulkWriteException(singletonList(1), true))); - StartedMongoSinkTask task = - new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); + task = new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); RecordsAndExpectations recordsAndExpectations = new RecordsAndExpectations( asList( @@ -281,7 +279,6 @@ void putTolerateAllOrderedWriteError() { task.put(recordsAndExpectations.records()); recordsAndExpectations.assertExpectations( client.capturedBulkWrites().get(DEFAULT_NAMESPACE), errorReporter.reported()); - task.stop(); } @Test @@ -303,8 +300,7 @@ void putTolerateAllUnorderedWriteError() { .thenThrow(bulkWriteException(emptyList(), true)) // batch4 .thenThrow(bulkWriteException(singletonList(1), true))); - StartedMongoSinkTask task = - new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); + task = new StartedMongoSinkTask(config, client.mongoClient(), errorReporter); RecordsAndExpectations recordsAndExpectations = new RecordsAndExpectations( asList( @@ -338,7 +334,6 @@ void putTolerateAllUnorderedWriteError() { task.put(recordsAndExpectations.records()); recordsAndExpectations.assertExpectations( client.capturedBulkWrites().get(DEFAULT_NAMESPACE), errorReporter.reported()); - task.stop(); } @SuppressWarnings("unchecked")