-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1151 from bghira/main
merge
- Loading branch information
Showing
13 changed files
with
296 additions
and
46 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# CLIP score tracking | ||
|
||
CLIP scores are loosely related to measurement of a model's ability to follow prompts; it is not at all related to image quality/fidelity. | ||
|
||
The `clip/mean` score of your model indicates how closely the features extracted from the image align with the features extracted from the prompt. It is currently a popular metric for determining general prompt adherence, though is typically evaluated across a very large (~5,000) number of test prompts (eg. Parti Prompts). | ||
|
||
CLIP score generation during model pretraining can help demonstrate that the model is approaching its objective, but once a `clip/mean` value around `.30` to `.39` is reached, the comparison seems to become less meaningful. Models that show an average CLIP score around `.33` can outperform a model with an average CLIP score of `.36` in human analysis. However, a model with a very low average CLIP score around `0.18` to `0.22` will probably be pretty poorly-performing. | ||
|
||
Within a single test run, some prompts will result in a very low CLIP score of around `0.14` (`clip/min` value in the tracker charts) even though their images align fairly well with the user prompt and have high image quality; conversely, CLIP scores as high as `0.39` (`clip/max` value in the tracker charts) may appear from images with questionable quality, as this test is not meant to capture this information. This is why such a large number of prompts are typically used to measure model performance - _and even then_.. | ||
|
||
On its own, CLIP scores do not take long to calculate; however, the number of prompts required for meaningful evaluation can make it take an incredibly long time. | ||
|
||
Since it doesn't take much to run, it doesn't hurt to include CLIP evaluation in small training runs. Perhaps you will discover a pattern of the outputs where it makes sense to abandon a training run or adjust other hyperparameters such as the learning rate. | ||
|
||
To include a standard prompt library for evaluation, `--validation_prompt_library` can be provided and then we will generate a somewhat relative benchmark between training runs. | ||
|
||
In `config.json`: | ||
|
||
```json | ||
{ | ||
... | ||
"evaluation_type": "clip", | ||
"pretrained_evaluation_model_name_or_path": "openai/clip-vit-large-patch14-336", | ||
"report_to": "tensorboard", # or wandb | ||
... | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from functools import partial | ||
from torchmetrics.functional.multimodal import clip_score | ||
from torchvision import transforms | ||
import torch, logging, os | ||
import numpy as np | ||
from PIL import Image | ||
from helpers.training.state_tracker import StateTracker | ||
|
||
logger = logging.getLogger("ModelEvaluator") | ||
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) | ||
|
||
model_evaluator_map = { | ||
"clip": "CLIPModelEvaluator", | ||
} | ||
|
||
class ModelEvaluator: | ||
def __init__(self, pretrained_model_name_or_path): | ||
raise NotImplementedError("Subclasses is incomplete, no __init__ method was found.") | ||
|
||
def evaluate(self, images, prompts): | ||
raise NotImplementedError("Subclasses should implement the evaluate() method.") | ||
|
||
@staticmethod | ||
def from_config(args): | ||
"""Instantiate a ModelEvaluator from the training config, if set to do so.""" | ||
if not StateTracker.get_accelerator().is_main_process: | ||
return None | ||
if args.evaluation_type is not None and args.evaluation_type.lower() != "" and args.evaluation_type.lower() != "none": | ||
model_evaluator = model_evaluator_map[args.evaluation_type] | ||
return globals()[model_evaluator](args.pretrained_evaluation_model_name_or_path) | ||
|
||
return None | ||
|
||
|
||
class CLIPModelEvaluator(ModelEvaluator): | ||
def __init__(self, pretrained_model_name_or_path='openai/clip-vit-large-patch14-336'): | ||
self.clip_score_fn = partial(clip_score, model_name_or_path=pretrained_model_name_or_path) | ||
self.preprocess = transforms.Compose([ | ||
transforms.ToTensor() | ||
]) | ||
|
||
def evaluate(self, images, prompts): | ||
# Preprocess images | ||
images_tensor = torch.stack([self.preprocess(img) * 255 for img in images]) | ||
# Compute CLIP scores | ||
result = self.clip_score_fn(images_tensor, prompts).detach().cpu() | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.