From 61b2491b37e4c6a5e478db60cca5f14dc8d14846 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 30 May 2024 11:40:19 -0600 Subject: [PATCH 1/4] debiased bucket training should rebuild cache upon epoch end (implements #416) --- helpers/caching/vae.py | 29 +++++++++++++++++++--------- helpers/metadata/backends/base.py | 12 ++++++++++-- helpers/metadata/backends/json.py | 11 ++++++++++- helpers/metadata/backends/parquet.py | 14 +++++++++++--- helpers/training/state_tracker.py | 7 +++++-- train_sd21.py | 22 ++++++++++++++++++--- train_sdxl.py | 17 +++++++++++++++- 7 files changed, 91 insertions(+), 21 deletions(-) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index e5677900..0d5479fc 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -259,15 +259,6 @@ def rebuild_cache(self): First, we'll clear the cache before rebuilding it. """ self.debug_log("Rebuilding cache.") - self.debug_log("-> Clearing cache objects") - self.clear_cache() - self.debug_log("-> Split tasks between GPU(s)") - self.split_cache_between_processes() - self.debug_log("-> Load VAE") - self.init_vae() - self.debug_log("-> Process VAE cache") - self.process_buckets() - self.debug_log("-> Completed cache rebuild") if self.accelerator.is_local_main_process: self.debug_log("Updating StateTracker with new VAE cache entry list.") StateTracker.set_vae_cache_files( @@ -278,6 +269,26 @@ def rebuild_cache(self): data_backend_id=self.id, ) self.accelerator.wait_for_everyone() + self.debug_log("-> Clearing cache objects") + self.clear_cache() + self.debug_log("-> Split tasks between GPU(s)") + self.split_cache_between_processes() + self.debug_log("-> Load VAE") + self.init_vae() + if StateTracker.get_args().vae_cache_preprocess: + self.debug_log("-> Process VAE cache") + self.process_buckets() + if self.accelerator.is_local_main_process: + self.debug_log("Updating StateTracker with new VAE cache entry list.") + StateTracker.set_vae_cache_files( + self.data_backend.list_files( + instance_data_root=self.cache_dir, + str_pattern="*.pt", + ), + data_backend_id=self.id, + ) + self.accelerator.wait_for_everyone() + self.debug_log("-> Completed cache rebuild") def clear_cache(self): """ diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index b974a941..70f9e269 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -157,15 +157,21 @@ def _bucket_worker( time.sleep(0.001) logger.debug(f"Bucket worker completed processing. Returning to main thread.") - def compute_aspect_ratio_bucket_indices(self): + def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = False): """ Compute the aspect ratio bucket indices. The workhorse of this class. + Arguments: + ignore_existing_cache (bool): Whether to ignore the existing cache + and entirely recompute the aspect ratio bucket indices. + Returns: dict: The aspect ratio bucket indices. """ logger.info("Discovering new files...") - new_files = self._discover_new_files() + new_files = self._discover_new_files( + ignore_existing_cache=ignore_existing_cache + ) existing_files_set = set().union(*self.aspect_ratio_bucket_indices.values()) logger.info( @@ -186,6 +192,8 @@ def compute_aspect_ratio_bucket_indices(self): logger.info("No new files discovered. Doing nothing.") logger.info(f"Statistics: {aggregated_statistics}") return + else: + logger.debug(f"New files: {new_files}") num_cpus = ( StateTracker.get_args().aspect_bucket_worker_count ) # Using a fixed number for better control and predictability diff --git a/helpers/metadata/backends/json.py b/helpers/metadata/backends/json.py index b777b6e8..934f07f2 100644 --- a/helpers/metadata/backends/json.py +++ b/helpers/metadata/backends/json.py @@ -63,7 +63,9 @@ def __len__(self): if len(bucket) >= self.batch_size ) - def _discover_new_files(self, for_metadata: bool = False): + def _discover_new_files( + self, for_metadata: bool = False, ignore_existing_cache: bool = False + ): """ Discover new files that have not been processed yet. @@ -73,6 +75,13 @@ def _discover_new_files(self, for_metadata: bool = False): all_image_files = StateTracker.get_image_files( data_backend_id=self.data_backend.id ) + if ignore_existing_cache: + # Return all files and remove the existing buckets. + logger.debug( + f"Resetting the entire aspect bucket cache as we've received the signal to ignore existing cache." + ) + self.aspect_ratio_bucket_indices = {} + return list(all_image_files.keys()) if all_image_files is None: logger.debug("No image file cache available, retrieving fresh") all_image_files = self.data_backend.list_files( diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 3edc1e3e..d7569979 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -87,7 +87,9 @@ def __len__(self): if len(bucket) >= self.batch_size ) - def _discover_new_files(self, for_metadata: bool = False): + def _discover_new_files( + self, for_metadata: bool = False, ignore_existing_cache: bool = False + ): """ Discover new files that have not been processed yet. @@ -123,6 +125,10 @@ def _discover_new_files(self, for_metadata: bool = False): for file in all_image_files if self.get_metadata_by_filepath(file) is None ] + elif ignore_existing_cache: + # Remove existing aspect bucket indices and return all image files. + result = all_image_files + self.aspect_ratio_bucket_indices = {} else: processed_files = set( path @@ -209,7 +215,7 @@ def save_image_metadata(self): """Save image metadata to a JSON file.""" self.data_backend.write(self.metadata_file, json.dumps(self.image_metadata)) - def compute_aspect_ratio_bucket_indices(self): + def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = False): """ Compute the aspect ratio bucket indices without any threads or queues. @@ -219,7 +225,9 @@ def compute_aspect_ratio_bucket_indices(self): dict: The aspect ratio bucket indices. """ logger.info("Discovering new files...") - new_files = self._discover_new_files() + new_files = self._discover_new_files( + ignore_existing_cache=ignore_existing_cache + ) existing_files_set = set().union(*self.aspect_ratio_bucket_indices.values()) # Initialize aggregated statistics diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 95d8fb34..dfcb4541 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -285,11 +285,14 @@ def set_vae_cache_files(cls, raw_file_list: list, data_backend_id: str): @classmethod def get_vae_cache_files(cls: list, data_backend_id: str): - if data_backend_id not in cls.all_vae_cache_files: + if ( + data_backend_id not in cls.all_vae_cache_files + or cls.all_vae_cache_files.get(data_backend_id) is None + ): cls.all_vae_cache_files[data_backend_id] = cls._load_from_disk( "all_vae_cache_files_{}".format(data_backend_id) ) - return cls.all_vae_cache_files[data_backend_id] + return cls.all_vae_cache_files[data_backend_id] or {} @classmethod def set_text_cache_files(cls, raw_file_list: list, data_backend_id: str): diff --git a/train_sd21.py b/train_sd21.py index 4190c575..371da069 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -1053,12 +1053,28 @@ def main(): backend_config = StateTracker.get_data_backend_config(backend_id) logger.debug(f"Backend config: {backend_config}") if ( - "deepfloyd" not in args.model_type - and "vae_cache_clear_each_epoch" in backend_config + "crop_aspect" in backend_config + and backend_config["crop_aspect"] is not None + and backend_config["crop_aspect"] == "random" + and "metadata_backend" in backend + ): + # when the aspect ratio is random, we need to shuffle the dataset on each epoch. + backend["metadata_backend"].compute_aspect_ratio_bucket_indices( + ignore_existing_cache=True + ) + # we have to rebuild the VAE cache if it exists. + if "vaecache" in backend: + backend["vaecache"].rebuild_cache() + backend["metadata_backend"].save_cache() + elif ( + "vae_cache_clear_each_epoch" in backend_config and backend_config["vae_cache_clear_each_epoch"] + and "vaecache" in backend ): - # We will clear the cache and then rebuild it. This is useful for random crops. + # If the user has specified that this should happen, + # we will clear the cache and then rebuild it. This is useful for random crops. logger.debug(f"VAE Cache rebuild is enabled. Rebuilding.") + logger.debug(f"Backend config: {backend_config}") backend["vaecache"].rebuild_cache() current_epoch = epoch StateTracker.set_epoch(epoch) diff --git a/train_sdxl.py b/train_sdxl.py index d57826e2..f87afb39 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -1143,11 +1143,26 @@ def main(): for backend_id, backend in StateTracker.get_data_backends().items(): backend_config = StateTracker.get_data_backend_config(backend_id) if ( + "crop_aspect" in backend_config + and backend_config["crop_aspect"] is not None + and backend_config["crop_aspect"] == "random" + and "metadata_backend" in backend + ): + # when the aspect ratio is random, we need to shuffle the dataset on each epoch. + backend["metadata_backend"].compute_aspect_ratio_bucket_indices( + ignore_existing_cache=True + ) + # we have to rebuild the VAE cache if it exists. + if "vaecache" in backend: + backend["vaecache"].rebuild_cache() + backend["metadata_backend"].save_cache() + elif ( "vae_cache_clear_each_epoch" in backend_config and backend_config["vae_cache_clear_each_epoch"] and "vaecache" in backend ): - # We will clear the cache and then rebuild it. This is useful for random crops. + # If the user has specified that this should happen, + # we will clear the cache and then rebuild it. This is useful for random crops. logger.debug(f"VAE Cache rebuild is enabled. Rebuilding.") logger.debug(f"Backend config: {backend_config}") backend["vaecache"].rebuild_cache() From ef08dddbe1be7b110ad4c850cdb46c55d69d72b8 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 30 May 2024 12:35:31 -0600 Subject: [PATCH 2/4] metadata: remove unneeded debug log --- helpers/metadata/backends/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index 70f9e269..96ac2119 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -192,8 +192,6 @@ def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = Fals logger.info("No new files discovered. Doing nothing.") logger.info(f"Statistics: {aggregated_statistics}") return - else: - logger.debug(f"New files: {new_files}") num_cpus = ( StateTracker.get_args().aspect_bucket_worker_count ) # Using a fixed number for better control and predictability From 014144532718d5aa5c9c8f0e7aae8007260c90a8 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 30 May 2024 12:35:52 -0600 Subject: [PATCH 3/4] parquet backend adjustment for ignoring bucket cache --- helpers/metadata/backends/parquet.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index d7569979..dd557404 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -110,7 +110,13 @@ def _discover_new_files( ) else: logger.debug("Using cached image file list") - + if ignore_existing_cache: + # Return all files and remove the existing buckets. + logger.debug( + f"Resetting the entire aspect bucket cache as we've received the signal to ignore existing cache." + ) + self.aspect_ratio_bucket_indices = {} + return list(all_image_files.keys()) # Flatten the list if it contains nested lists if any(isinstance(i, list) for i in all_image_files): all_image_files = [item for sublist in all_image_files for item in sublist] From f3f5d7bb29ee0a4e85a654640d96bbf09bbd7357 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 30 May 2024 13:46:38 -0600 Subject: [PATCH 4/4] parquet backend: resolve retrieval of captions and other oddities with handling of filenames --- helpers/data_backend/factory.py | 3 ++ helpers/image_manipulation/training_sample.py | 16 +++++--- helpers/metadata/backends/parquet.py | 7 ++++ helpers/prompts.py | 40 ++++++++++++------- tests/test_prompthandler.py | 2 +- 5 files changed, 47 insertions(+), 21 deletions(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index c4b7d507..68990864 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -99,6 +99,9 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: output["config"]["caption_strategy"] = backend["caption_strategy"] else: output["config"]["caption_strategy"] = args.caption_strategy + output["config"]["instance_data_root"] = backend.get( + "instance_data_dir", backend.get("aws_data_prefix", "") + ) maximum_image_size = backend.get("maximum_image_size", args.maximum_image_size) target_downsample_size = backend.get( diff --git a/helpers/image_manipulation/training_sample.py b/helpers/image_manipulation/training_sample.py index 653a9031..40839753 100644 --- a/helpers/image_manipulation/training_sample.py +++ b/helpers/image_manipulation/training_sample.py @@ -227,11 +227,17 @@ def _select_random_aspect(self): return 1.0 # filter to portrait or landscape buckets, depending on our aspect ratio available_aspects = self._trim_aspect_bucket_list() - logger.debug( - f"Available aspect buckets: {available_aspects} for {self.aspect_ratio} from {self.crop_aspect_buckets}" - ) - selected_aspect = random.choice(available_aspects) - logger.debug(f"Randomly selected aspect ratio: {selected_aspect}") + if len(available_aspects) == 0: + selected_aspect = 1.0 + logger.warning( + f"Image dimensions do not fit into the configured aspect buckets. Using square crop." + ) + else: + logger.debug( + f"Available aspect buckets: {available_aspects} for {self.aspect_ratio} from {self.crop_aspect_buckets}" + ) + selected_aspect = random.choice(available_aspects) + logger.debug(f"Randomly selected aspect ratio: {selected_aspect}") else: raise ValueError( "Aspect buckets must be a list of floats or dictionaries." diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index dd557404..9ac4fcf4 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -341,6 +341,13 @@ def _process_for_bucket( image_path_filtered = os.path.splitext( os.path.split(image_path_str)[-1] )[0] + if self.instance_data_root in image_path_filtered: + image_path_filtered = image_path_filtered.replace( + self.instance_data_root, "" + ) + # remove leading / + if image_path_filtered.startswith("/"): + image_path_filtered = image_path_filtered[1:] if image_path_filtered.isdigit(): image_path_filtered = int(image_path_filtered) diff --git a/helpers/prompts.py b/helpers/prompts.py index 63be11dd..b940918f 100644 --- a/helpers/prompts.py +++ b/helpers/prompts.py @@ -176,7 +176,6 @@ def prepare_instance_prompt_from_parquet( "Instance prompt is required when instance_prompt_only is enabled." ) return instance_prompt - image_filename_stem = os.path.splitext(os.path.split(image_path)[1])[0] ( parquet_db, filename_column, @@ -184,14 +183,28 @@ def prepare_instance_prompt_from_parquet( fallback_caption_column, identifier_includes_extension, ) = StateTracker.get_parquet_database(sampler_backend_id) - if identifier_includes_extension: - image_filename_stem = image_path + backend_config = StateTracker.get_data_backend_config( + data_backend_id=data_backend.id + ) + instance_data_root = backend_config.get("instance_data_root") + image_filename_stem = image_path + if instance_data_root is not None and instance_data_root in image_filename_stem: + image_filename_stem = image_filename_stem.replace(instance_data_root, "") + if image_filename_stem.startswith("/"): + image_filename_stem = image_filename_stem[1:] + + if not identifier_includes_extension: + image_filename_stem = os.path.splitext(image_filename_stem)[0] + + logger.debug( + f"for image_path: {image_path} we have image_filename_stem: {image_filename_stem}" + ) # parquet_db is a dataframe. let's find the row that matches the image filename. if parquet_db is None: raise ValueError( f"Parquet database not found for sampler {sampler_backend_id}." ) - image_caption = None + image_caption = "" # Are the types incorrect, eg. the column is int64 vs str stem? if "int" in str(parquet_db[filename_column].dtype): if image_filename_stem.isdigit(): @@ -387,17 +400,14 @@ def get_all_captions( data_backend=data_backend, ) elif caption_strategy == "parquet": - try: - caption = PromptHandler.prepare_instance_prompt_from_parquet( - image_path, - use_captions=use_captions, - prepend_instance_prompt=prepend_instance_prompt, - instance_prompt=instance_prompt, - data_backend=data_backend, - sampler_backend_id=data_backend.id, - ) - except: - continue + caption = PromptHandler.prepare_instance_prompt_from_parquet( + image_path, + use_captions=use_captions, + prepend_instance_prompt=prepend_instance_prompt, + instance_prompt=instance_prompt, + data_backend=data_backend, + sampler_backend_id=data_backend.id, + ) elif caption_strategy == "instanceprompt": return [instance_prompt] else: diff --git a/tests/test_prompthandler.py b/tests/test_prompthandler.py index 4bd49717..6f02889c 100644 --- a/tests/test_prompthandler.py +++ b/tests/test_prompthandler.py @@ -19,7 +19,7 @@ def setUp(self): @patch("helpers.training.state_tracker.StateTracker.get_parquet_database") def test_prepare_instance_prompt_from_parquet(self, mock_get_parquet_database): # Setup - image_path = "path/to/image_3.jpg" + image_path = "image_3.jpg" use_captions = True prepend_instance_prompt = True data_backend = MagicMock()