From ebf9fc123807997c327dd8e8ce7d12b46470f9c1 Mon Sep 17 00:00:00 2001 From: Kevin Donahue Date: Fri, 20 Dec 2024 21:13:46 -0500 Subject: [PATCH 1/2] include consumer in github action --- .github/workflows/python-app.yml | 1 + .idea/.gitignore | 3 - .idea/ctts.iml | 19 ------ .../inspectionProfiles/profiles_settings.xml | 6 -- .idea/misc.xml | 7 --- .idea/modules.xml | 8 --- .idea/vcs.xml | 6 -- pc/pc/main.py | 5 +- pc/pc/persistence/db.py | 2 +- pc/tests/fixtures.py | 27 ++++++++ pc/tests/test_consumer.py | 43 +++---------- pp/README.md | 2 + pp/pp/cells/cell_covering.py | 13 ++-- pp/pp/cells/cell_publisher.py | 4 +- pp/pp/main.py | 5 +- pp/tests/fixtures.py | 41 +++++++++++++ pp/tests/test_covering.py | 29 +++++++++ pp/tests/test_publisher.py | 61 +++++-------------- update-open-meteo-data.sh | 4 +- 19 files changed, 139 insertions(+), 147 deletions(-) delete mode 100644 .idea/.gitignore delete mode 100644 .idea/ctts.iml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml create mode 100644 pc/tests/fixtures.py create mode 100644 pp/tests/fixtures.py create mode 100644 pp/tests/test_covering.py diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index be1f509..94a9af9 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -22,6 +22,7 @@ jobs: package-dir: - api - pp + - pc steps: - uses: actions/checkout@v4 diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 26d3352..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml diff --git a/.idea/ctts.iml b/.idea/ctts.iml deleted file mode 100644 index 21363d4..0000000 --- a/.idea/ctts.iml +++ /dev/null @@ -1,19 +0,0 @@ - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index efcf9cc..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 74201b3..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/pc/pc/main.py b/pc/pc/main.py index 33562a6..ed441e1 100644 --- a/pc/pc/main.py +++ b/pc/pc/main.py @@ -2,7 +2,7 @@ import logging import typing -from pc.persistence.db import create_connection_pool, create_brightness_table +from pc.persistence.db import create_pg_connection_pool, create_brightness_table from pc.persistence.models import BrightnessObservation from pc.consumer.consumer import Consumer from pc.config import amqp_url, prediction_queue, cycle_queue @@ -16,11 +16,12 @@ def on_cycle_completion(brightness_observation: BrightnessObservation): async def main(): - pool = await create_connection_pool() + pool = await create_pg_connection_pool() if pool is None: raise ValueError("no connection pool!") await create_brightness_table(pool) + consumer = Consumer( url=amqp_url, prediction_queue=prediction_queue, diff --git a/pc/pc/persistence/db.py b/pc/pc/persistence/db.py index b3ae500..32bcc2b 100644 --- a/pc/pc/persistence/db.py +++ b/pc/pc/persistence/db.py @@ -10,7 +10,7 @@ brightness_observation_table = "brightness_observation" -async def create_connection_pool() -> typing.Optional[asyncpg.Pool]: +async def create_pg_connection_pool() -> typing.Optional[asyncpg.Pool]: pool = await asyncpg.create_pool( user=pg_user, password=pg_password, diff --git a/pc/tests/fixtures.py b/pc/tests/fixtures.py new file mode 100644 index 0000000..7296191 --- /dev/null +++ b/pc/tests/fixtures.py @@ -0,0 +1,27 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from pc.consumer.consumer import Consumer + +@pytest.fixture +async def mock_asyncpg_pool(): + with patch("asyncpg.create_pool") as mock_create_pool: + mock_pool = AsyncMock() + mock_create_pool.return_value = mock_pool + + mock_connection = AsyncMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_connection + yield mock_pool + +@pytest.fixture +def consumer(mock_asyncpg_pool): + prediction_queue="prediction" + cycle_queue="cycle" + + return Consumer( + url="amqp://localhost", + prediction_queue=prediction_queue, + cycle_queue=cycle_queue, + connection_pool=mock_asyncpg_pool, + on_cycle_completion=lambda _: None + ) diff --git a/pc/tests/test_consumer.py b/pc/tests/test_consumer.py index 8933f61..8ead988 100644 --- a/pc/tests/test_consumer.py +++ b/pc/tests/test_consumer.py @@ -2,40 +2,15 @@ import pytest import asyncpg -from aio_pika import Message -from pc.consumer.consumer import Consumer - -@pytest.fixture -async def mock_asyncpg_pool(): - with patch("asyncpg.create_pool") as mock_create_pool: - mock_pool = AsyncMock() - mock_create_pool.return_value = mock_pool - - mock_connection = AsyncMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_connection - yield mock_pool - -amqp_url="amqp://localhost" - -@pytest.fixture -def consumer(mock_asyncpg_pool): - prediction_queue="prediction" - cycle_queue="cycle" - return Consumer( - url=amqp_url, - prediction_queue=prediction_queue, - cycle_queue=cycle_queue, - connection_pool=mock_asyncpg_pool, - on_cycle_completion=lambda _: None - ) +from .fixtures import * +@patch("pc.consumer.consumer.Consumer.connect", new_callable=AsyncMock) @pytest.mark.asyncio -async def test_consumer_connection(consumer): - with patch("pc.consumer.consumer.Consumer.connect", new_callable=AsyncMock) as mock_connect: - mock_connection = AsyncMock() - mock_channel = AsyncMock() - mock_connect.return_value = mock_connection - mock_connection.channel.return_value = mock_channel - await consumer.connect() - mock_connect.assert_called_once() +async def test_consumer_can_connect(mock_connect, consumer): + mock_connection = AsyncMock() + mock_channel = AsyncMock() + mock_connect.return_value = mock_connection + mock_connection.channel.return_value = mock_channel + await consumer.connect() + mock_connect.assert_called_once() diff --git a/pp/README.md b/pp/README.md index 1e4f002..d7b641c 100644 --- a/pp/README.md +++ b/pp/README.md @@ -4,6 +4,8 @@ Retrieves sky brightness prediction across in-polygon h3 cells and puts results on rabbitmq. +> in-polygon is the interior of `land.geojson` + ## monitoring see the rabbitmq [dashboard](http://localhost:15672/#/) diff --git a/pp/pp/cells/cell_covering.py b/pp/pp/cells/cell_covering.py index aa91e70..280ba36 100644 --- a/pp/pp/cells/cell_covering.py +++ b/pp/pp/cells/cell_covering.py @@ -7,17 +7,18 @@ from ..config import resolution -def get_cell_id(lat, lon, resolution) -> str: - return h3.geo_to_h3(lat, lon, resolution=resolution) - - class CellCovering: - def __init__(self): - with open(Path(__file__).parent / "land.geojson", "r") as file: + def __init__(self, path_to_geojson: Path = Path(__file__).parent / "land.geojson"): + with open(path_to_geojson, "r") as file: geojson = json.load(file) self.polygons = [CellCovering.get_polygon_of_feature(f) for f in geojson["features"]] + @staticmethod + def get_cell_id(lat, lon, resolution) -> str: + return h3.geo_to_h3(lat, lon, resolution=resolution) + + @staticmethod def get_polygon_of_feature(feature: typing.Dict) -> typing.Dict: polygon = shape(feature["geometry"]) diff --git a/pp/pp/cells/cell_publisher.py b/pp/pp/cells/cell_publisher.py index a97b32b..b026fd8 100644 --- a/pp/pp/cells/cell_publisher.py +++ b/pp/pp/cells/cell_publisher.py @@ -8,7 +8,7 @@ from h3 import h3_to_geo from pika.adapters.blocking_connection import BlockingChannel -from ..cells.cell_covering import CellCovering, get_cell_id +from ..cells.cell_covering import CellCovering from ..config import resolution from ..stubs.brightness_service_pb2_grpc import BrightnessServiceStub from ..stubs import brightness_service_pb2 @@ -50,7 +50,7 @@ def publish_cell_brightness_message(self, cell) -> None: uuid=response.uuid, lat=lat, lon=lon, - h3_id=get_cell_id(lat, lon, resolution=resolution), + h3_id=CellCovering.get_cell_id(lat, lon, resolution=resolution), mpsas=response.mpsas, timestamp_utc=response.utc_iso, ) diff --git a/pp/pp/main.py b/pp/pp/main.py index c13b29c..edb0f49 100644 --- a/pp/pp/main.py +++ b/pp/pp/main.py @@ -5,14 +5,13 @@ from pika.exceptions import AMQPConnectionError from .config import rabbitmq_host, prediction_queue, cycle_queue, api_port, api_host -from .cells.cell_covering import CellCovering from .cells.cell_publisher import CellPublisher logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger(__name__) -def main(): +def run_publisher(): try: connection = pika.BlockingConnection(pika.ConnectionParameters(rabbitmq_host)) @@ -43,4 +42,4 @@ def main(): if __name__ == "__main__": - main() + run_publisher() diff --git a/pp/tests/fixtures.py b/pp/tests/fixtures.py new file mode 100644 index 0000000..4b2390b --- /dev/null +++ b/pp/tests/fixtures.py @@ -0,0 +1,41 @@ +from unittest.mock import MagicMock +import uuid + +import pytest + +from pp.cells.cell_publisher import CellPublisher + +@pytest.fixture +def mock_grpc_client(mocker): + from datetime import datetime, timezone + from pp.stubs.brightness_service_pb2 import BrightnessObservation + + mock_client_stub = mocker.MagicMock() + + mock_brightness_observation = BrightnessObservation() + mock_brightness_observation.uuid = str(uuid.uuid4()) + mock_brightness_observation.utc_iso = datetime.now(timezone.utc).isoformat() + mock_brightness_observation.mpsas = 10. + + mock_client_stub.GetBrightnessObservation.return_value = mock_brightness_observation + + mocker.patch("pp.cells.cell_publisher.BrightnessServiceStub", return_value=mock_client_stub) + return mock_client_stub + + +@pytest.fixture +def mock_pika_channel(mocker): + channel_mock = MagicMock() + connection_mock = MagicMock() + connection_mock.channel.return_value = channel_mock + mocker.patch("pika.BlockingConnection", return_value=connection_mock) + return channel_mock + + +@pytest.fixture +def publisher(mock_grpc_client, mock_pika_channel): + return CellPublisher(api_host="localhost", + api_port=50051, + channel=mock_pika_channel, + prediction_queue="prediction", + cycle_queue="cycle") diff --git a/pp/tests/test_covering.py b/pp/tests/test_covering.py new file mode 100644 index 0000000..f288112 --- /dev/null +++ b/pp/tests/test_covering.py @@ -0,0 +1,29 @@ +import json +from pathlib import Path + +import pytest +from pp.cells.cell_covering import CellCovering + +@pytest.mark.parametrize("geojson_path", [ + (Path.cwd() / "pp" / "cells" / "land.geojson") +]) +def test_cell_covering_polygons_is_one_to_one_with_features(geojson_path): + cell_covering = CellCovering(path_to_geojson=geojson_path) + with open(geojson_path) as f: + gj = json.load(f) + + assert len(cell_covering.polygons) == len(gj["features"]) + +@pytest.mark.parametrize("geojson_path", [ + (Path.cwd() / "pp" / "cells" / "land.geojson") +]) +def test_cell_covering_set_nonempty(geojson_path): + cell_covering = CellCovering(path_to_geojson=geojson_path) + assert bool(cell_covering.covering) + +@pytest.mark.parametrize("geojson_path", [ + (Path.cwd() / "pp" / "cells" / "fake.geojson") +]) +def test_cell_covering_with_bad_geojson_path(geojson_path): + with pytest.raises(FileNotFoundError): + cell_covering = CellCovering(path_to_geojson=geojson_path) diff --git a/pp/tests/test_publisher.py b/pp/tests/test_publisher.py index ea27d3d..0630881 100644 --- a/pp/tests/test_publisher.py +++ b/pp/tests/test_publisher.py @@ -1,55 +1,22 @@ -from unittest.mock import MagicMock from datetime import datetime, timedelta -import uuid import pytest - -from pp.cells.cell_publisher import CellPublisher -from pp.cells.cell_covering import CellCovering - - -@pytest.fixture -def mock_grpc_client(mocker): - from datetime import datetime, timezone - from pp.stubs.brightness_service_pb2 import BrightnessObservation - - mock_client_stub = mocker.MagicMock() - - mock_brightness_observation = BrightnessObservation() - mock_brightness_observation.uuid = str(uuid.uuid4()) - mock_brightness_observation.utc_iso = datetime.now(timezone.utc).isoformat() - mock_brightness_observation.mpsas = 10. - - mock_client_stub.GetBrightnessObservation.return_value = mock_brightness_observation - - mocker.patch("pp.cells.cell_publisher.BrightnessServiceStub", return_value=mock_client_stub) - return mock_client_stub - - -@pytest.fixture -def mock_pika_channel(mocker): - channel_mock = MagicMock() - connection_mock = MagicMock() - connection_mock.channel.return_value = channel_mock - mocker.patch("pika.BlockingConnection", return_value=connection_mock) - return channel_mock - - -@pytest.fixture -def publisher(mock_grpc_client, mock_pika_channel): - return CellPublisher(api_host="localhost", - api_port=50051, - channel=mock_pika_channel, - prediction_queue="prediction", - cycle_queue="cycle") - -def test_brightness_message_publish(publisher, mock_pika_channel): - cell = "89283082813ffff" - publisher.publish_cell_brightness_message(cell) +from .fixtures import * + +@pytest.mark.parametrize("cell_id", [ + ("89283082813ffff"), + ("8928308280fffff"), + ("89283082807ffff"), +]) +def test_can_publish_cell_brightness(cell_id, publisher, mock_pika_channel): + publisher.publish_cell_brightness_message(cell_id) mock_pika_channel.basic_publish.assert_called_once() -def test_cycle_completion_message_publish(publisher, mock_pika_channel): - then = datetime.now() - timedelta(minutes=5) +@pytest.mark.parametrize("minutes_ago", [ + (i) for i in range(1, 10) +]) +def test_can_publish_cycle_complete(minutes_ago, publisher, mock_pika_channel): + then = datetime.now() - timedelta(minutes=minutes_ago) now = datetime.now() publisher.publish_cycle_completion_message(then, now) mock_pika_channel.basic_publish.assert_called_once() diff --git a/update-open-meteo-data.sh b/update-open-meteo-data.sh index 39c7673..8f2521d 100755 --- a/update-open-meteo-data.sh +++ b/update-open-meteo-data.sh @@ -4,10 +4,8 @@ volume_name="open-meteo-data" if docker volume ls -q | grep -q "^${volume_name}$"; then echo "volume $volume_name exists; updating volume" - docker run -it --rm -v open-meteo-data:/app/data ghcr.io/open-meteo/open-meteo sync ecmwf_ifs04 cloud_cover,temperature_2m + docker run -it --rm -v open-meteo-data:/app/data ghcr.io/open-meteo/open-meteo sync ecmwf_ifs04 cloud_cover,temperature_2m else echo "$volume_name does not exist and must be created" exit 1 fi - - From a04d81fecad80d7ad492e42e0b66491d96e0870d Mon Sep 17 00:00:00 2001 From: Kevin Donahue Date: Sat, 21 Dec 2024 12:10:21 -0500 Subject: [PATCH 2/2] add better consumer test --- README.md | 11 ++++---- docker-compose.yml | 2 +- pc/pc/config.py | 2 ++ pc/pc/consumer/consumer.py | 18 ++++++++----- pc/pc/main.py | 18 ++++++------- pc/pc/persistence/db.py | 6 ++--- pc/tests/fixtures.py | 55 +++++++++++++++++++++++++------------- pc/tests/test_consumer.py | 23 ++++++++++------ 8 files changed, 83 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index e8f6e41..ab3013a 100644 --- a/README.md +++ b/README.md @@ -57,11 +57,11 @@ After rabbitmq starts up, the producer and consumer containers will start up, at which point you should see output like this: ```log -producer-1 | 2024-11-13 03:01:02,478 [INFO] publishing {'uuid': 'c6df89c5-a4fa-48fc-bfd8-11d08494902f', 'lat': 16.702868303031234, 'lon': -13.374845104752373, 'h3_id': '8055fffffffffff', 'mpsas': 6.862955570220947, 'timestamp_utc': '2024-11-13T03:01:02.478000+00:00'} to brightness.prediction -producer-1 | 2024-11-13 03:01:02,553 [INFO] publishing {'uuid': '9b5f2e8b-c22d-4d05-900e-0156f78632ce', 'lat': 26.283628653081813, 'lon': 62.954274989658984, 'h3_id': '8043fffffffffff', 'mpsas': 9.472949028015137, 'timestamp_utc': '2024-11-13T03:01:02.552848+00:00'} to brightness.prediction -producer-1 | 2024-11-13 03:01:02,625 [INFO] publishing {'uuid': 'fbbc3cd5-839d-43de-a7c4-8f51100679fd', 'lat': -4.530154895350926, 'lon': -42.02241568705745, 'h3_id': '8081fffffffffff', 'mpsas': 9.065463066101074, 'timestamp_utc': '2024-11-13T03:01:02.624759+00:00'} to brightness.prediction -producer-1 | 2024-11-13 03:01:02,626 [INFO] publishing {'start_time_utc': '2024-11-13T03:01:00.114586+00:00', 'end_time_utc': '2024-11-13T03:01:02.626208+00:00', 'duration_s': 2} to brightness.cycle -consumer-1 | 2024-11-13 03:01:02,631 [INFO] cycle completed with {'uuid': '4bb0c627-596c-42be-a93a-26f36c5ca3c1', 'lat': 55.25746462939812, 'lon': 127.08774514928741, 'h3_id': '8015fffffffffff', 'mpsas': 23.763256072998047, 'timestamp_utc': datetime.datetime(2024, 11, 13, 3, 1, 1, 129155, tzinfo=datetime.timezone.utc)} +producer-1 | 2024-12-21 17:08:55,237 [INFO] publishing {'uuid': '0cdacdcb-dcf3-4d5c-9e60-94d397d89840', 'lat': 69.66345294982115, 'lon': -30.968044606549025, 'h3_id': '8007fffffffffff', 'mpsas': 24.703824996948242, 'timestamp_utc': '2024-12-21T17:08:55.236185+00:00'} to brightness.prediction +producer-1 | 2024-12-21 17:08:55,355 [INFO] publishing {'uuid': 'f16a7b7c-039d-44d6-b764-fc37fadad1b7', 'lat': 26.80710329336693, 'lon': 109.167486033384, 'h3_id': '8041fffffffffff', 'mpsas': 10.82265853881836, 'timestamp_utc': '2024-12-21T17:08:55.354661+00:00'} to brightness.prediction +producer-1 | 2024-12-21 17:08:55,356 [INFO] publishing {'start_time_utc': '2024-12-21T17:08:34.174937+00:00', 'end_time_utc': '2024-12-21T17:08:55.356353+00:00', 'duration_s': 21} to brightness.cycle +producer-1 | 2024-12-21 17:08:55,502 [INFO] publishing {'uuid': 'bc236db7-dd78-43cb-925b-78ea7c777f5e', 'lat': 16.702868303031234, 'lon': -13.374845104752373, 'h3_id': '8055fffffffffff', 'mpsas': 6.5024333000183105, 'timestamp_utc': '2024-12-21T17:08:55.501490+00:00'} to brightness.prediction +consumer-1 | 2024-12-21 17:08:55,507 [INFO] cycle completed with max observation {'uuid': '0fbfe7cd-4b49-49b3-9c51-b5560706a2d8', 'lat': -69.66345294982115, 'lon': 149.03195539345094, 'h3_id': '80edfffffffffff', 'mpsas': 28.068134307861328, 'timestamp_utc': datetime.datetime(2024, 12, 21, 17, 8, 53, 2272, tzinfo=datetime.timezone.utc)} ``` The above output means: @@ -94,3 +94,4 @@ producer: ## licensing This project is licensed under the AGPL-3.0 license. + diff --git a/docker-compose.yml b/docker-compose.yml index 5cc99ad..626630c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,7 +11,7 @@ services: - postgres-data:/var/lib/postgresql/data rabbitmq: - image: "rabbitmq:alpine" + image: "rabbitmq:latest" environment: RABBITMQ_DEFAULT_USER: "guest" RABBITMQ_DEFAULT_PASS: "guest" diff --git a/pc/pc/config.py b/pc/pc/config.py index bff1e2b..8d72a47 100644 --- a/pc/pc/config.py +++ b/pc/pc/config.py @@ -15,3 +15,5 @@ cycle_queue = os.getenv("AMQP_CYCLE_QUEUE", "brightness.cycle") amqp_url = f"amqp://{rabbitmq_user}:{rabbitmq_password}@{rabbitmq_host}" + +brightness_observation_table = "brightness_observation" diff --git a/pc/pc/consumer/consumer.py b/pc/pc/consumer/consumer.py index 875d2f8..167148c 100644 --- a/pc/pc/consumer/consumer.py +++ b/pc/pc/consumer/consumer.py @@ -35,18 +35,22 @@ async def connect(self): log.warning("exiting") sys.exit(1) - async def consume(self): + async def consume_from_queues(self): + """consume data from the prediction and cycle queues""" if self.connection is None: raise ValueError("there is no connection!") async with self.connection: channel = await self.connection.channel() - - prediction_queue = await channel.declare_queue(self._prediction_queue) - await prediction_queue.consume(self._on_prediction_message, no_ack=True) - - cycle_queue = await channel.declare_queue(self._cycle_queue) - await cycle_queue.consume(self._on_cycle_message, no_ack=True) + queues = { + self._prediction_queue: self._on_prediction_message, + self._cycle_queue: self._on_cycle_message + } + + for queue_name, handler in queues.items(): + log.info(f"consuming from {queue_name}") + queue = await channel.declare_queue(queue_name) + await queue.consume(handler, no_ack=True) log.info("waiting on messages") await asyncio.Future() diff --git a/pc/pc/main.py b/pc/pc/main.py index ed441e1..2a5d37c 100644 --- a/pc/pc/main.py +++ b/pc/pc/main.py @@ -2,7 +2,7 @@ import logging import typing -from pc.persistence.db import create_pg_connection_pool, create_brightness_table +from pc.persistence.db import create_pg_connection_pool, setup_table from pc.persistence.models import BrightnessObservation from pc.consumer.consumer import Consumer from pc.config import amqp_url, prediction_queue, cycle_queue @@ -11,16 +11,16 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger(__name__) -def on_cycle_completion(brightness_observation: BrightnessObservation): - log.info(f"cycle completed with {brightness_observation.model_dump()}") +def on_cycle_completion(max_observation: BrightnessObservation): + # TODO communicate this result to cycle recipients + log.info(f"cycle completed with max observation {max_observation.model_dump()}") -async def main(): +async def consume_brightness(): pool = await create_pg_connection_pool() if pool is None: raise ValueError("no connection pool!") - - await create_brightness_table(pool) + await setup_table(pool) consumer = Consumer( url=amqp_url, @@ -30,11 +30,11 @@ async def main(): on_cycle_completion=on_cycle_completion ) await consumer.connect() - await consumer.consume() + await consumer.consume_from_queues() if __name__ == "__main__": try: - asyncio.run(main()) + asyncio.run(consume_brightness()) except Exception as e: - log.error(f"failed to run: {e}") + log.error(f"failed to consume brightness: {e}") diff --git a/pc/pc/persistence/db.py b/pc/pc/persistence/db.py index 32bcc2b..ed679a5 100644 --- a/pc/pc/persistence/db.py +++ b/pc/pc/persistence/db.py @@ -3,13 +3,11 @@ import asyncpg -from ..config import pg_host,pg_port,pg_user,pg_password,pg_database +from ..config import pg_host, pg_port, pg_user, pg_password, pg_database, brightness_observation_table from .models import BrightnessObservation, CellCycle log = logging.getLogger(__name__) -brightness_observation_table = "brightness_observation" - async def create_pg_connection_pool() -> typing.Optional[asyncpg.Pool]: pool = await asyncpg.create_pool( user=pg_user, @@ -22,7 +20,7 @@ async def create_pg_connection_pool() -> typing.Optional[asyncpg.Pool]: ) return pool -async def create_brightness_table(pool: asyncpg.Pool): +async def setup_table(pool: asyncpg.Pool): async with pool.acquire() as conn: await conn.execute( f""" diff --git a/pc/tests/fixtures.py b/pc/tests/fixtures.py index 7296191..e8c78f3 100644 --- a/pc/tests/fixtures.py +++ b/pc/tests/fixtures.py @@ -1,27 +1,46 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from pc.consumer.consumer import Consumer + +@pytest.fixture +def mock_connection(): + connection = AsyncMock() + channel = AsyncMock() + connection.channel.return_value = channel + return connection + +@pytest.fixture +def mock_channel(mock_connection): + return mock_connection.channel.return_value + @pytest.fixture -async def mock_asyncpg_pool(): - with patch("asyncpg.create_pool") as mock_create_pool: - mock_pool = AsyncMock() - mock_create_pool.return_value = mock_pool +def mock_queues(): + prediction_queue = AsyncMock() + cycle_queue = AsyncMock() + return {"prediction": prediction_queue, "cycle": cycle_queue} - mock_connection = AsyncMock() - mock_pool.acquire.return_value.__aenter__.return_value = mock_connection - yield mock_pool +@pytest.fixture +def mock_pool(): + return AsyncMock() + +@pytest.fixture +def mock_handler(): + return AsyncMock() + +@pytest.fixture +def mock_shutdown(): + return MagicMock() @pytest.fixture -def consumer(mock_asyncpg_pool): - prediction_queue="prediction" - cycle_queue="cycle" - - return Consumer( - url="amqp://localhost", - prediction_queue=prediction_queue, - cycle_queue=cycle_queue, - connection_pool=mock_asyncpg_pool, - on_cycle_completion=lambda _: None +def consumer(mock_connection, mock_pool, mock_handler, mock_shutdown): + consumer = Consumer( + url="amqp://test", + prediction_queue="prediction", + cycle_queue="cycle", + connection_pool=mock_pool, + on_cycle_completion=mock_handler, ) + consumer.connection = mock_connection + return consumer diff --git a/pc/tests/test_consumer.py b/pc/tests/test_consumer.py index 8ead988..8e823e2 100644 --- a/pc/tests/test_consumer.py +++ b/pc/tests/test_consumer.py @@ -1,16 +1,23 @@ +import asyncio from unittest.mock import AsyncMock, patch import pytest import asyncpg +from pc.consumer.consumer import Consumer + from .fixtures import * -@patch("pc.consumer.consumer.Consumer.connect", new_callable=AsyncMock) @pytest.mark.asyncio -async def test_consumer_can_connect(mock_connect, consumer): - mock_connection = AsyncMock() - mock_channel = AsyncMock() - mock_connect.return_value = mock_connection - mock_connection.channel.return_value = mock_channel - await consumer.connect() - mock_connect.assert_called_once() +async def test_consumer_can_consume_from_queues(consumer: Consumer, mock_channel, mock_queues): + mock_channel.declare_queue.side_effect = [ + mock_queues["prediction"], + mock_queues["cycle"], + ] + task = asyncio.create_task(consumer.consume_from_queues()) + await asyncio.sleep(0.1) + task.cancel() + mock_channel.declare_queue.assert_any_call("prediction") + mock_channel.declare_queue.assert_any_call("cycle") + mock_queues["prediction"].consume.assert_called_once() + mock_queues["cycle"].consume.assert_called_once()