Skip to content

Commit

Permalink
Merge pull request #20 from invrs-io/loss2
Browse files Browse the repository at this point in the history
Tweak loss functions
  • Loading branch information
mfschubert authored Oct 5, 2023
2 parents 53cbd79 + 41ad276 commit 9bad4b4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
16 changes: 6 additions & 10 deletions src/invrs_gym/challenge/diffract/metagrating_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,12 @@ class MetagratingChallenge:

def loss(self, response: common.GratingResponse) -> jnp.ndarray:
"""Compute a scalar loss from the component `response`."""
assert response.transmission_efficiency.shape[-1] == 1

order_idx = common.index_for_order(self.transmission_order, response.expansion)
target = jnp.zeros_like(response.transmission_efficiency)
target = target.at[..., order_idx, :].set(1.0)

transmission_error = jnp.sum((response.transmission_efficiency - target) ** 2)
reflection_error = jnp.sum(response.reflection_efficiency**2)
num_batch = jnp.prod(jnp.asarray(target.shape[:-2]))
return (transmission_error + reflection_error) / num_batch
efficiency = _value_for_order(
response.transmission_efficiency,
expansion=response.expansion,
order=self.transmission_order,
)
return jnp.mean(1 - jnp.abs(jnp.sqrt(efficiency)))

def metrics(
self,
Expand Down
36 changes: 27 additions & 9 deletions src/invrs_gym/challenge/diffract/splitter_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,33 @@ def loss(self, response: common.GratingResponse) -> jnp.ndarray:
"""Compute a scalar loss from the component `response`."""
assert response.transmission_efficiency.shape[-1] == 1

idxs = indices_for_splitting(response.expansion, self.splitting)
target = jnp.zeros_like(response.transmission_efficiency)
target_transmission = 1 / jnp.prod(jnp.asarray(self.splitting))
target = target.at[..., idxs, :].set(target_transmission)

transmission_error = jnp.sum((response.transmission_efficiency - target) ** 2)
reflection_error = jnp.sum(response.reflection_efficiency**2)
num_batch = jnp.prod(jnp.asarray(target.shape[:-2]))
return (transmission_error + reflection_error) / num_batch
efficiency = extract_orders_for_splitting(
response.transmission_efficiency,
expansion=response.expansion,
splitting=self.splitting,
)
assert efficiency.shape[-3:] == self.splitting + (1,)

zeroth_order_mask = jnp.zeros_like(efficiency, dtype=bool)
zeroth_order_mask = zeroth_order_mask.at[
..., self.splitting[0] // 2, self.splitting[1] // 2, :
].set(True)

# The target for each order is `1 / num_splits`, except for the order having
# maximum transmission. This order has target equal to the mean efficiency.
num_splits = self.splitting[0] * self.splitting[1]
max_efficiency = jnp.amax(efficiency, axis=(-3, -2), keepdims=True)
efficiency_sum = jnp.sum(efficiency, axis=(-3, -2), keepdims=True)
mean_efficiency_excluding_max = (efficiency_sum - max_efficiency) / (
num_splits - 1
)

target = jnp.where(
efficiency == max_efficiency,
jax.lax.stop_gradient(mean_efficiency_excluding_max),
1 / num_splits,
)
return jnp.linalg.norm(jnp.sqrt(target) - jnp.sqrt(efficiency)) ** 2

def metrics(
self,
Expand Down

0 comments on commit 9bad4b4

Please sign in to comment.