From b88364b9b9078df478e111dec6884258e588ca24 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 16 Sep 2023 10:51:56 -0700 Subject: [PATCH] A bit of cleanup --- src/open_clip/loss.py | 63 +++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 06134015b..638e453a3 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -216,31 +216,37 @@ def forward( return contrastive_loss, distill_loss -class _Exchange(torch.autograd.Function): +def neighbour_exchange(from_rank, to_rank, tensor, group=None): + tensor_recv = torch.zeros_like(tensor) + send_op = torch.distributed.P2POp( + torch.distributed.isend, + tensor, + to_rank, + group=group, + ) + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + tensor_recv, + from_rank, + group=group, + ) + reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + return tensor_recv + + +class NeighbourExchange(torch.autograd.Function): @staticmethod - def forward(ctx, left_rank, right_rank, group, tensor): + def forward(ctx, from_rank, to_rank, group, tensor): ctx.group = group - ctx.left_rank = left_rank - ctx.right_rank = right_rank - tensor_from_left = torch.zeros_like(tensor) - send_op = torch.distributed.P2POp( - torch.distributed.isend, - tensor, - right_rank, # send to the right - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, - tensor_from_left, - left_rank, # recv from left - ) - reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) - for req in reqs: - req.wait() - return tensor_from_left + ctx.from_rank = from_rank + ctx.to_rank = to_rank + return neighbour_exchange(from_rank, to_rank, tensor, group=group) @staticmethod def backward(ctx, grad_output): - return (None, None, None) + (_Exchange.apply(ctx.right_rank, ctx.left_rank, ctx.group, grad_output),) + return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) class SigLipLoss(nn.Module): @@ -296,23 +302,10 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output # 3. custom autograd.Function Exchange (gradient passed back in reverse right -> left) # approach #3 - text_features_from_left = _Exchange.apply(left_rank, right_rank, text_features_to_right) + text_features_from_left = NeighbourExchange.apply(left_rank, right_rank, text_features_to_right) # approach #1 - # text_features_from_left = torch.zeros_like(text_features_to_right) - # send_op = torch.distributed.P2POp( - # torch.distributed.isend, # send to the right - # text_features_to_right, - # right_rank, - # ) - # recv_op = torch.distributed.P2POp( - # torch.distributed.irecv, # recv from left - # text_features_from_left, - # left_rank, - # ) - # reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) - # for req in reqs: - # req.wait() + # text_features_from_left = neighbour_exchange(left_rank, right_rank, text_features_to_right) neg_loss = self._loss( image_features,