-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathpretrain.py
61 lines (55 loc) · 2.23 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import argparse
import torch.nn as nn
import MinkowskiEngine as ME
import pytorch_lightning as pl
from utils.read_config import generate_config
from pretrain.model_builder import make_model
from pytorch_lightning.plugins import DDPPlugin
from pretrain.lightning_trainer import LightningPretrain
from pretrain.lightning_datamodule import PretrainDataModule
from pretrain.lightning_trainer_spconv import LightningPretrainSpconv
def main():
"""
Code for launching the pretraining
"""
parser = argparse.ArgumentParser(description="arg parser")
parser.add_argument(
"--cfg_file", type=str, default="config/slidr_minkunet.yaml", help="specify the config for training"
)
parser.add_argument(
"--resume_path", type=str, default=None, help="provide a path to resume an incomplete training"
)
args = parser.parse_args()
config = generate_config(args.cfg_file)
if args.resume_path:
config['resume_path'] = args.resume_path
if os.environ.get("LOCAL_RANK", 0) == 0:
print(
"\n" + "\n".join(list(map(lambda x: f"{x[0]:20}: {x[1]}", config.items())))
)
dm = PretrainDataModule(config)
model_points, model_images = make_model(config)
if config["num_gpus"] > 1:
model_points = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model_points)
model_images = nn.SyncBatchNorm.convert_sync_batchnorm(model_images)
if config["model_points"] == "minkunet":
module = LightningPretrain(model_points, model_images, config)
elif config["model_points"] == "voxelnet":
module = LightningPretrainSpconv(model_points, model_images, config)
path = os.path.join(config["working_dir"], config["datetime"])
trainer = pl.Trainer(
gpus=config["num_gpus"],
accelerator="ddp",
default_root_dir=path,
checkpoint_callback=True,
max_epochs=config["num_epochs"],
plugins=DDPPlugin(find_unused_parameters=False),
num_sanity_val_steps=0,
resume_from_checkpoint=config["resume_path"],
check_val_every_n_epoch=1,
)
print("Starting the training")
trainer.fit(module, dm)
if __name__ == "__main__":
main()