Skip to content

Commit

Permalink
Merge pull request #888 from bghira/main
Browse files Browse the repository at this point in the history
multigpu VAE cache rebuild fixes; random crop auto-rebuild; mobius flux; json backend now renamed to discovery ; wandb guidance tracking
  • Loading branch information
bghira authored Aug 27, 2024
2 parents d4b59d6 + d8bfea2 commit 202c437
Show file tree
Hide file tree
Showing 16 changed files with 151 additions and 67 deletions.
4 changes: 2 additions & 2 deletions documentation/CONTROLNET.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ The dataloader configuration remains pretty close to a typical text-to-image dat
"instance_data_dir": "/Volumes/ml/datasets/canny-edge/animals/antelope-data",
"caption_strategy": "instanceprompt",
"instance_prompt": "an antelope",
"metadata_backend": "json",
"metadata_backend": "discovery",
"minimum_image_size": 1.0,
"maximum_image_size": 1.0,
"target_downsample_size": 1.0,
Expand All @@ -76,7 +76,7 @@ The dataloader configuration remains pretty close to a typical text-to-image dat
"instance_data_dir": "/Volumes/ml/datasets/canny-edge/animals/antelope-conditioning",
"caption_strategy": "instanceprompt",
"instance_prompt": "an antelope",
"metadata_backend": "json",
"metadata_backend": "discovery",
"crop": true,
"crop_aspect": "square",
"crop_style": "center",
Expand Down
4 changes: 2 additions & 2 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ create a `DATALOADER_CONFIG` (config/multidatabackend.json) with this:
"disabled": false,
"skip_file_discovery": "",
"caption_strategy": "filename",
"metadata_backend": "json"
"metadata_backend": "discovery"
},
{
"id": "dreambooth-subject",
Expand All @@ -278,7 +278,7 @@ create a `DATALOADER_CONFIG` (config/multidatabackend.json) with this:
"instance_data_dir": "datasets/dreambooth-subject",
"caption_strategy": "instanceprompt",
"instance_prompt": "the name of your subject goes here",
"metadata_backend": "json"
"metadata_backend": "discovery"
},
{
"id": "text-embeds",
Expand Down
2 changes: 1 addition & 1 deletion documentation/quickstart/KOLORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ In your `OUTPUT_DIR` directory, create a multidatabackend.json:
"disabled": false,
"skip_file_discovery": "",
"caption_strategy": "filename",
"metadata_backend": "json"
"metadata_backend": "discovery"
},
{
"id": "text-embeds",
Expand Down
2 changes: 1 addition & 1 deletion documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ In your `OUTPUT_DIR` directory, create a multidatabackend.json:
"disabled": false,
"skip_file_discovery": "",
"caption_strategy": "filename",
"metadata_backend": "json"
"metadata_backend": "discovery"
},
{
"id": "text-embeds",
Expand Down
2 changes: 1 addition & 1 deletion documentation/quickstart/SIGMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ In your `OUTPUT_DIR` directory, create a multidatabackend.json:
"disabled": false,
"skip_file_discovery": "",
"caption_strategy": "filename",
"metadata_backend": "json"
"metadata_backend": "discovery"
},
{
"id": "text-embeds",
Expand Down
17 changes: 16 additions & 1 deletion helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,17 @@ def parse_args(input_args=None):
parser.add_argument(
"--flux_guidance_mode",
type=str,
choices=["constant", "random-range"],
choices=["constant", "random-range", "mobius"],
default="constant",
help=(
"Flux has a 'guidance' value used during training time that reflects the CFG range of your training samples."
" The default mode 'constant' will use a single value for every sample."
" The mode 'random-range' will randomly select a value from the range of the CFG for each sample."
" The mode 'mobius' will use a value that is a function of the remaining steps in the epoch, constructively"
" deconstructing the constructed deconstructions to then Mobius them back into the constructed reconstructions,"
" possibly resulting in the exploration of what is known as the Mobius space, a new continuous"
" realm of possibility brought about by destroying the model so that you can make it whole once more."
" Or so according to DataVoid, anyway. This is just a Flux-specific implementation of Mobius."
" Set the range using --flux_guidance_min and --flux_guidance_max."
),
)
Expand Down Expand Up @@ -2071,6 +2076,16 @@ def parse_args(input_args=None):
"Flux Schnell requires fewer inference steps. Consider reducing --validation_num_inference_steps to 4."
)

if args.flux_guidance_mode == "mobius":
logger.warning(
"Mobius training is only for the most elite. Pardon my English, but this is not for those who don't like to destroy something beautiful every now and then. If you feel perhaps this is not for you, please consider using a different guidance mode."
)
if args.flux_guidance_min < 1.0:
logger.warning(
"Flux minimum guidance value for Mobius training is 1.0. Updating value.."
)
args.flux_guidance_min = 1.0

