Skip to content

Commit

Permalink
A bit of cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 16, 2023
1 parent 85725fc commit b88364b
Showing 1 changed file with 28 additions and 35 deletions.
63 changes: 28 additions & 35 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,31 +216,37 @@ def forward(
return contrastive_loss, distill_loss


class _Exchange(torch.autograd.Function):
def neighbour_exchange(from_rank, to_rank, tensor, group=None):
tensor_recv = torch.zeros_like(tensor)
send_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor,
to_rank,
group=group,
)
recv_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv,
from_rank,
group=group,
)
reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
return tensor_recv


class NeighbourExchange(torch.autograd.Function):
@staticmethod
def forward(ctx, left_rank, right_rank, group, tensor):
def forward(ctx, from_rank, to_rank, group, tensor):
ctx.group = group
ctx.left_rank = left_rank
ctx.right_rank = right_rank
tensor_from_left = torch.zeros_like(tensor)
send_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor,
right_rank, # send to the right
)
recv_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_left,
left_rank, # recv from left
)
reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
return tensor_from_left
ctx.from_rank = from_rank
ctx.to_rank = to_rank
return neighbour_exchange(from_rank, to_rank, tensor, group=group)

@staticmethod
def backward(ctx, grad_output):
return (None, None, None) + (_Exchange.apply(ctx.right_rank, ctx.left_rank, ctx.group, grad_output),)
return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)


class SigLipLoss(nn.Module):
Expand Down Expand Up @@ -296,23 +302,10 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
# 3. custom autograd.Function Exchange (gradient passed back in reverse right -> left)

# approach #3
text_features_from_left = _Exchange.apply(left_rank, right_rank, text_features_to_right)
text_features_from_left = NeighbourExchange.apply(left_rank, right_rank, text_features_to_right)

# approach #1
# text_features_from_left = torch.zeros_like(text_features_to_right)
# send_op = torch.distributed.P2POp(
# torch.distributed.isend, # send to the right
# text_features_to_right,
# right_rank,
# )
# recv_op = torch.distributed.P2POp(
# torch.distributed.irecv, # recv from left
# text_features_from_left,
# left_rank,
# )
# reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
# for req in reqs:
# req.wait()
# text_features_from_left = neighbour_exchange(left_rank, right_rank, text_features_to_right)

neg_loss = self._loss(
image_features,
Expand Down

0 comments on commit b88364b

Please sign in to comment.