Skip to content

Commit

Permalink
Removed workers as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Aug 6, 2024
1 parent b6b76fa commit 0a8f7c1
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions src/progpy/data_models/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
If early stopping is desired. Default is True
early_stop.cfg (dict):
Configuration to pass into early stopping callback (if enabled). See keras documentation (https://keras.io/api/callbacks/early_stopping) for options. E.g., {'patience': 5}
workers (int):
Number of workers to use when training. One worker indicates no multiprocessing
Returns:
LSTMStateTransitionModel: Generated Model
Expand All @@ -460,7 +458,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
'normalize': True,
'early_stop': True,
'early_stop.cfg': {'patience': 3, 'monitor': 'loss'},
'workers': 1
}.copy() # Copy is needed to avoid updating default

params.update(LSTMStateTransitionModel.default_params)
Expand Down Expand Up @@ -498,10 +495,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
raise TypeError(f"epochs must be an integer greater than 0, not {type(params['epochs'])}")
if params['epochs'] < 1:
raise ValueError(f"epochs must be greater than 0, got {params['epochs']}")
if not isinstance(params['workers'], int):
raise TypeError(f"workers must be positive integer, got {type(params['workers'])}")
if params['workers'] < 1:
raise ValueError(f"workers must be positive integer, got {params['workers']}")
if np.isscalar(inputs): # Is scalar (e.g., SimResult)
inputs = [inputs]
if np.isscalar(outputs):
Expand Down Expand Up @@ -587,9 +580,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
output_data,
epochs=params['epochs'],
callbacks=callbacks,
validation_split=params['validation_split'],
workers=params['workers'],
use_multiprocessing=(params['workers'] > 1))
validation_split=params['validation_split'])

# Split model into separate models
n_state_layers = params['layers'] + 1 + (params['dropout'] > 0) + (params['normalize'])
Expand Down

0 comments on commit 0a8f7c1

Please sign in to comment.