Skip to content

Commit

Permalink
Merge pull request #23 from invrs-io/refactor
Browse files Browse the repository at this point in the history
Refactor
  • Loading branch information
mfschubert authored Oct 17, 2023
2 parents 002de81 + d884ce9 commit 8e91d14
Show file tree
Hide file tree
Showing 36 changed files with 120 additions and 101 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ dev = [
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools.package-data]
"invrs_gym" = ["py.typed"]

[tool.black]
line-length = 88
target-version = ['py310']
Expand Down
31 changes: 0 additions & 31 deletions src/invrs_gym/challenge/__init__.py

This file was deleted.

47 changes: 47 additions & 0 deletions src/invrs_gym/challenges/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from invrs_gym.challenges.ceviche.challenge import (
beam_splitter as ceviche_beam_splitter,
)
from invrs_gym.challenges.ceviche.challenge import (
mode_converter as ceviche_mode_converter,
)
from invrs_gym.challenges.ceviche.challenge import (
waveguide_bend as ceviche_waveguide_bend,
)
from invrs_gym.challenges.ceviche.challenge import wdm as ceviche_wdm
from invrs_gym.challenges.ceviche.challenge import (
lightweight_beam_splitter as ceviche_lightweight_beam_splitter,
)
from invrs_gym.challenges.ceviche.challenge import (
lightweight_mode_converter as ceviche_lightweight_mode_converter,
)
from invrs_gym.challenges.ceviche.challenge import (
lightweight_waveguide_bend as ceviche_lightweight_waveguide_bend,
)
from invrs_gym.challenges.ceviche.challenge import (
lightweight_wdm as ceviche_lightweight_wdm,
)

from invrs_gym.challenges.diffract.metagrating_challenge import (
broadband_metagrating,
metagrating,
)

from invrs_gym.challenges.diffract.splitter_challenge import diffractive_splitter

from invrs_gym.challenges.extractor.challenge import photon_extractor


BY_NAME = {
"ceviche_beam_splitter": ceviche_beam_splitter,
"ceviche_mode_converter": ceviche_mode_converter,
"ceviche_waveguide_bend": ceviche_waveguide_bend,
"ceviche_wdm": ceviche_wdm,
"ceviche_lightweight_beam_splitter": ceviche_lightweight_beam_splitter,
"ceviche_lightweight_mode_converter": ceviche_lightweight_mode_converter,
"ceviche_lightweight_waveguide_bend": ceviche_lightweight_waveguide_bend,
"ceviche_lightweight_wdm": ceviche_lightweight_wdm,
"metagrating": metagrating,
"broadband_metagrating": broadband_metagrating,
"diffractive_splitter": diffractive_splitter,
"photon_extractor": photon_extractor,
}
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import functools
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

import agjax # type: ignore[import]
import agjax # type: ignore[import-untyped]
import jax
import jax.numpy as jnp
import numpy as onp
from totypes import types # type: ignore[import,attr-defined,unused-ignore]
from totypes import types

from invrs_gym.challenge.ceviche import defaults
from invrs_gym.loss import transmission_loss
from invrs_gym.challenges.ceviche import defaults
from invrs_gym.challenges.ceviche import transmission_loss

