diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 015e961cc..83c28c626 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -328,7 +328,7 @@ def __init__( self.world_size = world_size assert not use_horovod # FIXME need to look at hvd ops for ring transfers self.use_horovod = use_horovod - self.bidir = bidir # FIXME need further verification & benchmarking of bidirectional mode + self.bidir = bidir # FIXME need further benchmarking of bidirectional mode needed # cache state FIXME cache not currently used, worthwhile? self.prev_num_logits = 0 @@ -399,18 +399,9 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output else: text_features_to_right = text_features for i in range(self.world_size - 1): - # FIXME having issues with distributed exchange, possibly gradient flow, three approaches - # 1. no intervention, do isend/irecv in forward, avg loss, leave up to DDP to reduce grads - # 2. extra all_reduce (sum) (nn. ver w/ grads) of loss in final step - # 3. custom autograd.Function Exchange (gradient passed back in reverse right -> left) - - # approach #3 text_features_from_left = neighbour_exchange_with_grad( left_rank, right_rank, text_features_to_right) - # approach #1 - # text_features_from_left = neighbour_exchange(left_rank, right_rank, text_features_to_right) - loss += self._loss( image_features, text_features_from_left, @@ -420,8 +411,4 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output ) text_features_to_right = text_features_from_left - # approach #2 - # loss /= self.world_size - # loss = torch.distributed.nn.all_reduce(loss) - return {"contrastive_loss": loss} if output_dict else loss