diff --git a/perceptor/models/guided_diffusion/predictions.py b/perceptor/models/guided_diffusion/predictions.py index b5b6ca6..ae7ae4f 100644 --- a/perceptor/models/guided_diffusion/predictions.py +++ b/perceptor/models/guided_diffusion/predictions.py @@ -172,12 +172,9 @@ def dynamic_threshold(self, quantile=0.95) -> "Predictions": def forced_denoised_images(self, denoised_images) -> "Predictions": denoised_xs = diffusion_space.encode(denoised_images) - if (self.from_sigmas >= 1e-3).all(): - predicted_noise = ( - self.from_diffused_xs - denoised_xs * self.from_alphas - ) / self.from_sigmas - else: - predicted_noise = self.predicted_noise + predicted_noise = ( + self.from_diffused_xs - denoised_xs * self.from_alphas + ) / self.from_sigmas.clamp(min=1e-7) return self.replace(predicted_noise=predicted_noise) def forced_predicted_noise(self, predicted_noise) -> "Predictions": diff --git a/perceptor/models/stable_diffusion/predictions.py b/perceptor/models/stable_diffusion/predictions.py index 981e1f4..08825de 100644 --- a/perceptor/models/stable_diffusion/predictions.py +++ b/perceptor/models/stable_diffusion/predictions.py @@ -188,17 +188,10 @@ def dynamic_threshold(self, quantile=0.95) -> "Predictions": return self.forced_denoised_latents(denoised_latents) def forced_denoised_latents(self, denoised_latents) -> "Predictions": - # if (self.from_sigmas >= 1e-3).all(): - # predicted_noise = ( - # self.from_diffused_latents - denoised_latents * self.from_alphas - # ) / self.from_sigmas - # else: - # predicted_noise = self.predicted_noise - # return self.replace( - # velocities=self.from_alphas * predicted_noise - # - self.from_sigmas * denoised_latents - # ) - pass + predicted_noise = ( + self.from_diffused_latents - denoised_latents * self.from_alphas + ) / self.from_sigmas.clamp(min=1e-7) + return self.replace(predicted_noise=predicted_noise) def forced_predicted_noise(self, predicted_noise) -> "Predictions": return self.replace(predicted_noise=predicted_noise) diff --git a/perceptor/models/stable_diffusion/stable_diffusion.py b/perceptor/models/stable_diffusion/stable_diffusion.py index 9d07f80..e7b831e 100644 --- a/perceptor/models/stable_diffusion/stable_diffusion.py +++ b/perceptor/models/stable_diffusion/stable_diffusion.py @@ -120,7 +120,7 @@ def encode( raise Exception(f"Width must be divisible by 32, got {w}") return ( 0.18215 - * self.vae.encode(diffusion_space.encode(images.to(self.device))).sample() + * self.vae.encode(diffusion_space.encode(images.to(self.device))).mode() ) def decode( diff --git a/perceptor/models/velocity_diffusion/predictions.py b/perceptor/models/velocity_diffusion/predictions.py index 7c1f128..3ae2e1f 100644 --- a/perceptor/models/velocity_diffusion/predictions.py +++ b/perceptor/models/velocity_diffusion/predictions.py @@ -104,7 +104,7 @@ def step(self, to_ts, eta=0.0): return diffusion_space.decode(to_diffused_xs) def correction(self, previous: "Predictions"): - return previous.forced_denoised( + return previous.forced_denoised_images( (self.denoised_images + previous.denoised_images) / 2 ) @@ -169,12 +169,12 @@ def dynamic_threshold(self, quantile=0.95) -> "Predictions": # / dynamic_threshold # imagen's dynamic thresholding divides by threshold but this makes the images gray ) - return self.forced_denoised(diffusion_space.decode(denoised_xs)) + return self.forced_denoised_images(diffusion_space.decode(denoised_xs)) def static_threshold(self): - return self.forced_denoised(clamp_with_grad(self.denoised_images, 0, 1)) + return self.forced_denoised_images(clamp_with_grad(self.denoised_images, 0, 1)) - def forced_denoised(self, denoised_images) -> "Predictions": + def forced_denoised_images(self, denoised_images) -> "Predictions": denoised_xs = diffusion_space.encode(denoised_images) if (self.from_sigmas >= 1e-3).all(): predicted_noise = ( diff --git a/pyproject.toml b/pyproject.toml index c523487..896113c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "perceptor" -version = "0.6.0" +version = "0.6.1" description = "Modular image generation library" authors = ["Richard Löwenström ", "dribnet"] readme = "README.md" @@ -14,6 +14,7 @@ numpy = "^1.22.2" tqdm = "^4.62.3" einops = "^0.4.0" imageio = "^2.14.1" + kornia = "^0.6.3" Pillow = "^9.0.1" timm = "^0.5.4"