Skip to content

Commit

Permalink
add ability to return raw text embedding and mask, conditionally drop…
Browse files Browse the repository at this point in the history
…ped out
  • Loading branch information
lucidrains committed Dec 8, 2023
1 parent 01425ec commit 587eaf6
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 20 deletions.
13 changes: 11 additions & 2 deletions classifier_free_guidance_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import NullConditioner, TextConditioner, AttentionTextConditioner
from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import classifier_free_guidance, classifier_free_guidance_class_decorator
from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import (
NullConditioner,
TextConditioner,
AttentionTextConditioner,
TextEmbeddingReturner
)

from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import (
classifier_free_guidance,
classifier_free_guidance_class_decorator
)

from classifier_free_guidance_pytorch.open_clip import OpenClipAdapter
from classifier_free_guidance_pytorch.t5 import T5Adapter
153 changes: 136 additions & 17 deletions classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections import namedtuple
from functools import wraps, partial

import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch import nn, einsum, Tensor

from einops import rearrange, repeat, pack, unpack

Expand All @@ -25,6 +26,11 @@
TEXT_CONDITIONER_NAME = 'text_conditioner'
CONDITION_FUNCTION_KEY_NAME = 'cond_fns'

TextCondReturn = namedtuple('TextCondReturn', [
'embed',
'mask'
])

# helper functions

