-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlit_main_pretrain.py
56 lines (43 loc) · 1.34 KB
/
lit_main_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging
import os
import random
import warnings
import hydra
import numpy as np
import torch
from pytorch_lightning import Trainer, loggers
from datamodule.lit_unlabel_combined_pretrain_data_module import (
UnlabelCombinedPretrainDataModule,
)
from models.lit_VideoMAETrainer import VideoMAETrainer
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
@hydra.main(config_path="configs", config_name="config_pretrain.yaml")
def main(cfg):
print(cfg.trainer)
# initialize random seeds
torch.cuda.manual_seed_all(cfg.seed)
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
random.seed(cfg.seed)
# data module
data_module = UnlabelCombinedPretrainDataModule(cfg)
# model
model = VideoMAETrainer(cfg)
if torch.cuda.is_available() and len(cfg.devices):
print(f"Using {len(cfg.devices)} GPUs !")
train_logger = loggers.TensorBoardLogger("tensor_board", default_hp_metric=False)
trainer = Trainer(
accelerator=cfg.accelerator,
devices=cfg.devices,
strategy=cfg.strategy,
max_epochs=cfg.trainer.epochs,
logger=train_logger,
detect_anomaly=True,
)
if cfg.train:
trainer.fit(model, data_module)
print(trainer.callback_metrics)
if __name__ == "__main__":
os.environ["HYDRA_FULL_ERROR"] = "1"
main()