Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimizer utils #17

Merged
merged 3 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 103 additions & 74 deletions notebooks/diffract.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,161 +8,190 @@
"outputs": [],
"source": [
"from importlib import reload\n",
"\n",
"import dataclasses\n",
"import jax\n",
"from jax import tree_util\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as onp\n",
"import optax\n",
"import time\n",
"\n",
"from fmmax import basis, fmm\n",
"\n",
"from totypes import types\n",
"from invrs_gym.challenge.diffract import metagrating_challenge, splitter_challenge"
"from invrs_gym.challenge.diffract import metagrating_challenge, splitter_challenge\n",
"from invrs_gym.challenge.extractor import challenge, component"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c48cb28-598e-4155-a2ff-3e7d78587699",
"id": "d9d9f7b5-f7fe-4c09-aa37-079f7b8d1f54",
"metadata": {},
"outputs": [],
"source": [
"def optimize(challenge, opt=optax.adam(0.02), num_steps=200):\n",
"from invrs_gym.utils import optimizer\n",
"reload(optimizer)\n",
"\n",
"\n",
" def loss_fn(params):\n",
" response, aux = challenge.component.response(params)\n",
" loss = challenge.loss(response)\n",
" return loss, (response, aux)\n",
" \n",
" def clip(leaf):\n",
" (value,), treedef = tree_util.tree_flatten(leaf)\n",
" return tree_util.tree_unflatten(\n",
" treedef, (jnp.clip(value, leaf.lower_bound, leaf.upper_bound),)\n",
" )\n",
" \n",
" def transform(leaf):\n",
" if isinstance(leaf, types.BoundedArray):\n",
" return clip(leaf)\n",
" if isinstance(leaf, types.Density2DArray):\n",
" return clip(types.symmetrize_density(leaf))\n",
" return leaf \n",
" \n",
" @jax.jit\n",
" def step_fn(params, state):\n",
" (value, (response, aux)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)\n",
" metrics = challenge.metrics(response, params, aux)\n",
" updates, state = opt.update(grad, state)\n",
" params = optax.apply_updates(params, updates)\n",
" params = tree_util.tree_map(\n",
" transform, params, is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES)\n",
" )\n",
" params = tree_util.tree_map(lambda x: jnp.clip(x, 0, 1), params)\n",
" return params, state, (value, response, aux, metrics)\n",
"def optimize(challenge, opt, num_steps, response_kwargs={}, use_jit=True):\n",
" params, state, step_fn = optimizer.setup_optimization(\n",
" challenge=challenge, optimizer=opt, response_kwargs=response_kwargs\n",
" )\n",
" if use_jit:\n",
" step_fn = jax.jit(step_fn)\n",
"\n",
" params = challenge.component.init(jax.random.PRNGKey(0))\n",
" state = opt.init(params)\n",
" \n",
" loss_values = []\n",
" metrics_values = []\n",
" for i in range(num_steps):\n",
" params, state, (value, response, aux, metrics) = step_fn(params, state)\n",
" print(i, value)\n",
" loss_values.append(value)\n",
" metrics_values.append(metrics)\n",
" try:\n",
" t0 = time.time()\n",
" params, state, (value, response, aux, metrics) = step_fn(params, state)\n",
" print(f\"step {i} ({time.time() - t0:.2f}s): loss={value}\")\n",
" loss_values.append(value)\n",
" metrics_values.append(metrics)\n",
" except Exception as e:\n",
" if \"KeyboardInterrupt\" in str(e):\n",
" print(\"Terminating optimization\")\n",
" break\n",
" else:\n",
" raise e\n",
"\n",
" return loss_values, metrics_values, response, params"
" return params, response, aux, loss_values, metrics_values"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e93003c2-a171-42f9-8a71-25cf36a64f7f",
"id": "b521450f-9acc-49c4-aa0e-e3b368a6acd1",
"metadata": {},
"outputs": [],
"source": [
"(\n",
" splitter_loss_values,\n",
" splitter_metrics_values,\n",
" splitter_response,\n",
" splitter_params,\n",
") = optimize(splitter_challenge.diffractive_splitter(), num_steps=5)"
" metagrating_params,\n",
" metagrating_response,\n",
" metagrating_aux,\n",
" metagrating_loss_values,utils\n",
" metagrating_metrics_values\n",
") = optimize(\n",
" metagrating_challenge.metagrating(),\n",
" opt=optax.adam(0.02),\n",
" num_steps=200,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "919401d9-38c9-4d0b-a61f-2e5b345c1a6b",
"id": "33fd7ccb-9f29-4485-9047-deafa3b9b888",
"metadata": {},
"outputs": [],
"source": [
"for key, value in splitter_metrics_values[-1].items():\n",
" print(f\"{key}={value}\")"
"plt.figure(figsize=(8, 5))\n",
"plt.subplot(121)\n",
"plt.plot(metagrating_loss_values)\n",
"plt.plot([m[\"distance_to_window\"] for m in metagrating_metrics_values])\n",
"plt.subplot(122)\n",
"plt.imshow(jnp.tile(metagrating_params.array, (1, 1)))\n",
"plt.colorbar()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48d6daae-c79e-4852-947c-a4a17ea8482a",
"id": "a51052d2-54a7-4120-bb06-1c61f176d18d",
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(8, 5))\n",
"plt.subplot(121)\n",
"plt.plot(splitter_loss_values)\n",
"plt.subplot(122)\n",
"plt.imshow(jnp.tile(splitter_params[\"density\"].array, (2, 2)))\n",
"plt.colorbar()"
"(\n",
" splitter_params,\n",
" splitter_response,\n",
" splitter_aux,\n",
" splitter_loss_values,\n",
" splitter_metrics_values\n",
") = optimize(\n",
" splitter_challenge.diffractive_splitter(),\n",
" opt=optax.adam(0.02),\n",
" num_steps=100,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eed3eac7-b03e-4c24-86da-b1380cbfda29",
"id": "fd5bb84f-f672-4751-836b-4d05fd7da40f",
"metadata": {},
"outputs": [],
"source": [
"print(f\"thickness={splitter_params['thickness'].array}\")\n",
"for key, value in splitter_metrics_values[-1].items():\n",
" print(f\"{key}={value}\")\n",
"\n",
"plt.figure(figsize=(12, 5))\n",
"plt.subplot(131)\n",
"plt.plot(splitter_loss_values)\n",
"plt.subplot(132)\n",
"plt.imshow(jnp.tile(splitter_params[\"density\"].array, (2, 2)))\n",
"\n",
"plt.subplot(133)\n",
"plt.imshow(splitter_challenge.extract_orders_for_splitting(\n",
" splitter_response.transmission_efficiency,\n",
" splitter_response.expansion,\n",
" (9, 9),\n",
"))\n",
"plt.colorbar()"
"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b818f159-a136-4c63-8cce-37ec6a371d4c",
"id": "c16fae32-856c-4b0a-9d62-8a1efc2402bb",
"metadata": {},
"outputs": [],
"source": [
"pec = challenge.photon_extractor()\n",
"\n",
"(\n",
" metagrating_loss_values,\n",
" metagrating_metrics_values,\n",
" metagrating_response,\n",
" metagrating_params,\n",
") = optimize(metagrating_challenge.metagrating(), num_steps=5)"
" extractor_loss_values,\n",
" extractor_metrics_values,\n",
" extractor_response,\n",
" extractor_params,\n",
" extractor_aux,\n",
") = optimize(challenge=pec, opt=optax.adam(0.02), num_steps=100)\n",
"\n",
"extractor_aux = pec.component.response(extractor_params, compute_fields=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5eba56e1-7bd0-4e8d-bd6e-95e6f58b67ef",
"id": "e5da18a7-65b9-4798-b0a3-5b05d90e4697",
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(8, 5))\n",
"plt.subplot(121)\n",
"plt.plot(metagrating_loss_values)\n",
"plt.subplot(122)\n",
"plt.imshow(jnp.tile(metagrating_params.array, (2, 2)))\n",
"plt.colorbar()"
"x, y, z = extractor_aux[\"field_coordinates\"]\n",
"ex, ey, ez = extractor_aux[\"efield\"]\n",
"\n",
"print(extractor_response)\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
"xplot, zplot = jnp.meshgrid(x, z, indexing=\"ij\")\n",
"field_plot = jnp.sqrt(jnp.abs(ex)**2 + jnp.abs(ey)**2 + jnp.abs(ez)**2)\n",
"ax = plt.subplot(121)\n",
"plt.pcolormesh(xplot, zplot, field_plot[:, :, 1], cmap=\"magma\")\n",
"ax.axis(\"equal\")\n",
"ax.axis(\"off\")\n",
"ax.set_ylim(ax.get_ylim()[::-1])\n",
"\n",
"ax = plt.subplot(122)\n",
"plt.imshow(extractor_params.array, cmap=\"gray\")\n",
"ax.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "497cf599-5593-49b0-9deb-09f0a77d69a2",
"id": "e5514da1-5ae2-4d91-a4dc-997bf21aadd8",
"metadata": {},
"outputs": [],
"source": []
Expand Down
Empty file added src/invrs_gym/utils/__init__.py
Empty file.
119 changes: 119 additions & 0 deletions src/invrs_gym/utils/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Implements a basic optimization."""


from typing import Any, Callable, Dict, Protocol, Tuple

import jax
import jax.numpy as jnp
import optax # type: ignore[import]
from jax import tree_util
from totypes import types # type: ignore[import,attr-defined,unused-ignore]

AuxDict = Dict[str, Any]


class Component(Protocol):
def init(self, key: jax.Array) -> Any:
...

def response(self, params: Any) -> Tuple[Any, AuxDict]:
...


class Challenge(Protocol):
component: Component

def loss(self, response: Any) -> jnp.ndarray:
...

def metrics(self, response: Any, params: Any, aux: AuxDict) -> AuxDict:
...


def setup_optimization(
challenge: Challenge,
optimizer: optax.GradientTransformation,
response_kwargs: Dict[str, Any] = {},
) -> Tuple[
Any, # params
Any, # state
# f(params, state) -> params, state, (value, response, aux, metrics)
Callable[
[Any, Any],
Tuple[Any, Any, Tuple[jnp.ndarray, Any, AuxDict, AuxDict]],
],
]:
"""Set up the optimization of `challenge` using the given `optimizer`.

Example usage is as follows:

params, state, step_fn = setup_optimization(
challenge=my_challenge(),
optimizer=optax.adam(0.02),
)
for _ in range(num_steps):
params, state, (loss, response, aux, metrics) = step_fn(params, state)

Args:
challenge: The challenge to be optimized.
optimizer: Optax gradient transformation used to drive the optimization.
response_kwargs: Optional keyword arguments to be passed to the `response`
method of the `challenge.component`.

Returns:
The `(params, state, step_fn)` tuple, where `step_fn` has the signature
`fn(params, state) -> (params, state, (value, response, aux, metrics))`.
"""

def loss_fn(params: Any) -> Tuple[jnp.ndarray, Tuple[Any, AuxDict]]:
response, aux = challenge.component.response(params, **response_kwargs)
loss = challenge.loss(response)
return loss, (response, aux)

def clip(
leaf: types.BoundedArray | types.Density2DArray,
) -> types.BoundedArray | types.Density2DArray:
(value,), treedef = tree_util.tree_flatten(leaf)
return tree_util.tree_unflatten(
treedef, (jnp.clip(value, leaf.lower_bound, leaf.upper_bound),)
)

def apply_fixed_pixels(
leaf: types.Density2DArray,
) -> types.Density2DArray:
(value,), treedef = tree_util.tree_flatten(leaf)
if leaf.fixed_solid is not None:
value = jnp.where(leaf.fixed_solid, leaf.upper_bound, value)
if leaf.fixed_void is not None:
value = jnp.where(leaf.fixed_void, leaf.lower_bound, value)
return tree_util.tree_unflatten(treedef, (value,))

def transform(leaf: Any) -> Any:
if isinstance(leaf, types.BoundedArray):
return clip(leaf)
if isinstance(leaf, types.Density2DArray):
leaf = types.symmetrize_density(leaf)
leaf = clip(leaf)
return apply_fixed_pixels(leaf)
return leaf

def step_fn(
params: Any, state: Any
) -> Tuple[Any, Any, Tuple[jnp.ndarray, Any, AuxDict, AuxDict]]:
(value, (response, aux)), grad = jax.value_and_grad(loss_fn, has_aux=True)(
params
)
metrics = challenge.metrics(response, params, aux)
updates, state = optimizer.update(grad, state)
params = optax.apply_updates(params, updates)
params = tree_util.tree_map(
transform, params, is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES)
)
return params, state, (value, response, aux, metrics)

params = challenge.component.init(jax.random.PRNGKey(0))

# Eliminate weak types by specifying types of all arrays in the params pytree.
params = tree_util.tree_map(lambda x: jnp.asarray(x, jnp.asarray(x).dtype), params)
state = optimizer.init(params)
return params, state, step_fn
Loading