Skip to content

Commit

Permalink
able to return non-conditioned model output for CFG++ (better CFG w/ …
Browse files Browse the repository at this point in the history
…ddim)
  • Loading branch information
lucidrains committed Jun 16, 2024
1 parent a9edb94 commit 545ea06
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,12 @@ guided_pred = model(data, texts = texts, cond_scale = 3.) # cond_scale stands f
year = {2023}
}
```

```bibtex
@inproceedings{Chung2024CFGMC,
title = {CFG++: Manifold-constrained Classifier Free Guidance for Diffusion Models},
author = {Hyungjin Chung and Jeongsol Kim and Geon Yeong Park and Hyelin Nam and Jong Chul Ye},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:270391454}
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def inner(
*args,
cond_scale: float = 1.,
rescale_phi: float = 0.,
return_unconditioned: bool = False,
cfg_routed_kwargs: Dict[str, Tuple[Any, Any]] = dict(), # to pass in separate arguments to forward and nulled forward calls (for handling caching when using CFG on transformer decoding)
**kwargs
):
Expand Down Expand Up @@ -181,10 +182,20 @@ def fn_maybe_with_text(self, *args, **kwargs):
rescaled_logits = scaled_logits * (logits.std(dim = dims, keepdim = True) / scaled_logits.std(dim = dims, keepdim= True))
logit_output = rescaled_logits * rescale_phi + scaled_logits * (1. - rescale_phi)

# can return unconditioned prediction
# for use in CFG++ https://arxiv.org/abs/2406.08070

output = logit_output

if return_unconditioned:
output = (output, null_logits)

# handle multiple outputs from original function

if is_empty(zipped_rest):
return logit_output
return output

return (logit_output, *zipped_rest)
return (output, *zipped_rest)

return inner

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'classifier-free-guidance-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.6.4',
version = '0.6.5',
license='MIT',
description = 'Classifier Free Guidance - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 545ea06

Please sign in to comment.