Skip to content

Commit

Permalink
Merge pull request #16 from invrs-io/slow
Browse files Browse the repository at this point in the history
Speed up diffract tests
  • Loading branch information
mfschubert authored Oct 2, 2023
2 parents 64dde4f + 2d439ca commit 59748a7
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 15 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
# gym

# Testing
Some tests are marked as slow and are skipped by default. To run these manually, use
```
pytest --runslow
```
10 changes: 7 additions & 3 deletions src/invrs_gym/challenge/diffract/metagrating_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def _value_for_order(
truncation=basis.Truncation.CIRCULAR,
)

SYMMETRIES = (symmetry.REFLECTION_E_W,)

# Minimum width and spacing are approximately 80 nm for the default dimensions
# of 1.371 x 0.525 um and grid shape of (118, 45).
MINIMUM_WIDTH = 7
Expand All @@ -195,16 +197,18 @@ def metagrating(
density_initializer: DensityInitializer = common.identity_initializer,
transmission_order: Tuple[int, int] = TRANSMISSION_ORDER,
transmission_lower_bound: float = TRANSMISSION_LOWER_BOUND,
spec: common.GratingSpec = METAGRATING_SPEC,
sim_params: common.GratingSimParams = METAGRATING_SIM_PARAMS,
) -> MetagratingChallenge:
"""Metagrating with 1.371 x 0.525 um design region."""
return MetagratingChallenge(
component=MetagratingComponent(
spec=METAGRATING_SPEC,
sim_params=METAGRATING_SIM_PARAMS,
spec=spec,
sim_params=sim_params,
density_initializer=density_initializer,
minimum_width=minimum_width,
minimum_spacing=minimum_spacing,
symmetries=(symmetry.REFLECTION_E_W,),
symmetries=SYMMETRIES,
),
transmission_order=transmission_order,
transmission_lower_bound=transmission_lower_bound,
Expand Down
6 changes: 4 additions & 2 deletions src/invrs_gym/challenge/diffract/splitter_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,14 @@ def diffractive_splitter(
thickness_initializer: ThicknessInitializer = common.identity_initializer,
density_initializer: DensityInitializer = common.identity_initializer,
splitting: Tuple[int, int] = SPLITTING,
spec: common.GratingSpec = DIFFRACTIVE_SPLITTER_SPEC,
sim_params: common.GratingSimParams = DIFFRACTIVE_SPLITTER_SIM_PARAMS,
) -> DiffractiveSplitterChallenge:
"""Diffractive splitter with 7.2 x 7.2 um design region."""
return DiffractiveSplitterChallenge(
component=DiffractiveSplitterComponent(
spec=DIFFRACTIVE_SPLITTER_SPEC,
sim_params=DIFFRACTIVE_SPLITTER_SIM_PARAMS,
spec=spec,
sim_params=sim_params,
thickness_initializer=thickness_initializer,
density_initializer=density_initializer,
minimum_width=minimum_width,
Expand Down
16 changes: 12 additions & 4 deletions tests/challenge/diffract/test_metagrating_challenge.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
"""Tests for `diffract.metagrating_challenge`."""

import dataclasses
import unittest

import jax
import jax.numpy as jnp
import optax
from fmmax import fmm
from parameterized import parameterized
from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore]

from invrs_gym.challenge.diffract import metagrating_challenge

LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace(
metagrating_challenge.METAGRATING_SIM_PARAMS,
approximate_num_terms=100,
formulation=fmm.Formulation.FFT,
)


class MetagratingComponentTest(unittest.TestCase):
def test_density_has_expected_properties(self):
mc = metagrating_challenge.MetagratingComponent(
spec=metagrating_challenge.METAGRATING_SPEC,
sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS,
sim_params=LIGHTWEIGHT_SIM_PARAMS,
density_initializer=lambda _, seed_density: seed_density,
)
params = mc.init(jax.random.PRNGKey(0))
Expand All @@ -26,7 +34,7 @@ def test_density_has_expected_properties(self):
def test_can_jit_response(self):
mc = metagrating_challenge.MetagratingComponent(
spec=metagrating_challenge.METAGRATING_SPEC,
sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS,
sim_params=LIGHTWEIGHT_SIM_PARAMS,
density_initializer=lambda _, seed_density: seed_density,
)
params = mc.init(jax.random.PRNGKey(0))
Expand All @@ -40,7 +48,7 @@ def jit_response_fn(params):
def test_multiple_wavelengths(self):
mc = metagrating_challenge.MetagratingComponent(
spec=metagrating_challenge.METAGRATING_SPEC,
sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS,
sim_params=LIGHTWEIGHT_SIM_PARAMS,
density_initializer=lambda _, seed_density: seed_density,
)
params = mc.init(jax.random.PRNGKey(0))
Expand All @@ -54,7 +62,7 @@ def test_multiple_wavelengths(self):
class MetagratingChallengeTest(unittest.TestCase):
@parameterized.expand([[lambda fn: fn], [jax.jit]])
def test_optimize(self, step_fn_decorator):
mc = metagrating_challenge.metagrating()
mc = metagrating_challenge.metagrating(sim_params=LIGHTWEIGHT_SIM_PARAMS)

def loss_fn(params):
response, aux = mc.component.response(params)
Expand Down
4 changes: 2 additions & 2 deletions tests/challenge/diffract/test_reference_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


class ReferenceMetagratingTest(unittest.TestCase):
@pytest.mark.slow
@parameterized.expand(
[
# device name, expected, tolerance
Expand All @@ -28,6 +27,7 @@ class ReferenceMetagratingTest(unittest.TestCase):
["device5.csv", 0.841, 0.015], # Reticolo 0.841, Meep 0.843
]
)
@pytest.mark.slow
def test_efficiency_matches_expected(self, fname, expected_efficiency, tol):
# Compares efficiencies against those reported at.
# https://github.com/NanoComp/photonics-opt-testbed/tree/main/Metagrating3D
Expand Down Expand Up @@ -68,7 +68,6 @@ def test_efficiency_matches_expected(self, fname, expected_efficiency, tol):


class ReferenceDiffractiveSplitterTest(unittest.TestCase):
@pytest.mark.slow
@parameterized.expand(
[
[
Expand Down Expand Up @@ -100,6 +99,7 @@ class ReferenceDiffractiveSplitterTest(unittest.TestCase):
],
]
)
@pytest.mark.slow
def test_efficiency_matches_expected(
self,
fname,
Expand Down
16 changes: 12 additions & 4 deletions tests/challenge/diffract/test_splitter_challenge.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
"""Tests for `diffract.splitter_challenge`."""

import dataclasses
import unittest

import jax
import jax.numpy as jnp
import optax
from fmmax import fmm
from parameterized import parameterized

from invrs_gym.challenge.diffract import splitter_challenge

LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace(
splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS,
approximate_num_terms=100,
formulation=fmm.Formulation.FFT,
)


class SplitterComponentTest(unittest.TestCase):
def test_density_has_expected_properties(self):
mc = splitter_challenge.DiffractiveSplitterComponent(
spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC,
sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS,
sim_params=LIGHTWEIGHT_SIM_PARAMS,
thickness_initializer=lambda _, thickness: thickness,
density_initializer=lambda _, seed_density: seed_density,
)
Expand All @@ -30,7 +38,7 @@ def test_density_has_expected_properties(self):
def test_can_jit_response(self):
mc = splitter_challenge.DiffractiveSplitterComponent(
spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC,
sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS,
sim_params=LIGHTWEIGHT_SIM_PARAMS,
thickness_initializer=lambda _, thickness: thickness,
density_initializer=lambda _, seed_density: seed_density,
)
Expand All @@ -45,7 +53,7 @@ def jit_response_fn(params):
def test_multiple_wavelengths(self):
mc = splitter_challenge.DiffractiveSplitterComponent(
spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC,
sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS,
sim_params=LIGHTWEIGHT_SIM_PARAMS,
thickness_initializer=lambda _, thickness: thickness,
density_initializer=lambda _, seed_density: seed_density,
)
Expand All @@ -60,7 +68,7 @@ def test_multiple_wavelengths(self):
class SplitterChallengeTest(unittest.TestCase):
@parameterized.expand([[lambda fn: fn], [jax.jit]])
def test_optimize(self, step_fn_decorator):
mc = splitter_challenge.diffractive_splitter()
mc = splitter_challenge.diffractive_splitter(sim_params=LIGHTWEIGHT_SIM_PARAMS)

def loss_fn(params):
response, aux = mc.component.response(params)
Expand Down

0 comments on commit 59748a7

Please sign in to comment.