Skip to content

Commit

Permalink
Updated to pass mypy and added missing method to TargetPredictor
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Bontrager committed Oct 10, 2023
1 parent 1102e1d commit e98e2cf
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from torch import nn, Tensor
from torchmultimodal.diffusion_labs.modules.losses.vlb_loss import VLBLoss
from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
)


class DiffusionHybridLoss(nn.Module):
Expand Down Expand Up @@ -34,7 +36,7 @@ class DiffusionHybridLoss(nn.Module):

def __init__(
self,
schedule: DiffusionSchedule,
schedule: DiscreteGaussianSchedule,
simple_loss: nn.Module = nn.MSELoss(),
lmbda: float = 0.001,
):
Expand Down
6 changes: 4 additions & 2 deletions torchmultimodal/diffusion_labs/modules/losses/vlb_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import torch
from torch import nn, Tensor
from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
)


class VLBLoss(nn.Module):
Expand Down Expand Up @@ -36,7 +38,7 @@ class VLBLoss(nn.Module):
"""

def __init__(self, schedule: DiffusionSchedule):
def __init__(self, schedule: DiscreteGaussianSchedule):
super().__init__()
self.schedule = schedule

Expand Down
10 changes: 6 additions & 4 deletions torchmultimodal/diffusion_labs/predictors/noise_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from torch import Tensor
from torchmultimodal.diffusion_labs.predictors.predictor import Predictor
from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
)


class NoisePredictor(Predictor):
Expand All @@ -21,7 +23,7 @@ class NoisePredictor(Predictor):
"""

def __init__(
self, schedule: DiffusionSchedule, clamp_func: Optional[Callable] = None
self, schedule: DiscreteGaussianSchedule, clamp_func: Optional[Callable] = None
):
self.clamp_func = clamp_func
schedule.add_property("sqrt_recip_alphas_cumprod", _sqrt_recip_alphas_cumprod)
Expand All @@ -43,11 +45,11 @@ def predict_noise(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor:
return prediction


def _sqrt_recip_alphas_cumprod(schedule: DiffusionSchedule) -> Tensor:
def _sqrt_recip_alphas_cumprod(schedule: DiscreteGaussianSchedule) -> Tensor:
# pyre-ignore
return (1.0 / schedule.alphas_cumprod).sqrt()


def _sqrt_recipm1_alphas_cumprod(schedule: DiffusionSchedule) -> Tensor:
def _sqrt_recipm1_alphas_cumprod(schedule: DiscreteGaussianSchedule) -> Tensor:
# pyre-ignore
return (1.0 / schedule.alphas_cumprod - 1).sqrt()
6 changes: 4 additions & 2 deletions torchmultimodal/diffusion_labs/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from typing import Callable, Optional, Protocol, runtime_checkable

from torch import Tensor
from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
)


@runtime_checkable
Expand All @@ -18,7 +20,7 @@ class Predictor(Protocol):
trained to predict.
"""

schedule: DiffusionSchedule
schedule: DiscreteGaussianSchedule
clamp_func: Optional[Callable]

@abstractmethod
Expand Down
14 changes: 10 additions & 4 deletions torchmultimodal/diffusion_labs/predictors/target_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from torch import Tensor
from torchmultimodal.diffusion_labs.predictors.predictor import Predictor
from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
)


class TargetPredictor(Predictor):
Expand All @@ -21,7 +23,7 @@ class TargetPredictor(Predictor):
"""

def __init__(
self, schedule: DiffusionSchedule, clamp_func: Optional[Callable] = None
self, schedule: DiscreteGaussianSchedule, clamp_func: Optional[Callable] = None
):
self.clamp_func = clamp_func
self.schedule = schedule
Expand All @@ -32,5 +34,9 @@ def predict_x0(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor:
return prediction

def predict_noise(self, prediction: Tensor, xt: Tensor, t: Tensor) -> Tensor:
# TODO: For DDIM add predict_noise
pass
shape, dtype = xt.shape, xt.dtype
x_coef = self.schedule("sqrt_recip_alphas_cumprod", t, shape)
e_coef = self.schedule("sqrt_recip_alphas_cumprod_minus_one", t, shape)
x0 = self.predict_x0(prediction, xt, t)
e = (x_coef * xt - x0) / e_coef
return e.to(dtype)
10 changes: 6 additions & 4 deletions torchmultimodal/diffusion_labs/samplers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from torch import nn, Tensor
from torchmultimodal.diffusion_labs.predictors.predictor import Predictor
from torchmultimodal.diffusion_labs.samplers.sampler import Sampler
from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
)
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput


Expand Down Expand Up @@ -38,7 +40,7 @@ class DDIModule(nn.Module, Sampler):
Attributes:
model (nn.Module):
schedule (DiffusionSchedule): defines noise diffusion throughout time
schedule (DiscreteGaussianSchedule): defines noise diffusion throughout time
predictor (Predictor): used to help predict x0
eval_steps (Tensor): a subset of steps to sample at inference time
eta (float): scaling factor used in Equation 12 of Song et. al
Expand All @@ -54,7 +56,7 @@ class DDIModule(nn.Module, Sampler):
def __init__(
self,
model: nn.Module,
schedule: DiffusionSchedule,
schedule: DiscreteGaussianSchedule,
predictor: Predictor,
eval_steps: Optional[Tensor] = None,
progress_bar: bool = True,
Expand Down Expand Up @@ -136,7 +138,7 @@ def generator(
) -> Generator[Tensor, None, None]:
"""Generate xt for each t in self.eval_steps"""
steps = self.eval_steps.flip(0)
for step, next_step in zip(steps[:-1], steps[1:], strict=True):
for step, next_step in zip(steps[:-1], steps[1:]):
# Convert steps to batched tensors
t = step * torch.ones(x.size(0), device=x.device, dtype=torch.long)
t1 = next_step * torch.ones(x.size(0), device=x.device, dtype=torch.long)
Expand Down
8 changes: 5 additions & 3 deletions torchmultimodal/diffusion_labs/samplers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from torch import nn, Tensor
from torchmultimodal.diffusion_labs.predictors.predictor import Predictor
from torchmultimodal.diffusion_labs.samplers.sampler import Sampler
from torchmultimodal.diffusion_labs.schedules.schedule import DiffusionSchedule
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import (
DiscreteGaussianSchedule,
)
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput


Expand All @@ -36,7 +38,7 @@ class DDPModule(nn.Module, Sampler):
Attributes:
model (nn.Module): prediction neural network
schedule (DiffusionSchedule): defines diffusion of noise through time
schedule (DiscreteGaussianSchedule): defines diffusion of noise through time
predictor (Predictor): predictor class to handle predictions depending on the model input
eval_steps (Tensor): subset of steps to sample at inference
Expand All @@ -50,7 +52,7 @@ class DDPModule(nn.Module, Sampler):
def __init__(
self,
model: nn.Module,
schedule: DiffusionSchedule,
schedule: DiscreteGaussianSchedule,
predictor: Predictor,
eval_steps: Optional[Tensor] = None,
progress_bar: bool = True,
Expand Down

0 comments on commit e98e2cf

Please sign in to comment.