Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Can we add the batch inference or batch decoding for XTTS #3776

Open
Onkarsus13 opened this issue Jun 5, 2024 · 4 comments
Labels
feature request feature requests for making TTS better. wontfix This will not be worked on but feel free to help.

Comments

@Onkarsus13
Copy link

I tried the batch inference in XTTS, So I am doing padding till the max text sequence in the batch and also adding the attention mask for this, But for shorter sequences,
I am getting some random noise at the end of the audio
It would be helpful if we get this feature in Coqui tts.

@Onkarsus13 Onkarsus13 added the feature request feature requests for making TTS better. label Jun 5, 2024
@tuanh123789
Copy link

I face same problem when infer with batch size. Do you solve it

@Rakshith12-pixel
Copy link

@Onkarsus13 Could you implement batched inference successfully?

@Onkarsus13
Copy link
Author

Yes Rakshith I can implement it
But is I am able to do it like partial batch decoding
Let me share the code sinppet with you guys

    @torch.inference_mode()
    def Pbatch_inference(
        self,
        text,
        language,
        gpt_cond_latent,
        speaker_embedding,
        # GPT inference
        temperature=0.75,
        length_penalty=1.0,
        repetition_penalty=10.0,
        top_k=50,
        top_p=0.85,
        do_sample=True,
        num_beams=1,
        speed=1.0,
        enable_text_splitting=False,
        **hf_generate_kwargs,
    ):
        language = language.split("-")[0]  # remove the country code
        length_scale = 1.0 / max(speed, 0.05)
        gpt_cond_latent = gpt_cond_latent.to(self.device)
        speaker_embedding = speaker_embedding.to(self.device)

        xg = gpt_cond_latent.repeat(len(text), 1, 1)
        xse = speaker_embedding.repeat(len(text), 1, 1)

        wavs = []
        text_tokens = []
        gpt_latents_list = []
        lens = []
        GPT_in = []
        with torch.no_grad():
            for sent in text:
                sent = sent.strip().lower()
                text_token = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0)
                lens.append(text_token.shape[1])
                text_tokens.append(text_token)
                
                gpt_codes = self.gpt.generate(
                    cond_latents=xg[0].unsqueeze(0),
                    text_inputs=text_token.to(self.device),
                    input_tokens=None,
                    do_sample=do_sample,
                    top_p=top_p,
                    top_k=top_k,
                    temperature=temperature,
                    num_return_sequences=self.gpt_batch_size,
                    num_beams=num_beams,
                    length_penalty=length_penalty,
                    repetition_penalty=repetition_penalty,
                    output_attentions=False,
                    **hf_generate_kwargs,
                )
                GPT_in.append(gpt_codes[0])
        
            max_text_len = max(lens)
            text_padded = torch.IntTensor(len(text), max_text_len)
            text_padded = text_padded.zero_()
            for i in range(len(text)):
                t = text_tokens[i]
                text_padded[i, : lens[i]] = torch.IntTensor(t)
            text_padded = text_padded.to(self.device)
            
            gpt_codes = rnn_utils.pad_sequence(GPT_in, batch_first=True, padding_value=1025)

            expected_output_len = torch.tensor(
                    [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=self.device
                )

            text_len = torch.tensor(lens, device=self.device)
            gpt_latents = self.gpt2(
                text_padded,
                text_len,
                gpt_codes,
                expected_output_len,
                cond_latents=xg,
                return_attentions=False,
                return_latent=True,
            )

            for i in range(gpt_codes.shape[0]):
                for idx, d in enumerate(gpt_codes[i]):
                    if d == 1025:
                        break
                
                z = torch.zeros((gpt_codes[i].shape[0] - idx-1, gpt_latents.shape[-1]), dtype=gpt_latents.dtype, device=gpt_latents.device)
                gpt_latents[i,idx:,:] = z

            if length_scale != 1.0:
                gpt_latents = F.interpolate(
                    gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
                ).transpose(1, 2)
            wav = self.hifigan_decoder(gpt_latents, g=xse).cpu().squeeze()


        return {
            "wav": wav.cpu().unsqueeze(1),
            "gpt_latents": gpt_latents.cpu().numpy(),
            "speaker_embedding": speaker_embedding,
        }

Copy link

stale bot commented Sep 18, 2024

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. You might also look our discussion channels.

@stale stale bot added the wontfix This will not be worked on but feel free to help. label Sep 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request feature requests for making TTS better. wontfix This will not be worked on but feel free to help.
Projects
None yet
Development

No branches or pull requests

3 participants