Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cvnad1 committed Nov 2, 2024
1 parent a15a2c7 commit fc3c32a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def triplet_margin_loss(
"""
chex.assert_equal_shape([anchor, positive, negative])

if not(jnp.ndim(anchor) == 2 and jnp.ndim(positive) == 2 and jnp.ndim(negative) == 2):
if not(anchor.ndim == 2 and positive.ndim == 2 and negative.dim == 2):
raise ValueError('Inputs must be 2D tensors')

# Calculate distances between pairs
Expand Down
15 changes: 8 additions & 7 deletions optax/losses/_self_supervised_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ class TripletMarginLossTest(chex.TestCase):

def setUp(self):
super().setUp()
self.a1 = jnp.ones((2, 2))
self.p1 = jnp.zeros((2, 2))
self.n1 = jnp.ones((2, 2))*2

self.a2 = jnp.zeros((2, 2))
self.p2 = jnp.ones((2, 2))
self.n2 = jnp.ones((2, 2))*2
self.t = jnp.random.normal((2,2))
self.a1 = self.t*0
self.p1 = self.a1+1
self.n1 = self.p1+1

self.a2 = self.t*0+1
self.p2 = self.a2-1
self.n2 = self.a2+1

def testing_triplet_loss(self, a, p, n, margin=1.0, swap=False):
ap = jnp.linalg.norm(a - p)
Expand Down

0 comments on commit fc3c32a

Please sign in to comment.