diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index 89cbb868b..6c9edfa09 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -415,8 +415,8 @@ def _generate_latent_tiles(size: _ImageSize, tile_size: _ImageSize, overlap: int """ tiles: list[_Tile] = [] - for x in range(0, size.width, tile_size.width - overlap): - for y in range(0, size.height, tile_size.height - overlap): + for x in range(0, max(size.width - overlap, 1), tile_size.width - overlap): + for y in range(0, max(size.height - overlap, 1), tile_size.height - overlap): tile = _Tile( top=max(0, y), left=max(0, x), diff --git a/tests/foundationals/latent_diffusion/conftest.py b/tests/foundationals/latent_diffusion/conftest.py index 8e05561f8..787be04bb 100644 --- a/tests/foundationals/latent_diffusion/conftest.py +++ b/tests/foundationals/latent_diffusion/conftest.py @@ -1,7 +1,6 @@ from pathlib import Path import pytest -import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline @@ -93,25 +92,6 @@ def refiners_sdxl( ) -@pytest.fixture(scope="module", params=["SD1.5", "SDXL"]) -def refiners_autoencoder( - request: pytest.FixtureRequest, - refiners_sd15_autoencoder: SD1Autoencoder, - refiners_sdxl_autoencoder: SDXLAutoencoder, - test_dtype_fp32_bf16_fp16: torch.dtype, -) -> SD1Autoencoder | SDXLAutoencoder: - model_version = request.param - match (model_version, test_dtype_fp32_bf16_fp16): - case ("SD1.5", _): - return refiners_sd15_autoencoder - case ("SDXL", torch.float16): - return refiners_sdxl_autoencoder - case ("SDXL", _): - return refiners_sdxl_autoencoder - case _: - raise ValueError(f"Unknown model version: {model_version}") - - @pytest.fixture(scope="module") def diffusers_sd15_pipeline( sd15_diffusers_runwayml_path: str, diff --git a/tests/foundationals/latent_diffusion/test_autoencoders.py b/tests/foundationals/latent_diffusion/test_autoencoders.py index d21ac3ee8..3190c8d88 100644 --- a/tests/foundationals/latent_diffusion/test_autoencoders.py +++ b/tests/foundationals/latent_diffusion/test_autoencoders.py @@ -7,7 +7,11 @@ from tests.utils import ensure_similar_images from refiners.fluxion.utils import no_grad -from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder +from refiners.foundationals.latent_diffusion import ( + LatentDiffusionAutoencoder, + SD1Autoencoder, + SDXLAutoencoder, +) @pytest.fixture(scope="module") @@ -16,17 +20,24 @@ def sample_image() -> Image.Image: if not test_image.is_file(): warn(f"could not reference image at {test_image}, skipping") pytest.skip(allow_module_level=True) - img = Image.open(test_image) # type: ignore + img = Image.open(test_image) assert img.size == (512, 512) return img -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", params=["SD1.5", "SDXL"]) def autoencoder( - refiners_autoencoder: LatentDiffusionAutoencoder, + request: pytest.FixtureRequest, + refiners_sd15_autoencoder: SD1Autoencoder, + refiners_sdxl_autoencoder: SDXLAutoencoder, test_device: torch.device, + test_dtype_fp32_bf16_fp16: torch.dtype, ) -> LatentDiffusionAutoencoder: - return refiners_autoencoder.to(test_device) + model_version = request.param + if model_version == "SDXL" and test_dtype_fp32_bf16_fp16 == torch.float16: + pytest.skip("SDXL autoencoder does not support float16") + ae = refiners_sd15_autoencoder if model_version == "SD1.5" else refiners_sdxl_autoencoder + return ae.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) @no_grad() @@ -34,7 +45,7 @@ def test_encode_decode_image(autoencoder: LatentDiffusionAutoencoder, sample_ima encoded = autoencoder.image_to_latents(sample_image) decoded = autoencoder.latents_to_image(encoded) - assert decoded.mode == "RGB" # type: ignore + assert decoded.mode == "RGB" # Ensure no saturation. The green channel (band = 1) must not max out. assert max(iter(decoded.getdata(band=1))) < 255 # type: ignore @@ -53,7 +64,7 @@ def test_encode_decode_images(autoencoder: LatentDiffusionAutoencoder, sample_im @no_grad() def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): - sample_image = sample_image.resize((2048, 2048)) # type: ignore + sample_image = sample_image.resize((2048, 2048)) with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)): encoded = autoencoder.tiled_image_to_latents(sample_image) @@ -64,7 +75,7 @@ def test_tiled_autoencoder(autoencoder: LatentDiffusionAutoencoder, sample_image @no_grad() def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): - sample_image = sample_image.resize((2048, 2048)) # type: ignore + sample_image = sample_image.resize((2048, 2048)) with autoencoder.tiled_inference(sample_image, tile_size=(512, 1024)): encoded = autoencoder.tiled_image_to_latents(sample_image) @@ -75,7 +86,7 @@ def test_tiled_autoencoder_rectangular_tiles(autoencoder: LatentDiffusionAutoenc @no_grad() def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): - sample_image = sample_image.resize((1024, 1024)) # type: ignore + sample_image = sample_image.resize((1024, 1024)) with autoencoder.tiled_inference(sample_image, tile_size=(2048, 2048)): encoded = autoencoder.tiled_image_to_latents(sample_image) @@ -87,7 +98,7 @@ def test_tiled_autoencoder_large_tile(autoencoder: LatentDiffusionAutoencoder, s @no_grad() def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image): sample_image = sample_image.crop((0, 0, 300, 500)) - sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore + sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) with autoencoder.tiled_inference(sample_image, tile_size=(512, 512)): encoded = autoencoder.tiled_image_to_latents(sample_image) @@ -96,6 +107,28 @@ def test_tiled_autoencoder_rectangular_image(autoencoder: LatentDiffusionAutoenc ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985) +@no_grad() +@pytest.mark.parametrize("img_width", [960, 968, 976, 1016, 1024, 1032]) +def test_tiled_autoencoder_pathologic_sizes( + refiners_sd15_autoencoder: SD1Autoencoder, + sample_image: Image.Image, + test_device: torch.device, + img_width: int, +): + # 968 is the pathologic case, just larger than (tile size - overlap): (128 - 8 + 1) * 8 = 968 + + autoencoder = refiners_sd15_autoencoder.to(device=test_device, dtype=torch.float32) + + sample_image = sample_image.crop((0, 0, img_width // 4, 400)) + sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) + + with autoencoder.tiled_inference(sample_image, tile_size=(1024, 1024)): + encoded = autoencoder.tiled_image_to_latents(sample_image) + result = autoencoder.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985) + + def test_value_error_tile_encode_no_context(autoencoder: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None: with pytest.raises(ValueError): autoencoder.tiled_image_to_latents(sample_image)