diff --git a/README.md b/README.md index b248bac..e82b7f4 100644 --- a/README.md +++ b/README.md @@ -1 +1,7 @@ # gym + +# Testing +Some tests are marked as slow and are skipped by default. To run these manually, use +``` +pytest --runslow +``` diff --git a/src/invrs_gym/challenge/diffract/metagrating_challenge.py b/src/invrs_gym/challenge/diffract/metagrating_challenge.py index 8e3890c..a003b1a 100644 --- a/src/invrs_gym/challenge/diffract/metagrating_challenge.py +++ b/src/invrs_gym/challenge/diffract/metagrating_challenge.py @@ -178,6 +178,8 @@ def _value_for_order( truncation=basis.Truncation.CIRCULAR, ) +SYMMETRIES = (symmetry.REFLECTION_E_W,) + # Minimum width and spacing are approximately 80 nm for the default dimensions # of 1.371 x 0.525 um and grid shape of (118, 45). MINIMUM_WIDTH = 7 @@ -195,16 +197,18 @@ def metagrating( density_initializer: DensityInitializer = common.identity_initializer, transmission_order: Tuple[int, int] = TRANSMISSION_ORDER, transmission_lower_bound: float = TRANSMISSION_LOWER_BOUND, + spec: common.GratingSpec = METAGRATING_SPEC, + sim_params: common.GratingSimParams = METAGRATING_SIM_PARAMS, ) -> MetagratingChallenge: """Metagrating with 1.371 x 0.525 um design region.""" return MetagratingChallenge( component=MetagratingComponent( - spec=METAGRATING_SPEC, - sim_params=METAGRATING_SIM_PARAMS, + spec=spec, + sim_params=sim_params, density_initializer=density_initializer, minimum_width=minimum_width, minimum_spacing=minimum_spacing, - symmetries=(symmetry.REFLECTION_E_W,), + symmetries=SYMMETRIES, ), transmission_order=transmission_order, transmission_lower_bound=transmission_lower_bound, diff --git a/src/invrs_gym/challenge/diffract/splitter_challenge.py b/src/invrs_gym/challenge/diffract/splitter_challenge.py index cfd2fae..7713405 100644 --- a/src/invrs_gym/challenge/diffract/splitter_challenge.py +++ b/src/invrs_gym/challenge/diffract/splitter_challenge.py @@ -276,12 +276,14 @@ def diffractive_splitter( thickness_initializer: ThicknessInitializer = common.identity_initializer, density_initializer: DensityInitializer = common.identity_initializer, splitting: Tuple[int, int] = SPLITTING, + spec: common.GratingSpec = DIFFRACTIVE_SPLITTER_SPEC, + sim_params: common.GratingSimParams = DIFFRACTIVE_SPLITTER_SIM_PARAMS, ) -> DiffractiveSplitterChallenge: """Diffractive splitter with 7.2 x 7.2 um design region.""" return DiffractiveSplitterChallenge( component=DiffractiveSplitterComponent( - spec=DIFFRACTIVE_SPLITTER_SPEC, - sim_params=DIFFRACTIVE_SPLITTER_SIM_PARAMS, + spec=spec, + sim_params=sim_params, thickness_initializer=thickness_initializer, density_initializer=density_initializer, minimum_width=minimum_width, diff --git a/tests/challenge/diffract/test_metagrating_challenge.py b/tests/challenge/diffract/test_metagrating_challenge.py index ed871e9..155a79b 100644 --- a/tests/challenge/diffract/test_metagrating_challenge.py +++ b/tests/challenge/diffract/test_metagrating_challenge.py @@ -1,21 +1,29 @@ """Tests for `diffract.metagrating_challenge`.""" +import dataclasses import unittest import jax import jax.numpy as jnp import optax +from fmmax import fmm from parameterized import parameterized from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore] from invrs_gym.challenge.diffract import metagrating_challenge +LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace( + metagrating_challenge.METAGRATING_SIM_PARAMS, + approximate_num_terms=100, + formulation=fmm.Formulation.FFT, +) + class MetagratingComponentTest(unittest.TestCase): def test_density_has_expected_properties(self): mc = metagrating_challenge.MetagratingComponent( spec=metagrating_challenge.METAGRATING_SPEC, - sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, density_initializer=lambda _, seed_density: seed_density, ) params = mc.init(jax.random.PRNGKey(0)) @@ -26,7 +34,7 @@ def test_density_has_expected_properties(self): def test_can_jit_response(self): mc = metagrating_challenge.MetagratingComponent( spec=metagrating_challenge.METAGRATING_SPEC, - sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, density_initializer=lambda _, seed_density: seed_density, ) params = mc.init(jax.random.PRNGKey(0)) @@ -40,7 +48,7 @@ def jit_response_fn(params): def test_multiple_wavelengths(self): mc = metagrating_challenge.MetagratingComponent( spec=metagrating_challenge.METAGRATING_SPEC, - sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, density_initializer=lambda _, seed_density: seed_density, ) params = mc.init(jax.random.PRNGKey(0)) @@ -54,7 +62,7 @@ def test_multiple_wavelengths(self): class MetagratingChallengeTest(unittest.TestCase): @parameterized.expand([[lambda fn: fn], [jax.jit]]) def test_optimize(self, step_fn_decorator): - mc = metagrating_challenge.metagrating() + mc = metagrating_challenge.metagrating(sim_params=LIGHTWEIGHT_SIM_PARAMS) def loss_fn(params): response, aux = mc.component.response(params) diff --git a/tests/challenge/diffract/test_reference_devices.py b/tests/challenge/diffract/test_reference_devices.py index ed4e676..6de6d4c 100644 --- a/tests/challenge/diffract/test_reference_devices.py +++ b/tests/challenge/diffract/test_reference_devices.py @@ -17,7 +17,6 @@ class ReferenceMetagratingTest(unittest.TestCase): - @pytest.mark.slow @parameterized.expand( [ # device name, expected, tolerance @@ -28,6 +27,7 @@ class ReferenceMetagratingTest(unittest.TestCase): ["device5.csv", 0.841, 0.015], # Reticolo 0.841, Meep 0.843 ] ) + @pytest.mark.slow def test_efficiency_matches_expected(self, fname, expected_efficiency, tol): # Compares efficiencies against those reported at. # https://github.com/NanoComp/photonics-opt-testbed/tree/main/Metagrating3D @@ -68,7 +68,6 @@ def test_efficiency_matches_expected(self, fname, expected_efficiency, tol): class ReferenceDiffractiveSplitterTest(unittest.TestCase): - @pytest.mark.slow @parameterized.expand( [ [ @@ -100,6 +99,7 @@ class ReferenceDiffractiveSplitterTest(unittest.TestCase): ], ] ) + @pytest.mark.slow def test_efficiency_matches_expected( self, fname, diff --git a/tests/challenge/diffract/test_splitter_challenge.py b/tests/challenge/diffract/test_splitter_challenge.py index 434ffd7..e187ae8 100644 --- a/tests/challenge/diffract/test_splitter_challenge.py +++ b/tests/challenge/diffract/test_splitter_challenge.py @@ -1,20 +1,28 @@ """Tests for `diffract.splitter_challenge`.""" +import dataclasses import unittest import jax import jax.numpy as jnp import optax +from fmmax import fmm from parameterized import parameterized from invrs_gym.challenge.diffract import splitter_challenge +LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace( + splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, + approximate_num_terms=100, + formulation=fmm.Formulation.FFT, +) + class SplitterComponentTest(unittest.TestCase): def test_density_has_expected_properties(self): mc = splitter_challenge.DiffractiveSplitterComponent( spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC, - sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, thickness_initializer=lambda _, thickness: thickness, density_initializer=lambda _, seed_density: seed_density, ) @@ -30,7 +38,7 @@ def test_density_has_expected_properties(self): def test_can_jit_response(self): mc = splitter_challenge.DiffractiveSplitterComponent( spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC, - sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, thickness_initializer=lambda _, thickness: thickness, density_initializer=lambda _, seed_density: seed_density, ) @@ -45,7 +53,7 @@ def jit_response_fn(params): def test_multiple_wavelengths(self): mc = splitter_challenge.DiffractiveSplitterComponent( spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC, - sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, thickness_initializer=lambda _, thickness: thickness, density_initializer=lambda _, seed_density: seed_density, ) @@ -60,7 +68,7 @@ def test_multiple_wavelengths(self): class SplitterChallengeTest(unittest.TestCase): @parameterized.expand([[lambda fn: fn], [jax.jit]]) def test_optimize(self, step_fn_decorator): - mc = splitter_challenge.diffractive_splitter() + mc = splitter_challenge.diffractive_splitter(sim_params=LIGHTWEIGHT_SIM_PARAMS) def loss_fn(params): response, aux = mc.component.response(params)