Skip to content

Commit

Permalink
fix!: face diffusion working and refactoring
Browse files Browse the repository at this point in the history
random diffused and remove pseudo linear sampler

BREAKING CHANGE
  • Loading branch information
samedii committed Aug 1, 2022
1 parent e50ac20 commit b6aa9ec
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 202 deletions.
2 changes: 1 addition & 1 deletion perceptor/drawers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .deep_image_prior import DeepImagePrior
from .diffusion import BruteDiffusion, PseudoLinearSampler
from .diffusion import BruteDiffusion
from .raw import Raw
from .jpeg import JPEG
from .rudalle import BruteRuDalle
Expand Down
1 change: 0 additions & 1 deletion perceptor/drawers/diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .brute_diffusion import BruteDiffusion
from .pseudo_linear_sampler import PseudoLinearSampler
71 changes: 0 additions & 71 deletions perceptor/drawers/diffusion/pseudo_linear_sampler.py

This file was deleted.

6 changes: 4 additions & 2 deletions perceptor/losses/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@


class LPIPS(LossInterface):
def __init__(self, name="squeeze", linear_layers=True):
def __init__(self, name="squeeze", linear_layers=True, spatial=False):
"""
LPIPS loss. Expects images of shape (batch_size, 3, height, width) between 0 and 1.
Args:
name (str): name of the loss. Available options: ["alex", "vgg", "squeeze"]
"""
super().__init__()
self.model = lpips.LPIPS(net=name, lpips=linear_layers, verbose=False)
self.model = lpips.LPIPS(
net=name, lpips=linear_layers, spatial=spatial, verbose=False
)
self.model.eval()
self.model.requires_grad_(False)

