Skip to content

Commit

Permalink
improve: stable diffusion vae mode and forced denoised
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Sep 27, 2022
1 parent 8daccb4 commit da87bf7
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 23 deletions.
9 changes: 3 additions & 6 deletions perceptor/models/guided_diffusion/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,9 @@ def dynamic_threshold(self, quantile=0.95) -> "Predictions":

def forced_denoised_images(self, denoised_images) -> "Predictions":
denoised_xs = diffusion_space.encode(denoised_images)
if (self.from_sigmas >= 1e-3).all():
predicted_noise = (
self.from_diffused_xs - denoised_xs * self.from_alphas
) / self.from_sigmas
else:
predicted_noise = self.predicted_noise
predicted_noise = (
self.from_diffused_xs - denoised_xs * self.from_alphas
) / self.from_sigmas.clamp(min=1e-7)
return self.replace(predicted_noise=predicted_noise)

def forced_predicted_noise(self, predicted_noise) -> "Predictions":
Expand Down
15 changes: 4 additions & 11 deletions perceptor/models/stable_diffusion/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,10 @@ def dynamic_threshold(self, quantile=0.95) -> "Predictions":
return self.forced_denoised_latents(denoised_latents)

def forced_denoised_latents(self, denoised_latents) -> "Predictions":
# if (self.from_sigmas >= 1e-3).all():
# predicted_noise = (
# self.from_diffused_latents - denoised_latents * self.from_alphas
# ) / self.from_sigmas
# else:
# predicted_noise = self.predicted_noise
# return self.replace(
# velocities=self.from_alphas * predicted_noise
# - self.from_sigmas * denoised_latents
# )
pass
predicted_noise = (
self.from_diffused_latents - denoised_latents * self.from_alphas
) / self.from_sigmas.clamp(min=1e-7)
return self.replace(predicted_noise=predicted_noise)

def forced_predicted_noise(self, predicted_noise) -> "Predictions":
return self.replace(predicted_noise=predicted_noise)
Expand Down
2 changes: 1 addition & 1 deletion perceptor/models/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def encode(
raise Exception(f"Width must be divisible by 32, got {w}")
return (
0.18215
* self.vae.encode(diffusion_space.encode(images.to(self.device))).sample()
* self.vae.encode(diffusion_space.encode(images.to(self.device))).mode()
)

def decode(
Expand Down
8 changes: 4 additions & 4 deletions perceptor/models/velocity_diffusion/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def step(self, to_ts, eta=0.0):
return diffusion_space.decode(to_diffused_xs)

def correction(self, previous: "Predictions"):
return previous.forced_denoised(
return previous.forced_denoised_images(
(self.denoised_images + previous.denoised_images) / 2
)

Expand Down Expand Up @@ -169,12 +169,12 @@ def dynamic_threshold(self, quantile=0.95) -> "Predictions":
# / dynamic_threshold
# imagen's dynamic thresholding divides by threshold but this makes the images gray
)
return self.forced_denoised(diffusion_space.decode(denoised_xs))
return self.forced_denoised_images(diffusion_space.decode(denoised_xs))

def static_threshold(self):
return self.forced_denoised(clamp_with_grad(self.denoised_images, 0, 1))
return self.forced_denoised_images(clamp_with_grad(self.denoised_images, 0, 1))

def forced_denoised(self, denoised_images) -> "Predictions":
def forced_denoised_images(self, denoised_images) -> "Predictions":
denoised_xs = diffusion_space.encode(denoised_images)
if (self.from_sigmas >= 1e-3).all():
predicted_noise = (
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "perceptor"
version = "0.6.0"
version = "0.6.1"
description = "Modular image generation library"
authors = ["Richard Löwenström <samedii@gmail.com>", "dribnet"]
readme = "README.md"
Expand All @@ -14,6 +14,7 @@ numpy = "^1.22.2"
tqdm = "^4.62.3"
einops = "^0.4.0"
imageio = "^2.14.1"

kornia = "^0.6.3"
Pillow = "^9.0.1"
timm = "^0.5.4"
Expand Down

0 comments on commit da87bf7

Please sign in to comment.