Skip to content

Commit

Permalink
Updating from main to fix model loading errors (#211)
Browse files Browse the repository at this point in the history
* Change setup to fix MIRA name error (#205)

* changed mira version before bug

* fix error

* Add utilities for loading distributions from AMR (#200)

* added mira distribution loading

* added normal distribution

* fixed Normal2 and Normal3

* nit

* added minimal mira_distribution_to_pyro test

* Symbolic Rate law to Pytorch Rate law (#201)

* I believe I wrote the correct code, based on experiments in the notebook. Will test next.

* FAILED test/test_mira/test_rate_law.py::TestRateLaw::test_rate_law_compilation - AttributeError: 'ScaledBetaNoisePetriNetODESystem' object has no attribute 'beta'

* Added Symbolic_Deriv_Experiments notebook

* Something weird is happening. I can confirm that 'beta' is an attribute of ScaledBetaNoisePetriNetODESystem after setting up the model, but then it can't be found at sample time

* Clarified the bug in the Symbolic derivatives notebook

* Expected and actual derivative match

* Time varying parameter rate law correctly read

* Thought we added this already

* Added kwargs to from_askenet and from_mira and compile_rate_law_p to load_petri_net

* Blocked on gyorilab/mira#189 but tests pass by making compile_rate_law_p False by default

* Removed unnecessary pygraphviz dependency

* Unit test to fail when concept name does not equal rate law symbols

* All tests pass with default compile_rate_law_p = False

* Merged from main. removed dependency on older version of mira

* point mira to the github repo main branch

* point mira to the github repo main branch

* load_and_calibrate_and_sample(..., compile_rate_law_p=True) works with the caveat that the ScaledBetaNoisePetriNetODESystem solution was returning very slightly negative values, so I set mean = torch.abs(solution[var_name]) to address the issue

* merged changes to MiraPetriNetODESystem and ScaledBetaNoisePetriNetODESystem from main.  ScaledBetaNoisePetriNetODESystem has default compiled_rate_law_p=True

* observation_function for ScaledBetaNoisePetriNetODESystem now uses torch.maximum(solution[var_name], torch.tensor(1e-9)) to deal with overshooting derivatives

* aggregate parameters is now by default opt-out, and AMR models with multiple parameters per transition can be interpreted using compile_rate_law

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>
Co-authored-by: Jeremy Zucker <djinnome@gmail.com>
  • Loading branch information
3 people authored Jul 10, 2023
1 parent fb3e00e commit fe6f625
Show file tree
Hide file tree
Showing 15 changed files with 1,017 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "ecbb71f2-4ac5-459f-8c3b-825b5f520bf3",
"metadata": {},
"outputs": [],
"source": [
"from collections.abc import Callable"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8d9914c4",
"metadata": {},
"outputs": [],
Expand All @@ -21,11 +31,158 @@
"from mira.modeling.viz import GraphicalModel\n",
"from mira.modeling import Model\n",
"from mira.modeling.askenet.petrinet import AskeNetPetriNetModel\n",
"\n",
"import torch\n",
"from pyciemss.interfaces import setup_model, calibrate, intervene\n",
"from pyciemss.PetriNetODE.interfaces import load_petri_model, setup_petri_model\n",
"from collections.abc import Callable\n",
"from typing import Tuple\n",
"import sympy\n",
"import sympytorch"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f6cb8d58-c4f1-4301-bd49-46875ba35571",
"metadata": {},
"outputs": [],
"source": [
"def make_model() -> Callable[[float, Tuple[torch.Tensor]], Tuple[torch.Tensor]]:\n",
" \"\"\"Compile the deriv function during initialization.\"\"\"\n",
" state_variables = \"beta, total_population, susceptible_population, infectious_population, gamma, recovered_population\"\n",
" beta, total_pop,S, I, gamma, R = sympy.symbols(state_variables)\n",
" susceptible = Concept(name=\"susceptible_population\", identifiers={\"ido\": \"0000514\"})\n",
" infectious = Concept(name=\"infectious_population\", identifiers={\"ido\": \"0000513\"}) # http://purl.obolibrary.org/obo/IDO_0000513\n",
" recovered = Concept(name=\"recovered_population\", identifiers={\"ido\": \"0000592\"})\n",
" \n",
" # Set a value for the total population\n",
" total_pop = 100000 \n",
" S_to_I = ControlledConversion(\n",
" controller = infectious,\n",
" subject=susceptible,\n",
" outcome=infectious,\n",
" rate_law=(beta/total_pop)*S*I\n",
" )\n",
" I_to_R = NaturalConversion(\n",
" subject=infectious,\n",
" outcome=recovered,\n",
" rate_law=gamma*I\n",
" )\n",
" template_model = TemplateModel(\n",
" templates=[S_to_I, I_to_R],\n",
" parameters={\n",
" 'beta': Parameter(name='beta', value=0.55), # transmission rate\n",
" 'total_population': Parameter(name='total_population', value=total_pop),\n",
" 'gamma': Parameter(name='gamma', value=0.2), # recovery rate\n",
" },\n",
" initials={\n",
" 'susceptible_population': (Initial(concept=self.susceptible, value=total_pop-1)),\n",
" 'infectious_population': (Initial(concept=self.infectious, value=1)),\n",
" 'recovered_population': (Initial(concept=self.recovered, value=0))\n",
" }\n",
" )\n",
" model=Model(template_model)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "6a1478f5-4a77-44e2-a729-1cf4f200eebd",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'self' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[18], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mmake_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[17], line 26\u001b[0m, in \u001b[0;36mmake_model\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m S_to_I \u001b[38;5;241m=\u001b[39m ControlledConversion(\n\u001b[1;32m 12\u001b[0m controller \u001b[38;5;241m=\u001b[39m infectious,\n\u001b[1;32m 13\u001b[0m subject\u001b[38;5;241m=\u001b[39msusceptible,\n\u001b[1;32m 14\u001b[0m outcome\u001b[38;5;241m=\u001b[39minfectious,\n\u001b[1;32m 15\u001b[0m rate_law\u001b[38;5;241m=\u001b[39m(beta\u001b[38;5;241m/\u001b[39mtotal_pop)\u001b[38;5;241m*\u001b[39mS\u001b[38;5;241m*\u001b[39mI\n\u001b[1;32m 16\u001b[0m )\n\u001b[1;32m 17\u001b[0m I_to_R \u001b[38;5;241m=\u001b[39m NaturalConversion(\n\u001b[1;32m 18\u001b[0m subject\u001b[38;5;241m=\u001b[39minfectious,\n\u001b[1;32m 19\u001b[0m outcome\u001b[38;5;241m=\u001b[39mrecovered,\n\u001b[1;32m 20\u001b[0m rate_law\u001b[38;5;241m=\u001b[39mgamma\u001b[38;5;241m*\u001b[39mI\n\u001b[1;32m 21\u001b[0m )\n\u001b[1;32m 22\u001b[0m template_model \u001b[38;5;241m=\u001b[39m TemplateModel(\n\u001b[1;32m 23\u001b[0m templates\u001b[38;5;241m=\u001b[39m[S_to_I, I_to_R],\n\u001b[1;32m 24\u001b[0m parameters\u001b[38;5;241m=\u001b[39m{\n\u001b[1;32m 25\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbeta\u001b[39m\u001b[38;5;124m'\u001b[39m: Parameter(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbeta\u001b[39m\u001b[38;5;124m'\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.55\u001b[39m), \u001b[38;5;66;03m# transmission rate\u001b[39;00m\n\u001b[0;32m---> 26\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtotal_population\u001b[39m\u001b[38;5;124m'\u001b[39m: Parameter(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtotal_population\u001b[39m\u001b[38;5;124m'\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241m.\u001b[39mtotal_pop),\n\u001b[1;32m 27\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgamma\u001b[39m\u001b[38;5;124m'\u001b[39m: Parameter(name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgamma\u001b[39m\u001b[38;5;124m'\u001b[39m, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.2\u001b[39m), \u001b[38;5;66;03m# recovery rate\u001b[39;00m\n\u001b[1;32m 28\u001b[0m },\n\u001b[1;32m 29\u001b[0m initials\u001b[38;5;241m=\u001b[39m{\n\u001b[1;32m 30\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msusceptible_population\u001b[39m\u001b[38;5;124m'\u001b[39m: (Initial(concept\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msusceptible, value\u001b[38;5;241m=\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtotal_pop\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m))),\n\u001b[1;32m 31\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124minfectious_population\u001b[39m\u001b[38;5;124m'\u001b[39m: (Initial(concept\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfectious, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)),\n\u001b[1;32m 32\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrecovered_population\u001b[39m\u001b[38;5;124m'\u001b[39m: (Initial(concept\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrecovered, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m))\n\u001b[1;32m 33\u001b[0m }\n\u001b[1;32m 34\u001b[0m )\n\u001b[1;32m 35\u001b[0m model\u001b[38;5;241m=\u001b[39mModel(template_model)\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model\n",
"\u001b[0;31mNameError\u001b[0m: name 'self' is not defined"
]
}
],
"source": [
"model = make_model()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fb719d87-c045-4d46-a764-501f6e788758",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.9900, -0.0100, 1.0000])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sympy, torch, sympytorch\n",
"from torch import tensor\n",
"total_pop, S, I, R, beta, gamma = sympy.symbols('total_pop, S, I, R, beta, gamma')\n",
"S_to_I = beta*I*S/total_pop\n",
"I_to_R = gamma*I\n",
"\n",
"# = sympytorch.SymPyModule(expressions=[S_to_I, I_to_R])\n",
"dSdt = -S_to_I\n",
"dIdt = S_to_I - I_to_R\n",
"dRdt = I_to_R\n",
"\n",
"import sympy"
"compile_deriv = sympytorch.SymPyModule(expressions=[dSdt, dIdt, dRdt])\n",
"\n",
"compiled_deriv = compile_deriv(beta=getattr(self, 'beta'),\n",
" gamma=getattr(self, 'gamma'),\n",
" S=states['S'],\n",
" I=states['I'],\n",
" R=states['R'],\n",
" total_pop=sum(states[i] for i in states)\n",
" )\n",
"\n",
"compile_deriv(beta=tensor(1.),\n",
" gamma=tensor(1.),\n",
" S=tensor(99.),\n",
" I=tensor(1.),\n",
" R=tensor(0.0),\n",
" total_pop=tensor(100.)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40c5d7eb-72f2-41a7-9b15-efec7458447c",
"metadata": {},
"outputs": [],
"source": [
" \n",
"x_ = torch.rand(3)\n",
"out = mod(x_name=x_) # out has shape (3, 2)\n",
"\n",
"assert torch.equal(out[:, 0], x_.cos())\n",
"assert torch.equal(out[:, 1], 2 * x_.sin())\n",
"assert out.requires_grad # from the two Parameters initialised as 1.0 and 2.0\n",
"assert {x.item() for x in mod.parameters()} == {1.0, 2.0}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efda5703-eb9f-4118-ad38-685bb0fe18d4",
"metadata": {},
"outputs": [],
"source": [
"def get_fluxes(rate_law):\n",
" \n",
" "
]
},
{
Expand Down
Loading

0 comments on commit fe6f625

Please sign in to comment.