diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 3a6f3738..a1eda10e 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -144,7 +144,7 @@ def triplet_margin_loss( """ chex.assert_equal_shape([anchor, positive, negative]) - if not(jnp.ndim(anchor) == 2 and jnp.ndim(positive) == 2 and jnp.ndim(negative) == 2): + if not(anchor.ndim == 2 and positive.ndim == 2 and negative.dim == 2): raise ValueError('Inputs must be 2D tensors') # Calculate distances between pairs diff --git a/optax/losses/_self_supervised_test.py b/optax/losses/_self_supervised_test.py index 36b61e24..d886621f 100644 --- a/optax/losses/_self_supervised_test.py +++ b/optax/losses/_self_supervised_test.py @@ -58,13 +58,14 @@ class TripletMarginLossTest(chex.TestCase): def setUp(self): super().setUp() - self.a1 = jnp.ones((2, 2)) - self.p1 = jnp.zeros((2, 2)) - self.n1 = jnp.ones((2, 2))*2 - - self.a2 = jnp.zeros((2, 2)) - self.p2 = jnp.ones((2, 2)) - self.n2 = jnp.ones((2, 2))*2 + self.t = jnp.random.normal((2,2)) + self.a1 = self.t*0 + self.p1 = self.a1+1 + self.n1 = self.p1+1 + + self.a2 = self.t*0+1 + self.p2 = self.a2-1 + self.n2 = self.a2+1 def testing_triplet_loss(self, a, p, n, margin=1.0, swap=False): ap = jnp.linalg.norm(a - p)