Skip to content

Commit

Permalink
refactor wav2mel module
Browse files Browse the repository at this point in the history
  • Loading branch information
yistLin committed Mar 22, 2021
1 parent c14058e commit 9ed8f7d
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions data/wav2mel.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Wav2Mel for processing audio data."""

import torch
import torch.nn as nn
from torchaudio.sox_effects import apply_effects_tensor
from torchaudio.transforms import MelSpectrogram


class Wav2Mel(torch.nn.Module):
class Wav2Mel(nn.Module):
"""Transform audio file into mel spectrogram tensors."""

def __init__(
self,
sample_rate: float = 16000,
sample_rate: int = 16000,
norm_db: float = -3.0,
sil_threshold: float = 1.0,
sil_duration: float = 0.1,
Expand Down Expand Up @@ -41,7 +42,7 @@ def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
return mel_tensor


class SoxEffects(torch.nn.Module):
class SoxEffects(nn.Module):
"""Transform waveform tensors."""

def __init__(
Expand Down Expand Up @@ -72,12 +73,12 @@ def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
return wav_tensor


class LogMelspectrogram(torch.nn.Module):
class LogMelspectrogram(nn.Module):
"""Transform waveform tensors into log mel spectrogram tensors."""

def __init__(
self,
sample_rate: float,
sample_rate: int,
fft_window_ms: float,
fft_hop_ms: float,
f_min: float,
Expand All @@ -94,4 +95,4 @@ def __init__(

def forward(self, wav_tensor: torch.Tensor) -> torch.Tensor:
mel_tensor = self.melspectrogram(wav_tensor).squeeze(0).T # (time, n_mels)
return torch.log(mel_tensor.squeeze(0) + 1e-9)
return torch.log(torch.clamp(mel_tensor, min=1e-9))

0 comments on commit 9ed8f7d

Please sign in to comment.