Skip to content

Commit

Permalink
improve: update to diffusers 0.4.2
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Oct 13, 2022
1 parent 04a558b commit 3763564
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
project = "perceptor"
copyright = "2022, Richard Löwenström"
author = "Richard Löwenström"
release = "v0.6.2"
release = "v0.6.3"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
31 changes: 20 additions & 11 deletions perceptor/models/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# @cache
class StableDiffusion(torch.nn.Module):
def __init__(
self, name="CompVis/stable-diffusion-v1-4", fp16=False, auth_token=True
self, name="CompVis/stable-diffusion-v1-4", fp16=True, auth_token=True
):
"""
Args:
Expand All @@ -37,12 +37,14 @@ def __init__(
name,
scheduler=scheduler,
use_auth_token=auth_token,
**dict(
revision="fp16",
torch_dtype=torch.float16,
)
if fp16
else dict(),
**(
dict(
revision="fp16",
torch_dtype=torch.float16,
)
if fp16
else dict()
),
)

self.vae = pipeline.vae
Expand Down Expand Up @@ -124,13 +126,15 @@ def encode(
raise Exception(f"Width must be divisible by 32, got {w}")
return (
0.18215
* self.vae.encode(diffusion_space.encode(images.to(self.device))).mode()
* self.vae.encode(
diffusion_space.encode(images.to(self.device))
).latent_dist.mode()
)

def decode(
self, latents: lantern.Tensor.dims("NCHW").float()
) -> lantern.Tensor.dims("NCHW"):
return diffusion_space.decode(self.vae.decode(latents / 0.18215))
return diffusion_space.decode(self.vae.decode(latents / 0.18215).sample)

@contextmanager
def finetuneable_vae(self):
Expand Down Expand Up @@ -167,7 +171,10 @@ def random_diffused_latents(self, shape) -> lantern.Tensor:
raise ValueError("Height must be divisible by 32")
if w % 8 != 0:
raise ValueError("Width must be divisible by 32")
return torch.randn((n, self.unet.in_channels, h // 8, w // 8)).to(self.device)
return (
torch.randn((n, self.unet.in_channels, h // 8, w // 8)).to(self.device)
* self.scheduler.init_noise_sigma
)

def indices(self, indices) -> lantern.Tensor:
if isinstance(indices, float) or isinstance(indices, int):
Expand Down Expand Up @@ -281,7 +288,9 @@ def test_stable_diffusion_step():

# compare with diffusers
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
)
scheduler.set_timesteps(1000)
pipeline = StableDiffusionPipeline.from_pretrained(
Expand Down
4 changes: 3 additions & 1 deletion perceptor/models/velocity_diffusion/velocity_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def test_conditioned_velocity_diffusion():
def test_convert_sigma_ts():
diffusion = VelocityDiffusion("cc12m_1_cfg")
from_ts = 0.3
assert from_ts == diffusion.sigmas_to_ts(diffusion.sigmas(from_ts))
assert (
from_ts - diffusion.sigmas_to_ts(diffusion.sigmas(from_ts)).squeeze()
).abs() <= 1e-5


def test_schedule_ts():
Expand Down
27 changes: 14 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "perceptor"
version = "0.6.2"
version = "0.6.3"
description = "Modular image generation library"
authors = ["Richard Löwenström <samedii@gmail.com>", "dribnet"]
readme = "README.md"
Expand Down Expand Up @@ -31,7 +31,7 @@ ninja = "^1.10.2"
lpips = "^0.1.4"
pytorch-lantern = "^0.12.0"
taming-transformers-rom1504 = "^0.0.6"
diffusers = "^0.2.4"
diffusers = "^0.4.2"
open-clip-torch = "^2.0.2"
pytorch-zero-lit = "^0.2.2"

Expand Down

0 comments on commit 3763564

Please sign in to comment.