Skip to content

Commit

Permalink
Experimenting with alternative siglip loss impl for better dist scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Oct 26, 2024
1 parent aedd550 commit ad71de2
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,16 @@ def __init__(
cache_labels=False,
rank=0,
world_size=1,
bidir=True,
use_horovod=False,
impl='bidir',
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.bidir = bidir
self.impl = impl

# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
Expand Down Expand Up @@ -361,10 +361,9 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
loss = self._loss(image_features, text_features, logit_scale, logit_bias)

if self.world_size > 1:
# exchange text features w/ neighbour world_size - 1 times
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
if self.bidir:
if self.impl == 'bidir':
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
text_features_to_right = text_features_to_left = text_features
num_bidir, remainder = divmod(self.world_size - 1, 2)
for i in range(num_bidir):
Expand All @@ -374,7 +373,6 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
text_features_to_left,
text_features_to_right,
)

for f in text_features_recv:
loss += self._loss(
image_features,
Expand All @@ -387,21 +385,27 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output

if remainder:
text_features_recv = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)

left_rank,
right_rank,
text_features_to_right
)
loss += self._loss(
image_features,
text_features_recv,
logit_scale,
logit_bias,
negative_only=True,
)
else:
elif self.impl == "shift":
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
text_features_to_right = text_features
for i in range(self.world_size - 1):
text_features_from_left = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)

left_rank,
right_rank,
text_features_to_right,
)
loss += self._loss(
image_features,
text_features_from_left,
Expand All @@ -410,5 +414,30 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
negative_only=True,
)
text_features_to_right = text_features_from_left
elif self.impl == "reduce":
for i in range(self.world_size):
text_from_other = torch.distributed.nn.all_reduce(
text_features * (self.rank == i),
torch.distributed.ReduceOp.SUM,
)
loss += float(i != self.rank) * self._loss(
image_features,
text_from_other,
logit_scale,
logit_bias,
negative_only=True,
)
elif self.impl == "gather":
all_text = torch.distributed.nn.all_gather(text_features)
for i in range(self.world_size):
loss += float(i != self.rank) * self._loss(
image_features,
all_text[i],
logit_scale,
logit_bias,
negative_only=True,
)
else:
assert False

return {"contrastive_loss": loss} if output_dict else loss

0 comments on commit ad71de2

Please sign in to comment.