Skip to content

Commit

Permalink
auto move clip to cuda if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 26, 2024
1 parent 0f8eb10 commit 2a8457a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
17 changes: 11 additions & 6 deletions classifier_free_guidance_pytorch/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ def __init__(
self,
name = DEFAULT_CLIP_NAME,
pretrained = DEFAULT_PRETRAINED_CLIP,
text_embed_pad_value = 0.
text_embed_pad_value = 0.,
auto_move_clip_cuda = True
):
name = default(name, DEFAULT_CLIP_NAME)
pretrained = default(pretrained, DEFAULT_PRETRAINED_CLIP)

clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
if torch.cuda.is_available():
clip = clip.to("cuda")

if auto_move_clip_cuda and torch.cuda.is_available():
clip = clip.cuda()

self.clip = clip
clip.eval()
Expand Down Expand Up @@ -83,9 +85,12 @@ def embed_text(
return_text_encodings = False,
output_device = None
):
if output_device is None:
output_device = next(self.clip.parameters()).device
elif output_device != next(self.clip.parameters()).device:
clip_device = next(self.clip.parameters()).device

if not exists(output_device):
output_device = clip_device

if output_device != clip_device:
self.clip = self.clip.to(output_device)

texts = self.tokenizer(texts).to(output_device)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'classifier-free-guidance-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.6.9',
version = '0.6.10',
license='MIT',
description = 'Classifier Free Guidance - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2a8457a

Please sign in to comment.