Skip to content

Commit

Permalink
Merge pull request #1146 from bghira/feature/clip-evaluation
Browse files Browse the repository at this point in the history
add clip score tracking
  • Loading branch information
bghira authored Nov 13, 2024
2 parents dfc3f8f + 6a8b707 commit c9f8025
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 10 deletions.
17 changes: 17 additions & 0 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ A lot of settings are instead set through the [dataloader config](/documentation
- **What**: Output image resolution, measured in pixels, or, formatted as: `widthxheight`, as in `1024x1024`. Multiple resolutions can be defined, separated by commas.
- **Why**: All images generated during validation will be this resolution. Useful if the model is being trained with a different resolution.

### `--validation_model_evaluator`

- **What**: Enable CLIP evaluation of generated images during validations.
- **Why**: CLIP scores calculate the distance of the generated image features to the provided validation prompt. This can give an idea of whether prompt adherence is improving, though it requires a large number of validation prompts to have any meaningful value.
- **Options**: "none" or "clip"

### `--crop`

Expand Down Expand Up @@ -467,6 +472,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--model_card_note MODEL_CARD_NOTE]
[--model_card_safe_for_work] [--logging_dir LOGGING_DIR]
[--benchmark_base_model] [--disable_benchmark]
[--validation_model_evaluator {clip,none}]
[--pretrained_validation_model_name_or_path PRETRAINED_VALIDATION_MODEL_NAME_OR_PATH]
[--validation_on_startup] [--validation_seed_source {gpu,cpu}]
[--validation_torch_compile]
[--validation_torch_compile_mode {max-autotune,reduce-overhead,default}]
Expand Down Expand Up @@ -1236,6 +1243,16 @@ options:
--disable_benchmark By default, the model will be benchmarked on the first
batch of the first epoch. This can be disabled with
this option.
--validation_model_evaluator {clip,none}
Validations must be enabled for model evaluation to
function. The default is to use no evaluator, and
'clip' will use a CLIP model to evaluate the resulting
model's performance during validations.
--pretrained_validation_model_name_or_path PRETRAINED_VALIDATION_MODEL_NAME_OR_PATH
Optionally provide a custom model to use for ViT
evaluations. The default is currently clip-vit-large-
patch14-336, allowing for lower patch sizes (greater
accuracy) and an input resolution of 336x336.
--validation_on_startup
When training begins, the starting model will have
validation prompts run through it, for later
Expand Down
4 changes: 4 additions & 0 deletions documentation/DREAMBOOTH.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ Alternatively, one might use the real name of their subject, or a 'similar enoug

After a number of training experiments, it seems as though a 'similar enough' celebrity is the best choice, especially if prompting the model for the person's real name ends up looking dissimilar.

# CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.

# Refiner tuning

If you're a fan of the SDXL refiner, you may find that it causes your generations to "ruin" the results of your Dreamboothed model.
Expand Down
4 changes: 4 additions & 0 deletions documentation/MIXTURE_OF_EXPERTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ If you'd like a demonstration dataset, [pseudo-camera-10k](https://huggingface.c

Stage two refiner training will automatically select images from each of your training sets, and use those as inputs for partial denoising at validation time.

## CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.

## Putting it all together at inference time

If you'd like to plug both of the models together to experiment with in a simple script, this will get you started:
Expand Down
27 changes: 27 additions & 0 deletions documentation/evaluation/CLIP_SCORES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# CLIP score tracking

CLIP scores are loosely related to measurement of a model's ability to follow prompts; it is not at all related to image quality/fidelity.

The `clip/mean` score of your model indicates how closely the features extracted from the image align with the features extracted from the prompt. It is currently a popular metric for determining general prompt adherence, though is typically evaluated across a very large (~5,000) number of test prompts (eg. Parti Prompts).

CLIP score generation during model pretraining can help demonstrate that the model is approaching its objective, but once a `clip/mean` value around `.30` to `.39` is reached, the comparison seems to become less meaningful. Models that show an average CLIP score around `.33` can outperform a model with an average CLIP score of `.36` in human analysis. However, a model with a very low average CLIP score around `0.18` to `0.22` will probably be pretty poorly-performing.

Within a single test run, some prompts will result in a very low CLIP score of around `0.14` (`clip/min` value in the tracker charts) even though their images align fairly well with the user prompt and have high image quality; conversely, CLIP scores as high as `0.39` (`clip/max` value in the tracker charts) may appear from images with questionable quality, as this test is not meant to capture this information. This is why such a large number of prompts are typically used to measure model performance - _and even then_..

On its own, CLIP scores do not take long to calculate; however, the number of prompts required for meaningful evaluation can make it take an incredibly long time.

Since it doesn't take much to run, it doesn't hurt to include CLIP evaluation in small training runs. Perhaps you will discover a pattern of the outputs where it makes sense to abandon a training run or adjust other hyperparameters such as the learning rate.

To include a standard prompt library for evaluation, `--validation_prompt_library` can be provided and then we will generate a somewhat relative benchmark between training runs.

In `config.json`:

```json
{
...
"evaluation_type": "clip",
"pretrained_evaluation_model_name_or_path": "openai/clip-vit-large-patch14-336",
"report_to": "tensorboard", # or wandb
...
}
```
4 changes: 4 additions & 0 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ A set of diverse prompt will help determine whether the model is collapsing as i

> ℹ️ Flux is a flow-matching model and shorter prompts that have strong similarities will result in practically the same image being produced by the model. Be sure to use longer, more descriptive prompts.
#### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.

