Skip to content

Commit

Permalink
update endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
nonnontrivial committed Jun 1, 2024
1 parent c323e61 commit 74f949b
Show file tree
Hide file tree
Showing 18 changed files with 104 additions and 85 deletions.
3 changes: 3 additions & 0 deletions api/api/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

api_version = os.getenv("API_VERSION", "v1")
38 changes: 16 additions & 22 deletions api/api/main.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,47 @@
from dataclasses import dataclass, asdict
import os
from dataclasses import asdict

from fastapi import FastAPI, HTTPException, APIRouter

from .config import api_version
from .models import PredictionResponse
from .pollution.pollution import ArtificialNightSkyBrightnessMapImage, Coords
from .prediction.prediction import (
Prediction,
predict_sky_brightness,
)

api_version = os.getenv("API_VERSION", "v1")

app = FastAPI()
main_router = APIRouter(prefix=f"/api/{api_version}")


@dataclass
class PredictionResponse:
"""response with sky brightness in magnitudes per square arcsecond"""
sky_brightness: float
def create_prediction_response(prediction_obj: Prediction) -> PredictionResponse:
y = round(float(prediction_obj.y.item()), 4)
return PredictionResponse(sky_brightness=y)


@main_router.get("/prediction")
@main_router.get("/predict")
async def get_prediction(lat, lon):
"""Predict sky brightness in magnitudes per square arcsecond for a lat and lon"""

def create_prediction_response(prediction_obj: Prediction) -> PredictionResponse:
y = round(float(prediction_obj.y.item()), 4)
return PredictionResponse(sky_brightness=y)

"""Predict sky brightness in magnitudes per square arcsecond for a lat and lon."""
try:
lat, lon = float(lat), float(lon)

prediction = await predict_sky_brightness(lat, lon)
return asdict(create_prediction_response(prediction))
except Exception as e:
raise HTTPException(status_code=500, detail=f"could not get prediction because {e}")
raise HTTPException(status_code=500, detail=f"failed to predict because {e}")


@main_router.get("/pollution")
@main_router.get("/lp")
async def get_artificial_light_pollution(lat, lon):
"""Get artificial light pollution at a lat and lon
Source: https://djlorenz.github.io/astronomy/lp2022/
"""
"""Get artificial light pollution at a lat and lon. Source https://djlorenz.github.io/astronomy/lp2022/"""
try:
lat, lon = float(lat), float(lon)

map_image = ArtificialNightSkyBrightnessMapImage()
pixel_rgba = map_image.get_pixel_value_at_coords(coords=Coords(lat, lon))
return {channel: pixel_value for channel, pixel_value in zip(("r", "g", "b", "a"), pixel_rgba)}

channels = ("r", "g", "b", "a")
return {channel: pixel_value for channel, pixel_value in zip(channels, pixel_rgba)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"could not get light pollution because {e}")

Expand Down
2 changes: 1 addition & 1 deletion api/api/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader, TensorDataset, random_split

from ..prediction.constants import features
from ..prediction.nn import NeuralNetwork
from api.prediction.net.nn import NeuralNetwork

features_size = len(features)

Expand Down
7 changes: 7 additions & 0 deletions api/api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dataclasses import dataclass


@dataclass
class PredictionResponse:
"""in magnitudes per square arcsecond"""
sky_brightness: float
2 changes: 2 additions & 0 deletions api/api/prediction/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@

open_meteo_host = os.getenv("OPEN_METEO_HOST", "localhost")
open_meteo_port = int(os.getenv("OPEN_METEO_PORT", "8080"))

model_state_dict_file_name = os.getenv("MODEL_STATE_DICT_FILE_NAME", "model.pth")
4 changes: 0 additions & 4 deletions api/api/prediction/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
"MoonAz",
]
num_features = len(features)

HIDDEN_SIZE = 64 * 3
OUTPUT_SIZE = 1

MODEL_STATE_DICT_FILE_NAME = "model.pth"
MAX_OKTAS = 8

LOGFILE_KEY = "SKY_BRIGHTNESS_LOGFILE"
Empty file.
2 changes: 2 additions & 0 deletions api/api/prediction/meteo/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
PROTOCOL = "http"
MAX_OKTAS = 8
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import logging
import typing as t

from .config import open_meteo_host, open_meteo_port
from .constants import MAX_OKTAS
from .observer_site import ObserverSite
from .utils import get_astro_time_hour
from ..config import open_meteo_host, open_meteo_port
from ..observer_site import ObserverSite
from ..utils import get_astro_time_hour
from .constants import MAX_OKTAS, PROTOCOL


