Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct Parameters #145

Merged
merged 4 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/01_Simulation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1200,11 +1200,11 @@
"results1 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n",
"fig = results1.outputs.plot(title='default')\n",
"\n",
"m.parameters['throwing_speed'] = 10\n",
"m['throwing_speed'] = 10\n",
"results2 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n",
"fig = results2.outputs.plot(title='slow')\n",
"\n",
"m.parameters['throwing_speed'] = 80\n",
"m['throwing_speed'] = 80\n",
"results3 = m.simulate_to_threshold(threshold_keys='impact', dt=0.1, save_freq=0.1)\n",
"fig = results3.outputs.plot(title='fast')"
]
Expand Down
36 changes: 18 additions & 18 deletions examples/02_param_est.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
"# Printing state before\n",
"print('Model configuration before')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))"
]
},
Expand Down Expand Up @@ -176,7 +176,7 @@
"source": [
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))"
]
},
Expand Down Expand Up @@ -223,7 +223,7 @@
"m.estimate_params(times = times, inputs = inputs, outputs = outputs, keys = keys, dt=0.1, tol=1e-6)\n",
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1))"
]
},
Expand Down Expand Up @@ -260,14 +260,14 @@
"metadata": {},
"outputs": [],
"source": [
"m.parameters['thrower_height'] = 3.1\n",
"m.parameters['throwing_speed'] = 29\n",
"m['thrower_height'] = 3.1\n",
"m['throwing_speed'] = 29\n",
"\n",
"# Using MAE, or Mean Absolute Error instead of the default Mean Squared Error.\n",
"m.estimate_params(times = times, inputs = inputs, outputs = outputs, keys = keys, dt=0.1, tol=1e-9, error_method='MAX_E')\n",
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(times, inputs, outputs, dt=0.1, method='MAX_E'))"
]
},
Expand Down Expand Up @@ -309,9 +309,9 @@
"results = m.simulate_to_threshold(save_freq=0.5, dt=('auto', 0.1))\n",
"\n",
"# Resetting parameters to their incorrectly set values.\n",
"m.parameters['thrower_height'] = 20\n",
"m.parameters['throwing_speed'] = 3.1\n",
"m.parameters['g'] = 15\n",
"m['thrower_height'] = 20\n",
"m['throwing_speed'] = 3.1\n",
"m['g'] = 15\n",
"keys = ['thrower_height', 'throwing_speed', 'g']"
]
},
Expand All @@ -324,7 +324,7 @@
"m.estimate_params(times = results.times, inputs = results.inputs, outputs = results.outputs, keys = keys)\n",
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"print(' Error: ', m.calc_error(results.times, results.inputs, results.outputs))"
]
},
Expand Down Expand Up @@ -353,7 +353,7 @@
"def AME(m, keys):\n",
" error = 0\n",
" for key in keys:\n",
" error += abs(m.parameters[key] - true_Values.parameters[key])\n",
" error += abs(m[key] - true_Values[key])\n",
" return error"
]
},
Expand Down Expand Up @@ -393,9 +393,9 @@
" results = m.simulate_to_threshold(save_freq=0.5, dt=('auto', 0.1))\n",
" \n",
" # Resetting parameters to their originally incorrectly set values.\n",
" m.parameters['thrower_height'] = 20\n",
" m.parameters['throwing_speed'] = 3.1\n",
" m.parameters['g'] = 15\n",
" m['thrower_height'] = 20\n",
" m['throwing_speed'] = 3.1\n",
" m['g'] = 15\n",
"\n",
" m.estimate_params(times = results.times, inputs = results.inputs, outputs = results.outputs, keys = keys, dt=0.1)\n",
" error = AME(m, ['thrower_height', 'throwing_speed', 'g'])\n",
Expand Down Expand Up @@ -439,9 +439,9 @@
"metadata": {},
"outputs": [],
"source": [
"m.parameters['thrower_height'] = 20\n",
"m.parameters['throwing_speed'] = 3.1\n",
"m.parameters['g'] = 15"
"m['thrower_height'] = 20\n",
"m['throwing_speed'] = 3.1\n",
"m['g'] = 15"
]
},
{
Expand All @@ -461,7 +461,7 @@
"m.estimate_params(times=times, inputs=inputs, outputs=outputs, keys=keys, dt=0.1)\n",
"print('\\nOptimized configuration')\n",
"for key in keys:\n",
" print(\"-\", key, m.parameters[key])\n",
" print(\"-\", key, m[key])\n",
"error = AME(m, ['thrower_height', 'throwing_speed', 'g'])\n",
"print('AME Error: ', error)"
]
Expand Down
26 changes: 13 additions & 13 deletions examples/04_New Models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@
"class ThrownObject(ThrownObject):\n",
" def initialize(self, u=None, z=None):\n",
" return self.StateContainer({\n",
" 'x': self.parameters['thrower_height'],\n",
" 'v': self.parameters['throwing_speed']\n",
" 'x': self['thrower_height'],\n",
" 'v': self['throwing_speed']\n",
" })"
]
},
Expand Down Expand Up @@ -237,7 +237,7 @@
" def event_state(self, x): \n",
" x_max = x['x'] + np.square(x['v'])/(9.81*2)\n",
" return {\n",
" 'falling': np.maximum(x['v']/self.parameters['throwing_speed'],0),\n",
" 'falling': np.maximum(x['v']/self['throwing_speed'],0),\n",
" 'impact': np.maximum(x['x']/x_max,0) if x['v'] < 0 else 1\n",
" }"
]
Expand Down Expand Up @@ -415,8 +415,8 @@
"\n",
" def initialize(self, u=None, z=None):\n",
" return self.StateContainer({\n",
" 'x': self.parameters['thrower_height'], # Thrown, so initial altitude is height of thrower\n",
" 'v': self.parameters['throwing_speed'] # Velocity at which the ball is thrown - this guy is a professional baseball pitcher\n",
" 'x': self['thrower_height'], # Thrown, so initial altitude is height of thrower\n",
" 'v': self['throwing_speed'] # Velocity at which the ball is thrown - this guy is a professional baseball pitcher\n",
" })"
]
},
Expand All @@ -440,7 +440,7 @@
" def dx(self, x, u):\n",
" return self.StateContainer({\n",
" 'x': x['v'], \n",
" 'v': self.parameters['g']}) # Acceleration of gravity"
" 'v': self['g']}) # Acceleration of gravity"
]
},
{
Expand Down Expand Up @@ -480,11 +480,11 @@
" \n",
" def event_state(self, x): \n",
" # Use speed and position to estimate maximum height\n",
" x_max = x['x'] + np.square(x['v'])/(-self.parameters['g']*2)\n",
" x_max = x['x'] + np.square(x['v'])/(-self['g']*2)\n",
" # 1 until falling begins\n",
" x_max = np.where(x['v'] > 0, x['x'], x_max)\n",
" return {\n",
" 'falling': max(x['v']/self.parameters['throwing_speed'],0), # Throwing speed is max speed\n",
" 'falling': max(x['v']/self['throwing_speed'],0), # Throwing speed is max speed\n",
" 'impact': max(x['x']/x_max,0) # 1 until falling begins, then it's fraction of height\n",
" }"
]
Expand Down Expand Up @@ -654,7 +654,7 @@
"outputs": [],
"source": [
"obj = ThrownObject_ST()\n",
"print(\"Default Settings:\\n\\tthrower_height: {}\\n\\tthrowing_speed: {}\".format(obj.parameters['thrower_height'], obj.parameters['throwing_speed']))"
"print(\"Default Settings:\\n\\tthrower_height: {}\\n\\tthrowing_speed: {}\".format(obj['thrower_height'], obj['throwing_speed']))"
]
},
{
Expand All @@ -670,8 +670,8 @@
"metadata": {},
"outputs": [],
"source": [
"obj.parameters['thrower_height'] = 1.75 # Our thrower is 1.75 m tall\n",
"print(\"\\nUpdated Settings:\\n\\tthrower_height: {}\\n\\tthowing_speed: {}\".format(obj.parameters['thrower_height'], obj.parameters['throwing_speed']))"
"obj['thrower_height'] = 1.75 # Our thrower is 1.75 m tall\n",
"print(\"\\nUpdated Settings:\\n\\tthrower_height: {}\\n\\tthowing_speed: {}\".format(obj['thrower_height'], obj['throwing_speed']))"
]
},
{
Expand Down Expand Up @@ -743,7 +743,7 @@
" def time_of_event(self, x, *args, **kwargs):\n",
" # calculate time when object hits ground given x['x'] and x['v']\n",
" # 0 = x0 + v0*t - 0.5*g*t^2\n",
" g = self.parameters['g']\n",
" g = self['g']\n",
" t_impact = -(x['v'] + np.sqrt(x['v']*x['v'] - 2*g*x['x']))/g\n",
"\n",
" # 0 = v0 - g*t\n",
Expand Down Expand Up @@ -890,7 +890,7 @@
" def next_state(self, x, u, dt):\n",
"\n",
" A = np.array([[0, 1], [0, 0]]) # State transition matrix\n",
" B = np.array([[0], [self.parameters['g']]]) # Acceleration due to gravity\n",
" B = np.array([[0], [self['g']]]) # Acceleration due to gravity\n",
" x.matrix += (np.matmul(A, x.matrix) + B) * dt\n",
"\n",
" return x"
Expand Down
2 changes: 1 addition & 1 deletion examples/06_Combining_Models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"m_ensemble.parameters['aggregation_method'] = np.median\n",
"m_ensemble['aggregation_method'] = np.median\n",
"\n",
"results_ensemble_median = m_ensemble.simulate_to(t_end, future_loading)\n",
"fig = plt.plot(results_ensemble_median.times, [z['v'] for z in results_ensemble_median.outputs], color='orange', label='ensemble -median')\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/07_State_Estimation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@
"metadata": {},
"outputs": [],
"source": [
"print('Initial thrower height:', m.parameters['thrower_height'])\n",
"print('Initial speed:', m.parameters['throwing_speed'])"
"print('Initial thrower height:', m['thrower_height'])\n",
"print('Initial speed:', m['throwing_speed'])"
]
},
{
Expand Down
6 changes: 6 additions & 0 deletions src/progpy/prognostics_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ def __init__(self, data):

self.parameters = PrognosticsModelParameters(self, params, self.param_callbacks)

def __getitem__(self, arg):
return self.parameters[arg]

def __setitem__(self, key, value):
self.parameters[key] = value

def initialize(self, u=None, z=None):
"""
Calculate initial state given inputs and outputs. If not defined for a model, it will return parameters['x0']
Expand Down
15 changes: 15 additions & 0 deletions tests/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,21 @@ def test_parameter_equality(self):
self.assertTrue(m1.parameters == m2.parameters) # Checking to see previous equal statements stay the same
self.assertTrue(m2.parameters == m1.parameters)

def test_direct_params(self):
m1 = LinearThrownObject()
print('test')

# Accessing parameters directly
self.assertEqual(m1.parameters['g'], m1['g'])

# Setting parameters
m1['g'] *= 2 # Doubling
self.assertEqual(m1.parameters['g'], m1['g'])

# Updating from parameters
m1.parameters['g'] *= 2
self.assertEqual(m1.parameters['g'], m1['g'])

# This allows the module to be executed directly
def main():
load_test = unittest.TestLoader()
Expand Down
Loading