Skip to content

Commit

Permalink
Merge pull request #341 from alexbrillant/trainer-train-method-before…
Browse files Browse the repository at this point in the history
…-mixins

Add Train Method To Trainer
  • Loading branch information
alexbrillant authored May 22, 2020
2 parents 1425d58 + 5d3c0d0 commit 1711721
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 99 deletions.
243 changes: 148 additions & 95 deletions neuraxle/metaopt/auto_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,21 +410,19 @@ class Trainer:
.. code-block:: python
trainer = Trainer(
callbacks=[],
epochs=10,
callbacks=[EarlyStoppingCallback()],
scoring_callback=ScoringCallback(mean_squared_error, higher_score_is_better=False),
validation_splitter=ValidationSplitter(test_size=0.15),
print_func=print
)
repo_trial = trainer.fit(
p=p,
trial_repository=repo_trial,
train_data_container=training_data_container,
validation_data_container=validation_data_container,
context=context
repo_trial = trainer.train(
pipeline=pipeline,
data_inputs=data_inputs,
expected_outputs=expected_outputs
)
pipeline = trainer.refit(repo_trial.pipeline, data_container, context)
.. seealso::
:class:`AutoML`,
Expand All @@ -439,25 +437,120 @@ class Trainer:

def __init__(
self,
epochs,
metrics=None,
callbacks=None,
print_metrics=True,
print_func=None
epochs: int,
scoring_callback: ScoringCallback,
validation_splitter: 'BaseValidationSplitter',
callbacks: List[BaseCallback] = None,
print_func: Callable = None
):
self.epochs = epochs
if metrics is None:
metrics = {}
self.metrics = metrics
self._initialize_metrics(metrics)
self.epochs: int = epochs
self.validation_split_function = validation_splitter

self.callbacks = CallbackList(callbacks)
if callbacks is None:
callbacks = []
callbacks: List[BaseCallback] = [scoring_callback] + callbacks
self.callbacks: CallbackList = CallbackList(callbacks)

if print_func is None:
print_func = print

self.print_func = print_func
self.print_metrics = print_metrics

def train(self, pipeline: BaseStep, data_inputs, expected_outputs=None) -> Trial:
"""
Train pipeline using the validation splitter.
Track training, and validation metrics for each epoch.
Note: the present method is just a shortcut to using the `execute_trial` method with less boilerplate code needed.
Refer to `execute_trial` for full flexibility
:param pipeline: pipeline to train on
:param data_inputs: data inputs
:param expected_outputs: expected ouptuts to fit on
:return: executed trial
"""
validation_splits: List[Tuple[DataContainer, DataContainer]] = self.validation_split_function.split_data_container(
DataContainer(data_inputs=data_inputs, expected_outputs=expected_outputs)
)

repo_trial: Trial = Trial(
pipeline=pipeline,
hyperparams=pipeline.get_hyperparams(),
main_metric_name=self.get_main_metric_name()
)

self.execute_trial(
pipeline=pipeline,
trial_number=1,
repo_trial=repo_trial,
context=ExecutionContext(),
validation_splits=validation_splits,
n_trial=1,
delete_pipeline_on_completion=False
)

return repo_trial

def execute_trial(
self,
pipeline: BaseStep,
trial_number: int,
repo_trial: Trial,
context: ExecutionContext,
validation_splits: List[Tuple[DataContainer, DataContainer]],
n_trial: int,
delete_pipeline_on_completion: bool = True
):
"""
Train pipeline using the validation splitter.
Track training, and validation metrics for each epoch.
:param pipeline: pipeline to train on
:param trial_number: trial number
:param repo_trial: repo trial
:param validation_splits: validation splits
:param context: execution context
:param n_trial: total number of trials that will be executed
:param delete_pipeline_on_completion: bool to delete pipeline on completion or not
:return: executed trial split
"""
for training_data_container, validation_data_container in validation_splits:
p = copy.deepcopy(pipeline)
p.update_hyperparams(repo_trial.hyperparams)
repo_trial.set_hyperparams(p.get_hyperparams())

repo_trial_split: TrialSplit = repo_trial.new_validation_split(
pipeline=p,
delete_pipeline_on_completion=delete_pipeline_on_completion
)

with repo_trial_split:
trial_split_description = _get_trial_split_description(
repo_trial=repo_trial,
repo_trial_split=repo_trial_split,
validation_splits=validation_splits,
trial_number=trial_number,
n_trial=n_trial
)

self.print_func('fitting trial {}'.format(
trial_split_description
))

repo_trial_split = self.fit_trial_split(
trial_split=repo_trial_split,
train_data_container=training_data_container,
validation_data_container=validation_data_container,
context=context
)

repo_trial_split.set_success()

self.print_func('success trial {} score: {}'.format(
trial_split_description,
repo_trial_split.get_validation_score()
))

return repo_trial_split

def fit_trial_split(
self,
Expand Down Expand Up @@ -514,21 +607,6 @@ def refit(self, p: BaseStep, data_container: DataContainer, context: ExecutionCo

return p

def _initialize_metrics(self, metrics):
"""
Initialize metrics results dict for train, and validation using the metrics function dict.
:param metrics: metrics function dict
:return:
"""
self.metrics_results_train = {}
self.metrics_results_validation = {}

for m in metrics:
self.metrics_results_train[m] = []
self.metrics_results_validation[m] = []

def get_main_metric_name(self) -> str:
"""
Get main metric name.
Expand Down Expand Up @@ -556,7 +634,6 @@ class AutoML(ForceHandleOnlyMixin, BaseStep):
MetricCallback('mse', metric_function=mean_squared_error, higher_score_is_better=False)
],
refit_trial=True,
print_metrics=False,
cache_folder_when_no_handle=str(tmpdir)
)
Expand Down Expand Up @@ -595,7 +672,7 @@ def __init__(
BaseStep.__init__(self)
ForceHandleOnlyMixin.__init__(self, cache_folder=cache_folder_when_no_handle)

self.validation_split_function: BaseValidationSplitter = validation_splitter
self.validation_splitter: BaseValidationSplitter = validation_splitter

if print_func is None:
print_func = print
Expand All @@ -619,17 +696,14 @@ def __init__(

self.refit_scoring_function: Callable = refit_scoring_function

if callbacks is None:
callbacks = []

callbacks: List[BaseCallback] = [scoring_callback] + callbacks

self.refit_trial: bool = refit_trial

self.trainer = Trainer(
callbacks=callbacks,
epochs=epochs,
print_func=self.print_func
scoring_callback=scoring_callback,
callbacks=callbacks,
print_func=self.print_func,
validation_splitter=validation_splitter
)

def _fit_data_container(self, data_container: DataContainer, context: ExecutionContext) -> 'BaseStep':
Expand All @@ -643,7 +717,7 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
:return: self
"""
validation_splits = self.validation_split_function.split_data_container(data_container)
validation_splits = self.validation_splitter.split_data_container(data_container)

for trial_number in range(self.n_trial):
try:
Expand All @@ -657,11 +731,13 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
with self.hyperparams_repository.new_trial(auto_ml_data) as repo_trial:
self.print_func('\ntrial {}/{}'.format(trial_number + 1, self.n_trial))

repo_trial_split = self._execute_trial(
repo_trial_split = self.trainer.execute_trial(
pipeline=self.pipeline,
trial_number=trial_number,
repo_trial=repo_trial,
context=context,
validation_splits=validation_splits
validation_splits=validation_splits,
n_trial=self.n_trial
)
except (SystemError, SystemExit, EOFError, KeyboardInterrupt) as error:
track = traceback.format_exc()
Expand All @@ -670,8 +746,13 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
raise error
except Exception:
track = traceback.format_exc()
self.print_func('failed trial {}'.format(
self._get_trial_split_description(repo_trial, repo_trial_split, validation_splits, trial_number)))
self.print_func('failed trial {}'.format(_get_trial_split_description(
repo_trial=repo_trial,
repo_trial_split=repo_trial_split,
validation_splits=validation_splits,
trial_number=trial_number,
n_trial=self.n_trial
)))
self.print_func(track)
finally:
repo_trial.update_final_trial_status()
Expand All @@ -694,51 +775,6 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC

return self

def _execute_trial(self, trial_number: int, repo_trial: Trial, context: ExecutionContext,
validation_splits: List[Tuple[DataContainer, DataContainer]]):
for training_data_container, validation_data_container in validation_splits:
p = copy.deepcopy(self.pipeline)
p.update_hyperparams(repo_trial.hyperparams)
repo_trial.set_hyperparams(p.get_hyperparams())

with repo_trial.new_validation_split(p) as repo_trial_split:
trial_split_description = self._get_trial_split_description(
repo_trial=repo_trial,
repo_trial_split=repo_trial_split,
validation_splits=validation_splits,
trial_number=trial_number
)

self.print_func('fitting trial {}'.format(
trial_split_description
))

repo_trial_split = self.trainer.fit_trial_split(
trial_split=repo_trial_split,
train_data_container=training_data_container,
validation_data_container=validation_data_container,
context=context
)

repo_trial_split.set_success()

self.print_func('success trial {} score: {}'.format(
trial_split_description,
repo_trial_split.get_validation_score()
))

return repo_trial_split

def _get_trial_split_description(self, repo_trial, repo_trial_split, validation_splits, trial_number):
trial_split_description = '{}/{} split {}/{}\nhyperparams: {}\n'.format(
trial_number + 1,
self.n_trial,
repo_trial_split.split_number + 1,
len(validation_splits),
json.dumps(repo_trial.hyperparams, sort_keys=True, indent=4)
)
return trial_split_description

def get_best_model(self):
"""
Get best model using the hyperparams repository.
Expand Down Expand Up @@ -769,6 +805,23 @@ def _load_virgin_model(self, hyperparams: HyperparameterSamples) -> BaseStep:
return copy.deepcopy(self.pipeline).update_hyperparams(hyperparams)


def _get_trial_split_description(
repo_trial: Trial,
repo_trial_split: TrialSplit,
validation_splits: List[Tuple[DataContainer, DataContainer]],
trial_number: int,
n_trial: int
):
trial_split_description = '{}/{} split {}/{}\nhyperparams: {}\n'.format(
trial_number + 1,
n_trial,
repo_trial_split.split_number + 1,
len(validation_splits),
json.dumps(repo_trial.hyperparams, sort_keys=True, indent=4)
)
return trial_split_description


class AutoMLContainer:
"""
Data object for auto ml.
Expand Down
Loading

0 comments on commit 1711721

Please sign in to comment.