From 545ea060dd5605db264030c72e9aea1d71d11100 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 16 Jun 2024 08:16:20 -0700 Subject: [PATCH] able to return non-conditioned model output for CFG++ (better CFG w/ ddim) --- README.md | 9 +++++++++ .../classifier_free_guidance_pytorch.py | 15 +++++++++++++-- setup.py | 2 +- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5257956..e7a62c0 100644 --- a/README.md +++ b/README.md @@ -203,3 +203,12 @@ guided_pred = model(data, texts = texts, cond_scale = 3.) # cond_scale stands f year = {2023} } ``` + +```bibtex +@inproceedings{Chung2024CFGMC, + title = {CFG++: Manifold-constrained Classifier Free Guidance for Diffusion Models}, + author = {Hyungjin Chung and Jeongsol Kim and Geon Yeong Park and Hyelin Nam and Jong Chul Ye}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:270391454} +} +``` diff --git a/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py b/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py index 4ff55fa..e94799d 100644 --- a/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py +++ b/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py @@ -89,6 +89,7 @@ def inner( *args, cond_scale: float = 1., rescale_phi: float = 0., + return_unconditioned: bool = False, cfg_routed_kwargs: Dict[str, Tuple[Any, Any]] = dict(), # to pass in separate arguments to forward and nulled forward calls (for handling caching when using CFG on transformer decoding) **kwargs ): @@ -181,10 +182,20 @@ def fn_maybe_with_text(self, *args, **kwargs): rescaled_logits = scaled_logits * (logits.std(dim = dims, keepdim = True) / scaled_logits.std(dim = dims, keepdim= True)) logit_output = rescaled_logits * rescale_phi + scaled_logits * (1. - rescale_phi) + # can return unconditioned prediction + # for use in CFG++ https://arxiv.org/abs/2406.08070 + + output = logit_output + + if return_unconditioned: + output = (output, null_logits) + + # handle multiple outputs from original function + if is_empty(zipped_rest): - return logit_output + return output - return (logit_output, *zipped_rest) + return (output, *zipped_rest) return inner diff --git a/setup.py b/setup.py index f04ab42..b01a47d 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'classifier-free-guidance-pytorch', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.6.4', + version = '0.6.5', license='MIT', description = 'Classifier Free Guidance - Pytorch', author = 'Phil Wang',