diff --git a/classifier_free_guidance_pytorch/open_clip.py b/classifier_free_guidance_pytorch/open_clip.py index 2aa4fe0..d3d1bca 100644 --- a/classifier_free_guidance_pytorch/open_clip.py +++ b/classifier_free_guidance_pytorch/open_clip.py @@ -30,14 +30,16 @@ def __init__( self, name = DEFAULT_CLIP_NAME, pretrained = DEFAULT_PRETRAINED_CLIP, - text_embed_pad_value = 0. + text_embed_pad_value = 0., + auto_move_clip_cuda = True ): name = default(name, DEFAULT_CLIP_NAME) pretrained = default(pretrained, DEFAULT_PRETRAINED_CLIP) clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained) - if torch.cuda.is_available(): - clip = clip.to("cuda") + + if auto_move_clip_cuda and torch.cuda.is_available(): + clip = clip.cuda() self.clip = clip clip.eval() @@ -83,9 +85,12 @@ def embed_text( return_text_encodings = False, output_device = None ): - if output_device is None: - output_device = next(self.clip.parameters()).device - elif output_device != next(self.clip.parameters()).device: + clip_device = next(self.clip.parameters()).device + + if not exists(output_device): + output_device = clip_device + + if output_device != clip_device: self.clip = self.clip.to(output_device) texts = self.tokenizer(texts).to(output_device) diff --git a/setup.py b/setup.py index 44bae0a..436d1b7 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'classifier-free-guidance-pytorch', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.6.9', + version = '0.6.10', license='MIT', description = 'Classifier Free Guidance - Pytorch', author = 'Phil Wang',