Skip to content

Commit

Permalink
Merge pull request #3391 from aaron-lii/multi-gpu
Browse files Browse the repository at this point in the history
support multiple GPU training for XTTS
  • Loading branch information
erogol authored Dec 12, 2023
2 parents b0fe0e6 + b6e9296 commit 934b87b
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,10 @@ def eval_step(self, batch, criterion):
def on_train_epoch_start(self, trainer):
trainer.model.eval() # the whole model to eval
# put gpt model in training mode
trainer.model.xtts.gpt.train()
if hasattr(trainer.model, "module") and hasattr(trainer.model.module, "xtts"):
trainer.model.module.xtts.gpt.train()
else:
trainer.model.xtts.gpt.train()

def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload
Expand Down Expand Up @@ -387,7 +390,8 @@ def get_data_loader(
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
sampler=sampler,
batch_size = config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
Expand Down

0 comments on commit 934b87b

Please sign in to comment.