diff --git a/src/progpy/data_models/lstm_model.py b/src/progpy/data_models/lstm_model.py index 4e27de4..0449ad4 100644 --- a/src/progpy/data_models/lstm_model.py +++ b/src/progpy/data_models/lstm_model.py @@ -439,8 +439,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): If early stopping is desired. Default is True early_stop.cfg (dict): Configuration to pass into early stopping callback (if enabled). See keras documentation (https://keras.io/api/callbacks/early_stopping) for options. E.g., {'patience': 5} - workers (int): - Number of workers to use when training. One worker indicates no multiprocessing Returns: LSTMStateTransitionModel: Generated Model @@ -460,7 +458,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): 'normalize': True, 'early_stop': True, 'early_stop.cfg': {'patience': 3, 'monitor': 'loss'}, - 'workers': 1 }.copy() # Copy is needed to avoid updating default params.update(LSTMStateTransitionModel.default_params) @@ -498,10 +495,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): raise TypeError(f"epochs must be an integer greater than 0, not {type(params['epochs'])}") if params['epochs'] < 1: raise ValueError(f"epochs must be greater than 0, got {params['epochs']}") - if not isinstance(params['workers'], int): - raise TypeError(f"workers must be positive integer, got {type(params['workers'])}") - if params['workers'] < 1: - raise ValueError(f"workers must be positive integer, got {params['workers']}") if np.isscalar(inputs): # Is scalar (e.g., SimResult) inputs = [inputs] if np.isscalar(outputs): @@ -579,7 +572,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): output_data.append(t_all) model = keras.Model(inputs, outputs) - model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"]) + model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"]*len(outputs)) # Train model history = model.fit( @@ -587,9 +580,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs): output_data, epochs=params['epochs'], callbacks=callbacks, - validation_split=params['validation_split'], - workers=params['workers'], - use_multiprocessing=(params['workers'] > 1)) + validation_split=params['validation_split']) # Split model into separate models n_state_layers = params['layers'] + 1 + (params['dropout'] > 0) + (params['normalize'])