diff --git a/README.md b/README.md index c9897fc..e0c8104 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,16 @@ APIs that tap into the quality of the night sky. -## HTTP APIs +## Building and training the sky brightness model -- artificial sky brightness (light pollution) -- predictive sky brightness +> Note that `globe_at_night.tar.gz` will need to unpacked into `./data/globe_at_night` + +- `python -m ctts.model.build` to write the csv that the model trains on +- `python -m ctts.model.train` to train on the data in the csv -### running locally +## HTTP APIs + +### how to run locally > Note: tested on python 3.11 @@ -16,6 +20,10 @@ pip install -r requirements.txt python -m uvicorn ctts.api:app --reload ``` + +- artificial sky brightness (light pollution) +- predictive sky brightness + ### endpoints #### `/api/v1/pollution` @@ -50,39 +58,3 @@ curl "http://localhost:8000/api/v1/prediction?lat=-30.2466&lon=-70.7494" Open the [ui](http://localhost:8000/docs) in a browser. -### running with docker - -> Note: image size is on the order of 5.13GB - -- build the image - -```sh -cd ctts -docker build -t ctts:latest . -``` - -- run the container - -```sh -docker run -d --name ctts -p 8000:80 ctts:latest -``` - -### validation - -> Note: validation of the model's predictions is an ongoing process. - -### running tests - -```sh -cd ctts -python -m pytest -``` - -### env - -Location of logfile can be controlled by setting `SKY_BRIGHTNESS_LOGFILE` to -some filename (in the home directory). - -```sh -export SKY_BRIGHTNESS_LOGFILE=ctts.log -``` diff --git a/ctts/model/build.py b/ctts/model/build.py index aac98bf..5acc708 100644 --- a/ctts/model/build.py +++ b/ctts/model/build.py @@ -9,7 +9,7 @@ from astropy.time import Time config = ConfigParser() -config.read("config.ini") +config.read(Path(__file__).parent / "config.ini") max_sqm = config.getint("sqm", "max") min_sqm = config.getint("sqm", "min") diff --git a/ctts/model/config.ini b/ctts/model/config.ini index 418c786..6c67335 100644 --- a/ctts/model/config.ini +++ b/ctts/model/config.ini @@ -3,4 +3,7 @@ max = 22 min = 16 [csv] -filename = globe_at_night.csv \ No newline at end of file +filename = globe_at_night.csv + +[train] +epochs = 100 \ No newline at end of file diff --git a/ctts/model/train.py b/ctts/model/train.py index f6041d0..950e234 100644 --- a/ctts/model/train.py +++ b/ctts/model/train.py @@ -13,13 +13,16 @@ features_size = len(features) config = ConfigParser() -config.read("config.ini") -# csv_filename = config.get("csv", "filename") +config.read(Path(__file__).parent / "config.ini") + +csv_filename = config.get("csv", "filename") +epochs = config.getint("train", "epochs") cwd = Path.cwd() -csv_filename = "globe_at_night.csv" +saved_model_path = cwd / "ctts" / "prediction" / "model.pth" -path_to_gan_dataframe = cwd / "data" / csv_filename +path_to_data_dir = Path(__file__).parent.parent.parent / "data" +path_to_gan_dataframe = path_to_data_dir / csv_filename if not path_to_gan_dataframe.exists(): raise FileNotFoundError() @@ -87,12 +90,10 @@ def test_model(data_loader: DataLoader, model: NeuralNetwork, loss_fn: nn.HuberL if __name__ == "__main__": - epochs = 100 - saved_model_path = cwd / "ctts" / "prediction" / "model.pth" print(f"starting training with {epochs} epochs, and saving state dict to {saved_model_path}") - for t in range(epochs): - print(f"epoch {t+1}") + for epoch in range(epochs): + print(f"epoch {epoch + 1}") train_loop(train_dataloader, model, loss_fn, optimizer) test_model(test_dataloader, model, loss_fn) diff --git a/ctts/prediction/constants.py b/ctts/prediction/constants.py index d3e617e..552ac92 100644 --- a/ctts/prediction/constants.py +++ b/ctts/prediction/constants.py @@ -11,9 +11,8 @@ HIDDEN_SIZE = 64 * 3 OUTPUT_SIZE = 1 -MODEL_STATE_DICT_FILE_NAME = "model.pth" -# ASTRO_TWILIGHT_DEGS = -18 +MODEL_STATE_DICT_FILE_NAME = "model.pth" OPEN_METEO_BASE_URL = "https://api.open-meteo.com" MAX_OKTAS = 8 diff --git a/ctts/prediction/meteo.py b/ctts/prediction/open_meteo_client.py similarity index 100% rename from ctts/prediction/meteo.py rename to ctts/prediction/open_meteo_client.py diff --git a/ctts/prediction/prediction.py b/ctts/prediction/prediction.py index 9e5149d..96596c0 100644 --- a/ctts/prediction/prediction.py +++ b/ctts/prediction/prediction.py @@ -11,7 +11,7 @@ MODEL_STATE_DICT_FILE_NAME, LOGFILE_KEY, ) -from .meteo import OpenMeteoClient +from .open_meteo_client import OpenMeteoClient from .nn import NeuralNetwork from .observer_site import ObserverSite