diff --git a/docs/source/interfaces.ipynb b/docs/source/interfaces.ipynb index d140583ec..e4c08ebf2 100644 --- a/docs/source/interfaces.ipynb +++ b/docs/source/interfaces.ipynb @@ -1,12 +1,57 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A tour of PyCIEMSS interfaces and functionality" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load dependencies and interfaces" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pyciemss\n", + "from pyciemss.interfaces import calibrate" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Select models and data" + ] + }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "import pyciemss" + "MODEL_PATH = \"https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/\"\n", + "DATA_PATH = \"https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/datasets/\"\n", + "\n", + "model1 = os.path.join(MODEL_PATH, \"SEIRHD_NPI_Type1_petrinet.json\")\n", + "model2 = os.path.join(MODEL_PATH, \"SEIRHD_NPI_Type2_petrinet.json\")\n", + "\n", + "dataset1 = os.path.join(DATA_PATH, \"traditional.csv\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set parameters for sampling" ] }, { @@ -15,15 +60,27 @@ "metadata": {}, "outputs": [], "source": [ - "### SETUP ###\n", - "\n", - "model_1_path = \"https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/SEIRHD_NPI_Type1_petrinet.json\"\n", "start_time = 0.0\n", "end_time = 100.\n", "logging_step_size = 10.0\n", "num_samples = 3" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sample interface\n", + "Take `num_samples` number of samples from the (prior) distribution invoked by the chosen model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sample from model 1" + ] + }, { "cell_type": "code", "execution_count": 4, @@ -32,47 +89,75 @@ { "data": { "text/plain": [ - "{'beta_c': tensor([0.1086, 0.6213, 0.1839]),\n", - " 'kappa': tensor([0.4764, 0.3179, 0.4837]),\n", - " 'gamma': tensor([0.4524, 0.2553, 0.3339]),\n", - " 'hosp': tensor([0.1164, 0.0947, 0.0185]),\n", - " 'death_hosp': tensor([0.0288, 0.0273, 0.0981]),\n", - " 'D_state': tensor([[ 0.1427, 0.4143, 0.7265, 1.0684, 1.4406, 1.8455, 2.2859, 2.7650,\n", - " 3.2861],\n", - " [ 0.0831, 0.2762, 0.5478, 0.9133, 1.4034, 2.0602, 2.9402, 4.1194,\n", - " 5.6992],\n", - " [ 0.0744, 0.2808, 0.6614, 1.3446, 2.5686, 4.7612, 8.6879, 15.7182,\n", - " 28.2966]]),\n", - " 'E_state': tensor([[ 35.0912, 38.1662, 41.5179, 45.1635, 49.1286, 53.4413,\n", - " 58.1320, 63.2334, 54.9501],\n", - " [ 34.6768, 46.3779, 62.1469, 83.2771, 111.5897, 149.5242,\n", - " 200.3483, 268.4357, 326.5832],\n", - " [ 52.5087, 94.0385, 168.4659, 301.7839, 540.5568, 968.0879,\n", - " 1733.2552, 3101.5935, 4357.2822]]),\n", - " 'H_state': tensor([[ 4.1075, 5.1372, 5.6790, 6.1899, 6.7351, 7.3266, 7.9698, 8.6693,\n", - " 9.3940],\n", - " [ 2.7630, 4.2363, 5.7498, 7.7147, 10.3390, 13.8542, 18.5637, 24.8734,\n", - " 33.2846],\n", - " [ 0.7227, 1.4214, 2.5637, 4.5950, 8.2314, 14.7434, 26.4014, 47.2604,\n", - " 84.1758]]),\n", - " 'I_state': tensor([[ 19.0262, 20.7039, 22.5221, 24.4998, 26.6508, 28.9904,\n", - " 31.5350, 34.3025, 35.6219],\n", - " [ 30.3420, 40.7402, 54.5933, 73.1556, 98.0277, 131.3530,\n", - " 176.0024, 235.8200, 311.6320],\n", - " [ 33.4459, 59.9439, 107.3877, 192.3733, 344.5883, 617.1570,\n", - " 1105.0461, 1977.7384, 3386.9104]]),\n", - " 'R_state': tensor([[ 73.5677, 162.1008, 258.9751, 364.4325, 479.1595, 603.9607,\n", - " 739.7171, 887.3885, 1047.7235],\n", - " [ 58.4188, 146.9025, 265.9601, 425.5612, 639.4350, 926.0212,\n", - " 1310.0326, 1824.5624, 2513.5286],\n", - " [ 79.3534, 230.1014, 500.2771, 984.2914, 1851.3185, 3404.2715,\n", - " 6185.2612, 11163.6143, 20049.1035]]),\n", - " 'S_state': tensor([[19339868., 19339772., 19339670., 19339560., 19339438., 19339302.,\n", - " 19339160., 19338996., 19338848.],\n", - " [19339868., 19339764., 19339612., 19339418., 19339120., 19338778.,\n", - " 19338308., 19337656., 19336836.],\n", - " [19339836., 19339642., 19339224., 19338516., 19337282., 19334992.,\n", - " 19330954., 19323698., 19312068.]])}" + "{'persistent_beta_c': tensor([0.7351, 0.2722, 0.6301]),\n", + " 'persistent_kappa': tensor([0.7916, 0.2749, 0.2125]),\n", + " 'persistent_gamma': tensor([0.1581, 0.3249, 0.4120]),\n", + " 'persistent_hosp': tensor([0.0436, 0.0068, 0.1053]),\n", + " 'persistent_death_hosp': tensor([0.0174, 0.0480, 0.0121]),\n", + " 'persistent_I0': tensor([14.4330, 13.8162, 3.2406]),\n", + " 'D_state': tensor([[3.8956e-02, 5.4015e-01, 6.2445e+00, 7.0356e+01, 7.1390e+02, 4.1181e+03,\n", + " 9.4866e+03, 1.2776e+04, 1.4026e+04],\n", + " [1.2083e-02, 3.1182e-02, 4.7820e-02, 6.1282e-02, 7.2044e-02, 8.0632e-02,\n", + " 8.7482e-02, 9.2945e-02, 9.7302e-02],\n", + " [3.3920e-02, 7.5363e-02, 9.6949e-02, 1.0659e-01, 1.1071e-01, 1.1246e-01,\n", + " 1.1319e-01, 1.1350e-01, 1.1363e-01]]),\n", + " 'E_state': tensor([[3.4949e+02, 3.9743e+03, 4.4992e+04, 4.8423e+05, 3.2283e+06, 3.1408e+06,\n", + " 6.1947e+05, 9.3426e+04, 1.5437e+04],\n", + " [2.3967e+01, 1.9084e+01, 1.5221e+01, 1.2140e+01, 9.6831e+00, 7.7231e+00,\n", + " 6.1598e+00, 4.9130e+00, 3.3426e+00],\n", + " [1.1823e+01, 4.9530e+00, 2.0815e+00, 8.7482e-01, 3.6766e-01, 1.5452e-01,\n", + " 6.4940e-02, 2.7293e-02, 1.0832e-02]]),\n", + " 'H_state': tensor([[3.3353e+00, 3.8501e+01, 4.3723e+02, 4.8641e+03, 4.4252e+04, 1.5015e+05,\n", + " 1.3361e+05, 5.8951e+04, 1.9111e+04],\n", + " [1.9613e-01, 1.9018e-01, 1.5628e-01, 1.2527e-01, 1.0000e-01, 7.9770e-02,\n", + " 6.3625e-02, 5.0747e-02, 4.0408e-02],\n", + " [2.0707e+00, 1.2671e+00, 5.8662e-01, 2.5386e-01, 1.0768e-01, 4.5389e-02,\n", + " 1.9094e-02, 8.0272e-03, 3.3725e-03]]),\n", + " 'I_state': tensor([[2.1773e+02, 2.4765e+03, 2.8089e+04, 3.0869e+05, 2.5044e+06, 5.3752e+06,\n", + " 2.7832e+06, 8.4614e+05, 2.1702e+05],\n", + " [1.9746e+01, 1.5784e+01, 1.2589e+01, 1.0041e+01, 8.0086e+00, 6.3876e+00,\n", + " 5.0946e+00, 4.0634e+00, 3.1665e+00],\n", + " [8.9978e+00, 3.8057e+00, 1.5995e+00, 6.7224e-01, 2.8252e-01, 1.1874e-01,\n", + " 4.9902e-02, 2.0973e-02, 8.7333e-03]]),\n", + " 'R_state': tensor([[1.2462e+02, 1.5579e+03, 1.7838e+04, 1.9961e+05, 1.9135e+06, 8.7308e+06,\n", + " 1.5453e+07, 1.8156e+07, 1.8931e+07],\n", + " [6.5494e+01, 1.2300e+02, 1.6891e+02, 2.0553e+02, 2.3474e+02, 2.5804e+02,\n", + " 2.7662e+02, 2.9144e+02, 3.0325e+02],\n", + " [4.4201e+01, 6.9875e+01, 8.1021e+01, 8.5751e+01, 8.7746e+01, 8.8585e+01,\n", + " 8.8938e+01, 8.9086e+01, 8.9148e+01]]),\n", + " 'S_state': tensor([[19339298.0000, 19331996.0000, 19248676.0000, 18342560.0000,\n", + " 11648909.0000, 1938993.7500, 341272.4688, 172677.5625,\n", + " 143060.9844],\n", + " [19339930.0000, 19339878.0000, 19339856.0000, 19339812.0000,\n", + " 19339786.0000, 19339754.0000, 19339736.0000, 19339716.0000,\n", + " 19339712.0000],\n", + " [19339968.0000, 19339956.0000, 19339956.0000, 19339956.0000,\n", + " 19339956.0000, 19339956.0000, 19339956.0000, 19339956.0000,\n", + " 19339956.0000]]),\n", + " 'infected_observable': tensor([[2.1773e+02, 2.4765e+03, 2.8089e+04, 3.0869e+05, 2.5044e+06, 5.3752e+06,\n", + " 2.7832e+06, 8.4614e+05, 2.1702e+05],\n", + " [1.9746e+01, 1.5784e+01, 1.2589e+01, 1.0041e+01, 8.0086e+00, 6.3876e+00,\n", + " 5.0946e+00, 4.0634e+00, 3.1665e+00],\n", + " [8.9978e+00, 3.8057e+00, 1.5995e+00, 6.7224e-01, 2.8252e-01, 1.1874e-01,\n", + " 4.9902e-02, 2.0973e-02, 8.7333e-03]]),\n", + " 'exposed_observable': tensor([[3.4949e+02, 3.9743e+03, 4.4992e+04, 4.8423e+05, 3.2283e+06, 3.1408e+06,\n", + " 6.1947e+05, 9.3426e+04, 1.5437e+04],\n", + " [2.3967e+01, 1.9084e+01, 1.5221e+01, 1.2140e+01, 9.6831e+00, 7.7231e+00,\n", + " 6.1598e+00, 4.9130e+00, 3.3426e+00],\n", + " [1.1823e+01, 4.9530e+00, 2.0815e+00, 8.7482e-01, 3.6766e-01, 1.5452e-01,\n", + " 6.4940e-02, 2.7293e-02, 1.0832e-02]]),\n", + " 'hospitalized_observable': tensor([[3.3353e+00, 3.8501e+01, 4.3723e+02, 4.8641e+03, 4.4252e+04, 1.5015e+05,\n", + " 1.3361e+05, 5.8951e+04, 1.9111e+04],\n", + " [1.9613e-01, 1.9018e-01, 1.5628e-01, 1.2527e-01, 1.0000e-01, 7.9770e-02,\n", + " 6.3625e-02, 5.0747e-02, 4.0408e-02],\n", + " [2.0707e+00, 1.2671e+00, 5.8662e-01, 2.5386e-01, 1.0768e-01, 4.5389e-02,\n", + " 1.9094e-02, 8.0272e-03, 3.3725e-03]]),\n", + " 'dead_observable': tensor([[3.8956e-02, 5.4015e-01, 6.2445e+00, 7.0356e+01, 7.1390e+02, 4.1181e+03,\n", + " 9.4866e+03, 1.2776e+04, 1.4026e+04],\n", + " [1.2083e-02, 3.1182e-02, 4.7820e-02, 6.1282e-02, 7.2044e-02, 8.0632e-02,\n", + " 8.7482e-02, 9.2945e-02, 9.7302e-02],\n", + " [3.3920e-02, 7.5363e-02, 9.6949e-02, 1.0659e-01, 1.1071e-01, 1.1246e-01,\n", + " 1.1319e-01, 1.1350e-01, 1.1363e-01]])}" ] }, "execution_count": 4, @@ -81,10 +166,8 @@ } ], "source": [ - "### SAMPLE INTERFACE ###\n", - "\n", - "result = pyciemss.sample(model_1_path, end_time, logging_step_size, num_samples, start_time=start_time)\n", - "result[\"unprocessed_result\"]" + "result1 = pyciemss.sample(model1, end_time, logging_step_size, num_samples, start_time=start_time)\n", + "result1[\"unprocessed_result\"]" ] }, { @@ -115,17 +198,22 @@ " \n", " timepoint_id\n", " sample_id\n", - " beta_c_param\n", - " kappa_param\n", - " gamma_param\n", - " hosp_param\n", - " death_hosp_param\n", - " D_state\n", - " E_state\n", - " H_state\n", - " I_state\n", - " R_state\n", - " S_state\n", + " persistent_beta_c_param\n", + " persistent_kappa_param\n", + " persistent_gamma_param\n", + " persistent_hosp_param\n", + " persistent_death_hosp_param\n", + " persistent_I0_param\n", + " D_state_state\n", + " E_state_state\n", + " H_state_state\n", + " I_state_state\n", + " R_state_state\n", + " S_state_state\n", + " infected_observable_state\n", + " exposed_observable_state\n", + " hospitalized_observable_state\n", + " dead_observable_state\n", " \n", " \n", " \n", @@ -133,107 +221,153 @@ " 0\n", " 0\n", " 0\n", - " 0.108597\n", - " 0.476381\n", - " 0.45244\n", - " 0.116417\n", - " 0.028812\n", - " 0.142670\n", - " 35.091187\n", - " 4.107528\n", - " 19.026175\n", - " 73.567719\n", - " 19339868.0\n", + " 0.735054\n", + " 0.791559\n", + " 0.158109\n", + " 0.043577\n", + " 0.017374\n", + " 14.433031\n", + " 0.038956\n", + " 3.494867e+02\n", + " 3.335325\n", + " 2.177264e+02\n", + " 1.246180e+02\n", + " 19339298.0\n", + " 2.177264e+02\n", + " 3.494867e+02\n", + " 3.335325\n", + " 0.038956\n", " \n", " \n", " 1\n", " 1\n", " 0\n", - " 0.108597\n", - " 0.476381\n", - " 0.45244\n", - " 0.116417\n", - " 0.028812\n", - " 0.414323\n", - " 38.166237\n", - " 5.137194\n", - " 20.703896\n", - " 162.100815\n", - " 19339772.0\n", + " 0.735054\n", + " 0.791559\n", + " 0.158109\n", + " 0.043577\n", + " 0.017374\n", + " 14.433031\n", + " 0.540151\n", + " 3.974265e+03\n", + " 38.500515\n", + " 2.476499e+03\n", + " 1.557898e+03\n", + " 19331996.0\n", + " 2.476499e+03\n", + " 3.974265e+03\n", + " 38.500515\n", + " 0.540151\n", " \n", " \n", " 2\n", " 2\n", " 0\n", - " 0.108597\n", - " 0.476381\n", - " 0.45244\n", - " 0.116417\n", - " 0.028812\n", - " 0.726510\n", - " 41.517891\n", - " 5.678977\n", - " 22.522100\n", - " 258.975098\n", - " 19339670.0\n", + " 0.735054\n", + " 0.791559\n", + " 0.158109\n", + " 0.043577\n", + " 0.017374\n", + " 14.433031\n", + " 6.244518\n", + " 4.499189e+04\n", + " 437.227478\n", + " 2.808854e+04\n", + " 1.783779e+04\n", + " 19248676.0\n", + " 2.808854e+04\n", + " 4.499189e+04\n", + " 437.227478\n", + " 6.244518\n", " \n", " \n", " 3\n", " 3\n", " 0\n", - " 0.108597\n", - " 0.476381\n", - " 0.45244\n", - " 0.116417\n", - " 0.028812\n", - " 1.068370\n", - " 45.163460\n", - " 6.189935\n", - " 24.499756\n", - " 364.432495\n", - " 19339560.0\n", + " 0.735054\n", + " 0.791559\n", + " 0.158109\n", + " 0.043577\n", + " 0.017374\n", + " 14.433031\n", + " 70.356255\n", + " 4.842307e+05\n", + " 4864.138184\n", + " 3.086899e+05\n", + " 1.996144e+05\n", + " 18342560.0\n", + " 3.086899e+05\n", + " 4.842307e+05\n", + " 4864.138184\n", + " 70.356255\n", " \n", " \n", " 4\n", " 4\n", " 0\n", - " 0.108597\n", - " 0.476381\n", - " 0.45244\n", - " 0.116417\n", - " 0.028812\n", - " 1.440553\n", - " 49.128628\n", - " 6.735100\n", - " 26.650801\n", - " 479.159485\n", - " 19339438.0\n", + " 0.735054\n", + " 0.791559\n", + " 0.158109\n", + " 0.043577\n", + " 0.017374\n", + " 14.433031\n", + " 713.903564\n", + " 3.228286e+06\n", + " 44251.902344\n", + " 2.504426e+06\n", + " 1.913453e+06\n", + " 11648909.0\n", + " 2.504426e+06\n", + " 3.228286e+06\n", + " 44251.902344\n", + " 713.903564\n", " \n", " \n", "\n", "" ], "text/plain": [ - " timepoint_id sample_id beta_c_param kappa_param gamma_param \\\n", - "0 0 0 0.108597 0.476381 0.45244 \n", - "1 1 0 0.108597 0.476381 0.45244 \n", - "2 2 0 0.108597 0.476381 0.45244 \n", - "3 3 0 0.108597 0.476381 0.45244 \n", - "4 4 0 0.108597 0.476381 0.45244 \n", + " timepoint_id sample_id persistent_beta_c_param persistent_kappa_param \\\n", + "0 0 0 0.735054 0.791559 \n", + "1 1 0 0.735054 0.791559 \n", + "2 2 0 0.735054 0.791559 \n", + "3 3 0 0.735054 0.791559 \n", + "4 4 0 0.735054 0.791559 \n", + "\n", + " persistent_gamma_param persistent_hosp_param persistent_death_hosp_param \\\n", + "0 0.158109 0.043577 0.017374 \n", + "1 0.158109 0.043577 0.017374 \n", + "2 0.158109 0.043577 0.017374 \n", + "3 0.158109 0.043577 0.017374 \n", + "4 0.158109 0.043577 0.017374 \n", + "\n", + " persistent_I0_param D_state_state E_state_state H_state_state \\\n", + "0 14.433031 0.038956 3.494867e+02 3.335325 \n", + "1 14.433031 0.540151 3.974265e+03 38.500515 \n", + "2 14.433031 6.244518 4.499189e+04 437.227478 \n", + "3 14.433031 70.356255 4.842307e+05 4864.138184 \n", + "4 14.433031 713.903564 3.228286e+06 44251.902344 \n", + "\n", + " I_state_state R_state_state S_state_state infected_observable_state \\\n", + "0 2.177264e+02 1.246180e+02 19339298.0 2.177264e+02 \n", + "1 2.476499e+03 1.557898e+03 19331996.0 2.476499e+03 \n", + "2 2.808854e+04 1.783779e+04 19248676.0 2.808854e+04 \n", + "3 3.086899e+05 1.996144e+05 18342560.0 3.086899e+05 \n", + "4 2.504426e+06 1.913453e+06 11648909.0 2.504426e+06 \n", "\n", - " hosp_param death_hosp_param D_state E_state H_state I_state \\\n", - "0 0.116417 0.028812 0.142670 35.091187 4.107528 19.026175 \n", - "1 0.116417 0.028812 0.414323 38.166237 5.137194 20.703896 \n", - "2 0.116417 0.028812 0.726510 41.517891 5.678977 22.522100 \n", - "3 0.116417 0.028812 1.068370 45.163460 6.189935 24.499756 \n", - "4 0.116417 0.028812 1.440553 49.128628 6.735100 26.650801 \n", + " exposed_observable_state hospitalized_observable_state \\\n", + "0 3.494867e+02 3.335325 \n", + "1 3.974265e+03 38.500515 \n", + "2 4.499189e+04 437.227478 \n", + "3 4.842307e+05 4864.138184 \n", + "4 3.228286e+06 44251.902344 \n", "\n", - " R_state S_state \n", - "0 73.567719 19339868.0 \n", - "1 162.100815 19339772.0 \n", - "2 258.975098 19339670.0 \n", - "3 364.432495 19339560.0 \n", - "4 479.159485 19339438.0 " + " dead_observable_state \n", + "0 0.038956 \n", + "1 0.540151 \n", + "2 6.244518 \n", + "3 70.356255 \n", + "4 713.903564 " ] }, "execution_count": 5, @@ -242,22 +376,19 @@ } ], "source": [ - "result['data'].head()" + "result1['data'].head()" ] }, { - "cell_type": "code", - "execution_count": 6, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "### ENSEMBLE SETUP ###\n", - "model_2_path = \"https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/SEIRHD_NPI_Type2_petrinet.json\"" + "### Sample from model 2" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -283,18 +414,23 @@ " \n", " timepoint_id\n", " sample_id\n", - " beta_c_param\n", - " beta_nc_param\n", - " kappa_param\n", - " gamma_param\n", - " hosp_param\n", - " death_hosp_param\n", - " D_state\n", - " E_state\n", - " H_state\n", - " I_state\n", - " R_state\n", - " S_state\n", + " persistent_beta_c_param\n", + " persistent_beta_nc_param\n", + " persistent_kappa_param\n", + " persistent_gamma_param\n", + " persistent_hosp_param\n", + " persistent_death_hosp_param\n", + " persistent_I0_param\n", + " D_state_state\n", + " E_state_state\n", + " H_state_state\n", + " I_state_state\n", + " R_state_state\n", + " S_state_state\n", + " infected_observable_state\n", + " exposed_observable_state\n", + " hospitalized_observable_state\n", + " dead_observable_state\n", " \n", " \n", " \n", @@ -302,127 +438,181 @@ " 0\n", " 0\n", " 0\n", - " 0.318328\n", - " 0.467599\n", - " 0.357985\n", - " 0.385949\n", - " 0.036065\n", - " 0.088924\n", - " 0.119807\n", - " 27.080441\n", - " 1.084867\n", - " 18.032969\n", - " 66.234688\n", - " 19339888.0\n", + " 0.53956\n", + " 0.301945\n", + " 0.491844\n", + " 0.235206\n", + " 0.170145\n", + " 0.077194\n", + " 1.9169\n", + " 0.363128\n", + " 61.741619\n", + " 5.386476\n", + " 44.884804\n", + " 53.556133\n", + " 19339876.0\n", + " 44.884804\n", + " 61.741619\n", + " 5.386476\n", + " 0.363128\n", " \n", " \n", " 1\n", " 1\n", " 0\n", - " 0.318328\n", - " 0.467599\n", - " 0.357985\n", - " 0.385949\n", - " 0.036065\n", - " 0.088924\n", - " 0.324788\n", - " 24.196489\n", - " 1.156573\n", - " 16.141342\n", - " 131.862915\n", - " 19339828.0\n", + " 0.53956\n", + " 0.301945\n", + " 0.491844\n", + " 0.235206\n", + " 0.170145\n", + " 0.077194\n", + " 1.9169\n", + " 1.956210\n", + " 181.897064\n", + " 17.140951\n", + " 132.448746\n", + " 230.586838\n", + " 19339478.0\n", + " 132.448746\n", + " 181.897064\n", + " 17.140951\n", + " 1.956210\n", " \n", " \n", " 2\n", " 2\n", " 0\n", - " 0.318328\n", - " 0.467599\n", - " 0.357985\n", - " 0.385949\n", - " 0.036065\n", - " 0.088924\n", - " 0.522439\n", - " 21.634531\n", - " 1.059445\n", - " 14.432345\n", - " 190.700455\n", - " 19339772.0\n", + " 0.53956\n", + " 0.301945\n", + " 0.491844\n", + " 0.235206\n", + " 0.170145\n", + " 0.077194\n", + " 1.9169\n", + " 6.737850\n", + " 536.286072\n", + " 50.710808\n", + " 390.508484\n", + " 753.597351\n", + " 19338306.0\n", + " 390.508484\n", + " 536.286072\n", + " 50.710808\n", + " 6.737850\n", " \n", " \n", " 3\n", " 3\n", " 0\n", - " 0.318328\n", - " 0.467599\n", - " 0.357985\n", - " 0.385949\n", - " 0.036065\n", - " 0.088924\n", - " 0.701110\n", - " 19.343807\n", - " 0.950698\n", - " 12.904219\n", - " 243.328217\n", - " 19339730.0\n", + " 0.53956\n", + " 0.301945\n", + " 0.491844\n", + " 0.235206\n", + " 0.170145\n", + " 0.077194\n", + " 1.9169\n", + " 20.846823\n", + " 1580.785278\n", + " 149.522675\n", + " 1151.173950\n", + " 2295.645020\n", + " 19334840.0\n", + " 1151.173950\n", + " 1580.785278\n", + " 149.522675\n", + " 20.846823\n", " \n", " \n", " 4\n", " 4\n", " 0\n", - " 0.318328\n", - " 0.467599\n", - " 0.357985\n", - " 0.385949\n", - " 0.036065\n", - " 0.088924\n", - " 0.861128\n", - " 17.295570\n", - " 0.850499\n", - " 11.537855\n", - " 290.386169\n", - " 19339684.0\n", + " 0.53956\n", + " 0.301945\n", + " 0.491844\n", + " 0.235206\n", + " 0.170145\n", + " 0.077194\n", + " 1.9169\n", + " 62.435204\n", + " 4656.605957\n", + " 440.647888\n", + " 3391.869141\n", + " 6840.399414\n", + " 19324652.0\n", + " 3391.869141\n", + " 4656.605957\n", + " 440.647888\n", + " 62.435204\n", " \n", " \n", "\n", "" ], "text/plain": [ - " timepoint_id sample_id beta_c_param beta_nc_param kappa_param \\\n", - "0 0 0 0.318328 0.467599 0.357985 \n", - "1 1 0 0.318328 0.467599 0.357985 \n", - "2 2 0 0.318328 0.467599 0.357985 \n", - "3 3 0 0.318328 0.467599 0.357985 \n", - "4 4 0 0.318328 0.467599 0.357985 \n", + " timepoint_id sample_id persistent_beta_c_param persistent_beta_nc_param \\\n", + "0 0 0 0.53956 0.301945 \n", + "1 1 0 0.53956 0.301945 \n", + "2 2 0 0.53956 0.301945 \n", + "3 3 0 0.53956 0.301945 \n", + "4 4 0 0.53956 0.301945 \n", "\n", - " gamma_param hosp_param death_hosp_param D_state E_state H_state \\\n", - "0 0.385949 0.036065 0.088924 0.119807 27.080441 1.084867 \n", - "1 0.385949 0.036065 0.088924 0.324788 24.196489 1.156573 \n", - "2 0.385949 0.036065 0.088924 0.522439 21.634531 1.059445 \n", - "3 0.385949 0.036065 0.088924 0.701110 19.343807 0.950698 \n", - "4 0.385949 0.036065 0.088924 0.861128 17.295570 0.850499 \n", + " persistent_kappa_param persistent_gamma_param persistent_hosp_param \\\n", + "0 0.491844 0.235206 0.170145 \n", + "1 0.491844 0.235206 0.170145 \n", + "2 0.491844 0.235206 0.170145 \n", + "3 0.491844 0.235206 0.170145 \n", + "4 0.491844 0.235206 0.170145 \n", "\n", - " I_state R_state S_state \n", - "0 18.032969 66.234688 19339888.0 \n", - "1 16.141342 131.862915 19339828.0 \n", - "2 14.432345 190.700455 19339772.0 \n", - "3 12.904219 243.328217 19339730.0 \n", - "4 11.537855 290.386169 19339684.0 " + " persistent_death_hosp_param persistent_I0_param D_state_state \\\n", + "0 0.077194 1.9169 0.363128 \n", + "1 0.077194 1.9169 1.956210 \n", + "2 0.077194 1.9169 6.737850 \n", + "3 0.077194 1.9169 20.846823 \n", + "4 0.077194 1.9169 62.435204 \n", + "\n", + " E_state_state H_state_state I_state_state R_state_state S_state_state \\\n", + "0 61.741619 5.386476 44.884804 53.556133 19339876.0 \n", + "1 181.897064 17.140951 132.448746 230.586838 19339478.0 \n", + "2 536.286072 50.710808 390.508484 753.597351 19338306.0 \n", + "3 1580.785278 149.522675 1151.173950 2295.645020 19334840.0 \n", + "4 4656.605957 440.647888 3391.869141 6840.399414 19324652.0 \n", + "\n", + " infected_observable_state exposed_observable_state \\\n", + "0 44.884804 61.741619 \n", + "1 132.448746 181.897064 \n", + "2 390.508484 536.286072 \n", + "3 1151.173950 1580.785278 \n", + "4 3391.869141 4656.605957 \n", + "\n", + " hospitalized_observable_state dead_observable_state \n", + "0 5.386476 0.363128 \n", + "1 17.140951 1.956210 \n", + "2 50.710808 6.737850 \n", + "3 149.522675 20.846823 \n", + "4 440.647888 62.435204 " ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "result_2 = pyciemss.sample(model_2_path, end_time, logging_step_size, num_samples, start_time=start_time)\n", - "result_2['data'].head()" + "result2 = pyciemss.sample(model2, end_time, logging_step_size, num_samples, start_time=start_time)\n", + "result2['data'].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ensemble Sample Interface\n", + "Sample from an ensemble of model 1 and model 2 " ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -448,25 +638,25 @@ " \n", " timepoint_id\n", " sample_id\n", - " model_0/beta_c_param\n", - " model_0/kappa_param\n", - " model_0/gamma_param\n", - " model_0/hosp_param\n", - " model_0/death_hosp_param\n", - " model_1/beta_c_param\n", - " model_1/beta_nc_param\n", - " model_1/kappa_param\n", + " model_0/persistent_beta_c_param\n", + " model_0/persistent_kappa_param\n", + " model_0/persistent_gamma_param\n", + " model_0/persistent_hosp_param\n", + " model_0/persistent_death_hosp_param\n", + " model_0/persistent_I0_param\n", + " model_1/persistent_beta_c_param\n", + " model_1/persistent_beta_nc_param\n", " ...\n", - " model_0/H_state\n", - " model_0/I_state\n", - " model_0/R_state\n", - " model_0/S_state\n", - " model_1/D_state\n", - " model_1/E_state\n", - " model_1/H_state\n", - " model_1/I_state\n", - " model_1/R_state\n", - " model_1/S_state\n", + " model_0/H_state_state\n", + " model_0/I_state_state\n", + " model_0/R_state_state\n", + " model_0/S_state_state\n", + " model_1/D_state_state\n", + " model_1/E_state_state\n", + " model_1/H_state_state\n", + " model_1/I_state_state\n", + " model_1/R_state_state\n", + " model_1/S_state_state\n", " \n", " \n", " \n", @@ -474,192 +664,551 @@ " 0\n", " 0\n", " 0\n", - " 0.326421\n", - " 0.261223\n", - " 0.23488\n", - " 0.076074\n", - " 0.043145\n", - " 0.193485\n", - " 0.634903\n", - " 0.329632\n", + " 0.318085\n", + " 0.11662\n", + " 0.487817\n", + " 0.160195\n", + " 0.081918\n", + " 3.146115\n", + " 0.579869\n", + " 0.31039\n", " ...\n", - " 1.969539\n", - " 28.096125\n", - " 53.085701\n", - " 19339886.0\n", - " 0.122893\n", - " 21.867935\n", - " 1.392038\n", - " 14.199327\n", - " 63.355595\n", - " 19339900.0\n", + " 2.607337\n", + " 4.878901\n", + " 38.814823\n", + " 19339996.0\n", + " 0.016280\n", + " 133.996201\n", + " 1.183788\n", + " 86.414459\n", + " 96.795776\n", + " 19339724.0\n", " \n", " \n", " 1\n", " 1\n", " 0\n", - " 0.326421\n", - " 0.261223\n", - " 0.23488\n", - " 0.076074\n", - " 0.043145\n", - " 0.193485\n", - " 0.634903\n", - " 0.329632\n", + " 0.318085\n", + " 0.11662\n", + " 0.487817\n", + " 0.160195\n", + " 0.081918\n", + " 3.146115\n", + " 0.579869\n", + " 0.31039\n", " ...\n", - " 2.642947\n", - " 32.182053\n", - " 122.968758\n", - " 19339816.0\n", - " 0.302223\n", - " 15.316343\n", - " 1.210305\n", - " 9.964979\n", - " 113.576477\n", - " 19339864.0\n", + " 0.991349\n", + " 0.990093\n", + " 52.083225\n", + " 19339996.0\n", + " 0.115308\n", + " 642.588379\n", + " 5.893455\n", + " 414.542816\n", + " 574.882446\n", + " 19338406.0\n", " \n", " \n", " 2\n", " 2\n", " 0\n", - " 0.326421\n", - " 0.261223\n", - " 0.23488\n", - " 0.076074\n", - " 0.043145\n", - " 0.193485\n", - " 0.634903\n", - " 0.329632\n", + " 0.318085\n", + " 0.11662\n", + " 0.487817\n", + " 0.160195\n", + " 0.081918\n", + " 3.146115\n", + " 0.579869\n", + " 0.31039\n", " ...\n", - " 3.070854\n", - " 36.731533\n", - " 203.108948\n", - " 19339724.0\n", - " 0.441501\n", - " 10.735215\n", - " 0.880138\n", - " 6.984461\n", - " 148.973877\n", - " 19339840.0\n", + " 0.262363\n", + " 0.198274\n", + " 55.123028\n", + " 19339996.0\n", + " 0.593380\n", + " 3080.865723\n", + " 28.292126\n", + " 1987.789429\n", + " 2867.893311\n", + " 19332076.0\n", " \n", " \n", " 3\n", " 3\n", " 0\n", - " 0.326421\n", - " 0.261223\n", - " 0.23488\n", - " 0.076074\n", - " 0.043145\n", - " 0.193485\n", - " 0.634903\n", - " 0.329632\n", + " 0.318085\n", + " 0.11662\n", + " 0.487817\n", + " 0.160195\n", + " 0.081918\n", + " 3.146115\n", + " 0.579869\n", + " 0.31039\n", " ...\n", - " 3.512248\n", - " 41.923016\n", - " 294.621429\n", - " 19339622.0\n", - " 0.540961\n", - " 7.524287\n", - " 0.621196\n", - " 4.895389\n", - " 173.809586\n", - " 19339820.0\n", + " 0.061175\n", + " 0.039698\n", + " 55.782387\n", + " 19339996.0\n", + " 2.885418\n", + " 14743.889648\n", + " 135.559143\n", + " 9519.206055\n", + " 13856.029297\n", + " 19301794.0\n", " \n", " \n", " 4\n", " 4\n", " 0\n", - " 0.326421\n", - " 0.261223\n", - " 0.23488\n", - " 0.076074\n", - " 0.043145\n", - " 0.193485\n", - " 0.634903\n", - " 0.329632\n", + " 0.318085\n", + " 0.11662\n", + " 0.487817\n", + " 0.160195\n", + " 0.081918\n", + " 3.146115\n", + " 0.579869\n", + " 0.31039\n", " ...\n", - " 4.009643\n", - " 47.847847\n", - " 399.073700\n", - " 19339474.0\n", - " 0.610921\n", - " 5.273749\n", - " 0.435978\n", - " 3.431168\n", - " 191.220322\n", - " 19339806.0\n", + " 0.013418\n", + " 0.007948\n", + " 55.921246\n", + " 19339996.0\n", + " 13.843322\n", + " 69940.492188\n", + " 646.675903\n", + " 45301.042969\n", + " 66312.703125\n", + " 19157822.0\n", " \n", " \n", "\n", - "

5 rows × 25 columns

\n", + "

5 rows × 27 columns

\n", "" ], "text/plain": [ - " timepoint_id sample_id model_0/beta_c_param model_0/kappa_param \\\n", - "0 0 0 0.326421 0.261223 \n", - "1 1 0 0.326421 0.261223 \n", - "2 2 0 0.326421 0.261223 \n", - "3 3 0 0.326421 0.261223 \n", - "4 4 0 0.326421 0.261223 \n", + " timepoint_id sample_id model_0/persistent_beta_c_param \\\n", + "0 0 0 0.318085 \n", + "1 1 0 0.318085 \n", + "2 2 0 0.318085 \n", + "3 3 0 0.318085 \n", + "4 4 0 0.318085 \n", "\n", - " model_0/gamma_param model_0/hosp_param model_0/death_hosp_param \\\n", - "0 0.23488 0.076074 0.043145 \n", - "1 0.23488 0.076074 0.043145 \n", - "2 0.23488 0.076074 0.043145 \n", - "3 0.23488 0.076074 0.043145 \n", - "4 0.23488 0.076074 0.043145 \n", + " model_0/persistent_kappa_param model_0/persistent_gamma_param \\\n", + "0 0.11662 0.487817 \n", + "1 0.11662 0.487817 \n", + "2 0.11662 0.487817 \n", + "3 0.11662 0.487817 \n", + "4 0.11662 0.487817 \n", "\n", - " model_1/beta_c_param model_1/beta_nc_param model_1/kappa_param ... \\\n", - "0 0.193485 0.634903 0.329632 ... \n", - "1 0.193485 0.634903 0.329632 ... \n", - "2 0.193485 0.634903 0.329632 ... \n", - "3 0.193485 0.634903 0.329632 ... \n", - "4 0.193485 0.634903 0.329632 ... \n", + " model_0/persistent_hosp_param model_0/persistent_death_hosp_param \\\n", + "0 0.160195 0.081918 \n", + "1 0.160195 0.081918 \n", + "2 0.160195 0.081918 \n", + "3 0.160195 0.081918 \n", + "4 0.160195 0.081918 \n", "\n", - " model_0/H_state model_0/I_state model_0/R_state model_0/S_state \\\n", - "0 1.969539 28.096125 53.085701 19339886.0 \n", - "1 2.642947 32.182053 122.968758 19339816.0 \n", - "2 3.070854 36.731533 203.108948 19339724.0 \n", - "3 3.512248 41.923016 294.621429 19339622.0 \n", - "4 4.009643 47.847847 399.073700 19339474.0 \n", + " model_0/persistent_I0_param model_1/persistent_beta_c_param \\\n", + "0 3.146115 0.579869 \n", + "1 3.146115 0.579869 \n", + "2 3.146115 0.579869 \n", + "3 3.146115 0.579869 \n", + "4 3.146115 0.579869 \n", "\n", - " model_1/D_state model_1/E_state model_1/H_state model_1/I_state \\\n", - "0 0.122893 21.867935 1.392038 14.199327 \n", - "1 0.302223 15.316343 1.210305 9.964979 \n", - "2 0.441501 10.735215 0.880138 6.984461 \n", - "3 0.540961 7.524287 0.621196 4.895389 \n", - "4 0.610921 5.273749 0.435978 3.431168 \n", + " model_1/persistent_beta_nc_param ... model_0/H_state_state \\\n", + "0 0.31039 ... 2.607337 \n", + "1 0.31039 ... 0.991349 \n", + "2 0.31039 ... 0.262363 \n", + "3 0.31039 ... 0.061175 \n", + "4 0.31039 ... 0.013418 \n", "\n", - " model_1/R_state model_1/S_state \n", - "0 63.355595 19339900.0 \n", - "1 113.576477 19339864.0 \n", - "2 148.973877 19339840.0 \n", - "3 173.809586 19339820.0 \n", - "4 191.220322 19339806.0 \n", + " model_0/I_state_state model_0/R_state_state model_0/S_state_state \\\n", + "0 4.878901 38.814823 19339996.0 \n", + "1 0.990093 52.083225 19339996.0 \n", + "2 0.198274 55.123028 19339996.0 \n", + "3 0.039698 55.782387 19339996.0 \n", + "4 0.007948 55.921246 19339996.0 \n", "\n", - "[5 rows x 25 columns]" + " model_1/D_state_state model_1/E_state_state model_1/H_state_state \\\n", + "0 0.016280 133.996201 1.183788 \n", + "1 0.115308 642.588379 5.893455 \n", + "2 0.593380 3080.865723 28.292126 \n", + "3 2.885418 14743.889648 135.559143 \n", + "4 13.843322 69940.492188 646.675903 \n", + "\n", + " model_1/I_state_state model_1/R_state_state model_1/S_state_state \n", + "0 86.414459 96.795776 19339724.0 \n", + "1 414.542816 574.882446 19338406.0 \n", + "2 1987.789429 2867.893311 19332076.0 \n", + "3 9519.206055 13856.029297 19301794.0 \n", + "4 45301.042969 66312.703125 19157822.0 \n", + "\n", + "[5 rows x 27 columns]" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "### ENSEMBLE SAMPLE INTERFACE ###\n", - "\n", - "model_paths = [model_1_path, model_2_path]\n", + "model_paths = [model1, model2]\n", "solution_mappings = [lambda x : x, lambda x : x] # Conveniently, these two models operate on exactly the same state space, with the same names.\n", "\n", "ensemble_result = pyciemss.ensemble_sample(model_paths, solution_mappings, end_time, logging_step_size, num_samples, start_time=start_time)\n", "ensemble_result['data'].head()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Calibrate interface\n", + "Calibrate a model to a dataset by mapping model state varibale or observables to columns in the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'inferred_parameters': AutoGuideList(\n", + " (0): AutoDelta()\n", + " (1): AutoLowRankMultivariateNormal()\n", + " ),\n", + " 'loss': 243.56869277358055}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_mapping = {\"Infected\": \"I\"} # data_mapping = \"column_name\": \"observable/state_variable\"\n", + "num_iterations = 10\n", + "calibrated_results = calibrate(model1, dataset1, data_mapping=data_mapping, num_iterations=num_iterations)\n", + "parameter_estimates = calibrated_results[\"inferred_parameters\"]\n", + "calibrated_results" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'persistent_beta_c': tensor(0.3868, grad_fn=),\n", + " 'persistent_kappa': tensor(0.4557, grad_fn=),\n", + " 'persistent_gamma': tensor(0.2839, grad_fn=),\n", + " 'persistent_hosp': tensor(0.1006, grad_fn=),\n", + " 'persistent_death_hosp': tensor(0.0501, grad_fn=),\n", + " 'persistent_I0': tensor(9.2005, grad_fn=)}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "parameter_estimates()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pass the parameter estimates to `sample` to sample from the calibrated model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'data': timepoint_id sample_id persistent_beta_c_param persistent_kappa_param \\\n", + " 0 0 0 0.414435 0.433556 \n", + " 1 1 0 0.414435 0.433556 \n", + " 2 2 0 0.414435 0.433556 \n", + " 3 3 0 0.414435 0.433556 \n", + " 4 4 0 0.414435 0.433556 \n", + " 5 5 0 0.414435 0.433556 \n", + " 6 6 0 0.414435 0.433556 \n", + " 7 7 0 0.414435 0.433556 \n", + " 8 8 0 0.414435 0.433556 \n", + " 9 0 1 0.424505 0.412429 \n", + " 10 1 1 0.424505 0.412429 \n", + " 11 2 1 0.424505 0.412429 \n", + " 12 3 1 0.424505 0.412429 \n", + " 13 4 1 0.424505 0.412429 \n", + " 14 5 1 0.424505 0.412429 \n", + " 15 6 1 0.424505 0.412429 \n", + " 16 7 1 0.424505 0.412429 \n", + " 17 8 1 0.424505 0.412429 \n", + " 18 0 2 0.420682 0.417782 \n", + " 19 1 2 0.420682 0.417782 \n", + " 20 2 2 0.420682 0.417782 \n", + " 21 3 2 0.420682 0.417782 \n", + " 22 4 2 0.420682 0.417782 \n", + " 23 5 2 0.420682 0.417782 \n", + " 24 6 2 0.420682 0.417782 \n", + " 25 7 2 0.420682 0.417782 \n", + " 26 8 2 0.420682 0.417782 \n", + " \n", + " persistent_gamma_param persistent_hosp_param \\\n", + " 0 0.293311 0.103353 \n", + " 1 0.293311 0.103353 \n", + " 2 0.293311 0.103353 \n", + " 3 0.293311 0.103353 \n", + " 4 0.293311 0.103353 \n", + " 5 0.293311 0.103353 \n", + " 6 0.293311 0.103353 \n", + " 7 0.293311 0.103353 \n", + " 8 0.293311 0.103353 \n", + " 9 0.295502 0.097597 \n", + " 10 0.295502 0.097597 \n", + " 11 0.295502 0.097597 \n", + " 12 0.295502 0.097597 \n", + " 13 0.295502 0.097597 \n", + " 14 0.295502 0.097597 \n", + " 15 0.295502 0.097597 \n", + " 16 0.295502 0.097597 \n", + " 17 0.295502 0.097597 \n", + " 18 0.265383 0.103509 \n", + " 19 0.265383 0.103509 \n", + " 20 0.265383 0.103509 \n", + " 21 0.265383 0.103509 \n", + " 22 0.265383 0.103509 \n", + " 23 0.265383 0.103509 \n", + " 24 0.265383 0.103509 \n", + " 25 0.265383 0.103509 \n", + " 26 0.265383 0.103509 \n", + " \n", + " persistent_death_hosp_param persistent_I0_param D_state_state \\\n", + " 0 0.054385 9.075936 0.206448 \n", + " 1 0.054385 9.075936 0.788959 \n", + " 2 0.054385 9.075936 1.864389 \n", + " 3 0.054385 9.075936 3.794818 \n", + " 4 0.054385 9.075936 7.252632 \n", + " 5 0.054385 9.075936 13.444811 \n", + " 6 0.054385 9.075936 24.531963 \n", + " 7 0.054385 9.075936 44.378624 \n", + " 8 0.054385 9.075936 79.884438 \n", + " 9 0.048039 8.883058 0.167836 \n", + " 10 0.048039 8.883058 0.612311 \n", + " 11 0.048039 8.883058 1.364731 \n", + " 12 0.048039 8.883058 2.598298 \n", + " 13 0.048039 8.883058 4.615367 \n", + " 14 0.048039 8.883058 7.912742 \n", + " 15 0.048039 8.883058 13.302500 \n", + " 16 0.048039 8.883058 22.111193 \n", + " 17 0.048039 8.883058 36.502872 \n", + " 18 0.056225 8.600613 0.200924 \n", + " 19 0.056225 8.600613 0.805037 \n", + " 20 0.056225 8.600613 2.002240 \n", + " 21 0.056225 8.600613 4.313284 \n", + " 22 0.056225 8.600613 8.766099 \n", + " 23 0.056225 8.600613 17.343636 \n", + " 24 0.056225 8.600613 33.863155 \n", + " 25 0.056225 8.600613 65.665779 \n", + " 26 0.056225 8.600613 126.837700 \n", + " \n", + " E_state_state H_state_state I_state_state R_state_state S_state_state \\\n", + " 0 48.883446 3.665491 34.711571 68.322739 19339884.0 \n", + " 1 87.496017 7.246261 62.214676 202.439713 19339684.0 \n", + " 2 156.703629 13.070772 111.426102 443.223328 19339318.0 \n", + " 3 280.640289 23.421701 199.555862 874.533142 19338660.0 \n", + " 4 502.558044 41.946537 357.364471 1646.964233 19337458.0 \n", + " 5 899.827332 75.112633 639.889954 3030.148926 19335396.0 \n", + " 6 1610.718506 134.477676 1145.518066 5506.589844 19331624.0 \n", + " 7 2881.894775 240.682861 2049.865967 9939.023438 19324872.0 \n", + " 8 4356.958984 429.294678 3563.881836 17856.638672 19313760.0 \n", + " 9 44.345261 3.307929 32.113510 66.215637 19339894.0 \n", + " 10 72.443115 6.026298 52.547020 185.707672 19339768.0 \n", + " 11 118.431320 9.936190 85.905792 381.589355 19339442.0 \n", + " 12 193.607986 16.255127 140.437637 701.888428 19338988.0 \n", + " 13 316.488403 26.574503 229.575348 1225.507690 19338252.0 \n", + " 14 517.314453 43.440029 375.261414 2081.445312 19337022.0 \n", + " 15 845.456970 71.001732 613.324341 3480.470947 19335016.0 \n", + " 16 1381.432251 116.031425 1002.214294 5766.817871 19331724.0 \n", + " 17 1922.659424 189.078857 1594.416748 9497.415039 19326782.0 \n", + " 18 50.199841 3.537989 37.854744 64.965637 19339900.0 \n", + " 19 96.637131 7.498452 72.997345 202.465958 19339708.0 \n", + " 20 186.180420 14.539669 140.637543 467.963348 19339236.0 \n", + " 21 358.674561 28.024151 270.942108 979.539673 19338400.0 \n", + " 22 690.908752 53.988152 521.930847 1965.070557 19336842.0 \n", + " 23 1330.613525 103.990067 1005.251038 3863.406982 19333710.0 \n", + " 24 2561.600342 200.248489 1935.499146 7519.072266 19327790.0 \n", + " 25 4927.651367 385.412231 3724.225586 14555.504883 19316402.0 \n", + " 26 7992.514160 739.002747 6967.222656 28067.730469 19296158.0 \n", + " \n", + " infected_observable_state exposed_observable_state \\\n", + " 0 34.711571 48.883446 \n", + " 1 62.214676 87.496017 \n", + " 2 111.426102 156.703629 \n", + " 3 199.555862 280.640289 \n", + " 4 357.364471 502.558044 \n", + " 5 639.889954 899.827332 \n", + " 6 1145.518066 1610.718506 \n", + " 7 2049.865967 2881.894775 \n", + " 8 3563.881836 4356.958984 \n", + " 9 32.113510 44.345261 \n", + " 10 52.547020 72.443115 \n", + " 11 85.905792 118.431320 \n", + " 12 140.437637 193.607986 \n", + " 13 229.575348 316.488403 \n", + " 14 375.261414 517.314453 \n", + " 15 613.324341 845.456970 \n", + " 16 1002.214294 1381.432251 \n", + " 17 1594.416748 1922.659424 \n", + " 18 37.854744 50.199841 \n", + " 19 72.997345 96.637131 \n", + " 20 140.637543 186.180420 \n", + " 21 270.942108 358.674561 \n", + " 22 521.930847 690.908752 \n", + " 23 1005.251038 1330.613525 \n", + " 24 1935.499146 2561.600342 \n", + " 25 3724.225586 4927.651367 \n", + " 26 6967.222656 7992.514160 \n", + " \n", + " hospitalized_observable_state dead_observable_state \n", + " 0 3.665491 0.206448 \n", + " 1 7.246261 0.788959 \n", + " 2 13.070772 1.864389 \n", + " 3 23.421701 3.794818 \n", + " 4 41.946537 7.252632 \n", + " 5 75.112633 13.444811 \n", + " 6 134.477676 24.531963 \n", + " 7 240.682861 44.378624 \n", + " 8 429.294678 79.884438 \n", + " 9 3.307929 0.167836 \n", + " 10 6.026298 0.612311 \n", + " 11 9.936190 1.364731 \n", + " 12 16.255127 2.598298 \n", + " 13 26.574503 4.615367 \n", + " 14 43.440029 7.912742 \n", + " 15 71.001732 13.302500 \n", + " 16 116.031425 22.111193 \n", + " 17 189.078857 36.502872 \n", + " 18 3.537989 0.200924 \n", + " 19 7.498452 0.805037 \n", + " 20 14.539669 2.002240 \n", + " 21 28.024151 4.313284 \n", + " 22 53.988152 8.766099 \n", + " 23 103.990067 17.343636 \n", + " 24 200.248489 33.863155 \n", + " 25 385.412231 65.665779 \n", + " 26 739.002747 126.837700 ,\n", + " 'unprocessed_result': {'persistent_beta_c': tensor([0.4144, 0.4245, 0.4207]),\n", + " 'persistent_kappa': tensor([0.4336, 0.4124, 0.4178]),\n", + " 'persistent_gamma': tensor([0.2933, 0.2955, 0.2654]),\n", + " 'persistent_hosp': tensor([0.1034, 0.0976, 0.1035]),\n", + " 'persistent_death_hosp': tensor([0.0544, 0.0480, 0.0562]),\n", + " 'persistent_I0': tensor([9.0759, 8.8831, 8.6006]),\n", + " 'D_state': tensor([[ 0.2064, 0.7890, 1.8644, 3.7948, 7.2526, 13.4448, 24.5320,\n", + " 44.3786, 79.8844],\n", + " [ 0.1678, 0.6123, 1.3647, 2.5983, 4.6154, 7.9127, 13.3025,\n", + " 22.1112, 36.5029],\n", + " [ 0.2009, 0.8050, 2.0022, 4.3133, 8.7661, 17.3436, 33.8632,\n", + " 65.6658, 126.8377]]),\n", + " 'E_state': tensor([[ 48.8834, 87.4960, 156.7036, 280.6403, 502.5580, 899.8273,\n", + " 1610.7185, 2881.8948, 4356.9590],\n", + " [ 44.3453, 72.4431, 118.4313, 193.6080, 316.4884, 517.3145,\n", + " 845.4570, 1381.4323, 1922.6594],\n", + " [ 50.1998, 96.6371, 186.1804, 358.6746, 690.9088, 1330.6135,\n", + " 2561.6003, 4927.6514, 7992.5142]]),\n", + " 'H_state': tensor([[ 3.6655, 7.2463, 13.0708, 23.4217, 41.9465, 75.1126, 134.4777,\n", + " 240.6829, 429.2947],\n", + " [ 3.3079, 6.0263, 9.9362, 16.2551, 26.5745, 43.4400, 71.0017,\n", + " 116.0314, 189.0789],\n", + " [ 3.5380, 7.4985, 14.5397, 28.0242, 53.9882, 103.9901, 200.2485,\n", + " 385.4122, 739.0027]]),\n", + " 'I_state': tensor([[ 34.7116, 62.2147, 111.4261, 199.5559, 357.3645, 639.8900,\n", + " 1145.5181, 2049.8660, 3563.8818],\n", + " [ 32.1135, 52.5470, 85.9058, 140.4376, 229.5753, 375.2614,\n", + " 613.3243, 1002.2143, 1594.4167],\n", + " [ 37.8547, 72.9973, 140.6375, 270.9421, 521.9308, 1005.2510,\n", + " 1935.4991, 3724.2256, 6967.2227]]),\n", + " 'R_state': tensor([[ 68.3227, 202.4397, 443.2233, 874.5331, 1646.9642, 3030.1489,\n", + " 5506.5898, 9939.0234, 17856.6387],\n", + " [ 66.2156, 185.7077, 381.5894, 701.8884, 1225.5077, 2081.4453,\n", + " 3480.4709, 5766.8179, 9497.4150],\n", + " [ 64.9656, 202.4660, 467.9633, 979.5397, 1965.0706, 3863.4070,\n", + " 7519.0723, 14555.5049, 28067.7305]]),\n", + " 'S_state': tensor([[19339884., 19339684., 19339318., 19338660., 19337458., 19335396.,\n", + " 19331624., 19324872., 19313760.],\n", + " [19339894., 19339768., 19339442., 19338988., 19338252., 19337022.,\n", + " 19335016., 19331724., 19326782.],\n", + " [19339900., 19339708., 19339236., 19338400., 19336842., 19333710.,\n", + " 19327790., 19316402., 19296158.]]),\n", + " 'infected_observable': tensor([[ 34.7116, 62.2147, 111.4261, 199.5559, 357.3645, 639.8900,\n", + " 1145.5181, 2049.8660, 3563.8818],\n", + " [ 32.1135, 52.5470, 85.9058, 140.4376, 229.5753, 375.2614,\n", + " 613.3243, 1002.2143, 1594.4167],\n", + " [ 37.8547, 72.9973, 140.6375, 270.9421, 521.9308, 1005.2510,\n", + " 1935.4991, 3724.2256, 6967.2227]]),\n", + " 'exposed_observable': tensor([[ 48.8834, 87.4960, 156.7036, 280.6403, 502.5580, 899.8273,\n", + " 1610.7185, 2881.8948, 4356.9590],\n", + " [ 44.3453, 72.4431, 118.4313, 193.6080, 316.4884, 517.3145,\n", + " 845.4570, 1381.4323, 1922.6594],\n", + " [ 50.1998, 96.6371, 186.1804, 358.6746, 690.9088, 1330.6135,\n", + " 2561.6003, 4927.6514, 7992.5142]]),\n", + " 'hospitalized_observable': tensor([[ 3.6655, 7.2463, 13.0708, 23.4217, 41.9465, 75.1126, 134.4777,\n", + " 240.6829, 429.2947],\n", + " [ 3.3079, 6.0263, 9.9362, 16.2551, 26.5745, 43.4400, 71.0017,\n", + " 116.0314, 189.0789],\n", + " [ 3.5380, 7.4985, 14.5397, 28.0242, 53.9882, 103.9901, 200.2485,\n", + " 385.4122, 739.0027]]),\n", + " 'dead_observable': tensor([[ 0.2064, 0.7890, 1.8644, 3.7948, 7.2526, 13.4448, 24.5320,\n", + " 44.3786, 79.8844],\n", + " [ 0.1678, 0.6123, 1.3647, 2.5983, 4.6154, 7.9127, 13.3025,\n", + " 22.1112, 36.5029],\n", + " [ 0.2009, 0.8050, 2.0022, 4.3133, 8.7661, 17.3436, 33.8632,\n", + " 65.6658, 126.8377]])}}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "calibrated_sample_results = pyciemss.sample(model1, end_time, logging_step_size, num_samples, \n", + " start_time=start_time, inferred_parameters=parameter_estimates)\n", + "calibrated_sample_results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO:\n", + "# - Add intervention example\n", + "# - Add examples for calibrate_ensemble and optimize interfaces as they become available\n", + "# - Plot results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "CIEMSS-ENV", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -673,7 +1222,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.18" + "version": "3.10.9" } }, "nbformat": 4,