def exists(val):
Expand Down Expand Up @@ -69,7 +75,6 @@ def classifier_free_guidance(
fn_params = signature(fn).parameters

auto_handle_text_condition = texts_key_name not in fn_params and text_embeds_key_name not in fn_params
assert not (auto_handle_text_condition and cond_fns_keyname not in fn_params), f'{cond_fns_keyname} must be in the wrapped function for autohandling texts -> conditioning functions - ex. forward(..., {cond_fns_keyname})'

@wraps(fn)
def inner(
Expand Down Expand Up @@ -103,12 +108,16 @@ def fn_maybe_with_text(self, *args, **kwargs):

text_condition_input = dict(texts = texts) if exists(texts) else dict(text_embeds = text_embeds)

cond_fns = text_conditioner(**text_condition_input, cond_drop_prob = cond_drop_prob)
cond_fns, raw_text_cond = text_conditioner(**text_condition_input, cond_drop_prob = cond_drop_prob)

elif isinstance(text_conditioner, NullConditioner):
cond_fns = text_conditioner()
cond_fns, raw_text_cond = text_conditioner()

if 'cond_fns' in fn_params:
kwargs.update(cond_fns = cond_fns)

kwargs.update(cond_fns = cond_fns)
if 'raw_text_cond' in fn_params:
kwargs.update(raw_text_cond = raw_text_cond)

return fn(self, *args, **kwargs)

Expand Down Expand Up @@ -169,7 +178,8 @@ def __init__(
text_condition_type: Union[
Literal['film'],
Literal['attention'],
Literal['null']
Literal['null'],
Literal['raw'],
] = 'film',
text_condition_model_types: Tuple[str, ...] = ('t5',),
text_condition_hidden_dims: Tuple[int, ...],
Expand All @@ -182,6 +192,8 @@ def __init__(
condition_klass = TextConditioner
elif text_condition_type == 'attention':
condition_klass = AttentionTextConditioner
elif text_condition_type == 'raw':
condition_klass = TextEmbeddingReturner
else:
condition_klass = NullConditioner

Expand Down Expand Up @@ -387,7 +399,7 @@ def __init__(
num_null_conditioners = len(hidden_dims)
self.cond_fns = tuple(Identity() for _ in range(num_null_conditioners))

self.register_buffer('_device_param', torch.tensor(0.), persistent = False)
self.register_buffer('_device_param', torch.tensor(0), persistent = False)

@property
def device(self):
Expand All @@ -396,8 +408,8 @@ def device(self):
def embed_texts(self, texts: List[str]):
assert False, 'null conditioner cannot embed text'

def forward(self, *args, **kwarg) -> Tuple[Identity, ...]:
return self.cond_fns
def forward(self, *args, **kwarg):
return self.cond_fns, None

# text conditioner with FiLM

Expand Down Expand Up @@ -472,10 +484,13 @@ def embed_texts(self, texts: List[str]):
def forward(
self,
texts: Optional[List[str]] = None,
text_embeds: Optional[torch.Tensor] = None,
text_embeds: Optional[Tensor] = None,
cond_drop_prob = None,
repeat_batch = 1, # for robotic transformer edge case
) -> Tuple[Callable, ...]:
repeat_batch = 1, # for robotic transformer edge case
) -> Tuple[
Tuple[Callable, ...],
TextCondReturn
]:

assert exists(texts) ^ exists(text_embeds)

Expand Down Expand Up @@ -520,7 +535,7 @@ def forward(

cond_fns.append(wrapper_fn(cond_fn))

return tuple(cond_fns)
return tuple(cond_fns), TextCondReturn(text_embeds, None)

# cross attention text conditioner

Expand Down Expand Up @@ -575,7 +590,7 @@ def __init__(
for hidden_dim in hidden_dims:
self.conditioners.append(CrossAttention(dim_latent, hidden_dim, flash = flash))

self.register_buffer('_device_param', torch.tensor(0.), persistent = False)
self.register_buffer('_device_param', torch.tensor(0), persistent = False)

@property
def device(self):
Expand Down Expand Up @@ -603,10 +618,13 @@ def embed_texts(self, texts: List[str]):
def forward(
self,
texts: Optional[List[str]] = None,
text_embeds: Optional[torch.Tensor] = None,
text_embeds: Optional[Tensor] = None,
cond_drop_prob = None,
repeat_batch = 1, # for robotic transformer edge case
) -> Tuple[Callable, ...]:
) -> Tuple[
Tuple[Callable, ...],
TextCondReturn
]:

assert exists(texts) ^ exists(text_embeds)

Expand Down Expand Up @@ -646,4 +664,105 @@ def forward(

cond_fns.append(wrapper_fn(cond_fn))

return tuple(cond_fns)
return tuple(cond_fns), TextCondReturn(text_embeds, mask)

# return raw text embedding

@beartype
class TextEmbeddingReturner(Conditioner):
def __init__(
self,
*,
hidden_dims: Tuple[int, ...],
model_types = 't5',
model_names = None,
cond_drop_prob = 0.,
):
super().__init__()
model_types = cast_tuple(model_types)
model_names = cast_tuple(model_names, length = len(model_types))

assert len(model_types) == len(model_names)
assert all([model_type in MODEL_TYPES for model_type in model_types])

text_models = []

for model_type, model_name in zip(model_types, model_names):
klass = CONDITION_CONFIG.get(model_type)
model = klass(model_name)
text_models.append(model)

self.text_models = text_models

self.to_latent_dims = nn.ModuleList([])

dim_latent = default(dim_latent, max([model.dim_latent for model in text_models]))

for model in text_models:
self.to_latent_dims.append(nn.Linear(model.dim_latent, dim_latent))

self.conditioners = nn.ModuleList([])

self.cond_drop_prob = cond_drop_prob

for hidden_dim in hidden_dims:
self.conditioners.append(nn.Identity())

self.register_buffer('_device_param', torch.tensor(0), persistent = False)

@property
def device(self):
return next(self.buffers()).device

def embed_texts(self, texts: List[str]):
device = self.device

text_embeds = []

for text_model, to_latent in zip(self.text_models, self.to_latent_dims):
text_embed = text_model.embed_text(texts, return_text_encodings = True)

text_embed = text_embed.to(device)

mask = (text_embed != 0).any(dim = -1)

text_embed = to_latent(text_embed)
text_embed = text_embed.masked_fill(~mask[..., None], 0.)

text_embeds.append(text_embed)

return torch.cat(text_embeds, dim = -2)

def forward(
self,
texts: Optional[List[str]] = None,
text_embeds: Optional[Tensor] = None,
cond_drop_prob = None
) -> Tuple[
Tuple[Callable, ...],
TextCondReturn
]:

assert exists(texts) ^ exists(text_embeds)

if self.training:
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
else:
assert exists(cond_drop_prob), 'when not training, cond_drop_prob must be explicitly set'

if exists(texts):
batch = len(texts)

elif exists(text_embeds):
batch = text_embeds.shape[0]

if not exists(text_embeds):
text_embeds = self.embed_texts(texts)

mask = (text_embeds != 0).any(dim = -1)

if cond_drop_prob > 0.:
prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device = self.device)
mask = mask & prob_keep_mask

return tuple(self.conditioners), TextCondReturn(text_embeds, mask)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'classifier-free-guidance-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.3.0',
version = '0.4.0',
license='MIT',
description = 'Classifier Free Guidance - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 587eaf6

Please sign in to comment.