-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model.py
31 lines (22 loc) · 898 Bytes
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from mdvae import VQMDVAE, SpeechVQVAE, VisualVQVAE
import hydra
import os
from omegaconf import DictConfig
import torch
torch.cuda.empty_cache()
path = r"checkpoints/2022/mdvae-Y2022M3D17-12h0"
@hydra.main(config_path=f"{path}/config_mdvae", config_name="config")
def main(cfg: DictConfig):
os.chdir(hydra.utils.get_original_cwd())
print("=" * 100)
""" VQ-VAE """
speech_vqvae = SpeechVQVAE(**cfg.vqvae_1)
speech_vqvae.load(path_model=r"checkpoints/VQVAE/speech/model_checkpoint_Y2022M3D5")
visual_vqvae = VisualVQVAE(**cfg.vqvae_2)
visual_vqvae.load(path_model=r"checkpoints/VQVAE/visual/model_checkpoint_Y2022M2D13")
""" MDVAE """
model = VQMDVAE(config_model=cfg.model, vqvae_speech=speech_vqvae, vqvae_visual=visual_vqvae)
model.load_model(path_model=f"{path}/mdvae_model_checkpoint")
print("=" * 100)
if __name__ == '__main__':
main()