Skip to content

Commit

Permalink
setting parameters directly
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Jul 5, 2024
1 parent 31af475 commit 60e637f
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 34 deletions.
4 changes: 2 additions & 2 deletions examples/01_Simulation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1200,11 +1200,11 @@
"results1 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n",
"fig = results1.outputs.plot(title='default')\n",
"\n",
"m.parameters['throwing_speed'] = 10\n",
"m['throwing_speed'] = 10\n",
"results2 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n",
"fig = results2.outputs.plot(title='slow')\n",
"\n",
"m.parameters['throwing_speed'] = 80\n",
"m['throwing_speed'] = 80\n",
"results3 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n",
"fig = results3.outputs.plot(title='fast')"
]
Expand Down
36 changes: 18 additions & 18 deletions examples/02_param_est.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
"# Printing state before\n",
"print('Model configuration before')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))"
]
},
Expand Down Expand Up @@ -176,7 +176,7 @@
"source": [
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))"
]
},
Expand Down Expand Up @@ -223,7 +223,7 @@
"m.estimate_params(times = times, inputs = inputs, outputs = outputs, keys = keys, dt=0.1, tol=1e-6)\n",
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))"
]
},
Expand Down Expand Up @@ -260,14 +260,14 @@
"metadata": {},
"outputs": [],
"source": [
"m.parameters['thrower_height'] = 3.1\n",
"m.parameters['throwing_speed'] = 29\n",
"m['thrower_height'] = 3.1\n",
"m['throwing_speed'] = 29\n",
"\n",
"# Using MAE, or Mean Absolute Error instead of the default Mean Squared Error.\n",
"m.estimate_params(times = times, inputs = inputs, outputs = outputs, keys = keys, dt=0.1, tol=1e-9, error_method='MAX_E')\n",
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1, method='MAX_E'))"
]
},
Expand Down Expand Up @@ -309,9 +309,9 @@
"results = m.simulate_to_threshold(save_freq=0.5, dt=('auto', 0.1))\n",
"\n",
"# Resetting parameters to their incorrectly set values.\n",
"m.parameters['thrower_height'] = 20\n",
"m.parameters['throwing_speed'] = 3.1\n",
"m.parameters['g'] = 15\n",
"m['thrower_height'] = 20\n",
"m['throwing_speed'] = 3.1\n",
"m['g'] = 15\n",
"keys = ['thrower_height', 'throwing_speed', 'g']"
]
},
Expand All @@ -324,7 +324,7 @@
"m.estimate_params(times = results.times, inputs = results.inputs, outputs = results.outputs, keys = keys)\n",
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(results.times, results.inputs, results.outputs))"
]
},
Expand Down Expand Up @@ -353,7 +353,7 @@
"def AME(m, keys):\n",
" error = 0\n",
" for key in keys:\n",
" error += abs(m.parameters[key] - true_Values.parameters[key])\n",
" error += abs(m[key] - true_Values.parameters[key])\n",
" return error"
]
},
Expand Down Expand Up @@ -393,9 +393,9 @@
" results = m.simulate_to_threshold(save_freq=0.5, dt=('auto', 0.1))\n",
" \n",
" # Resetting parameters to their originally incorrectly set values.\n",
" m.parameters['thrower_height'] = 20\n",
" m.parameters['throwing_speed'] = 3.1\n",
" m.parameters['g'] = 15\n",
" m['thrower_height'] = 20\n",
" m['throwing_speed'] = 3.1\n",
" m['g'] = 15\n",
"\n",
" m.estimate_params(times = results.times, inputs = results.inputs, outputs = results.outputs, keys = keys, dt=0.1)\n",
" error = AME(m, ['thrower_height', 'throwing_speed', 'g'])\n",
Expand Down Expand Up @@ -439,9 +439,9 @@
"metadata": {},
"outputs": [],
"source": [
"m.parameters['thrower_height'] = 20\n",
"m.parameters['throwing_speed'] = 3.1\n",
"m.parameters['g'] = 15"
"m['thrower_height'] = 20\n",
"m['throwing_speed'] = 3.1\n",
"m['g'] = 15"
]
},
{
Expand All @@ -461,7 +461,7 @@
"m.estimate_params(times=times, inputs=inputs, outputs=outputs, keys=keys, dt=0.1)\n",
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"error = AME(m, ['thrower_height', 'throwing_speed', 'g'])\n",
"print('AME Error: ', error)"
]
Expand Down
26 changes: 13 additions & 13 deletions examples/04_New Models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@
"class ThrownObject(ThrownObject):\n",
" def initialize(self, u=None, z=None):\n",
" return self.StateContainer({\n",
" 'x': self.parameters['thrower_height'],\n",
" 'v': self.parameters['throwing_speed']\n",
" 'x': self['thrower_height'],\n",
" 'v': self['throwing_speed']\n",
" })"
]
},
Expand Down Expand Up @@ -237,7 +237,7 @@
" def event_state(self, x): \n",
" x_max = x['x'] + np.square(x['v'])/(9.81*2)\n",
" return {\n",
" 'falling': np.maximum(x['v']/self.parameters['throwing_speed'],0),\n",
" 'falling': np.maximum(x['v']/self['throwing_speed'],0),\n",
" 'impact': np.maximum(x['x']/x_max,0) if x['v'] < 0 else 1\n",
" }"
]
Expand Down Expand Up @@ -415,8 +415,8 @@
"\n",
" def initialize(self, u=None, z=None):\n",
" return self.StateContainer({\n",
" 'x': self.parameters['thrower_height'], # Thrown, so initial altitude is height of thrower\n",
" 'v': self.parameters['throwing_speed'] # Velocity at which the ball is thrown - this guy is a professional baseball pitcher\n",
" 'x': self['thrower_height'], # Thrown, so initial altitude is height of thrower\n",
" 'v': self['throwing_speed'] # Velocity at which the ball is thrown - this guy is a professional baseball pitcher\n",
" })"
]
},
Expand All @@ -440,7 +440,7 @@
" def dx(self, x, u):\n",
" return self.StateContainer({\n",
" 'x': x['v'], \n",
" 'v': self.parameters['g']}) # Acceleration of gravity"
" 'v': self['g']}) # Acceleration of gravity"
]
},
{
Expand Down Expand Up @@ -480,11 +480,11 @@
" \n",
" def event_state(self, x): \n",
" # Use speed and position to estimate maximum height\n",
" x_max = x['x'] + np.square(x['v'])/(-self.parameters['g']*2)\n",
" x_max = x['x'] + np.square(x['v'])/(-self['g']*2)\n",
" # 1 until falling begins\n",
" x_max = np.where(x['v'] > 0, x['x'], x_max)\n",
" return {\n",
" 'falling': max(x['v']/self.parameters['throwing_speed'],0), # Throwing speed is max speed\n",
" 'falling': max(x['v']/self['throwing_speed'],0), # Throwing speed is max speed\n",
" 'impact': max(x['x']/x_max,0) # 1 until falling begins, then it's fraction of height\n",
" }"
]
Expand Down Expand Up @@ -654,7 +654,7 @@
"outputs": [],
"source": [
"obj = ThrownObject_ST()\n",
"print(\"Default Settings:\\n\\tthrower_height: {}\\n\\tthrowing_speed: {}\".format(obj.parameters['thrower_height'], obj.parameters['throwing_speed']))"
"print(\"Default Settings:\\n\\tthrower_height: {}\\n\\tthrowing_speed: {}\".format(obj['thrower_height'], obj['throwing_speed']))"
]
},
{
Expand All @@ -670,8 +670,8 @@
"metadata": {},
"outputs": [],
"source": [
"obj.parameters['thrower_height'] = 1.75 # Our thrower is 1.75 m tall\n",
"print(\"\\nUpdated Settings:\\n\\tthrower_height: {}\\n\\tthowing_speed: {}\".format(obj.parameters['thrower_height'], obj.parameters['throwing_speed']))"
"obj['thrower_height'] = 1.75 # Our thrower is 1.75 m tall\n",
"print(\"\\nUpdated Settings:\\n\\tthrower_height: {}\\n\\tthowing_speed: {}\".format(obj['thrower_height'], obj['throwing_speed']))"
]
},
{
Expand Down Expand Up @@ -743,7 +743,7 @@
" def time_of_event(self, x, *args, **kwargs):\n",
" # calculate time when object hits ground given x['x'] and x['v']\n",
" # 0 = x0 + v0*t - 0.5*g*t^2\n",
" g = self.parameters['g']\n",
" g = self['g']\n",
" t_impact = -(x['v'] + np.sqrt(x['v']*x['v'] - 2*g*x['x']))/g\n",
"\n",
" # 0 = v0 - g*t\n",
Expand Down Expand Up @@ -890,7 +890,7 @@
" def next_state(self, x, u, dt):\n",
"\n",
" A = np.array([[0, 1], [0, 0]]) # State transition matrix\n",
" B = np.array([[0], [self.parameters['g']]]) # Acceleration due to gravity\n",
" B = np.array([[0], [self['g']]]) # Acceleration due to gravity\n",
" x.matrix += (np.matmul(A, x.matrix) + B) * dt\n",
"\n",
" return x"
Expand Down
2 changes: 1 addition & 1 deletion examples/06_Combining_Models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"m_ensemble.parameters['aggregation_method'] = np.median\n",
"m_ensemble['aggregation_method'] = np.median\n",
"\n",
"results_ensemble_median = m_ensemble.simulate_to(t_end, future_loading)\n",
"fig = plt.plot(results_ensemble_median.times, [z['v'] for z in results_ensemble_median.outputs], color='orange', label='ensemble -median')\n",
Expand Down
6 changes: 6 additions & 0 deletions src/progpy/prognostics_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ def __init__(self, data):

self.parameters = PrognosticsModelParameters(self, params, self.param_callbacks)

def __getitem__(self, arg):
return self.parameters[arg]

def __setitem__(self, key, value):
self.parameters[key] = value

def initialize(self, u=None, z=None):
"""
Calculate initial state given inputs and outputs. If not defined for a model, it will return parameters['x0']
Expand Down

0 comments on commit 60e637f

Please sign in to comment.