AuxDict = Dict[str, Any]
Params = Any
Expand Down Expand Up @@ -271,7 +271,7 @@ def _wavelength_bound(
# -----------------------------------------------------------------------------


def beam_splitter_challenge(
def beam_splitter(
minimum_width: int = defaults.MINIMUM_WIDTH,
minimum_spacing: int = defaults.MINIMUM_SPACING,
density_initializer: DensityInitializer = identity_initializer,
Expand All @@ -290,7 +290,7 @@ def beam_splitter_challenge(
)


def lightweight_beam_splitter_challenge(
def lightweight_beam_splitter(
minimum_width: int = defaults.LIGHTWEIGHT_MINIMUM_WIDTH,
minimum_spacing: int = defaults.LIGHTWEIGHT_MINIMUM_SPACING,
density_initializer: DensityInitializer = identity_initializer,
Expand All @@ -309,7 +309,7 @@ def lightweight_beam_splitter_challenge(
)


def mode_converter_challenge(
def mode_converter(
minimum_width: int = defaults.MINIMUM_WIDTH,
minimum_spacing: int = defaults.MINIMUM_SPACING,
density_initializer: DensityInitializer = identity_initializer,
Expand All @@ -327,7 +327,7 @@ def mode_converter_challenge(
)


def lightweight_mode_converter_challenge(
def lightweight_mode_converter(
minimum_width: int = defaults.LIGHTWEIGHT_MINIMUM_WIDTH,
minimum_spacing: int = defaults.LIGHTWEIGHT_MINIMUM_SPACING,
density_initializer: DensityInitializer = identity_initializer,
Expand All @@ -345,7 +345,7 @@ def lightweight_mode_converter_challenge(
)


def waveguide_bend_challenge(
def waveguide_bend(
minimum_width: int = defaults.MINIMUM_WIDTH,
minimum_spacing: int = defaults.MINIMUM_SPACING,
density_initializer: DensityInitializer = identity_initializer,
Expand All @@ -364,7 +364,7 @@ def waveguide_bend_challenge(
)


def lightweight_waveguide_bend_challenge(
def lightweight_waveguide_bend(
minimum_width: int = defaults.LIGHTWEIGHT_MINIMUM_WIDTH,
minimum_spacing: int = defaults.LIGHTWEIGHT_MINIMUM_SPACING,
density_initializer: DensityInitializer = identity_initializer,
Expand All @@ -383,7 +383,7 @@ def lightweight_waveguide_bend_challenge(
)


def wdm_challenge(
def wdm(
minimum_width: int = defaults.MINIMUM_WIDTH,
minimum_spacing: int = defaults.MINIMUM_SPACING,
density_initializer: DensityInitializer = identity_initializer,
Expand All @@ -401,7 +401,7 @@ def wdm_challenge(
)


def lightweight_wdm_challenge(
def lightweight_wdm(
minimum_width: int = defaults.LIGHTWEIGHT_MINIMUM_WIDTH,
minimum_spacing: int = defaults.LIGHTWEIGHT_MINIMUM_SPACING,
density_initializer: DensityInitializer = identity_initializer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from typing import Union

import jax.numpy as jnp
from ceviche_challenges import ( # type: ignore[import]
from ceviche_challenges import ( # type: ignore[import-untyped]
beam_splitter,
mode_converter,
model_base,
params,
)
from ceviche_challenges import units as u
from ceviche_challenges import waveguide_bend, wdm # type: ignore[import]
from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore]
from ceviche_challenges import waveguide_bend, wdm # type: ignore[import-untyped]
from totypes import symmetry

DeviceSpec = Union[
beam_splitter.spec.BeamSplitterSpec,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import jax
import jax.numpy as jnp
import numpy as onp
from fmmax import basis, fields, fmm, scattering, utils # type: ignore[import]
from fmmax import basis, fields, fmm, scattering, utils # type: ignore[import-untyped]
from jax import tree_util
from totypes import types # type: ignore[import,attr-defined,unused-ignore]
from totypes import types

AuxDict = Dict[str, Any]
Params = Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import jax
import jax.numpy as jnp
from fmmax import basis, fmm # type: ignore[import]
from totypes import symmetry, types # type: ignore[import,attr-defined,unused-ignore]
from fmmax import basis, fmm # type: ignore[import-untyped]
from totypes import symmetry, types

from invrs_gym.challenge.diffract import common
from invrs_gym.challenges.diffract import common

AuxDict = Dict[str, Any]
DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray]
Expand Down Expand Up @@ -82,7 +82,7 @@ def response(
if wavelength is None:
wavelength = self.sim_params.wavelength
transmission_efficiency, reflection_efficiency = common.grating_efficiency(
density_array=params.array,
density_array=params.array, # type: ignore[arg-type]
thickness=jnp.asarray(self.spec.thickness_grating),
spec=self.spec,
wavelength=jnp.asarray(wavelength),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import jax
import jax.numpy as jnp
from fmmax import basis, fmm # type: ignore[import]
from totypes import types # type: ignore[import,attr-defined,unused-ignore]
from fmmax import basis, fmm # type: ignore[import-untyped]
from totypes import types

from invrs_gym.challenge.diffract import common
from invrs_gym.challenges.diffract import common

PyTree = Any
Params = Dict[str, types.BoundedArray | types.Density2DArray]
AuxDict = Dict[str, Any]
ThicknessInitializer = Callable[[jax.Array, jnp.ndarray], jnp.ndarray]
ThicknessInitializer = Callable[[jax.Array, types.BoundedArray], types.BoundedArray]
DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray]


Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
truncation=self.sim_params.truncation,
)

def init(self, key: jax.Array) -> PyTree:
def init(self, key: jax.Array) -> Params:
"""Return the initial parameters for the diffractive splitter component."""
key_thickness, key_density = jax.random.split(key)
return {
Expand All @@ -87,7 +87,7 @@ def init(self, key: jax.Array) -> PyTree:

def response(
self,
params: types.Density2DArray,
params: Params,
wavelength: Optional[Union[float, jnp.ndarray]] = None,
expansion: Optional[basis.Expansion] = None,
) -> Tuple[common.GratingResponse, AuxDict]:
Expand All @@ -107,8 +107,8 @@ def response(
if wavelength is None:
wavelength = self.sim_params.wavelength
transmission_efficiency, reflection_efficiency = common.grating_efficiency(
density_array=params[DENSITY].array,
thickness=params[THICKNESS].array,
density_array=params[DENSITY].array, # type: ignore[arg-type]
thickness=params[THICKNESS].array, # type: ignore[arg-type]
spec=self.spec,
wavelength=jnp.asarray(wavelength),
polarization=self.sim_params.polarization,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Any, Callable, Dict, Tuple

import jax
from fmmax import basis, fmm # type: ignore[import]
from fmmax import basis, fmm # type: ignore[import-untyped]
from jax import numpy as jnp
from jax import tree_util
from totypes import symmetry, types # type: ignore[import,attr-defined,unused-ignore]
from totypes import symmetry, types

from invrs_gym.challenge.extractor import component as extractor_component
from invrs_gym.challenges.extractor import component as extractor_component

AuxDict = Dict[str, Any]
DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray]
Expand Down Expand Up @@ -96,11 +96,11 @@ def metrics(

EXTRACTOR_SPEC = extractor_component.ExtractorSpec(
permittivity_ambient=(1.0 + 0.0j) ** 2,
permittivity_resist=(1.46 + 0.0j) ** 2,
permittivity_oxide=(1.46 + 0.0j) ** 2,
permittivity_extractor=(3.31 + 0.0j) ** 2,
permittivity_substrate=(2.4102 + 0.0j) ** 2,
thickness_ambient=1.0,
thickness_resist=0.13,
thickness_oxide=0.13,
thickness_extractor=0.25,
thickness_substrate_before_source=0.1,
thickness_substrate_after_source=0.9,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
utils,
)
from jax import tree_util
from totypes import types # type: ignore[import,attr-defined,unused-ignore]
from totypes import types

AuxDict = Dict[str, Any]
DensityInitializer = Callable[[jax.Array, types.Density2DArray], types.Density2DArray]
Expand All @@ -41,11 +41,11 @@ class ExtractorSpec:
Args:
permittivity_ambient: Permittivity of the ambient material.
permittivity_resist: Permittivity of the resist material.
permittivity_oxide: Permittivity of the oxide material.
permittivity_extractor: Permittivity of the extractor material.
permittivity_substrate: Permittivity of the substrate.
thickness_ambient: The thickness of the ambient layer.
thickness_resist: The thickness of the resist layer.
thickness_oxide: The thickness of the oxide layer.
thickness_extractor: The thickness of the extractor layer.
thickness_substrate_before_source: The distance between the substrate and
the plane containing the source.
Expand All @@ -67,12 +67,12 @@ class ExtractorSpec:
"""

permittivity_ambient: complex
permittivity_resist: complex
permittivity_oxide: complex
permittivity_extractor: complex
permittivity_substrate: complex

thickness_ambient: float
thickness_resist: float
thickness_oxide: float
thickness_extractor: float
thickness_substrate_before_source: float
thickness_substrate_after_source: float
Expand Down Expand Up @@ -182,7 +182,7 @@ def __init__(
# to ensure gridpoints are correctly spaced.
self.layer_znum = (
_num_gridpoints(spec.thickness_ambient) + 1,
_num_gridpoints(spec.thickness_resist) + 1,
_num_gridpoints(spec.thickness_oxide) + 1,
_num_gridpoints(spec.thickness_extractor) + 1,
_num_gridpoints(spec.thickness_substrate_before_source) + 1,
_num_gridpoints(spec.thickness_substrate_after_source) + 1,
Expand Down Expand Up @@ -232,7 +232,7 @@ def response(
wavelength = self.sim_params.wavelength

return simulate_extractor(
density_array=params.array,
density_array=params.array, # type: ignore[arg-type]
spec=self.spec,
layer_znum=self.layer_znum,
wavelength=jnp.asarray(wavelength),
Expand Down Expand Up @@ -359,9 +359,9 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:
solve_result_ambient = eigensolve_pml(
permittivity=jnp.full(grid_shape, spec.permittivity_ambient)
)
solve_result_resist = eigensolve_pml(
solve_result_oxide = eigensolve_pml(
permittivity=utils.interpolate_permittivity(
permittivity_solid=spec.permittivity_resist,
permittivity_solid=spec.permittivity_oxide,
permittivity_void=spec.permittivity_ambient,
density=density_array,
),
Expand All @@ -379,14 +379,14 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:

layer_solve_results = (
solve_result_ambient,
solve_result_resist,
solve_result_oxide,
solve_result_extractor,
solve_result_substrate, # Before the source.
solve_result_substrate, # After the source.
)
layer_thicknesses = (
jnp.asarray(spec.thickness_ambient),
jnp.asarray(spec.thickness_resist),
jnp.asarray(spec.thickness_oxide),
jnp.asarray(spec.thickness_extractor),
jnp.asarray(spec.thickness_substrate_before_source),
jnp.asarray(spec.thickness_substrate_after_source),
Expand Down
File renamed without changes.
Loading

0 comments on commit 8e91d14

Please sign in to comment.