Skip to content

Commit

Permalink
Merge pull request #39 from marvinpfoertner/jax-config-bugfix
Browse files Browse the repository at this point in the history
Fix `jax.config` related CI errors
  • Loading branch information
marvinpfoertner committed Apr 26, 2024
2 parents 75adeb8 + 8e6fe5a commit f1fc705
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 151 deletions.
22 changes: 11 additions & 11 deletions experiments/0000_cpu_stationary_1d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f416152f-fc5f-405a-8b24-1ca919bcdade",
"id": "8234a5d8-aa32-48ee-9c8c-81bc280c5218",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import probnum as pn\n",
"import experiment_utils\n",
"from experiment_utils import config\n",
"\n",
"import linpde_gp"
"config.experiment_name = \"0000_cpu_stationary_1d\"\n",
"config.target = \"jmlr\"\n",
"config.debug_mode = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8234a5d8-aa32-48ee-9c8c-81bc280c5218",
"id": "f416152f-fc5f-405a-8b24-1ca919bcdade",
"metadata": {},
"outputs": [],
"source": [
"import experiment_utils\n",
"from experiment_utils import config\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import probnum as pn\n",
"\n",
"config.experiment_name = \"0000_cpu_stationary_1d\"\n",
"config.target = \"jmlr\"\n",
"config.debug_mode = True\n",
"import linpde_gp\n",
"\n",
"plt.rcParams.update(config.tueplots_bundle())"
]
Expand Down
84 changes: 42 additions & 42 deletions experiments/0000_poisson_dirichlet_1d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,31 @@
{
"cell_type": "code",
"execution_count": null,
"id": "dc637170-8b95-4e44-b0eb-3dbc8402c9a4",
"id": "9009b61f-6bf6-4ea3-af92-586c62b71af6",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import probnum as pn\n",
"import experiment_utils\n",
"from experiment_utils import config\n",
"\n",
"import linpde_gp\n",
"from linpde_gp.problems.pde import get_1d_dirichlet_boundary_observations"
"config.experiment_name = \"0000_poisson_dirichlet_1d\"\n",
"config.target = \"jmlr\"\n",
"config.debug_mode = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9009b61f-6bf6-4ea3-af92-586c62b71af6",
"id": "dc637170-8b95-4e44-b0eb-3dbc8402c9a4",
"metadata": {},
"outputs": [],
"source": [
"import experiment_utils\n",
"from experiment_utils import config\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import probnum as pn\n",
"\n",
"config.experiment_name = \"0000_poisson_dirichlet_1d\"\n",
"config.target = \"jmlr\"\n",
"config.debug_mode = True"
"import linpde_gp\n",
"from linpde_gp.problems.pde import get_1d_dirichlet_boundary_observations"
]
},
{
Expand Down Expand Up @@ -60,19 +60,19 @@
" f_X_pde: pn.randvars.RandomVariable | None = None,\n",
" ):\n",
" u_conditional_strs = []\n",
" \n",
"\n",
" for key in conditioned_on:\n",
" if key == \"bc\":\n",
" u_conditional_strs.append(r\"u\\vert_{\\partial \\Omega} = g\")\n",
" elif key == \"pde\":\n",
" u_conditional_strs.append(r\"-\\Delta u(x_i) = f(x_i)\")\n",
" \n",
"\n",
" u_label = (\n",
" fr\"$u \\mid {', '.join(u_conditional_strs)}$\"\n",
" if len(u_conditional_strs) > 0\n",
" else \"$u$\"\n",
" )\n",
" \n",
"\n",
" u.plot(\n",
" ax,\n",
" self._plt_grid,\n",
Expand All @@ -81,14 +81,14 @@
" color=\"C0\",\n",
" label=u_label,\n",
" )\n",
" \n",
"\n",
" ax.plot(\n",
" self._plt_grid,\n",
" self._bvp.solution(self._plt_grid),\n",
" color=\"C1\",\n",
" label=\"$u^*$\",\n",
" )\n",
" \n",
"\n",
" for key in conditioned_on:\n",
" if key == \"bc\":\n",
" X_bc, Y_bc = get_1d_dirichlet_boundary_observations(self._bvp.boundary_conditions)\n",
Expand All @@ -114,7 +114,7 @@
" color=\"C3\",\n",
" label=f\"$(f(x_1), \\dots, f(x_{X_pde.shape[0]}))$\",\n",
" )\n",
" \n",
"\n",
" ax.legend()\n",
"\n",
" def plot_pred_belief(\n",
Expand All @@ -126,13 +126,13 @@
" f_X_pde: pn.randvars.RandomVariable | None = None,\n",
" ):\n",
" u_conditional_strs = []\n",
" \n",
"\n",
" for key in conditioned_on:\n",
" if key == \"bc\":\n",
" u_conditional_strs.append(r\"u\\vert_{\\partial \\Omega} = g\")\n",
" elif key == \"pde\":\n",
" u_conditional_strs.append(r\"-\\Delta u(x_i) = f(x_i)\")\n",
" \n",
"\n",
" u_label = (\n",
" fr\"$-\\Delta u \\mid {', '.join(u_conditional_strs)}$\"\n",
" if len(u_conditional_strs) > 0\n",
Expand All @@ -147,14 +147,14 @@
" color=\"C0\",\n",
" label=u_label,\n",
" )\n",
" \n",
"\n",
" self._bvp.pde.rhs.plot(\n",
" ax,\n",
" self._plt_grid,\n",
" color=\"C1\",\n",
" label=\"$f$\",\n",
" )\n",
" \n",
"\n",
" if \"pde\" in conditioned_on:\n",
" ax.scatter(\n",
" X_pde,\n",
Expand All @@ -163,7 +163,7 @@
" c=\"C3\",\n",
" label=f\"$(f(x_1), \\dots, f(x_{X_pde.shape[0]}))$\",\n",
" )\n",
" \n",
"\n",
" ax.legend()"
]
},
Expand Down Expand Up @@ -468,14 +468,14 @@
"source": [
"for include_bc in [False, True]:\n",
" nrows = 3 if include_bc else 2\n",
" \n",
"\n",
" rc = config.tueplots_bundle(nrows=nrows, ncols=2)\n",
" rc.update(\n",
" {\n",
" \"lines.linewidth\": 1\n",
" }\n",
" )\n",
" \n",
"\n",
" with plt.rc_context(rc):\n",
" fig, ax = plt.subplots(nrows=nrows, ncols=2)\n",
"\n",
Expand Down Expand Up @@ -512,7 +512,7 @@
" X_pde=X_pde,\n",
" f_X_pde=Y_pde,\n",
" )\n",
" \n",
"\n",
" if include_bc:\n",
" ax[2, 0].set_title(\"(e)\")\n",
"\n",
Expand Down Expand Up @@ -682,49 +682,49 @@
"\n",
"with plt.rc_context(rc):\n",
" fig, ax = plt.subplots(nrows=3, ncols=2)\n",
" \n",
"\n",
" ax[0, 0].set_title(\"(a)\")\n",
" \n",
"\n",
" plotter.plot_belief(\n",
" ax=ax[0, 0],\n",
" u=u_prior,\n",
" )\n",
" \n",
"\n",
" ax[0, 1].set_title(\"(b)\")\n",
" \n",
"\n",
" plotter.plot_pred_belief(\n",
" ax=ax[0, 1],\n",
" u=u_prior,\n",
" )\n",
" \n",
"\n",
" ax[1, 0].set_title(\"(c)\")\n",
" \n",
"\n",
" plotter.plot_belief(\n",
" ax=ax[1, 0],\n",
" u=u_cond_bc,\n",
" conditioned_on=[\"bc\"],\n",
" )\n",
" \n",
"\n",
" ax[1, 1].set_title(\"(d)\")\n",
" \n",
"\n",
" plotter.plot_pred_belief(\n",
" ax=ax[1, 1],\n",
" u=u_cond_bc,\n",
" conditioned_on=[\"bc\"],\n",
" )\n",
" \n",
"\n",
" ax[2, 0].set_title(\"(e)\")\n",
" \n",
"\n",
" plotter.plot_belief(\n",
" ax=ax[2, 0],\n",
" u=u_cond_bc_pde,\n",
" conditioned_on=[\"bc\", \"pde\"],\n",
" X_pde=X_pde,\n",
" f_X_pde=Y_pde,\n",
" )\n",
" \n",
"\n",
" ax[2, 1].set_title(\"(f)\")\n",
" \n",
"\n",
" plotter.plot_pred_belief(\n",
" ax=ax[2, 1],\n",
" u=u_cond_bc_pde,\n",
Expand Down Expand Up @@ -857,7 +857,7 @@
" rhs=linpde_gp.functions.Constant(input_shape=(), value=f),\n",
" boundary_values=(u_l, u_r),\n",
" )\n",
" \n",
"\n",
" u_prior = pn.randprocs.GaussianProcess(\n",
" mean=linpde_gp.functions.Zero(\n",
" input_shape=bvp.pde.diffop.input_domain_shape,\n",
Expand Down Expand Up @@ -887,13 +887,13 @@
" X=X_pde,\n",
" b=pn.randvars.Normal(np.zeros_like(Y_pde), f_std**2 * np.eye(Y_pde.size)),\n",
" )\n",
" \n",
"\n",
" # Plotting\n",
" plotter = BeliefPlotter(bvp)\n",
" \n",
"\n",
" ax[0].cla()\n",
" ax[1].cla()\n",
" \n",
"\n",
" plotter.plot_belief(\n",
" ax[0],\n",
" u=u_cond_bc_pde,\n",
Expand Down
Loading

0 comments on commit f1fc705

Please sign in to comment.