Skip to content

Commit

Permalink
Merge pull request #19 from invrs-io/loss
Browse files Browse the repository at this point in the history
Update loss function for diffract challenges
  • Loading branch information
mfschubert authored Oct 5, 2023
2 parents a265012 + f2fe5cf commit 53cbd79
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/invrs_gym/challenge/diffract/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def seed_density(grid_shape: Tuple[int, int], **kwargs: Any) -> types.Density2DA
def index_for_order(
order: Tuple[int, int],
expansion: basis.Expansion,
) -> jnp.ndarray:
) -> int:
"""Returns the index for the specified Fourier order and expansion."""
((order_idx,),) = onp.where(onp.all(expansion.basis_coefficients == order, axis=1))
assert tuple(expansion.basis_coefficients[order_idx, :]) == order
return order_idx
return int(order_idx)


def grating_efficiency(
Expand Down
16 changes: 10 additions & 6 deletions src/invrs_gym/challenge/diffract/metagrating_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,16 @@ class MetagratingChallenge:

def loss(self, response: common.GratingResponse) -> jnp.ndarray:
"""Compute a scalar loss from the component `response`."""
transmission_efficiency = _value_for_order(
response.transmission_efficiency,
expansion=response.expansion,
order=self.transmission_order,
)
return jnp.mean(jnp.sqrt(jnp.abs(1 - transmission_efficiency)))
assert response.transmission_efficiency.shape[-1] == 1

order_idx = common.index_for_order(self.transmission_order, response.expansion)
target = jnp.zeros_like(response.transmission_efficiency)
target = target.at[..., order_idx, :].set(1.0)

transmission_error = jnp.sum((response.transmission_efficiency - target) ** 2)
reflection_error = jnp.sum(response.reflection_efficiency**2)
num_batch = jnp.prod(jnp.asarray(target.shape[:-2]))
return (transmission_error + reflection_error) / num_batch

def metrics(
self,
Expand Down
41 changes: 25 additions & 16 deletions src/invrs_gym/challenge/diffract/splitter_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ class DiffractiveSplitterChallenge:

def loss(self, response: common.GratingResponse) -> jnp.ndarray:
"""Compute a scalar loss from the component `response`."""
target = jnp.ones(self.splitting + (1,))
target /= self.splitting[0] * self.splitting[1]
assert response.transmission_efficiency.shape[-1] == 1

transmission = extract_orders_for_splitting(
response.transmission_efficiency,
expansion=response.expansion,
splitting=self.splitting,
)
assert transmission.shape[-3:] == target.shape[-3:]
return jnp.linalg.norm(jnp.sqrt(target) - jnp.sqrt(transmission)) ** 2
idxs = indices_for_splitting(response.expansion, self.splitting)
target = jnp.zeros_like(response.transmission_efficiency)
target_transmission = 1 / jnp.prod(jnp.asarray(self.splitting))
target = target.at[..., idxs, :].set(target_transmission)

transmission_error = jnp.sum((response.transmission_efficiency - target) ** 2)
reflection_error = jnp.sum(response.reflection_efficiency**2)
num_batch = jnp.prod(jnp.asarray(target.shape[:-2]))
return (transmission_error + reflection_error) / num_batch

def metrics(
self,
Expand Down Expand Up @@ -216,23 +217,31 @@ def metrics(
}


def indices_for_splitting(
expansion: basis.Expansion,
splitting: Tuple[int, int],
) -> Tuple[int, ...]:
"""Return the indices for the orders identified by the target `splitting`."""
idxs = []
orders_x = range(-splitting[0] // 2 + 1, splitting[0] // 2 + 1)
orders_y = range(-splitting[1] // 2 + 1, splitting[1] // 2 + 1)
for ox, oy in itertools.product(orders_x, orders_y):
idxs.append(common.index_for_order((ox, oy), expansion))
return tuple(idxs)


def extract_orders_for_splitting(
array: jnp.ndarray,
expansion: basis.Expansion,
splitting: Tuple[int, int],
) -> jnp.ndarray:
"""Extracts the values from `array` for the specified splitting."""
"""Extract the values from `array` for the specified splitting."""

num_splits = splitting[0] * splitting[1]
shape = array.shape[:-2] + splitting + (array.shape[-1],)
flat_shape = array.shape[:-2] + (num_splits, array.shape[-1])

idxs = []
orders_x = range(-splitting[0] // 2 + 1, splitting[0] // 2 + 1)
orders_y = range(-splitting[1] // 2 + 1, splitting[1] // 2 + 1)
for ox, oy in itertools.product(orders_x, orders_y):
idxs.append(common.index_for_order((ox, oy), expansion))

idxs = indices_for_splitting(expansion, splitting)
extracted = jnp.zeros(flat_shape, dtype=array.dtype)
extracted = extracted.at[..., :, :].set(array[..., idxs, :])
return extracted.reshape(shape)
Expand Down

0 comments on commit 53cbd79

Please sign in to comment.