Skip to content

Commit

Permalink
Support transcoding audio formats when saving tarred datasets (FLAC, …
Browse files Browse the repository at this point in the history
…OPUS) (NVIDIA#8102)

* Support transcoding audio formats when saving tarred datasets (FLAC, OPUS)

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Revert the default change

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pzelasko and pre-commit-ci[bot] authored Jan 3, 2024
1 parent 8a8258d commit ae95cda
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion scripts/speech_recognition/convert_to_tarred_audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from io import BytesIO
from typing import Any, List, Optional

import numpy as np
import soundfile
from joblib import Parallel, delayed
from omegaconf import DictConfig, OmegaConf, open_dict

Expand Down Expand Up @@ -179,6 +182,15 @@
action='store_true',
help="Do not write sharded manifests along with the aggregated manifest.",
)
parser.add_argument(
"--force_codec",
type=str,
default=None,
help=(
"If specified, transcode the audio to the given format. "
"Supports libnsndfile formats (example values: 'opus', 'flac')."
),
)
parser.add_argument('--workers', type=int, default=1, help='Number of worker processes')
args = parser.parse_args()

Expand All @@ -193,6 +205,7 @@ class ASRTarredDatasetConfig:
sort_in_shards: bool = True
shard_manifests: bool = True
keep_files_together: bool = False
force_codec: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -569,6 +582,25 @@ def _read_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig):

return entries, total_duration, filtered_entries, filtered_duration

def _write_to_tar(self, tar, audio_filepath: str, squashed_filename: str) -> None:
if (codec := self.config.force_codec) is None or audio_filepath.endswith(f".{codec}"):
# Add existing file without transcoding.
tar.add(audio_filepath, arcname=squashed_filename)
else:
# Transcode to the desired format in-memory and add the result to the tar file.
audio, sampling_rate = soundfile.read(audio_filepath, dtype=np.float32)
encoded_audio = BytesIO()
if codec == "opus":
kwargs = {"format": "ogg", "subtype": "opus"}
else:
kwargs = {"format": codec}
soundfile.write(encoded_audio, audio, sampling_rate, closefd=False, **kwargs)
encoded_squashed_filename = f"{squashed_filename.split('.')[0]}.{codec}"
ti = tarfile.TarInfo(encoded_squashed_filename)
encoded_audio.seek(0)
ti.size = len(encoded_audio.getvalue())
tar.addfile(ti, encoded_audio)

def _create_shard(self, entries, target_dir, shard_id, manifest_folder):
"""Creates a tarball containing the audio files from `entries`.
"""
Expand All @@ -594,7 +626,7 @@ def _create_shard(self, entries, target_dir, shard_id, manifest_folder):
base = base.replace('.', '_')
squashed_filename = f'{base}{ext}'
if squashed_filename not in count:
tar.add(audio_filepath, arcname=squashed_filename)
self._write_to_tar(tar, audio_filepath, squashed_filename)
to_write = squashed_filename
count[squashed_filename] = 1
else:
Expand Down Expand Up @@ -671,6 +703,7 @@ def create_tar_datasets(min_duration: float, max_duration: float, target_dir: st
sort_in_shards=args.sort_in_shards,
shard_manifests=shard_manifests,
keep_files_together=args.keep_files_together,
force_codec=args.force_codec,
)
metadata.dataset_config = dataset_cfg

Expand All @@ -692,6 +725,7 @@ def create_tar_datasets(min_duration: float, max_duration: float, target_dir: st
sort_in_shards=args.sort_in_shards,
shard_manifests=shard_manifests,
keep_files_together=args.keep_files_together,
force_codec=args.force_codec,
)
builder.configure(config)
builder.create_new_dataset(manifest_path=args.manifest_path, target_dir=target_dir, num_workers=args.workers)
Expand Down

0 comments on commit ae95cda

Please sign in to comment.