diff --git a/main.py b/main.py index ca96e46..83c9328 100644 --- a/main.py +++ b/main.py @@ -18,10 +18,6 @@ @hydra.main(config_path=args.cp, config_name=args.cn) def main(cfg: DictConfig): text_process = TextProcess(**cfg.text_process) - if cfg.decoder.type == "beamsearch": - ctc_decoder = CTCDecoder(text_process=text_process, **cfg.ctcdecoder) - else: - ctc_decoder = None trainset = VivosDataset(**cfg.dataset, subset="train") testset = VivosDataset(**cfg.dataset, subset="test") @@ -32,7 +28,6 @@ def main(cfg: DictConfig): model = DeepSpeechModule( n_class=n_class, text_process=text_process, - ctc_decoder=ctc_decoder, cfg_optim=cfg.optimizer, **cfg.model ) diff --git a/model.py b/model.py index 1c24dc2..b154b97 100644 --- a/model.py +++ b/model.py @@ -16,7 +16,6 @@ def __init__( n_class: int, lr: float, text_process: TextProcess, - ctc_decoder: CTCDecoder, cfg_optim: dict, ): super().__init__() @@ -25,7 +24,6 @@ def __init__( ) self.lr = lr self.text_process = text_process - self.ctc_decoder = ctc_decoder self.cal_wer = torchmetrics.WordErrorRate() self.cfg_optim = cfg_optim self.criterion = nn.CTCLoss(zero_infinity=True) @@ -63,12 +61,8 @@ def validation_step(self, batch, batch_idx): outputs.permute(1, 0, 2), targets, input_lengths, target_lengths ) - if self.ctc_decoder: - # unsqueeze for batchsize 1 - predicts = [self.ctc_decoder(sent.unsqueeze(0)) for sent in outputs] - else: - decode = outputs.argmax(dim=-1) - predicts = [self.text_process.decode(sent) for sent in decode] + decode = outputs.argmax(dim=-1) + predicts = [self.text_process.decode(sent) for sent in decode] targets = [self.text_process.int2text(sent) for sent in targets] @@ -92,12 +86,8 @@ def test_step(self, batch, batch_idx): outputs.permute(1, 0, 2), targets, input_lengths, target_lengths ) - if self.ctc_decoder: - # unsqueeze for batchsize 1 - predicts = [self.ctc_decoder(sent.unsqueeze(0)) for sent in outputs] - else: - decode = outputs.argmax(dim=-1) - predicts = [self.text_process.decode(sent) for sent in decode] + decode = outputs.argmax(dim=-1) + predicts = [self.text_process.decode(sent) for sent in decode] targets = [self.text_process.int2text(sent) for sent in targets] diff --git a/utils.py b/utils.py index f5f3418..cb64a42 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,4 @@ import torch -import ctcdecode class TextProcess: @@ -45,32 +44,32 @@ def int2text(self, s: torch.Tensor) -> str: return "".join([self.list_vocab[i] for i in s if i > 2]) -class CTCDecoder: - def __init__( - self, - alpha: float = 0.5, - beta: float = 0.96, - beam_size: int = 100, - kenlm_path: str = None, - text_process: TextProcess = None, - ): - self.text_process = text_process - labels = text_process.list_vocab - blank_id = labels.index("
") +# class CTCDecoder: +# def __init__( +# self, +# alpha: float = 0.5, +# beta: float = 0.96, +# beam_size: int = 100, +# kenlm_path: str = None, +# text_process: TextProcess = None, +# ): +# self.text_process = text_process +# labels = text_process.list_vocab +# blank_id = labels.index("
") - print("loading beam search with lm...") - self.decoder = ctcdecode.CTCBeamDecoder( - labels, - alpha=alpha, - beta=beta, - beam_width=beam_size, - blank_id=blank_id, - model_path=kenlm_path, - ) - print("finished loading beam search") +# print("loading beam search with lm...") +# self.decoder = ctcdecode.CTCBeamDecoder( +# labels, +# alpha=alpha, +# beta=beta, +# beam_width=beam_size, +# blank_id=blank_id, +# model_path=kenlm_path, +# ) +# print("finished loading beam search") - def __call__(self, output: torch.Tensor) -> str: - beam_result, beam_scores, timesteps, out_seq_len = self.decoder.decode(output) - tokens = beam_result[0][0] - seq_len = out_seq_len[0][0] - return self.text_process.int2text(tokens[:seq_len]) +# def __call__(self, output: torch.Tensor) -> str: +# beam_result, beam_scores, timesteps, out_seq_len = self.decoder.decode(output) +# tokens = beam_result[0][0] +# seq_len = out_seq_len[0][0] +# return self.text_process.int2text(tokens[:seq_len])