forked from raphychek/mbappe-nuplan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_caching.py
105 lines (84 loc) · 3.7 KB
/
run_caching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import logging
import os
os.environ['USE_PYGEOS'] = '0'
from typing import Optional
import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from nuplan.planning.script.builders.folder_builder import build_training_experiment_folder
from nuplan.planning.script.builders.logging_builder import build_logger
from nuplan.planning.script.builders.utils.utils_config import update_config_for_training
from nuplan.planning.script.builders.worker_pool_builder import build_worker
from nuplan.planning.script.profiler_context_manager import ProfilerContextManager
from nuplan.planning.script.utils import set_default_path
from nuplan.planning.training.experiments.caching import cache_data
from nuplan.planning.training.experiments.training import TrainingEngine, build_training_engine
from pathlib import Path
import tempfile
logging.getLogger('numba').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
def main(cfg: DictConfig) -> Optional[TrainingEngine]:
"""
Main entrypoint for training/validation experiments.
:param cfg: omegaconf dictionary
"""
# Fix random seed
pl.seed_everything(cfg.seed, workers=True)
# Configure logger
build_logger(cfg)
# Override configs based on setup, and print config
update_config_for_training(cfg)
# Create output storage folder
build_training_experiment_folder(cfg=cfg)
# Build worker
worker = build_worker(cfg)
if cfg.py_func == 'train':
# Build training engine
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "build_training_engine"):
engine = build_training_engine(cfg, worker)
# Run training
logger.info('Starting training...')
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "training"):
engine.trainer.fit(model=engine.model, datamodule=engine.datamodule)
return engine
elif cfg.py_func == 'test':
# Build training engine
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "build_training_engine"):
engine = build_training_engine(cfg, worker)
# Test model
logger.info('Starting testing...')
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "testing"):
engine.trainer.test(model=engine.model, datamodule=engine.datamodule)
return engine
elif cfg.py_func == 'cache':
# Precompute and cache all features
logger.info('Starting caching...')
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "caching"):
cache_data(cfg=cfg, worker=worker)
return None
else:
raise NameError(f'Function {cfg.py_func} does not exist')
if __name__ == '__main__':
# Location of path with all training configs
CONFIG_PATH = 'nuplan/planning/script/config/training'
CONFIG_NAME = 'default_training'
# Create a temporary directory to store the cache and experiment artifacts
SAVE_DIR = Path('nuplan_output_test/') # optionally replace with persistent dir
EXPERIMENT = 'caching'
JOB_NAME = 'cache_urban_multi_agents'
# Initialize configuration management system
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(config_path=CONFIG_PATH)
# Compose the configuration
cfg = hydra.compose(config_name=CONFIG_NAME, overrides=[
f'group={str(SAVE_DIR)}',
f'cache.cache_path={str(SAVE_DIR)}/cache',
f'experiment_name={EXPERIMENT}',
f'job_name={JOB_NAME}',
'py_func=cache',
'model=urban_multi_model',
'+training=training_urban_multi_model',
'scenario_builder=nuplan',
'scenario_filter.limit_total_scenarios=0.04',
])
main(cfg)