Skip to content

Commit

Permalink
freeu: add one more test for identity scales
Browse files Browse the repository at this point in the history
It should act as a NOP when [1.0, 1.0] is used for backbone and skip
scales.
  • Loading branch information
deltheil committed Dec 1, 2023
1 parent 761678d commit b306c7d
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/foundationals/latent_diffusion/test_freeu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Iterator

import pytest
import torch

from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter, FreeUResidualConcatenator
from refiners.fluxion import manual_seed


@pytest.fixture(scope="module", params=[True, False])
Expand Down Expand Up @@ -39,3 +41,27 @@ def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])


def test_freeu_identity_scales() -> None:
manual_seed(0)
text_embedding = torch.randn(1, 77, 768)
timestep = torch.randint(0, 999, size=(1, 1))
x = torch.randn(1, 4, 32, 32)

unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s

with torch.no_grad():
unet.set_timestep(timestep=timestep)
y_1 = unet(x.clone())

freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0])
freeu.inject()

with torch.no_grad():
unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone())

# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
assert torch.allclose(y_1, y_2, atol=1e-5)

0 comments on commit b306c7d

Please sign in to comment.