Skip to content

Commit

Permalink
feature: gradient checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Aug 8, 2022
1 parent 3a36c10 commit 92211eb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
1 change: 1 addition & 0 deletions perceptor/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .cache import cache
from .gradient_checkpoint import gradient_checkpoint
from .pil_image import pil_image
47 changes: 47 additions & 0 deletions perceptor/utils/gradient_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from lantern import FunctionalBase, Tensor


class GradientCheckpoint(FunctionalBase):
original: Tensor
detached: Tensor

def __init__(self, tensor):
super().__init__(original=tensor, detached=tensor.detach().requires_grad_())

def continue_backward(self):
if self.grad is None:
raise ValueError("Gradient is not defined")
return self.original.backward(self.detached.grad)

@property
def grad(self):
return self.detached.grad

def tensor(self):
return self.detached


def gradient_checkpoint(tensor: Tensor) -> GradientCheckpoint:
"""
Gradient checkpointing to save compute for common part of graph.
Usage:
checkpoint = gradient_checkpoint(images)
for text_loss in text_losses:
text_loss(checkpoint.tensor()).backward()
checkpoint.continue_backward()
"""
return GradientCheckpoint(tensor)


def test_gradient_checkpoint():
import torch

with torch.enable_grad():
images = torch.zeros(1, 3, 64, 64).requires_grad_()
checkpoint = gradient_checkpoint(images * 2)
checkpoint.tensor().pow(2).mean().backward()
assert checkpoint.grad is not None
checkpoint.continue_backward()
assert images.grad is not None

0 comments on commit 92211eb

Please sign in to comment.