Skip to content

Commit

Permalink
modify some foundational tests to also test in float16 and bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Oct 3, 2024
1 parent 4d4079b commit f9b87fc
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 45 deletions.
35 changes: 23 additions & 12 deletions tests/foundationals/clip/test_image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@


@pytest.fixture(scope="module")
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH:
def our_encoder(
test_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> CLIPImageEncoderH:
weights = test_weights_path / "CLIPImageEncoderH.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPImageEncoderH(device=test_device)
encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
tensors = load_from_safetensors(weights)
encoder.load_state_dict(tensors)
return encoder
Expand All @@ -31,24 +35,31 @@ def stabilityai_unclip_weights_path(test_weights_path: Path):


@pytest.fixture(scope="module")
def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection:
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore
test_device # type: ignore
)
def ref_encoder(
stabilityai_unclip_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> CLIPVisionModelWithProjection:
return CLIPVisionModelWithProjection.from_pretrained( # type: ignore
stabilityai_unclip_weights_path,
subfolder="image_encoder",
).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)


@no_grad()
@pytest.mark.flaky(reruns=3)
def test_encoder(
ref_encoder: CLIPVisionModelWithProjection,
our_encoder: CLIPImageEncoderH,
test_device: torch.device,
):
x = torch.randn(1, 3, 224, 224).to(test_device)
assert ref_encoder.dtype == our_encoder.dtype
assert ref_encoder.device == our_encoder.device
x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device)

with no_grad():
ref_embeddings = ref_encoder(x).image_embeds
our_embeddings = our_encoder(x)
ref_embeddings = ref_encoder(x).image_embeds
our_embeddings = our_encoder(x)

assert ref_embeddings.shape == (1, 1024)
assert our_embeddings.shape == (1, 1024)

assert (our_embeddings - ref_embeddings).abs().max() < 0.01
assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05)
31 changes: 20 additions & 11 deletions tests/foundationals/clip/test_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@


@pytest.fixture(scope="module")
def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPTextEncoderL:
def our_encoder(
test_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_fp16: torch.dtype,
) -> CLIPTextEncoderL:
weights = test_weights_path / "CLIPTextEncoderL.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = CLIPTextEncoderL(device=test_device)
tensors = load_from_safetensors(weights)
encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16)
encoder.load_state_dict(tensors)
return encoder

Expand All @@ -56,8 +60,15 @@ def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer:


@pytest.fixture(scope="module")
def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> transformers.CLIPTextModel:
return transformers.CLIPTextModel.from_pretrained(runwayml_weights_path, subfolder="text_encoder").to(test_device) # type: ignore
def ref_encoder(
runwayml_weights_path: Path,
test_device: torch.device,
test_dtype_fp32_fp16: torch.dtype,
) -> transformers.CLIPTextModel:
return transformers.CLIPTextModel.from_pretrained( # type: ignore
runwayml_weights_path,
subfolder="text_encoder",
).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore


def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL):
Expand All @@ -70,12 +81,12 @@ def prompt(request: pytest.FixtureRequest):
return long_prompt if request.param == "<long prompt>" else request.param


