Skip to content

Commit

Permalink
[chore] add channels argument in to resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
gudgud96 committed Mar 25, 2024
1 parent 71688f4 commit 9649ee3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion frechet_audio_distance/clap_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ def update(*a):
if self.verbose:
print("[CLAP score] Loading audio from {}...".format(dir))
for fname in os.listdir(dir):
# assume CLAP input audio to always be mono channel
res = pool.apply_async(
load_audio_task,
args=(os.path.join(dir, fname), self.sample_rate, dtype),
args=(os.path.join(dir, fname), self.sample_rate, 1, dtype),
callback=update
)
task_results.append(res)
Expand Down
5 changes: 3 additions & 2 deletions frechet_audio_distance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import soundfile as sf


def load_audio_task(fname, sample_rate, dtype="float32"):
def load_audio_task(fname, sample_rate, channels, dtype="float32"):
if dtype not in ['float64', 'float32', 'int32', 'int16']:
raise ValueError(f"dtype not supported: {dtype}")

Expand All @@ -15,7 +15,8 @@ def load_audio_task(fname, sample_rate, dtype="float32"):
wav_data = wav_data / float(2**31)

# Convert to mono
if len(wav_data.shape) > 1:
assert channels in [1, 2], "channels must be 1 or 2"
if len(wav_data.shape) > channels:
wav_data = np.mean(wav_data, axis=1)

if sr != sample_rate:
Expand Down

0 comments on commit 9649ee3

Please sign in to comment.