From 0b611d75d8512200df92fda21265eedfda69d4c9 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Mon, 2 Oct 2023 10:36:16 -0700 Subject: [PATCH 1/2] Speed up diffract tests --- README.md | 6 ++++++ .../challenge/diffract/metagrating_challenge.py | 10 +++++++--- .../challenge/diffract/splitter_challenge.py | 6 ++++-- .../diffract/test_metagrating_challenge.py | 17 +++++++++++++---- .../diffract/test_reference_devices.py | 4 ++-- .../diffract/test_splitter_challenge.py | 17 +++++++++++++---- 6 files changed, 45 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index b248bac..db8cca4 100644 --- a/README.md +++ b/README.md @@ -1 +1,7 @@ # gym + +# Testing +Some tests are marked as slow and are skipped by default. To run these manually, use +``` +pytest --runslow +``` \ No newline at end of file diff --git a/src/invrs_gym/challenge/diffract/metagrating_challenge.py b/src/invrs_gym/challenge/diffract/metagrating_challenge.py index 8e3890c..a003b1a 100644 --- a/src/invrs_gym/challenge/diffract/metagrating_challenge.py +++ b/src/invrs_gym/challenge/diffract/metagrating_challenge.py @@ -178,6 +178,8 @@ def _value_for_order( truncation=basis.Truncation.CIRCULAR, ) +SYMMETRIES = (symmetry.REFLECTION_E_W,) + # Minimum width and spacing are approximately 80 nm for the default dimensions # of 1.371 x 0.525 um and grid shape of (118, 45). MINIMUM_WIDTH = 7 @@ -195,16 +197,18 @@ def metagrating( density_initializer: DensityInitializer = common.identity_initializer, transmission_order: Tuple[int, int] = TRANSMISSION_ORDER, transmission_lower_bound: float = TRANSMISSION_LOWER_BOUND, + spec: common.GratingSpec = METAGRATING_SPEC, + sim_params: common.GratingSimParams = METAGRATING_SIM_PARAMS, ) -> MetagratingChallenge: """Metagrating with 1.371 x 0.525 um design region.""" return MetagratingChallenge( component=MetagratingComponent( - spec=METAGRATING_SPEC, - sim_params=METAGRATING_SIM_PARAMS, + spec=spec, + sim_params=sim_params, density_initializer=density_initializer, minimum_width=minimum_width, minimum_spacing=minimum_spacing, - symmetries=(symmetry.REFLECTION_E_W,), + symmetries=SYMMETRIES, ), transmission_order=transmission_order, transmission_lower_bound=transmission_lower_bound, diff --git a/src/invrs_gym/challenge/diffract/splitter_challenge.py b/src/invrs_gym/challenge/diffract/splitter_challenge.py index cfd2fae..7713405 100644 --- a/src/invrs_gym/challenge/diffract/splitter_challenge.py +++ b/src/invrs_gym/challenge/diffract/splitter_challenge.py @@ -276,12 +276,14 @@ def diffractive_splitter( thickness_initializer: ThicknessInitializer = common.identity_initializer, density_initializer: DensityInitializer = common.identity_initializer, splitting: Tuple[int, int] = SPLITTING, + spec: common.GratingSpec = DIFFRACTIVE_SPLITTER_SPEC, + sim_params: common.GratingSimParams = DIFFRACTIVE_SPLITTER_SIM_PARAMS, ) -> DiffractiveSplitterChallenge: """Diffractive splitter with 7.2 x 7.2 um design region.""" return DiffractiveSplitterChallenge( component=DiffractiveSplitterComponent( - spec=DIFFRACTIVE_SPLITTER_SPEC, - sim_params=DIFFRACTIVE_SPLITTER_SIM_PARAMS, + spec=spec, + sim_params=sim_params, thickness_initializer=thickness_initializer, density_initializer=density_initializer, minimum_width=minimum_width, diff --git a/tests/challenge/diffract/test_metagrating_challenge.py b/tests/challenge/diffract/test_metagrating_challenge.py index ed871e9..3fe2da1 100644 --- a/tests/challenge/diffract/test_metagrating_challenge.py +++ b/tests/challenge/diffract/test_metagrating_challenge.py @@ -2,20 +2,29 @@ import unittest +import dataclasses import jax import jax.numpy as jnp import optax from parameterized import parameterized from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore] +from fmmax import fmm from invrs_gym.challenge.diffract import metagrating_challenge +LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace( + metagrating_challenge.METAGRATING_SIM_PARAMS, + approximate_num_terms=100, + formulation=fmm.Formulation.FFT, +) + + class MetagratingComponentTest(unittest.TestCase): def test_density_has_expected_properties(self): mc = metagrating_challenge.MetagratingComponent( spec=metagrating_challenge.METAGRATING_SPEC, - sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, density_initializer=lambda _, seed_density: seed_density, ) params = mc.init(jax.random.PRNGKey(0)) @@ -26,7 +35,7 @@ def test_density_has_expected_properties(self): def test_can_jit_response(self): mc = metagrating_challenge.MetagratingComponent( spec=metagrating_challenge.METAGRATING_SPEC, - sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, density_initializer=lambda _, seed_density: seed_density, ) params = mc.init(jax.random.PRNGKey(0)) @@ -40,7 +49,7 @@ def jit_response_fn(params): def test_multiple_wavelengths(self): mc = metagrating_challenge.MetagratingComponent( spec=metagrating_challenge.METAGRATING_SPEC, - sim_params=metagrating_challenge.METAGRATING_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, density_initializer=lambda _, seed_density: seed_density, ) params = mc.init(jax.random.PRNGKey(0)) @@ -54,7 +63,7 @@ def test_multiple_wavelengths(self): class MetagratingChallengeTest(unittest.TestCase): @parameterized.expand([[lambda fn: fn], [jax.jit]]) def test_optimize(self, step_fn_decorator): - mc = metagrating_challenge.metagrating() + mc = metagrating_challenge.metagrating(sim_params=LIGHTWEIGHT_SIM_PARAMS) def loss_fn(params): response, aux = mc.component.response(params) diff --git a/tests/challenge/diffract/test_reference_devices.py b/tests/challenge/diffract/test_reference_devices.py index ed4e676..6de6d4c 100644 --- a/tests/challenge/diffract/test_reference_devices.py +++ b/tests/challenge/diffract/test_reference_devices.py @@ -17,7 +17,6 @@ class ReferenceMetagratingTest(unittest.TestCase): - @pytest.mark.slow @parameterized.expand( [ # device name, expected, tolerance @@ -28,6 +27,7 @@ class ReferenceMetagratingTest(unittest.TestCase): ["device5.csv", 0.841, 0.015], # Reticolo 0.841, Meep 0.843 ] ) + @pytest.mark.slow def test_efficiency_matches_expected(self, fname, expected_efficiency, tol): # Compares efficiencies against those reported at. # https://github.com/NanoComp/photonics-opt-testbed/tree/main/Metagrating3D @@ -68,7 +68,6 @@ def test_efficiency_matches_expected(self, fname, expected_efficiency, tol): class ReferenceDiffractiveSplitterTest(unittest.TestCase): - @pytest.mark.slow @parameterized.expand( [ [ @@ -100,6 +99,7 @@ class ReferenceDiffractiveSplitterTest(unittest.TestCase): ], ] ) + @pytest.mark.slow def test_efficiency_matches_expected( self, fname, diff --git a/tests/challenge/diffract/test_splitter_challenge.py b/tests/challenge/diffract/test_splitter_challenge.py index 434ffd7..022b351 100644 --- a/tests/challenge/diffract/test_splitter_challenge.py +++ b/tests/challenge/diffract/test_splitter_challenge.py @@ -2,19 +2,28 @@ import unittest +import dataclasses import jax import jax.numpy as jnp import optax from parameterized import parameterized +from fmmax import fmm from invrs_gym.challenge.diffract import splitter_challenge +LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace( + splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, + approximate_num_terms=100, + formulation=fmm.Formulation.FFT, +) + + class SplitterComponentTest(unittest.TestCase): def test_density_has_expected_properties(self): mc = splitter_challenge.DiffractiveSplitterComponent( spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC, - sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, thickness_initializer=lambda _, thickness: thickness, density_initializer=lambda _, seed_density: seed_density, ) @@ -30,7 +39,7 @@ def test_density_has_expected_properties(self): def test_can_jit_response(self): mc = splitter_challenge.DiffractiveSplitterComponent( spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC, - sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, thickness_initializer=lambda _, thickness: thickness, density_initializer=lambda _, seed_density: seed_density, ) @@ -45,7 +54,7 @@ def jit_response_fn(params): def test_multiple_wavelengths(self): mc = splitter_challenge.DiffractiveSplitterComponent( spec=splitter_challenge.DIFFRACTIVE_SPLITTER_SPEC, - sim_params=splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, + sim_params=LIGHTWEIGHT_SIM_PARAMS, thickness_initializer=lambda _, thickness: thickness, density_initializer=lambda _, seed_density: seed_density, ) @@ -60,7 +69,7 @@ def test_multiple_wavelengths(self): class SplitterChallengeTest(unittest.TestCase): @parameterized.expand([[lambda fn: fn], [jax.jit]]) def test_optimize(self, step_fn_decorator): - mc = splitter_challenge.diffractive_splitter() + mc = splitter_challenge.diffractive_splitter(sim_params=LIGHTWEIGHT_SIM_PARAMS) def loss_fn(params): response, aux = mc.component.response(params) From 2d439ca3bfb4d5fcec38435a08466c3aed4ba92a Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Mon, 2 Oct 2023 10:39:38 -0700 Subject: [PATCH 2/2] Formatting --- README.md | 2 +- tests/challenge/diffract/test_metagrating_challenge.py | 5 ++--- tests/challenge/diffract/test_splitter_challenge.py | 5 ++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index db8cca4..e82b7f4 100644 --- a/README.md +++ b/README.md @@ -4,4 +4,4 @@ Some tests are marked as slow and are skipped by default. To run these manually, use ``` pytest --runslow -``` \ No newline at end of file +``` diff --git a/tests/challenge/diffract/test_metagrating_challenge.py b/tests/challenge/diffract/test_metagrating_challenge.py index 3fe2da1..155a79b 100644 --- a/tests/challenge/diffract/test_metagrating_challenge.py +++ b/tests/challenge/diffract/test_metagrating_challenge.py @@ -1,18 +1,17 @@ """Tests for `diffract.metagrating_challenge`.""" +import dataclasses import unittest -import dataclasses import jax import jax.numpy as jnp import optax +from fmmax import fmm from parameterized import parameterized from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore] -from fmmax import fmm from invrs_gym.challenge.diffract import metagrating_challenge - LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace( metagrating_challenge.METAGRATING_SIM_PARAMS, approximate_num_terms=100, diff --git a/tests/challenge/diffract/test_splitter_challenge.py b/tests/challenge/diffract/test_splitter_challenge.py index 022b351..e187ae8 100644 --- a/tests/challenge/diffract/test_splitter_challenge.py +++ b/tests/challenge/diffract/test_splitter_challenge.py @@ -1,17 +1,16 @@ """Tests for `diffract.splitter_challenge`.""" +import dataclasses import unittest -import dataclasses import jax import jax.numpy as jnp import optax -from parameterized import parameterized from fmmax import fmm +from parameterized import parameterized from invrs_gym.challenge.diffract import splitter_challenge - LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace( splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, approximate_num_terms=100,