Skip to content

Commit

Permalink
add keydb
Browse files Browse the repository at this point in the history
  • Loading branch information
nonnontrivial committed Jun 27, 2024
1 parent 9076009 commit 3985e43
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 15 deletions.
7 changes: 7 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
version: "3"

services:
keydb:
image: "eqalpha/keydb:latest"
ports:
- "6379:6379"

rabbitmq:
image: "rabbitmq:management"
ports:
Expand Down Expand Up @@ -41,12 +46,14 @@ services:
environment:
API_VERSION: "v1"
API_HOST: "api"
KEYDB_HOST: "keydb"
RABBITMQ_HOST: "rabbitmq"
SLEEP_INTERVAL: "0.5"
restart: on-failure
depends_on:
- rabbitmq
- api
- keydb
links:
- rabbitmq

Expand Down
2 changes: 2 additions & 0 deletions pp/pp/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

keydb_host = os.getenv("KEYDB_HOST", "keydb")
keydb_port = int(os.getenv("KEYDB_PORT", 6379))
rabbitmq_host = os.getenv("RABBITMQ_HOST", "localhost")
prediction_queue = os.getenv("PREDICTION_QUEUE", "prediction")

Expand Down
5 changes: 2 additions & 3 deletions pp/pp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pika
from pika.exceptions import AMQPConnectionError

from .prediction import predict_on_cell_coords
from .prediction import publish_cell_prediction
from .cells import get_res_zero_cell_coords
from .config import rabbitmq_host, prediction_queue, task_sleep_interval

Expand Down Expand Up @@ -39,9 +39,8 @@ async def main():
try:
async with httpx.AsyncClient() as client:
while True:
# allow predictions over cells to run in interleaved way
for cell_coords in resolution_zero_cell_coords:
await asyncio.create_task(predict_on_cell_coords(client, cell_coords, channel))
await asyncio.create_task(publish_cell_prediction(client, cell_coords, channel))
await asyncio.sleep(task_sleep_interval)
except Exception as e:
log.error(f"could not continue publishing because {e}")
Expand Down
22 changes: 10 additions & 12 deletions pp/pp/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
import logging

from pika.channel import Channel
import redis
import httpx
import h3

from .config import api_protocol, api_host, api_port, api_version, prediction_queue
from .config import api_protocol, api_host, api_port, api_version, prediction_queue, keydb_host, keydb_port
from .message import PredictionMessage

log = logging.getLogger(__name__)
keydb = redis.Redis(host=keydb_host, port=keydb_port, db=0)

prediction_endpoint_url = f"{api_protocol}://{api_host}:{api_port}/api/{api_version}/predict"


async def get_prediction_message_for_lat_lon(client: httpx.AsyncClient, lat: float, lon: float) -> PredictionMessage:
"""create the object that will get published to rabbitmq"""
"""create the object that will get published to rabbitmq."""
res = await client.get(prediction_endpoint_url, params={"lat": lat, "lon": lon})
res.raise_for_status()

Expand All @@ -27,18 +29,15 @@ async def get_prediction_message_for_lat_lon(client: httpx.AsyncClient, lat: flo
message = PredictionMessage(
lat=lat,
lon=lon,
h3_id=h3.geo_to_h3(lat, lon, 0),
h3_id=h3.geo_to_h3(lat, lon, resolution=0),
utc=datetime.utcnow().isoformat(),
mpsas=mpsas,
)
return message


# message_store = {}


async def predict_on_cell_coords(client: httpx.AsyncClient, h3_coords: Tuple[float, float], channel: Channel):
"""retrieve and publish a sky brightness prediction at coords for the h3 cell"""
async def publish_cell_prediction(client: httpx.AsyncClient, h3_coords: Tuple[float, float], channel: Channel):
"""retrieve and publish a sky brightness prediction at h3 cell coords."""
import json

try:
Expand All @@ -50,10 +49,9 @@ async def predict_on_cell_coords(client: httpx.AsyncClient, h3_coords: Tuple[flo
log.info(f"publishing {message_body} to {prediction_queue}")
channel.basic_publish(exchange="", routing_key=prediction_queue, body=json.dumps(message_body))

# keep track of how many messages are published for each cell
# message_store[m.h3_id] = message_store.get(m.h3_id, 0) + 1
# with open("data.json", "w") as f:
# json.dump(message_store, f, indent=4)
keydb.incr(m.h3_id)
num_predictions_published = int(keydb.get(m.h3_id))
log.info(f"{m.h3_id} has had {num_predictions_published} predictions published")

except httpx.HTTPStatusError as e:
log.error(f"got bad status from api server {e}")
Expand Down
1 change: 1 addition & 0 deletions pp/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
h3==3.7.7
httpx==0.27.0
pika==1.3.2
redis==5.0.7

0 comments on commit 3985e43

Please sign in to comment.