From 1129a835f16e730fee99395a730a47367999a65f Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 17 Oct 2023 15:17:49 -0700 Subject: [PATCH 1/5] Rename various modules and functions --- src/invrs_gym/challenge/__init__.py | 31 ----------------- src/invrs_gym/challenges/__init__.py | 33 +++++++++++++++++++ .../ceviche/__init__.py | 0 .../ceviche/challenge.py | 20 +++++------ .../ceviche/defaults.py | 0 .../ceviche}/transmission_loss.py | 0 .../diffract/__init__.py | 0 .../diffract/common.py | 0 .../diffract/metagrating_challenge.py | 2 +- .../diffract/splitter_challenge.py | 2 +- .../extractor/__init__.py | 0 .../extractor/challenge.py | 2 +- .../extractor/component.py | 0 src/invrs_gym/loss/__init__.py | 0 .../ceviche/test_challenge.py | 26 +++++++-------- .../ceviche/test_defaults.py | 2 +- .../ceviche}/test_transmission_loss.py | 2 +- .../broadband_metagrating_designs/device1.csv | 0 .../diffract/metagrating_designs/device1.csv | 0 .../diffract/metagrating_designs/device2.csv | 0 .../diffract/metagrating_designs/device3.csv | 0 .../diffract/metagrating_designs/device4.csv | 0 .../diffract/metagrating_designs/device5.csv | 0 .../diffract/splitter_designs/device1.csv | 0 .../diffract/splitter_designs/device2.csv | 0 .../diffract/splitter_designs/device3.csv | 0 .../diffract/test_common.py | 2 +- .../diffract/test_metagrating_challenge.py | 2 +- .../diffract/test_reference_devices.py | 2 +- .../diffract/test_splitter_challenge.py | 2 +- .../extractor/designs/device1.csv | 0 .../extractor/test_challenge.py | 2 +- .../extractor/test_component.py | 2 +- .../extractor/test_reference_devices.py | 2 +- tests/utils/test_optimizer.py | 4 +-- 35 files changed, 70 insertions(+), 68 deletions(-) delete mode 100644 src/invrs_gym/challenge/__init__.py create mode 100644 src/invrs_gym/challenges/__init__.py rename src/invrs_gym/{challenge => challenges}/ceviche/__init__.py (100%) rename src/invrs_gym/{challenge => challenges}/ceviche/challenge.py (97%) rename src/invrs_gym/{challenge => challenges}/ceviche/defaults.py (100%) rename src/invrs_gym/{loss => challenges/ceviche}/transmission_loss.py (100%) rename src/invrs_gym/{challenge => challenges}/diffract/__init__.py (100%) rename src/invrs_gym/{challenge => challenges}/diffract/common.py (100%) rename src/invrs_gym/{challenge => challenges}/diffract/metagrating_challenge.py (99%) rename src/invrs_gym/{challenge => challenges}/diffract/splitter_challenge.py (99%) rename src/invrs_gym/{challenge => challenges}/extractor/__init__.py (100%) rename src/invrs_gym/{challenge => challenges}/extractor/challenge.py (98%) rename src/invrs_gym/{challenge => challenges}/extractor/component.py (100%) delete mode 100644 src/invrs_gym/loss/__init__.py rename tests/{challenge => challenges}/ceviche/test_challenge.py (75%) rename tests/{challenge => challenges}/ceviche/test_defaults.py (97%) rename tests/{loss => challenges/ceviche}/test_transmission_loss.py (97%) rename tests/{challenge => challenges}/diffract/broadband_metagrating_designs/device1.csv (100%) rename tests/{challenge => challenges}/diffract/metagrating_designs/device1.csv (100%) rename tests/{challenge => challenges}/diffract/metagrating_designs/device2.csv (100%) rename tests/{challenge => challenges}/diffract/metagrating_designs/device3.csv (100%) rename tests/{challenge => challenges}/diffract/metagrating_designs/device4.csv (100%) rename tests/{challenge => challenges}/diffract/metagrating_designs/device5.csv (100%) rename tests/{challenge => challenges}/diffract/splitter_designs/device1.csv (100%) rename tests/{challenge => challenges}/diffract/splitter_designs/device2.csv (100%) rename tests/{challenge => challenges}/diffract/splitter_designs/device3.csv (100%) rename tests/{challenge => challenges}/diffract/test_common.py (95%) rename tests/{challenge => challenges}/diffract/test_metagrating_challenge.py (98%) rename tests/{challenge => challenges}/diffract/test_reference_devices.py (98%) rename tests/{challenge => challenges}/diffract/test_splitter_challenge.py (98%) rename tests/{challenge => challenges}/extractor/designs/device1.csv (100%) rename tests/{challenge => challenges}/extractor/test_challenge.py (98%) rename tests/{challenge => challenges}/extractor/test_component.py (98%) rename tests/{challenge => challenges}/extractor/test_reference_devices.py (99%) diff --git a/src/invrs_gym/challenge/__init__.py b/src/invrs_gym/challenge/__init__.py deleted file mode 100644 index 918f160..0000000 --- a/src/invrs_gym/challenge/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -__all__ = [ - "lightweight_beam_splitter_challenge", - "lightweight_mode_converter_challenge", - "lightweight_waveguide_bend_challenge", - "lightweight_wdm_challenge", - "beam_splitter_challenge", - "mode_converter_challenge", - "waveguide_bend_challenge", - "wdm_challenge", - "metagrating", - "broadband_metagrating", - "diffractive_splitter", - "photon_extractor", -] - -from invrs_gym.challenge.ceviche.challenge import ( - beam_splitter_challenge, - lightweight_beam_splitter_challenge, - lightweight_mode_converter_challenge, - lightweight_waveguide_bend_challenge, - lightweight_wdm_challenge, - mode_converter_challenge, - waveguide_bend_challenge, - wdm_challenge, -) -from invrs_gym.challenge.diffract.metagrating_challenge import ( - broadband_metagrating, - metagrating, -) -from invrs_gym.challenge.diffract.splitter_challenge import diffractive_splitter -from invrs_gym.challenge.extractor.challenge import photon_extractor diff --git a/src/invrs_gym/challenges/__init__.py b/src/invrs_gym/challenges/__init__.py new file mode 100644 index 0000000..08b2d30 --- /dev/null +++ b/src/invrs_gym/challenges/__init__.py @@ -0,0 +1,33 @@ +from invrs_gym.challenges.ceviche.challenge import beam_splitter as ceviche_beam_splitter +from invrs_gym.challenges.ceviche.challenge import mode_converter as ceviche_mode_converter +from invrs_gym.challenges.ceviche.challenge import waveguide_bend as ceviche_waveguide_bend +from invrs_gym.challenges.ceviche.challenge import wdm as ceviche_wdm +from invrs_gym.challenges.ceviche.challenge import lightweight_beam_splitter as ceviche_lightweight_beam_splitter +from invrs_gym.challenges.ceviche.challenge import lightweight_mode_converter as ceviche_lightweight_mode_converter +from invrs_gym.challenges.ceviche.challenge import lightweight_waveguide_bend as ceviche_lightweight_waveguide_bend +from invrs_gym.challenges.ceviche.challenge import lightweight_wdm as ceviche_lightweight_wdm + +from invrs_gym.challenges.diffract.metagrating_challenge import ( + broadband_metagrating, + metagrating, +) + +from invrs_gym.challenges.diffract.splitter_challenge import diffractive_splitter + +from invrs_gym.challenges.extractor.challenge import photon_extractor + + +BY_NAME = { + "ceviche_beam_splitter": ceviche_beam_splitter, + "ceviche_mode_converter": ceviche_mode_converter, + "ceviche_waveguide_bend": ceviche_waveguide_bend, + "ceviche_wdm": ceviche_wdm, + "ceviche_lightweight_beam_splitter": ceviche_lightweight_beam_splitter, + "ceviche_lightweight_mode_converter": ceviche_lightweight_mode_converter, + "ceviche_lightweight_waveguide_bend": ceviche_lightweight_waveguide_bend, + "ceviche_lightweight_wdm": ceviche_lightweight_wdm, + "metagrating": metagrating, + "broadband_metagrating": broadband_metagrating, + "diffractive_splitter": diffractive_splitter, + "photon_extractor": photon_extractor, +} \ No newline at end of file diff --git a/src/invrs_gym/challenge/ceviche/__init__.py b/src/invrs_gym/challenges/ceviche/__init__.py similarity index 100% rename from src/invrs_gym/challenge/ceviche/__init__.py rename to src/invrs_gym/challenges/ceviche/__init__.py diff --git a/src/invrs_gym/challenge/ceviche/challenge.py b/src/invrs_gym/challenges/ceviche/challenge.py similarity index 97% rename from src/invrs_gym/challenge/ceviche/challenge.py rename to src/invrs_gym/challenges/ceviche/challenge.py index 034c787..bede776 100644 --- a/src/invrs_gym/challenge/ceviche/challenge.py +++ b/src/invrs_gym/challenges/ceviche/challenge.py @@ -10,8 +10,8 @@ import numpy as onp from totypes import types # type: ignore[import,attr-defined,unused-ignore] -from invrs_gym.challenge.ceviche import defaults -from invrs_gym.loss import transmission_loss +from invrs_gym.challenges.ceviche import defaults +from invrs_gym.challenges.ceviche import transmission_loss AuxDict = Dict[str, Any] Params = Any @@ -271,7 +271,7 @@ def _wavelength_bound( # ----------------------------------------------------------------------------- -def beam_splitter_challenge( +def beam_splitter( minimum_width: int = defaults.MINIMUM_WIDTH, minimum_spacing: int = defaults.MINIMUM_SPACING, density_initializer: DensityInitializer = identity_initializer, @@ -290,7 +290,7 @@ def beam_splitter_challenge( ) -def lightweight_beam_splitter_challenge( +def lightweight_beam_splitter( minimum_width: int = defaults.LIGHTWEIGHT_MINIMUM_WIDTH, minimum_spacing: int = defaults.LIGHTWEIGHT_MINIMUM_SPACING, density_initializer: DensityInitializer = identity_initializer, @@ -309,7 +309,7 @@ def lightweight_beam_splitter_challenge( ) -def mode_converter_challenge( +def mode_converter( minimum_width: int = defaults.MINIMUM_WIDTH, minimum_spacing: int = defaults.MINIMUM_SPACING, density_initializer: DensityInitializer = identity_initializer, @@ -327,7 +327,7 @@ def mode_converter_challenge( ) -def lightweight_mode_converter_challenge( +def lightweight_mode_converter( minimum_width: int = defaults.LIGHTWEIGHT_MINIMUM_WIDTH, minimum_spacing: int = defaults.LIGHTWEIGHT_MINIMUM_SPACING, density_initializer: DensityInitializer = identity_initializer, @@ -345,7 +345,7 @@ def lightweight_mode_converter_challenge( ) -def waveguide_bend_challenge( +def waveguide_bend( minimum_width: int = defaults.MINIMUM_WIDTH, minimum_spacing: int = defaults.MINIMUM_SPACING, density_initializer: DensityInitializer = identity_initializer, @@ -364,7 +364,7 @@ def waveguide_bend_challenge( ) -def lightweight_waveguide_bend_challenge( +def lightweight_waveguide_bend( minimum_width: int = defaults.LIGHTWEIGHT_MINIMUM_WIDTH, minimum_spacing: int = defaults.LIGHTWEIGHT_MINIMUM_SPACING, density_initializer: DensityInitializer = identity_initializer, @@ -383,7 +383,7 @@ def lightweight_waveguide_bend_challenge( ) -def wdm_challenge( +def wdm( minimum_width: int = defaults.MINIMUM_WIDTH, minimum_spacing: int = defaults.MINIMUM_SPACING, density_initializer: DensityInitializer = identity_initializer, @@ -401,7 +401,7 @@ def wdm_challenge( ) -def lightweight_wdm_challenge( +def lightweight_wdm( minimum_width: int = defaults.LIGHTWEIGHT_MINIMUM_WIDTH, minimum_spacing: int = defaults.LIGHTWEIGHT_MINIMUM_SPACING, density_initializer: DensityInitializer = identity_initializer, diff --git a/src/invrs_gym/challenge/ceviche/defaults.py b/src/invrs_gym/challenges/ceviche/defaults.py similarity index 100% rename from src/invrs_gym/challenge/ceviche/defaults.py rename to src/invrs_gym/challenges/ceviche/defaults.py diff --git a/src/invrs_gym/loss/transmission_loss.py b/src/invrs_gym/challenges/ceviche/transmission_loss.py similarity index 100% rename from src/invrs_gym/loss/transmission_loss.py rename to src/invrs_gym/challenges/ceviche/transmission_loss.py diff --git a/src/invrs_gym/challenge/diffract/__init__.py b/src/invrs_gym/challenges/diffract/__init__.py similarity index 100% rename from src/invrs_gym/challenge/diffract/__init__.py rename to src/invrs_gym/challenges/diffract/__init__.py diff --git a/src/invrs_gym/challenge/diffract/common.py b/src/invrs_gym/challenges/diffract/common.py similarity index 100% rename from src/invrs_gym/challenge/diffract/common.py rename to src/invrs_gym/challenges/diffract/common.py diff --git a/src/invrs_gym/challenge/diffract/metagrating_challenge.py b/src/invrs_gym/challenges/diffract/metagrating_challenge.py similarity index 99% rename from src/invrs_gym/challenge/diffract/metagrating_challenge.py rename to src/invrs_gym/challenges/diffract/metagrating_challenge.py index 21def44..879d657 100644 --- a/src/invrs_gym/challenge/diffract/metagrating_challenge.py +++ b/src/invrs_gym/challenges/diffract/metagrating_challenge.py @@ -8,7 +8,7 @@ from fmmax import basis, fmm # type: ignore[import] from totypes import symmetry, types # type: ignore[import,attr-defined,unused-ignore] -from invrs_gym.challenge.diffract import common +from invrs_gym.challenges.diffract import common AuxDict = Dict[str, Any] DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray] diff --git a/src/invrs_gym/challenge/diffract/splitter_challenge.py b/src/invrs_gym/challenges/diffract/splitter_challenge.py similarity index 99% rename from src/invrs_gym/challenge/diffract/splitter_challenge.py rename to src/invrs_gym/challenges/diffract/splitter_challenge.py index f5aeea2..f939e61 100644 --- a/src/invrs_gym/challenge/diffract/splitter_challenge.py +++ b/src/invrs_gym/challenges/diffract/splitter_challenge.py @@ -9,7 +9,7 @@ from fmmax import basis, fmm # type: ignore[import] from totypes import types # type: ignore[import,attr-defined,unused-ignore] -from invrs_gym.challenge.diffract import common +from invrs_gym.challenges.diffract import common PyTree = Any AuxDict = Dict[str, Any] diff --git a/src/invrs_gym/challenge/extractor/__init__.py b/src/invrs_gym/challenges/extractor/__init__.py similarity index 100% rename from src/invrs_gym/challenge/extractor/__init__.py rename to src/invrs_gym/challenges/extractor/__init__.py diff --git a/src/invrs_gym/challenge/extractor/challenge.py b/src/invrs_gym/challenges/extractor/challenge.py similarity index 98% rename from src/invrs_gym/challenge/extractor/challenge.py rename to src/invrs_gym/challenges/extractor/challenge.py index d798281..7a4d6f8 100644 --- a/src/invrs_gym/challenge/extractor/challenge.py +++ b/src/invrs_gym/challenges/extractor/challenge.py @@ -9,7 +9,7 @@ from jax import tree_util from totypes import symmetry, types # type: ignore[import,attr-defined,unused-ignore] -from invrs_gym.challenge.extractor import component as extractor_component +from invrs_gym.challenges.extractor import component as extractor_component AuxDict = Dict[str, Any] DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray] diff --git a/src/invrs_gym/challenge/extractor/component.py b/src/invrs_gym/challenges/extractor/component.py similarity index 100% rename from src/invrs_gym/challenge/extractor/component.py rename to src/invrs_gym/challenges/extractor/component.py diff --git a/src/invrs_gym/loss/__init__.py b/src/invrs_gym/loss/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/challenge/ceviche/test_challenge.py b/tests/challenges/ceviche/test_challenge.py similarity index 75% rename from tests/challenge/ceviche/test_challenge.py rename to tests/challenges/ceviche/test_challenge.py index 7dd895c..1ce73da 100644 --- a/tests/challenge/ceviche/test_challenge.py +++ b/tests/challenges/ceviche/test_challenge.py @@ -7,16 +7,16 @@ import optax from parameterized import parameterized -from invrs_gym.challenge.ceviche import challenge +from invrs_gym.challenges.ceviche import challenge class CreateChallengesTest(unittest.TestCase): @parameterized.expand( [ - [challenge.lightweight_beam_splitter_challenge], - [challenge.lightweight_mode_converter_challenge], - [challenge.lightweight_waveguide_bend_challenge], - [challenge.lightweight_wdm_challenge], + [challenge.lightweight_beam_splitter], + [challenge.lightweight_mode_converter], + [challenge.lightweight_waveguide_bend], + [challenge.lightweight_wdm], ] ) def test_optimize(self, ceviche_challenge): @@ -60,14 +60,14 @@ def loss_fn(params): @parameterized.expand( [ - [challenge.beam_splitter_challenge], - [challenge.lightweight_beam_splitter_challenge], - [challenge.mode_converter_challenge], - [challenge.lightweight_mode_converter_challenge], - [challenge.waveguide_bend_challenge], - [challenge.lightweight_waveguide_bend_challenge], - [challenge.wdm_challenge], - [challenge.lightweight_wdm_challenge], + [challenge.beam_splitter], + [challenge.lightweight_beam_splitter], + [challenge.mode_converter], + [challenge.lightweight_mode_converter], + [challenge.waveguide_bend], + [challenge.lightweight_waveguide_bend], + [challenge.wdm], + [challenge.lightweight_wdm], ] ) def test_with_dummy_response(self, ceviche_challenge): diff --git a/tests/challenge/ceviche/test_defaults.py b/tests/challenges/ceviche/test_defaults.py similarity index 97% rename from tests/challenge/ceviche/test_defaults.py rename to tests/challenges/ceviche/test_defaults.py index 1935cfd..32ab863 100644 --- a/tests/challenge/ceviche/test_defaults.py +++ b/tests/challenges/ceviche/test_defaults.py @@ -6,7 +6,7 @@ from ceviche_challenges import units as u from parameterized import parameterized -from invrs_gym.challenge.ceviche import defaults +from invrs_gym.challenges.ceviche import defaults class CreateModelTest(unittest.TestCase): diff --git a/tests/loss/test_transmission_loss.py b/tests/challenges/ceviche/test_transmission_loss.py similarity index 97% rename from tests/loss/test_transmission_loss.py rename to tests/challenges/ceviche/test_transmission_loss.py index 4cecd53..ea0715f 100644 --- a/tests/loss/test_transmission_loss.py +++ b/tests/challenges/ceviche/test_transmission_loss.py @@ -3,7 +3,7 @@ import functools import unittest -from invrs_gym.loss import transmission_loss +from invrs_gym.challenges.ceviche import transmission_loss class OrthotopeSmoothLossTest(unittest.TestCase): diff --git a/tests/challenge/diffract/broadband_metagrating_designs/device1.csv b/tests/challenges/diffract/broadband_metagrating_designs/device1.csv similarity index 100% rename from tests/challenge/diffract/broadband_metagrating_designs/device1.csv rename to tests/challenges/diffract/broadband_metagrating_designs/device1.csv diff --git a/tests/challenge/diffract/metagrating_designs/device1.csv b/tests/challenges/diffract/metagrating_designs/device1.csv similarity index 100% rename from tests/challenge/diffract/metagrating_designs/device1.csv rename to tests/challenges/diffract/metagrating_designs/device1.csv diff --git a/tests/challenge/diffract/metagrating_designs/device2.csv b/tests/challenges/diffract/metagrating_designs/device2.csv similarity index 100% rename from tests/challenge/diffract/metagrating_designs/device2.csv rename to tests/challenges/diffract/metagrating_designs/device2.csv diff --git a/tests/challenge/diffract/metagrating_designs/device3.csv b/tests/challenges/diffract/metagrating_designs/device3.csv similarity index 100% rename from tests/challenge/diffract/metagrating_designs/device3.csv rename to tests/challenges/diffract/metagrating_designs/device3.csv diff --git a/tests/challenge/diffract/metagrating_designs/device4.csv b/tests/challenges/diffract/metagrating_designs/device4.csv similarity index 100% rename from tests/challenge/diffract/metagrating_designs/device4.csv rename to tests/challenges/diffract/metagrating_designs/device4.csv diff --git a/tests/challenge/diffract/metagrating_designs/device5.csv b/tests/challenges/diffract/metagrating_designs/device5.csv similarity index 100% rename from tests/challenge/diffract/metagrating_designs/device5.csv rename to tests/challenges/diffract/metagrating_designs/device5.csv diff --git a/tests/challenge/diffract/splitter_designs/device1.csv b/tests/challenges/diffract/splitter_designs/device1.csv similarity index 100% rename from tests/challenge/diffract/splitter_designs/device1.csv rename to tests/challenges/diffract/splitter_designs/device1.csv diff --git a/tests/challenge/diffract/splitter_designs/device2.csv b/tests/challenges/diffract/splitter_designs/device2.csv similarity index 100% rename from tests/challenge/diffract/splitter_designs/device2.csv rename to tests/challenges/diffract/splitter_designs/device2.csv diff --git a/tests/challenge/diffract/splitter_designs/device3.csv b/tests/challenges/diffract/splitter_designs/device3.csv similarity index 100% rename from tests/challenge/diffract/splitter_designs/device3.csv rename to tests/challenges/diffract/splitter_designs/device3.csv diff --git a/tests/challenge/diffract/test_common.py b/tests/challenges/diffract/test_common.py similarity index 95% rename from tests/challenge/diffract/test_common.py rename to tests/challenges/diffract/test_common.py index 17a5165..cccbd5a 100644 --- a/tests/challenge/diffract/test_common.py +++ b/tests/challenges/diffract/test_common.py @@ -7,7 +7,7 @@ from fmmax import basis from jax import tree_util -from invrs_gym.challenge.diffract import common +from invrs_gym.challenges.diffract import common class GatingResponseTest(unittest.TestCase): diff --git a/tests/challenge/diffract/test_metagrating_challenge.py b/tests/challenges/diffract/test_metagrating_challenge.py similarity index 98% rename from tests/challenge/diffract/test_metagrating_challenge.py rename to tests/challenges/diffract/test_metagrating_challenge.py index 155a79b..13a3124 100644 --- a/tests/challenge/diffract/test_metagrating_challenge.py +++ b/tests/challenges/diffract/test_metagrating_challenge.py @@ -10,7 +10,7 @@ from parameterized import parameterized from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore] -from invrs_gym.challenge.diffract import metagrating_challenge +from invrs_gym.challenges.diffract import metagrating_challenge LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace( metagrating_challenge.METAGRATING_SIM_PARAMS, diff --git a/tests/challenge/diffract/test_reference_devices.py b/tests/challenges/diffract/test_reference_devices.py similarity index 98% rename from tests/challenge/diffract/test_reference_devices.py rename to tests/challenges/diffract/test_reference_devices.py index b1aa90b..e34cbb4 100644 --- a/tests/challenge/diffract/test_reference_devices.py +++ b/tests/challenges/diffract/test_reference_devices.py @@ -10,7 +10,7 @@ import pytest from parameterized import parameterized -from invrs_gym.challenge.diffract import metagrating_challenge, splitter_challenge +from invrs_gym.challenges.diffract import metagrating_challenge, splitter_challenge PARENT_PATH = pathlib.Path(__file__).resolve().parent BROADBAND_METAGRATING_DIR = PARENT_PATH / "broadband_metagrating_designs" diff --git a/tests/challenge/diffract/test_splitter_challenge.py b/tests/challenges/diffract/test_splitter_challenge.py similarity index 98% rename from tests/challenge/diffract/test_splitter_challenge.py rename to tests/challenges/diffract/test_splitter_challenge.py index e187ae8..607f765 100644 --- a/tests/challenge/diffract/test_splitter_challenge.py +++ b/tests/challenges/diffract/test_splitter_challenge.py @@ -9,7 +9,7 @@ from fmmax import fmm from parameterized import parameterized -from invrs_gym.challenge.diffract import splitter_challenge +from invrs_gym.challenges.diffract import splitter_challenge LIGHTWEIGHT_SIM_PARAMS = dataclasses.replace( splitter_challenge.DIFFRACTIVE_SPLITTER_SIM_PARAMS, diff --git a/tests/challenge/extractor/designs/device1.csv b/tests/challenges/extractor/designs/device1.csv similarity index 100% rename from tests/challenge/extractor/designs/device1.csv rename to tests/challenges/extractor/designs/device1.csv diff --git a/tests/challenge/extractor/test_challenge.py b/tests/challenges/extractor/test_challenge.py similarity index 98% rename from tests/challenge/extractor/test_challenge.py rename to tests/challenges/extractor/test_challenge.py index 3f79021..bbc2dc3 100644 --- a/tests/challenge/extractor/test_challenge.py +++ b/tests/challenges/extractor/test_challenge.py @@ -10,7 +10,7 @@ from parameterized import parameterized from totypes import symmetry -from invrs_gym.challenge.extractor import challenge +from invrs_gym.challenges.extractor import challenge class SplitterChallengeTest(unittest.TestCase): diff --git a/tests/challenge/extractor/test_component.py b/tests/challenges/extractor/test_component.py similarity index 98% rename from tests/challenge/extractor/test_component.py rename to tests/challenges/extractor/test_component.py index 7cb37fd..c53ec57 100644 --- a/tests/challenge/extractor/test_component.py +++ b/tests/challenges/extractor/test_component.py @@ -9,7 +9,7 @@ from fmmax import fmm from jax import tree_util -from invrs_gym.challenge.extractor import challenge, component +from invrs_gym.challenges.extractor import challenge, component class ExtractorComponentTest(unittest.TestCase): diff --git a/tests/challenge/extractor/test_reference_devices.py b/tests/challenges/extractor/test_reference_devices.py similarity index 99% rename from tests/challenge/extractor/test_reference_devices.py rename to tests/challenges/extractor/test_reference_devices.py index 00f7745..5597f5f 100644 --- a/tests/challenge/extractor/test_reference_devices.py +++ b/tests/challenges/extractor/test_reference_devices.py @@ -10,7 +10,7 @@ import pytest from fmmax import basis -from invrs_gym.challenge.extractor import challenge +from invrs_gym.challenges.extractor import challenge DESIGNS_DIR = pathlib.Path(__file__).resolve().parent / "designs" diff --git a/tests/utils/test_optimizer.py b/tests/utils/test_optimizer.py index 70c9deb..8412878 100644 --- a/tests/utils/test_optimizer.py +++ b/tests/utils/test_optimizer.py @@ -4,14 +4,14 @@ import optax -from invrs_gym import challenge +from invrs_gym import challenges from invrs_gym.utils import optimizer class OptimizerTest(unittest.TestCase): def test_can_optimize(self): params, state, step_fn = optimizer.setup_optimization( - challenge=challenge.metagrating(), + challenge=challenges.metagrating(), optimizer=optax.adam(0.02), ) From 4b2ba5a3e15b8c0115cf35487cb1cf14f0b8a667 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 17 Oct 2023 15:20:22 -0700 Subject: [PATCH 2/5] Rename resist -> oxide --- src/invrs_gym/challenges/__init__.py | 30 ++++++++++++++----- .../challenges/extractor/challenge.py | 4 +-- .../challenges/extractor/component.py | 18 +++++------ 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/invrs_gym/challenges/__init__.py b/src/invrs_gym/challenges/__init__.py index 08b2d30..ccef82b 100644 --- a/src/invrs_gym/challenges/__init__.py +++ b/src/invrs_gym/challenges/__init__.py @@ -1,11 +1,25 @@ -from invrs_gym.challenges.ceviche.challenge import beam_splitter as ceviche_beam_splitter -from invrs_gym.challenges.ceviche.challenge import mode_converter as ceviche_mode_converter -from invrs_gym.challenges.ceviche.challenge import waveguide_bend as ceviche_waveguide_bend +from invrs_gym.challenges.ceviche.challenge import ( + beam_splitter as ceviche_beam_splitter, +) +from invrs_gym.challenges.ceviche.challenge import ( + mode_converter as ceviche_mode_converter, +) +from invrs_gym.challenges.ceviche.challenge import ( + waveguide_bend as ceviche_waveguide_bend, +) from invrs_gym.challenges.ceviche.challenge import wdm as ceviche_wdm -from invrs_gym.challenges.ceviche.challenge import lightweight_beam_splitter as ceviche_lightweight_beam_splitter -from invrs_gym.challenges.ceviche.challenge import lightweight_mode_converter as ceviche_lightweight_mode_converter -from invrs_gym.challenges.ceviche.challenge import lightweight_waveguide_bend as ceviche_lightweight_waveguide_bend -from invrs_gym.challenges.ceviche.challenge import lightweight_wdm as ceviche_lightweight_wdm +from invrs_gym.challenges.ceviche.challenge import ( + lightweight_beam_splitter as ceviche_lightweight_beam_splitter, +) +from invrs_gym.challenges.ceviche.challenge import ( + lightweight_mode_converter as ceviche_lightweight_mode_converter, +) +from invrs_gym.challenges.ceviche.challenge import ( + lightweight_waveguide_bend as ceviche_lightweight_waveguide_bend, +) +from invrs_gym.challenges.ceviche.challenge import ( + lightweight_wdm as ceviche_lightweight_wdm, +) from invrs_gym.challenges.diffract.metagrating_challenge import ( broadband_metagrating, @@ -30,4 +44,4 @@ "broadband_metagrating": broadband_metagrating, "diffractive_splitter": diffractive_splitter, "photon_extractor": photon_extractor, -} \ No newline at end of file +} diff --git a/src/invrs_gym/challenges/extractor/challenge.py b/src/invrs_gym/challenges/extractor/challenge.py index 7a4d6f8..d94e481 100644 --- a/src/invrs_gym/challenges/extractor/challenge.py +++ b/src/invrs_gym/challenges/extractor/challenge.py @@ -96,11 +96,11 @@ def metrics( EXTRACTOR_SPEC = extractor_component.ExtractorSpec( permittivity_ambient=(1.0 + 0.0j) ** 2, - permittivity_resist=(1.46 + 0.0j) ** 2, + permittivity_oxide=(1.46 + 0.0j) ** 2, permittivity_extractor=(3.31 + 0.0j) ** 2, permittivity_substrate=(2.4102 + 0.0j) ** 2, thickness_ambient=1.0, - thickness_resist=0.13, + thickness_oxide=0.13, thickness_extractor=0.25, thickness_substrate_before_source=0.1, thickness_substrate_after_source=0.9, diff --git a/src/invrs_gym/challenges/extractor/component.py b/src/invrs_gym/challenges/extractor/component.py index 4ae39dc..6f630d0 100644 --- a/src/invrs_gym/challenges/extractor/component.py +++ b/src/invrs_gym/challenges/extractor/component.py @@ -41,11 +41,11 @@ class ExtractorSpec: Args: permittivity_ambient: Permittivity of the ambient material. - permittivity_resist: Permittivity of the resist material. + permittivity_oxide: Permittivity of the oxide material. permittivity_extractor: Permittivity of the extractor material. permittivity_substrate: Permittivity of the substrate. thickness_ambient: The thickness of the ambient layer. - thickness_resist: The thickness of the resist layer. + thickness_oxide: The thickness of the oxide layer. thickness_extractor: The thickness of the extractor layer. thickness_substrate_before_source: The distance between the substrate and the plane containing the source. @@ -67,12 +67,12 @@ class ExtractorSpec: """ permittivity_ambient: complex - permittivity_resist: complex + permittivity_oxide: complex permittivity_extractor: complex permittivity_substrate: complex thickness_ambient: float - thickness_resist: float + thickness_oxide: float thickness_extractor: float thickness_substrate_before_source: float thickness_substrate_after_source: float @@ -182,7 +182,7 @@ def __init__( # to ensure gridpoints are correctly spaced. self.layer_znum = ( _num_gridpoints(spec.thickness_ambient) + 1, - _num_gridpoints(spec.thickness_resist) + 1, + _num_gridpoints(spec.thickness_oxide) + 1, _num_gridpoints(spec.thickness_extractor) + 1, _num_gridpoints(spec.thickness_substrate_before_source) + 1, _num_gridpoints(spec.thickness_substrate_after_source) + 1, @@ -359,9 +359,9 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult: solve_result_ambient = eigensolve_pml( permittivity=jnp.full(grid_shape, spec.permittivity_ambient) ) - solve_result_resist = eigensolve_pml( + solve_result_oxide = eigensolve_pml( permittivity=utils.interpolate_permittivity( - permittivity_solid=spec.permittivity_resist, + permittivity_solid=spec.permittivity_oxide, permittivity_void=spec.permittivity_ambient, density=density_array, ), @@ -379,14 +379,14 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult: layer_solve_results = ( solve_result_ambient, - solve_result_resist, + solve_result_oxide, solve_result_extractor, solve_result_substrate, # Before the source. solve_result_substrate, # After the source. ) layer_thicknesses = ( jnp.asarray(spec.thickness_ambient), - jnp.asarray(spec.thickness_resist), + jnp.asarray(spec.thickness_oxide), jnp.asarray(spec.thickness_extractor), jnp.asarray(spec.thickness_substrate_before_source), jnp.asarray(spec.thickness_substrate_after_source), From f23b30f78d7ff00ec7f2f3f0ea72ca35923a25b3 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 17 Oct 2023 15:25:30 -0700 Subject: [PATCH 3/5] Update mypy ignore statements --- src/invrs_gym/challenges/ceviche/challenge.py | 4 ++-- src/invrs_gym/challenges/ceviche/defaults.py | 6 +++--- src/invrs_gym/challenges/diffract/common.py | 4 ++-- src/invrs_gym/challenges/diffract/metagrating_challenge.py | 4 ++-- src/invrs_gym/challenges/diffract/splitter_challenge.py | 4 ++-- src/invrs_gym/challenges/extractor/challenge.py | 4 ++-- src/invrs_gym/challenges/extractor/component.py | 2 +- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/invrs_gym/challenges/ceviche/challenge.py b/src/invrs_gym/challenges/ceviche/challenge.py index bede776..2566bb9 100644 --- a/src/invrs_gym/challenges/ceviche/challenge.py +++ b/src/invrs_gym/challenges/ceviche/challenge.py @@ -4,11 +4,11 @@ import functools from typing import Any, Callable, Dict, Optional, Sequence, Tuple -import agjax # type: ignore[import] +import agjax # type: ignore[import-untyped] import jax import jax.numpy as jnp import numpy as onp -from totypes import types # type: ignore[import,attr-defined,unused-ignore] +from totypes import types # type: ignore[import-untyped] from invrs_gym.challenges.ceviche import defaults from invrs_gym.challenges.ceviche import transmission_loss diff --git a/src/invrs_gym/challenges/ceviche/defaults.py b/src/invrs_gym/challenges/ceviche/defaults.py index ea15ee7..dfe8224 100644 --- a/src/invrs_gym/challenges/ceviche/defaults.py +++ b/src/invrs_gym/challenges/ceviche/defaults.py @@ -3,15 +3,15 @@ from typing import Union import jax.numpy as jnp -from ceviche_challenges import ( # type: ignore[import] +from ceviche_challenges import ( # type: ignore[import-untyped] beam_splitter, mode_converter, model_base, params, ) from ceviche_challenges import units as u -from ceviche_challenges import waveguide_bend, wdm # type: ignore[import] -from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore] +from ceviche_challenges import waveguide_bend, wdm # type: ignore[import-untyped] +from totypes import symmetry # type: ignore[import-untyped] DeviceSpec = Union[ beam_splitter.spec.BeamSplitterSpec, diff --git a/src/invrs_gym/challenges/diffract/common.py b/src/invrs_gym/challenges/diffract/common.py index 4a0a6ec..7ad0f57 100644 --- a/src/invrs_gym/challenges/diffract/common.py +++ b/src/invrs_gym/challenges/diffract/common.py @@ -6,9 +6,9 @@ import jax import jax.numpy as jnp import numpy as onp -from fmmax import basis, fields, fmm, scattering, utils # type: ignore[import] +from fmmax import basis, fields, fmm, scattering, utils # type: ignore[import-untyped] from jax import tree_util -from totypes import types # type: ignore[import,attr-defined,unused-ignore] +from totypes import types # type: ignore[import-untyped] AuxDict = Dict[str, Any] Params = Any diff --git a/src/invrs_gym/challenges/diffract/metagrating_challenge.py b/src/invrs_gym/challenges/diffract/metagrating_challenge.py index 879d657..392c500 100644 --- a/src/invrs_gym/challenges/diffract/metagrating_challenge.py +++ b/src/invrs_gym/challenges/diffract/metagrating_challenge.py @@ -5,8 +5,8 @@ import jax import jax.numpy as jnp -from fmmax import basis, fmm # type: ignore[import] -from totypes import symmetry, types # type: ignore[import,attr-defined,unused-ignore] +from fmmax import basis, fmm # type: ignore[import-untyped] +from totypes import symmetry, types # type: ignore[import-untyped] from invrs_gym.challenges.diffract import common diff --git a/src/invrs_gym/challenges/diffract/splitter_challenge.py b/src/invrs_gym/challenges/diffract/splitter_challenge.py index f939e61..8e25294 100644 --- a/src/invrs_gym/challenges/diffract/splitter_challenge.py +++ b/src/invrs_gym/challenges/diffract/splitter_challenge.py @@ -6,8 +6,8 @@ import jax import jax.numpy as jnp -from fmmax import basis, fmm # type: ignore[import] -from totypes import types # type: ignore[import,attr-defined,unused-ignore] +from fmmax import basis, fmm # type: ignore[import-untyped] +from totypes import types # type: ignore[import-untyped] from invrs_gym.challenges.diffract import common diff --git a/src/invrs_gym/challenges/extractor/challenge.py b/src/invrs_gym/challenges/extractor/challenge.py index d94e481..b907559 100644 --- a/src/invrs_gym/challenges/extractor/challenge.py +++ b/src/invrs_gym/challenges/extractor/challenge.py @@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, Tuple import jax -from fmmax import basis, fmm # type: ignore[import] +from fmmax import basis, fmm # type: ignore[import-untyped] from jax import numpy as jnp from jax import tree_util -from totypes import symmetry, types # type: ignore[import,attr-defined,unused-ignore] +from totypes import symmetry, types # type: ignore[import-untyped] from invrs_gym.challenges.extractor import component as extractor_component diff --git a/src/invrs_gym/challenges/extractor/component.py b/src/invrs_gym/challenges/extractor/component.py index 6f630d0..d3b4ea4 100644 --- a/src/invrs_gym/challenges/extractor/component.py +++ b/src/invrs_gym/challenges/extractor/component.py @@ -16,7 +16,7 @@ utils, ) from jax import tree_util -from totypes import types # type: ignore[import,attr-defined,unused-ignore] +from totypes import types # type: ignore[import-untyped] AuxDict = Dict[str, Any] DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray] From 65ddbec9fdc5d0fbcb54e4f48959f24ca5d169a0 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 17 Oct 2023 15:52:41 -0700 Subject: [PATCH 4/5] Mypy adjustments --- src/invrs_gym/challenges/ceviche/challenge.py | 2 +- src/invrs_gym/challenges/ceviche/defaults.py | 2 +- src/invrs_gym/challenges/diffract/common.py | 2 +- .../challenges/diffract/metagrating_challenge.py | 4 ++-- .../challenges/diffract/splitter_challenge.py | 14 +++++++------- src/invrs_gym/challenges/extractor/challenge.py | 2 +- src/invrs_gym/challenges/extractor/component.py | 4 ++-- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/invrs_gym/challenges/ceviche/challenge.py b/src/invrs_gym/challenges/ceviche/challenge.py index 2566bb9..3f4bd6f 100644 --- a/src/invrs_gym/challenges/ceviche/challenge.py +++ b/src/invrs_gym/challenges/ceviche/challenge.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp import numpy as onp -from totypes import types # type: ignore[import-untyped] +from totypes import types from invrs_gym.challenges.ceviche import defaults from invrs_gym.challenges.ceviche import transmission_loss diff --git a/src/invrs_gym/challenges/ceviche/defaults.py b/src/invrs_gym/challenges/ceviche/defaults.py index dfe8224..8f5af09 100644 --- a/src/invrs_gym/challenges/ceviche/defaults.py +++ b/src/invrs_gym/challenges/ceviche/defaults.py @@ -11,7 +11,7 @@ ) from ceviche_challenges import units as u from ceviche_challenges import waveguide_bend, wdm # type: ignore[import-untyped] -from totypes import symmetry # type: ignore[import-untyped] +from totypes import symmetry DeviceSpec = Union[ beam_splitter.spec.BeamSplitterSpec, diff --git a/src/invrs_gym/challenges/diffract/common.py b/src/invrs_gym/challenges/diffract/common.py index 7ad0f57..57195c7 100644 --- a/src/invrs_gym/challenges/diffract/common.py +++ b/src/invrs_gym/challenges/diffract/common.py @@ -8,7 +8,7 @@ import numpy as onp from fmmax import basis, fields, fmm, scattering, utils # type: ignore[import-untyped] from jax import tree_util -from totypes import types # type: ignore[import-untyped] +from totypes import types AuxDict = Dict[str, Any] Params = Any diff --git a/src/invrs_gym/challenges/diffract/metagrating_challenge.py b/src/invrs_gym/challenges/diffract/metagrating_challenge.py index 392c500..542516b 100644 --- a/src/invrs_gym/challenges/diffract/metagrating_challenge.py +++ b/src/invrs_gym/challenges/diffract/metagrating_challenge.py @@ -6,7 +6,7 @@ import jax import jax.numpy as jnp from fmmax import basis, fmm # type: ignore[import-untyped] -from totypes import symmetry, types # type: ignore[import-untyped] +from totypes import symmetry, types from invrs_gym.challenges.diffract import common @@ -82,7 +82,7 @@ def response( if wavelength is None: wavelength = self.sim_params.wavelength transmission_efficiency, reflection_efficiency = common.grating_efficiency( - density_array=params.array, + density_array=params.array, # type: ignore[arg-type] thickness=jnp.asarray(self.spec.thickness_grating), spec=self.spec, wavelength=jnp.asarray(wavelength), diff --git a/src/invrs_gym/challenges/diffract/splitter_challenge.py b/src/invrs_gym/challenges/diffract/splitter_challenge.py index 8e25294..4bf7bac 100644 --- a/src/invrs_gym/challenges/diffract/splitter_challenge.py +++ b/src/invrs_gym/challenges/diffract/splitter_challenge.py @@ -7,13 +7,13 @@ import jax import jax.numpy as jnp from fmmax import basis, fmm # type: ignore[import-untyped] -from totypes import types # type: ignore[import-untyped] +from totypes import types from invrs_gym.challenges.diffract import common -PyTree = Any +Params = Dict[str, types.BoundedArray | types.Density2DArray] AuxDict = Dict[str, Any] -ThicknessInitializer = Callable[[jax.Array, jnp.ndarray], jnp.ndarray] +ThicknessInitializer = Callable[[jax.Array, types.BoundedArray], types.BoundedArray] DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray] @@ -77,7 +77,7 @@ def __init__( truncation=self.sim_params.truncation, ) - def init(self, key: jax.Array) -> PyTree: + def init(self, key: jax.Array) -> Params: """Return the initial parameters for the diffractive splitter component.""" key_thickness, key_density = jax.random.split(key) return { @@ -87,7 +87,7 @@ def init(self, key: jax.Array) -> PyTree: def response( self, - params: types.Density2DArray, + params: Params, wavelength: Optional[Union[float, jnp.ndarray]] = None, expansion: Optional[basis.Expansion] = None, ) -> Tuple[common.GratingResponse, AuxDict]: @@ -107,8 +107,8 @@ def response( if wavelength is None: wavelength = self.sim_params.wavelength transmission_efficiency, reflection_efficiency = common.grating_efficiency( - density_array=params[DENSITY].array, - thickness=params[THICKNESS].array, + density_array=params[DENSITY].array, # type: ignore[arg-type] + thickness=params[THICKNESS].array, # type: ignore[arg-type] spec=self.spec, wavelength=jnp.asarray(wavelength), polarization=self.sim_params.polarization, diff --git a/src/invrs_gym/challenges/extractor/challenge.py b/src/invrs_gym/challenges/extractor/challenge.py index b907559..fba1af1 100644 --- a/src/invrs_gym/challenges/extractor/challenge.py +++ b/src/invrs_gym/challenges/extractor/challenge.py @@ -7,7 +7,7 @@ from fmmax import basis, fmm # type: ignore[import-untyped] from jax import numpy as jnp from jax import tree_util -from totypes import symmetry, types # type: ignore[import-untyped] +from totypes import symmetry, types from invrs_gym.challenges.extractor import component as extractor_component diff --git a/src/invrs_gym/challenges/extractor/component.py b/src/invrs_gym/challenges/extractor/component.py index d3b4ea4..cce4e3f 100644 --- a/src/invrs_gym/challenges/extractor/component.py +++ b/src/invrs_gym/challenges/extractor/component.py @@ -16,7 +16,7 @@ utils, ) from jax import tree_util -from totypes import types # type: ignore[import-untyped] +from totypes import types AuxDict = Dict[str, Any] DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray] @@ -232,7 +232,7 @@ def response( wavelength = self.sim_params.wavelength return simulate_extractor( - density_array=params.array, + density_array=params.array, # type: ignore[arg-type] spec=self.spec, layer_znum=self.layer_znum, wavelength=jnp.asarray(wavelength), From d884ce9b5ac621438f11bc046abddb88238d7c1b Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 17 Oct 2023 15:56:31 -0700 Subject: [PATCH 5/5] Add py.typed file --- pyproject.toml | 3 +++ src/invrs_gym/py.typed | 0 2 files changed, 3 insertions(+) create mode 100644 src/invrs_gym/py.typed diff --git a/pyproject.toml b/pyproject.toml index 7487e2d..48f63be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ dev = [ requires = ["setuptools>=45", "wheel"] build-backend = "setuptools.build_meta" +[tool.setuptools.package-data] +"invrs_gym" = ["py.typed"] + [tool.black] line-length = 88 target-version = ['py310'] diff --git a/src/invrs_gym/py.typed b/src/invrs_gym/py.typed new file mode 100644 index 0000000..e69de29