-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_m2pretrain.py
99 lines (81 loc) · 4.4 KB
/
run_m2pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import glob
import yaml
import torch
import random
import argparse
import numpy as np
from shutil import copyfile
from argparse import Namespace
from m2p_runner import Runner
from m2p_dataloader import MultiModalDataset
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from transformers import RobertaTokenizer
def get_upstream_args():
parser = argparse.ArgumentParser(description='Argument Parser for Upstream Models of the S3PLR project.')
# required, set either (--run and --config) or (--resume)
parser.add_argument('--run', default=None, choices=['transformer', 'apc'], help='Select pre-training task. \
For the transformer models, which type of pre-training (mockingjay, tera, aalbert, etc) \
is determined by config file.')
parser.add_argument('--config', default=None, type=str, help='Path to experiment config.')
parser.add_argument('--resume', default=None, help='Specify the upstream checkpoint path to resume training')
# ckpt and logging
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--ckpdir', default='', type=str, help='Path to store checkpoint result, if empty then default is used.')
parser.add_argument('--seed', default=1337, type=int, help='Random seed for reproducable results.')
# Options
parser.add_argument('--test', default='', type=str, help='Input path to the saved model ckpt for testing.')
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--multi_gpu', action='store_true', help='Enable Multi-GPU training.')
parser.add_argument('--test_reconstruct', action='store_true', help='Test reconstruction capability')
parser.add_argument('--online_config', default=None, help='Explicitly specify the config of on-the-fly feature extraction')
parser.add_argument('--kaldi_data', action='store_true', help='Whether to use the Kaldi dataset')
# parse
args = parser.parse_args()
if args.resume is None:
assert args.run is not None and args.config is not None, '`--run` and `--config` must be given if `--resume` is not provided'
setattr(args, 'gpu', not args.cpu)
config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader)
if args.online_config is not None:
online_config = yaml.load(open(args.online_config, 'r'), Loader=yaml.FullLoader)
config['online'] = online_config
else:
if os.path.isdir(args.resume):
ckpts = glob.glob(f'{args.resume}/*.ckpt')
assert len(ckpts) > 0
ckpts = sorted(ckpts, key=lambda pth: int(pth.split('-')[-1].split('.')[0]))
resume_ckpt = ckpts[-1]
else:
resume_ckpt = args.resume
def update_args(old, new):
old_dict = vars(old)
new_dict = vars(new)
old_dict.update(new_dict)
return Namespace(**old_dict)
ckpt = torch.load(resume_ckpt, map_location='cpu')
args = update_args(args, ckpt['Settings']['Paras'])
config = ckpt['Settings']['Config']
setattr(args, 'resume', resume_ckpt)
return args, config
def train(args, config):
if args.ckpdir == '':
if args.name is None: args.name = 'run_' + str(random.randint(0, 999))
ckpdir = os.path.join('result/result_transformer/', args.name)
else:
ckpdir = args.ckpdir
if not os.path.exists(ckpdir):
os.makedirs(ckpdir)
copyfile(args.config, os.path.join(ckpdir, args.config.split('/')[-1]))
tokenizer = BertTokenizer.from_pretrained(config['dataloader']['tokenizer_path'])
dataset = MultiModalDataset(file_path=config['dataloader']['data_path'], sets=config['dataloader']['train_set'],
bucket_size=config['dataloader']['batch_size'],max_timestep=config['dataloader']['max_timestep'],
drop=True,acoustic_config=config['acoustic'],semantic_config=config['semantic'],tokenizer=tokenizer)
dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, drop_last=False,
num_workers=config['dataloader']['n_jobs'], pin_memory=True)
runner = Runner(args, config, dataloader, ckpdir)
runner.set_model()
runner.train()
if __name__ == "__main__":
args, config = get_upstream_args()
train(args, config)