Skip to content

Commit

Permalink
Merge pull request #426 from bghira/main
Browse files Browse the repository at this point in the history
parquet backend improvements and rebuilding buckets/vae cache on each epoch for randomised bucketing
  • Loading branch information
bghira authored May 30, 2024
2 parents fa3e886 + 98fcd39 commit 11ab703
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 43 deletions.
29 changes: 20 additions & 9 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,6 @@ def rebuild_cache(self):
First, we'll clear the cache before rebuilding it.
"""
self.debug_log("Rebuilding cache.")
self.debug_log("-> Clearing cache objects")
self.clear_cache()
self.debug_log("-> Split tasks between GPU(s)")
self.split_cache_between_processes()
self.debug_log("-> Load VAE")
self.init_vae()
self.debug_log("-> Process VAE cache")
self.process_buckets()
self.debug_log("-> Completed cache rebuild")
if self.accelerator.is_local_main_process:
self.debug_log("Updating StateTracker with new VAE cache entry list.")
StateTracker.set_vae_cache_files(
Expand All @@ -278,6 +269,26 @@ def rebuild_cache(self):
data_backend_id=self.id,
)
self.accelerator.wait_for_everyone()
self.debug_log("-> Clearing cache objects")
self.clear_cache()
self.debug_log("-> Split tasks between GPU(s)")
self.split_cache_between_processes()
self.debug_log("-> Load VAE")
self.init_vae()
if StateTracker.get_args().vae_cache_preprocess:
self.debug_log("-> Process VAE cache")
self.process_buckets()
if self.accelerator.is_local_main_process:
self.debug_log("Updating StateTracker with new VAE cache entry list.")
StateTracker.set_vae_cache_files(
self.data_backend.list_files(
instance_data_root=self.cache_dir,
str_pattern="*.pt",
),
data_backend_id=self.id,
)
self.accelerator.wait_for_everyone()
self.debug_log("-> Completed cache rebuild")

def clear_cache(self):
"""
Expand Down
3 changes: 3 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
output["config"]["caption_strategy"] = backend["caption_strategy"]
else:
output["config"]["caption_strategy"] = args.caption_strategy
output["config"]["instance_data_root"] = backend.get(
"instance_data_dir", backend.get("aws_data_prefix", "")
)

