Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/issue-1118' into issue-1118
Browse files Browse the repository at this point in the history
  • Loading branch information
Saanidhyavats committed Nov 4, 2024
2 parents 6bf8aa5 + 63a91f6 commit 9df7799
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
1 change: 1 addition & 0 deletions optax/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@
from optax.losses._regression import log_cosh
from optax.losses._regression import squared_error
from optax.losses._self_supervised import ntxent
from optax.losses._self_supervised import triplet_margin_loss
from optax.losses._smoothing import smooth_labels
18 changes: 10 additions & 8 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ def triplet_margin_loss(
) -> chex.Array:
"""Triplet margin loss function.
References:
V. Balntas et al. `Learning shallow convolutional feature
descriptors with triplet losses
<https://bmva-archive.org.uk/bmvc/2016/papers/paper119/paper119.pdf>`_, 2016
Measures the relative similarity between an anchor point,
a positive point, and a negative point using the distance
metric specified by p-norm. The loss encourages
Expand All @@ -133,21 +128,28 @@ def triplet_margin_loss(
swap: Use the distance swap optimization
reduction: Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'
Returns:
The triplet margin loss value.
If reduction is 'none': tensor of shape [batch_size]
If reduction is 'mean' or 'sum': scalar tensor.
example:
Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> anchor = jnp.array([[1.0, 2.0], [3.0, 4.0]])
>>> positive = jnp.array([[1.1, 2.1], [3.1, 4.1]])
>>> negative = jnp.array([[2.0, 3.0], [4.0, 5.0]])
>>> margin = 1.0
>>> loss = triplet_margin_loss(anchor, positive, negative,
>>> loss = optax.losses.triplet_margin_loss(anchor, positive, negative,
>>> margin=margin, reduction='mean')
>>> print(loss)
References:
V. Balntas et al. `Learning shallow convolutional feature
descriptors with triplet losses
<https://bmva-archive.org.uk/bmvc/2016/papers/paper119/paper119.pdf>`_, 2016
"""
chex.assert_equal_shape([anchor, positive, negative])

Expand Down
24 changes: 14 additions & 10 deletions optax/losses/_self_supervised_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ def setUp(self):
self.p2 = jnp.ones((2, 2))
self.n2 = jnp.ones((2, 2))*2

@chex.all_variants
def test_vmap(self):
# VMAP applied function result
original_loss = _self_supervised.triplet_margin_loss(
self.a1, self.p1, self.n1)
self.a1 = self.a1.reshape(1, *self.a1.shape)
self.p1 = self.p1.reshape(1, *self.p1.shape)
self.n1 = self.n1.reshape(1, *self.n1.shape)
vmap_loss = self.variant(jax.vmap(
_self_supervised.triplet_margin_loss
))(self.a1, self.p1,
self.n1)
np.testing.assert_allclose(vmap_loss, original_loss, atol=1e-4)

@chex.all_variants
def test_batched(self):
def testing_triplet_loss(a, p, n, margin=1.0, swap=False):
Expand Down Expand Up @@ -106,16 +120,6 @@ def testing_triplet_loss(a, p, n, margin=1.0, swap=False):
np.testing.assert_allclose(jit_loss, original_loss,
atol=1e-4)

# VMAP applied function result
self.a1 = self.a1.reshape(1, *self.a1.shape)
self.p1 = self.p1.reshape(1, *self.p1.shape)
self.n1 = self.n1.reshape(1, *self.n1.shape)
vmap_loss = self.variant(jax.vmap(
_self_supervised.triplet_margin_loss
))(self.a1, self.p1,
self.n1)
np.testing.assert_allclose(vmap_loss, original_loss, atol=1e-4)


if __name__ == '__main__':
absltest.main()

0 comments on commit 9df7799

Please sign in to comment.