@@ -410,21 +410,19 @@ class Trainer:
410
410
.. code-block:: python
411
411
412
412
trainer = Trainer(
413
- callbacks=[],
414
413
epochs=10,
414
+ callbacks=[EarlyStoppingCallback()],
415
+ scoring_callback=ScoringCallback(mean_squared_error, higher_score_is_better=False),
416
+ validation_splitter=ValidationSplitter(test_size=0.15),
415
417
print_func=print
416
418
)
417
419
418
- repo_trial = trainer.fit(
419
- p=p,
420
- trial_repository=repo_trial,
421
- train_data_container=training_data_container,
422
- validation_data_container=validation_data_container,
423
- context=context
420
+ repo_trial = trainer.train(
421
+ pipeline=pipeline,
422
+ data_inputs=data_inputs,
423
+ expected_outputs=expected_outputs
424
424
)
425
425
426
- pipeline = trainer.refit(repo_trial.pipeline, data_container, context)
427
-
428
426
429
427
.. seealso::
430
428
:class:`AutoML`,
@@ -439,25 +437,120 @@ class Trainer:
439
437
440
438
def __init__ (
441
439
self ,
442
- epochs ,
443
- metrics = None ,
444
- callbacks = None ,
445
- print_metrics = True ,
446
- print_func = None
440
+ epochs : int ,
441
+ scoring_callback : ScoringCallback ,
442
+ validation_splitter : 'BaseValidationSplitter' ,
443
+ callbacks : List [ BaseCallback ] = None ,
444
+ print_func : Callable = None
447
445
):
448
- self .epochs = epochs
449
- if metrics is None :
450
- metrics = {}
451
- self .metrics = metrics
452
- self ._initialize_metrics (metrics )
446
+ self .epochs : int = epochs
447
+ self .validation_split_function = validation_splitter
453
448
454
- self .callbacks = CallbackList (callbacks )
449
+ if callbacks is None :
450
+ callbacks = []
451
+ callbacks : List [BaseCallback ] = [scoring_callback ] + callbacks
452
+ self .callbacks : CallbackList = CallbackList (callbacks )
455
453
456
454
if print_func is None :
457
455
print_func = print
458
456
459
457
self .print_func = print_func
460
- self .print_metrics = print_metrics
458
+
459
+ def train (self , pipeline : BaseStep , data_inputs , expected_outputs = None ) -> Trial :
460
+ """
461
+ Train pipeline using the validation splitter.
462
+ Track training, and validation metrics for each epoch.
463
+ Note: the present method is just a shortcut to using the `execute_trial` method with less boilerplate code needed.
464
+ Refer to `execute_trial` for full flexibility
465
+
466
+ :param pipeline: pipeline to train on
467
+ :param data_inputs: data inputs
468
+ :param expected_outputs: expected ouptuts to fit on
469
+ :return: executed trial
470
+ """
471
+ validation_splits : List [Tuple [DataContainer , DataContainer ]] = self .validation_split_function .split_data_container (
472
+ DataContainer (data_inputs = data_inputs , expected_outputs = expected_outputs )
473
+ )
474
+
475
+ repo_trial : Trial = Trial (
476
+ pipeline = pipeline ,
477
+ hyperparams = pipeline .get_hyperparams (),
478
+ main_metric_name = self .get_main_metric_name ()
479
+ )
480
+
481
+ self .execute_trial (
482
+ pipeline = pipeline ,
483
+ trial_number = 1 ,
484
+ repo_trial = repo_trial ,
485
+ context = ExecutionContext (),
486
+ validation_splits = validation_splits ,
487
+ n_trial = 1 ,
488
+ delete_pipeline_on_completion = False
489
+ )
490
+
491
+ return repo_trial
492
+
493
+ def execute_trial (
494
+ self ,
495
+ pipeline : BaseStep ,
496
+ trial_number : int ,
497
+ repo_trial : Trial ,
498
+ context : ExecutionContext ,
499
+ validation_splits : List [Tuple [DataContainer , DataContainer ]],
500
+ n_trial : int ,
501
+ delete_pipeline_on_completion : bool = True
502
+ ):
503
+ """
504
+ Train pipeline using the validation splitter.
505
+ Track training, and validation metrics for each epoch.
506
+
507
+ :param pipeline: pipeline to train on
508
+ :param trial_number: trial number
509
+ :param repo_trial: repo trial
510
+ :param validation_splits: validation splits
511
+ :param context: execution context
512
+ :param n_trial: total number of trials that will be executed
513
+ :param delete_pipeline_on_completion: bool to delete pipeline on completion or not
514
+ :return: executed trial split
515
+ """
516
+ for training_data_container , validation_data_container in validation_splits :
517
+ p = copy .deepcopy (pipeline )
518
+ p .update_hyperparams (repo_trial .hyperparams )
519
+ repo_trial .set_hyperparams (p .get_hyperparams ())
520
+
521
+ repo_trial_split : TrialSplit = repo_trial .new_validation_split (
522
+ pipeline = p ,
523
+ delete_pipeline_on_completion = delete_pipeline_on_completion
524
+ )
525
+
526
+ with repo_trial_split :
527
+ trial_split_description = _get_trial_split_description (
528
+ repo_trial = repo_trial ,
529
+ repo_trial_split = repo_trial_split ,
530
+ validation_splits = validation_splits ,
531
+ trial_number = trial_number ,
532
+ n_trial = n_trial
533
+ )
534
+
535
+ self .print_func ('fitting trial {}' .format (
536
+ trial_split_description
537
+ ))
538
+
539
+ repo_trial_split = self .fit_trial_split (
540
+ trial_split = repo_trial_split ,
541
+ train_data_container = training_data_container ,
542
+ validation_data_container = validation_data_container ,
543
+ context = context
544
+ )
545
+
546
+ repo_trial_split .set_success ()
547
+
548
+ self .print_func ('success trial {} score: {}' .format (
549
+ trial_split_description ,
550
+ repo_trial_split .get_validation_score ()
551
+ ))
552
+
553
+ return repo_trial_split
461
554
462
555
def fit_trial_split (
463
556
self ,
@@ -514,21 +607,6 @@ def refit(self, p: BaseStep, data_container: DataContainer, context: ExecutionCo
514
607
515
608
return p
516
609
517
- def _initialize_metrics (self , metrics ):
518
- """
519
- Initialize metrics results dict for train, and validation using the metrics function dict.
520
-
521
- :param metrics: metrics function dict
522
-
523
- :return:
524
- """
525
- self .metrics_results_train = {}
526
- self .metrics_results_validation = {}
527
-
528
- for m in metrics :
529
- self .metrics_results_train [m ] = []
530
- self .metrics_results_validation [m ] = []
531
-
532
610
def get_main_metric_name (self ) -> str :
533
611
"""
534
612
Get main metric name.
@@ -556,7 +634,6 @@ class AutoML(ForceHandleOnlyMixin, BaseStep):
556
634
MetricCallback('mse', metric_function=mean_squared_error, higher_score_is_better=False)
557
635
],
558
636
refit_trial=True,
559
- print_metrics=False,
560
637
cache_folder_when_no_handle=str(tmpdir)
561
638
)
562
639
@@ -595,7 +672,7 @@ def __init__(
595
672
BaseStep .__init__ (self )
596
673
ForceHandleOnlyMixin .__init__ (self , cache_folder = cache_folder_when_no_handle )
597
674
598
- self .validation_split_function : BaseValidationSplitter = validation_splitter
675
+ self .validation_splitter : BaseValidationSplitter = validation_splitter
599
676
600
677
if print_func is None :
601
678
print_func = print
@@ -619,17 +696,14 @@ def __init__(
619
696
620
697
self .refit_scoring_function : Callable = refit_scoring_function
621
698
622
- if callbacks is None :
623
- callbacks = []
624
-
625
- callbacks : List [BaseCallback ] = [scoring_callback ] + callbacks
626
-
627
699
self .refit_trial : bool = refit_trial
628
700
629
701
self .trainer = Trainer (
630
- callbacks = callbacks ,
631
702
epochs = epochs ,
632
- print_func = self .print_func
703
+ scoring_callback = scoring_callback ,
704
+ callbacks = callbacks ,
705
+ print_func = self .print_func ,
706
+ validation_splitter = validation_splitter
633
707
)
634
708
635
709
def _fit_data_container (self , data_container : DataContainer , context : ExecutionContext ) -> 'BaseStep' :
@@ -643,7 +717,7 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
643
717
644
718
:return: self
645
719
"""
646
- validation_splits = self .validation_split_function .split_data_container (data_container )
720
+ validation_splits = self .validation_splitter .split_data_container (data_container )
647
721
648
722
for trial_number in range (self .n_trial ):
649
723
try :
@@ -657,11 +731,13 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
657
731
with self .hyperparams_repository .new_trial (auto_ml_data ) as repo_trial :
658
732
self .print_func ('\n trial {}/{}' .format (trial_number + 1 , self .n_trial ))
659
733
660
- repo_trial_split = self ._execute_trial (
734
+ repo_trial_split = self .trainer .execute_trial (
735
+ pipeline = self .pipeline ,
661
736
trial_number = trial_number ,
662
737
repo_trial = repo_trial ,
663
738
context = context ,
664
- validation_splits = validation_splits
739
+ validation_splits = validation_splits ,
740
+ n_trial = self .n_trial
665
741
)
666
742
except (SystemError , SystemExit , EOFError , KeyboardInterrupt ) as error :
667
743
track = traceback .format_exc ()
@@ -670,8 +746,13 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
670
746
raise error
671
747
except Exception :
672
748
track = traceback .format_exc ()
673
- self .print_func ('failed trial {}' .format (
674
- self ._get_trial_split_description (repo_trial , repo_trial_split , validation_splits , trial_number )))
749
+ self .print_func ('failed trial {}' .format (_get_trial_split_description (
750
+ repo_trial = repo_trial ,
751
+ repo_trial_split = repo_trial_split ,
752
+ validation_splits = validation_splits ,
753
+ trial_number = trial_number ,
754
+ n_trial = self .n_trial
755
+ )))
675
756
self .print_func (track )
676
757
finally :
677
758
repo_trial .update_final_trial_status ()
@@ -694,51 +775,6 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
694
775
695
776
return self
696
777
697
- def _execute_trial (self , trial_number : int , repo_trial : Trial , context : ExecutionContext ,
698
- validation_splits : List [Tuple [DataContainer , DataContainer ]]):
699
- for training_data_container , validation_data_container in validation_splits :
700
- p = copy .deepcopy (self .pipeline )
701
- p .update_hyperparams (repo_trial .hyperparams )
702
- repo_trial .set_hyperparams (p .get_hyperparams ())
703
-
704
- with repo_trial .new_validation_split (p ) as repo_trial_split :
705
- trial_split_description = self ._get_trial_split_description (
706
- repo_trial = repo_trial ,
707
- repo_trial_split = repo_trial_split ,
708
- validation_splits = validation_splits ,
709
- trial_number = trial_number
710
- )
711
-
712
- self .print_func ('fitting trial {}' .format (
713
- trial_split_description
714
- ))
715
-
716
- repo_trial_split = self .trainer .fit_trial_split (
717
- trial_split = repo_trial_split ,
718
- train_data_container = training_data_container ,
719
- validation_data_container = validation_data_container ,
720
- context = context
721
- )
722
-
723
- repo_trial_split .set_success ()
724
-
725
- self .print_func ('success trial {} score: {}' .format (
726
- trial_split_description ,
727
- repo_trial_split .get_validation_score ()
728
- ))
729
-
730
- return repo_trial_split
731
-
732
- def _get_trial_split_description (self , repo_trial , repo_trial_split , validation_splits , trial_number ):
733
- trial_split_description = '{}/{} split {}/{}\n hyperparams: {}\n ' .format (
734
- trial_number + 1 ,
735
- self .n_trial ,
736
- repo_trial_split .split_number + 1 ,
737
- len (validation_splits ),
738
- json .dumps (repo_trial .hyperparams , sort_keys = True , indent = 4 )
739
- )
740
- return trial_split_description
741
-
742
778
def get_best_model (self ):
743
779
"""
744
780
Get best model using the hyperparams repository.
@@ -769,6 +805,23 @@ def _load_virgin_model(self, hyperparams: HyperparameterSamples) -> BaseStep:
769
805
return copy .deepcopy (self .pipeline ).update_hyperparams (hyperparams )
770
806
771
807
808
+ def _get_trial_split_description (
809
+ repo_trial : Trial ,
810
+ repo_trial_split : TrialSplit ,
811
+ validation_splits : List [Tuple [DataContainer , DataContainer ]],
812
+ trial_number : int ,
813
+ n_trial : int
814
+ ):
815
+ trial_split_description = '{}/{} split {}/{}\n hyperparams: {}\n ' .format (
816
+ trial_number + 1 ,
817
+ n_trial ,
818
+ repo_trial_split .split_number + 1 ,
819
+ len (validation_splits ),
820
+ json .dumps (repo_trial .hyperparams , sort_keys = True , indent = 4 )
821
+ )
822
+ return trial_split_description
823
+
824
+
772
825
class AutoMLContainer :
773
826
"""
774
827
Data object for auto ml.
0 commit comments