class OpenMeteoClient:
def __init__(self, site: ObserverSite) -> None:
self.site = site
self.url_base = f"http://{open_meteo_host}:{open_meteo_port}"
self.url_base = f"{PROTOCOL}://{open_meteo_host}:{open_meteo_port}"

async def get_values_at_site(self) -> t.Tuple[int, float]:
"""get cloudcover and elevation values for the observer site"""
Expand Down Expand Up @@ -45,5 +44,5 @@ def get_cloud_cover_as_oktas(self, cloud_cover_percentage: int):
"""convert percentage to integer oktas value (eights of sky covered)"""
import numpy as np

percentage_as_oktas = np.interp(cloud_cover_percentage, (0, 100), (0, MAX_OKTAS))
return int(percentage_as_oktas)
percentage_as_oktas = int(np.interp(cloud_cover_percentage, (0, 100), (0, MAX_OKTAS)))
return percentage_as_oktas
Empty file.
2 changes: 1 addition & 1 deletion api/api/prediction/nn.py → api/api/prediction/net/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn

from .constants import (
from ..constants import (
num_features,
HIDDEN_SIZE,
OUTPUT_SIZE,
Expand Down
30 changes: 18 additions & 12 deletions api/api/prediction/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
import torch
from astropy.coordinates import EarthLocation

from .constants import (
MODEL_STATE_DICT_FILE_NAME,
LOGFILE_KEY,
)
from .open_meteo_client import OpenMeteoClient
from .nn import NeuralNetwork
from api.prediction.meteo.open_meteo_client import OpenMeteoClient
from api.prediction.net.nn import NeuralNetwork
from .observer_site import ObserverSite
from .constants import LOGFILE_KEY
from .config import model_state_dict_file_name

logfile_name = os.getenv(LOGFILE_KEY)
path_to_logfile = (Path.home() / logfile_name) if logfile_name else None
Expand All @@ -26,34 +24,41 @@
)


def get_path_to_state_dict():
return Path(__file__).parent / model_state_dict_file_name


@dataclass
class Prediction:
X: torch.Tensor
y: torch.Tensor


async def predict_sky_brightness(lat: float, lon: float) -> Prediction:
"""Predict sky brightness at utcnow for the lat and lon"""
"""Predict sky brightness at utcnow for given lat and lon"""

logging.debug(f"registering site at {lat},{lon}")
location = EarthLocation.from_geodetic(lon * u.degree, lat * u.degree)

site = ObserverSite(location=location)
meteo_client = OpenMeteoClient(site=site)

try:
cloud_cover, elevation = await meteo_client.get_values_at_site()
logging.debug(f"meteo_client response at {lat},{lon} is {cloud_cover}o, {elevation}m")

model = NeuralNetwork()

path_to_state_dict = Path(__file__).parent / MODEL_STATE_DICT_FILE_NAME
path_to_state_dict = get_path_to_state_dict()
logging.debug(f"loading state dict at {path_to_state_dict}")
model.load_state_dict(torch.load(path_to_state_dict))
model.eval()
except Exception as e:
logging.error(f"failed to predict because {e}")
empty_tensor = torch.empty(4, 4)
return Prediction(X=empty_tensor, y=empty_tensor)
import traceback
logging.error(traceback.format_exc())
raise ValueError(f"{e}")
else:
torch.set_printoptions(sci_mode=False)

X = torch.tensor(
[
site.latitude.value,
Expand All @@ -66,6 +71,7 @@ async def predict_sky_brightness(lat: float, lon: float) -> Prediction:
],
dtype=torch.float32,
).unsqueeze(0)
logging.debug(f"X vector for site is {X}")
with torch.no_grad():
predicted_y = model(X)
return Prediction(X=X, y=predicted_y)
7 changes: 4 additions & 3 deletions api/api/tests/test_pollution_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from fastapi.testclient import TestClient

from api.config import api_version
from api.main import app

API_PREFIX = f"/api/{api_version}"
client = TestClient(app)
API_PREFIX = "/api/v1"


@pytest.mark.parametrize("lat, lon", [
Expand All @@ -20,7 +21,7 @@ def test_get_city_pollution(lat, lon):
"b": 255,
"a": 255
}
res = client.get(f"{API_PREFIX}/pollution?lat={lat}&lon={lon}")
res = client.get(f"{API_PREFIX}/lp?lat={lat}&lon={lon}")
assert res.json() == max_channels


Expand All @@ -35,5 +36,5 @@ def test_out_of_bounds(lat, lon):
"b": 0,
"a": 255
}
res = client.get(f"{API_PREFIX}/pollution?lat={lat}&lon={lon}")
res = client.get(f"{API_PREFIX}/lp?lat={lat}&lon={lon}")
assert res.json() == empty_channels
7 changes: 4 additions & 3 deletions api/api/tests/test_prediction_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from fastapi.testclient import TestClient

from api.main import app
from api.config import api_version

client = TestClient(app)
API_PREFIX = "/api/v1"
API_PREFIX = f"/api/{api_version}"


def test_get_prediction_bad_status_without_lat_lon():
r = client.get(f"{API_PREFIX}/prediction")
r = client.get(f"{API_PREFIX}/predict")
assert r.status_code != 200


Expand All @@ -18,7 +19,7 @@ def test_get_prediction_bad_status_without_lat_lon():
])
def test_prediction(coords, lowerbound, upperbound):
lat, lon = coords
response = client.get(f"{API_PREFIX}/prediction?lat={lat}&lon={lon}")
response = client.get(f"{API_PREFIX}/predict?lat={lat}&lon={lon}")
assert response.status_code == 200
brightness = response.json()["sky_brightness"]
assert lowerbound <= brightness <= upperbound
5 changes: 2 additions & 3 deletions cpp/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# cpp

