Skip to content

Commit

Permalink
Add bidirectional exchange option, more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 18, 2023
1 parent b88364b commit 134e61a
Showing 1 changed file with 126 additions and 32 deletions.
158 changes: 126 additions & 32 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,39 @@ def neighbour_exchange(from_rank, to_rank, tensor, group=None):
return tensor_recv


def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
tensor_from_left = torch.zeros_like(tensor_to_right)
tensor_from_right = torch.zeros_like(tensor_to_left)
send_op_left = torch.distributed.P2POp(
torch.distributed.isend,
tensor_to_left,
left_rank,
group=group,
)
send_op_right = torch.distributed.P2POp(
torch.distributed.isend,
tensor_to_right,
right_rank,
group=group,
)
recv_op_left = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_left,
left_rank,
group=group,
)
recv_op_right = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_right,
right_rank,
group=group,
)
reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
for req in reqs:
req.wait()
return tensor_from_right, tensor_from_left


class NeighbourExchange(torch.autograd.Function):
@staticmethod
def forward(ctx, from_rank, to_rank, group, tensor):
Expand All @@ -249,29 +282,54 @@ def backward(ctx, grad_output):
return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)


def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
return NeighbourExchange.apply(from_rank, to_rank, group, tensor)


class NeighbourExchangeBidir(torch.autograd.Function):
@staticmethod
def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
ctx.group = group
ctx.left_rank = left_rank
ctx.right_rank = right_rank
return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)

@staticmethod
def backward(ctx, *grad_outputs):
return (None, None, None) + \
NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)


def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)


class SigLipLoss(nn.Module):

def __init__(
self,
cache_labels=False,
rank=0,
world_size=1,
bidir=False,
use_horovod=False,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod # FIXME need to look at hvd ops for ring transfers
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.bidir = bidir # FIXME need further verification & benchmarking of bidirectional mode

# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
self.labels = {}

def get_ground_truth(self, device, num_logits, negative_only=False) -> torch.Tensor:
labels = -torch.ones((num_logits, num_logits), device=device)
def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
if not negative_only:
labels = 2 * torch.eye(num_logits, device=device) + labels
labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
return labels

def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
Expand All @@ -282,7 +340,12 @@ def get_logits(self, image_features, text_features, logit_scale, logit_bias=None

def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
labels = self.get_ground_truth(image_features.device, image_features.shape[0], negative_only=negative_only)
labels = self.get_ground_truth(
image_features.device,
image_features.dtype,
image_features.shape[0],
negative_only=negative_only,
)
loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
return loss

Expand All @@ -291,35 +354,66 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output

if self.world_size > 1:
# exchange text features w/ neighbour world_size - 1 times
text_features_to_right = text_features
for i in range(self.world_size - 1):
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size

# FIXME having issues with distributed exchange, possibly gradient flow, three approaches
# 1. no intervention, do isend/irecv in forward, avg loss, leave up to DDP to reduce grads
# 2. extra all_reduce (sum) (nn. ver w/ grads) of loss in final step
# 3. custom autograd.Function Exchange (gradient passed back in reverse right -> left)

# approach #3
text_features_from_left = NeighbourExchange.apply(left_rank, right_rank, text_features_to_right)

# approach #1
# text_features_from_left = neighbour_exchange(left_rank, right_rank, text_features_to_right)

neg_loss = self._loss(
image_features,
text_features_from_left,
logit_scale,
logit_bias,
negative_only=True,
)
loss += neg_loss
text_features_to_right = text_features_from_left

# loss /= self.world_size # better without this
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
if self.bidir:
text_features_to_right = text_features_to_left = text_features
num_bidir, remainder = divmod(self.world_size - 1, 2)
for i in range(num_bidir):
text_features_recv = neighbour_exchange_bidir_with_grad(
left_rank,
right_rank,
text_features_to_left,
text_features_to_right,
)

for f in text_features_recv:
loss += self._loss(
image_features,
f,
logit_scale,
logit_bias,
negative_only=True,
)
text_features_to_left, text_features_to_right = text_features_recv

if remainder:
text_features_recv = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)

loss += self._loss(
image_features,
text_features_recv,
logit_scale,
logit_bias,
negative_only=True,
)
else:
text_features_to_right = text_features
for i in range(self.world_size - 1):
# FIXME having issues with distributed exchange, possibly gradient flow, three approaches
# 1. no intervention, do isend/irecv in forward, avg loss, leave up to DDP to reduce grads
# 2. extra all_reduce (sum) (nn. ver w/ grads) of loss in final step
# 3. custom autograd.Function Exchange (gradient passed back in reverse right -> left)

# approach #3
text_features_from_left = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)

# approach #1
# text_features_from_left = neighbour_exchange(left_rank, right_rank, text_features_to_right)

loss += self._loss(
image_features,
text_features_from_left,
logit_scale,
logit_bias,
negative_only=True,
)
text_features_to_right = text_features_from_left

# approach #2
# loss /= self.world_size
# loss = torch.distributed.nn.all_reduce(loss)

return {"contrastive_loss": loss} if output_dict else loss

0 comments on commit 134e61a

Please sign in to comment.