Skip to content

Commit

Permalink
restart task when commit failed (#74)
Browse files Browse the repository at this point in the history
* restart task when commit failed
  • Loading branch information
YongGang authored Jun 22, 2021
1 parent 4b48244 commit 6545ce1
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/main/java/com/salesforce/mirus/KafkaMonitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ boolean partitionsChanged() {
}

private List<TopicPartition> fetchMatchingPartitions(Consumer<byte[], byte[]> consumer) {
return consumer.listTopics().entrySet().stream()
return consumer
.listTopics()
.entrySet()
.stream()
.filter(
e ->
topicsWhitelist.contains(e.getKey())
Expand Down
30 changes: 30 additions & 0 deletions src/main/java/com/salesforce/mirus/MirusSourceTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.errors.WakeupException;
import org.apache.kafka.common.header.Headers;
import org.apache.kafka.common.utils.SystemTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.connect.data.SchemaAndValue;
import org.apache.kafka.connect.header.ConnectHeaders;
import org.apache.kafka.connect.source.SourceRecord;
Expand Down Expand Up @@ -69,11 +71,15 @@ public class MirusSourceTask extends SourceTask {
private HeaderConverter headerConverter;
private ReplayPolicy replayPolicy;
private long replayWindowRecords;
private long successfulCommitTime = Long.MAX_VALUE;
private long lastNewRecordTime = Long.MAX_VALUE;

private final Map<TopicPartition, Long> latestOffsetMap = new HashMap<>();
private final Set<TopicPartition> loggingFlags = new HashSet<>();

protected AtomicBoolean shutDown = new AtomicBoolean(false);
protected Time time = new SystemTime();
private long commitFailureRestartMs;

@SuppressWarnings("unused")
public MirusSourceTask() {
Expand Down Expand Up @@ -107,6 +113,7 @@ public void start(Map<String, String> properties) {
this.destinationTopicNamePrefix = config.getDestinationTopicNamePrefix();
this.destinationTopicNameSuffix = config.getDestinationTopicNameSuffix();
this.enablePartitionMatching = config.getEnablePartitionMatching();
this.commitFailureRestartMs = config.getCommitFailureRestartMs();

this.keyConverter = config.getKeyConverter();
this.valueConverter = config.getValueConverter();
Expand Down Expand Up @@ -181,11 +188,17 @@ public List<SourceRecord> poll() {

try {
logger.trace("Calling poll");
checkCommitFailure();
ConsumerRecords<byte[], byte[]> result = consumer.poll(consumerPollTimeoutMillis);
logger.trace("Got {} records", result.count());
if (!result.isEmpty()) {
lastNewRecordTime = time.milliseconds();
return sourceRecords(result);
} else {
// If no new data has arrived since last successful commit, move the effective commit time forward
if (lastNewRecordTime <= successfulCommitTime) {
successfulCommitTime = time.milliseconds();
}
return Collections.emptyList();
}
} catch (WakeupException e) {
Expand All @@ -197,6 +210,23 @@ public List<SourceRecord> poll() {
return Collections.emptyList();
}

@Override
public void commit() {
successfulCommitTime = time.milliseconds();
}

private void checkCommitFailure() {
// if no success offset commit in an extensive period of time, restart task to reestablish Kafka
// connection
if (lastNewRecordTime - successfulCommitTime >= commitFailureRestartMs) {
throw new RuntimeException(
"Unable to commit offsets for more than "
+ commitFailureRestartMs / 1000
+ " seconds. "
+ "Attempting to restart task.");
}
}

List<SourceRecord> sourceRecords(ConsumerRecords<byte[], byte[]> pollResult) {
List<SourceRecord> sourceRecords = new ArrayList<>(pollResult.count());
pollResult.forEach(
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/com/salesforce/mirus/config/SourceConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,19 @@

package com.salesforce.mirus.config;

import static java.util.stream.Collectors.toList;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import org.apache.kafka.common.config.ConfigDef;
import org.apache.kafka.connect.errors.ConnectException;
import org.apache.kafka.connect.runtime.ConnectorConfig;
import org.apache.kafka.connect.source.SourceRecord;
import org.apache.kafka.connect.transforms.Transformation;
import org.apache.kafka.connect.transforms.util.SimpleConfig;

import static java.util.stream.Collectors.toList;

public class SourceConfig {

private final SimpleConfig simpleConfig;
Expand Down Expand Up @@ -114,7 +113,8 @@ private List<Transformation<SourceRecord>> buildTransformations() {
}

private static List<Pattern> parseTopicsRegexList(List<String> topicsRegexList) {
return topicsRegexList.stream()
return topicsRegexList
.stream()
.map(
r -> {
String regex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ public enum SourceConfigDefinition {
"org.apache.kafka.connect.converters.ByteArrayConverter",
ConfigDef.Importance.MEDIUM,
"Converter class to apply to source record headers"),
COMMIT_FAILURE_RESTART_MS(
"commit.failure.restart.ms",
ConfigDef.Type.LONG,
TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES),
ConfigDef.Importance.MEDIUM,
"Fail task if no successful commit is seen for this time. Tasks automatically restart by default"),
@Deprecated
DESTINATION_BOOTSTRAP_SERVERS(
"destination.bootstrap.servers",
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/com/salesforce/mirus/config/TaskConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ public long getConsumerPollTimeout() {
return simpleConfig.getLong(SourceConfigDefinition.POLL_TIMEOUT_MS.key);
}

public long getCommitFailureRestartMs() {
return simpleConfig.getLong(SourceConfigDefinition.COMMIT_FAILURE_RESTART_MS.key);
}

public String getDestinationTopicNamePrefix() {
return simpleConfig.getString(SourceConfigDefinition.DESTINATION_TOPIC_NAME_PREFIX.key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public class TaskConfigDefinition {
SourceConfigDefinition.ENABLE_PARTITION_MATCHING,
SourceConfigDefinition.SOURCE_KEY_CONVERTER,
SourceConfigDefinition.SOURCE_VALUE_CONVERTER,
SourceConfigDefinition.SOURCE_HEADER_CONVERTER);
SourceConfigDefinition.SOURCE_HEADER_CONVERTER,
SourceConfigDefinition.COMMIT_FAILURE_RESTART_MS);

static ConfigDef configDef() {
ConfigDef configDef = new ConfigDef();
Expand Down
87 changes: 87 additions & 0 deletions src/test/java/com/salesforce/mirus/MirusSourceTaskTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.salesforce.mirus.config.SourceConfigDefinition;
import com.salesforce.mirus.config.TaskConfig;
import com.salesforce.mirus.config.TaskConfig.ReplayPolicy;
import com.salesforce.mirus.config.TaskConfigDefinition;
import java.nio.charset.StandardCharsets;
Expand All @@ -37,6 +40,7 @@
import org.apache.kafka.common.header.Headers;
import org.apache.kafka.common.header.internals.RecordHeaders;
import org.apache.kafka.common.record.TimestampType;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.connect.data.ConnectSchema;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Struct;
Expand Down Expand Up @@ -355,4 +359,87 @@ public void testReplayFilterWindow() {
assertThat(result.size(), is(3));
assertThat(result.get(0).sourceOffset().get(MirusSourceTask.KEY_OFFSET), is(2L));
}

@Test(expected = RuntimeException.class)
public void shouldThrowExceptionWhenCommitFailed() {
Time mockTime = mock(Time.class);
mirusSourceTask.time = mockTime;
long currentMillis = System.currentTimeMillis();
when(mockTime.milliseconds()).thenReturn(currentMillis);
// normal poll-commit cycle
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, new byte[] {}, new byte[] {}));
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 0, new byte[] {}, new byte[] {}));
List<SourceRecord> result = mirusSourceTask.poll();
assertThat(result.size(), is(2));
mirusSourceTask.commit();

// poll success but commit failed
TaskConfig config = new TaskConfig(mockTaskProperties());
long elapseTime = config.getCommitFailureRestartMs() / 2;
when(mockTime.milliseconds()).thenReturn(currentMillis + elapseTime);
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 1, new byte[] {}, new byte[] {}));
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 1, new byte[] {}, new byte[] {}));
mirusSourceTask.poll();
elapseTime = config.getCommitFailureRestartMs();
when(mockTime.milliseconds()).thenReturn(currentMillis + elapseTime);
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 2, new byte[] {}, new byte[] {}));
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 2, new byte[] {}, new byte[] {}));
mirusSourceTask.poll();

// check commit failure and throw exception to restart task
mirusSourceTask.poll();
}

@Test
public void shouldNotThrowExceptionIfNotTimeToRestart() {
Time mockTime = mock(Time.class);
mirusSourceTask.time = mockTime;
long currentMillis = System.currentTimeMillis();
when(mockTime.milliseconds()).thenReturn(currentMillis);
// normal poll-commit cycle
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, new byte[] {}, new byte[] {}));
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 0, new byte[] {}, new byte[] {}));
List<SourceRecord> result = mirusSourceTask.poll();
assertThat(result.size(), is(2));
mirusSourceTask.commit();

// poll success but commit failed
TaskConfig config = new TaskConfig(mockTaskProperties());
long elapseTime = config.getCommitFailureRestartMs() - 10;
when(mockTime.milliseconds()).thenReturn(currentMillis + elapseTime);
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 1, new byte[] {}, new byte[] {}));
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 1, new byte[] {}, new byte[] {}));
mirusSourceTask.poll();

