From 3985e438c40d73aef5e86d2500200f3c8697d88d Mon Sep 17 00:00:00 2001 From: Kevin Donahue Date: Wed, 26 Jun 2024 21:52:13 -0400 Subject: [PATCH] add keydb --- docker-compose.yml | 7 +++++++ pp/pp/config.py | 2 ++ pp/pp/main.py | 5 ++--- pp/pp/prediction.py | 22 ++++++++++------------ pp/requirements.txt | 1 + 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index ab5c0ce..e2b1a9b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,11 @@ version: "3" services: + keydb: + image: "eqalpha/keydb:latest" + ports: + - "6379:6379" + rabbitmq: image: "rabbitmq:management" ports: @@ -41,12 +46,14 @@ services: environment: API_VERSION: "v1" API_HOST: "api" + KEYDB_HOST: "keydb" RABBITMQ_HOST: "rabbitmq" SLEEP_INTERVAL: "0.5" restart: on-failure depends_on: - rabbitmq - api + - keydb links: - rabbitmq diff --git a/pp/pp/config.py b/pp/pp/config.py index 6e715a7..fc620ae 100644 --- a/pp/pp/config.py +++ b/pp/pp/config.py @@ -1,5 +1,7 @@ import os +keydb_host = os.getenv("KEYDB_HOST", "keydb") +keydb_port = int(os.getenv("KEYDB_PORT", 6379)) rabbitmq_host = os.getenv("RABBITMQ_HOST", "localhost") prediction_queue = os.getenv("PREDICTION_QUEUE", "prediction") diff --git a/pp/pp/main.py b/pp/pp/main.py index ab1cd83..3091489 100644 --- a/pp/pp/main.py +++ b/pp/pp/main.py @@ -5,7 +5,7 @@ import pika from pika.exceptions import AMQPConnectionError -from .prediction import predict_on_cell_coords +from .prediction import publish_cell_prediction from .cells import get_res_zero_cell_coords from .config import rabbitmq_host, prediction_queue, task_sleep_interval @@ -39,9 +39,8 @@ async def main(): try: async with httpx.AsyncClient() as client: while True: - # allow predictions over cells to run in interleaved way for cell_coords in resolution_zero_cell_coords: - await asyncio.create_task(predict_on_cell_coords(client, cell_coords, channel)) + await asyncio.create_task(publish_cell_prediction(client, cell_coords, channel)) await asyncio.sleep(task_sleep_interval) except Exception as e: log.error(f"could not continue publishing because {e}") diff --git a/pp/pp/prediction.py b/pp/pp/prediction.py index f743985..5311ecd 100644 --- a/pp/pp/prediction.py +++ b/pp/pp/prediction.py @@ -4,19 +4,21 @@ import logging from pika.channel import Channel +import redis import httpx import h3 -from .config import api_protocol, api_host, api_port, api_version, prediction_queue +from .config import api_protocol, api_host, api_port, api_version, prediction_queue, keydb_host, keydb_port from .message import PredictionMessage log = logging.getLogger(__name__) +keydb = redis.Redis(host=keydb_host, port=keydb_port, db=0) prediction_endpoint_url = f"{api_protocol}://{api_host}:{api_port}/api/{api_version}/predict" async def get_prediction_message_for_lat_lon(client: httpx.AsyncClient, lat: float, lon: float) -> PredictionMessage: - """create the object that will get published to rabbitmq""" + """create the object that will get published to rabbitmq.""" res = await client.get(prediction_endpoint_url, params={"lat": lat, "lon": lon}) res.raise_for_status() @@ -27,18 +29,15 @@ async def get_prediction_message_for_lat_lon(client: httpx.AsyncClient, lat: flo message = PredictionMessage( lat=lat, lon=lon, - h3_id=h3.geo_to_h3(lat, lon, 0), + h3_id=h3.geo_to_h3(lat, lon, resolution=0), utc=datetime.utcnow().isoformat(), mpsas=mpsas, ) return message -# message_store = {} - - -async def predict_on_cell_coords(client: httpx.AsyncClient, h3_coords: Tuple[float, float], channel: Channel): - """retrieve and publish a sky brightness prediction at coords for the h3 cell""" +async def publish_cell_prediction(client: httpx.AsyncClient, h3_coords: Tuple[float, float], channel: Channel): + """retrieve and publish a sky brightness prediction at h3 cell coords.""" import json try: @@ -50,10 +49,9 @@ async def predict_on_cell_coords(client: httpx.AsyncClient, h3_coords: Tuple[flo log.info(f"publishing {message_body} to {prediction_queue}") channel.basic_publish(exchange="", routing_key=prediction_queue, body=json.dumps(message_body)) - # keep track of how many messages are published for each cell - # message_store[m.h3_id] = message_store.get(m.h3_id, 0) + 1 - # with open("data.json", "w") as f: - # json.dump(message_store, f, indent=4) + keydb.incr(m.h3_id) + num_predictions_published = int(keydb.get(m.h3_id)) + log.info(f"{m.h3_id} has had {num_predictions_published} predictions published") except httpx.HTTPStatusError as e: log.error(f"got bad status from api server {e}") diff --git a/pp/requirements.txt b/pp/requirements.txt index d56fc87..9cd9e8d 100644 --- a/pp/requirements.txt +++ b/pp/requirements.txt @@ -1,3 +1,4 @@ h3==3.7.7 httpx==0.27.0 pika==1.3.2 +redis==5.0.7