Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nonnontrivial committed Apr 28, 2024
1 parent f6264b6 commit 3dedf12
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 34 deletions.
54 changes: 28 additions & 26 deletions ctts/test_pollution_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from fastapi.testclient import TestClient

from .api import app
Expand All @@ -6,32 +7,33 @@
API_PREFIX = "/api/v1"


def test_get_pollution():
cities = {
(29.7756796, -95.4888013),
(40.7277478, -74.0000374),
(55.7545835, 37.6137138),
(39.905245, 116.4050653)
@pytest.mark.parametrize("lat, lon", [
(29.7756796, -95.4888013),
(40.7277478, -74.0000374),
(55.7545835, 37.6137138),
(39.905245, 116.4050653)
])
def test_get_city_pollution(lat, lon):
max_channels = {
"r": 255,
"g": 255,
"b": 255,
"a": 255
}
for lat, lon in cities:
r = client.get(f"{API_PREFIX}/pollution?lat={lat}&lon={lon}")
assert r.status_code == 200
assert r.json() == {
"r": 255,
"g": 255,
"b": 255,
"a": 255
}
res = client.get(f"{API_PREFIX}/pollution?lat={lat}&lon={lon}")
assert res.json() == max_channels


def test_get_pollution_out_of_bounds():
out_of_bounds_coords = {(76., -74.), (-65., -74.)}
for lat, lon in out_of_bounds_coords:
r = client.get(f"{API_PREFIX}/pollution?lat={lat}&lon={lon}")
assert r.status_code == 200
assert r.json() == {
"r": 0,
"g": 0,
"b": 0,
"a": 255
}
@pytest.mark.parametrize("lat, lon", [
(76., -74.),
(-65., -74.)
])
def test_out_of_bounds(lat, lon):
empty_channels = {
"r": 0,
"g": 0,
"b": 0,
"a": 255
}
res = client.get(f"{API_PREFIX}/pollution?lat={lat}&lon={lon}")
assert res.json() == empty_channels
19 changes: 11 additions & 8 deletions ctts/test_prediction_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import pytest
from fastapi.testclient import TestClient

from .api import app

client = TestClient(app)
lat, lon = (-30.2466, -70.7494)

API_PREFIX = "/api/v1"


Expand All @@ -13,9 +12,13 @@ def test_get_prediction_bad_status_without_lat_lon():
assert r.status_code != 200


def test_get_prediction():
r = client.get(
f"{API_PREFIX}/prediction?lat={lat}&lon={lon}"
)
assert r.status_code == 200
assert list(r.json().keys()) == ["sky_brightness"]
@pytest.mark.parametrize("coords, lowerbound, upperbound", [
((-30.2466, -70.7494), 6, 25),
((19.8264, -155.4750), 6, 28)
])
def test_prediction(coords, lowerbound, upperbound):
lat, lon = coords
response = client.get(f"{API_PREFIX}/prediction?lat={lat}&lon={lon}")
assert response.status_code == 200
brightness = response.json()["sky_brightness"]
assert lowerbound <= brightness <= upperbound

0 comments on commit 3dedf12

Please sign in to comment.