From 36c00813a64b193fda1604a1dfe941a0d3729f70 Mon Sep 17 00:00:00 2001 From: Kevin Donahue Date: Sun, 27 Oct 2024 17:49:24 -0400 Subject: [PATCH] use asyncpg instead of orm --- README.md | 2 +- pc/pc/consumer/consumer.py | 16 +++++++----- pc/pc/main.py | 26 ++++++++++++------- pc/pc/persistence/db.py | 49 ++++++++++++++++++++++++++++------- pc/pc/persistence/models.py | 21 +++++++-------- pc/requirements.txt | 6 ++--- pc/tests/test_consumer.py | 22 +++++++++++----- pp/pp/cells/cell_publisher.py | 6 +++-- pp/pp/models/models.py | 2 +- 9 files changed, 100 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index bbc67e1..988dc80 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ earth, without a sensor. - [x] less noisy container startup - [ ] live updates to open meteo data while app is running - [ ] REST apis in addition to the gRPC ones -- [ ] better storage of predictions in order to faciliate grouping/sorting +- [x] better storage of predictions in order to faciliate grouping/sorting ## about diff --git a/pc/pc/consumer/consumer.py b/pc/pc/consumer/consumer.py index 8d7e372..9a875e5 100644 --- a/pc/pc/consumer/consumer.py +++ b/pc/pc/consumer/consumer.py @@ -3,19 +3,22 @@ import asyncio import typing +import asyncpg import aio_pika from aio_pika.abc import AbstractIncomingMessage, AbstractRobustChannel from pc.persistence.models import BrightnessObservation +from pc.persistence.db import insert_brightness_observation log = logging.getLogger(__name__) class Consumer: - def __init__(self, url: str, prediction_queue: str, cycle_queue: str): + def __init__(self, url: str, prediction_queue: str, cycle_queue: str, connection_pool: asyncpg.Pool): self._amqp_url = url self._prediction_queue = prediction_queue self._cycle_queue = cycle_queue + self._pool = connection_pool async def start(self): try: @@ -51,11 +54,10 @@ async def _on_prediction_message(self, message: AbstractIncomingMessage): """handle incoming message by storing in postgres""" try: log.debug(f"received message {message.body}") - brightness_observation_json = json.loads(message.body.decode()) - brightness_observation = BrightnessObservation(**brightness_observation_json) - - await brightness_observation.save() + message_dict: typing.Dict = json.loads(message.body.decode()) + brightness = BrightnessObservation(**message_dict) + await insert_brightness_observation(self._pool, brightness) except Exception as e: - log.error(f"could not save brightness observation {e}") + log.error(f"could not save brightness observation: {e}") else: - log.info(f"saved {brightness_observation}") + log.info(f"saved brightness of {brightness.h3_id}") diff --git a/pc/pc/main.py b/pc/pc/main.py index 038fbc1..552f73d 100644 --- a/pc/pc/main.py +++ b/pc/pc/main.py @@ -1,7 +1,7 @@ import asyncio import logging -from pc.persistence.db import initialize_db +from pc.persistence.db import create_pool, create_brightness_table from pc.consumer.consumer import Consumer from pc.config import amqp_url, prediction_queue, cycle_queue @@ -11,14 +11,22 @@ async def main(): - """run the primary coroutines together""" - consumer = Consumer(url=amqp_url, prediction_queue=prediction_queue, cycle_queue=cycle_queue) - coroutines = [ - initialize_db(), - consumer.start() - ] - await asyncio.gather(*coroutines) + pool = await create_pool() + if pool is None: + raise ValueError("connection pool is none") + + await create_brightness_table(pool) + consumer = Consumer( + url=amqp_url, + prediction_queue=prediction_queue, + cycle_queue=cycle_queue, + connection_pool=pool + ) + await consumer.start() if __name__ == "__main__": - asyncio.run(main()) + try: + asyncio.run(main()) + except Exception as e: + log.error(f"failed to run: {e}") diff --git a/pc/pc/persistence/db.py b/pc/pc/persistence/db.py index d32456c..490bc99 100644 --- a/pc/pc/persistence/db.py +++ b/pc/pc/persistence/db.py @@ -1,15 +1,46 @@ import logging +import typing -from pc.config import pg_dsn -from tortoise import Tortoise +import asyncpg -log = logging.getLogger(__name__) +from ..config import pg_host,pg_port,pg_user,pg_password,pg_database +from .models import BrightnessObservation +log = logging.getLogger(__name__) +table = "brightness_observation" -async def initialize_db(): - log.info(f"initializing db at {pg_dsn}") - await Tortoise.init( - db_url=pg_dsn, - modules={"models": ["pc.persistence.models"]} +async def create_pool() -> typing.Optional[asyncpg.Pool]: + pool = await asyncpg.create_pool( + user=pg_user, + password=pg_password, + database=pg_database, + host=pg_host, + port=pg_port, + min_size=1, + max_size=10 ) - await Tortoise.generate_schemas() + return pool + +async def create_brightness_table(pool: asyncpg.Pool): + async with pool.acquire() as conn: + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {table} ( + uuid UUID PRIMARY KEY, + lat DOUBLE PRECISION NOT NULL, + lon DOUBLE PRECISION NOT NULL, + h3_id TEXT NOT NULL, + mpsas DOUBLE PRECISION NOT NULL, + timestamp_utc TIMESTAMPTZ NOT NULL + ); + """ + ) + + +async def insert_brightness_observation(pool, observation: BrightnessObservation): + async with pool.acquire() as conn: + await conn.execute(f""" + INSERT INTO {table} (uuid, lat, lon, h3_id, mpsas, timestamp_utc) + VALUES ($1, $2, $3, $4, $5, $6) + """, observation.uuid, observation.lat, observation.lon, observation.h3_id, observation.mpsas, observation.timestamp_utc) + log.info(f"Inserted observation: {observation}") diff --git a/pc/pc/persistence/models.py b/pc/pc/persistence/models.py index 9b5c13d..53f151f 100644 --- a/pc/pc/persistence/models.py +++ b/pc/pc/persistence/models.py @@ -1,13 +1,10 @@ -from tortoise import fields, models +from pydantic import BaseModel +from datetime import datetime - -class BrightnessObservation(models.Model): - uuid = fields.CharField(primary_key=True, max_length=36) - lat = fields.FloatField() - lon = fields.FloatField() - h3_id = fields.CharField(max_length=15) - utc_iso = fields.CharField(max_length=32) - mpsas = fields.FloatField() - - def __str__(self): - return f"{self.__class__.__name__}(#{self.h3_id},{self.mpsas},{self.utc_iso})" +class BrightnessObservation(BaseModel): + uuid: str + lat: float + lon: float + h3_id: str + mpsas: float + timestamp_utc: datetime diff --git a/pc/requirements.txt b/pc/requirements.txt index f7be743..eb85dba 100644 --- a/pc/requirements.txt +++ b/pc/requirements.txt @@ -1,3 +1,3 @@ -aio-pika==9.4.2 -asyncpg -tortoise-orm[asyncpg]~=0.21.6 +aio-pika~=9.4.2 +asyncpg~=0.29.0 +pydantic~=2.9.2 diff --git a/pc/tests/test_consumer.py b/pc/tests/test_consumer.py index 8a509a8..2d3cba9 100644 --- a/pc/tests/test_consumer.py +++ b/pc/tests/test_consumer.py @@ -1,17 +1,27 @@ -from unittest import mock +from unittest.mock import AsyncMock, patch import pytest +import asyncpg from aio_pika import Message from pc.consumer.consumer import Consumer @pytest.fixture -def consumer(): +async def mock_asyncpg_pool(): + with patch("asyncpg.create_pool") as mock_create_pool: + mock_pool = AsyncMock() + mock_create_pool.return_value = mock_pool + + mock_connection = AsyncMock() + mock_pool.acquire.return_value.__aenter__.return_value = mock_connection + yield mock_pool + +@pytest.fixture +def consumer(mock_asyncpg_pool): amqp_url="amqp://localhost" prediction_queue="prediction" - return Consumer(url=amqp_url, prediction_queue=prediction_queue,cycle_queue="") + return Consumer(url=amqp_url, prediction_queue=prediction_queue,cycle_queue="",connection_pool=mock_asyncpg_pool) -@pytest.mark.skip @pytest.mark.asyncio -async def test_can_consume_message(consumer): - pass +async def test_consumer(consumer): + assert consumer is not None diff --git a/pp/pp/cells/cell_publisher.py b/pp/pp/cells/cell_publisher.py index 804612b..1e37f1f 100644 --- a/pp/pp/cells/cell_publisher.py +++ b/pp/pp/cells/cell_publisher.py @@ -52,10 +52,12 @@ def predict_cell_brightness(self, cell) -> None: lat=lat, lon=lon, h3_id=get_cell_id(lat, lon, resolution=6), - utc_iso=response.utc_iso, mpsas=response.mpsas, + timestamp_utc=response.utc_iso, ) - self._publish(self._prediction_queue, brightness_observation.model_dump()) + dumped = brightness_observation.model_dump() + dumped["timestamp_utc"] = brightness_observation.timestamp_utc.isoformat() + self._publish(self._prediction_queue, dumped) def run(self): cells = self.covering diff --git a/pp/pp/models/models.py b/pp/pp/models/models.py index 7d31bd1..f58dee7 100644 --- a/pp/pp/models/models.py +++ b/pp/pp/models/models.py @@ -7,8 +7,8 @@ class BrightnessObservation(BaseModel): lat: float lon: float h3_id: str - utc_iso: str mpsas: float + timestamp_utc: datetime class CellCycle(BaseModel): start_time_utc: datetime