diff --git a/modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala b/modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala index c52fe843b..f7887a5cc 100644 --- a/modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala +++ b/modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala @@ -12,7 +12,7 @@ import scala.collection.immutable.SortedSet import scala.concurrent.duration.FiniteDuration import scala.util.matching.Regex import cats.{Applicative, Foldable, Functor, Reducible} -import cats.data.{NonEmptySet, OptionT} +import cats.data.{Chain, NonEmptySet, OptionT} import cats.effect.* import cats.effect.implicits.* import cats.effect.std.* @@ -268,17 +268,14 @@ object KafkaConsumer { ): OnRebalance[F] = OnRebalance( onRevoked = revoked => { - val finishSignals = for { + for { finishers <- assignmentRef.modify(_.partition(entry => !revoked.contains(entry._1))) - revokeFinishers <- finishers - .toVector + revokeFinishers <- Chain.fromIterableOnce(finishers) .traverse { case (_, assignmentSignals) => assignmentSignals.signalStreamToTerminate.as(assignmentSignals.awaitStreamFinishedSignal) } } yield revokeFinishers - - finishSignals.flatMap(revokes => revokes.sequence_) }, onAssigned = assignedPartitions => { for { @@ -447,7 +444,9 @@ object KafkaConsumer { assignmentRef.updateAndGet(_ ++ assigned).flatMap(updateQueue.offer), onRevoked = revoked => initialAssignmentDone >> - assignmentRef.updateAndGet(_ -- revoked).flatMap(updateQueue.offer) + assignmentRef.updateAndGet(_ -- revoked) + .flatMap(updateQueue.offer) + .as(Chain.empty) ) Stream diff --git a/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala b/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala index e2c049625..80b1b5ba7 100644 --- a/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala +++ b/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala @@ -224,13 +224,17 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V]( } .flatMap { res => val onRevoked = - res.onRebalances.foldLeft(F.unit)(_ >> _.onRevoked(revoked)) + res.onRebalances.foldLeftM(Chain.empty[F[Unit]]) { (revocationsAcc, revocationsNext) => + revocationsNext.onRevoked(revoked).map(revocationsAcc ++ _) + } res.logRevoked >> res.completeWithRecords >> res.completeWithoutRecords >> res.removeRevokedRecords >> - onRevoked.timeout(settings.sessionTimeout) //just to be extra-safe timeout this revoke + onRevoked //first we trigger all the streams to finalize + .flatMap(_.sequence_) //second we await streams termination (Eager mode returns immediately) + .timeout(settings.sessionTimeout) //just to be extra-safe timeout this revoke } } @@ -630,7 +634,7 @@ private[kafka] object KafkaConsumerActor { final case class OnRebalance[F[_]]( onAssigned: SortedSet[TopicPartition] => F[Unit], - onRevoked: SortedSet[TopicPartition] => F[Unit], + onRevoked: SortedSet[TopicPartition] => F[Chain[F[Unit]]], ) { override def toString: String =