Skip to content

Commit

Permalink
Merge pull request #290 from alexbrillant/fix-new-automl-and-choose-o…
Browse files Browse the repository at this point in the history
…ne-step-of-hp-space

Fix new automl printing, and choose one step of hyperparams spaces
  • Loading branch information
guillaume-chevalier authored Mar 12, 2020
2 parents 59f8943 + da69ad3 commit 94bb857
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 92 deletions.
2 changes: 0 additions & 2 deletions examples/auto_ml_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def main(tmpdir, sleep_time: float = 0, n_iter: int = 10):
pipeline,
refit_trial=True,
n_trials=n_iter,
print_metrics=False,
cache_folder_when_no_handle=str(tmpdir),
validation_split_function=validation_splitter(0.2),
hyperparams_optimizer=RandomSearchHyperparameterSelectionStrategy(),
Expand Down Expand Up @@ -106,7 +105,6 @@ def main(tmpdir, sleep_time: float = 0, n_iter: int = 10):
pipeline,
refit_trial=False,
n_trials=n_iter,
print_metrics=False,
cache_folder_when_no_handle=str(tmpdir),
validation_split_function=validation_splitter(0.2),
hyperparams_optimizer=RandomSearchHyperparameterSelectionStrategy(),
Expand Down
1 change: 0 additions & 1 deletion examples/boston_housing_meta_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def main(tmpdir: LocalPath):
refit_trial=True,
n_trials=10,
epochs=10,
print_metrics=False,
cache_folder_when_no_handle=str(tmpdir),
scoring_callback=ScoringCallback(mean_squared_error, higher_score_is_better=False),
callbacks=[MetricCallback('mse', metric_function=mean_squared_error, higher_score_is_better=False)],
Expand Down
2 changes: 1 addition & 1 deletion neuraxle/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.3"
__version__ = "0.3.4"
9 changes: 5 additions & 4 deletions neuraxle/metaopt/auto_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def fit(self, p, train_data_container: DataContainer, validation_data_container:
early_stopping = False

for i in range(self.epochs):
self.print_func('epoch {}/{}'.format(i, self.epochs))
self.print_func('\nepoch {}/{}'.format(i + 1, self.epochs))
p = p.handle_fit(train_data_container, context)

y_pred_train = p.handle_predict(train_data_container, context)
Expand Down Expand Up @@ -570,7 +570,6 @@ def __init__(
callbacks: List[BaseCallback] = None,
refit_scoring_function: Callable = None,
print_func: Callable = None,
print_metrics=True,
cache_folder_when_no_handle=None
):
BaseStep.__init__(self)
Expand All @@ -579,7 +578,6 @@ def __init__(
self.scoring_callback = scoring_callback
self.validation_split_function = create_split_data_container_function(validation_split_function)

self.print_metrics = print_metrics
if print_func is None:
print_func = print

Expand Down Expand Up @@ -631,7 +629,7 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
training_data_container, validation_data_container = self.validation_split_function(data_container)

for trial_number in range(self.n_trial):
self.print_func('trial {}/{}'.format(trial_number, self.n_trial))
self.print_func('\ntrial {}/{}'.format(trial_number + 1, self.n_trial))

auto_ml_data = self._load_auto_ml_data(trial_number)
p = copy.deepcopy(self.pipeline)
Expand All @@ -650,6 +648,7 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
)

repo_trial.set_success()
self.print_func('trial {}/{} score: {}'.format(trial_number + 1, self.n_trial, repo_trial.get_validation_scores()[-1]))
except Exception as error:
track = traceback.format_exc()
self.print_func(track)
Expand All @@ -658,6 +657,8 @@ def _fit_data_container(self, data_container: DataContainer, context: ExecutionC
self.hyperparams_repository.save_trial(repo_trial)

best_hyperparams = self.hyperparams_repository.get_best_hyperparams()

self.print_func('best hyperparams:\n{}'.format(json.dumps(best_hyperparams.to_nested_dict(), sort_keys=True, indent=4)))
p: BaseStep = self._load_virgin_model(hyperparams=best_hyperparams)
if self.refit_trial:
p = self.trainer.refit(
Expand Down
10 changes: 9 additions & 1 deletion neuraxle/metaopt/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,14 @@ class MetricCallback(BaseCallback):
:class:`~neuraxle.base.HyperparameterSamples`,
:class:`~neuraxle.data_container.DataContainer`
"""
def __init__(self, name: str, metric_function: Callable, higher_score_is_better: bool):
def __init__(self, name: str, metric_function: Callable, higher_score_is_better: bool, print_metrics=True, print_function=None):
self.name = name
self.metric_function = metric_function
self.higher_score_is_better = higher_score_is_better
self.print_metrics = print_metrics
if print_function is None:
print_function = print
self.print_function = print_function

def call(self, trial: Trial, epoch_number: int, total_epochs: int, input_train: DataContainer,
pred_train: DataContainer, input_val: DataContainer, pred_val: DataContainer, is_finished_and_fitted: bool):
Expand All @@ -333,6 +337,10 @@ def call(self, trial: Trial, epoch_number: int, total_epochs: int, input_train:
higher_score_is_better=self.higher_score_is_better
)

if self.print_metrics:
self.print_function('{} train: {}'.format(self.name, train_score))
self.print_function('{} validation: {}'.format(self.name, validation_score))

return False


Expand Down
Loading

0 comments on commit 94bb857

Please sign in to comment.