Skip to content

Commit

Permalink
Merge pull request #12 from bghira/main
Browse files Browse the repository at this point in the history
updates and fixes
  • Loading branch information
bghira authored Aug 7, 2023
2 parents bf0bd5b + 8749493 commit c5ebc0e
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 70 deletions.
2 changes: 0 additions & 2 deletions helpers/aspect_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class BalancedBucketSampler(torch.utils.data.Sampler):
def __init__(
self,
aspect_ratio_bucket_indices,
accelerator,
batch_size: int = 15,
seen_images_path: str = "/notebooks/SimpleTuner/seen_images.json",
state_path: str = "/notebooks/SimpleTuner/bucket_sampler_state.json",
Expand All @@ -37,7 +36,6 @@ def __init__(
self.current_bucket = 0
self.seen_images_path = seen_images_path
self.state_path = state_path
self.accelerator = accelerator
self.seen_images = self.load_seen_images()
self.drop_caption_every_n_percent = drop_caption_every_n_percent
self.debug_aspect_buckets = debug_aspect_buckets
Expand Down
116 changes: 61 additions & 55 deletions helpers/dreambooth_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def _process_image(self, image_path_str, aspect_ratio_bucket_indices):
aspect_ratio_bucket_indices[str(aspect_ratio)] = []
aspect_ratio_bucket_indices[str(aspect_ratio)].append(image_path_str)
except Exception as e:
logging.error(f"Error processing image {image_path_str}.")
logging.error(e)
logger.error(f"Error processing image {image_path_str}.")
logger.error(e)
return aspect_ratio_bucket_indices
finally:
if "image" in locals():
Expand Down Expand Up @@ -166,19 +166,20 @@ def _add_file_to_cache(self, file_path):
file_path
)
except Exception as e:
logging.error(f"Error processing image {file_path}.")
logging.error(e)
logger.error(f"Error processing image {file_path}.")
logger.error(e)

def load_aspect_ratio_bucket_indices(self, cache_file):
logging.info("Loading aspect ratio bucket indices from cache file.")
logger.info("Loading aspect ratio bucket indices from cache file.")
with cache_file.open("r") as f:
try:
aspect_ratio_bucket_indices = json.load(f)
except:
logging.warn(
logger.warn(
f"Could not load aspect ratio bucket indices from {cache_file}. Creating a new one!"
)
aspect_ratio_bucket_indices = {}
logger.info("Loading of aspect bucket indexes completed.")
return aspect_ratio_bucket_indices

def _bucket_worker(self, tqdm_queue, files, aspect_ratio_bucket_indices_queue):
Expand All @@ -193,7 +194,7 @@ def _bucket_worker(self, tqdm_queue, files, aspect_ratio_bucket_indices_queue):
) # Send processed data

def compute_aspect_ratio_bucket_indices(self, cache_file):
logging.warning("Computing aspect ratio bucket indices.")
logger.warning("Computing aspect ratio bucket indices.")
def rglob_follow_symlinks(path: Path, pattern: str):
for p in path.glob(pattern):
yield p
Expand All @@ -206,57 +207,60 @@ def rglob_follow_symlinks(path: Path, pattern: str):
yield from rglob_follow_symlinks(real_path, pattern)


with self.accelerator.main_process_first():
tqdm_queue = Queue() # Queue for updating progress bar
aspect_ratio_bucket_indices_queue = (
Queue()
) # Queue for gathering data from processes

all_image_files = list(
rglob_follow_symlinks(
Path(self.instance_data_root), "*.[jJpP][pPnN][gG]"
)
logger.info('Built queue object.')
tqdm_queue = Queue() # Queue for updating progress bar
aspect_ratio_bucket_indices_queue = (
Queue()
) # Queue for gathering data from processes
logger.info('Build file list..')
all_image_files = list(
rglob_follow_symlinks(
Path(self.instance_data_root), "*.[jJpP][pPnN][gG]"
)
)
logger.info('Split file list into shards.')
files_split = np.array_split(all_image_files, 8)
workers = []
logger.info('Process lists...')
for files in files_split:
p = Process(
target=self._bucket_worker,
args=(tqdm_queue, files, aspect_ratio_bucket_indices_queue),
)
p.start()
workers.append(p)

# Update progress bar and gather results in main process
aspect_ratio_bucket_indices = {}
logger.info('Update progress bar and gather results in main process.')
with tqdm(total=len(all_image_files)) as pbar:
while any(
p.is_alive() for p in workers
): # Continue until all processes are done
while (
not tqdm_queue.empty()
): # Update progress bar with each completed file
pbar.update(tqdm_queue.get())
while (
not aspect_ratio_bucket_indices_queue.empty()
): # Gather results
aspect_ratio_bucket_indices.update(
aspect_ratio_bucket_indices_queue.get()
)

files_split = np.array_split(all_image_files, 8)
workers = []
for files in files_split:
p = Process(
target=self._bucket_worker,
args=(tqdm_queue, files, aspect_ratio_bucket_indices_queue),
)
p.start()
workers.append(p)

# Update progress bar and gather results in main process
aspect_ratio_bucket_indices = {}
with tqdm(total=len(all_image_files)) as pbar:
while any(
p.is_alive() for p in workers
): # Continue until all processes are done
while (
not tqdm_queue.empty()
): # Update progress bar with each completed file
pbar.update(tqdm_queue.get())
while (
not aspect_ratio_bucket_indices_queue.empty()
): # Gather results
aspect_ratio_bucket_indices.update(
aspect_ratio_bucket_indices_queue.get()
)

# Gather any remaining results
while not aspect_ratio_bucket_indices_queue.empty(): # Gather results
aspect_ratio_bucket_indices.update(
aspect_ratio_bucket_indices_queue.get()
)

for p in workers:
p.join() # Wait for processes to finish

with cache_file.open("w") as f:
json.dump(aspect_ratio_bucket_indices, f)
# Gather any remaining results
while not aspect_ratio_bucket_indices_queue.empty(): # Gather results
aspect_ratio_bucket_indices.update(
aspect_ratio_bucket_indices_queue.get()
)
logger.info('Join processes and finish up.')
for p in workers:
p.join() # Wait for processes to finish

with cache_file.open("w") as f:
logger.info('Writing updated cache file to disk')
json.dump(aspect_ratio_bucket_indices, f)
logger.info('Completed aspect bucket update.')
return aspect_ratio_bucket_indices

def assign_to_buckets(self):
Expand All @@ -265,7 +269,9 @@ def assign_to_buckets(self):
if cache_file.exists():
output = self.load_aspect_ratio_bucket_indices(cache_file)
if output is not None and len(output) > 0:
logging.info(f'We found {len(output)} buckets')
return output
logger.info('Bucket assignment completed.')
return self.compute_aspect_ratio_bucket_indices(cache_file)

def __len__(self):
Expand Down
3 changes: 1 addition & 2 deletions helpers/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,4 @@ def precompute_embeddings_for_prompts(self, prompts):
Args:
prompts (list[str]): All of the prompts.
"""
with self.accelerator.main_process_first():
self.compute_embeddings_for_prompts(prompts, return_concat=False)
self.compute_embeddings_for_prompts(prompts, return_concat=False)
24 changes: 13 additions & 11 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def collate_fn(examples):

# DataLoaders creation:
# Dataset and DataLoaders creation:
logger.info('Creating dataset iterator object')
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
accelerator=accelerator,
Expand All @@ -1004,13 +1005,15 @@ def collate_fn(examples):
debug_dataset_loader=args.debug_dataset_loader,
caption_strategy=args.caption_strategy
)
logger.info('Creating aspect bucket sampler')
custom_balanced_sampler = BalancedBucketSampler(
train_dataset.aspect_ratio_bucket_indices,
batch_size=args.train_batch_size,
seen_images_path=args.seen_state_path,
state_path=args.state_path,
debug_aspect_buckets=args.debug_aspect_buckets,
)
logger.info('Plugging sampler into dataloader')
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
Expand All @@ -1019,11 +1022,13 @@ def collate_fn(examples):
collate_fn=lambda examples: collate_fn(examples),
num_workers=args.dataloader_num_workers,
)
logger.info('Initialise text embedding cache')
embed_cache = TextEmbeddingCache(
text_encoders=text_encoders, tokenizers=tokenizers, accelerator=accelerator
)
logger.info(f"Pre-computing text embeds / updating cache.")
embed_cache.precompute_embeddings_for_prompts(train_dataset.get_all_captions())
with accelerator.main_process_first():
logger.info(f"Pre-computing text embeds / updating cache.")
embed_cache.precompute_embeddings_for_prompts(train_dataset.get_all_captions())

if args.validation_prompt is not None:
(
Expand All @@ -1044,14 +1049,11 @@ def collate_fn(examples):
text_encoder_2.to("cpu")
memory_after_unload = torch.cuda.memory_allocated() / 1024**3
memory_saved = memory_after_unload - memory_before_unload
gc.collect()
torch.cuda.empty_cache()
if accelerator.is_main_process:
logger.info(
f"After nuking text encoders from orbit, we freed {abs(round(memory_saved, 2))} GB of VRAM."
"This number might be massively understated, because of how CUDA memory management works."
"The real memories were the friends we trained a model on along the way."
)
logger.info(
f"After nuking text encoders from orbit, we freed {abs(round(memory_saved, 2))} GB of VRAM."
"This number might be massively understated, because of how CUDA memory management works."
"The real memories were the friends we trained a model on along the way."
)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
Expand All @@ -1068,7 +1070,7 @@ def collate_fn(examples):
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)

accelerator.wait_for_everyone()
# Prepare everything with our `accelerator`.
logger.info(f'Loading our accelerator...')
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
Expand Down

0 comments on commit c5ebc0e

Please sign in to comment.