Skip to content

Commit

Permalink
Merge pull request #48 from bghira/main
Browse files Browse the repository at this point in the history
v0.2 merge
  • Loading branch information
bghira authored Aug 21, 2023
2 parents d527b8f + ceb34f1 commit 484cc39
Show file tree
Hide file tree
Showing 13 changed files with 1,142 additions and 150 deletions.
6 changes: 5 additions & 1 deletion helpers/aspect_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __iter__(self):
self.handle_small_image(image_path, bucket)
continue
image = exif_transpose(image)
aspect_ratio = round(image.width / image.height, 3)
aspect_ratio = round(image.width / image.height, 2)
actual_bucket = str(aspect_ratio)
if actual_bucket != bucket:
self.handle_incorrect_bucket(image_path, bucket, actual_bucket)
Expand All @@ -217,6 +217,10 @@ def __iter__(self):
f"Yielding {image.width}x{image.height} sample from bucket: {bucket} with aspect {actual_bucket}"
)
to_yield.append(image_path)
try:
image.close()
except:
pass
if StateTracker.status_training():
self.seen_images[image_path] = actual_bucket
if self.debug_aspect_buckets:
Expand Down
124 changes: 45 additions & 79 deletions helpers/dreambooth_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from ctypes import c_int

logger = logging.getLogger("DatasetLoader")
logger.setLevel(logging.INFO)
target_level = os.environ.get('SIMPLETUNER_LOG_LEVEL', 'WARNING')
logger.setLevel(target_level)
from concurrent.futures import ThreadPoolExecutor
import threading

