Skip to content

Commit

Permalink
Fix for latest TF
Browse files Browse the repository at this point in the history
  • Loading branch information
gravityrail committed Jul 8, 2024
1 parent 6696abf commit bf9a38f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def generate(

elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand All @@ -455,7 +455,7 @@ def generate(
)
elif is_sample_gen_stream_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand Down Expand Up @@ -517,7 +517,7 @@ def generate(

elif is_beam_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, inputs_tensor.device)

if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
Expand Down

0 comments on commit bf9a38f

Please sign in to comment.