From 587eaf6fff46626e4554a30ae94fc3eaeac35818 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 8 Dec 2023 06:43:16 -0800 Subject: [PATCH] add ability to return raw text embedding and mask, conditionally dropped out --- classifier_free_guidance_pytorch/__init__.py | 13 +- .../classifier_free_guidance_pytorch.py | 153 ++++++++++++++++-- setup.py | 2 +- 3 files changed, 148 insertions(+), 20 deletions(-) diff --git a/classifier_free_guidance_pytorch/__init__.py b/classifier_free_guidance_pytorch/__init__.py index 1c1d4a4..e67a962 100644 --- a/classifier_free_guidance_pytorch/__init__.py +++ b/classifier_free_guidance_pytorch/__init__.py @@ -1,5 +1,14 @@ -from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import NullConditioner, TextConditioner, AttentionTextConditioner -from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import classifier_free_guidance, classifier_free_guidance_class_decorator +from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import ( + NullConditioner, + TextConditioner, + AttentionTextConditioner, + TextEmbeddingReturner +) + +from classifier_free_guidance_pytorch.classifier_free_guidance_pytorch import ( + classifier_free_guidance, + classifier_free_guidance_class_decorator +) from classifier_free_guidance_pytorch.open_clip import OpenClipAdapter from classifier_free_guidance_pytorch.t5 import T5Adapter diff --git a/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py b/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py index c1caefb..3cd0e36 100644 --- a/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py +++ b/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py @@ -1,8 +1,9 @@ +from collections import namedtuple from functools import wraps, partial import torch import torch.nn.functional as F -from torch import nn, einsum +from torch import nn, einsum, Tensor from einops import rearrange, repeat, pack, unpack @@ -25,6 +26,11 @@ TEXT_CONDITIONER_NAME = 'text_conditioner' CONDITION_FUNCTION_KEY_NAME = 'cond_fns' +TextCondReturn = namedtuple('TextCondReturn', [ + 'embed', + 'mask' +]) + # helper functions def exists(val): @@ -69,7 +75,6 @@ def classifier_free_guidance( fn_params = signature(fn).parameters auto_handle_text_condition = texts_key_name not in fn_params and text_embeds_key_name not in fn_params - assert not (auto_handle_text_condition and cond_fns_keyname not in fn_params), f'{cond_fns_keyname} must be in the wrapped function for autohandling texts -> conditioning functions - ex. forward(..., {cond_fns_keyname})' @wraps(fn) def inner( @@ -103,12 +108,16 @@ def fn_maybe_with_text(self, *args, **kwargs): text_condition_input = dict(texts = texts) if exists(texts) else dict(text_embeds = text_embeds) - cond_fns = text_conditioner(**text_condition_input, cond_drop_prob = cond_drop_prob) + cond_fns, raw_text_cond = text_conditioner(**text_condition_input, cond_drop_prob = cond_drop_prob) elif isinstance(text_conditioner, NullConditioner): - cond_fns = text_conditioner() + cond_fns, raw_text_cond = text_conditioner() + + if 'cond_fns' in fn_params: + kwargs.update(cond_fns = cond_fns) - kwargs.update(cond_fns = cond_fns) + if 'raw_text_cond' in fn_params: + kwargs.update(raw_text_cond = raw_text_cond) return fn(self, *args, **kwargs) @@ -169,7 +178,8 @@ def __init__( text_condition_type: Union[ Literal['film'], Literal['attention'], - Literal['null'] + Literal['null'], + Literal['raw'], ] = 'film', text_condition_model_types: Tuple[str, ...] = ('t5',), text_condition_hidden_dims: Tuple[int, ...], @@ -182,6 +192,8 @@ def __init__( condition_klass = TextConditioner elif text_condition_type == 'attention': condition_klass = AttentionTextConditioner + elif text_condition_type == 'raw': + condition_klass = TextEmbeddingReturner else: condition_klass = NullConditioner @@ -387,7 +399,7 @@ def __init__( num_null_conditioners = len(hidden_dims) self.cond_fns = tuple(Identity() for _ in range(num_null_conditioners)) - self.register_buffer('_device_param', torch.tensor(0.), persistent = False) + self.register_buffer('_device_param', torch.tensor(0), persistent = False) @property def device(self): @@ -396,8 +408,8 @@ def device(self): def embed_texts(self, texts: List[str]): assert False, 'null conditioner cannot embed text' - def forward(self, *args, **kwarg) -> Tuple[Identity, ...]: - return self.cond_fns + def forward(self, *args, **kwarg): + return self.cond_fns, None # text conditioner with FiLM @@ -472,10 +484,13 @@ def embed_texts(self, texts: List[str]): def forward( self, texts: Optional[List[str]] = None, - text_embeds: Optional[torch.Tensor] = None, + text_embeds: Optional[Tensor] = None, cond_drop_prob = None, - repeat_batch = 1, # for robotic transformer edge case - ) -> Tuple[Callable, ...]: + repeat_batch = 1, # for robotic transformer edge case + ) -> Tuple[ + Tuple[Callable, ...], + TextCondReturn + ]: assert exists(texts) ^ exists(text_embeds) @@ -520,7 +535,7 @@ def forward( cond_fns.append(wrapper_fn(cond_fn)) - return tuple(cond_fns) + return tuple(cond_fns), TextCondReturn(text_embeds, None) # cross attention text conditioner @@ -575,7 +590,7 @@ def __init__( for hidden_dim in hidden_dims: self.conditioners.append(CrossAttention(dim_latent, hidden_dim, flash = flash)) - self.register_buffer('_device_param', torch.tensor(0.), persistent = False) + self.register_buffer('_device_param', torch.tensor(0), persistent = False) @property def device(self): @@ -603,10 +618,13 @@ def embed_texts(self, texts: List[str]): def forward( self, texts: Optional[List[str]] = None, - text_embeds: Optional[torch.Tensor] = None, + text_embeds: Optional[Tensor] = None, cond_drop_prob = None, repeat_batch = 1, # for robotic transformer edge case - ) -> Tuple[Callable, ...]: + ) -> Tuple[ + Tuple[Callable, ...], + TextCondReturn + ]: assert exists(texts) ^ exists(text_embeds) @@ -646,4 +664,105 @@ def forward( cond_fns.append(wrapper_fn(cond_fn)) - return tuple(cond_fns) + return tuple(cond_fns), TextCondReturn(text_embeds, mask) + +# return raw text embedding + +@beartype +class TextEmbeddingReturner(Conditioner): + def __init__( + self, + *, + hidden_dims: Tuple[int, ...], + model_types = 't5', + model_names = None, + cond_drop_prob = 0., + ): + super().__init__() + model_types = cast_tuple(model_types) + model_names = cast_tuple(model_names, length = len(model_types)) + + assert len(model_types) == len(model_names) + assert all([model_type in MODEL_TYPES for model_type in model_types]) + + text_models = [] + + for model_type, model_name in zip(model_types, model_names): + klass = CONDITION_CONFIG.get(model_type) + model = klass(model_name) + text_models.append(model) + + self.text_models = text_models + + self.to_latent_dims = nn.ModuleList([]) + + dim_latent = default(dim_latent, max([model.dim_latent for model in text_models])) + + for model in text_models: + self.to_latent_dims.append(nn.Linear(model.dim_latent, dim_latent)) + + self.conditioners = nn.ModuleList([]) + + self.cond_drop_prob = cond_drop_prob + + for hidden_dim in hidden_dims: + self.conditioners.append(nn.Identity()) + + self.register_buffer('_device_param', torch.tensor(0), persistent = False) + + @property + def device(self): + return next(self.buffers()).device + + def embed_texts(self, texts: List[str]): + device = self.device + + text_embeds = [] + + for text_model, to_latent in zip(self.text_models, self.to_latent_dims): + text_embed = text_model.embed_text(texts, return_text_encodings = True) + + text_embed = text_embed.to(device) + + mask = (text_embed != 0).any(dim = -1) + + text_embed = to_latent(text_embed) + text_embed = text_embed.masked_fill(~mask[..., None], 0.) + + text_embeds.append(text_embed) + + return torch.cat(text_embeds, dim = -2) + + def forward( + self, + texts: Optional[List[str]] = None, + text_embeds: Optional[Tensor] = None, + cond_drop_prob = None + ) -> Tuple[ + Tuple[Callable, ...], + TextCondReturn + ]: + + assert exists(texts) ^ exists(text_embeds) + + if self.training: + cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) + else: + assert exists(cond_drop_prob), 'when not training, cond_drop_prob must be explicitly set' + + if exists(texts): + batch = len(texts) + + elif exists(text_embeds): + batch = text_embeds.shape[0] + + if not exists(text_embeds): + text_embeds = self.embed_texts(texts) + + mask = (text_embeds != 0).any(dim = -1) + + if cond_drop_prob > 0.: + prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device = self.device) + mask = mask & prob_keep_mask + + return tuple(self.conditioners), TextCondReturn(text_embeds, mask) diff --git a/setup.py b/setup.py index eda66c0..f635ce9 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'classifier-free-guidance-pytorch', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.3.0', + version = '0.4.0', license='MIT', description = 'Classifier Free Guidance - Pytorch', author = 'Phil Wang',