-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
29 lines (22 loc) · 817 Bytes
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import torch
from dotenv import load_dotenv
class Config:
def __init__(self):
load_dotenv()
self.imagenet = ImageNetConfig()
self.runtime = RuntimeConfig()
class ImageNetConfig:
def __init__(self):
self.annotations_dir = os.getenv("IMAGENET_ANNOTATIONS_DIR")
self.data_dir = os.getenv("IMAGENET_DATA_DIR")
self.synset_file = os.getenv("IMAGENET_SYNSET_FILE")
class RuntimeConfig:
def __init__(self):
self.batch_size = int(os.getenv("BATCH_SIZE"))
self.device = torch.device(
os.getenv("DEVICE") or "cuda" if torch.cuda.is_available() else "cpu"
)
self.num_predictions = int(os.getenv("NUM_PREDICTIONS"))
self.num_workers = int(os.getenv("NUM_WORKERS") or os.cpu_count())
config = Config()