diff --git a/pyproject.toml b/pyproject.toml index d0b7c96a4..f59c70226 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,7 +139,7 @@ black = true [tool.pyright] include = ["src/refiners", "tests", "scripts"] strict = ["*"] -exclude = ["**/__pycache__", "tests/weights"] +exclude = ["**/__pycache__", "tests/weights", "tests/repos"] reportMissingTypeStubs = "warning" [tool.coverage.run] diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 6cf95ee04..c29da7939 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,5 +1,5 @@ import math -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, TypeVar import torch from jaxtyping import Float @@ -454,37 +454,42 @@ def set_clip_image_embedding(self, image_embedding: Tensor) -> None: """ self.set_context("ip_adapter", {"clip_image_embedding": image_embedding}) - @overload - def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: ... - - @overload - def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: ... - - @overload - def compute_clip_image_embedding( - self, image_prompt: list[Image.Image], weights: list[float] | None = None - ) -> Tensor: ... - def compute_clip_image_embedding( self, - image_prompt: Tensor | Image.Image | list[Image.Image], + image_prompt: Image.Image | list[Image.Image] | Tensor, weights: list[float] | None = None, concat_batches: bool = True, ) -> Tensor: - """Compute the CLIP image embedding. + """Compute CLIP image embeddings from the provided image prompts. Args: - image_prompt: The image prompt to use. - weights: The scale to use for the image prompt. - concat_batches: Whether to concatenate the batches. + image_prompt: A single image or a list of images to compute embeddings for. + This can be a PIL Image, a list of PIL Images, or a Tensor. + weights: An optional list of scaling factors for the conditional embeddings. + If provided, it must have the same length as the number of images in `image_prompt`. + Each weight scales the corresponding image's conditional embedding, allowing you to + adjust the influence of each image. Defaults to uniform weights of 1.0. + concat_batches: Determines how embeddings are concatenated when multiple images are provided: + - If `True`, embeddings from multiple images are concatenated along the feature + dimension to form a longer sequence of image tokens. This is useful when you want to + treat multiple images as a single combined input. + - If `False`, embeddings are kept separate along the batch dimension, treating each image + independently. Returns: - The CLIP image embedding. + A Tensor containing the CLIP image embeddings. + The structure of the returned Tensor depends on the `concat_batches` parameter: + - If `concat_batches` is `True` and multiple images are provided, the embeddings are + concatenated along the feature dimension. + - If `concat_batches` is `False` or a single image is provided, the embeddings are returned + as a batch, with one embedding per image. """ if isinstance(image_prompt, Image.Image): image_prompt = self.preprocess_image(image_prompt) elif isinstance(image_prompt, list): - assert all(isinstance(image, Image.Image) for image in image_prompt) + assert all( + isinstance(image, Image.Image) for image in image_prompt + ), "All elements of `image_prompt` must be of PIL Images." image_prompt = torch.cat([self.preprocess_image(image) for image in image_prompt]) negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)