-
Notifications
You must be signed in to change notification settings - Fork 1
/
pretrain_speech_vqvae.py
28 lines (22 loc) · 1004 Bytes
/
pretrain_speech_vqvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from vqmae import SpeechVQVAE, Speech_VQVAE_Train, VoxcelebSequential
import hydra
from omegaconf import DictConfig
import os
@hydra.main(config_path="config_vqvae", config_name="config")
def main(cfg: DictConfig):
os.chdir(hydra.utils.get_original_cwd())
""" Data """
data_train = VoxcelebSequential(root=r"D:\These\data\Audio-Visual\voxceleb\test\video",
h5_path=r"E:\H5\modality_spectrogram_test.hdf5",
frames_per_clip=1,
train=True
)
""" Model """
vqvae = SpeechVQVAE(**cfg.model)
# vqvae.load(path_model=r"checkpoint/VQVAE/2023-1-10/22-36/model_checkpoint")
""" Training """
pretrain_vqvae = Speech_VQVAE_Train(vqvae, data_train, data_train, config_training=cfg.train)
# pretrain_vqvae.load(path=r"checkpoint/VQVAE/2022-12-28/12-7/model_checkpoint")
pretrain_vqvae.fit()
if __name__ == '__main__':
main()