maximum_image_size = backend.get("maximum_image_size", args.maximum_image_size)
target_downsample_size = backend.get(
Expand Down
16 changes: 11 additions & 5 deletions helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,17 @@ def _select_random_aspect(self):
return 1.0
# filter to portrait or landscape buckets, depending on our aspect ratio
available_aspects = self._trim_aspect_bucket_list()
logger.debug(
f"Available aspect buckets: {available_aspects} for {self.aspect_ratio} from {self.crop_aspect_buckets}"
)
selected_aspect = random.choice(available_aspects)
logger.debug(f"Randomly selected aspect ratio: {selected_aspect}")
if len(available_aspects) == 0:
selected_aspect = 1.0
logger.warning(
f"Image dimensions do not fit into the configured aspect buckets. Using square crop."
)
else:
logger.debug(
f"Available aspect buckets: {available_aspects} for {self.aspect_ratio} from {self.crop_aspect_buckets}"
)
selected_aspect = random.choice(available_aspects)
logger.debug(f"Randomly selected aspect ratio: {selected_aspect}")
else:
raise ValueError(
"Aspect buckets must be a list of floats or dictionaries."
Expand Down
10 changes: 8 additions & 2 deletions helpers/metadata/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,21 @@ def _bucket_worker(
time.sleep(0.001)
logger.debug(f"Bucket worker completed processing. Returning to main thread.")

def compute_aspect_ratio_bucket_indices(self):
def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = False):
"""
Compute the aspect ratio bucket indices. The workhorse of this class.
Arguments:
ignore_existing_cache (bool): Whether to ignore the existing cache
and entirely recompute the aspect ratio bucket indices.
Returns:
dict: The aspect ratio bucket indices.
"""
logger.info("Discovering new files...")
new_files = self._discover_new_files()
new_files = self._discover_new_files(
ignore_existing_cache=ignore_existing_cache
)

existing_files_set = set().union(*self.aspect_ratio_bucket_indices.values())
logger.info(
Expand Down
11 changes: 10 additions & 1 deletion helpers/metadata/backends/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def __len__(self):
if len(bucket) >= self.batch_size
)

def _discover_new_files(self, for_metadata: bool = False):
def _discover_new_files(
self, for_metadata: bool = False, ignore_existing_cache: bool = False
):
"""
Discover new files that have not been processed yet.
Expand All @@ -73,6 +75,13 @@ def _discover_new_files(self, for_metadata: bool = False):
all_image_files = StateTracker.get_image_files(
data_backend_id=self.data_backend.id
)
if ignore_existing_cache:
# Return all files and remove the existing buckets.
logger.debug(
f"Resetting the entire aspect bucket cache as we've received the signal to ignore existing cache."
)
self.aspect_ratio_bucket_indices = {}
return list(all_image_files.keys())
if all_image_files is None:
logger.debug("No image file cache available, retrieving fresh")
all_image_files = self.data_backend.list_files(
Expand Down
29 changes: 25 additions & 4 deletions helpers/metadata/backends/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __len__(self):
if len(bucket) >= self.batch_size
)

def _discover_new_files(self, for_metadata: bool = False):
def _discover_new_files(
self, for_metadata: bool = False, ignore_existing_cache: bool = False
):
"""
Discover new files that have not been processed yet.
Expand All @@ -108,7 +110,13 @@ def _discover_new_files(self, for_metadata: bool = False):
)
else:
logger.debug("Using cached image file list")

if ignore_existing_cache:
# Return all files and remove the existing buckets.
logger.debug(
f"Resetting the entire aspect bucket cache as we've received the signal to ignore existing cache."
)
self.aspect_ratio_bucket_indices = {}
return list(all_image_files.keys())
# Flatten the list if it contains nested lists
if any(isinstance(i, list) for i in all_image_files):
all_image_files = [item for sublist in all_image_files for item in sublist]
Expand All @@ -123,6 +131,10 @@ def _discover_new_files(self, for_metadata: bool = False):
for file in all_image_files
if self.get_metadata_by_filepath(file) is None
]
elif ignore_existing_cache:
# Remove existing aspect bucket indices and return all image files.
result = all_image_files
self.aspect_ratio_bucket_indices = {}
else:
processed_files = set(
path
Expand Down Expand Up @@ -209,7 +221,7 @@ def save_image_metadata(self):
"""Save image metadata to a JSON file."""
self.data_backend.write(self.metadata_file, json.dumps(self.image_metadata))

def compute_aspect_ratio_bucket_indices(self):
def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = False):
"""
Compute the aspect ratio bucket indices without any threads or queues.
Expand All @@ -219,7 +231,9 @@ def compute_aspect_ratio_bucket_indices(self):
dict: The aspect ratio bucket indices.
"""
logger.info("Discovering new files...")
new_files = self._discover_new_files()
new_files = self._discover_new_files(
ignore_existing_cache=ignore_existing_cache
)

