Skip to content

Commit

Permalink
MINOR: Fix race condition in KafkaConsumer close
Browse files Browse the repository at this point in the history
We intended to make `KafkaConsumer.close()` idempotent,
but due to the fact that the `closed` variable is
checked without a lock prior to beginning close logic,
it is possible for two or more threads to see
`closed=false` and attempt to close.

Author: Jason Gustafson <[email protected]>

Reviewers: Apurva Mehta <[email protected]>, Ismael Juma <[email protected]>

Closes apache#3426 from hachikuji/minor-fix-consumer-idempotent-close
  • Loading branch information
hachikuji authored and ijuma committed Jun 27, 2017
1 parent f1cc800 commit 031da88
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ private KafkaConsumer(ConsumerConfig config,
* @return The set of partitions currently assigned to this consumer
*/
public Set<TopicPartition> assignment() {
acquire();
acquireAndEnsureOpen();
try {
return Collections.unmodifiableSet(new HashSet<>(this.subscriptions.assignedPartitions()));
} finally {
Expand All @@ -818,7 +818,7 @@ public Set<TopicPartition> assignment() {
* @return The set of topics currently subscribed to
*/
public Set<String> subscription() {
acquire();
acquireAndEnsureOpen();
try {
return Collections.unmodifiableSet(new HashSet<>(this.subscriptions.subscription()));
} finally {
Expand Down Expand Up @@ -857,7 +857,7 @@ public Set<String> subscription() {
*/
@Override
public void subscribe(Collection<String> topics, ConsumerRebalanceListener listener) {
acquire();
acquireAndEnsureOpen();
try {
if (topics == null) {
throw new IllegalArgumentException("Topic collection to subscribe to cannot be null");
Expand Down Expand Up @@ -923,7 +923,7 @@ public void subscribe(Collection<String> topics) {
*/
@Override
public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
acquire();
acquireAndEnsureOpen();
try {
if (pattern == null)
throw new IllegalArgumentException("Topic pattern to subscribe to cannot be null");
Expand All @@ -942,7 +942,7 @@ public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
* also clears any partitions directly assigned through {@link #assign(Collection)}.
*/
public void unsubscribe() {
acquire();
acquireAndEnsureOpen();
try {
log.debug("Unsubscribed all topics or patterns and assigned partitions");
this.subscriptions.unsubscribe();
Expand Down Expand Up @@ -970,7 +970,7 @@ public void unsubscribe() {
*/
@Override
public void assign(Collection<TopicPartition> partitions) {
acquire();
acquireAndEnsureOpen();
try {
if (partitions == null) {
throw new IllegalArgumentException("Topic partition collection to assign to cannot be null");
Expand Down Expand Up @@ -1028,7 +1028,7 @@ public void assign(Collection<TopicPartition> partitions) {
*/
@Override
public ConsumerRecords<K, V> poll(long timeout) {
acquire();
acquireAndEnsureOpen();
try {
if (timeout < 0)
throw new IllegalArgumentException("Timeout must not be negative");
Expand Down Expand Up @@ -1134,7 +1134,7 @@ public boolean shouldBlock() {
*/
@Override
public void commitSync() {
acquire();
acquireAndEnsureOpen();
try {
coordinator.commitOffsetsSync(subscriptions.allConsumed(), Long.MAX_VALUE);
} finally {
Expand Down Expand Up @@ -1168,7 +1168,7 @@ public void commitSync() {
*/
@Override
public void commitSync(final Map<TopicPartition, OffsetAndMetadata> offsets) {
acquire();
acquireAndEnsureOpen();
try {
coordinator.commitOffsetsSync(new HashMap<>(offsets), Long.MAX_VALUE);
} finally {
Expand Down Expand Up @@ -1199,7 +1199,7 @@ public void commitAsync() {
*/
@Override
public void commitAsync(OffsetCommitCallback callback) {
acquire();
acquireAndEnsureOpen();
try {
commitAsync(subscriptions.allConsumed(), callback);
} finally {
Expand All @@ -1224,7 +1224,7 @@ public void commitAsync(OffsetCommitCallback callback) {
*/
@Override
public void commitAsync(final Map<TopicPartition, OffsetAndMetadata> offsets, OffsetCommitCallback callback) {
acquire();
acquireAndEnsureOpen();
try {
log.debug("Committing offsets: {} ", offsets);
coordinator.commitOffsetsAsync(new HashMap<>(offsets), callback);
Expand All @@ -1240,11 +1240,11 @@ public void commitAsync(final Map<TopicPartition, OffsetAndMetadata> offsets, Of
*/
@Override
public void seek(TopicPartition partition, long offset) {
if (offset < 0) {
throw new IllegalArgumentException("seek offset must not be a negative number");
}
acquire();
acquireAndEnsureOpen();
try {
if (offset < 0)
throw new IllegalArgumentException("seek offset must not be a negative number");

log.debug("Seeking to offset {} for partition {}", offset, partition);
this.subscriptions.seek(partition, offset);
} finally {
Expand All @@ -1258,7 +1258,7 @@ public void seek(TopicPartition partition, long offset) {
* If no partition is provided, seek to the first offset for all of the currently assigned partitions.
*/
public void seekToBeginning(Collection<TopicPartition> partitions) {
acquire();
acquireAndEnsureOpen();
try {
Collection<TopicPartition> parts = partitions.size() == 0 ? this.subscriptions.assignedPartitions() : partitions;
for (TopicPartition tp : parts) {
Expand All @@ -1279,7 +1279,7 @@ public void seekToBeginning(Collection<TopicPartition> partitions) {
* of the first message with an open transaction.
*/
public void seekToEnd(Collection<TopicPartition> partitions) {
acquire();
acquireAndEnsureOpen();
try {
Collection<TopicPartition> parts = partitions.size() == 0 ? this.subscriptions.assignedPartitions() : partitions;
for (TopicPartition tp : parts) {
Expand Down Expand Up @@ -1307,7 +1307,7 @@ public void seekToEnd(Collection<TopicPartition> partitions) {
* @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors
*/
public long position(TopicPartition partition) {
acquire();
acquireAndEnsureOpen();
try {
if (!this.subscriptions.isAssigned(partition))
throw new IllegalArgumentException("You can only check the position for partitions assigned to this consumer.");
Expand Down Expand Up @@ -1341,7 +1341,7 @@ public long position(TopicPartition partition) {
*/
@Override
public OffsetAndMetadata committed(TopicPartition partition) {
acquire();
acquireAndEnsureOpen();
try {
Map<TopicPartition, OffsetAndMetadata> offsets = coordinator.fetchCommittedOffsets(Collections.singleton(partition));
return offsets.get(partition);
Expand Down Expand Up @@ -1375,7 +1375,7 @@ public OffsetAndMetadata committed(TopicPartition partition) {
*/
@Override
public List<PartitionInfo> partitionsFor(String topic) {
acquire();
acquireAndEnsureOpen();
try {
Cluster cluster = this.metadata.fetch();
List<PartitionInfo> parts = cluster.partitionsForTopic(topic);
Expand Down Expand Up @@ -1405,7 +1405,7 @@ public List<PartitionInfo> partitionsFor(String topic) {
*/
@Override
public Map<String, List<PartitionInfo>> listTopics() {
acquire();
acquireAndEnsureOpen();
try {
return fetcher.getAllTopicMetadata(requestTimeoutMs);
} finally {
Expand All @@ -1422,7 +1422,7 @@ public Map<String, List<PartitionInfo>> listTopics() {
*/
@Override
public void pause(Collection<TopicPartition> partitions) {
acquire();
acquireAndEnsureOpen();
try {
for (TopicPartition partition: partitions) {
log.debug("Pausing partition {}", partition);
Expand All @@ -1441,7 +1441,7 @@ public void pause(Collection<TopicPartition> partitions) {
*/
@Override
public void resume(Collection<TopicPartition> partitions) {
acquire();
acquireAndEnsureOpen();
try {
for (TopicPartition partition: partitions) {
log.debug("Resuming partition {}", partition);
Expand All @@ -1459,7 +1459,7 @@ public void resume(Collection<TopicPartition> partitions) {
*/
@Override
public Set<TopicPartition> paused() {
acquire();
acquireAndEnsureOpen();
try {
return Collections.unmodifiableSet(subscriptions.pausedPartitions());
} finally {
Expand Down Expand Up @@ -1487,14 +1487,19 @@ public Set<TopicPartition> paused() {
*/
@Override
public Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(Map<TopicPartition, Long> timestampsToSearch) {
for (Map.Entry<TopicPartition, Long> entry : timestampsToSearch.entrySet()) {
// we explicitly exclude the earliest and latest offset here so the timestamp in the returned
// OffsetAndTimestamp is always positive.
if (entry.getValue() < 0)
throw new IllegalArgumentException("The target time for partition " + entry.getKey() + " is " +
entry.getValue() + ". The target time cannot be negative.");
acquireAndEnsureOpen();
try {
for (Map.Entry<TopicPartition, Long> entry : timestampsToSearch.entrySet()) {
// we explicitly exclude the earliest and latest offset here so the timestamp in the returned
// OffsetAndTimestamp is always positive.
if (entry.getValue() < 0)
throw new IllegalArgumentException("The target time for partition " + entry.getKey() + " is " +
entry.getValue() + ". The target time cannot be negative.");
}
return fetcher.getOffsetsByTimes(timestampsToSearch, requestTimeoutMs);
} finally {
release();
}
return fetcher.getOffsetsByTimes(timestampsToSearch, requestTimeoutMs);
}

/**
Expand All @@ -1510,7 +1515,12 @@ public Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(Map<TopicPartitio
*/
@Override
public Map<TopicPartition, Long> beginningOffsets(Collection<TopicPartition> partitions) {
return fetcher.beginningOffsets(partitions, requestTimeoutMs);
acquireAndEnsureOpen();
try {
return fetcher.beginningOffsets(partitions, requestTimeoutMs);
} finally {
release();
}
}

/**
Expand All @@ -1532,7 +1542,13 @@ public Map<TopicPartition, Long> beginningOffsets(Collection<TopicPartition> par
*/
@Override
public Map<TopicPartition, Long> endOffsets(Collection<TopicPartition> partitions) {
return fetcher.endOffsets(partitions, requestTimeoutMs);
acquireAndEnsureOpen();
try {
return fetcher.endOffsets(partitions, requestTimeoutMs);
} finally {
release();
}

}

/**
Expand Down Expand Up @@ -1564,13 +1580,14 @@ public void close() {
* @throws IllegalArgumentException If the <code>timeout</code> is negative.
*/
public void close(long timeout, TimeUnit timeUnit) {
if (closed)
return;
if (timeout < 0)
throw new IllegalArgumentException("The timeout cannot be negative.");
acquire();
try {
close(timeUnit.toMillis(timeout), false);
if (!closed) {
closed = true;
close(timeUnit.toMillis(timeout), false);
}
} finally {
release();
}
Expand Down Expand Up @@ -1599,7 +1616,6 @@ private ClusterResourceListeners configureClusterResourceListeners(Deserializer<
private void close(long timeoutMs, boolean swallowException) {
log.trace("Closing the Kafka consumer.");
AtomicReference<Throwable> firstException = new AtomicReference<>();
this.closed = true;
try {
if (coordinator != null)
coordinator.close(Math.min(timeoutMs, requestTimeoutMs));
Expand Down Expand Up @@ -1651,23 +1667,25 @@ private void updateFetchPositions(Set<TopicPartition> partitions) {
}
}

/*
* Check that the consumer hasn't been closed.
/**
* Acquire the light lock and ensure that the consumer hasn't been closed.
* @throws IllegalStateException If the consumer has been closed
*/
private void ensureNotClosed() {
if (this.closed)
private void acquireAndEnsureOpen() {
acquire();
if (this.closed) {
release();
throw new IllegalStateException("This consumer has already been closed.");
}
}

/**
* Acquire the light lock protecting this consumer from multi-threaded access. Instead of blocking
* when the lock is not available, however, we just throw an exception (since multi-threaded usage is not
* supported).
* @throws IllegalStateException if the consumer has been closed
* @throws ConcurrentModificationException if another thread already has the lock
*/
private void acquire() {
ensureNotClosed();
long threadId = Thread.currentThread().getId();
if (threadId != currentThread.get() && !currentThread.compareAndSet(NO_CURRENT_THREAD, threadId))
throw new ConcurrentModificationException("KafkaConsumer is not safe for multi-threaded access");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ public void closeShouldBeIdempotent() {
KafkaConsumer<byte[], byte[]> consumer = newConsumer();
consumer.close();
consumer.close();
consumer.close();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class ConsumerBounceTest extends IntegrationTestHarness with Logging {
val rebalanceFuture = createConsumerToRebalance()

// consumer1 should leave group and close immediately even though rebalance is in progress
submitCloseAndValidate(consumer1, Long.MaxValue, None, Some(gracefulCloseTimeMs))
val closeFuture1 = submitCloseAndValidate(consumer1, Long.MaxValue, None, Some(gracefulCloseTimeMs))

// Rebalance should complete without waiting for consumer1 to timeout since consumer1 has left the group
waitForRebalance(2000, rebalanceFuture, consumer2)
Expand All @@ -343,7 +343,11 @@ class ConsumerBounceTest extends IntegrationTestHarness with Logging {
servers.foreach(server => killBroker(server.config.brokerId))

// consumer2 should close immediately without LeaveGroup request since there are no brokers available
submitCloseAndValidate(consumer2, Long.MaxValue, None, Some(0))
val closeFuture2 = submitCloseAndValidate(consumer2, Long.MaxValue, None, Some(0))

// Ensure futures complete to avoid concurrent shutdown attempt during test cleanup
closeFuture1.get(2000, TimeUnit.MILLISECONDS)
closeFuture2.get(2000, TimeUnit.MILLISECONDS)
}

private def createConsumer(groupId: String) : KafkaConsumer[Array[Byte], Array[Byte]] = {
Expand Down

0 comments on commit 031da88

Please sign in to comment.