diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index dca93610..75d1a93d 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -822,26 +822,40 @@ def init_post_load_freeze(self): if self.unet is not None: logger.info("Applying BitFit freezing strategy to the U-net.") - self.unet = apply_bitfit_freezing(self.unet, self.config) + self.unet = apply_bitfit_freezing( + unwrap_model(self.accelerator, self.unet), self.config + ) if self.transformer is not None: logger.warning( "Training DiT models with BitFit is not yet tested, and unexpected results may occur." ) - self.transformer = apply_bitfit_freezing(self.transformer, self.config) + self.transformer = apply_bitfit_freezing( + unwrap_model(self.accelerator, self.transformer), self.config + ) if self.config.gradient_checkpointing: if self.unet is not None: - self.unet.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.unet + ).enable_gradient_checkpointing() if self.transformer is not None and self.config.model_family != "smoldit": - self.transformer.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.transformer + ).enable_gradient_checkpointing() if self.config.controlnet: - self.controlnet.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.controlnet + ).enable_gradient_checkpointing() if ( hasattr(self.config, "train_text_encoder") and self.config.train_text_encoder ): - self.text_encoder_1.gradient_checkpointing_enable() - self.text_encoder_2.gradient_checkpointing_enable() + unwrap_model( + self.accelerator, self.text_encoder_1 + ).gradient_checkpointing_enable() + unwrap_model( + self.accelerator, self.text_encoder_2 + ).gradient_checkpointing_enable() def _recalculate_training_steps(self): # Scheduler and math around the number of training steps. diff --git a/train.py b/train.py index cb63532b..b3c72a18 100644 --- a/train.py +++ b/train.py @@ -8,8 +8,6 @@ logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) if __name__ == "__main__": - global bf - bf = None trainer = None try: import multiprocessing @@ -64,4 +62,4 @@ print(e) print(traceback.format_exc()) if trainer is not None and trainer.bf is not None: - bf.stop_fetching() + trainer.bf.stop_fetching()