@no_grad()
def test_encoder(
prompt: str,
ref_tokenizer: transformers.CLIPTokenizer,
ref_encoder: transformers.CLIPTextModel,
our_encoder: CLIPTextEncoderL,
test_device: torch.device,
):
ref_tokens = ref_tokenizer( # type: ignore
prompt,
Expand All @@ -89,18 +100,16 @@ def test_encoder(
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens)

with no_grad():
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder(prompt)
ref_embeddings = ref_encoder(ref_tokens.to(device=ref_encoder.device))[0]
our_embeddings = our_encoder(prompt)

assert ref_embeddings.shape == (1, 77, 768)
assert our_embeddings.shape == (1, 77, 768)

# FG-336 - Not strictly equal because we do not use the same implementation
# of self-attention. We use `scaled_dot_product_attention` which can have
# numerical differences depending on the backend.
# Also we use FP16 weights.
assert (our_embeddings - ref_embeddings).abs().max() < 0.01
# numerical differences depending on the backend. Also we use FP16 weights.
torch.testing.assert_close(our_embeddings, ref_embeddings, atol=0.035, rtol=0.0)


def test_list_string_tokenizer(
Expand Down
15 changes: 8 additions & 7 deletions tests/foundationals/dinov2/test_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_dinov2_facebook_weights(
) -> None:
manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
size=(1, 3, resolution, resolution),
device=test_device,
)

Expand All @@ -129,27 +129,28 @@ def test_dinov2_facebook_weights(


@no_grad()
def test_dinov2_float16(
def test_dinov2(
resolution: int,
test_dtype_fp32_bf16_fp16: torch.dtype,
test_device: torch.device,
) -> None:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()

model = DINOv2_small(device=test_device, dtype=torch.float16)
model = DINOv2_small(device=test_device, dtype=test_dtype_fp32_bf16_fp16)

manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
size=(1, 3, resolution, resolution),
device=test_device,
dtype=torch.float16,
dtype=test_dtype_fp32_bf16_fp16,
)

output = model(input_data)
sequence_length = (resolution // model.patch_size) ** 2 + 1
assert output.shape == (1, sequence_length, model.embedding_dim)
assert output.dtype == torch.float16
assert output.dtype == test_dtype_fp32_bf16_fp16


@no_grad()
Expand All @@ -162,7 +163,7 @@ def test_dinov2_batch_size(
batch_size = 4
manual_seed(2)
input_data = torch.randn(
(batch_size, 3, resolution, resolution),
size=(batch_size, 3, resolution, resolution),
device=test_device,
)

Expand Down
45 changes: 33 additions & 12 deletions tests/foundationals/latent_diffusion/test_auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,46 @@
from PIL import Image
from tests.utils import ensure_similar_images

from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder


@pytest.fixture(scope="module")
def ref_path() -> Path:
return Path(__file__).parent / "test_auto_encoder_ref"


@pytest.fixture(scope="module")
def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
lda_weights = test_weights_path / "lda.safetensors"
if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping")
pytest.skip(allow_module_level=True)
encoder = LatentDiffusionAutoencoder(device=test_device)
tensors = load_from_safetensors(lda_weights)
encoder.load_state_dict(tensors)
return encoder
@pytest.fixture(scope="module", params=["SD1.5", "SDXL"])
def lda(
request: pytest.FixtureRequest,
test_weights_path: Path,
test_dtype_fp32_bf16_fp16: torch.dtype,
test_device: torch.device,
) -> LatentDiffusionAutoencoder:
model_version = request.param
match (model_version, test_dtype_fp32_bf16_fp16):
case ("SD1.5", _):
weight_path = test_weights_path / "lda.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SD1Autoencoder().load_from_safetensors(weight_path)
case ("SDXL", torch.float16):
weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SDXLAutoencoder().load_from_safetensors(weight_path)
case ("SDXL", _):
weight_path = test_weights_path / "sdxl-lda.safetensors"
if not weight_path.is_file():
warn(f"could not find weights at {weight_path}, skipping")
pytest.skip(allow_module_level=True)
model = SDXLAutoencoder().load_from_safetensors(weight_path)
case _:
raise ValueError(f"Unknown model version: {model_version}")
model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16)
return model


@pytest.fixture(scope="module")
Expand Down
77 changes: 77 additions & 0 deletions tests/foundationals/latent_diffusion/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
from PIL import Image

from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting, StableDiffusion_XL
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel


@no_grad()
def test_sample_noise_zero_offset(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
manual_seed(2)
latents_0 = LatentDiffusionModel.sample_noise(
size=(1, 4, 64, 64),
device=test_device,
dtype=test_dtype_fp32_bf16_fp16,
)
manual_seed(2)
latents_1 = LatentDiffusionModel.sample_noise(
size=(1, 4, 64, 64),
offset_noise=0.0, # should be no-op
device=test_device,
dtype=test_dtype_fp32_bf16_fp16,
)

assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)


@no_grad()
def test_sd15_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
sd = StableDiffusion_1(device=test_device, dtype=test_dtype_fp32_bf16_fp16)

# prepare inputs
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
text_embedding = sd.compute_clip_text_embedding("")

# run the pipeline of models, for a single step
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)

assert output.shape == (1, 4, 64, 64)


@no_grad()
def test_sd15_inpainting_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
sd = StableDiffusion_1_Inpainting(device=test_device, dtype=test_dtype_fp32_bf16_fp16)

# prepare inputs
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
target_image = Image.new("RGB", (512, 512))
mask = Image.new("L", (512, 512))
sd.set_inpainting_conditions(target_image=target_image, mask=mask)
text_embedding = sd.compute_clip_text_embedding("")

# run the pipeline of models, for a single step
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding)

assert output.shape == (1, 4, 64, 64)


@no_grad()
def test_sdxl_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None:
sd = StableDiffusion_XL(device=test_device, dtype=test_dtype_fp32_bf16_fp16)

# prepare inputs
latent_noise = torch.randn(1, 4, 128, 128, device=test_device, dtype=test_dtype_fp32_bf16_fp16)
text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding("")
time_ids = sd.default_time_ids

# run the pipeline of models, for a single step
output = sd(
latent_noise,
step=0,
clip_text_embedding=text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)

assert output.shape == (1, 4, 128, 128)
12 changes: 9 additions & 3 deletions tests/foundationals/latent_diffusion/test_sd15_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@


@pytest.fixture(scope="module")
def refiners_sd15_unet(test_device: torch.device) -> SD1UNet:
unet = SD1UNet(in_channels=4, device=test_device)
return unet
def refiners_sd15_unet(
test_device: torch.device,
test_dtype_fp32_bf16_fp16: torch.dtype,
) -> SD1UNet:
return SD1UNet(
in_channels=4,
device=test_device,
dtype=test_dtype_fp32_bf16_fp16,
)


def test_unet_context_flush(refiners_sd15_unet: SD1UNet):
Expand Down

0 comments on commit f9b87fc

Please sign in to comment.