existing_files_set = set().union(*self.aspect_ratio_bucket_indices.values())
# Initialize aggregated statistics
Expand Down Expand Up @@ -327,6 +341,13 @@ def _process_for_bucket(
image_path_filtered = os.path.splitext(
os.path.split(image_path_str)[-1]
)[0]
if self.instance_data_root in image_path_filtered:
image_path_filtered = image_path_filtered.replace(
self.instance_data_root, ""
)
# remove leading /
if image_path_filtered.startswith("/"):
image_path_filtered = image_path_filtered[1:]
if image_path_filtered.isdigit():
image_path_filtered = int(image_path_filtered)

Expand Down
40 changes: 25 additions & 15 deletions helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,22 +176,35 @@ def prepare_instance_prompt_from_parquet(
"Instance prompt is required when instance_prompt_only is enabled."
)
return instance_prompt
image_filename_stem = os.path.splitext(os.path.split(image_path)[1])[0]
(
parquet_db,
filename_column,
caption_column,
fallback_caption_column,
identifier_includes_extension,
) = StateTracker.get_parquet_database(sampler_backend_id)
if identifier_includes_extension:
image_filename_stem = image_path
backend_config = StateTracker.get_data_backend_config(
data_backend_id=data_backend.id
)
instance_data_root = backend_config.get("instance_data_root")
image_filename_stem = image_path
if instance_data_root is not None and instance_data_root in image_filename_stem:
image_filename_stem = image_filename_stem.replace(instance_data_root, "")
if image_filename_stem.startswith("/"):
image_filename_stem = image_filename_stem[1:]

if not identifier_includes_extension:
image_filename_stem = os.path.splitext(image_filename_stem)[0]

logger.debug(
f"for image_path: {image_path} we have image_filename_stem: {image_filename_stem}"
)
# parquet_db is a dataframe. let's find the row that matches the image filename.
if parquet_db is None:
raise ValueError(
f"Parquet database not found for sampler {sampler_backend_id}."
)
image_caption = None
image_caption = ""
# Are the types incorrect, eg. the column is int64 vs str stem?
if "int" in str(parquet_db[filename_column].dtype):
if image_filename_stem.isdigit():
Expand Down Expand Up @@ -387,17 +400,14 @@ def get_all_captions(
data_backend=data_backend,
)
elif caption_strategy == "parquet":
try:
caption = PromptHandler.prepare_instance_prompt_from_parquet(
image_path,
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
data_backend=data_backend,
sampler_backend_id=data_backend.id,
)
except:
continue
caption = PromptHandler.prepare_instance_prompt_from_parquet(
image_path,
use_captions=use_captions,
prepend_instance_prompt=prepend_instance_prompt,
instance_prompt=instance_prompt,
data_backend=data_backend,
sampler_backend_id=data_backend.id,
)
elif caption_strategy == "instanceprompt":
return [instance_prompt]
else:
Expand Down
7 changes: 5 additions & 2 deletions helpers/training/state_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,14 @@ def set_vae_cache_files(cls, raw_file_list: list, data_backend_id: str):

@classmethod
def get_vae_cache_files(cls: list, data_backend_id: str):
if data_backend_id not in cls.all_vae_cache_files:
if (
data_backend_id not in cls.all_vae_cache_files
or cls.all_vae_cache_files.get(data_backend_id) is None
):
cls.all_vae_cache_files[data_backend_id] = cls._load_from_disk(
"all_vae_cache_files_{}".format(data_backend_id)
)
return cls.all_vae_cache_files[data_backend_id]
return cls.all_vae_cache_files[data_backend_id] or {}

@classmethod
def set_text_cache_files(cls, raw_file_list: list, data_backend_id: str):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prompthandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setUp(self):
@patch("helpers.training.state_tracker.StateTracker.get_parquet_database")
def test_prepare_instance_prompt_from_parquet(self, mock_get_parquet_database):
# Setup
image_path = "path/to/image_3.jpg"
image_path = "image_3.jpg"
use_captions = True
prepend_instance_prompt = True
data_backend = MagicMock()
Expand Down
22 changes: 19 additions & 3 deletions train_sd21.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,12 +1053,28 @@ def main():
backend_config = StateTracker.get_data_backend_config(backend_id)
logger.debug(f"Backend config: {backend_config}")
if (
"deepfloyd" not in args.model_type
and "vae_cache_clear_each_epoch" in backend_config
"crop_aspect" in backend_config
and backend_config["crop_aspect"] is not None
and backend_config["crop_aspect"] == "random"
and "metadata_backend" in backend
):
# when the aspect ratio is random, we need to shuffle the dataset on each epoch.
backend["metadata_backend"].compute_aspect_ratio_bucket_indices(
ignore_existing_cache=True
)
# we have to rebuild the VAE cache if it exists.
if "vaecache" in backend:
backend["vaecache"].rebuild_cache()
backend["metadata_backend"].save_cache()
elif (
"vae_cache_clear_each_epoch" in backend_config
and backend_config["vae_cache_clear_each_epoch"]
and "vaecache" in backend
):
# We will clear the cache and then rebuild it. This is useful for random crops.
# If the user has specified that this should happen,
# we will clear the cache and then rebuild it. This is useful for random crops.
logger.debug(f"VAE Cache rebuild is enabled. Rebuilding.")
logger.debug(f"Backend config: {backend_config}")
backend["vaecache"].rebuild_cache()
current_epoch = epoch
StateTracker.set_epoch(epoch)
Expand Down
17 changes: 16 additions & 1 deletion train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,11 +1143,26 @@ def main():
for backend_id, backend in StateTracker.get_data_backends().items():
backend_config = StateTracker.get_data_backend_config(backend_id)
if (
"crop_aspect" in backend_config
and backend_config["crop_aspect"] is not None
and backend_config["crop_aspect"] == "random"
and "metadata_backend" in backend
):
# when the aspect ratio is random, we need to shuffle the dataset on each epoch.
backend["metadata_backend"].compute_aspect_ratio_bucket_indices(
ignore_existing_cache=True
)
# we have to rebuild the VAE cache if it exists.
if "vaecache" in backend:
backend["vaecache"].rebuild_cache()
backend["metadata_backend"].save_cache()
elif (
"vae_cache_clear_each_epoch" in backend_config
and backend_config["vae_cache_clear_each_epoch"]
and "vaecache" in backend
):
# We will clear the cache and then rebuild it. This is useful for random crops.
# If the user has specified that this should happen,
# we will clear the cache and then rebuild it. This is useful for random crops.
logger.debug(f"VAE Cache rebuild is enabled. Rebuilding.")
logger.debug(f"Backend config: {backend_config}")
backend["vaecache"].rebuild_cache()
Expand Down

0 comments on commit 11ab703

Please sign in to comment.