diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderIT.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderIT.java index 037a6c5f7..79fd44118 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderIT.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderIT.java @@ -18,168 +18,123 @@ package com.snowflake.kafka.connector.internal.streaming; import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; -import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.internal.TestUtils; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; -import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.LoadingCache; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.mockito.Mockito; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; -@RunWith(Parameterized.class) public class StreamingClientProviderIT { - private final boolean enableClientOptimization; - private final Map clientConfig = TestUtils.getConfForStreaming(); - @Parameterized.Parameters(name = "enableClientOptimization: {0}") - public static Collection input() { - return Arrays.asList(new Object[][] {{true}, {false}}); + @Test + public void getClient_forOptimizationEnabled_returnSameClient() { + // given + Map clientConfig = getClientConfig(true); + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderInstance(); + + // when + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(clientConfig); + SnowflakeStreamingIngestClient client2 = streamingClientProvider.getClient(clientConfig); + + // then + Assertions.assertSame(client, client2); + + int clientsRegistered = streamingClientProvider.getRegisteredClients().size(); + Assertions.assertEquals(1, clientsRegistered); } - public StreamingClientProviderIT(boolean enableClientOptimization) { - this.enableClientOptimization = enableClientOptimization; - this.clientConfig.put( - SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, - String.valueOf(this.enableClientOptimization)); + @Test + public void getClient_forOptimizationDisabled_returnDifferentClients() { + // given + Map clientConfig = getClientConfig(false); + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderInstance(); + + // when + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(clientConfig); + SnowflakeStreamingIngestClient client2 = streamingClientProvider.getClient(clientConfig); + + // then + Assertions.assertNotNull(client); + Assertions.assertNotNull(client2); + Assertions.assertNotEquals(client, client2); + } + + @Test + public void getClient_forInvalidClient_returnNewInstance() throws Exception { + // given + Map clientConfig = getClientConfig(true); + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderInstance(); + + // when + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(clientConfig); + client.close(); + SnowflakeStreamingIngestClient client2 = streamingClientProvider.getClient(clientConfig); + + // then + Assertions.assertTrue(client.isClosed()); + Assertions.assertFalse(client2.isClosed()); } @Test - public void testGetMultipleClients() throws Exception { - String validRegisteredClientName = "openRegisteredClient"; - String invalidRegisteredClientName = "closedRegisteredClient"; - String validUnregisteredClientName = "openUnregisteredClient"; - StreamingClientHandler clientCreator = new StreamingClientHandler(); - - // setup registered valid client - Map validRegisteredClientConfig = new HashMap<>(this.clientConfig); - validRegisteredClientConfig.put(Utils.NAME, validRegisteredClientName); - validRegisteredClientConfig.put(Utils.SF_OAUTH_CLIENT_ID, "0"); - StreamingClientProperties validRegisteredClientProps = - new StreamingClientProperties(validRegisteredClientConfig); - SnowflakeStreamingIngestClient validRegisteredClient = - clientCreator.createClient(validRegisteredClientProps); - - // setup registered invalid client - Map invalidRegisteredClientConfig = new HashMap<>(this.clientConfig); - invalidRegisteredClientConfig.put(Utils.NAME, invalidRegisteredClientName); - invalidRegisteredClientConfig.put(Utils.SF_OAUTH_CLIENT_ID, "1"); - StreamingClientProperties invalidRegisteredClientProps = - new StreamingClientProperties(invalidRegisteredClientConfig); - SnowflakeStreamingIngestClient invalidRegisteredClient = - clientCreator.createClient(invalidRegisteredClientProps); - invalidRegisteredClient.close(); - - // setup unregistered valid client - Map validUnregisteredClientConfig = new HashMap<>(this.clientConfig); - validUnregisteredClientConfig.put(Utils.NAME, validUnregisteredClientName); - validUnregisteredClientConfig.put(Utils.SF_OAUTH_CLIENT_ID, "2"); - StreamingClientProperties validUnregisteredClientProps = - new StreamingClientProperties(validUnregisteredClientConfig); - SnowflakeStreamingIngestClient validUnregisteredClient = - clientCreator.createClient(validUnregisteredClientProps); - - // inject registered clients - StreamingClientHandler streamingClientHandlerSpy = - Mockito.spy(StreamingClientHandler.class); // use this to verify behavior - LoadingCache registeredClients = - StreamingClientProvider.buildLoadingCache(streamingClientHandlerSpy); - - registeredClients.put(validRegisteredClientProps, validRegisteredClient); - registeredClients.put(invalidRegisteredClientProps, invalidRegisteredClient); + public void getClient_forMissingOptimizationConfig_returnValidClient() { + // given + Map clientConfig = getClientConfig(true); + clientConfig.remove(SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG); + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderInstance(); + + // when + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(clientConfig); + + // then + Assertions.assertNotNull(client); + Assertions.assertFalse(client.isClosed()); + } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void closeClient_forValidClient_stopTheClient(boolean clientOptimizationEnabled) { + // given + Map clientConfig = getClientConfig(clientOptimizationEnabled); StreamingClientProvider streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - streamingClientHandlerSpy, registeredClients); - - assert streamingClientProvider.getRegisteredClients().size() == 2; - - // test 1: get registered valid client optimization returns existing client - SnowflakeStreamingIngestClient resultValidRegisteredClient = - streamingClientProvider.getClient(validRegisteredClientConfig); - - assert StreamingClientHandler.isClientValid(resultValidRegisteredClient); - assert resultValidRegisteredClient.getName().contains("_0"); - assert this.enableClientOptimization - == resultValidRegisteredClient.equals(validRegisteredClient); - Mockito.verify(streamingClientHandlerSpy, Mockito.times(this.enableClientOptimization ? 0 : 1)) - .createClient(validRegisteredClientProps); - assert streamingClientProvider.getRegisteredClients().size() == 2; - - // test 2: get registered invalid client creates new client regardless of optimization - SnowflakeStreamingIngestClient resultInvalidRegisteredClient = - streamingClientProvider.getClient(invalidRegisteredClientConfig); - - assert StreamingClientHandler.isClientValid(resultInvalidRegisteredClient); - assert resultInvalidRegisteredClient - .getName() - .contains("_" + (this.enableClientOptimization ? 0 : 1)); - assert !resultInvalidRegisteredClient.equals(invalidRegisteredClient); - Mockito.verify(streamingClientHandlerSpy, Mockito.times(1)) - .createClient(invalidRegisteredClientProps); - assert streamingClientProvider.getRegisteredClients().size() == 2; - - // test 3: get unregistered valid client creates and registers new client with optimization - SnowflakeStreamingIngestClient resultValidUnregisteredClient = - streamingClientProvider.getClient(validUnregisteredClientConfig); - - assert StreamingClientHandler.isClientValid(resultValidUnregisteredClient); - assert resultValidUnregisteredClient - .getName() - .contains("_" + (this.enableClientOptimization ? 1 : 2)); - assert !resultValidUnregisteredClient.equals(validUnregisteredClient); - Mockito.verify(streamingClientHandlerSpy, Mockito.times(1)) - .createClient(validUnregisteredClientProps); - assert streamingClientProvider.getRegisteredClients().size() - == (this.enableClientOptimization ? 3 : 2); - - // verify streamingClientHandler behavior - Mockito.verify(streamingClientHandlerSpy, Mockito.times(this.enableClientOptimization ? 2 : 3)) - .createClient(Mockito.any()); - - // test 4: get all clients multiple times and verify optimization doesn't create new clients - List gotClientList = new ArrayList<>(); - - for (int i = 0; i < 5; i++) { - gotClientList.add(streamingClientProvider.getClient(validRegisteredClientConfig)); - gotClientList.add(streamingClientProvider.getClient(invalidRegisteredClientConfig)); - gotClientList.add(streamingClientProvider.getClient(validUnregisteredClientConfig)); - } - - List distinctClients = - gotClientList.stream().distinct().collect(Collectors.toList()); - assert distinctClients.size() == (this.enableClientOptimization ? 3 : gotClientList.size()); - Mockito.verify( - streamingClientHandlerSpy, - Mockito.times(this.enableClientOptimization ? 2 : 3 + gotClientList.size())) - .createClient(Mockito.any()); - assert streamingClientProvider.getRegisteredClients().size() - == (this.enableClientOptimization ? 3 : 2); - - // close all clients - validRegisteredClient.close(); - invalidRegisteredClient.close(); - validUnregisteredClient.close(); - - resultValidRegisteredClient.close(); - resultInvalidRegisteredClient.close(); - resultValidUnregisteredClient.close(); - - distinctClients.stream() - .forEach( - client -> { - try { - client.close(); - } catch (Exception e) { - // do nothing - } - }); + StreamingClientProvider.getStreamingClientProviderInstance(); + + // when + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(clientConfig); + streamingClientProvider.closeClient(clientConfig, client); + + // then + Assertions.assertTrue(client.isClosed()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void closeClient_forAlreadyStoppedClient_doesNotThrowException( + boolean clientOptimizationEnabled) { + // given + Map clientConfig = getClientConfig(clientOptimizationEnabled); + StreamingClientProvider streamingClientProvider = + StreamingClientProvider.getStreamingClientProviderInstance(); + + // when + SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(clientConfig); + streamingClientProvider.closeClient(clientConfig, client); + + // then + Assertions.assertDoesNotThrow(() -> streamingClientProvider.closeClient(clientConfig, client)); + Assertions.assertTrue(client.isClosed()); + } + + private Map getClientConfig(boolean clientOptimizationEnabled) { + Map clientConfig = TestUtils.getConfForStreaming(); + clientConfig.put( + SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, + String.valueOf(clientOptimizationEnabled)); + return clientConfig; } } diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderTest.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderTest.java deleted file mode 100644 index 02d25f001..000000000 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProviderTest.java +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Copyright (c) 2023 Snowflake Inc. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package com.snowflake.kafka.connector.internal.streaming; - -import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; -import com.snowflake.kafka.connector.Utils; -import com.snowflake.kafka.connector.internal.TestUtils; -import java.util.Arrays; -import java.util.Collection; -import java.util.Map; -import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.LoadingCache; -import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.mockito.Mockito; - -@RunWith(Parameterized.class) -public class StreamingClientProviderTest { - private final boolean enableClientOptimization; - private final Map clientConfig = TestUtils.getConfForStreaming(); - - @Parameterized.Parameters(name = "enableClientOptimization: {0}") - public static Collection input() { - return Arrays.asList(new Object[][] {{true}, {false}}); - } - - public StreamingClientProviderTest(boolean enableClientOptimization) { - this.enableClientOptimization = enableClientOptimization; - this.clientConfig.put( - SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG, - String.valueOf(this.enableClientOptimization)); - } - - @Test - public void testFirstGetClient() { - // setup mock client and handler - SnowflakeStreamingIngestClient clientMock = Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(clientMock.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); - - StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); - Mockito.when(mockClientHandler.createClient(new StreamingClientProperties(this.clientConfig))) - .thenReturn(clientMock); - - // test provider gets new client - StreamingClientProvider streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - mockClientHandler, StreamingClientProvider.buildLoadingCache(mockClientHandler)); - SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(this.clientConfig); - - // verify - should create a client regardless of optimization - assert client.equals(clientMock); - assert client.getName().contains(this.clientConfig.get(Utils.NAME)); - Mockito.verify(mockClientHandler, Mockito.times(1)) - .createClient(new StreamingClientProperties(this.clientConfig)); - } - - @Test - public void testGetInvalidClient() { - // setup handler, invalid mock client and valid returned client - SnowflakeStreamingIngestClient mockInvalidClient = - Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(mockInvalidClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); - Mockito.when(mockInvalidClient.isClosed()).thenReturn(true); - - SnowflakeStreamingIngestClient mockValidClient = - Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(mockValidClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); - Mockito.when(mockValidClient.isClosed()).thenReturn(false); - - StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); - Mockito.when(mockClientHandler.createClient(new StreamingClientProperties(this.clientConfig))) - .thenReturn(mockValidClient); - - // inject invalid client into provider - LoadingCache mockRegisteredClients = - Mockito.mock(LoadingCache.class); - Mockito.when(mockRegisteredClients.get(new StreamingClientProperties(this.clientConfig))) - .thenReturn(mockInvalidClient); - - // test provider gets new client - StreamingClientProvider streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - mockClientHandler, mockRegisteredClients); - SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(this.clientConfig); - - // verify - returned client is valid even though we injected an invalid client - assert client.equals(mockValidClient); - assert !client.equals(mockInvalidClient); - assert client.getName().contains(this.clientConfig.get(Utils.NAME)); - assert !client.isClosed(); - Mockito.verify(mockClientHandler, Mockito.times(1)) - .createClient(new StreamingClientProperties(this.clientConfig)); - } - - @Test - public void testGetExistingClient() { - // setup existing client, handler and inject to registeredClients - SnowflakeStreamingIngestClient mockExistingClient = - Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(mockExistingClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); - Mockito.when(mockExistingClient.isClosed()).thenReturn(false); - - StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); - - LoadingCache mockRegisteredClients = - Mockito.mock(LoadingCache.class); - Mockito.when(mockRegisteredClients.get(new StreamingClientProperties(this.clientConfig))) - .thenReturn(mockExistingClient); - - // if optimization is disabled, we will create new client regardless of registeredClientws - if (!this.enableClientOptimization) { - Mockito.when(mockClientHandler.createClient(new StreamingClientProperties(this.clientConfig))) - .thenReturn(mockExistingClient); - } - - // test getting client - StreamingClientProvider streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - mockClientHandler, mockRegisteredClients); - SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(this.clientConfig); - - // verify client and expected client creation - assert client.equals(mockExistingClient); - assert client.getName().equals(this.clientConfig.get(Utils.NAME)); - Mockito.verify(mockClientHandler, Mockito.times(this.enableClientOptimization ? 0 : 1)) - .createClient(new StreamingClientProperties(this.clientConfig)); - } - - @Test - public void testGetClientMissingConfig() { - // remove one client opt from config - this.clientConfig.remove( - SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_CONFIG); - - // setup existing client, handler and inject to registeredClients - SnowflakeStreamingIngestClient mockExistingClient = - Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(mockExistingClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); - Mockito.when(mockExistingClient.isClosed()).thenReturn(false); - - StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); - - LoadingCache mockRegisteredClients = - Mockito.mock(LoadingCache.class); - Mockito.when(mockRegisteredClients.get(new StreamingClientProperties(this.clientConfig))) - .thenReturn(mockExistingClient); - - // test getting client - StreamingClientProvider streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - mockClientHandler, mockRegisteredClients); - SnowflakeStreamingIngestClient client = streamingClientProvider.getClient(this.clientConfig); - - // verify returned existing client since removing the optimization should default to true - assert client.equals(mockExistingClient); - assert client.getName().equals(this.clientConfig.get(Utils.NAME)); - Mockito.verify(mockClientHandler, Mockito.times(0)) - .createClient(new StreamingClientProperties(this.clientConfig)); - } - - @Test - public void testCloseClients() throws Exception { - // setup valid existing client and handler - SnowflakeStreamingIngestClient mockExistingClient = - Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(mockExistingClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); - Mockito.when(mockExistingClient.isClosed()).thenReturn(false); - - StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); - Mockito.doCallRealMethod() - .when(mockClientHandler) - .closeClient(Mockito.any(SnowflakeStreamingIngestClient.class)); - - // inject existing client in for optimization - LoadingCache mockRegisteredClients = - Mockito.mock(LoadingCache.class); - if (this.enableClientOptimization) { - Mockito.when( - mockRegisteredClients.getIfPresent(new StreamingClientProperties(this.clientConfig))) - .thenReturn(mockExistingClient); - } - - // test closing valid client - StreamingClientProvider streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - mockClientHandler, mockRegisteredClients); - streamingClientProvider.closeClient(this.clientConfig, mockExistingClient); - - // verify existing client was closed, optimization will call given client and registered client - Mockito.verify(mockClientHandler, Mockito.times(this.enableClientOptimization ? 2 : 1)) - .closeClient(mockExistingClient); - Mockito.verify(mockExistingClient, Mockito.times(this.enableClientOptimization ? 2 : 1)) - .close(); - } - - @Test - public void testCloseInvalidClient() throws Exception { - // setup invalid existing client and handler - SnowflakeStreamingIngestClient mockInvalidClient = - Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(mockInvalidClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); - Mockito.when(mockInvalidClient.isClosed()).thenReturn(true); - - StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); - Mockito.doCallRealMethod() - .when(mockClientHandler) - .closeClient(Mockito.any(SnowflakeStreamingIngestClient.class)); - - // inject invalid existing client in for optimization - LoadingCache mockRegisteredClients = - Mockito.mock(LoadingCache.class); - if (this.enableClientOptimization) { - Mockito.when( - mockRegisteredClients.getIfPresent(new StreamingClientProperties(this.clientConfig))) - .thenReturn(mockInvalidClient); - } - - // test closing valid client - StreamingClientProvider streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - mockClientHandler, mockRegisteredClients); - streamingClientProvider.closeClient(this.clientConfig, mockInvalidClient); - - // verify handler close client no-op and client did not need to call close - Mockito.verify(mockClientHandler, Mockito.times(this.enableClientOptimization ? 2 : 1)) - .closeClient(mockInvalidClient); - Mockito.verify(mockInvalidClient, Mockito.times(0)).close(); - } - - @Test - public void testCloseUnregisteredClient() throws Exception { - // setup valid existing client and handler - SnowflakeStreamingIngestClient mockUnregisteredClient = - Mockito.mock(SnowflakeStreamingIngestClient.class); - Mockito.when(mockUnregisteredClient.getName()).thenReturn(this.clientConfig.get(Utils.NAME)); - Mockito.when(mockUnregisteredClient.isClosed()).thenReturn(false); - - StreamingClientHandler mockClientHandler = Mockito.mock(StreamingClientHandler.class); - Mockito.doCallRealMethod() - .when(mockClientHandler) - .closeClient(Mockito.any(SnowflakeStreamingIngestClient.class)); - - // ensure no clients are registered - LoadingCache mockRegisteredClients = - Mockito.mock(LoadingCache.class); - Mockito.when( - mockRegisteredClients.getIfPresent(new StreamingClientProperties(this.clientConfig))) - .thenReturn(null); - - // test closing valid client - StreamingClientProvider streamingClientProvider = - StreamingClientProvider.getStreamingClientProviderForTests( - mockClientHandler, mockRegisteredClients); - streamingClientProvider.closeClient(this.clientConfig, mockUnregisteredClient); - - // verify unregistered client was closed - Mockito.verify(mockClientHandler, Mockito.times(1)).closeClient(mockUnregisteredClient); - Mockito.verify(mockUnregisteredClient, Mockito.times(1)).close(); - } -}