diff --git a/arroyo/processing/processor.py b/arroyo/processing/processor.py index ba87c9ab..0fe77a1c 100644 --- a/arroyo/processing/processor.py +++ b/arroyo/processing/processor.py @@ -298,6 +298,7 @@ def run(self) -> None: logger.info("Closing %r...", self.__consumer) self.__consumer.close() + self.__processor_factory.shutdown() logger.info("Processor terminated") raise @@ -446,6 +447,7 @@ def _shutdown(self) -> None: logger.info("Stopping consumer") self.__metrics_buffer.flush() self.__consumer.close() + self.__processor_factory.shutdown() logger.info("Stopped") # if there was an active processing strategy, it should be shut down diff --git a/arroyo/processing/strategies/abstract.py b/arroyo/processing/strategies/abstract.py index 33064e5f..6008809a 100644 --- a/arroyo/processing/strategies/abstract.py +++ b/arroyo/processing/strategies/abstract.py @@ -118,3 +118,11 @@ def create_with_partitions( :param partitions: A mapping of a ``Partition`` to it's most recent offset. """ raise NotImplementedError + + def shutdown(self) -> None: + """ + Custom code to execute when the ``StreamProcessor`` shuts down entirely. + + Note that this code will also be executed on crashes of the strategy. + """ + pass diff --git a/tests/processing/test_processor.py b/tests/processing/test_processor.py index eb509235..d5c57ebe 100644 --- a/tests/processing/test_processor.py +++ b/tests/processing/test_processor.py @@ -1,10 +1,10 @@ +import time from datetime import datetime, timedelta from typing import Any, Mapping, Optional, Sequence, cast from unittest import mock -import time -import pytest import py.path +import pytest from arroyo.backends.local.backend import LocalBroker from arroyo.backends.local.storages.abstract import MessageStorage @@ -123,7 +123,9 @@ def test_stream_processor_lifecycle() -> None: with pytest.raises(InvalidStateError): processor._run_once() - with assert_changes(lambda: int(consumer.close.call_count), 0, 1): + with assert_changes(lambda: int(consumer.close.call_count), 0, 1), assert_changes( + lambda: int(factory.shutdown.call_count), 0, 1 + ): processor._shutdown() assert list((type(call), call.name) for call in metrics.calls) == [ @@ -564,12 +566,14 @@ def test_healthcheck(tmpdir: py.path.local) -> None: strategy.submit.side_effect = InvalidMessage(partition, 1) factory = mock.Mock() factory.create_with_partitions.return_value = Healthcheck( - healthcheck_file=str(tmpdir.join("health.txt")), - next_step=strategy + healthcheck_file=str(tmpdir.join("health.txt")), next_step=strategy ) processor: StreamProcessor[int] = StreamProcessor( - consumer, topic, factory, IMMEDIATE, + consumer, + topic, + factory, + IMMEDIATE, ) # Assignment