diff --git a/src/invrs_gym/challenge/diffract/metagrating_challenge.py b/src/invrs_gym/challenge/diffract/metagrating_challenge.py index 8abe961..8e3890c 100644 --- a/src/invrs_gym/challenge/diffract/metagrating_challenge.py +++ b/src/invrs_gym/challenge/diffract/metagrating_challenge.py @@ -123,7 +123,7 @@ def loss(self, response: common.GratingResponse) -> jnp.ndarray: expansion=response.expansion, order=self.transmission_order, ) - return jnp.mean(jnp.sqrt(1 - transmission_efficiency)) + return jnp.mean(jnp.sqrt(jnp.abs(1 - transmission_efficiency))) def metrics( self,