Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use streaming asr to transcript the audio #68

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/libriheavy/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def get_params() -> AttributeDict:
# you can find the docs in textsearch/match.py#split_aligned_queries
"preceding_context_length": 1000,
"timestamp_position": "current",
"duration_add_on_left": 0.0,
"duration_add_on_right": 0.5,
"duration_add_on_left": -0.4,
"duration_add_on_right": -0.8,
"silence_length_to_break": 0.45,
"overlap_ratio": 0.25,
"min_duration": 2,
Expand Down
83 changes: 29 additions & 54 deletions examples/libriheavy/tools/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lhotse.cut import Cut
from lhotse.dataset import (
K2SpeechRecognitionDataset,
DynamicBucketingSampler,
SimpleCutSampler,
)
from lhotse.dataset.input_strategies import (
Expand All @@ -38,53 +39,6 @@
from textsearch.utils import str2bool


class SpeechRecognitionDataset(K2SpeechRecognitionDataset):
def __init__(
self,
return_cuts: bool = False,
input_strategy: BatchIO = OnTheFlyFeatures(Fbank()),
):
super().__init__(return_cuts=return_cuts, input_strategy=input_strategy)

def __getitem__(
self, cuts: CutSet
) -> Dict[str, Union[torch.Tensor, List[Cut]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_frames and max_cuts.
"""
self.hdf5_fix.update()

# Note: don't sort cuts here
# Sort the cuts by duration so that the first one determines the batch time dimensions.
# cuts = cuts.sort_by_duration(ascending=False)

# Resample cuts since the ASR model works at 16kHz
cuts = cuts.resample(16000)

# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we succesfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl

# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)

batch = {"inputs": inputs, "supervisions": supervision_intervals}
if self.return_cuts:
batch["supervisions"]["cut"] = [cut for cut in cuts]

return batch


class AsrDataModule:
"""
DataModule for k2 ASR experiments.
Expand Down Expand Up @@ -117,6 +71,13 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-mel-bins",
type=int,
Expand All @@ -130,22 +91,36 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--batch-size",
type=int,
default=10,
help="The number of utterances in a batch",
)

def dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
dataset = SpeechRecognitionDataset(
dataset = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins))
),
return_cuts=self.args.return_cuts,
)

sampler = SimpleCutSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
drop_last=False,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
drop_last=False,
)
else:
logging.info("Using SimpleCutSampler.")
sampler = SimpleCutSampler(
cuts,
max_cuts=self.args.batch_size,
)

logging.debug("About to create test dataloader")
dl = DataLoader(
Expand Down
Loading
Loading