From 60e637f86a775a05061179b2fa85718d91e4b04f Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Fri, 5 Jul 2024 14:11:18 -0700 Subject: [PATCH 1/3] setting parameters directly --- examples/01_Simulation.ipynb | 4 ++-- examples/02_param_est.ipynb | 36 +++++++++++++++--------------- examples/04_New Models.ipynb | 26 ++++++++++----------- examples/06_Combining_Models.ipynb | 2 +- src/progpy/prognostics_model.py | 6 +++++ 5 files changed, 40 insertions(+), 34 deletions(-) diff --git a/examples/01_Simulation.ipynb b/examples/01_Simulation.ipynb index 75ca322..436d224 100644 --- a/examples/01_Simulation.ipynb +++ b/examples/01_Simulation.ipynb @@ -1200,11 +1200,11 @@ "results1 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n", "fig = results1.outputs.plot(title='default')\n", "\n", - "m.parameters['throwing_speed'] = 10\n", + "m['throwing_speed'] = 10\n", "results2 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n", "fig = results2.outputs.plot(title='slow')\n", "\n", - "m.parameters['throwing_speed'] = 80\n", + "m['throwing_speed'] = 80\n", "results3 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n", "fig = results3.outputs.plot(title='fast')" ] diff --git a/examples/02_param_est.ipynb b/examples/02_param_est.ipynb index 69aa940..1217789 100644 --- a/examples/02_param_est.ipynb +++ b/examples/02_param_est.ipynb @@ -137,7 +137,7 @@ "# Printing state before\n", "print('Model configuration before')\n", "for key in keys:\n", - " print(\"-\", key, m.parameters[key])\n", + " print(\"-\", key, m[key])\n", "print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))" ] }, @@ -176,7 +176,7 @@ "source": [ "print('\\nOptimized configuration')\n", "for key in keys:\n", - " print(\"-\", key, m.parameters[key])\n", + " print(\"-\", key, m[key])\n", "print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))" ] }, @@ -223,7 +223,7 @@ "m.estimate_params(times = times, inputs = inputs, outputs = outputs, keys = keys, dt=0.1, tol=1e-6)\n", "print('\\nOptimized configuration')\n", "for key in keys:\n", - " print(\"-\", key, m.parameters[key])\n", + " print(\"-\", key, m[key])\n", "print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))" ] }, @@ -260,14 +260,14 @@ "metadata": {}, "outputs": [], "source": [ - "m.parameters['thrower_height'] = 3.1\n", - "m.parameters['throwing_speed'] = 29\n", + "m['thrower_height'] = 3.1\n", + "m['throwing_speed'] = 29\n", "\n", "# Using MAE, or Mean Absolute Error instead of the default Mean Squared Error.\n", "m.estimate_params(times = times, inputs = inputs, outputs = outputs, keys = keys, dt=0.1, tol=1e-9, error_method='MAX_E')\n", "print('\\nOptimized configuration')\n", "for key in keys:\n", - " print(\"-\", key, m.parameters[key])\n", + " print(\"-\", key, m[key])\n", "print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1, method='MAX_E'))" ] }, @@ -309,9 +309,9 @@ "results = m.simulate_to_threshold(save_freq=0.5, dt=('auto', 0.1))\n", "\n", "# Resetting parameters to their incorrectly set values.\n", - "m.parameters['thrower_height'] = 20\n", - "m.parameters['throwing_speed'] = 3.1\n", - "m.parameters['g'] = 15\n", + "m['thrower_height'] = 20\n", + "m['throwing_speed'] = 3.1\n", + "m['g'] = 15\n", "keys = ['thrower_height', 'throwing_speed', 'g']" ] }, @@ -324,7 +324,7 @@ "m.estimate_params(times = results.times, inputs = results.inputs, outputs = results.outputs, keys = keys)\n", "print('\\nOptimized configuration')\n", "for key in keys:\n", - " print(\"-\", key, m.parameters[key])\n", + " print(\"-\", key, m[key])\n", "print(' Error: ', m.calc_error(results.times, results.inputs, results.outputs))" ] }, @@ -353,7 +353,7 @@ "def AME(m, keys):\n", " error = 0\n", " for key in keys:\n", - " error += abs(m.parameters[key] - true_Values.parameters[key])\n", + " error += abs(m[key] - true_Values.parameters[key])\n", " return error" ] }, @@ -393,9 +393,9 @@ " results = m.simulate_to_threshold(save_freq=0.5, dt=('auto', 0.1))\n", " \n", " # Resetting parameters to their originally incorrectly set values.\n", - " m.parameters['thrower_height'] = 20\n", - " m.parameters['throwing_speed'] = 3.1\n", - " m.parameters['g'] = 15\n", + " m['thrower_height'] = 20\n", + " m['throwing_speed'] = 3.1\n", + " m['g'] = 15\n", "\n", " m.estimate_params(times = results.times, inputs = results.inputs, outputs = results.outputs, keys = keys, dt=0.1)\n", " error = AME(m, ['thrower_height', 'throwing_speed', 'g'])\n", @@ -439,9 +439,9 @@ "metadata": {}, "outputs": [], "source": [ - "m.parameters['thrower_height'] = 20\n", - "m.parameters['throwing_speed'] = 3.1\n", - "m.parameters['g'] = 15" + "m['thrower_height'] = 20\n", + "m['throwing_speed'] = 3.1\n", + "m['g'] = 15" ] }, { @@ -461,7 +461,7 @@ "m.estimate_params(times=times, inputs=inputs, outputs=outputs, keys=keys, dt=0.1)\n", "print('\\nOptimized configuration')\n", "for key in keys:\n", - " print(\"-\", key, m.parameters[key])\n", + " print(\"-\", key, m[key])\n", "error = AME(m, ['thrower_height', 'throwing_speed', 'g'])\n", "print('AME Error: ', error)" ] diff --git a/examples/04_New Models.ipynb b/examples/04_New Models.ipynb index 0cd7256..0bb2d2f 100644 --- a/examples/04_New Models.ipynb +++ b/examples/04_New Models.ipynb @@ -192,8 +192,8 @@ "class ThrownObject(ThrownObject):\n", " def initialize(self, u=None, z=None):\n", " return self.StateContainer({\n", - " 'x': self.parameters['thrower_height'],\n", - " 'v': self.parameters['throwing_speed']\n", + " 'x': self['thrower_height'],\n", + " 'v': self['throwing_speed']\n", " })" ] }, @@ -237,7 +237,7 @@ " def event_state(self, x): \n", " x_max = x['x'] + np.square(x['v'])/(9.81*2)\n", " return {\n", - " 'falling': np.maximum(x['v']/self.parameters['throwing_speed'],0),\n", + " 'falling': np.maximum(x['v']/self['throwing_speed'],0),\n", " 'impact': np.maximum(x['x']/x_max,0) if x['v'] < 0 else 1\n", " }" ] @@ -415,8 +415,8 @@ "\n", " def initialize(self, u=None, z=None):\n", " return self.StateContainer({\n", - " 'x': self.parameters['thrower_height'], # Thrown, so initial altitude is height of thrower\n", - " 'v': self.parameters['throwing_speed'] # Velocity at which the ball is thrown - this guy is a professional baseball pitcher\n", + " 'x': self['thrower_height'], # Thrown, so initial altitude is height of thrower\n", + " 'v': self['throwing_speed'] # Velocity at which the ball is thrown - this guy is a professional baseball pitcher\n", " })" ] }, @@ -440,7 +440,7 @@ " def dx(self, x, u):\n", " return self.StateContainer({\n", " 'x': x['v'], \n", - " 'v': self.parameters['g']}) # Acceleration of gravity" + " 'v': self['g']}) # Acceleration of gravity" ] }, { @@ -480,11 +480,11 @@ " \n", " def event_state(self, x): \n", " # Use speed and position to estimate maximum height\n", - " x_max = x['x'] + np.square(x['v'])/(-self.parameters['g']*2)\n", + " x_max = x['x'] + np.square(x['v'])/(-self['g']*2)\n", " # 1 until falling begins\n", " x_max = np.where(x['v'] > 0, x['x'], x_max)\n", " return {\n", - " 'falling': max(x['v']/self.parameters['throwing_speed'],0), # Throwing speed is max speed\n", + " 'falling': max(x['v']/self['throwing_speed'],0), # Throwing speed is max speed\n", " 'impact': max(x['x']/x_max,0) # 1 until falling begins, then it's fraction of height\n", " }" ] @@ -654,7 +654,7 @@ "outputs": [], "source": [ "obj = ThrownObject_ST()\n", - "print(\"Default Settings:\\n\\tthrower_height: {}\\n\\tthrowing_speed: {}\".format(obj.parameters['thrower_height'], obj.parameters['throwing_speed']))" + "print(\"Default Settings:\\n\\tthrower_height: {}\\n\\tthrowing_speed: {}\".format(obj['thrower_height'], obj['throwing_speed']))" ] }, { @@ -670,8 +670,8 @@ "metadata": {}, "outputs": [], "source": [ - "obj.parameters['thrower_height'] = 1.75 # Our thrower is 1.75 m tall\n", - "print(\"\\nUpdated Settings:\\n\\tthrower_height: {}\\n\\tthowing_speed: {}\".format(obj.parameters['thrower_height'], obj.parameters['throwing_speed']))" + "obj['thrower_height'] = 1.75 # Our thrower is 1.75 m tall\n", + "print(\"\\nUpdated Settings:\\n\\tthrower_height: {}\\n\\tthowing_speed: {}\".format(obj['thrower_height'], obj['throwing_speed']))" ] }, { @@ -743,7 +743,7 @@ " def time_of_event(self, x, *args, **kwargs):\n", " # calculate time when object hits ground given x['x'] and x['v']\n", " # 0 = x0 + v0*t - 0.5*g*t^2\n", - " g = self.parameters['g']\n", + " g = self['g']\n", " t_impact = -(x['v'] + np.sqrt(x['v']*x['v'] - 2*g*x['x']))/g\n", "\n", " # 0 = v0 - g*t\n", @@ -890,7 +890,7 @@ " def next_state(self, x, u, dt):\n", "\n", " A = np.array([[0, 1], [0, 0]]) # State transition matrix\n", - " B = np.array([[0], [self.parameters['g']]]) # Acceleration due to gravity\n", + " B = np.array([[0], [self['g']]]) # Acceleration due to gravity\n", " x.matrix += (np.matmul(A, x.matrix) + B) * dt\n", "\n", " return x" diff --git a/examples/06_Combining_Models.ipynb b/examples/06_Combining_Models.ipynb index 7839753..26f6c06 100644 --- a/examples/06_Combining_Models.ipynb +++ b/examples/06_Combining_Models.ipynb @@ -468,7 +468,7 @@ "outputs": [], "source": [ "import numpy as np\n", - "m_ensemble.parameters['aggregation_method'] = np.median\n", + "m_ensemble['aggregation_method'] = np.median\n", "\n", "results_ensemble_median = m_ensemble.simulate_to(t_end, future_loading)\n", "fig = plt.plot(results_ensemble_median.times, [z['v'] for z in results_ensemble_median.outputs], color='orange', label='ensemble -median')\n", diff --git a/src/progpy/prognostics_model.py b/src/progpy/prognostics_model.py index 773d1c0..2febd02 100644 --- a/src/progpy/prognostics_model.py +++ b/src/progpy/prognostics_model.py @@ -202,6 +202,12 @@ def __init__(self, data): self.parameters = PrognosticsModelParameters(self, params, self.param_callbacks) + def __getitem__(self, arg): + return self.parameters[arg] + + def __setitem__(self, key, value): + self.parameters[key] = value + def initialize(self, u=None, z=None): """ Calculate initial state given inputs and outputs. If not defined for a model, it will return parameters['x0'] From db46efe9226f942bb62c95a19dd00e816a708d12 Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Fri, 5 Jul 2024 14:15:54 -0700 Subject: [PATCH 2/3] Updated examples --- examples/02_param_est.ipynb | 2 +- examples/07_State_Estimation.ipynb | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/02_param_est.ipynb b/examples/02_param_est.ipynb index 1217789..c860d74 100644 --- a/examples/02_param_est.ipynb +++ b/examples/02_param_est.ipynb @@ -353,7 +353,7 @@ "def AME(m, keys):\n", " error = 0\n", " for key in keys:\n", - " error += abs(m[key] - true_Values.parameters[key])\n", + " error += abs(m[key] - true_Values[key])\n", " return error" ] }, diff --git a/examples/07_State_Estimation.ipynb b/examples/07_State_Estimation.ipynb index 84cac76..5ee2394 100644 --- a/examples/07_State_Estimation.ipynb +++ b/examples/07_State_Estimation.ipynb @@ -84,8 +84,8 @@ "metadata": {}, "outputs": [], "source": [ - "print('Initial thrower height:', m.parameters['thrower_height'])\n", - "print('Initial speed:', m.parameters['throwing_speed'])" + "print('Initial thrower height:', m['thrower_height'])\n", + "print('Initial speed:', m['throwing_speed'])" ] }, { From c19e9bf01f2bc6d74ca6f7efad609e91cb8d8db6 Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Fri, 5 Jul 2024 14:39:32 -0700 Subject: [PATCH 3/3] Add test --- tests/test_base_models.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_base_models.py b/tests/test_base_models.py index 86b63a6..0d1a6fc 100644 --- a/tests/test_base_models.py +++ b/tests/test_base_models.py @@ -1343,6 +1343,21 @@ def test_parameter_equality(self): self.assertTrue(m1.parameters == m2.parameters) # Checking to see previous equal statements stay the same self.assertTrue(m2.parameters == m1.parameters) + def test_direct_params(self): + m1 = LinearThrownObject() + print('test') + + # Accessing parameters directly + self.assertEqual(m1.parameters['g'], m1['g']) + + # Setting parameters + m1['g'] *= 2 # Doubling + self.assertEqual(m1.parameters['g'], m1['g']) + + # Updating from parameters + m1.parameters['g'] *= 2 + self.assertEqual(m1.parameters['g'], m1['g']) + # This allows the module to be executed directly def main(): load_test = unittest.TestLoader()