Skip to content

Commit

Permalink
some extra asserts for no conditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 8, 2023
1 parent 49e5eb0 commit dc80a4f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def fn_maybe_with_text(self, *args, **kwargs):

text_conditioner = getattr(self, text_conditioner_name, None)

cond_drop_prob = kwargs.pop(cond_drop_prob_keyname, None)

assert not exists(cond_drop_prob) or 0. <= cond_drop_prob <= 1.

# auto convert texts -> conditioning functions

if exists(texts) ^ exists(text_embeds):
Expand All @@ -104,13 +108,13 @@ def fn_maybe_with_text(self, *args, **kwargs):

assert exists(text_conditioner) and is_bearable(text_conditioner, Conditioner), 'text_conditioner must be set on your network with the correct hidden dimensions to be conditioned on'

cond_drop_prob = kwargs.pop(cond_drop_prob_keyname, None)

text_condition_input = dict(texts = texts) if exists(texts) else dict(text_embeds = text_embeds)

cond_fns, raw_text_cond = text_conditioner(**text_condition_input, cond_drop_prob = cond_drop_prob)

elif isinstance(text_conditioner, NullConditioner):
assert cond_drop_prob == 0., 'null conditioner has nothing to dropout'

cond_fns, raw_text_cond = text_conditioner()

if 'cond_fns' in fn_params:
Expand All @@ -129,7 +133,7 @@ def fn_maybe_with_text(self, *args, **kwargs):
return fn_maybe_with_text(self, *args, **kwargs)

assert cond_scale >= 1, 'invalid conditioning scale, must be greater or equal to 1'

kwargs_without_cond_dropout = {**kwargs, cond_drop_prob_keyname: 0.}
kwargs_with_cond_dropout = {**kwargs, cond_drop_prob_keyname: 1.}

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'classifier-free-guidance-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.4.0',
version = '0.4.1',
license='MIT',
description = 'Classifier Free Guidance - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit dc80a4f

Please sign in to comment.