-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfilter.py
executable file
·109 lines (92 loc) · 5.21 KB
/
filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import shutil
import matplotlib.pyplot as plt
import pandas as pd
import argparse
from tqdm import tqdm
from glob import glob
import soundfile as sf
import scipy.signal as signal
import numpy as np
###截止频率
cutoff_freq_default = 11000 ## 最低截止频率
dns_mos_threshold_default = 3.7
wvmos_threshold_default = 3.9
sigmos_threshold_default = 3.2
nisqamos_threshold_default = 4.0
utmos_thehold_default = 3.6
def read_wav_file(filename):
# 读取wav文件
data, sample_rate = sf.read(filename)
return sample_rate, data
def calculate_cutoff_frequency(filename):
sample_rate, data = read_wav_file(filename)
# 如果音频是立体声,则只取一个声道
if len(data.shape) > 1:
data = data[:, 0]
# 计算STFT
f, t, Zxx = signal.stft(data, fs=sample_rate, nperseg=1024)
# 计算频率幅度
magnitude = np.abs(Zxx)
# 对时间轴进行平均
avg_magnitude = np.mean(magnitude, axis=1)
# 找到50-2000 Hz的频率范围
freq_range = (f >= 50) & (f <= 2000)
# 计算50-2000 Hz频率范围的平均能量
avg_energy = np.mean(avg_magnitude[freq_range])
# 找到2000 Hz以上的频率
high_freq_range = f > 2000
# 找到幅度大于平均能量的5%的频率
significant_freqs = f[high_freq_range][avg_magnitude[high_freq_range] > 0.05 * avg_energy]
# 实际截止频率为这些显著频率中的最大值,如果没有则为2000
cutoff_frequency = np.max(significant_freqs) if significant_freqs.size > 0 else 2000
return cutoff_frequency
def run(statistics_folder_paths, statisitcs_output_folder_path, origianl_folder_path, filtered_folder_path):
all_statistics_file_path = []
for statistics_folder_path in statistics_folder_paths:
all_statistics_file_path += glob(statistics_folder_path + "/*.csv")
df_total = pd.DataFrame()
for statistics_file_path in tqdm(all_statistics_file_path):
df = pd.read_csv(statistics_file_path)
df_total = pd.concat([df_total, df], axis=0)
#给df_total加上cutoff_freq列
df_total["cutoff_freq"] = 0
df_total = df_total.copy()
df_new = pd.DataFrame(columns=df_total.columns)
for i in tqdm(range(len(df_total))):
filepath = df_total.iloc[i]["wav_file"]
dnsmos_value = eval(df_total.iloc[i]["dnsmos"])["OVRL_raw"]
wvmos_value = df_total.iloc[i]["wvmos"]
sigmos_value = eval(df_total.iloc[i]["sigmos"])["MOS_OVRL"]
nisqamos_value = eval(df_total.iloc[i]["nisqa"])["mos"]
utmos_value = df_total.iloc[i]["utmos_strong"]
cutoff_freq = calculate_cutoff_frequency(filepath)
df_total.loc[i, "cutoff_freq"] = cutoff_freq
if dnsmos_value >= dns_mos_threshold and wvmos_value >= wvmos_threshold and sigmos_value >= sigmos_threshold and nisqamos_value >= nisqamos_threshold and utmos_value >= utmos_thehold and cutoff_freq >= cutoff_freq:
df_new = df_new._append(df_total.iloc[i])
out_path = filepath.replace(origianl_folder_path, filtered_folder_path)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
shutil.copyfile(filepath, out_path)
df_total.to_csv(statisitcs_output_folder_path+"/total_result/statistics_all.csv", index=False)
os.makedirs(statisitcs_output_folder_path+"/total_result", exist_ok=True)
df_new.to_csv(statisitcs_output_folder_path+"/total_result/statistics_filtered.csv", index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--statistics_folder_paths', type=str, required=True, default=None, nargs='+', help='The paths of the statistics folders')
parser.add_argument('--statisitcs_output_folder_path', type=str, required=True, help='The path of the statistics output folder')
parser.add_argument('--origianl_folder_path', type=str, required=True, help='The path of the original folder')
parser.add_argument('--filtered_folder_path', type=str, required=True, help='The path of the filtered folder')
parser.add_argument('--dns_mos_threshold', type=float, default=dns_mos_threshold_default, help='The threshold of DNS MOS')
parser.add_argument('--wvmos_threshold', type=float, default=wvmos_threshold_default, help='The threshold of WV MOS')
parser.add_argument('--sigmos_threshold', type=float, default=sigmos_threshold_default, help='The threshold of SIGMOS')
parser.add_argument('--nisqamos_threshold', type=float, default=nisqamos_threshold_default, help='The threshold of NISQ MOS')
parser.add_argument('--utmos_thehold', type=float, default=utmos_thehold_default, help='The threshold of UT MOS')
parser.add_argument('--cutoff_freq', type=float, default=cutoff_freq_default, help='The default cutoff frequency')
args = parser.parse_args()
dns_mos_threshold = args.dns_mos_threshold
wvmos_threshold = args.wvmos_threshold
sigmos_threshold = args.sigmos_threshold
nisqamos_threshold = args.nisqamos_threshold
utmos_thehold = args.utmos_thehold
cutoff_freq = args.cutoff_freq
run(args.statistics_folder_paths, args.statisitcs_output_folder_path, args.origianl_folder_path, args.filtered_folder_path)