Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed compute_clip_image_embedding overloads #451

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ black = true
[tool.pyright]
include = ["src/refiners", "tests", "scripts"]
strict = ["*"]
exclude = ["**/__pycache__", "tests/weights"]
exclude = ["**/__pycache__", "tests/weights", "tests/repos"]
reportMissingTypeStubs = "warning"

[tool.coverage.run]
Expand Down
43 changes: 24 additions & 19 deletions src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import torch
from jaxtyping import Float
Expand Down Expand Up @@ -454,37 +454,42 @@ def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
"""
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})

@overload
def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: ...

@overload
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: ...

@overload
def compute_clip_image_embedding(
self, image_prompt: list[Image.Image], weights: list[float] | None = None
) -> Tensor: ...

def compute_clip_image_embedding(
self,
image_prompt: Tensor | Image.Image | list[Image.Image],
image_prompt: Image.Image | list[Image.Image] | Tensor,
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor:
"""Compute the CLIP image embedding.
"""Compute CLIP image embeddings from the provided image prompts.

Args:
image_prompt: The image prompt to use.
weights: The scale to use for the image prompt.
concat_batches: Whether to concatenate the batches.
image_prompt: A single image or a list of images to compute embeddings for.
This can be a PIL Image, a list of PIL Images, or a Tensor.
weights: An optional list of scaling factors for the conditional embeddings.
If provided, it must have the same length as the number of images in `image_prompt`.
Each weight scales the corresponding image's conditional embedding, allowing you to
adjust the influence of each image. Defaults to uniform weights of 1.0.
concat_batches: Determines how embeddings are concatenated when multiple images are provided:
- If `True`, embeddings from multiple images are concatenated along the feature
dimension to form a longer sequence of image tokens. This is useful when you want to
treat multiple images as a single combined input.
- If `False`, embeddings are kept separate along the batch dimension, treating each image
independently.

Returns:
The CLIP image embedding.
A Tensor containing the CLIP image embeddings.
The structure of the returned Tensor depends on the `concat_batches` parameter:
- If `concat_batches` is `True` and multiple images are provided, the embeddings are
concatenated along the feature dimension.
- If `concat_batches` is `False` or a single image is provided, the embeddings are returned
as a batch, with one embedding per image.
"""
if isinstance(image_prompt, Image.Image):
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
assert all(isinstance(image, Image.Image) for image in image_prompt)
assert all(
isinstance(image, Image.Image) for image in image_prompt
), "All elements of `image_prompt` must be of PIL Images."
image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt])

negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
Expand Down