diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProvider.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProvider.java index f70a27d0f..493ac0d64 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProvider.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/StreamingClientProvider.java @@ -20,9 +20,11 @@ import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.ENABLE_STREAMING_CLIENT_OPTIMIZATION_DEFAULT; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig; import com.snowflake.kafka.connector.internal.KCLogger; import java.util.Map; +import java.util.function.Supplier; import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.Caffeine; import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.LoadingCache; import net.snowflake.ingest.internal.com.github.benmanes.caffeine.cache.RemovalCause; @@ -38,10 +40,10 @@ * node with equal {@link StreamingClientProperties} will use the same client */ public class StreamingClientProvider { - private static class StreamingClientProviderSingleton { - private static final StreamingClientProvider streamingClientProvider = - new StreamingClientProvider(); - } + private static volatile StreamingClientProvider streamingClientProvider = null; + + private static Supplier clientHandlerSupplier = + DirectStreamingClientHandler::new; /** * Gets the current streaming provider @@ -49,7 +51,45 @@ private static class StreamingClientProviderSingleton { * @return The streaming client provider */ public static StreamingClientProvider getStreamingClientProviderInstance() { - return StreamingClientProviderSingleton.streamingClientProvider; + if (streamingClientProvider == null) { + synchronized (StreamingClientProvider.class) { + if (streamingClientProvider == null) { + streamingClientProvider = new StreamingClientProvider(clientHandlerSupplier.get()); + } + } + } + + return streamingClientProvider; + } + + /** + * Gets the provider state to pre-initialization state. This method is currently used by the test + * code only. + */ + @VisibleForTesting + public static void reset() { + synchronized (StreamingClientProvider.class) { + streamingClientProvider = null; + clientHandlerSupplier = DirectStreamingClientHandler::new; + } + } + + /*** + * The method allows for providing custom {@link StreamingClientHandler} to be used by the connector + * instead of the default that is {@link DirectStreamingClientHandler} + * + * This method is currently used by the test code only. + * + * @param streamingClientHandler The handler that will be used by the connector. + */ + @VisibleForTesting + public static void overrideStreamingClientHandler(StreamingClientHandler streamingClientHandler) { + Preconditions.checkState( + streamingClientProvider == null, + "StreamingClientProvider is already initialized and cannot be overridden."); + synchronized (StreamingClientProvider.class) { + clientHandlerSupplier = () -> streamingClientHandler; + } } /** @@ -92,8 +132,8 @@ public static StreamingClientProvider getStreamingClientProviderInstance() { * When a client is evicted, the cache will try closing the client, however it is best to still * call close client manually as eviction is executed lazily */ - private StreamingClientProvider() { - this.streamingClientHandler = new DirectStreamingClientHandler(); + private StreamingClientProvider(StreamingClientHandler streamingClientHandler) { + this.streamingClientHandler = streamingClientHandler; this.registeredClients = buildLoadingCache(this.streamingClientHandler); } diff --git a/src/test/java/com/snowflake/kafka/connector/ConnectClusterBaseIT.java b/src/test/java/com/snowflake/kafka/connector/ConnectClusterBaseIT.java index 0c6dc3a14..d784ecf71 100644 --- a/src/test/java/com/snowflake/kafka/connector/ConnectClusterBaseIT.java +++ b/src/test/java/com/snowflake/kafka/connector/ConnectClusterBaseIT.java @@ -1,48 +1,35 @@ package com.snowflake.kafka.connector; -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.BUFFER_COUNT_RECORDS; import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.BUFFER_FLUSH_TIME_SEC; +import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.INGESTION_METHOD_OPT; import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.NAME; -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_DATABASE; -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY; -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_SCHEMA; -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_URL; -import static com.snowflake.kafka.connector.SnowflakeSinkConnectorConfig.SNOWFLAKE_USER; import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; import static org.apache.kafka.connect.sink.SinkConnector.TOPICS_CONFIG; -import static org.assertj.core.api.Assertions.assertThat; -import static org.awaitility.Awaitility.await; -import com.snowflake.kafka.connector.fake.SnowflakeFakeSinkConnector; -import com.snowflake.kafka.connector.fake.SnowflakeFakeSinkTask; -import java.time.Duration; -import java.util.HashMap; -import java.util.List; +import com.snowflake.kafka.connector.internal.TestUtils; +import com.snowflake.kafka.connector.internal.streaming.FakeStreamingClientHandler; +import com.snowflake.kafka.connector.internal.streaming.IngestionMethodConfig; +import com.snowflake.kafka.connector.internal.streaming.StreamingClientProvider; import java.util.Map; -import org.apache.kafka.connect.runtime.AbstractStatus; -import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; -import org.apache.kafka.connect.sink.SinkRecord; import org.apache.kafka.connect.storage.StringConverter; import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class ConnectClusterBaseIT { +class ConnectClusterBaseIT { - protected EmbeddedConnectCluster connectCluster; + EmbeddedConnectCluster connectCluster; - protected static final String TEST_TOPIC = "kafka-int-test"; - protected static final String TEST_CONNECTOR_NAME = "test-connector"; - protected static final Integer TASK_NUMBER = 1; - private static final Duration CONNECTOR_MAX_STARTUP_TIME = Duration.ofSeconds(20); + FakeStreamingClientHandler fakeStreamingClientHandler; + + static final Integer TASK_NUMBER = 1; @BeforeAll public void beforeAll() { @@ -52,14 +39,18 @@ public void beforeAll() { .numWorkers(3) .build(); connectCluster.start(); - connectCluster.kafka().createTopic(TEST_TOPIC); - connectCluster.configureConnector(TEST_CONNECTOR_NAME, createProperties()); - await().timeout(CONNECTOR_MAX_STARTUP_TIME).until(this::isConnectorRunning); } @BeforeEach - public void before() { - SnowflakeFakeSinkTask.resetRecords(); + public void beforeEach() { + StreamingClientProvider.reset(); + fakeStreamingClientHandler = new FakeStreamingClientHandler(); + StreamingClientProvider.overrideStreamingClientHandler(fakeStreamingClientHandler); + } + + @AfterEach + public void afterEach() { + StreamingClientProvider.reset(); } @AfterAll @@ -70,55 +61,30 @@ public void afterAll() { } } - @AfterEach - public void after() { - SnowflakeFakeSinkTask.resetRecords(); - } - - @Test - public void connectorShouldConsumeMessagesFromTopic() { - connectCluster.kafka().produce(TEST_TOPIC, "test1"); - connectCluster.kafka().produce(TEST_TOPIC, "test2"); + final Map defaultProperties(String topicName, String connectorName) { + Map config = TestUtils.getConf(); - await() - .untilAsserted( - () -> { - List records = SnowflakeFakeSinkTask.getRecords(); - assertThat(records).hasSize(2); - assertThat(records.stream().map(SinkRecord::value)).containsExactly("test1", "test2"); - }); - } - - protected Map createProperties() { - Map config = new HashMap<>(); - - // kafka connect specific - // real connector will be specified with SNOW-1055561 - config.put(CONNECTOR_CLASS_CONFIG, SnowflakeFakeSinkConnector.class.getName()); - config.put(TOPICS_CONFIG, TEST_TOPIC); + config.put(CONNECTOR_CLASS_CONFIG, SnowflakeSinkConnector.class.getName()); + config.put(NAME, connectorName); + config.put(TOPICS_CONFIG, topicName); + config.put(INGESTION_METHOD_OPT, IngestionMethodConfig.SNOWPIPE_STREAMING.toString()); + config.put(Utils.SF_ROLE, "testrole_kafka"); + config.put(BUFFER_FLUSH_TIME_SEC, "1"); config.put(TASKS_MAX_CONFIG, TASK_NUMBER.toString()); config.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); config.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); - // kafka push specific - config.put(NAME, TEST_CONNECTOR_NAME); - config.put(SNOWFLAKE_URL, "https://test.testregion.snowflakecomputing.com:443"); - config.put(SNOWFLAKE_USER, "testName"); - config.put(SNOWFLAKE_PRIVATE_KEY, "testPrivateKey"); - config.put(SNOWFLAKE_DATABASE, "testDbName"); - config.put(SNOWFLAKE_SCHEMA, "testSchema"); - config.put(BUFFER_COUNT_RECORDS, "1000000"); - config.put(BUFFER_FLUSH_TIME_SEC, "1"); - return config; } - private boolean isConnectorRunning() { - ConnectorStateInfo status = connectCluster.connectorStatus(TEST_CONNECTOR_NAME); - return status != null - && status.connector().state().equals(AbstractStatus.State.RUNNING.toString()) - && status.tasks().size() >= TASK_NUMBER - && status.tasks().stream() - .allMatch(state -> state.state().equals(AbstractStatus.State.RUNNING.toString())); + final void waitForConnectorRunning(String connectorName) { + try { + connectCluster + .assertions() + .assertConnectorAndAtLeastNumTasksAreRunning( + connectorName, 1, "The connector did not start."); + } catch (InterruptedException e) { + throw new IllegalStateException("The connector is not running"); + } } } diff --git a/src/test/java/com/snowflake/kafka/connector/SmtIT.java b/src/test/java/com/snowflake/kafka/connector/SmtIT.java new file mode 100644 index 000000000..cc77d7858 --- /dev/null +++ b/src/test/java/com/snowflake/kafka/connector/SmtIT.java @@ -0,0 +1,78 @@ +package com.snowflake.kafka.connector; + +import static org.apache.kafka.connect.runtime.ConnectorConfig.TRANSFORMS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import com.snowflake.kafka.connector.internal.TestUtils; +import java.time.Duration; +import java.util.Map; +import java.util.function.UnaryOperator; +import java.util.stream.Stream; +import org.apache.kafka.connect.json.JsonConverter; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +public class SmtIT extends ConnectClusterBaseIT { + + private String smtTopic; + private String smtConnector; + + @BeforeEach + void before() { + smtTopic = TestUtils.randomTableName(); + smtConnector = String.format("%s_connector", smtTopic); + connectCluster.kafka().createTopic(smtTopic); + } + + @AfterEach + void after() { + connectCluster.kafka().deleteTopic(smtTopic); + connectCluster.deleteConnector(smtConnector); + } + + private Map smtProperties( + String smtTopic, String smtConnector, String behaviorOnNull) { + Map config = defaultProperties(smtTopic, smtConnector); + + config.put(VALUE_CONVERTER_CLASS_CONFIG, JsonConverter.class.getName()); + config.put("value.converter.schemas.enable", "false"); + config.put("behavior.on.null.values", behaviorOnNull); + + config.put(TRANSFORMS_CONFIG, "extractField"); + config.put( + "transforms.extractField.type", "org.apache.kafka.connect.transforms.ExtractField$Value"); + config.put("transforms.extractField.field", "message"); + + return config; + } + + @ParameterizedTest + @CsvSource({"DEFAULT, 20", "IGNORE, 10"}) + void testIfSmtReturningNullsIngestDataCorrectly(String behaviorOnNull, int expectedRecordNumber) { + // given + connectCluster.configureConnector( + smtConnector, smtProperties(smtTopic, smtConnector, behaviorOnNull)); + waitForConnectorRunning(smtConnector); + + // when + Stream.iterate(0, UnaryOperator.identity()) + .limit(10) + .flatMap(v -> Stream.of("{}", "{\"message\":\"value\"}")) + .forEach(message -> connectCluster.kafka().produce(smtTopic, message)); + + // then + await() + .timeout(Duration.ofSeconds(60)) + .untilAsserted( + () -> { + assertThat(fakeStreamingClientHandler.ingestedRows()).hasSize(expectedRecordNumber); + assertThat(fakeStreamingClientHandler.getLatestCommittedOffsetTokensPerChannel()) + .hasSize(1) + .containsValue("19"); + }); + } +} diff --git a/src/test/java/com/snowflake/kafka/connector/fake/SnowflakeFakeSinkConnector.java b/src/test/java/com/snowflake/kafka/connector/fake/SnowflakeFakeSinkConnector.java deleted file mode 100644 index 56671057b..000000000 --- a/src/test/java/com/snowflake/kafka/connector/fake/SnowflakeFakeSinkConnector.java +++ /dev/null @@ -1,53 +0,0 @@ -package com.snowflake.kafka.connector.fake; - -import com.snowflake.kafka.connector.SnowflakeSinkConnector; -import com.snowflake.kafka.connector.Utils; -import com.snowflake.kafka.connector.internal.KCLogger; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.kafka.common.config.ConfigDef; -import org.apache.kafka.connect.connector.Task; -import org.apache.kafka.connect.sink.SinkConnector; - -public class SnowflakeFakeSinkConnector extends SinkConnector { - - private static final KCLogger LOGGER = new KCLogger(SnowflakeSinkConnector.class.getName()); - private Map config; // connector configuration, provided by - // user through kafka connect framework - @Override - public void start(Map parsedConfig) { - LOGGER.debug("Starting " + SnowflakeFakeSinkConnector.class.getSimpleName()); - config = new HashMap<>(parsedConfig); - } - - @Override - public Class taskClass() { - return SnowflakeFakeSinkTask.class; - } - - @Override - public List> taskConfigs(int maxTasks) { - List> taskConfigs = new ArrayList<>(); - Map conf = new HashMap<>(config); - conf.put(Utils.TASK_ID, "fakeTask1"); - taskConfigs.add(conf); - return taskConfigs; - } - - @Override - public void stop() { - LOGGER.debug("Stopping " + SnowflakeFakeSinkConnector.class.getSimpleName()); - } - - @Override - public ConfigDef config() { - return new ConfigDef(); - } - - @Override - public String version() { - return Utils.VERSION; - } -} diff --git a/src/test/java/com/snowflake/kafka/connector/fake/SnowflakeFakeSinkTask.java b/src/test/java/com/snowflake/kafka/connector/fake/SnowflakeFakeSinkTask.java deleted file mode 100644 index 253b80744..000000000 --- a/src/test/java/com/snowflake/kafka/connector/fake/SnowflakeFakeSinkTask.java +++ /dev/null @@ -1,49 +0,0 @@ -package com.snowflake.kafka.connector.fake; - -import com.snowflake.kafka.connector.SnowflakeSinkConnector; -import com.snowflake.kafka.connector.Utils; -import com.snowflake.kafka.connector.internal.KCLogger; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import org.apache.kafka.connect.sink.SinkRecord; -import org.apache.kafka.connect.sink.SinkTask; - -public class SnowflakeFakeSinkTask extends SinkTask { - - private static final KCLogger LOGGER = new KCLogger(SnowflakeSinkConnector.class.getName()); - - private static final List records = Collections.synchronizedList(new ArrayList<>()); - - public static List getRecords() { - return Arrays.asList(records.toArray(new SinkRecord[0])); - } - - public static void resetRecords() { - records.clear(); - } - - @Override - public String version() { - return Utils.VERSION; - } - - @Override - public void start(Map map) { - LOGGER.debug("Starting " + SnowflakeFakeSinkTask.class.getSimpleName()); - resetRecords(); - } - - @Override - public void put(Collection collection) { - records.addAll(collection); - } - - @Override - public void stop() { - LOGGER.debug("Stopping " + SnowflakeFakeSinkTask.class.getSimpleName()); - } -} diff --git a/src/test/java/com/snowflake/kafka/connector/internal/streaming/FakeStreamingClientHandler.java b/src/test/java/com/snowflake/kafka/connector/internal/streaming/FakeStreamingClientHandler.java index 9c45a2d49..8f953282e 100644 --- a/src/test/java/com/snowflake/kafka/connector/internal/streaming/FakeStreamingClientHandler.java +++ b/src/test/java/com/snowflake/kafka/connector/internal/streaming/FakeStreamingClientHandler.java @@ -1,21 +1,31 @@ package com.snowflake.kafka.connector.internal.streaming; +import java.util.Collection; +import java.util.Map; +import java.util.Set; import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import net.snowflake.ingest.streaming.FakeSnowflakeStreamingIngestClient; import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; public class FakeStreamingClientHandler implements StreamingClientHandler { - private AtomicInteger createClientCalls = new AtomicInteger(0); - private AtomicInteger closeClientCalls = new AtomicInteger(0); + private final ConcurrentLinkedQueue clients = + new ConcurrentLinkedQueue<>(); + private final AtomicInteger createClientCalls = new AtomicInteger(0); + private final AtomicInteger closeClientCalls = new AtomicInteger(0); @Override public SnowflakeStreamingIngestClient createClient( StreamingClientProperties streamingClientProperties) { createClientCalls.incrementAndGet(); - return new FakeSnowflakeStreamingIngestClient( - streamingClientProperties.clientName + "_" + UUID.randomUUID()); + FakeSnowflakeStreamingIngestClient ingestClient = + new FakeSnowflakeStreamingIngestClient( + streamingClientProperties.clientName + "_" + UUID.randomUUID()); + clients.add(ingestClient); + return ingestClient; } @Override @@ -35,4 +45,19 @@ public Integer getCreateClientCalls() { public Integer getCloseClientCalls() { return closeClientCalls.get(); } + + public Set> ingestedRows() { + return clients.stream() + .map(FakeSnowflakeStreamingIngestClient::ingestedRecords) + .flatMap(Collection::stream) + .collect(Collectors.toSet()); + } + + public Map getLatestCommittedOffsetTokensPerChannel() { + return this.clients.stream() + .map(FakeSnowflakeStreamingIngestClient::getLatestCommittedOffsetTokensPerChannel) + .map(Map::entrySet) + .flatMap(Collection::stream) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/FakeSnowflakeStreamingIngestChannel.java b/src/test/java/net/snowflake/ingest/streaming/FakeSnowflakeStreamingIngestChannel.java index 991e1f8a7..3bacf2194 100644 --- a/src/test/java/net/snowflake/ingest/streaming/FakeSnowflakeStreamingIngestChannel.java +++ b/src/test/java/net/snowflake/ingest/streaming/FakeSnowflakeStreamingIngestChannel.java @@ -1,5 +1,6 @@ package net.snowflake.ingest.streaming; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.*; import java.util.concurrent.CompletableFuture; @@ -154,4 +155,8 @@ public String getLatestCommittedOffsetToken() { public Map getTableSchema() { throw new UnsupportedOperationException("Method is unsupported in fake communication channel"); } + + List> getRows() { + return ImmutableList.copyOf(this.rows); + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/FakeSnowflakeStreamingIngestClient.java b/src/test/java/net/snowflake/ingest/streaming/FakeSnowflakeStreamingIngestClient.java index 5db917e70..e7c315345 100644 --- a/src/test/java/net/snowflake/ingest/streaming/FakeSnowflakeStreamingIngestClient.java +++ b/src/test/java/net/snowflake/ingest/streaming/FakeSnowflakeStreamingIngestClient.java @@ -1,10 +1,13 @@ package net.snowflake.ingest.streaming; import com.snowflake.kafka.connector.internal.KCLogger; +import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import net.snowflake.ingest.utils.Pair; /** * Fake implementation of {@link SnowflakeStreamingIngestClient}. Uses in memory state only. @@ -18,7 +21,7 @@ public class FakeSnowflakeStreamingIngestClient implements SnowflakeStreamingIng private boolean closed; private static final KCLogger LOGGER = new KCLogger(FakeSnowflakeStreamingIngestClient.class.getName()); - private final ConcurrentHashMap channelCache = + private final ConcurrentHashMap channelCache = new ConcurrentHashMap<>(); public FakeSnowflakeStreamingIngestClient(String name) { @@ -74,4 +77,17 @@ public Map getLatestCommittedOffsetTokens( public void close() throws Exception { closed = true; } + + public Set> ingestedRecords() { + return channelCache.values().stream() + .map(FakeSnowflakeStreamingIngestChannel::getRows) + .flatMap(Collection::stream) + .collect(Collectors.toSet()); + } + + public Map getLatestCommittedOffsetTokensPerChannel() { + return this.channelCache.values().stream() + .map(channel -> new Pair<>(channel.getName(), channel.getLatestCommittedOffsetToken())) + .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + } }