-
Notifications
You must be signed in to change notification settings - Fork 1
/
classification_speaker_dependent.py
60 lines (49 loc) · 2.49 KB
/
classification_speaker_dependent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from vqmae import MAE, SpeechVQVAE, Classifier_Train, EvaluationDatasetSpeakerDependent, h5_creation, size_model
import hydra
from omegaconf import DictConfig
import os
# ---------------------------------------------------------------------------------------------
root = r"D:\These\data\Audio\RAVDESS"
dataset_name = "ravdess"
h5_path = r"H5/ravdess.hdf5"
mae_path = r"checkpoint/RSMAE/2023-2-22/12-45"
# ---------------------------------------------------------------------------------------------
@hydra.main(config_path=f"{mae_path}/config_mae", config_name="config")
def main(cfg: DictConfig):
os.chdir(hydra.utils.get_original_cwd())
""" Data """
data_train = EvaluationDatasetSpeakerDependent(root=root,
ratio_train=80,
train=True,
frames_per_clip=200,
dataset=dataset_name,
h5_path=h5_path
)
data_validation = EvaluationDatasetSpeakerDependent(root=root,
train=False,
ratio_train=80,
frames_per_clip=200,
dataset=dataset_name,
h5_path=h5_path
)
""" VQVAE """
vqvae = SpeechVQVAE(**cfg.vqvae)
vqvae.load(path_model=r"checkpoint/SPEECH_VQVAE/2022-12-27/21-42/model_checkpoint")
""" MAE """
mae = MAE(**cfg.model, trainable_position=True)
mae.load(path_model=f"{mae_path}//model")
size_model(mae, "mae")
# mae = mae.requires_grad_(False)
""" Training """
pretrain_classifier = Classifier_Train(mae,
vqvae,
data_train,
data_validation,
config_training=cfg.train, follow=True, query2emo=False)
# pretrain_classifier.load(path="checkpoint/CLASSIFIER/2023-1-23/10-31/model_checkpoint")
accuracy, f1 = pretrain_classifier.fit()
print("-" * 50)
print(f"Final accuracy: {accuracy}")
print(f"Final F1: {f1}")
if __name__ == '__main__':
main()