diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml new file mode 100644 index 0000000..f184b08 --- /dev/null +++ b/.github/workflows/pypi-release.yml @@ -0,0 +1,26 @@ +name: Publish Python package + +on: + release: + types: [published] + +jobs: + publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install pypa/setuptools + run: >- + python -m + pip install wheel + - name: Build a binary wheel + run: >- + python setup.py sdist bdist_wheel + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1b8d1cc --- /dev/null +++ b/.gitignore @@ -0,0 +1,164 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +logs/ +*.pt +*.ckpt \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c37bdaf --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Charactr Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md index 3a35457..83d28b8 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,124 @@ # Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis -[Audio samples](https://charactr-platform.github.io/vocos/) | Paper [[abs]](https://arxiv.org/abs/2306.00814) [[pdf]](https://arxiv.org/pdf/2306.00814.pdf) +[Audio samples](https://charactr-platform.github.io/vocos/) | +Paper [[abs]](https://arxiv.org/abs/2306.00814) [[pdf]](https://arxiv.org/pdf/2306.00814.pdf) + +## Installation + +To use Vocos only in inference mode, install it using: + +```bash +pip install vocos +``` + +If you wish to train the model, install it with additional dependencies: + +```bash +pip install vocos[train] +``` + +## Usage + +### Reconstruct audio from mel-spectrogram + +```python +import torch + +from vocos import Vocos + +vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") + +mel = torch.randn(1, 100, 256) # B, C, T + +with torch.no_grad(): + audio = vocos.decode(mel) +``` + +Copy-synthesis from a file: + +```python +import torchaudio + +y, sr = torchaudio.load(YOUR_AUDIO_FILE) +if y.size(0) > 1: # mix to mono + y = y.mean(dim=0, keepdim=True) +y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000) + +with torch.no_grad(): + y_hat = vocos(y) +``` + +### Reconstruct audio from EnCodec + +Additionally, you need to provide a `bandwidth_id` which corresponds to the lookup embedding for bandwidth from the +list: `[1.5, 3.0, 6.0, 12.0]`. + +```python +vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz") + +quantized_features = torch.randn(1, 128, 256) +bandwidth_id = torch.tensor([3]) # 12 kbps + +with torch.no_grad(): + audio = vocos.decode(quantized_features, bandwidth_id=bandwidth_id) +``` + +Copy-synthesis from a file: It extracts and quantizes features with EnCodec, then reconstructs them with Vocos in a +single forward pass. + +```python +y, sr = torchaudio.load(YOUR_AUDIO_FILE) +if y.size(0) > 1: # mix to mono + y = y.mean(dim=0, keepdim=True) +y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000) + +with torch.no_grad(): + y_hat = vocos(y, bandwidth_id=bandwidth_id) +``` + +## Pre-trained models + +The provided models were trained up to 2.5 million generator iterations, which resulted in slightly better objective +scores +compared to those reported in the paper. + +| Model Name | Dataset | Training Iterations | Parameters +|-------------------------------------------------------------------------------------|---------------|---------------------|------------| +| [charactr/vocos-mel-24khz](https://huggingface.co/charactr/vocos-mel-24khz) | LibriTTS | 2.5 M | 13.5 M +| [charactr/vocos-encodec-24khz](https://huggingface.co/charactr/vocos-encodec-24khz) | DNS Challenge | 2.5 M | 7.9 M + +## Training + +Prepare a filelist of audio files for the training and validation set: + +```bash +find $TRAIN_DATASET_DIR -name *.wav > filelist.train +find $VAL_DATASET_DIR -name *.wav > filelist.val +``` + +Fill a config file, e.g. [vocos.yaml](configs%2Fvocos.yaml), with your filelist paths and start training with: + +```bash +python train.py -c configs/vocos.yaml +``` + +Refer to [Pytorch Lightning documentation](https://lightning.ai/docs/pytorch/stable/) for details about customizing the +training pipeline. + +## Citation + +If this code contributes to your research, please cite our work: + +``` +@article{siuzdak2023vocos, + title={Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis}, + author={Siuzdak, Hubert}, + journal={arXiv preprint arXiv:2306.00814}, + year={2023} +} +``` + +## License + +The code in this repository is released under the MIT license as found in the +[LICENSE](LICENSE) file. \ No newline at end of file diff --git a/configs/vocos-encodec.yaml b/configs/vocos-encodec.yaml new file mode 100644 index 0000000..775a52c --- /dev/null +++ b/configs/vocos-encodec.yaml @@ -0,0 +1,86 @@ +# pytorch_lightning==1.8.6 +seed_everything: 4444 + +data: + class_path: vocos.dataset.VocosDataModule + init_args: + train_params: + filelist_path: ??? + sampling_rate: 24000 + num_samples: 24000 + batch_size: 16 + num_workers: 8 + + val_params: + filelist_path: ??? + sampling_rate: 24000 + num_samples: 24000 + batch_size: 16 + num_workers: 8 + +model: + class_path: vocos.experiment.VocosEncodecExp + init_args: + sample_rate: 24000 + initial_learning_rate: 2e-4 + mel_loss_coeff: 45 + mrd_loss_coeff: 1.0 + num_warmup_steps: 0 # Optimizers warmup steps + pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration + + # automatic evaluation + evaluate_utmos: true + evaluate_pesq: true + evaluate_periodicty: true + + feature_extractor: + class_path: vocos.feature_extractors.EncodecFeatures + init_args: + encodec_model: encodec_24khz + bandwidths: [1.5, 3.0, 6.0, 12.0] + train_codebooks: false + + backbone: + class_path: vocos.models.VocosBackbone + init_args: + input_channels: 128 + dim: 384 + intermediate_dim: 1152 + num_layers: 8 + adanorm_num_embeddings: 4 # len(bandwidths) + + head: + class_path: vocos.heads.ISTFTHead + init_args: + dim: 384 + n_fft: 1280 + hop_length: 320 + padding: same + +trainer: + logger: + class_path: pytorch_lightning.loggers.TensorBoardLogger + init_args: + save_dir: logs/ + callbacks: + - class_path: pytorch_lightning.callbacks.LearningRateMonitor + - class_path: pytorch_lightning.callbacks.ModelSummary + init_args: + max_depth: 2 + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + monitor: val_loss + filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} + save_top_k: 3 + save_last: true + - class_path: vocos.helpers.GradNormCallback + + # Lightning calculates max_steps across all optimizer steps (rather than number of batches) + # This equals to 1M steps per generator and 1M per discriminator + max_steps: 2000000 + # You might want to limit val batches when evaluating all the metrics, as they are time-consuming + limit_val_batches: 100 + accelerator: gpu + strategy: ddp + devices: [0] + log_every_n_steps: 100 diff --git a/configs/vocos-imdct.yaml b/configs/vocos-imdct.yaml new file mode 100644 index 0000000..7bdc5cf --- /dev/null +++ b/configs/vocos-imdct.yaml @@ -0,0 +1,86 @@ +# pytorch_lightning==1.8.6 +seed_everything: 4444 + +data: + class_path: vocos.dataset.VocosDataModule + init_args: + train_params: + filelist_path: ??? + sampling_rate: 24000 + num_samples: 16384 + batch_size: 16 + num_workers: 8 + + val_params: + filelist_path: ??? + sampling_rate: 24000 + num_samples: 48384 + batch_size: 16 + num_workers: 8 + +model: + class_path: vocos.experiment.VocosExp + init_args: + sample_rate: 24000 + initial_learning_rate: 2e-4 + mel_loss_coeff: 45 + mrd_loss_coeff: 0.1 + num_warmup_steps: 0 # Optimizers warmup steps + pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration + + # automatic evaluation + evaluate_utmos: true + evaluate_pesq: true + evaluate_periodicty: true + + feature_extractor: + class_path: vocos.feature_extractors.MelSpectrogramFeatures + init_args: + sample_rate: 24000 + n_fft: 1024 + hop_length: 256 + n_mels: 100 + padding: center + + backbone: + class_path: vocos.models.VocosBackbone + init_args: + input_channels: 100 + dim: 512 + intermediate_dim: 1536 + num_layers: 8 + + head: + class_path: vocos.heads.IMDCTCosHead + init_args: + dim: 512 + mdct_frame_len: 512 # mel-spec hop_length * 2 + padding: center + +trainer: + logger: + class_path: pytorch_lightning.loggers.TensorBoardLogger + init_args: + save_dir: logs/ + callbacks: + - class_path: pytorch_lightning.callbacks.LearningRateMonitor + - class_path: pytorch_lightning.callbacks.ModelSummary + init_args: + max_depth: 2 + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + monitor: val_loss + filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} + save_top_k: 3 + save_last: true + - class_path: vocos.helpers.GradNormCallback + + # Lightning calculates max_steps across all optimizer steps (rather than number of batches) + # This equals to 1M steps per generator and 1M per discriminator + max_steps: 2000000 + # You might want to limit val batches when evaluating all the metrics, as they are time-consuming + limit_val_batches: 100 + accelerator: gpu + strategy: ddp + devices: [0] + log_every_n_steps: 100 diff --git a/configs/vocos-resnet.yaml b/configs/vocos-resnet.yaml new file mode 100644 index 0000000..783272a --- /dev/null +++ b/configs/vocos-resnet.yaml @@ -0,0 +1,86 @@ +# pytorch_lightning==1.8.6 +seed_everything: 4444 + +data: + class_path: vocos.dataset.VocosDataModule + init_args: + train_params: + filelist_path: ??? + sampling_rate: 24000 + num_samples: 16384 + batch_size: 16 + num_workers: 8 + + val_params: + filelist_path: ??? + sampling_rate: 24000 + num_samples: 48384 + batch_size: 16 + num_workers: 8 + +model: + class_path: vocos.experiment.VocosExp + init_args: + sample_rate: 24000 + initial_learning_rate: 2e-4 + mel_loss_coeff: 45 + mrd_loss_coeff: 0.1 + num_warmup_steps: 0 # Optimizers warmup steps + pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration + + # automatic evaluation + evaluate_utmos: true + evaluate_pesq: true + evaluate_periodicty: true + + feature_extractor: + class_path: vocos.feature_extractors.MelSpectrogramFeatures + init_args: + sample_rate: 24000 + n_fft: 1024 + hop_length: 256 + n_mels: 100 + padding: center + + backbone: + class_path: vocos.models.VocosResNetBackbone + init_args: + input_channels: 100 + dim: 512 + num_blocks: 3 + + head: + class_path: vocos.heads.ISTFTHead + init_args: + dim: 512 + n_fft: 1024 + hop_length: 256 + padding: center + +trainer: + logger: + class_path: pytorch_lightning.loggers.TensorBoardLogger + init_args: + save_dir: logs/ + callbacks: + - class_path: pytorch_lightning.callbacks.LearningRateMonitor + - class_path: pytorch_lightning.callbacks.ModelSummary + init_args: + max_depth: 2 + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + monitor: val_loss + filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} + save_top_k: 3 + save_last: true + - class_path: vocos.helpers.GradNormCallback + + # Lightning calculates max_steps across all optimizer steps (rather than number of batches) + # This equals to 1M steps per generator and 1M per discriminator + max_steps: 2000000 + # You might want to limit val batches when evaluating all the metrics, as they are time-consuming + limit_val_batches: 100 + accelerator: gpu + strategy: ddp + devices: [0] + log_every_n_steps: 100 diff --git a/configs/vocos.yaml b/configs/vocos.yaml new file mode 100644 index 0000000..f2f4181 --- /dev/null +++ b/configs/vocos.yaml @@ -0,0 +1,87 @@ +# pytorch_lightning==1.8.6 +seed_everything: 4444 + +data: + class_path: vocos.dataset.VocosDataModule + init_args: + train_params: + filelist_path: ??? + sampling_rate: 24000 + num_samples: 16384 + batch_size: 16 + num_workers: 8 + + val_params: + filelist_path: ??? + sampling_rate: 24000 + num_samples: 48384 + batch_size: 16 + num_workers: 8 + +model: + class_path: vocos.experiment.VocosExp + init_args: + sample_rate: 24000 + initial_learning_rate: 2e-4 + mel_loss_coeff: 45 + mrd_loss_coeff: 0.1 + num_warmup_steps: 0 # Optimizers warmup steps + pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration + + # automatic evaluation + evaluate_utmos: true + evaluate_pesq: true + evaluate_periodicty: true + + feature_extractor: + class_path: vocos.feature_extractors.MelSpectrogramFeatures + init_args: + sample_rate: 24000 + n_fft: 1024 + hop_length: 256 + n_mels: 100 + padding: center + + backbone: + class_path: vocos.models.VocosBackbone + init_args: + input_channels: 100 + dim: 512 + intermediate_dim: 1536 + num_layers: 8 + + head: + class_path: vocos.heads.ISTFTHead + init_args: + dim: 512 + n_fft: 1024 + hop_length: 256 + padding: center + +trainer: + logger: + class_path: pytorch_lightning.loggers.TensorBoardLogger + init_args: + save_dir: logs/ + callbacks: + - class_path: pytorch_lightning.callbacks.LearningRateMonitor + - class_path: pytorch_lightning.callbacks.ModelSummary + init_args: + max_depth: 2 + - class_path: pytorch_lightning.callbacks.ModelCheckpoint + init_args: + monitor: val_loss + filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} + save_top_k: 3 + save_last: true + - class_path: vocos.helpers.GradNormCallback + + # Lightning calculates max_steps across all optimizer steps (rather than number of batches) + # This equals to 1M steps per generator and 1M per discriminator + max_steps: 2000000 + # You might want to limit val batches when evaluating all the metrics, as they are time-consuming + limit_val_batches: 100 + accelerator: gpu + strategy: ddp + devices: [0] + log_every_n_steps: 100 diff --git a/metrics/UTMOS.py b/metrics/UTMOS.py new file mode 100644 index 0000000..5e6e9a5 --- /dev/null +++ b/metrics/UTMOS.py @@ -0,0 +1,223 @@ +import os + +import fairseq +import pytorch_lightning as pl +import requests +import torch +import torch.nn as nn +from tqdm import tqdm + +UTMOS_CKPT_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt" +WAV2VEC_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt" + +""" +UTMOS score, automatic Mean Opinion Score (MOS) prediction system, +adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo +""" + + +class UTMOSScore: + """Predicting score for each audio clip.""" + + def __init__(self, device, ckpt_path="epoch=3-step=7459.ckpt"): + self.device = device + filepath = os.path.join(os.path.dirname(__file__), ckpt_path) + if not os.path.exists(filepath): + download_file(UTMOS_CKPT_URL, filepath) + self.model = BaselineLightningModule.load_from_checkpoint(filepath).eval().to(device) + + def score(self, wavs: torch.tensor) -> torch.tensor: + """ + Args: + wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2, + the model processes the input as a single audio clip. The model + performs batch processing when len(wavs) == 3. + """ + if len(wavs.shape) == 1: + out_wavs = wavs.unsqueeze(0).unsqueeze(0) + elif len(wavs.shape) == 2: + out_wavs = wavs.unsqueeze(0) + elif len(wavs.shape) == 3: + out_wavs = wavs + else: + raise ValueError("Dimension of input tensor needs to be <= 3.") + bs = out_wavs.shape[0] + batch = { + "wav": out_wavs, + "domains": torch.zeros(bs, dtype=torch.int).to(self.device), + "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288, + } + with torch.no_grad(): + output = self.model(batch) + + return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3 + + +def download_file(url, filename): + """ + Downloads a file from the given URL + + Args: + url (str): The URL of the file to download. + filename (str): The name to save the file as. + """ + print(f"Downloading file {filename}...") + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size_in_bytes = int(response.headers.get("content-length", 0)) + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + progress_bar.update(len(chunk)) + f.write(chunk) + + progress_bar.close() + + +def load_ssl_model(ckpt_path="wav2vec_small.pt"): + filepath = os.path.join(os.path.dirname(__file__), ckpt_path) + if not os.path.exists(filepath): + download_file(WAV2VEC_URL, filepath) + SSL_OUT_DIM = 768 + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([filepath]) + ssl_model = model[0] + ssl_model.remove_pretraining_modules() + return SSL_model(ssl_model, SSL_OUT_DIM) + + +class BaselineLightningModule(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.construct_model() + self.save_hyperparameters() + + def construct_model(self): + self.feature_extractors = nn.ModuleList( + [load_ssl_model(ckpt_path="wav2vec_small.pt"), DomainEmbedding(3, 128),] + ) + output_dim = sum([feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) + output_layers = [LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)] + output_dim = output_layers[-1].get_output_dim() + output_layers.append( + Projection(hidden_dim=2048, activation=torch.nn.ReLU(), range_clipping=False, input_dim=output_dim) + ) + + self.output_layers = nn.ModuleList(output_layers) + + def forward(self, inputs): + outputs = {} + for feature_extractor in self.feature_extractors: + outputs.update(feature_extractor(inputs)) + x = outputs + for output_layer in self.output_layers: + x = output_layer(x, inputs) + return x + + +class SSL_model(nn.Module): + def __init__(self, ssl_model, ssl_out_dim) -> None: + super(SSL_model, self).__init__() + self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim + + def forward(self, batch): + wav = batch["wav"] + wav = wav.squeeze(1) # [batches, audio_len] + res = self.ssl_model(wav, mask=False, features_only=True) + x = res["x"] + return {"ssl-feature": x} + + def get_output_dim(self): + return self.ssl_out_dim + + +class DomainEmbedding(nn.Module): + def __init__(self, n_domains, domain_dim) -> None: + super().__init__() + self.embedding = nn.Embedding(n_domains, domain_dim) + self.output_dim = domain_dim + + def forward(self, batch): + return {"domain-feature": self.embedding(batch["domains"])} + + def get_output_dim(self): + return self.output_dim + + +class LDConditioner(nn.Module): + """ + Conditions ssl output by listener embedding + """ + + def __init__(self, input_dim, judge_dim, num_judges=None): + super().__init__() + self.input_dim = input_dim + self.judge_dim = judge_dim + self.num_judges = num_judges + assert num_judges != None + self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) + # concat [self.output_layer, phoneme features] + + self.decoder_rnn = nn.LSTM( + input_size=self.input_dim + self.judge_dim, + hidden_size=512, + num_layers=1, + batch_first=True, + bidirectional=True, + ) # linear? + self.out_dim = self.decoder_rnn.hidden_size * 2 + + def get_output_dim(self): + return self.out_dim + + def forward(self, x, batch): + judge_ids = batch["judge_id"] + if "phoneme-feature" in x.keys(): + concatenated_feature = torch.cat( + (x["ssl-feature"], x["phoneme-feature"].unsqueeze(1).expand(-1, x["ssl-feature"].size(1), -1)), dim=2 + ) + else: + concatenated_feature = x["ssl-feature"] + if "domain-feature" in x.keys(): + concatenated_feature = torch.cat( + (concatenated_feature, x["domain-feature"].unsqueeze(1).expand(-1, concatenated_feature.size(1), -1),), + dim=2, + ) + if judge_ids != None: + concatenated_feature = torch.cat( + ( + concatenated_feature, + self.judge_embedding(judge_ids).unsqueeze(1).expand(-1, concatenated_feature.size(1), -1), + ), + dim=2, + ) + decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) + return decoder_output + + +class Projection(nn.Module): + def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): + super(Projection, self).__init__() + self.range_clipping = range_clipping + output_dim = 1 + if range_clipping: + self.proj = nn.Tanh() + + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), activation, nn.Dropout(0.3), nn.Linear(hidden_dim, output_dim), + ) + self.output_dim = output_dim + + def forward(self, x, batch): + output = self.net(x) + + # range clipping + if self.range_clipping: + return self.proj(output) * 2.0 + 3 + else: + return output + + def get_output_dim(self): + return self.output_dim diff --git a/metrics/periodicity.py b/metrics/periodicity.py new file mode 100644 index 0000000..728017c --- /dev/null +++ b/metrics/periodicity.py @@ -0,0 +1,105 @@ +import librosa +import numpy as np +import torch +import torchaudio +import torchcrepe +from torchcrepe.loudness import REF_DB + +SILENCE_THRESHOLD = -60 +UNVOICED_THRESHOLD = 0.21 + +""" +Periodicity metrics adapted from https://github.com/descriptinc/cargan +""" + + +def predict_pitch( + audio: torch.Tensor, silence_threshold: float = SILENCE_THRESHOLD, unvoiced_treshold: float = UNVOICED_THRESHOLD +): + """ + Predicts pitch and periodicity for the given audio. + + Args: + audio (Tensor): The audio waveform. + silence_threshold (float): The threshold for silence detection. + unvoiced_treshold (float): The threshold for unvoiced detection. + + Returns: + pitch (ndarray): The predicted pitch. + periodicity (ndarray): The predicted periodicity. + """ + # torchcrepe inference + pitch, periodicity = torchcrepe.predict( + audio, + fmin=50.0, + fmax=550, + sample_rate=torchcrepe.SAMPLE_RATE, + model="full", + return_periodicity=True, + device=audio.device, + pad=False, + ) + pitch = pitch.cpu().numpy() + periodicity = periodicity.cpu().numpy() + + # Calculate dB-scaled spectrogram and set low energy frames to unvoiced + hop_length = torchcrepe.SAMPLE_RATE // 100 # default CREPE + stft = torchaudio.functional.spectrogram( + audio, + window=torch.hann_window(torchcrepe.WINDOW_SIZE, device=audio.device), + n_fft=torchcrepe.WINDOW_SIZE, + hop_length=hop_length, + win_length=torchcrepe.WINDOW_SIZE, + power=2, + normalized=False, + pad=0, + center=False, + ) + + # Perceptual weighting + freqs = librosa.fft_frequencies(sr=torchcrepe.SAMPLE_RATE, n_fft=torchcrepe.WINDOW_SIZE) + perceptual_stft = librosa.perceptual_weighting(stft.cpu().numpy(), freqs) - REF_DB + silence = perceptual_stft.mean(axis=1) < silence_threshold + + periodicity[silence] = 0 + pitch[periodicity < unvoiced_treshold] = torchcrepe.UNVOICED + + return pitch, periodicity + + +def calculate_periodicity_metrics(y: torch.Tensor, y_hat: torch.Tensor): + """ + Calculates periodicity metrics for the predicted and true audio data. + + Args: + y (Tensor): The true audio data. + y_hat (Tensor): The predicted audio data. + + Returns: + periodicity_loss (float): The periodicity loss. + pitch_loss (float): The pitch loss. + f1 (float): The F1 score for voiced/unvoiced classification + """ + true_pitch, true_periodicity = predict_pitch(y) + pred_pitch, pred_periodicity = predict_pitch(y_hat) + + true_voiced = ~np.isnan(true_pitch) + pred_voiced = ~np.isnan(pred_pitch) + + periodicity_loss = np.sqrt(((pred_periodicity - true_periodicity) ** 2).mean(axis=1)).mean() + + # Update pitch rmse + voiced = true_voiced & pred_voiced + difference_cents = 1200 * (np.log2(true_pitch[voiced]) - np.log2(pred_pitch[voiced])) + pitch_loss = np.sqrt((difference_cents ** 2).mean()) + + # voiced/unvoiced precision and recall + true_positives = (true_voiced & pred_voiced).sum() + false_positives = (~true_voiced & pred_voiced).sum() + false_negatives = (true_voiced & ~pred_voiced).sum() + + precision = true_positives / (true_positives + false_positives) + recall = true_positives / (true_positives + false_negatives) + f1 = 2 * precision * recall / (precision + recall) + + return periodicity_loss, pitch_loss, f1 diff --git a/requirements-train.txt b/requirements-train.txt new file mode 100644 index 0000000..d7c23fc --- /dev/null +++ b/requirements-train.txt @@ -0,0 +1,7 @@ +pytorch_lightning==1.8.6 +jsonargparse[signatures] +transformers +matplotlib +torchcrepe +pesq +fairseq diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..410d868 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch +torchaudio +numpy +scipy +einops +pyyaml +huggingface_hub +encodec==0.1.1 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6f72868 --- /dev/null +++ b/setup.py @@ -0,0 +1,39 @@ +import io +import os + +from setuptools import find_packages, setup + +for line in open("vocos/__init__.py"): + line = line.strip() + if "__version__" in line: + context = {} + exec(line, context) + VERSION = context["__version__"] + + +def read(*paths, **kwargs): + content = "" + with io.open( + os.path.join(os.path.dirname(__file__), *paths), encoding=kwargs.get("encoding", "utf8"), + ) as open_file: + content = open_file.read().strip() + return content + + +def read_requirements(path): + return [line.strip() for line in read(path).split("\n") if not line.startswith(('"', "#", "-", "git+"))] + + +setup( + name="vocos", + version=VERSION, + author="Hubert Siuzdak", + author_email="huberts@charactr.com", + description="Fourier-based neural vocoder for high-quality audio synthesis", + url="https://github.com/charactr-platform/vocos", + long_description=read("README.md"), + long_description_content_type="text/markdown", + packages=find_packages(), + install_requires=read_requirements("requirements.txt"), + extras_require={"train": read_requirements("requirements-train.txt")}, +) diff --git a/train.py b/train.py new file mode 100644 index 0000000..4b50288 --- /dev/null +++ b/train.py @@ -0,0 +1,6 @@ +from pytorch_lightning.cli import LightningCLI + + +if __name__ == "__main__": + cli = LightningCLI(run=False) + cli.trainer.fit(model=cli.model, datamodule=cli.datamodule) diff --git a/vocos/__init__.py b/vocos/__init__.py new file mode 100644 index 0000000..19db5fc --- /dev/null +++ b/vocos/__init__.py @@ -0,0 +1,4 @@ +from vocos.pretrained import Vocos + + +__version__ = "0.0.1" diff --git a/vocos/dataset.py b/vocos/dataset.py new file mode 100644 index 0000000..25b3bc1 --- /dev/null +++ b/vocos/dataset.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass + +import numpy as np +import torch +import torchaudio +from pytorch_lightning import LightningDataModule +from torch.utils.data import Dataset, DataLoader + +torch.set_num_threads(1) + + +@dataclass +class DataConfig: + filelist_path: str + sampling_rate: int + num_samples: int + batch_size: int + num_workers: int + + +class VocosDataModule(LightningDataModule): + def __init__(self, train_params: DataConfig, val_params: DataConfig): + super().__init__() + self.train_config = train_params + self.val_config = val_params + + def _get_dataloder(self, cfg: DataConfig, train: bool): + dataset = VocosDataset(cfg, train=train) + dataloader = DataLoader( + dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, + ) + return dataloader + + def train_dataloader(self) -> DataLoader: + return self._get_dataloder(self.train_config, train=True) + + def val_dataloader(self) -> DataLoader: + return self._get_dataloder(self.val_config, train=False) + + +class VocosDataset(Dataset): + def __init__(self, cfg: DataConfig, train: bool): + with open(cfg.filelist_path) as f: + self.filelist = f.read().splitlines() + self.sampling_rate = cfg.sampling_rate + self.num_samples = cfg.num_samples + self.train = train + + def __len__(self) -> int: + return len(self.filelist) + + def __getitem__(self, index: int) -> torch.Tensor: + audio_path = self.filelist[index] + y, sr = torchaudio.load(audio_path) + if y.size(0) > 1: + # mix to mono + y = y.mean(dim=0, keepdim=True) + gain = np.random.uniform(-1, -6) if self.train else -3 + y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) + if sr != self.sampling_rate: + y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) + if y.size(-1) < self.num_samples: + pad_length = self.num_samples - y.size(-1) + padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) + y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) + elif self.train: + start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) + y = y[:, start : start + self.num_samples] + else: + # During validation, take always the first segment for determinism + y = y[:, : self.num_samples] + + return y[0] diff --git a/vocos/discriminators.py b/vocos/discriminators.py new file mode 100644 index 0000000..2f6dece --- /dev/null +++ b/vocos/discriminators.py @@ -0,0 +1,202 @@ +from typing import Tuple, List + +import torch +from torch import nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm + + +class MultiPeriodDiscriminator(nn.Module): + """ + Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + periods (tuple[int]): Tuple of periods for each discriminator. + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None): + super().__init__() + self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(nn.Module): + def __init__( + self, + period: int, + in_channels: int = 1, + kernel_size: int = 5, + stride: int = 3, + lrelu_slope: float = 0.1, + num_embeddings: int = None, + ): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + self.lrelu_slope = lrelu_slope + + def forward( + self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x = x.unsqueeze(1) + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = torch.nn.functional.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for i, l in enumerate(self.convs): + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + if i > 0: + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)), + num_embeddings: int = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator. + Each resolution should be a tuple of (n_fft, hop_length, win_length). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + resolution: Tuple[int, int, int], + channels: int = 64, + in_channels: int = 1, + num_embeddings: int = None, + lrelu_slope: float = 0.1, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.lrelu_slope = lrelu_slope + self.convs = nn.ModuleList( + [ + weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))), + weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)), + weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) + torch.nn.init.zeros_(self.emb.weight) + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1))) + + def forward( + self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x: torch.Tensor) -> torch.Tensor: + n_fft, hop_length, win_length = self.resolution + magnitude_spectrogram = torch.stft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=None, # interestingly rectangular window kind of works here + center=True, + return_complex=True, + ).abs() + + return magnitude_spectrogram diff --git a/vocos/experiment.py b/vocos/experiment.py new file mode 100644 index 0000000..22857d9 --- /dev/null +++ b/vocos/experiment.py @@ -0,0 +1,371 @@ +import math + +import numpy as np +import pytorch_lightning as pl +import torch +import torchaudio +import transformers + +from vocos.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator +from vocos.feature_extractors import FeatureExtractor +from vocos.heads import FourierHead +from vocos.helpers import plot_spectrogram_to_numpy +from vocos.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss +from vocos.models import Backbone +from vocos.modules import safe_log + + +class VocosExp(pl.LightningModule): + # noinspection PyUnusedLocal + def __init__( + self, + feature_extractor: FeatureExtractor, + backbone: Backbone, + head: FourierHead, + sample_rate: int, + initial_learning_rate: float, + num_warmup_steps: int = 0, + mel_loss_coeff: float = 45, + mrd_loss_coeff: float = 1.0, + pretrain_mel_steps: int = 0, + decay_mel_coeff: bool = False, + evaluate_utmos: bool = False, + evaluate_pesq: bool = False, + evaluate_periodicty: bool = False, + ): + """ + Args: + feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals. + backbone (Backbone): An instance of Backbone model. + head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform. + sample_rate (int): Sampling rate of the audio signals. + initial_learning_rate (float): Initial learning rate for the optimizer. + num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0. + mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45. + mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0. + pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0. + decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False. + evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run. + evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run. + evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run. + """ + super().__init__() + self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"]) + + self.feature_extractor = feature_extractor + self.backbone = backbone + self.head = head + + self.multiperioddisc = MultiPeriodDiscriminator() + self.multiresddisc = MultiResolutionDiscriminator() + + self.disc_loss = DiscriminatorLoss() + self.gen_loss = GeneratorLoss() + self.feat_matching_loss = FeatureMatchingLoss() + self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate) + + self.train_discriminator = False + self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff + + def configure_optimizers(self): + disc_params = [ + {"params": self.multiperioddisc.parameters()}, + {"params": self.multiresddisc.parameters()}, + ] + gen_params = [ + {"params": self.feature_extractor.parameters()}, + {"params": self.backbone.parameters()}, + {"params": self.head.parameters()}, + ] + + opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate) + opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate) + + max_steps = self.trainer.max_steps // 2 # Max steps per optimizer + scheduler_disc = transformers.get_cosine_schedule_with_warmup( + opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, + ) + scheduler_gen = transformers.get_cosine_schedule_with_warmup( + opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, + ) + + return ( + [opt_disc, opt_gen], + [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}], + ) + + def forward(self, audio_input, **kwargs): + features = self.feature_extractor(audio_input, **kwargs) + x = self.backbone(features, **kwargs) + audio_output = self.head(x) + return audio_output + + def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): + audio_input = batch + + # train discriminator + if optimizer_idx == 0 and self.train_discriminator: + with torch.no_grad(): + audio_hat = self(audio_input, **kwargs) + + real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,) + real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,) + loss_mp, loss_mp_real, _ = self.disc_loss( + disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp + ) + loss_mrd, loss_mrd_real, _ = self.disc_loss( + disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd + ) + loss_mp /= len(loss_mp_real) + loss_mrd /= len(loss_mrd_real) + loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + + self.log("discriminator/total", loss, prog_bar=True) + self.log("discriminator/multi_period_loss", loss_mp) + self.log("discriminator/multi_res_loss", loss_mrd) + return loss + + # train generator + if optimizer_idx == 1: + audio_hat = self(audio_input, **kwargs) + if self.train_discriminator: + _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc( + y=audio_input, y_hat=audio_hat, **kwargs, + ) + _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc( + y=audio_input, y_hat=audio_hat, **kwargs, + ) + loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp) + loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd) + loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) + loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) + loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp) + loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd) + + self.log("generator/multi_period_loss", loss_gen_mp) + self.log("generator/multi_res_loss", loss_gen_mrd) + self.log("generator/feature_matching_mp", loss_fm_mp) + self.log("generator/feature_matching_mrd", loss_fm_mrd) + else: + loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0 + + mel_loss = self.melspec_loss(audio_hat, audio_input) + loss = ( + loss_gen_mp + + self.hparams.mrd_loss_coeff * loss_gen_mrd + + loss_fm_mp + + self.hparams.mrd_loss_coeff * loss_fm_mrd + + self.mel_loss_coeff * mel_loss + ) + + self.log("generator/total_loss", loss, prog_bar=True) + self.log("mel_loss_coeff", self.mel_loss_coeff) + self.log("generator/mel_loss", mel_loss) + + if self.global_step % 1000 == 0 and self.global_rank == 0: + self.logger.experiment.add_audio( + "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate + ) + self.logger.experiment.add_audio( + "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate + ) + with torch.no_grad(): + mel = safe_log(self.melspec_loss.mel_spec(audio_input[0])) + mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0])) + self.logger.experiment.add_image( + "train/mel_target", + plot_spectrogram_to_numpy(mel.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + self.logger.experiment.add_image( + "train/mel_pred", + plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + + return loss + + def on_validation_epoch_start(self): + if self.hparams.evaluate_utmos: + from metrics.UTMOS import UTMOSScore + + if not hasattr(self, "utmos_model"): + self.utmos_model = UTMOSScore(device=self.device) + + def validation_step(self, batch, batch_idx, **kwargs): + audio_input = batch + audio_hat = self(audio_input, **kwargs) + + audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000) + audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000) + + if self.hparams.evaluate_periodicty: + from metrics.periodicity import calculate_periodicity_metrics + + periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz) + else: + periodicity_loss = pitch_loss = f1_score = 0 + + if self.hparams.evaluate_utmos: + utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean() + else: + utmos_score = torch.zeros(1, device=self.device) + + if self.hparams.evaluate_pesq: + from pesq import pesq + + pesq_score = 0 + for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()): + pesq_score += pesq(16000, ref, deg, "wb", on_error=1) + pesq_score /= len(audio_16_khz) + pesq_score = torch.tensor(pesq_score) + else: + pesq_score = torch.zeros(1, device=self.device) + + mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1)) + total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) + + return { + "val_loss": total_loss, + "mel_loss": mel_loss, + "utmos_score": utmos_score, + "pesq_score": pesq_score, + "periodicity_loss": periodicity_loss, + "pitch_loss": pitch_loss, + "f1_score": f1_score, + "audio_input": audio_input[0], + "audio_pred": audio_hat[0], + } + + def validation_epoch_end(self, outputs): + if self.global_rank == 0: + *_, audio_in, audio_pred = outputs[0].values() + self.logger.experiment.add_audio( + "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate + ) + self.logger.experiment.add_audio( + "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate + ) + mel_target = safe_log(self.melspec_loss.mel_spec(audio_in)) + mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred)) + self.logger.experiment.add_image( + "val_mel_target", + plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + self.logger.experiment.add_image( + "val_mel_hat", + plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), + self.global_step, + dataformats="HWC", + ) + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() + mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean() + utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean() + pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean() + periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean() + pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean() + f1_score = np.array([x["f1_score"] for x in outputs]).mean() + + self.log("val_loss", avg_loss, sync_dist=True) + self.log("val/mel_loss", mel_loss, sync_dist=True) + self.log("val/utmos_score", utmos_score, sync_dist=True) + self.log("val/pesq_score", pesq_score, sync_dist=True) + self.log("val/periodicity_loss", periodicity_loss, sync_dist=True) + self.log("val/pitch_loss", pitch_loss, sync_dist=True) + self.log("val/f1_score", f1_score, sync_dist=True) + + @property + def global_step(self): + """ + Override global_step so that it returns the total number of batches processed + """ + return self.trainer.fit_loop.epoch_loop.total_batch_idx + + def on_train_batch_start(self, *args): + if self.global_step >= self.hparams.pretrain_mel_steps: + self.train_discriminator = True + else: + self.train_discriminator = False + + def on_train_batch_end(self, *args): + def mel_loss_coeff_decay(current_step, num_cycles=0.5): + max_steps = self.trainer.max_steps // 2 + if current_step < self.hparams.num_warmup_steps: + return 1.0 + progress = float(current_step - self.hparams.num_warmup_steps) / float( + max(1, max_steps - self.hparams.num_warmup_steps) + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + if self.hparams.decay_mel_coeff: + self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1) + + +class VocosEncodecExp(VocosExp): + """ + VocosEncodecExp is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN. + It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to + a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step, + while during validation, a fixed bandwidth_id is used. + """ + + def __init__( + self, + feature_extractor: FeatureExtractor, + backbone: Backbone, + head: FourierHead, + sample_rate: int, + initial_learning_rate: float, + num_warmup_steps: int, + mel_loss_coeff: float = 45, + mrd_loss_coeff: float = 1.0, + pretrain_mel_steps: int = 0, + decay_mel_coeff: bool = False, + evaluate_utmos: bool = False, + evaluate_pesq: bool = False, + evaluate_periodicty: bool = False, + ): + super().__init__( + feature_extractor, + backbone, + head, + sample_rate, + initial_learning_rate, + num_warmup_steps, + mel_loss_coeff, + mrd_loss_coeff, + pretrain_mel_steps, + decay_mel_coeff, + evaluate_utmos, + evaluate_pesq, + evaluate_periodicty, + ) + # Override with conditional discriminators + self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) + self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) + + def training_step(self, *args): + bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,) + output = super().training_step(*args, bandwidth_id=bandwidth_id) + return output + + def validation_step(self, *args): + bandwidth_id = torch.tensor([0], device=self.device) + output = super().validation_step(*args, bandwidth_id=bandwidth_id) + return output + + def validation_epoch_end(self, outputs): + if self.global_rank == 0: + *_, audio_in, _ = outputs[0].values() + # Resynthesis with encodec for reference + self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0]) + encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :]) + self.logger.experiment.add_audio( + "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate, + ) + + super().validation_epoch_end(outputs) diff --git a/vocos/feature_extractors.py b/vocos/feature_extractors.py new file mode 100644 index 0000000..0b4d47b --- /dev/null +++ b/vocos/feature_extractors.py @@ -0,0 +1,96 @@ +from typing import List + +import torch +import torchaudio +from encodec import EncodecModel +from torch import nn + +from vocos.modules import safe_log + + +class FeatureExtractor(nn.Module): + """Base class for feature extractors.""" + + def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Extract features from the given audio. + + Args: + audio (Tensor): Input audio waveform. + + Returns: + Tensor: Extracted features of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class MelSpectrogramFeatures(FeatureExtractor): + def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=padding == "center", + power=1, + ) + + def forward(self, audio, **kwargs): + if self.padding == "same": + pad = self.mel_spec.win_length - self.mel_spec.hop_length + audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") + mel = self.mel_spec(audio) + features = safe_log(mel) + return features + + +class EncodecFeatures(FeatureExtractor): + def __init__( + self, + encodec_model: str = "encodec_24khz", + bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0], + train_codebooks: bool = False, + ): + super().__init__() + if encodec_model == "encodec_24khz": + encodec = EncodecModel.encodec_model_24khz + elif encodec_model == "encodec_48khz": + encodec = EncodecModel.encodec_model_48khz + else: + raise ValueError( + f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'." + ) + self.encodec = encodec(pretrained=True) + for param in self.encodec.parameters(): + param.requires_grad = False + self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth( + self.encodec.frame_rate, bandwidth=max(bandwidths) + ) + codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0) + self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks) + self.bandwidths = bandwidths + + @torch.no_grad() + def get_encodec_codes(self, audio): + audio = audio.unsqueeze(1) + emb = self.encodec.encoder(audio) + codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth) + return codes + + def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor): + self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode + self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id]) + codes = self.get_encodec_codes(audio) + # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights` + # with offsets given by the number of bins, and finally summed in a vectorized operation. + offsets = torch.arange( + 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device + ) + embeddings_idxs = codes + offsets.view(-1, 1, 1) + features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0) + return features.transpose(1, 2) diff --git a/vocos/heads.py b/vocos/heads.py new file mode 100644 index 0000000..181865f --- /dev/null +++ b/vocos/heads.py @@ -0,0 +1,152 @@ +import torch +from torch import nn +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz + +from vocos.spectral_ops import IMDCT, ISTFT +from vocos.modules import symexp + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes + x = torch.cos(p) + y = torch.sin(p) + phase = torch.atan2(y, x) + S = mag * torch.exp(phase * 1j) + audio = self.istft(S) + return audio + + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False, + ): + super().__init__() + out_dim = mdct_frame_len // 2 + self.out = nn.Linear(dim, out_dim) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + self.clip_audio = clip_audio + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.out.weight.mul_(scale.view(-1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + x = symexp(x) + x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + + return audio + + +class IMDCTCosHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) ยท cos(p) + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): + super().__init__() + self.clip_audio = clip_audio + self.out = nn.Linear(dim, mdct_frame_len) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTCosHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + m, p = x.chunk(2, dim=2) + m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(m * torch.cos(p)) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + return audio diff --git a/vocos/helpers.py b/vocos/helpers.py new file mode 100644 index 0000000..3d30301 --- /dev/null +++ b/vocos/helpers.py @@ -0,0 +1,71 @@ +import matplotlib +import numpy as np +import torch +from matplotlib import pyplot as plt +from pytorch_lightning import Callback + +matplotlib.use("Agg") + + +def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: + """ + Save a matplotlib figure to a numpy array. + + Args: + fig (Figure): Matplotlib figure object. + + Returns: + ndarray: Numpy array representing the figure. + """ + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray: + """ + Plot a spectrogram and convert it to a numpy array. + + Args: + spectrogram (ndarray): Spectrogram data. + + Returns: + ndarray: Numpy array representing the plotted spectrogram. + """ + spectrogram = spectrogram.astype(np.float32) + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +class GradNormCallback(Callback): + """ + Callback to log the gradient norm. + """ + + def on_after_backward(self, trainer, model): + model.log("grad_norm", gradient_norm(model)) + + +def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor: + """ + Compute the gradient norm. + + Args: + model (Module): PyTorch model. + norm_type (float, optional): Type of the norm. Defaults to 2.0. + + Returns: + Tensor: Gradient norm. + """ + grads = [p.grad for p in model.parameters() if p.grad is not None] + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type) + return total_norm diff --git a/vocos/loss.py b/vocos/loss.py new file mode 100644 index 0000000..e6b0ed5 --- /dev/null +++ b/vocos/loss.py @@ -0,0 +1,114 @@ +from typing import List, Tuple + +import torch +import torchaudio +from torch import nn + +from vocos.modules import safe_log + + +class MelSpecReconstructionLoss(nn.Module): + """ + L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample + """ + + def __init__( + self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100, + ): + super().__init__() + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1, + ) + + def forward(self, y_hat, y) -> torch.Tensor: + """ + Args: + y_hat (Tensor): Predicted audio waveform. + y (Tensor): Ground truth audio waveform. + + Returns: + Tensor: L1 loss between the mel-scaled magnitude spectrograms. + """ + mel_hat = safe_log(self.mel_spec(y_hat)) + mel = safe_log(self.mel_spec(y)) + + loss = torch.nn.functional.l1_loss(mel, mel_hat) + + return loss + + +class GeneratorLoss(nn.Module): + """ + Generator Loss module. Calculates the loss for the generator based on discriminator outputs. + """ + + def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + disc_outputs (List[Tensor]): List of discriminator outputs. + + Returns: + Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from + the sub-discriminators + """ + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean(torch.clamp(1 - dg, min=0)) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class DiscriminatorLoss(nn.Module): + """ + Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. + """ + + def forward( + self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. + disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. + + Returns: + Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from + the sub-discriminators for real outputs, and a list of + loss values for generated outputs. + """ + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(torch.clamp(1 - dr, min=0)) + g_loss = torch.mean(torch.clamp(1 + dg, min=0)) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +class FeatureMatchingLoss(nn.Module): + """ + Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. + """ + + def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: + """ + Args: + fmap_r (List[List[Tensor]]): List of feature maps from real samples. + fmap_g (List[List[Tensor]]): List of feature maps from generated samples. + + Returns: + Tensor: The calculated feature matching loss. + """ + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss diff --git a/vocos/models.py b/vocos/models.py new file mode 100644 index 0000000..886a88a --- /dev/null +++ b/vocos/models.py @@ -0,0 +1,117 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from vocos.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.embed(x) + if self.adanorm: + assert bandwidth_id is not None + x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, cond_embedding_id=bandwidth_id) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, input_channels, dim, num_blocks, layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x diff --git a/vocos/modules.py b/vocos/modules.py new file mode 100644 index 0000000..9688a97 --- /dev/null +++ b/vocos/modules.py @@ -0,0 +1,213 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn.utils import weight_norm, remove_weight_norm + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: tuple[int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: float = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) diff --git a/vocos/pretrained.py b/vocos/pretrained.py new file mode 100644 index 0000000..00a5883 --- /dev/null +++ b/vocos/pretrained.py @@ -0,0 +1,95 @@ +from typing import Tuple, Any, Union, Dict + +import torch +import yaml +from huggingface_hub import hf_hub_download +from torch import nn +from vocos.feature_extractors import FeatureExtractor, EncodecFeatures +from vocos.heads import FourierHead +from vocos.models import Backbone + + +def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: + """Instantiates a class with the given args and init. + + Args: + args: Positional arguments required for instantiation. + init: Dict of the form {"class_path":...,"init_args":...}. + + Returns: + The instantiated class object. + """ + kwargs = init.get("init_args", {}) + if not isinstance(args, tuple): + args = (args,) + class_module, class_name = init["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(*args, **kwargs) + + +class Vocos(nn.Module): + """ + The Vocos class represents a Fourier-based neural vocoder for audio synthesis. + This class is primarily designed for inference, with support for loading from pretrained + model checkpoints. It consists of three main components: a feature extractor, + a backbone, and a head. + """ + + def __init__( + self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead, + ): + super().__init__() + self.feature_extractor = feature_extractor + self.backbone = backbone + self.head = head + + @classmethod + def from_hparams(cls, config_path: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. + """ + with open(config_path, "r") as f: + config = yaml.safe_load(f) + feature_extractor = instantiate_class(args=(), init=config["feature_extractor"]) + backbone = instantiate_class(args=(), init=config["backbone"]) + head = instantiate_class(args=(), init=config["head"]) + model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) + return model + + @classmethod + def from_pretrained(self, repo_id: str) -> "Vocos": + """ + Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. + """ + config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml") + model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") + model = self.from_hparams(config_path) + state_dict = torch.load(model_path, map_location="cpu") + if isinstance(model.feature_extractor, EncodecFeatures): + encodec_parameters = { + "feature_extractor.encodec." + key: value + for key, value in model.feature_extractor.encodec.state_dict().items() + } + state_dict.update(encodec_parameters) + model.load_state_dict(state_dict) + model.eval() + return model + + def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, + which is then passed through the backbone and the head to reconstruct the audio output. + """ + features = self.feature_extractor(audio_input, **kwargs) + audio_output = self.decode(features, **kwargs) + return audio_output + + def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to decode audio waveform from already calculated features. The features input is passed through + the backbone and the head to reconstruct the audio output. + """ + x = self.backbone(features_input, **kwargs) + audio_output = self.head(x) + return audio_output diff --git a/vocos/spectral_ops.py b/vocos/spectral_ops.py new file mode 100644 index 0000000..a8eda1c --- /dev/null +++ b/vocos/spectral_ops.py @@ -0,0 +1,192 @@ +import numpy as np +import scipy +import torch +from torch import nn, view_as_real, view_as_complex + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + +class MDCT(nn.Module): + """ + Modified Discrete Cosine Transform (MDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) + post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) + # view_as_real: NCCL Backend does not support ComplexFloat data type + # https://github.com/pytorch/pytorch/issues/71613 + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. + + Args: + audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size + and T is the length of the audio. + + Returns: + Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames + and N is the number of frequency bins. + """ + if self.padding == "center": + audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2)) + elif self.padding == "same": + # hop_length is 1/2 frame_len + audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4)) + else: + raise ValueError("Padding must be 'center' or 'same'.") + + x = audio.unfold(-1, self.frame_len, self.frame_len // 2) + N = self.frame_len // 2 + x = x * self.window.expand(x.shape) + X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N] + res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) + return torch.real(res) * np.sqrt(2) + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1) + y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio