From fd8a926f72c1a5e9a4fc59d4c13be44df2952819 Mon Sep 17 00:00:00 2001 From: Kevin Donahue Date: Mon, 30 Sep 2024 20:02:44 -0400 Subject: [PATCH] use orm in consumer for observation storage --- api/requirements.txt | 2 +- docker-compose.yml | 4 +- pc/Dockerfile | 3 +- pc/README.md | 7 +- pc/pc/config.py | 6 +- pc/pc/main.py | 74 +------------------ .../pc/persistence/__init__.py | 0 pc/pc/persistence/db.py | 15 ++++ pc/pc/persistence/models.py | 14 ++++ pc/pc/rabbitmq.py | 43 +++++++++++ pc/pc/tests/__init__.py | 0 pc/pc/websockets_handler.py | 33 --------- pc/requirements.txt | 3 +- pp/.idea/misc.xml | 2 +- pp/.idea/pp.iml | 2 +- pp/Dockerfile | 3 +- pp/README.md | 4 +- pp/pp/cells/__init__.py | 0 pp/pp/{cells.py => cells/h3.py} | 0 pp/pp/config.py | 6 +- pp/pp/main.py | 10 +-- pp/pp/models/__init__.py | 0 pc/pc/model.py => pp/pp/models/models.py | 6 +- pp/pp/prediction.py | 42 +++++------ pp/pp/prediction_message.py | 13 ---- pp/requirements.txt | 1 + 26 files changed, 123 insertions(+), 170 deletions(-) rename pp/tests/test_prediction_publisher.py => pc/pc/persistence/__init__.py (100%) create mode 100644 pc/pc/persistence/db.py create mode 100644 pc/pc/persistence/models.py create mode 100644 pc/pc/rabbitmq.py create mode 100644 pc/pc/tests/__init__.py delete mode 100644 pc/pc/websockets_handler.py create mode 100644 pp/pp/cells/__init__.py rename pp/pp/{cells.py => cells/h3.py} (100%) create mode 100644 pp/pp/models/__init__.py rename pc/pc/model.py => pp/pp/models/models.py (57%) delete mode 100644 pp/pp/prediction_message.py diff --git a/api/requirements.txt b/api/requirements.txt index 5bc387f..24a0f25 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,7 +1,7 @@ astroplan==0.9.1 astropy==6.0.0 pandas==2.1.4 -torch~=2.2.2 +torch~=2.2.x requests==2.31.0 fastapi~=0.110.2 httpx==0.26.0 diff --git a/docker-compose.yml b/docker-compose.yml index 35d4435..17e6acc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,3 @@ -version: "3" - services: postgres: image: "postgres:latest" @@ -65,8 +63,8 @@ services: build: ./pp environment: API_VERSION: "v1" - MODEL_VERSION: "0.1.0" API_HOST: "api" + MODEL_VERSION: "0.1.0" KEYDB_HOST: "keydb" RABBITMQ_HOST: "rabbitmq" TASK_SLEEP_INTERVAL: "0.5" diff --git a/pc/Dockerfile b/pc/Dockerfile index 0936498..6a6e30a 100644 --- a/pc/Dockerfile +++ b/pc/Dockerfile @@ -1,4 +1,5 @@ -FROM python:3.11.7-slim-bullseye +ARG VERSION=3.12.6-slim-bookworm +FROM python:${VERSION} LABEL maintainer="Kevin Donahue " diff --git a/pc/README.md b/pc/README.md index 166540e..6372cd2 100644 --- a/pc/README.md +++ b/pc/README.md @@ -1,11 +1,10 @@ # pc -prediction consumer. +> prediction consumer. -pulls predictions messages off of prediction queue and into postgres and websockets. - -## connect to timescale instance +Pulls brightness observation messages off of the prediction queue and into postgres. ```shell +# connect to postgres instance psql -d "postgres://postgres:password@localhost/postgres" ``` diff --git a/pc/pc/config.py b/pc/pc/config.py index f45ddf3..c537bc7 100644 --- a/pc/pc/config.py +++ b/pc/pc/config.py @@ -5,13 +5,9 @@ PG_DATABASE = os.getenv("PG_DATABASE", "localhost") PG_HOST = os.getenv("PG_HOST", "postgres") PG_PORT = int(os.getenv("PG_PORT", 5432)) - -pg_dsn = f"dbname={PG_DATABASE} user={PG_USER} password={PG_PASSWORD} host={PG_HOST}" +pg_dsn = f"postgres://{PG_USER}:{PG_PASSWORD}@{PG_HOST}:{PG_PORT}/{PG_DATABASE}" AMQP_USER = os.getenv("AMQP_USER", "guest") AMQP_PASSWORD = os.getenv("AMQP_PASSWORD", "guest") AMQP_HOST = os.getenv("AMQP_HOST", "localhost") AMQP_PREDICTION_QUEUE = os.getenv("AMQP_PREDICTION_QUEUE", "prediction") - -WS_HOST = os.getenv("WS_HOST", "consumer") -WS_PORT = int(os.getenv("WS_PORT", 8090)) diff --git a/pc/pc/main.py b/pc/pc/main.py index e2b584c..b4463c3 100644 --- a/pc/pc/main.py +++ b/pc/pc/main.py @@ -1,84 +1,18 @@ -import json import asyncio import logging -import psycopg -import aio_pika - -from pc.config import * -from pc.model import BrightnessMessage -from pc.websockets_handler import WebSocketsHandler +from pc.persistence.db import initialize_db +from pc.rabbitmq import consume_brightness_observations logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger(__name__) -websockets_handler = WebSocketsHandler() - - -def initialize_db(): - """create the predictions table if it does not exist""" - with psycopg.connect(pg_dsn) as conn: - with conn.cursor() as cur: - cur.execute(""" - CREATE TABLE IF NOT EXISTS predictions ( - id serial PRIMARY KEY, - h3_id text NOT NULL, - utc_iso text NOT NULL, - utc_ns bigint NOT NULL, - mpsas real NOT NULL, - model_version text NOT NULL - ) - """) - conn.commit() - - -def insert_brightness_message_in_db(message: BrightnessMessage): - """insert subset of brightness message into the predictions table""" - with psycopg.connect(pg_dsn) as conn: - with conn.cursor() as cur: - log.info(f"inserting brightness message for {message.h3_id}") - - cur.execute(""" - INSERT INTO predictions (h3_id, utc_iso, utc_ns, mpsas, model_version) - VALUES (%s, %s, %s, %s, %s) - """, (message.h3_id, message.utc_iso, message.utc_ns, message.mpsas, message.model_version)) - conn.commit() - - -async def consume_from_rabbitmq(): - """create table in pg if needed and begin consuming messages from the queue, - storing them in the predictions table""" - try: - amqp_connection = await aio_pika.connect_robust(f"amqp://{AMQP_USER}:{AMQP_PASSWORD}@{AMQP_HOST}") - except Exception as e: - import sys - - log.error(f"could not form amqp connection because {e}; has rabbitmq started?") - log.warning("exiting") - sys.exit(1) - else: - async with amqp_connection: - - channel = await amqp_connection.channel() - queue = await channel.declare_queue(AMQP_PREDICTION_QUEUE) - - async for m in queue: - async with m.process(): - # serialize the message coming over the queue and add to postgres - json_data = json.loads(m.body.decode()) - message = BrightnessMessage(**json_data) - - insert_brightness_message_in_db(message) - await websockets_handler.broadcast(message) - - await asyncio.Future() - async def main(): - coroutines = [websockets_handler.setup(), consume_from_rabbitmq()] + """run the primary coroutines together""" + coroutines = [initialize_db(), consume_brightness_observations()] await asyncio.gather(*coroutines) if __name__ == "__main__": - initialize_db() asyncio.run(main()) diff --git a/pp/tests/test_prediction_publisher.py b/pc/pc/persistence/__init__.py similarity index 100% rename from pp/tests/test_prediction_publisher.py rename to pc/pc/persistence/__init__.py diff --git a/pc/pc/persistence/db.py b/pc/pc/persistence/db.py new file mode 100644 index 0000000..d32456c --- /dev/null +++ b/pc/pc/persistence/db.py @@ -0,0 +1,15 @@ +import logging + +from pc.config import pg_dsn +from tortoise import Tortoise + +log = logging.getLogger(__name__) + + +async def initialize_db(): + log.info(f"initializing db at {pg_dsn}") + await Tortoise.init( + db_url=pg_dsn, + modules={"models": ["pc.persistence.models"]} + ) + await Tortoise.generate_schemas() diff --git a/pc/pc/persistence/models.py b/pc/pc/persistence/models.py new file mode 100644 index 0000000..400f8dc --- /dev/null +++ b/pc/pc/persistence/models.py @@ -0,0 +1,14 @@ +from tortoise import fields, models + + +class BrightnessObservation(models.Model): + uuid = fields.CharField(primary_key=True, max_length=36) + lat = fields.FloatField() + lon = fields.FloatField() + h3_id = fields.CharField(max_length=15) + utc_iso = fields.CharField(max_length=30) + mpsas = fields.FloatField() + model_version = fields.CharField(max_length=36) + + def __str__(self): + return f"{self.h3_id}:{self.uuid}" diff --git a/pc/pc/rabbitmq.py b/pc/pc/rabbitmq.py new file mode 100644 index 0000000..2f21608 --- /dev/null +++ b/pc/pc/rabbitmq.py @@ -0,0 +1,43 @@ +import json +import logging +import asyncio + +import aio_pika + +from pc.config import * +from pc.persistence.models import BrightnessObservation + +log = logging.getLogger(__name__) + + +# class RabbitMQConsumer: +# def __init__(self, user: str, password: str, host: str): +# self.url = f"amqp://{user}:{password}@{host}" +# +# async def connect(self): +# pass + + +async def consume_brightness_observations(): + """begin consuming messages from the queue, storing them in predictions table""" + try: + amqp_connection = await aio_pika.connect_robust(f"amqp://{AMQP_USER}:{AMQP_PASSWORD}@{AMQP_HOST}") + except Exception as e: + import sys + + log.error(f"could not form amqp connection because {e}; has rabbitmq started?") + log.warning("exiting") + sys.exit(1) + else: + async with amqp_connection: + channel = await amqp_connection.channel() + queue = await channel.declare_queue(AMQP_PREDICTION_QUEUE) + + async for message in queue: + async with message.process(): + brightness_observation_json = json.loads(message.body.decode()) + brightness_observation = BrightnessObservation(**brightness_observation_json) + + log.info(f"saving brightness observation {brightness_observation}") + await brightness_observation.save() + await asyncio.Future() diff --git a/pc/pc/tests/__init__.py b/pc/pc/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pc/pc/websockets_handler.py b/pc/pc/websockets_handler.py deleted file mode 100644 index 2792a91..0000000 --- a/pc/pc/websockets_handler.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio -import json -import logging -from dataclasses import asdict - -from websockets import serve, broadcast - -from pc.config import * -from pc.model import BrightnessMessage - -log = logging.getLogger(__name__) - - -class WebSocketsHandler: - clients = set() - - async def setup(self): - async def register_client(websocket): - log.info(f"registering {websocket}") - self.clients.add(websocket) - try: - await websocket.wait_closed() - finally: - self.clients.remove(websocket) - - async with serve(register_client, WS_HOST, WS_PORT): - await asyncio.Future() - - async def broadcast(self, message: BrightnessMessage): - """send the message to all websockets""" - log.info(f"broadcasting to {len(self.clients)} websocket clients on {WS_HOST}:{WS_PORT}") - message_json = json.dumps(asdict(message)) - broadcast(self.clients, message_json) diff --git a/pc/requirements.txt b/pc/requirements.txt index 80e7be3..6788c32 100644 --- a/pc/requirements.txt +++ b/pc/requirements.txt @@ -1,3 +1,2 @@ aio-pika==9.4.2 -websockets==12.0 -psycopg~=3.2.1 \ No newline at end of file +tortoise-orm[asyncpg]~=0.21.6 diff --git a/pp/.idea/misc.xml b/pp/.idea/misc.xml index 320019d..fcc51a0 100644 --- a/pp/.idea/misc.xml +++ b/pp/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/pp/.idea/pp.iml b/pp/.idea/pp.iml index 3e9d178..3d5df41 100644 --- a/pp/.idea/pp.iml +++ b/pp/.idea/pp.iml @@ -5,7 +5,7 @@ - + diff --git a/pp/Dockerfile b/pp/Dockerfile index e9ee5dc..8c32c50 100644 --- a/pp/Dockerfile +++ b/pp/Dockerfile @@ -1,4 +1,5 @@ -FROM python:3.11.7-slim-bullseye +ARG VERSION=3.12.6-slim-bookworm +FROM python:${VERSION} LABEL maintainer="Kevin Donahue " diff --git a/pp/README.md b/pp/README.md index db590db..c94643f 100644 --- a/pp/README.md +++ b/pp/README.md @@ -1,8 +1,8 @@ # pp -prediction producer. +> prediction producer. -> handles sky brightness prediction across resolution 0 h3 cells and puts on rabbitmq +Handles sky brightness prediction across resolution 0 h3 cells and puts results on rabbitmq ## monitoring diff --git a/pp/pp/cells/__init__.py b/pp/pp/cells/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pp/pp/cells.py b/pp/pp/cells/h3.py similarity index 100% rename from pp/pp/cells.py rename to pp/pp/cells/h3.py diff --git a/pp/pp/config.py b/pp/pp/config.py index 8c78614..c354183 100644 --- a/pp/pp/config.py +++ b/pp/pp/config.py @@ -4,9 +4,7 @@ keydb_port = int(os.getenv("KEYDB_PORT", 6379)) rabbitmq_host = os.getenv("RABBITMQ_HOST", "localhost") -prediction_queue = os.getenv("PREDICTION_QUEUE", "prediction") - -task_sleep_interval = float(os.getenv("TASK_SLEEP_INTERVAL", "0.5")) +queue_name = os.getenv("QUEUE_NAME", "prediction") api_protocol = os.getenv("API_PROTOCOL", "http") api_port = int(os.getenv("API_PORT", "8000")) @@ -14,3 +12,5 @@ api_version = os.getenv("API_VERSION", "v1") model_version = os.getenv("MODEL_VERSION", "0.1.0") + +task_sleep_interval = float(os.getenv("TASK_SLEEP_INTERVAL", "0.5")) diff --git a/pp/pp/main.py b/pp/pp/main.py index 92fc221..8b71965 100644 --- a/pp/pp/main.py +++ b/pp/pp/main.py @@ -5,9 +5,9 @@ import pika from pika.exceptions import AMQPConnectionError -from .prediction import publish_cell_brightness -from .cells import get_h3_cells -from .config import rabbitmq_host, prediction_queue, task_sleep_interval +from .prediction import publish_observation_to_queue +from .cells.h3 import get_h3_cells +from .config import rabbitmq_host, queue_name, task_sleep_interval logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger(__name__) @@ -19,7 +19,7 @@ async def main(): try: connection = pika.BlockingConnection(pika.ConnectionParameters(rabbitmq_host)) channel = connection.channel() - channel.queue_declare(queue=prediction_queue) + channel.queue_declare(queue=queue_name) except AMQPConnectionError as e: import sys @@ -36,7 +36,7 @@ async def main(): async with httpx.AsyncClient() as client: while True: for cell_coords in h3_cell_coords: - await asyncio.create_task(publish_cell_brightness(client, cell_coords, channel)) + await asyncio.create_task(publish_observation_to_queue(client, cell_coords, channel)) await asyncio.sleep(task_sleep_interval) except Exception as e: log.error(f"could not continue publishing because {e}") diff --git a/pp/pp/models/__init__.py b/pp/pp/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pc/pc/model.py b/pp/pp/models/models.py similarity index 57% rename from pc/pc/model.py rename to pp/pp/models/models.py index 1c128ee..513f35d 100644 --- a/pc/pc/model.py +++ b/pp/pp/models/models.py @@ -1,13 +1,11 @@ -from dataclasses import dataclass +from pydantic import BaseModel -@dataclass -class BrightnessMessage: +class BrightnessObservation(BaseModel): uuid: str lat: float lon: float h3_id: str utc_iso: str - utc_ns: int mpsas: float model_version: str diff --git a/pp/pp/prediction.py b/pp/pp/prediction.py index 9df99ad..3a21d7e 100644 --- a/pp/pp/prediction.py +++ b/pp/pp/prediction.py @@ -1,33 +1,33 @@ import uuid import json +import logging from typing import Tuple -from dataclasses import asdict from datetime import datetime -import logging from pika.channel import Channel import redis import httpx import h3 -from .config import model_version, api_protocol, api_host, api_port, api_version, prediction_queue, keydb_host, \ +from .config import model_version, api_protocol, api_host, api_port, api_version, queue_name, keydb_host, \ keydb_port -from .prediction_message import BrightnessMessage +from .models.models import BrightnessObservation log = logging.getLogger(__name__) keydb = redis.Redis(host=keydb_host, port=keydb_port, db=0) -prediction_endpoint_url = f"{api_protocol}://{api_host}:{api_port}/api/{api_version}/predict" - def get_cell_id(lat, lon) -> str: """get the h3 cell for this lat and lon""" return h3.geo_to_h3(lat, lon, resolution=0) -async def create_brightness_message(client: httpx.AsyncClient, h3_lat: float, h3_lon: float) -> BrightnessMessage: +async def create_brightness_observation(client: httpx.AsyncClient, h3_lat: float, + h3_lon: float) -> BrightnessObservation: """create the object that will get published to the prediction queue.""" + prediction_endpoint_url = f"{api_protocol}://{api_host}:{api_port}/api/{api_version}/predict" + res = await client.get(prediction_endpoint_url, params={"lat": h3_lat, "lon": h3_lon}) res.raise_for_status() @@ -37,33 +37,33 @@ async def create_brightness_message(client: httpx.AsyncClient, h3_lat: float, h3 raise ValueError("no sky brightness reading in api response") utc_now = datetime.utcnow() - brightness_message = BrightnessMessage( + brightness_message = BrightnessObservation( uuid=str(uuid.uuid4()), lat=h3_lat, lon=h3_lon, h3_id=get_cell_id(h3_lat, h3_lon), utc_iso=utc_now.isoformat(), - utc_ns=int(utc_now.timestamp() * 1e9), mpsas=mpsas, model_version=model_version ) return brightness_message -async def publish_cell_brightness(client: httpx.AsyncClient, h3_coords: Tuple[float, float], channel: Channel): +async def publish_observation_to_queue(client: httpx.AsyncClient, h3_coords: Tuple[float, float], channel: Channel): """request and publish sky brightness at given h3 cell coords.""" - try: - lat, lon = h3_coords + lat, lon = h3_coords - m = await create_brightness_message(client, lat, lon) - message_body = asdict(m) - - log.info(f"publishing brightness message {message_body}") - channel.basic_publish(exchange="", routing_key=prediction_queue, body=json.dumps(message_body)) - - keydb.incr(m.h3_id) - log.info(f"{m.h3_id} has had {int(keydb.get(m.h3_id))} predictions published") + try: + observation = await create_brightness_observation(client, lat, lon) + log.info(f"publishing brightness observation for cell {observation.h3_id}") + channel.basic_publish(exchange="", routing_key=queue_name, body=json.dumps(observation.model_dump())) except httpx.HTTPStatusError as e: log.error(f"got bad status from api server {e}") except Exception as e: - log.error(f"could not publish prediction at {h3_coords} because {e}") + import traceback + + log.error("failed to publish prediction!") + log.error(traceback.format_exc()) + else: + keydb.incr(observation.h3_id) + log.info(f"cell {observation.h3_id} has had {int(keydb.get(observation.h3_id))} predictions published") diff --git a/pp/pp/prediction_message.py b/pp/pp/prediction_message.py deleted file mode 100644 index 1c128ee..0000000 --- a/pp/pp/prediction_message.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class BrightnessMessage: - uuid: str - lat: float - lon: float - h3_id: str - utc_iso: str - utc_ns: int - mpsas: float - model_version: str diff --git a/pp/requirements.txt b/pp/requirements.txt index 9cd9e8d..e523ab1 100644 --- a/pp/requirements.txt +++ b/pp/requirements.txt @@ -2,3 +2,4 @@ h3==3.7.7 httpx==0.27.0 pika==1.3.2 redis==5.0.7 +pydantic==2.9.2