-
Notifications
You must be signed in to change notification settings - Fork 1
/
training_vqmae.py
47 lines (40 loc) · 1.85 KB
/
training_vqmae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from vqmae import MAE, MAE_Train, SpeechVQVAE, VoxcelebSequential
import hydra
from omegaconf import DictConfig
import os
@hydra.main(config_path="config_mae", config_name="config")
def main(cfg: DictConfig):
os.chdir(hydra.utils.get_original_cwd())
""" Data """
data_train = VoxcelebSequential(root=r"D:\These\data\Audio-Visual\voxceleb\train",
h5_path=r"path-to-h5-train",
frames_per_clip=200,
train=True
)
data_validation = VoxcelebSequential(root=r"D:\These\data\Audio-Visual\voxceleb\test\video",
h5_path=r"path-to-h5-validation",
frames_per_clip=200
)
""" VQVAE """
vqvae = SpeechVQVAE(**cfg.vqvae)
vqvae.load(path_model=r"checkpoint/SPEECH_VQVAE/2022-12-27/21-42/model_checkpoint")
""" MAE """
mae = MAE(**cfg.model,
vqvae_embedding=None,
masking="random",
trainable_position=True) # ["random", "horizontal", "vertical", "mosaic"]
""" Training """
description = dict(encoder_depth=6, decoder_depth=4, ratio=0.50, masking="random", trainable_position=True)
pretrain_vqvae = MAE_Train(mae,
vqvae,
data_train,
data_validation,
config_training=cfg.train,
tube_bool=True,
follow=True,
multigpu_bool=True,
description=description)
# pretrain_vqvae.load(path="checkpoint/RSMAE/2023-2-1/11-4/model_checkpoint")
pretrain_vqvae.fit()
if __name__ == '__main__':
main()