From b6e929696a470dd2ed0daf1448ec4e0407fd4f20 Mon Sep 17 00:00:00 2001 From: Aaron-Li <1427346151@qq.com> Date: Fri, 8 Dec 2023 16:55:32 +0800 Subject: [PATCH] support multiple GPU training --- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 6276f60af6..9a7a1d7783 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -321,7 +321,10 @@ def eval_step(self, batch, criterion): def on_train_epoch_start(self, trainer): trainer.model.eval() # the whole model to eval # put gpt model in training mode - trainer.model.xtts.gpt.train() + if hasattr(trainer.model, "module") and hasattr(trainer.model.module, "xtts"): + trainer.model.module.xtts.gpt.train() + else: + trainer.model.xtts.gpt.train() def on_init_end(self, trainer): # pylint: disable=W0613 # ignore similarities.pth on clearml save/upload @@ -387,7 +390,8 @@ def get_data_loader( else: loader = DataLoader( dataset, - batch_sampler=sampler, + sampler=sampler, + batch_size = config.eval_batch_size if is_eval else config.batch_size, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False,