Skip to content

Commit 1711721

Browse files
authored
Merge pull request #341 from alexbrillant/trainer-train-method-before-mixins
Add Train Method To Trainer
2 parents 1425d58 + 5d3c0d0 commit 1711721

File tree

3 files changed

+207
-99
lines changed

3 files changed

+207
-99
lines changed

neuraxle/metaopt/auto_ml.py

Lines changed: 148 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -410,21 +410,19 @@ class Trainer:
410410
.. code-block:: python
411411
412412
trainer = Trainer(
413-
callbacks=[],
414413
epochs=10,
414+
callbacks=[EarlyStoppingCallback()],
415+
scoring_callback=ScoringCallback(mean_squared_error, higher_score_is_better=False),
416+
validation_splitter=ValidationSplitter(test_size=0.15),
415417
print_func=print
416418
)
417419
418-
repo_trial = trainer.fit(
419-
p=p,
420-
trial_repository=repo_trial,
421-
train_data_container=training_data_container,
422-
validation_data_container=validation_data_container,
423-
context=context
420+
repo_trial = trainer.train(
421+
pipeline=pipeline,
422+
data_inputs=data_inputs,
423+
expected_outputs=expected_outputs
424424
)
425425
426-
pipeline = trainer.refit(repo_trial.pipeline, data_container, context)
427-
428426
429427
.. seealso::
430428
:class:`AutoML`,
@@ -439,25 +437,120 @@ class Trainer:
439437

440438
def __init__(
441439
self,
442-
epochs,
443-
metrics=None,
444-
callbacks=None,
445-
print_metrics=True,
446-
print_func=None
440+
epochs: int,
441+
scoring_callback: ScoringCallback,
442+
validation_splitter: 'BaseValidationSplitter',
443+
callbacks: List[BaseCallback] = None,
444+
print_func: Callable = None
447445
):
448-
self.epochs = epochs
449-
if metrics is None:
450-
metrics = {}
451-
self.metrics = metrics
452-
self._initialize_metrics(metrics)
446+
self.epochs: int = epochs
447+
self.validation_split_function = validation_splitter
453448

454-
self.callbacks = CallbackList(callbacks)
449+
if callbacks is None:
450+
callbacks = []
451+
callbacks: List[BaseCallback] = [scoring_callback] + callbacks
452+
self.callbacks: CallbackList = CallbackList(callbacks)
455453

456454
if print_func is None:
457455
print_func = print
458456

459457
self.print_func = print_func
460-
self.print_metrics = print_metrics
458+
459+
def train(self, pipeline: BaseStep, data_inputs, expected_outputs=None) -> Trial:
460+
"""
461+
Train pipeline using the validation splitter.
462+
Track training, and validation metrics for each epoch.
463+
Note: the present method is just a shortcut to using the `execute_trial` method with less boilerplate code needed.
464+
Refer to `execute_trial` for full flexibility
465+
466+
:param pipeline: pipeline to train on
467+
:param data_inputs: data inputs
468+
:param expected_outputs: expected ouptuts to fit on
469+
:return: executed trial
470+
"""
471+
validation_splits: List[Tuple[DataContainer, DataContainer]] = self.validation_split_function.split_data_container(
472+
DataContainer(data_inputs=data_inputs, expected_outputs=expected_outputs)
473+
)
474+
475+
repo_trial: Trial = Trial(
476+
pipeline=pipeline,
477+
hyperparams=pipeline.get_hyperparams(),
478+
main_metric_name=self.get_main_metric_name()
479+
)
480+
481+
self.execute_trial(
482+
pipeline=pipeline,
483+
trial_number=1,
484+
repo_trial=repo_trial,
485+
context=ExecutionContext(),
486+
validation_splits=validation_splits,
487+
n_trial=1,
488+
delete_pipeline_on_completion=False
489+
)
490+
491+
return repo_trial
492+
493+
def execute_trial(
494+
self,
495+
pipeline: BaseStep,
496+
trial_number: int,
497+
repo_trial: Trial,
498+
context: ExecutionContext,
499+
validation_splits: List[Tuple[DataContainer, DataContainer]],
500+
n_trial: int,
501+
delete_pipeline_on_completion: bool = True
502+
):
503+
"""
504+
Train pipeline using the validation splitter.
505+
Track training, and validation metrics for each epoch.
506+
507+
:param pipeline: pipeline to train on
508+
:param trial_number: trial number
509+
:param repo_trial: repo trial
510+
:param validation_splits: validation splits
511+
:param context: execution context
512+
:param n_trial: total number of trials that will be executed
513+
:param delete_pipeline_on_completion: bool to delete pipeline on completion or not
514+
:return: executed trial split
515+
"""
516+
for training_data_container, validation_data_container in validation_splits:
517+
p = copy.deepcopy(pipeline)
518+
p.update_hyperparams(repo_trial.hyperparams)
519+
repo_trial.set_hyperparams(p.get_hyperparams())
520+
521+
repo_trial_split: TrialSplit = repo_trial.new_validation_split(
522+
pipeline=p,
523+
delete_pipeline_on_completion=delete_pipeline_on_completion
524+
)
525+
526+
with repo_trial_split:
527+
trial_split_description = _get_trial_split_description(
528+
repo_trial=repo_trial,
529+
repo_trial_split=repo_trial_split,
530+
validation_splits=validation_splits,
531+
trial_number=trial_number,
532+
n_trial=n_trial
533+
)
534+
535+
self.print_func('fitting trial {}'.format(
536+
trial_split_description
537+
))
538+
539+
repo_trial_split = self.fit_trial_split(
540+
trial_split=repo_trial_split,
541+
train_data_container=training_data_container,
542+
validation_data_container=validation_data_container,
543+
context=context
544+
)
545+
546+
repo_trial_split.set_success()
547+
548+
self.print_func('success trial {} score: {}'.format(
549+
trial_split_description,
550+
repo_trial_split.get_validation_score()
551+
))
552+
553+
return repo_trial_split
461554

462555
def fit_trial_split(
463556
self,
@@ -514,21 +607,6 @@ def refit(self, p: BaseStep, data_container: DataContainer, context: ExecutionCo
514607

515608
return p
516609

517-
def _initialize_metrics(self, metrics):
518-
"""
519-
Initialize metrics results dict for train, and validation using the metrics function dict.
520-
521-
:param metrics: metrics function dict
522-
523-
:return:
524-
"""
525-
self.metrics_results_train = {}
526-
self.metrics_results_validation = {}
527-
528-
for m in metrics:
529-
self.metrics_results_train[m] = []
530-
self.metrics_results_validation[m] = []
531-
532610
def get_main_metric_name(self) -> str:
533611
"""
534612
Get main metric name.
@@ -556,7 +634,6 @@ class AutoML(ForceHandleOnlyMixin, BaseStep):
556634
MetricCallback('mse', metric_function=mean_squared_error, higher_score_is_better=False)
557635
],
558636
refit_trial=True,
559-
print_metrics=False,
560637
cache_folder_when_no_handle=str(tmpdir)
561638
)
562639
@@ -595,7 +672,7 @@ def __init__(
595672
BaseStep.__init__(self)
596673
ForceHandleOnlyMixin.__init__(self, cache_folder=cache_folder_when_no_handle)
597674

598-
self.validation_split_function: BaseValidationSplitter = validation_splitter
675+
self.validation_splitter: BaseValidationSplitter = validation_splitter
599676

600677
if print_func is None:
601678
print_func = print
@@ -619,17 +696,14 @@ def __init__(
619696

620697
self.refit_scoring_function: Callable = refit_scoring_function
621698

622-
if callbacks is None:
623-
callbacks = []
624-
625-
callbacks: List[BaseCallback] = [scoring_callback] + callbacks
626-
627699
self.refit_trial: bool = refit_trial
628700

629701
self.trainer = Trainer(
630-
callbacks=callbacks,
631702
epochs=epochs,
632-
print_func=self.print_func
703+
scoring_callback=scoring_callback,
704+
callbacks=callbacks,
705+
print_func=self.print_func,
706+
validation_splitter=validation_splitter
633707
)
634708

635709
def _fit_data_container(self, data_container: DataContainer, context: ExecutionContext) -> 'BaseStep':
@@ -643,7 +717,7 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
643717
644718
:return: self
645719
"""
646-
validation_splits = self.validation_split_function.split_data_container(data_container)
720+
validation_splits = self.validation_splitter.split_data_container(data_container)
647721

648722
for trial_number in range(self.n_trial):
649723
try:
@@ -657,11 +731,13 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
657731
with self.hyperparams_repository.new_trial(auto_ml_data) as repo_trial:
658732
self.print_func('\ntrial {}/{}'.format(trial_number + 1, self.n_trial))
659733

660-
repo_trial_split = self._execute_trial(
734+
repo_trial_split = self.trainer.execute_trial(
735+
pipeline=self.pipeline,
661736
trial_number=trial_number,
662737
repo_trial=repo_trial,
663738
context=context,
664-
validation_splits=validation_splits
739+
validation_splits=validation_splits,
740+
n_trial=self.n_trial
665741
)
666742
except (SystemError, SystemExit, EOFError, KeyboardInterrupt) as error:
667743
track = traceback.format_exc()
@@ -670,8 +746,13 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
670746
raise error
671747
except Exception:
672748
track = traceback.format_exc()
673-
self.print_func('failed trial {}'.format(
674-
self._get_trial_split_description(repo_trial, repo_trial_split, validation_splits, trial_number)))
749+
self.print_func('failed trial {}'.format(_get_trial_split_description(
750+
repo_trial=repo_trial,
751+
repo_trial_split=repo_trial_split,
752+
validation_splits=validation_splits,
753+
trial_number=trial_number,
754+
n_trial=self.n_trial
755+
)))
675756
self.print_func(track)
676757
finally:
677758
repo_trial.update_final_trial_status()
@@ -694,51 +775,6 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
694775

695776
return self
696777

697-
def _execute_trial(self, trial_number: int, repo_trial: Trial, context: ExecutionContext,
698-
validation_splits: List[Tuple[DataContainer, DataContainer]]):
699-
for training_data_container, validation_data_container in validation_splits:
700-
p = copy.deepcopy(self.pipeline)
701-
p.update_hyperparams(repo_trial.hyperparams)
702-
repo_trial.set_hyperparams(p.get_hyperparams())
703-
704-
with repo_trial.new_validation_split(p) as repo_trial_split:
705-
trial_split_description = self._get_trial_split_description(
706-
repo_trial=repo_trial,
707-
repo_trial_split=repo_trial_split,
708-
validation_splits=validation_splits,
709-
trial_number=trial_number
710-
)
711-
712-
self.print_func('fitting trial {}'.format(
713-
trial_split_description
714-
))
715-
716-
repo_trial_split = self.trainer.fit_trial_split(
717-
trial_split=repo_trial_split,
718-
train_data_container=training_data_container,
719-
validation_data_container=validation_data_container,
720-
context=context
721-
)
722-
723-
repo_trial_split.set_success()
724-
725-
self.print_func('success trial {} score: {}'.format(
726-
trial_split_description,
727-
repo_trial_split.get_validation_score()
728-
))
729-
730-
return repo_trial_split
731-
732-
def _get_trial_split_description(self, repo_trial, repo_trial_split, validation_splits, trial_number):
733-
trial_split_description = '{}/{} split {}/{}\nhyperparams: {}\n'.format(
734-
trial_number + 1,
735-
self.n_trial,
736-
repo_trial_split.split_number + 1,
737-
len(validation_splits),
738-
json.dumps(repo_trial.hyperparams, sort_keys=True, indent=4)
739-
)
740-
return trial_split_description
741-
742778
def get_best_model(self):
743779
"""
744780
Get best model using the hyperparams repository.
@@ -769,6 +805,23 @@ def _load_virgin_model(self, hyperparams: HyperparameterSamples) -> BaseStep:
769805
return copy.deepcopy(self.pipeline).update_hyperparams(hyperparams)
770806

771807

808+
def _get_trial_split_description(
809+
repo_trial: Trial,
810+
repo_trial_split: TrialSplit,
811+
validation_splits: List[Tuple[DataContainer, DataContainer]],
812+
trial_number: int,
813+
n_trial: int
814+
):
815+
trial_split_description = '{}/{} split {}/{}\nhyperparams: {}\n'.format(
816+
trial_number + 1,
817+
n_trial,
818+
repo_trial_split.split_number + 1,
819+
len(validation_splits),
820+
json.dumps(repo_trial.hyperparams, sort_keys=True, indent=4)
821+
)
822+
return trial_split_description
823+
824+
772825
class AutoMLContainer:
773826
"""
774827
Data object for auto ml.

0 commit comments

Comments
 (0)