-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
163 lines (141 loc) · 5.73 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import shutil
from pathlib import Path
from typing import Optional
import librosa
import torch
import typer
from more_itertools import chunked
from rich import print
from rich.progress import track
from torch import Tensor
from hypersound.utils.eval import EvalModel, load_recordings
from hypersound.utils.metrics import compute_metrics, reduce_metric
def main(
# fmt: off
# Eval model
work_dir: Optional[Path] = typer.Argument(
None, help="Working directory (directory of the run to evaluate)."
),
cfg_path: Path = typer.Option(
Path("code.cfg"), "--cfg_path", "-c", help="Config file used for training."
),
ckpt_path: Path = typer.Option(
Path("checkpoints/"), "--ckpt_path", "-p", help="Checkpoint file or directory." # noqa
),
reference_dir: Path = typer.Option(
Path("eval/reference"), "--reference_dir", "-r", help="Directory containing reference recordings." # noqa
),
generated_dir: Path = typer.Option(
Path("eval/generated"), "--generated_dir", "-g", help="Directory containing recordings generated by the model." # noqa
),
sample_rate: Optional[int] = typer.Option(
None, "--sample_rate", "-s", help="Audio interpolation rate. If None, will use default training rate."
),
audio_format: str = typer.Option(
"wav", "--audio_format", "-f", help="Audio format."
),
generate_training: bool = typer.Option(
False, "--train", help="If set, will generate recordings for the training fold."
),
is_cuda: bool = typer.Option(
False, "--cuda", help="Enable GPU processing",
),
evaluation_fold: str = typer.Option(
"validation", help="Determines which fold to evaluate on (not applicable if using absolute recording paths)."
),
clean: bool = typer.Option(
False, "--clean", help="Remove generated recordings after evaluation."
),
file_limit_override: Optional[int] = typer.Option(
None, "--file_limit", help="Override config validation file limit."
),
# Metrics
batch_size: int = typer.Option(
16, help="Metric evaluation batch size."
),
pesq: bool = typer.Option(
False, help="If set, will compute PESQ metric on audio resampled to 16000 Hz."
),
stoi: bool = typer.Option(
False, help="If set, will compute STOI."
),
cdpam: bool = typer.Option(
False, help="If set, will compute CDPAM metric on audio resampled to 22050 Hz."
),
# fmt: on
):
"""
Evaluate a trained model.
"""
if work_dir:
eval_model = EvalModel(
cfg_path=work_dir / cfg_path,
ckpt_path=work_dir / ckpt_path,
audio_format=audio_format,
is_cuda=is_cuda,
batch_size=batch_size,
sample_rate=sample_rate,
file_limit_override=file_limit_override,
)
reference_dir = work_dir / reference_dir
generated_dir = work_dir / generated_dir
eval_model.generate_recordings(reference_dir, generated_dir, generate_training)
reference_dir = reference_dir / evaluation_fold
generated_dir = generated_dir / evaluation_fold
else:
if not reference_dir.is_absolute() or not generated_dir.is_absolute():
raise RuntimeError("When no work_dir is specified, dataset paths must be absolute.")
reference, generated = load_recordings(reference_dir, generated_dir, audio_format)
sample_rate = sample_rate or 22050 # Set default for metric evaluation
print("[bold yellow]Evaluating recordings:")
print(f" - reference ({len(reference)}): {reference_dir}")
print(f" - generated ({len(generated)}): {generated_dir}")
if cdpam and sample_rate != 22050:
print(
"[bold red]CDPAM can only be used with sampling rate 22050 Hz, "
"recordings will be resampled for evaluation of this metric."
)
if pesq and sample_rate != 16000:
print(
"[bold red]PESQ can only be used with sampling rate 16000 Hz, "
"recordings will be resampled for evaluation of this metric."
)
metrics: list[dict[str, Tensor]] = []
for ref_paths, gen_paths in track(
zip(chunked(reference, batch_size), chunked(generated, batch_size)),
description="Running evaluation...",
total=len(reference) // batch_size,
):
for (ref_path, gen_path) in zip(ref_paths, gen_paths):
assert ref_path.name == gen_path.name.replace("_reconstruction", "").replace("reconstruction_", ""), (
f"Names of the recordings in target and generated directory should match, "
f"but got: {str(ref_path)} and {str(gen_path)}."
)
ref_recordings: Tensor = torch.stack(
[torch.tensor(librosa.load(str(path), sr=sample_rate)[0]) for path in ref_paths] # type: ignore
)
gen_recordings: Tensor = torch.stack(
[torch.tensor(librosa.load(str(path), sr=sample_rate)[0]) for path in gen_paths] # type: ignore
)
if is_cuda:
ref_recordings = ref_recordings.to("cuda")
gen_recordings = gen_recordings.to("cuda")
metrics.append(
compute_metrics(
preds=gen_recordings,
target=ref_recordings,
sample_rate=sample_rate,
pesq=pesq,
stoi=stoi,
cdpam=cdpam,
)
)
print("Finished evaluation.")
for key in metrics[0]:
print(f"\t[bold yellow]{key}:[/] {reduce_metric(metrics, key):.4f}")
if work_dir and clean:
print("Removing generated recordings after evaluation...")
shutil.rmtree(generated_dir)
shutil.rmtree(reference_dir)
if __name__ == "__main__":
typer.run(main)