// check commit failure, no exception thrown as time is not up to restart task
mirusSourceTask.poll();
}

@Test
public void shouldNotThrowExceptionIfNoNewDataInCommitWindow() {
Time mockTime = mock(Time.class);
mirusSourceTask.time = mockTime;
long currentMillis = System.currentTimeMillis();
when(mockTime.milliseconds()).thenReturn(currentMillis);
// normal poll-commit cycle
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, new byte[] {}, new byte[] {}));
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 0, new byte[] {}, new byte[] {}));
List<SourceRecord> result = mirusSourceTask.poll();
assertThat(result.size(), is(2));
mirusSourceTask.commit();

// poll success but commit failed
TaskConfig config = new TaskConfig(mockTaskProperties());
// no new data
long elapseTime = config.getCommitFailureRestartMs() + 10;
when(mockTime.milliseconds()).thenReturn(currentMillis + elapseTime);
mirusSourceTask.poll();
// new data coming
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 1, new byte[] {}, new byte[] {}));
mockConsumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 1, new byte[] {}, new byte[] {}));
mirusSourceTask.poll();

// check commit failure
mirusSourceTask.poll();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import org.apache.kafka.common.config.ConfigDef;
import org.apache.kafka.connect.connector.ConnectRecord;
import org.apache.kafka.connect.source.SourceRecord;
Expand Down

0 comments on commit 6545ce1

Please sign in to comment.