Skip to content

Commit

Permalink
refactor!: easier to use timestemp schedules in diffusion
Browse files Browse the repository at this point in the history
BREAKING CHANGE
  • Loading branch information
samedii committed Aug 8, 2022
1 parent b6aa9ec commit 3a36c10
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 77 deletions.
1 change: 0 additions & 1 deletion perceptor/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .lpips import LPIPS
from .super_resolution import SuperResolution, SuperResolutionDiscriminator
from .memorability import Memorability
from .midas_depth import MidasDepth
from .open_clip import OpenCLIP
from .resize import Resize
from .ruclip import RuCLIP
Expand Down
41 changes: 0 additions & 41 deletions perceptor/losses/midas_depth.py

This file was deleted.

26 changes: 25 additions & 1 deletion perceptor/losses/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,31 @@


class OpenCLIP(LossInterface):
def __init__(self, architecture="ViT-B-32-quickgelu", weights="laion400m_e31"):
def __init__(self, architecture="ViT-B-32", weights="laion2b_e16"):
"""
Args:
archicture (str): name of the clip model
weights (str): name of the weights
Available weight/model combinations are (in order of relevance):
- ("ViT-B-32", "laion2b_e16") (65.62%)
- ("ViT-B-16-plus-240", "laion400m_e32") (69.21%)
- ("ViT-B-16", "laion400m_e32") (67.07%)
- ("ViT-B-32", "laion400m_e32") (62.96%)
- ("ViT-L-14", "laion400m_e32") (72.77%)
- ("RN101", "yfcc15m") (34.8%)
- ("RN50", "yfcc15m") (32.7%)
- ("RN50", "cc12m") (36.45%)
- ("RN50-quickgelu", "openai")
- ("RN101-quickgelu", "openai")
- ("RN50x4", "openai")
- ("RN50x16", "openai")
- ("RN50x64", "openai")
- ("ViT-B-32-quickgelu", "openai")
- ("ViT-B-16", "openai")
- ("ViT-L-14", "openai")
- ("ViT-L-14-336", "openai")
"""
super().__init__()
self.architecture = architecture
self.weights = weights
Expand Down
53 changes: 28 additions & 25 deletions perceptor/models/monster_diffusion/monster_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,15 @@ def training_ts(size):
random_ts = (diffusion.P_mean + torch.randn(size) * diffusion.P_std).exp()
return random_ts

