Skip to content

Commit

Permalink
Merge branch 'dev' into 152-make-tensor-flow-optional
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert authored Aug 7, 2024
2 parents 93f467e + 1bc324f commit c236657
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions src/progpy/data_models/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
If early stopping is desired. Default is True
early_stop.cfg (dict):
Configuration to pass into early stopping callback (if enabled). See keras documentation (https://keras.io/api/callbacks/early_stopping) for options. E.g., {'patience': 5}
workers (int):
Number of workers to use when training. One worker indicates no multiprocessing
Returns:
LSTMStateTransitionModel: Generated Model
Expand All @@ -470,7 +468,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
'normalize': True,
'early_stop': True,
'early_stop.cfg': {'patience': 3, 'monitor': 'loss'},
'workers': 1
}.copy() # Copy is needed to avoid updating default

params.update(LSTMStateTransitionModel.default_params)
Expand Down Expand Up @@ -508,10 +505,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
raise TypeError(f"epochs must be an integer greater than 0, not {type(params['epochs'])}")
if params['epochs'] < 1:
raise ValueError(f"epochs must be greater than 0, got {params['epochs']}")
if not isinstance(params['workers'], int):
raise TypeError(f"workers must be positive integer, got {type(params['workers'])}")
if params['workers'] < 1:
raise ValueError(f"workers must be positive integer, got {params['workers']}")
if np.isscalar(inputs): # Is scalar (e.g., SimResult)
inputs = [inputs]
if np.isscalar(outputs):
Expand Down Expand Up @@ -592,17 +585,15 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
output_data.append(t_all)

model = keras.Model(inputs, outputs)
model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"])
model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"]*len(outputs))

# Train model
history = model.fit(
u_all,
output_data,
epochs=params['epochs'],
callbacks=callbacks,
validation_split=params['validation_split'],
workers=params['workers'],
use_multiprocessing=(params['workers'] > 1))
validation_split=params['validation_split'])

# Split model into separate models
n_state_layers = params['layers'] + 1 + (params['dropout'] > 0) + (params['normalize'])
Expand Down

0 comments on commit c236657

Please sign in to comment.