Skip to content

Commit

Permalink
add uuid and model version to brightness message
Browse files Browse the repository at this point in the history
  • Loading branch information
nonnontrivial committed Jun 30, 2024
1 parent 7fd8037 commit a374554
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 36 deletions.
4 changes: 2 additions & 2 deletions api/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def get_log_config():


def create_prediction_response(prediction_obj: Prediction) -> PredictionResponse:
y = round(float(prediction_obj.y.item()), 4)
precision_digits = 4
y = round(float(prediction_obj.y.item()), precision_digits)
return PredictionResponse(sky_brightness=y)


Expand All @@ -36,7 +37,6 @@ async def get_prediction(lat, lon):
"""Predict sky brightness in magnitudes per square arcsecond for a lat and lon."""
try:
lat, lon = float(lat), float(lon)

prediction = await predict_sky_brightness(lat, lon)
return asdict(create_prediction_response(prediction))
except Exception as e:
Expand Down
6 changes: 5 additions & 1 deletion api/api/prediction/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class Prediction:
y: torch.Tensor


path_to_state_dict = get_path_to_state_dict()


async def predict_sky_brightness(lat: float, lon: float) -> Prediction:
"""Predict sky brightness at utcnow for given lat and lon"""

Expand All @@ -49,7 +52,6 @@ async def predict_sky_brightness(lat: float, lon: float) -> Prediction:

model = NeuralNetwork()

path_to_state_dict = get_path_to_state_dict()
logging.debug(f"loading state dict at {path_to_state_dict}")
model.load_state_dict(torch.load(path_to_state_dict))
model.eval()
Expand All @@ -72,7 +74,9 @@ async def predict_sky_brightness(lat: float, lon: float) -> Prediction:
],
dtype=torch.float32,
).unsqueeze(0)

logging.debug(f"X vector for site is {X}")

with torch.no_grad():
predicted_y = model(X)
return Prediction(X=X, y=predicted_y)
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ services:
build: ./pp
environment:
API_VERSION: "v1"
MODEL_VERSION: "0.1.0"
API_HOST: "api"
KEYDB_HOST: "keydb"
RABBITMQ_HOST: "rabbitmq"
Expand Down
2 changes: 1 addition & 1 deletion pp/pp/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@


def get_h3_cells() -> List[Tuple[float, float]]:
"""gets list of lat,lon coordinates of all resolution zero cells"""
"""gets coords of all resolution zero cells"""
resolution_zero_cells = h3.get_res0_indexes()
return [h3.h3_to_geo(c) for c in resolution_zero_cells]
1 change: 1 addition & 0 deletions pp/pp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
api_port = int(os.getenv("API_PORT", "8000"))
api_host = os.getenv("API_HOST", "localhost")
api_version = os.getenv("API_VERSION", "v1")
model_version = os.getenv("MODEL_VERSION", "0.1.0")
7 changes: 3 additions & 4 deletions pp/pp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pika
from pika.exceptions import AMQPConnectionError

from .prediction import publish_cell_prediction
from .prediction import publish_cell_brightness
from .cells import get_h3_cells
from .config import rabbitmq_host, prediction_queue, task_sleep_interval

Expand All @@ -14,8 +14,7 @@


async def main():
"""initializes process of getting sky brightness predictions for h3 cells;
publishing them to prediction queue as available.
"""initializes process of publishing sky brightness
n.b. with 122 res 0 cells on 2016 macbook, this will publish at a rate of 1.4m/s
"""
Expand All @@ -40,7 +39,7 @@ async def main():
async with httpx.AsyncClient() as client:
while True:
for cell_coords in h3_cell_coords:
await asyncio.create_task(publish_cell_prediction(client, cell_coords, channel))
await asyncio.create_task(publish_cell_brightness(client, cell_coords, channel))
await asyncio.sleep(task_sleep_interval)
except Exception as e:
log.error(f"could not continue publishing because {e}")
Expand Down
12 changes: 0 additions & 12 deletions pp/pp/message.py

This file was deleted.

42 changes: 26 additions & 16 deletions pp/pp/prediction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import uuid
import json
from typing import Tuple
from dataclasses import asdict
from datetime import datetime
Expand All @@ -8,50 +10,58 @@
import httpx
import h3

from .config import api_protocol, api_host, api_port, api_version, prediction_queue, keydb_host, keydb_port
from .message import PredictionMessage
from .config import model_version, api_protocol, api_host, api_port, api_version, prediction_queue, keydb_host, \
keydb_port
from .prediction_message import BrightnessMessage

log = logging.getLogger(__name__)

keydb = redis.Redis(host=keydb_host, port=keydb_port, db=0)

prediction_endpoint_url = f"{api_protocol}://{api_host}:{api_port}/api/{api_version}/predict"


async def get_prediction_message(client: httpx.AsyncClient, h3_lat: float, h3_lon: float) -> PredictionMessage:
"""create the object that will get published to rabbitmq."""
def get_cell_id(lat, lon) -> str:
"""get the h3 cell for this lat and lon"""
return h3.geo_to_h3(lat, lon, resolution=0)[0]


async def create_brightness_message(client: httpx.AsyncClient, h3_lat: float, h3_lon: float) -> BrightnessMessage:
"""create the object that will get published to the prediction queue."""
res = await client.get(prediction_endpoint_url, params={"lat": h3_lat, "lon": h3_lon})
res.raise_for_status()

data = res.json()

if (mpsas := data.get("sky_brightness", None)) is None:
raise ValueError("no sky brightness reading in api response")

message = PredictionMessage(
utc_now = datetime.utcnow().isoformat()
brightness_message = BrightnessMessage(
uuid=str(uuid.uuid4()),
lat=h3_lat,
lon=h3_lon,
utc=datetime.utcnow().isoformat(),
h3_id=get_cell_id(h3_lat, h3_lon),
utc=utc_now,
mpsas=mpsas,
h3_id=h3.geo_to_h3(h3_lat, h3_lon, resolution=0),
model_version=model_version
)
return message
return brightness_message


async def publish_cell_prediction(client: httpx.AsyncClient, h3_coords: Tuple[float, float], channel: Channel):
"""retrieve and publish a sky brightness prediction at h3 cell coords."""
import json

async def publish_cell_brightness(client: httpx.AsyncClient, h3_coords: Tuple[float, float], channel: Channel):
"""create and publish sky brightness at given h3 cell coords."""
try:
lat, lon = h3_coords
m = await get_prediction_message(client, lat, lon)

m = await create_brightness_message(client, lat, lon)
message_body = asdict(m)

log.info(f"publishing {message_body} to {prediction_queue}")
channel.basic_publish(exchange="", routing_key=prediction_queue, body=json.dumps(message_body))

keydb.incr(m.h3_id)
num_predictions_published = int(keydb.get(m.h3_id))
log.info(f"{m.h3_id} has had {num_predictions_published} predictions published")

log.info(f"{m.h3_id} has had {int(keydb.get(m.h3_id))} predictions published")
except httpx.HTTPStatusError as e:
log.error(f"got bad status from api server {e}")
except Exception as e:
Expand Down
15 changes: 15 additions & 0 deletions pp/pp/prediction_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass


@dataclass
class BrightnessMessage:
uuid: str
lat: float
lon: float
# id of the h3 cell
h3_id: str
# utc datetime that this message was published
utc: str
# magnitudes per square arc second estimated by the model
mpsas: float
model_version: str

0 comments on commit a374554

Please sign in to comment.