Expand Down
44 changes: 32 additions & 12 deletions perceptor/models/guided_diffusion/guided_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,13 @@ def __init__(self, name="standard"):
diffusion = models.GuidedDiffusion("pixelart").to(device)
from_index = 999
n_steps = 200
indices = torch.linspace(from_index, 0, n_steps).to(device).long()
diffused_image = torch.randn((1, 3, 256, 256))
diffused_image = diffusion.diffuse(
init_images,
indices[0],
)
for from_index, to_index in zip(
indices[:-1], indices[1:]
):
for from_index, to_index in model.schedule_indices():
eps = diffusion.eps(diffused_image, from_index)
denoised_image = diffusion.denoise(diffused_image, from_index, eps)
diffused_image = diffusion.step(diffused_image, eps, from_index, to_index)
denoised_image = diffusion.denoise(diffused_image, to_index)
"""
super().__init__()
self.name = name
Expand Down Expand Up @@ -66,6 +58,21 @@ def __init__(self, name="standard"):
def device(self):
return next(iter(self.parameters())).device

def schedule_indices(self, from_index=999, to_index=20, n_steps=None):
if from_index < to_index:
raise ValueError("from_index must be greater than to_index")
if n_steps is None:
n_steps = (from_index - to_index) // 2
schedule_indices = torch.linspace(from_index, to_index, n_steps).long()
from_indices = schedule_indices[:-1]
to_indices = schedule_indices[1:]
if (from_indices == to_indices).any():
raise ValueError("Schedule indices must be unique")
return zip(from_indices, to_indices)

def random_diffused(self, shape):
return diffusion_space.decode(torch.randn(shape)).to(self.device)

def forward(self, diffused, from_index):
return self.denoise(diffused, from_index)

Expand Down Expand Up @@ -138,7 +145,7 @@ def sqrt_one_minus_alphas_cumprod(self, index):
.to(self.device)[None, None, None, None]
)

def step(self, from_diffused, eps, from_index, to_index, noise=None, eta=0.0):
def step(self, from_diffused, eps, from_index, to_index, noise=None, eta=1.0):
if to_index > from_index:
raise ValueError("to_index must be smaller than from_index")
if noise is None:
Expand Down Expand Up @@ -272,3 +279,16 @@ def create_model(
resblock_updown=resblock_updown,
use_new_attention_order=use_new_attention_order,
)


def test_pixelart_diffusion():
from perceptor import utils

model = GuidedDiffusion("pixelart").cuda()
diffused = model.random_diffused((1, 3, 256, 256))

for from_index, to_index in model.schedule_indices(n_steps=50):
eps = model.eps(diffused, from_index)
diffused = model.step(diffused, eps, from_index, to_index)
denoised = model.denoise(diffused, to_index)
utils.pil_image(denoised).save("tests/pixelart.png")
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ model:
out_channels: 3
model_channels: 224
attention_resolutions:
# note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 64 for f4
- 8
- 4
- 2
# note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 64 for f4
- 8
- 4
- 2
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4
- 1
- 2
- 3
- 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 3
n_embed: 8192
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
ckpt_path: models/latent-diffusion-vq-f4.pt
ddconfig:
double_z: false
z_channels: 3
Expand All @@ -48,9 +48,9 @@ model:
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
Expand All @@ -72,7 +72,6 @@ data:
params:
size: 256


lightning:
callbacks:
image_logger:
Expand All @@ -83,4 +82,4 @@ lightning:
increase_log_steps: False

trainer:
benchmark: True
benchmark: True
52 changes: 47 additions & 5 deletions perceptor/models/latent_diffusion/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@
@cache
class Face(torch.nn.Module):
def __init__(self, eta=0.0):
"""
Usage:
from tqdm import tqdm
import perceptor
model = perceptor.models.latent_diffusion.Face().cuda()
diffused_latents = model.random_latents((1, 3, 256, 256)).cuda()
for from_index, to_index in tqdm(model.schedule_indices(n_steps=50)):
denoised_latents = model.denoise(diffused_latents, from_index)
diffused_latents = model.step(
diffused_latents, denoised_latents, from_index, to_index
)
denoised_latents = model.denoise(diffused_latents, to_index)
images = model.images(denoised_latents)
"""
super().__init__()
self.eta = eta

Expand Down Expand Up @@ -42,19 +58,29 @@ def __init__(self, eta=0.0):
def device(self):
return next(iter(self.parameters())).device

def schedule_indices(self, from_index=999, to_index=50, n_steps=None):
if from_index < to_index:
raise ValueError("from_index must be greater than to_index")
if n_steps is None:
n_steps = (from_index - to_index) // 2
schedule_indices = torch.linspace(from_index, to_index, n_steps).long()
from_indices = schedule_indices[:-1]
to_indices = schedule_indices[1:]
if (from_indices == to_indices).any():
raise ValueError("Schedule indices must be unique")
return zip(from_indices, to_indices)

@staticmethod
def latent_shape(height, width):
return [4, height // 8, width // 8]
return [3, height // 4, width // 4]

def forward(self, latents, index):
return self.velocity(latents, index)

def velocity(self, latents, index):
raise NotImplementedError()

def random_latents(self, images_shape):
return torch.randn(
images_shape[0], *self.latent_shape(*images_shape[-2:]), device=self.device
(images_shape[0], *self.latent_shape(*images_shape[-2:])),
device=self.device,
)

def latents(self, images):
Expand Down Expand Up @@ -145,3 +171,19 @@ def step(

def eps(self, latents, index):
return self.model.apply_model(latents, self.ts(index), cond=None)


def test_face():
import perceptor

model = perceptor.models.latent_diffusion.Face().cuda()
diffused_latents = model.random_latents((1, 3, 256, 256)).cuda()

for from_index, to_index in model.schedule_indices(to_index=50, n_steps=50):
denoised_latents = model.denoise(diffused_latents, from_index)
diffused_latents = model.step(
diffused_latents, denoised_latents, from_index, to_index
)
denoised_latents = model.denoise(diffused_latents, to_index)
images = model.images(denoised_latents)
perceptor.utils.pil_image(images).save("tests/face.png")
33 changes: 30 additions & 3 deletions perceptor/models/latent_diffusion/finetuned_text2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,25 @@ def __init__(self, guidance_scale=5, eta=0.0):
def device(self):
return next(iter(self.parameters())).device

def schedule_indices(self, from_index=999, to_index=50, n_steps=None):
if from_index < to_index:
raise ValueError("from_index must be greater than to_index")
if n_steps is None:
n_steps = (from_index - to_index) // 2
schedule_indices = torch.linspace(from_index, to_index, n_steps).long()
from_indices = schedule_indices[:-1]
to_indices = schedule_indices[1:]
if (from_indices == to_indices).any():
raise ValueError("Schedule indices must be unique")
return zip(from_indices, to_indices)

@staticmethod
def latent_shape(height, width):
return [4, height // 8, width // 8]

def forward(self, latents, conditioning, index):
return self.denoise(latents, conditioning, index)

def velocity(self, latents, conditioning, index):
raise NotImplementedError()

def random_latents(self, images_shape):
return torch.randn(
images_shape[0], *self.latent_shape(*images_shape[-2:]), device=self.device
Expand Down Expand Up @@ -258,3 +267,21 @@ def eps(self, latents, index, conditioning):
return eps_unconditioned + self.guidance_scale * (
eps_conditioned - eps_unconditioned
)


def test_finetuned_text2image():
from tqdm import tqdm
import perceptor

model = perceptor.models.latent_diffusion.FinetunedText2Image().cuda()
conditioning = model.conditioning(["photograph of a playful cat"])
diffused_latents = model.random_latents((1, 3, 512, 512)).cuda()

for from_index, to_index in tqdm(model.schedule_indices(to_index=50, n_steps=50)):
denoised_latents = model.denoise(diffused_latents, from_index, conditioning)
diffused_latents = model.step(
diffused_latents, denoised_latents, from_index, to_index
)
denoised_latents = model.denoise(diffused_latents, to_index, conditioning)
images = model.images(denoised_latents)
perceptor.utils.pil_image(images).save("tests/finetuned_text2image.png")
Loading

0 comments on commit b6aa9ec

Please sign in to comment.