this container is a producer for rabbitmq prediction queue.
producer for rabbitmq queue which handles sky brightness prediction across h3 cells.

it finds cells to request predictions for,
requests those predictions, and then sends
it finds cells to request predictions for, requests those predictions, and then sends
the response to the prediction queue (repeatedly)
32 changes: 19 additions & 13 deletions cpp/cpp/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,27 @@ async def predict_on_cell(client: httpx.AsyncClient, coords: Tuple[float, float]
import json
import asyncio

# FIXME routing key
routing_key = "hello"
api_url = f"{api_protocol}://{api_host}:{api_port}/api/{api_version}/prediction"

lat, lon = coords
res = await client.get(api_url, params={"lat": lat, "lon": lon})
# res.raise_for_status()
data = res.json()
message_body = asdict(PredictionMessage(
lat=lat,
lon=lon,
time_of=datetime.utcnow().isoformat(),
sky_brightness_mpsas=data["sky_brightness"],
))
log.info(f"publishing prediction message {message_body}")
# FIXME routing key
channel.basic_publish(exchange="", routing_key="hello", body=json.dumps(message_body))

try:
res = await client.get(api_url, params={"lat": lat, "lon": lon})
res.raise_for_status()
data = res.json()
message_body = asdict(PredictionMessage(
lat=lat,
lon=lon,
time_of=datetime.utcnow().isoformat(),
sky_brightness_mpsas=data["sky_brightness"],
))
log.info(f"publishing prediction message {message_body}")
channel.basic_publish(exchange="", routing_key=routing_key, body=json.dumps(message_body))
except httpx.HTTPStatusError as e:
log.error(f"got bad status from api server {e}")
except Exception as e:
log.error(f"could not publish prediction at {coords} because {e}")
await asyncio.sleep(1)


Expand Down
9 changes: 6 additions & 3 deletions cpp/cpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@


async def main():
"""continuously request predictions for cells, and publish responses as available"""

connection = pika.BlockingConnection(pika.ConnectionParameters(rabbitmq_host))
channel = connection.channel()
channel.queue_declare(queue=prediction_queue)

cell_coords = get_res_zero_cell_coords()[:5]
log.info(f"producing predictions for {len(cell_coords)} resolution zero cells")

async with httpx.AsyncClient() as client:
while True:
cells = get_res_zero_cell_coords()
log.info(f"found {len(cells)} resolution zero cells")
tasks = [predict_on_cell(client, coords, channel) for coords in cells[:10]]
tasks = [predict_on_cell(client, coords, channel) for coords in cell_coords]
await asyncio.gather(*tasks)
await asyncio.sleep(1)

Expand Down
24 changes: 12 additions & 12 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@ services:
depends_on:
- openmeteo

cpp:
build: ./cpp
environment:
API_VERSION: "v1"
API_HOST: "api"
RABBITMQ_HOST: "rabbitmq"
restart: on-failure
depends_on:
- rabbitmq
- api
links:
- rabbitmq
# cpp:
# build: ./cpp
# environment:
# API_VERSION: "v1"
# API_HOST: "api"
# RABBITMQ_HOST: "rabbitmq"
# restart: on-failure
# depends_on:
# - rabbitmq
# - api
# links:
# - rabbitmq

volumes:
open-meteo-data:
Expand Down

0 comments on commit 74f949b

Please sign in to comment.