Expand Down Expand Up @@ -63,7 +64,6 @@ def __init__(
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
self.aspect_ratio_buckets = aspect_ratio_buckets
self.use_original_images = use_original_images
self.accelerator = accelerator
Expand All @@ -73,9 +73,7 @@ def __init__(
self.caption_loop_count = 0
self.caption_strategy = caption_strategy
self.use_precomputed_token_ids = use_precomputed_token_ids
if len(self.aspect_ratio_bucket_indices) > 0:
pass
# self.update_cache()
self._length = self.num_instance_images
if not use_original_images:
logger.debug(f"Building transformations.")
self.image_transforms = self._get_image_transforms()
Expand All @@ -100,7 +98,7 @@ def _process_image(self, image_path_str, aspect_ratio_bucket_indices):
# Apply EXIF transforms
image = exif_transpose(image)
aspect_ratio = round(
image.width / image.height, 3
image.width / image.height, 2
) # Round to avoid excessive unique buckets
# Create a new bucket if it doesn't exist
if str(aspect_ratio) not in aspect_ratio_bucket_indices:
Expand All @@ -115,47 +113,14 @@ def _process_image(self, image_path_str, aspect_ratio_bucket_indices):
image.close()
return aspect_ratio_bucket_indices

def update_cache(self, base_dir=None, max_workers=64):
"""Update the aspect_ratio_bucket_indices based on the current state of the file system."""
if base_dir is None:
base_dir = self.instance_data_root
else:
base_dir = Path(base_dir)
if not base_dir.exists():
raise ValueError(f"Directory {base_dir} does not exist.")

logger.info(f"Looking for new images, to update aspect bucket cache.")
new_file_paths = [
str(path)
for path in base_dir.iterdir()
if path not in self.instance_images_path
]
logger.info(f"Discovered {len(new_file_paths)} new files to inspect for cache.")

for new_file in tqdm(new_file_paths, desc="Adding to cache"):
self.aspect_ratio_bucket_indices = self._process_image(
new_file, self.aspect_ratio_bucket_indices
)

# Update the instance_images_path to include the new images
self.instance_images_path += new_file_paths

# Update the total number of instance images
self.num_instance_images = len(self.instance_images_path)

# Save updated aspect_ratio_bucket_indices to the cache file
cache_file = self.instance_data_root / "aspect_ratio_bucket_indices.json"
with cache_file.open("w") as f:
json.dump(self.aspect_ratio_bucket_indices, f)

def _add_file_to_cache(self, file_path):
"""Add a single file to the cache (thread-safe)."""
try:
with Image.open(file_path) as image:
# Apply EXIF transforms
image = exif_transpose(image)
aspect_ratio = round(
image.width / image.height, 3
image.width / image.height, 2
) # Round to avoid excessive unique buckets

with threading.Lock():
Expand All @@ -182,19 +147,27 @@ def load_aspect_ratio_bucket_indices(self, cache_file):
logger.info("Loading of aspect bucket indexes completed.")
return aspect_ratio_bucket_indices

def _bucket_worker(self, tqdm_queue, files, aspect_ratio_bucket_indices_queue):
def _bucket_worker(self, tqdm_queue, files, aspect_ratio_bucket_indices_queue, existing_files_set):
for file in files:
# Process image as before, but now send results to queue instead of updating a manager.dict
if str(file) in existing_files_set:
tqdm_queue.put(1) # Update progress bar but skip further processing
continue
# Process image and send results to queue as before
aspect_ratio_bucket_indices = self._process_image(
str(file), self.aspect_ratio_bucket_indices
) # assuming _process_image now returns a value
)
tqdm_queue.put(1) # Update progress bar
aspect_ratio_bucket_indices_queue.put(
aspect_ratio_bucket_indices
) # Send processed data
aspect_ratio_bucket_indices_queue.put(aspect_ratio_bucket_indices)

def compute_aspect_ratio_bucket_indices(self, cache_file):
logger.warning("Computing aspect ratio bucket indices.")

# Step 1: Initialization Check
if hasattr(self, 'aspect_ratio_bucket_indices') and self.aspect_ratio_bucket_indices:
aspect_ratio_bucket_indices = self.aspect_ratio_bucket_indices
else:
aspect_ratio_bucket_indices = {}

def rglob_follow_symlinks(path: Path, pattern: str):
for p in path.glob(pattern):
yield p
Expand All @@ -206,72 +179,65 @@ def rglob_follow_symlinks(path: Path, pattern: str):
if real_path.is_dir():
yield from rglob_follow_symlinks(real_path, pattern)


logger.info('Built queue object.')
tqdm_queue = Queue() # Queue for updating progress bar
aspect_ratio_bucket_indices_queue = (
Queue()
) # Queue for gathering data from processes
tqdm_queue = Queue()
aspect_ratio_bucket_indices_queue = Queue()
logger.info('Build file list..')
all_image_files = list(
rglob_follow_symlinks(
Path(self.instance_data_root), "*.[jJpP][pPnN][gG]"
)
)
self._length = len(all_image_files)
logger.info('Split file list into shards.')
files_split = np.array_split(all_image_files, 8)
existing_files_set = set().union(*self.aspect_ratio_bucket_indices.values())
workers = []
logger.info('Process lists...')
for files in files_split:
p = Process(
target=self._bucket_worker,
args=(tqdm_queue, files, aspect_ratio_bucket_indices_queue),
args=(tqdm_queue, files, aspect_ratio_bucket_indices_queue, existing_files_set),
)
p.start()
workers.append(p)

# Update progress bar and gather results in main process
aspect_ratio_bucket_indices = {}
logger.info('Update progress bar and gather results in main process.')
with tqdm(total=len(all_image_files)) as pbar:
while any(
p.is_alive() for p in workers
): # Continue until all processes are done
while (
not tqdm_queue.empty()
): # Update progress bar with each completed file
while any(p.is_alive() for p in workers):
while not tqdm_queue.empty():
pbar.update(tqdm_queue.get())
while (
not aspect_ratio_bucket_indices_queue.empty()
): # Gather results
aspect_ratio_bucket_indices.update(
aspect_ratio_bucket_indices_queue.get()
)
while not aspect_ratio_bucket_indices_queue.empty():
aspect_ratio_bucket_indices.update(aspect_ratio_bucket_indices_queue.get())

# Gather any remaining results
while not aspect_ratio_bucket_indices_queue.empty(): # Gather results
aspect_ratio_bucket_indices.update(
aspect_ratio_bucket_indices_queue.get()
)
while not aspect_ratio_bucket_indices_queue.empty():
aspect_ratio_bucket_indices.update(aspect_ratio_bucket_indices_queue.get())
logger.info('Join processes and finish up.')
for p in workers:
p.join() # Wait for processes to finish
p.join()

# Step 3: Updating the Cache
new_file_paths = [str(file) for file in all_image_files if str(file) not in self.instance_images_path]

# Update the instance_images_path to include the new images
self.instance_images_path += new_file_paths

# Update the total number of instance images
self.num_instance_images = len(self.instance_images_path)

# Save updated aspect_ratio_bucket_indices to the cache file
with cache_file.open("w") as f:
logger.info('Writing updated cache file to disk')
json.dump(aspect_ratio_bucket_indices, f)

logger.info('Completed aspect bucket update.')

return aspect_ratio_bucket_indices


def assign_to_buckets(self):
cache_file = self.instance_data_root / "aspect_ratio_bucket_indices.json"
output = None
if cache_file.exists():
output = self.load_aspect_ratio_bucket_indices(cache_file)
if output is not None and len(output) > 0:
logging.info(f'We found {len(output)} buckets')
return output
logger.info('Bucket assignment completed.')
return self.compute_aspect_ratio_bucket_indices(cache_file)

def __len__(self):
Expand Down Expand Up @@ -382,7 +348,7 @@ def __getitem__(self, image_path):
def _resize_for_condition_image(self, input_image: Image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
aspect_ratio = round(W / H, 3)
aspect_ratio = round(W / H, 2)
msg = f"Inspecting image of aspect {aspect_ratio} and size {W}x{H} to "
if W < H:
W = resolution
Expand Down
Loading

0 comments on commit 484cc39

Please sign in to comment.