if args.use_ema and args.ema_cpu_only:
args.ema_device = "cpu"

Expand Down
17 changes: 9 additions & 8 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,13 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
output["config"]["hash_filenames"] = backend["hash_filenames"]

# check if caption_strategy=parquet with metadata_backend=json
if (
output["config"]["caption_strategy"] == "parquet"
and backend.get("metadata_backend", "json") == "json"
current_metadata_backend_type = backend.get("metadata_backend", "discovery")
if output["config"]["caption_strategy"] == "parquet" and (
current_metadata_backend_type == "json"
or current_metadata_backend_type == "discovery"
):
raise ValueError(
f"(id={backend['id']}) Cannot use caption_strategy=parquet with metadata_backend=json. Instead, it is recommended to use the textfile strategy and extract your captions into txt files."
f"(id={backend['id']}) Cannot use caption_strategy=parquet with metadata_backend={current_metadata_backend_type}. Instead, it is recommended to use the textfile strategy and extract your captions into txt files."
)

maximum_image_size = backend.get("maximum_image_size", args.maximum_image_size)
Expand Down Expand Up @@ -674,11 +675,11 @@ def configure_multi_databackend(
image_embed_data_backend = image_embed_backends[image_embed_backend_id]
info_log(f"(id={init_backend['id']}) Loading bucket manager.")
metadata_backend_args = {}
metadata_backend = backend.get("metadata_backend", "json")
if metadata_backend == "json":
from helpers.metadata.backends.json import JsonMetadataBackend
metadata_backend = backend.get("metadata_backend", "discovery")
if metadata_backend == "json" or metadata_backend == "discovery":
from helpers.metadata.backends.discovery import DiscoveryMetadataBackend

BucketManager_cls = JsonMetadataBackend
BucketManager_cls = DiscoveryMetadataBackend
elif metadata_backend == "parquet":
from helpers.metadata.backends.parquet import ParquetMetadataBackend

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from helpers.image_manipulation.brightness import calculate_luminance
from helpers.training import image_file_extensions

logger = logging.getLogger("JsonMetadataBackend")
logger = logging.getLogger("DiscoveryMetadataBackend")
if should_log():
target_level = os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")
else:
target_level = "ERROR"
logger.setLevel(target_level)


class JsonMetadataBackend(MetadataBackend):
class DiscoveryMetadataBackend(MetadataBackend):
def __init__(
self,
id: str,
Expand Down Expand Up @@ -275,7 +275,7 @@ def repeat_len(bucket):
return len(bucket) * (self.repeats + 1)

return sum(
repeat_len(bucket) // self.batch_size
(repeat_len(bucket) + (self.batch_size - 1)) // self.batch_size
for bucket in self.aspect_ratio_bucket_indices.values()
if repeat_len(bucket) >= self.batch_size
)
2 changes: 1 addition & 1 deletion helpers/metadata/backends/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def repeat_len(bucket):
return len(bucket) * (self.repeats + 1)

return sum(
repeat_len(bucket) // self.batch_size
(repeat_len(bucket) + (self.batch_size - 1)) // self.batch_size
for bucket in self.aspect_ratio_bucket_indices.values()
if repeat_len(bucket) >= self.batch_size
)
30 changes: 30 additions & 0 deletions helpers/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,35 @@
import torch
import random
from helpers.models.flux.pipeline import FluxPipeline
from helpers.training import steps_remaining_in_epoch


def get_mobius_guidance(args, global_step, steps_per_epoch, batch_size, device):
"""
state of the art
"""
steps_remaining = steps_remaining_in_epoch(global_step, steps_per_epoch)

# Start with a linear mapping from remaining steps to a scale between 0 and 1
scale_factor = steps_remaining / steps_per_epoch

# we want the last 10% of the epoch to have a guidance of 1.0
threshold_step_count = max(1, int(steps_per_epoch * 0.1))

if (
steps_remaining <= threshold_step_count
): # Last few steps in the epoch, set guidance to 1.0
guidance_values = [1.0 for _ in range(batch_size)]
else:
# Sample between flux_guidance_min and flux_guidance_max with bias towards 1.0
guidance_values = [
random.uniform(args.flux_guidance_min, args.flux_guidance_max)
* scale_factor
+ (1.0 - scale_factor)
for _ in range(batch_size)
]

return guidance_values


def update_flux_schedule_to_fast(args, noise_scheduler_to_copy):
Expand Down
24 changes: 14 additions & 10 deletions helpers/multiaspect/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,12 +503,12 @@ def __iter__(self):
to_yield = self._yield_n_from_exhausted_bucket(
need_image_count, self.buckets[self.current_bucket]
)
# # add the available images
# to_yield.extend(
# self._validate_and_yield_images_from_samples(
# available_images, self.buckets[self.current_bucket]
# )
# )
# add the available images
to_yield.extend(
self._validate_and_yield_images_from_samples(
available_images, self.buckets[self.current_bucket]
)
)
else:
all_buckets_exhausted = False # Found a non-exhausted bucket
samples = random.sample(
Expand Down Expand Up @@ -576,17 +576,21 @@ def __len__(self):
backend_config = StateTracker.get_data_backend_config(self.id)
repeats = backend_config.get("repeats", 0)
# We need at least a multiplier of 1. Repeats is the number of extra sample steps.
multiplier = 1
if repeats > 0:
multiplier = repeats + 1
return (
multiplier = repeats + 1 if repeats > 0 else 1

total_samples = (
sum(
len(indices)
for indices in self.metadata_backend.aspect_ratio_bucket_indices.values()
)
* multiplier
)

# Calculate the total number of full batches
total_batches = (total_samples + (self.batch_size - 1)) // self.batch_size

return total_batches

@staticmethod
def convert_to_human_readable(
aspect_ratio_float: float, bucket: iter, resolution: int = 1024
Expand Down
15 changes: 15 additions & 0 deletions helpers/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,18 @@
},
},
}


def steps_remaining_in_epoch(current_step: int, steps_per_epoch: int) -> int:
"""
Calculate the number of steps remaining in the current epoch.
Args:
current_step (int): The current step within the epoch.
steps_per_epoch (int): Total number of steps in the epoch.
Returns:
int: Number of steps remaining in the current epoch.
"""
remaining_steps = steps_per_epoch - (current_step % steps_per_epoch)
return remaining_steps
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from PIL import Image
from pathlib import Path
from helpers.multiaspect.dataset import MultiAspectDataset
from helpers.metadata.backends.json import JsonMetadataBackend
from helpers.metadata.backends.discovery import DiscoveryMetadataBackend
from helpers.data_backend.base import BaseDataBackend


class TestMultiAspectDataset(unittest.TestCase):
def setUp(self):
self.instance_data_dir = "/some/fake/path"
self.accelerator = Mock()
self.metadata_backend = Mock(spec=JsonMetadataBackend)
self.metadata_backend = Mock(spec=DiscoveryMetadataBackend)
self.metadata_backend.__len__ = Mock(return_value=10)
self.image_metadata = {
"image_path": "fake_image_path",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_metadata_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest, json
from PIL import Image
from unittest.mock import Mock, patch, MagicMock
from helpers.metadata.backends.json import JsonMetadataBackend
from helpers.metadata.backends.discovery import DiscoveryMetadataBackend
from helpers.training.state_tracker import StateTracker
from tests.helpers.data import MockDataBackend

Expand Down Expand Up @@ -30,8 +30,8 @@ def setUp(self):
"helpers.training.state_tracker.StateTracker._save_to_disk",
return_value=True,
), patch("pathlib.Path.exists", return_value=True):
with self.assertLogs("JsonMetadataBackend", level="WARNING"):
self.metadata_backend = JsonMetadataBackend(
with self.assertLogs("DiscoveryMetadataBackend", level="WARNING"):
self.metadata_backend = DiscoveryMetadataBackend(
id="foo",
instance_data_dir=self.instance_data_dir,
cache_file=self.cache_file,
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_load_cache_valid(self):
def test_load_cache_invalid(self):
invalid_cache_data = "this is not valid json"
with patch.object(self.data_backend, "read", return_value=invalid_cache_data):
with self.assertLogs("JsonMetadataBackend", level="WARNING"):
with self.assertLogs("DiscoveryMetadataBackend", level="WARNING"):
self.metadata_backend.reload_cache()

def test_save_cache(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from unittest import skip
from unittest.mock import Mock, MagicMock, patch
from helpers.multiaspect.sampler import MultiAspectSampler
from helpers.metadata.backends.json import JsonMetadataBackend
from helpers.metadata.backends.discovery import DiscoveryMetadataBackend
from helpers.multiaspect.state import BucketStateManager
from tests.helpers.data import MockDataBackend
from accelerate import PartialState
Expand All @@ -16,10 +16,10 @@ def setUp(self):
self.process_state = PartialState()
self.accelerator = MagicMock()
self.accelerator.log = MagicMock()
self.metadata_backend = Mock(spec=JsonMetadataBackend)
self.metadata_backend = Mock(spec=DiscoveryMetadataBackend)
self.metadata_backend.id = "foo"
self.metadata_backend.aspect_ratio_bucket_indices = {
"1.0": ["image1", "image2"]
"1.0": ["image1", "image2", "image3", "image4"],
}
self.metadata_backend.seen_images = {}
self.data_backend = MockDataBackend()
Expand Down
Loading

0 comments on commit 202c437

Please sign in to comment.