diff --git a/configure.py b/configure.py index 77a19852..417c2c79 100644 --- a/configure.py +++ b/configure.py @@ -40,7 +40,7 @@ "terminus": 8.0, "sd3": 5.0, } -lycoris_algos = ["lokr"] + lora_ranks = [1, 16, 64, 128, 256] learning_rates_by_rank = { 1: "3e-4", @@ -84,7 +84,7 @@ def configure_lycoris(): print("6. DyLoRA - Dynamic updates, efficient with large dims. (algo=dylora)") print("7. Diag-OFT - Fast convergence with orthogonal fine-tuning. (algo=diag-oft)") print("8. BOFT - Advanced version of Diag-OFT with more flexibility. (algo=boft)") - print("9. GLoRA/GLoKr - Generalized, still in development. (algo=glora/glokr)\n") + print("9. GLoRA - Generalized LoRA. (algo=glora)\n") # Prompt user to select an algorithm algo = prompt_user( diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index ed354a5c..83ea94b0 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -143,8 +143,9 @@ def __init__( self.pipeline_class = FluxPipeline elif args.flux and args.flux_attention_masked_training: from helpers.models.flux.transformer import ( - FluxTransformer2DModelWithMasking + FluxTransformer2DModelWithMasking, ) + self.denoiser_class = FluxTransformer2DModelWithMasking self.pipeline_class = FluxPipeline elif hasattr(args, "hunyuan_dit") and args.hunyuan_dit: @@ -313,7 +314,7 @@ def save_model_hook(self, models, weights, output_dir): StateTracker.save_training_state( os.path.join(output_dir, "training_state.json") ) - if "lora" in self.args.model_type and self.args.lora_type == "Standard": + if "lora" in self.args.model_type and self.args.lora_type == "standard": self._save_lora(models=models, weights=weights, output_dir=output_dir) return elif "lora" in self.args.model_type and self.args.lora_type == "lycoris": @@ -461,7 +462,7 @@ def load_model_hook(self, models, input_dir): f"Could not find training_state.json in checkpoint dir {input_dir}" ) - if "lora" in self.args.model_type and self.args.lora_type == "Standard": + if "lora" in self.args.model_type and self.args.lora_type == "standard": self._load_lora(models=models, input_dir=input_dir) elif "lora" in self.args.model_type and self.args.lora_type == "lycoris": self._load_lycoris(models=models, input_dir=input_dir)