Skip to content

Commit

Permalink
reference config in build and train scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
nonnontrivial committed Apr 28, 2024
1 parent d532fe9 commit 27b17e0
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 53 deletions.
52 changes: 12 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

APIs that tap into the quality of the night sky.

## HTTP APIs
## Building and training the sky brightness model

- artificial sky brightness (light pollution)
- predictive sky brightness
> Note that `globe_at_night.tar.gz` will need to unpacked into `./data/globe_at_night`
- `python -m ctts.model.build` to write the csv that the model trains on
- `python -m ctts.model.train` to train on the data in the csv

### running locally
## HTTP APIs

### how to run locally

> Note: tested on python 3.11
Expand All @@ -16,6 +20,10 @@ pip install -r requirements.txt
python -m uvicorn ctts.api:app --reload
```


- artificial sky brightness (light pollution)
- predictive sky brightness

### endpoints

#### `/api/v1/pollution`
Expand Down Expand Up @@ -50,39 +58,3 @@ curl "http://localhost:8000/api/v1/prediction?lat=-30.2466&lon=-70.7494"

Open the [ui](http://localhost:8000/docs) in a browser.

### running with docker

> Note: image size is on the order of 5.13GB
- build the image

```sh
cd ctts
docker build -t ctts:latest .
```

- run the container

```sh
docker run -d --name ctts -p 8000:80 ctts:latest
```

### validation

> Note: validation of the model's predictions is an ongoing process.
### running tests

```sh
cd ctts
python -m pytest
```

### env

Location of logfile can be controlled by setting `SKY_BRIGHTNESS_LOGFILE` to
some filename (in the home directory).

```sh
export SKY_BRIGHTNESS_LOGFILE=ctts.log
```
2 changes: 1 addition & 1 deletion ctts/model/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from astropy.time import Time

config = ConfigParser()
config.read("config.ini")
config.read(Path(__file__).parent / "config.ini")

max_sqm = config.getint("sqm", "max")
min_sqm = config.getint("sqm", "min")
Expand Down
5 changes: 4 additions & 1 deletion ctts/model/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ max = 22
min = 16

[csv]
filename = globe_at_night.csv
filename = globe_at_night.csv

[train]
epochs = 100
17 changes: 9 additions & 8 deletions ctts/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
features_size = len(features)

config = ConfigParser()
config.read("config.ini")
# csv_filename = config.get("csv", "filename")
config.read(Path(__file__).parent / "config.ini")

csv_filename = config.get("csv", "filename")
epochs = config.getint("train", "epochs")

cwd = Path.cwd()
csv_filename = "globe_at_night.csv"
saved_model_path = cwd / "ctts" / "prediction" / "model.pth"

path_to_gan_dataframe = cwd / "data" / csv_filename
path_to_data_dir = Path(__file__).parent.parent.parent / "data"
path_to_gan_dataframe = path_to_data_dir / csv_filename
if not path_to_gan_dataframe.exists():
raise FileNotFoundError()

Expand Down Expand Up @@ -87,12 +90,10 @@ def test_model(data_loader: DataLoader, model: NeuralNetwork, loss_fn: nn.HuberL


if __name__ == "__main__":
epochs = 100
saved_model_path = cwd / "ctts" / "prediction" / "model.pth"
print(f"starting training with {epochs} epochs, and saving state dict to {saved_model_path}")

for t in range(epochs):
print(f"epoch {t+1}")
for epoch in range(epochs):
print(f"epoch {epoch + 1}")
train_loop(train_dataloader, model, loss_fn, optimizer)

test_model(test_dataloader, model, loss_fn)
Expand Down
3 changes: 1 addition & 2 deletions ctts/prediction/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

HIDDEN_SIZE = 64 * 3
OUTPUT_SIZE = 1
MODEL_STATE_DICT_FILE_NAME = "model.pth"

# ASTRO_TWILIGHT_DEGS = -18
MODEL_STATE_DICT_FILE_NAME = "model.pth"

OPEN_METEO_BASE_URL = "https://api.open-meteo.com"
MAX_OKTAS = 8
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion ctts/prediction/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
MODEL_STATE_DICT_FILE_NAME,
LOGFILE_KEY,
)
from .meteo import OpenMeteoClient
from .open_meteo_client import OpenMeteoClient
from .nn import NeuralNetwork
from .observer_site import ObserverSite

Expand Down

0 comments on commit 27b17e0

Please sign in to comment.