@staticmethod
def schedule_ts(n_steps):
ramp = torch.linspace(0, 1, n_steps)
def _schedule_ts(self, n_steps):
ramp = torch.linspace(0, 1, n_steps).to(self.device)
min_inv_rho = diffusion.sigma_min ** (1 / diffusion.rho)
max_inv_rho = diffusion.sigma_max ** (1 / diffusion.rho)
return (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** diffusion.rho

@staticmethod
def evaluation_ts():
n_steps = 1000
schedule_ts = MonsterDiffusion.schedule_ts(n_steps)
return torch.cat(
[
schedule_ts,
MonsterDiffusion.reversed_ts(schedule_ts, n_steps),
]
).unique()
def schedule_ts(self, n_steps):
schedule_ts = self._schedule_ts(n_steps)
return zip(schedule_ts[:-1], schedule_ts[1:])

@staticmethod
def sigmas(ts):
Expand All @@ -70,8 +62,7 @@ def alphas(ts):
@staticmethod
def random_noise(size):
return standardize.decode(
torch.randn(size, *settings.INPUT_SHAPE)
* MonsterDiffusion.sigmas(MonsterDiffusion.schedule_ts(100)[:1])
torch.randn(size, *settings.INPUT_SHAPE) * diffusion.sigma_max
)

@staticmethod
Expand Down Expand Up @@ -147,7 +138,7 @@ def forward(
return PredictionBatch(
denoised_xs=denoised_xs,
diffused_images=diffused_images,
ts=ts,
ts=torch.as_tensor(ts).flatten().to(self.device),
)

def predictions_(
Expand Down Expand Up @@ -237,11 +228,12 @@ def elucidated_sample(
)

n_steps = n_evaluations // 2
schedule_ts = self.schedule_ts(n_steps)[:, None].repeat(1, size).to(self.device)
i = 0
progress = tqdm(total=n_steps, disable=not progress, leave=False)
for from_ts, to_ts in zip(schedule_ts[:-1], schedule_ts[1:]):
reversed_ts = self.reversed_ts(from_ts, n_steps).clamp(max=schedule_ts[0])
for from_ts, to_ts in self.schedule_ts(n_steps):
reversed_ts = self.reversed_ts(from_ts, n_steps).clamp(
max=diffusion.sigma_max
)
reversed_diffused_images = self.inject_noise(
diffused_images, from_ts, reversed_ts
)
Expand Down Expand Up @@ -317,13 +309,11 @@ def linear_multistep_sample(
diffused_images = diffused_images.to(self.device)

n_steps = n_evaluations
schedule_ts = self.schedule_ts(n_steps)[:, None].repeat(1, size).to(self.device)
schedule_ts = self._schedule_ts(n_steps)

epses = list()
progress = tqdm(total=n_steps, disable=not progress, leave=False)
for from_index, from_ts, to_ts in zip(
range(n_steps), schedule_ts[:-1], schedule_ts[1:]
):
for from_index, (from_ts, to_ts) in enumerate(self.schedule_ts(n_steps)):

predictions = self.predictions(
diffused_images,
Expand All @@ -338,7 +328,7 @@ def linear_multistep_sample(
coeffs = [
self.linear_multistep_coeff(
current_order,
self.sigmas(schedule_ts[:, 0]).cpu().flatten(),
self.sigmas(schedule_ts).cpu().flatten(),
from_index,
to_index,
)
Expand All @@ -364,5 +354,18 @@ def linear_multistep_sample(


def test_monster_diffusion():
from perceptor import utils

model = MonsterDiffusion().cuda()
for images in model.sample(size=1, n_evaluations=50):
pass
utils.pil_image(images).save("tests/monster_diffusion.png")


def test_monster_diffusion_lms():
from perceptor import utils

model = MonsterDiffusion().cuda()
model.sample(size=1, n_evaluations=4)
for images in model.linear_multistep_sample(size=1, n_evaluations=50):
pass
utils.pil_image(images).save("tests/monster_diffusion_lms.png")
4 changes: 4 additions & 0 deletions perceptor/models/monster_diffusion/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def __iter__(self):

@staticmethod
def sigmas(ts):
if isinstance(ts, float):
ts = torch.as_tensor(ts)
if ts.ndim == 0:
return torch.full((1,), ts).to(ts.device)
return ts[:, None, None, None]

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions perceptor/models/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class OpenCLIP(torch.nn.Module):
def __init__(self, archicture="ViT-B-32", weights="laion2b_e16"):
"""
Args:
archicture: name of the clip model
weights: name of the weights
archicture (str): name of the clip model
weights (str): name of the weights
Available weight/model combinations are (in order of relevance):
- ("ViT-B-32", "laion2b_e16") (65.62%)
Expand Down
12 changes: 7 additions & 5 deletions perceptor/models/velocity_diffusion/velocity_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,17 @@ def shape(self):
return self.model.shape

@staticmethod
def schedule_ts(n_steps, from_sigma=1, to_sigma=1e-2, rho=0.7):
def schedule_ts(n_steps=500, from_sigma=1, to_sigma=1e-2, rho=0.7):
ramp = torch.linspace(0, 1, n_steps + 1)
min_inv_rho = to_sigma ** (1 / rho)
max_inv_rho = from_sigma ** (1 / rho)
return Model.sigmas_to_ts(
schedule_ts = Model.sigmas_to_ts(
(max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
)
return zip(schedule_ts[:-1], schedule_ts[1:])

def random_diffused(self, shape):
return diffusion_space.decode(torch.randn(shape)).to(self.device)

@staticmethod
def sigmas_to_ts(sigmas):
Expand Down Expand Up @@ -143,11 +147,9 @@ def test_velocity_diffusion():

n_iterations = 3

steps = diffusion.schedule_ts(n_iterations, from_sigma=1.0, rho=0.7)

diffused_images = torch.randn((1, 3, 512, 512)).to(device).add(1).div(2)

for from_ts, to_ts in zip(steps[:-1], steps[1:]):
for from_ts, to_ts in diffusion.schedule_ts(n_iterations, from_sigma=1.0, rho=0.7):
if (from_ts < 1.0).all():
new_from_ts = from_ts * 1.003
diffused_images = diffusion.predictions(
Expand Down
4 changes: 3 additions & 1 deletion perceptor/utils/pil_image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from lantern import Tensor


def pil_image(images):
def pil_image(images: Tensor) -> Image:
if images.max() > 1 or images.min() < 0:
print("Warning: images are not in range [0, 1]")
n, c, h, w = images.shape
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "perceptor"
version = "0.4.0"
version = "0.5.0"
description = ""
authors = ["Richard Löwenström <samedii@gmail.com>", "dribnet"]
readme = "README.md"
Expand Down

0 comments on commit 3a36c10

Please sign in to comment.