#### Flux time schedule shifting

Flow-matching models such as Flux and SD3 have a property called "shift" that allows us to shift the trained portion of the timestep schedule using a simple decimal value.
Expand Down
4 changes: 4 additions & 0 deletions documentation/quickstart/KOLORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,7 @@ bash train.sh
This will begin the text embed and VAE output caching to disk.

For more information, see the [dataloader](/documentation/DATALOADER.md) and [tutorial](/TUTORIAL.md) documents.

### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.
6 changes: 5 additions & 1 deletion documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,8 @@ For more information on regularisation datasets, see [this section](/documentati

### Quantised training

See [this section](/documentation/DREAMBOOTH.md#quantised-model-training-loralycoris-only) of the Dreambooth guide for information on configuring quantisation for SD3 and other models.
See [this section](/documentation/DREAMBOOTH.md#quantised-model-training-loralycoris-only) of the Dreambooth guide for information on configuring quantisation for SD3 and other models.

### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.
4 changes: 4 additions & 0 deletions documentation/quickstart/SIGMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,7 @@ bash train.sh
This will begin the text embed and VAE output caching to disk.

For more information, see the [dataloader](/documentation/DATALOADER.md) and [tutorial](/TUTORIAL.md) documents.

### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.
20 changes: 20 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,26 @@ def get_argument_parser():
" This can be disabled with this option."
),
)
parser.add_argument(
"--validation_model_evaluator",
type=str,
default=None,
choices=["clip", "none"],
help=(
"Validations must be enabled for model evaluation to function. The default is to use no evaluator,"
" and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations."
)
)
parser.add_argument(
"--pretrained_validation_model_name_or_path",
type=str,
default="openai/clip-vit-large-patch14-336",
help=(
"Optionally provide a custom model to use for ViT evaluations."
" The default is currently clip-vit-large-patch14-336, allowing for lower patch sizes (greater accuracy)"
" and an input resolution of 336x336."
)
)
parser.add_argument(
"--validation_on_startup",
action="store_true",
Expand Down
48 changes: 48 additions & 0 deletions helpers/training/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from functools import partial
from torchmetrics.functional.multimodal import clip_score
from torchvision import transforms
import torch, logging, os
import numpy as np
from PIL import Image
from helpers.training.state_tracker import StateTracker

logger = logging.getLogger("ModelEvaluator")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))

model_evaluator_map = {
"clip": "CLIPModelEvaluator",
}

class ModelEvaluator:
def __init__(self, pretrained_model_name_or_path):
raise NotImplementedError("Subclasses is incomplete, no __init__ method was found.")

def evaluate(self, images, prompts):
raise NotImplementedError("Subclasses should implement the evaluate() method.")

@staticmethod
def from_config(args):
"""Instantiate a ModelEvaluator from the training config, if set to do so."""
if not StateTracker.get_accelerator().is_main_process:
return None
if args.validation_model_evaluator is not None and args.validation_model_evaluator.lower() != "" and args.validation_model_evaluator.lower() != "none":
model_evaluator = model_evaluator_map[args.validation_model_evaluator]
return globals()[model_evaluator](args.pretrained_validation_model_name_or_path)

return None


class CLIPModelEvaluator(ModelEvaluator):
def __init__(self, pretrained_model_name_or_path='openai/clip-vit-large-patch14-336'):
self.clip_score_fn = partial(clip_score, model_name_or_path=pretrained_model_name_or_path)
self.preprocess = transforms.Compose([
transforms.ToTensor()
])

def evaluate(self, images, prompts):
# Preprocess images
images_tensor = torch.stack([self.preprocess(img) * 255 for img in images])
# Compute CLIP scores
result = self.clip_score_fn(images_tensor, prompts).detach().cpu()

return result
10 changes: 10 additions & 0 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from helpers.caching.memory import reclaim_memory
from helpers.training.multi_process import _get_rank as get_rank
from helpers.training.validation import Validation, prepare_validation_prompt_list
from helpers.training.evaluation import ModelEvaluator
from helpers.training.state_tracker import StateTracker
from helpers.training.schedulers import load_scheduler_from_args
from helpers.training.custom_schedule import get_lr_scheduler
Expand Down Expand Up @@ -1353,6 +1354,7 @@ def init_validations(self):
):
logger.error("Cannot run validations with DeepSpeed ZeRO stage 3.")
return
model_evaluator = ModelEvaluator.from_config(args=self.config)
self.validation = Validation(
accelerator=self.accelerator,
unet=self.unet,
Expand All @@ -1374,6 +1376,7 @@ def init_validations(self):
ema_model=self.ema_model,
vae=self.vae,
controlnet=self.controlnet if self.config.controlnet else None,
model_evaluator=model_evaluator
)
if not self.config.train_text_encoder and self.validation is not None:
self.validation.clear_text_encoders()
Expand Down Expand Up @@ -2589,6 +2592,13 @@ def train(self):
self.guidance_values_list = []
if grad_norm is not None:
wandb_logs["grad_norm"] = grad_norm
if self.validation is not None and hasattr(self.validation, 'evaluation_result'):
eval_result = self.validation.get_eval_result()
if eval_result is not None and type(eval_result) == dict:
# add the dict to wandb_logs
self.validation.clear_eval_result()
wandb_logs.update(eval_result)

progress_bar.update(1)
self.state["global_step"] += 1
current_epoch_step += 1
Expand Down
Loading

0 comments on commit c9f8025

Please sign in to comment.