Skip to content

Commit

Permalink
improve: better gradient checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Sep 25, 2022
1 parent 56156e3 commit 9bd5cda
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 139 deletions.
2 changes: 0 additions & 2 deletions perceptor/models/open_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def __init__(
if (architecture, weights) not in open_clip.list_pretrained():
raise ValueError(f"Invalid architecture/weights: {architecture}/{weights}")

pretrained_cfg = open_clip.pretrained.get_pretrained_cfg(architecture, weights)

# softmax on cpu does not support half precision
start_device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
Expand Down
20 changes: 20 additions & 0 deletions perceptor/models/velocity_diffusion/velocity_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,26 @@ def diffuse(self, denoised_images, ts, noise=None):
alphas, sigmas = self.alphas(ts), self.sigmas(ts)
return diffusion_space.decode(denoised_xs * alphas + noise * sigmas)

def inject_noise(
self, diffused_images, ts, reversed_ts, extra_noise_multiplier=1.003
):
diffused_xs = diffusion_space.encode(diffused_images).to(self.device)

diffused_multiplier = self.alphas(reversed_ts) / self.alphas(ts)
target_sigmas = self.sigmas(reversed_ts)
additional_noise_std = (
target_sigmas.square()
- self.sigmas(ts).square() * diffused_multiplier.square()
).sqrt()

reversed_diffused_xs = (
diffused_xs * diffused_multiplier
+ additional_noise_std
* torch.randn_like(diffused_xs)
* extra_noise_multiplier
)
return diffusion_space.decode(reversed_diffused_xs)


VelocityDiffusion: Model = cache(Model)

Expand Down
47 changes: 40 additions & 7 deletions perceptor/utils/gradient_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lantern import FunctionalBase, Tensor
import torch


class GradientCheckpoint(FunctionalBase):
Expand All @@ -8,18 +9,50 @@ class GradientCheckpoint(FunctionalBase):
def __init__(self, tensor):
super().__init__(original=tensor, detached=tensor.detach().requires_grad_())

def continue_backward(self, retain_graph=False):
if self.grad is None:
def zero_grad_(self):
self.detached.grad.zero_()
return self

def backward(self, loss):
loss.backward()
gradients = self.detached.grad.clone()
self.zero_grad_()
return gradients

def continue_backward(self, gradients=None, retain_graph=False):
if self.detached.grad is None:
raise ValueError("Gradient is not defined")
return self.original.backward(self.detached.grad, retain_graph=retain_graph)

@property
def grad(self):
return self.detached.grad
if gradients is None:
return self.original.backward(self.detached.grad, retain_graph=retain_graph)
else:
return self.original.backward(gradients, retain_graph=retain_graph)

def tensor(self):
return self.detached

@staticmethod
def nonzero_mean(gradients, dim=0):
if isinstance(gradients, list):
gradients = torch.stack(gradients)
return gradients.sum(dim).div(gradients.ne(0).sum(dim).add(1e-6))

@staticmethod
def nonzero_scale(tensor, dim=None):
if isinstance(tensor, list):
tensor = torch.stack(tensor)
shape = tensor.shape
if dim is None:
tensor = tensor.flatten()
dim = 0

mask = tensor.ne(0)
mean_square = tensor.square().sum(dim) / mask.sum(dim).add(1e-6)
mean = tensor.sum(dim) / mask.sum(dim).add(1e-6)
std = (mean_square - mean.square()).sqrt().add(1e-6)
scaled_tensor = tensor / std.unsqueeze(dim).add(1e-6)
return scaled_tensor.view(*shape)


def gradient_checkpoint(tensor: Tensor) -> GradientCheckpoint:
"""
Expand All @@ -42,6 +75,6 @@ def test_gradient_checkpoint():
images = torch.zeros(1, 3, 64, 64).requires_grad_()
checkpoint = gradient_checkpoint(images * 2)
checkpoint.tensor().pow(2).mean().backward()
assert checkpoint.grad is not None
assert checkpoint.detached.grad is not None
checkpoint.continue_backward()
assert images.grad is not None
Loading

0 comments on commit 9bd5cda

Please sign in to comment.