-
Notifications
You must be signed in to change notification settings - Fork 4
/
pretrain.py
56 lines (46 loc) · 1.57 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import sys
import warnings
import hydra
from omegaconf import OmegaConf
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT_DIR)
from model.module import PretrainingModule
from model.network import create_encoder_network
from data_utils.PretrainDataset import create_dataloader
@hydra.main(version_base="1.2", config_path="configs", config_name="pretrain")
def main(cfg):
print("******************************** [Config] ********************************")
print(OmegaConf.to_yaml(cfg))
print("******************************** [Config] ********************************")
pl.seed_everything(cfg.seed)
logger = WandbLogger(
name=cfg.name,
save_dir=cfg.wandb.save_dir,
project=cfg.wandb.project
)
trainer = pl.Trainer(
logger=logger,
accelerator='gpu',
devices=cfg.gpu,
log_every_n_steps=cfg.log_every_n_steps,
max_epochs=cfg.training.max_epochs
)
dataloader = create_dataloader(cfg.dataset)
encoder = create_encoder_network(cfg.model.emb_dim)
model = PretrainingModule(
cfg=cfg.training,
encoder=encoder
)
model.train()
trainer.fit(model, dataloader)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
torch.autograd.set_detect_anomaly(True)
torch.cuda.empty_cache()
torch.multiprocessing.set_sharing_strategy("file_system")
warnings.simplefilter(action='ignore', category=FutureWarning)
main()