Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/EulerScheduler #138

Merged
merged 20 commits into from
Jan 10, 2024
Merged

Conversation

israfelsr
Copy link
Contributor

Hi!

This PR implement the EulerScheduler for Stable Diffusion. I based the implemenation in Elucidating the Design Space of Diffusion-Based Generative Models and tested the behaviour against the implementation in Diffusers.

Euler needs some extra inputs in the step function: s_t_min, s_t_max, s_churn and s_noise. I hardcoded the necessary ones for now. Should I add them with the suggested default? The rest is working as expected!

The code to test the implementation is on the scheduler test file. You can directly call the test from test_schedulers.

from typing import cast
from refiners.foundationals.latent_diffusion.schedulers import EulerScheduler
from refiners.fluxion import manual_seed
from torch import randn, Tensor, allclose

def test_euler_solver_diffusers():
    from diffusers import EulerDiscreteScheduler
    manual_seed(0)
    diffusers_scheduler = EulerDiscreteScheduler(beta_end=0.012,
                                                 beta_schedule="scaled_linear",
                                                 beta_start=0.00085,
                                                 num_train_timesteps=1000,
                                                 steps_offset=1,
                                                 timestep_spacing="trailing")
    diffusers_scheduler.set_timesteps(30)
    refiners_scheduler = EulerScheduler(num_inference_steps=30)

    sample = randn(1, 4, 32, 32)
    noise = randn(1, 4, 32, 32)

    for step, timestep in enumerate(diffusers_scheduler.timesteps):
        diffusers_output = cast(Tensor,
                                diffusers_scheduler.step(
                                    noise, timestep,
                                    sample).prev_sample)  # type: ignore
        refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)

        assert allclose(diffusers_output, refiners_output,
                        rtol=0.01), f"outputs differ at step {step}"

test_euler_solver_diffusers()

Copy link
Contributor

@limiteinductive limiteinductive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for submitting your work! I can see that most of the job has been completed.

And don't forget to lint your code with black and ruff (and check types with pyright).

poetry run black .
poetry run ruff . --fix

@limiteinductive
Copy link
Contributor

I made an end-to-end test with the following script:

from pathlib import Path
from refiners.fluxion.utils import manual_seed
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import StableDiffusion_1

import torch


torch.set_grad_enabled(False)

test_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hub = Path("/mnt/ssd2/hub/finegrain/stable-diffusion-1-5")

scheduler = EulerScheduler(30, device=test_device)
sd = StableDiffusion_1(device=test_device, scheduler=scheduler)
sd.unet.load_from_safetensors(hub / "unet.safetensors")
sd.lda.load_from_safetensors(hub / "lda.safetensors")
sd.clip_text_encoder.load_from_safetensors(hub / "CLIPTextEncoderL.safetensors")
n_steps = 30

prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd.set_num_inference_steps(n_steps)

manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)

for step in sd.steps:
    x = sd(
        x,
        step=step,
        clip_text_embedding=clip_text_embedding,
        condition_scale=7.5,
    )
predicted_image = sd.lda.decode_latents(x)
predicted_image.save("cute_cat_euler.png")

And I get this
image

@israfelsr
Copy link
Contributor Author

There was an error computing predicted_x. However, the sampling is still really bad (see images). The "better" one is using noise_schedule: NoiseSchedule = NoiseSchedule.KARRAS. I tried adding the other parameters to compute gamma but it didn't improve.

It works actually really fast, after the first iteration, the siouete of the cat is already there but it get stuck quite fast as well and the colors are over saturated on the rest of iterations.

There was an error computing predicted_x. However, the sampling is still quite poor (see images). The "better" result is from using noise_schedule: NoiseSchedule = NoiseSchedule.KARRAS.

With quadratic noise schedule:
cute_cat_euler_29

With karras:
cute_cat_euler_29-2

I tried adding other parameters to compute gamma, but it didn't improve; if anything, it became more unstable.
I also looked into the intermediate step images, and they converge faster to a plausible solution than DDIM. However, it gets stuck there without adding more details and only oversaturates the colors.

Do you have any ideas on how to improve the sampling?

Copy link
Member

@deltheil deltheil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some parts are missing - I checked that adding those should solve the issue.

@israfelsr
Copy link
Contributor Author

I came back to the 'epsilon' prediction and added the scaling functions. With this, I was able to generate a couple of images, but the algorithm is still quite unstable. The image below, for instance, was generated with s_churn=1.0 and s_noise=1.1 (I set these values to the default for now, so you can recreate it).

You can test it using the following code:

from pathlib import Path
from refiners.fluxion.utils import manual_seed
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import StableDiffusion_1
from tqdm import tqdm

import torch

torch.set_grad_enabled(False)

test_device = torch.device("mps")
hub = Path("./tests/weights")

n_steps = 30
scheduler = EulerScheduler(n_steps, device=test_device)
sd = StableDiffusion_1(device=test_device, scheduler=scheduler)
sd.unet.load_from_safetensors(hub / "unet.safetensors")
sd.lda.load_from_safetensors(hub / "lda.safetensors")
sd.clip_text_encoder.load_from_safetensors(hub /
                                           "CLIPTextEncoderL.safetensors")

prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd.compute_clip_text_embedding(
    text=prompt, negative_text=negative_prompt)
sd.set_num_inference_steps(n_steps)

manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)

x = x * scheduler.init_noise_sigma
for step in tqdm(sd.steps):
    x = sd(
        x,
        step=step,
        clip_text_embedding=clip_text_embedding,
        condition_scale=7.5,
    )
predicted_image = sd.lda.decode_latents(x)
predicted_image.save("cute_cat_euler.png")

Notice that I'm scaling the latents before the loop, and I'm also adding the scaling before the unet. This is the solution that is used in diffusers, and for the other schedulers, the scaling function just returns the same input (here)

churn-1-noise1-1

@deltheil
Copy link
Member

I came back to the 'epsilon' prediction and added the scaling functions.

Thanks! I will have a close look, stay tuned. In the meanwhile: could you please sync your main branch with https://github.com/finegrain-ai/refiners and then rebase israfelsr:feature/euler-scheduler?

Copy link
Member

@deltheil deltheil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, by getting back to the defaults (s_churn and s_noise) and fixing scale_model_input (see comment) I get:

cute_cat_euler

@deltheil
Copy link
Member

Hey @israfelsr, are you done with it? i.e. is it ready for (final) review? Thanks!

@israfelsr
Copy link
Contributor Author

Ready for the final review! 🙌🏽

@deltheil deltheil added the run-ci Run CI label Dec 27, 2023
Copy link
Member

@deltheil deltheil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see additional comments. But also, could you please rebase on main so as to incorporate latest changes?

@deltheil deltheil removed the run-ci Run CI label Jan 10, 2024
Copy link
Member

@deltheil deltheil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at the two final comments. We'd be good to go right after. Thanks!

@deltheil deltheil added the run-ci Run CI label Jan 10, 2024
@deltheil deltheil added run-ci Run CI and removed run-ci Run CI labels Jan 10, 2024
@limiteinductive limiteinductive self-requested a review January 10, 2024 10:25
@deltheil deltheil merged commit 8423c5e into finegrain-ai:main Jan 10, 2024
1 check passed
@israfelsr israfelsr deleted the feature/euler-scheduler branch January 10, 2024 10:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants