From ad71de21f5ef29f1fadd6455ba8fa8728af1c9e0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 26 Oct 2024 13:24:22 -0700 Subject: [PATCH] Experimenting with alternative siglip loss impl for better dist scaling --- src/open_clip/loss.py | 53 +++++++++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 5beaab1c3..9b39dbf31 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -319,8 +319,8 @@ def __init__( cache_labels=False, rank=0, world_size=1, - bidir=True, use_horovod=False, + impl='bidir', ): super().__init__() self.cache_labels = cache_labels @@ -328,7 +328,7 @@ def __init__( self.world_size = world_size assert not use_horovod # FIXME need to look at hvd ops for ring transfers self.use_horovod = use_horovod - self.bidir = bidir + self.impl = impl # cache state FIXME cache not currently used, worthwhile? self.prev_num_logits = 0 @@ -361,10 +361,9 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output loss = self._loss(image_features, text_features, logit_scale, logit_bias) if self.world_size > 1: - # exchange text features w/ neighbour world_size - 1 times - right_rank = (self.rank + 1) % self.world_size - left_rank = (self.rank - 1 + self.world_size) % self.world_size - if self.bidir: + if self.impl == 'bidir': + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size text_features_to_right = text_features_to_left = text_features num_bidir, remainder = divmod(self.world_size - 1, 2) for i in range(num_bidir): @@ -374,7 +373,6 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output text_features_to_left, text_features_to_right, ) - for f in text_features_recv: loss += self._loss( image_features, @@ -387,8 +385,10 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output if remainder: text_features_recv = neighbour_exchange_with_grad( - left_rank, right_rank, text_features_to_right) - + left_rank, + right_rank, + text_features_to_right + ) loss += self._loss( image_features, text_features_recv, @@ -396,12 +396,16 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output logit_bias, negative_only=True, ) - else: + elif self.impl == "shift": + right_rank = (self.rank + 1) % self.world_size + left_rank = (self.rank - 1 + self.world_size) % self.world_size text_features_to_right = text_features for i in range(self.world_size - 1): text_features_from_left = neighbour_exchange_with_grad( - left_rank, right_rank, text_features_to_right) - + left_rank, + right_rank, + text_features_to_right, + ) loss += self._loss( image_features, text_features_from_left, @@ -410,5 +414,30 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output negative_only=True, ) text_features_to_right = text_features_from_left + elif self.impl == "reduce": + for i in range(self.world_size): + text_from_other = torch.distributed.nn.all_reduce( + text_features * (self.rank == i), + torch.distributed.ReduceOp.SUM, + ) + loss += float(i != self.rank) * self._loss( + image_features, + text_from_other, + logit_scale, + logit_bias, + negative_only=True, + ) + elif self.impl == "gather": + all_text = torch.distributed.nn.all_gather(text_features) + for i in range(self.world_size): + loss += float(i != self.rank) * self._loss( + image_features, + all_text[i], + logit_scale, + logit_bias, + negative_only=True, + ) + else: + assert False return {"contrastive_loss": loss} if output_dict else loss