Skip to content

Commit

Permalink
adding n_jobs param and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Sep 11, 2023
1 parent fc2e4f5 commit 8680db2
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion fedot/core/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def get_not_encoded_data(self):
else:
cat_features_names = np.array([f'cat_feature_{i}' for i in range(1, cat_features.shape[1] + 1)])

if num_features and cat_features:
if num_features is not None and cat_features is not None:
new_features = np.hstack((num_features, cat_features))
new_features_names = np.hstack((num_features_names, cat_features_names))
new_features_idx = np.array(range(new_features.shape[1]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@


class FedotCatBoostImplementation(ModelImplementation):
__operation_params = ['use_eval_set']
__operation_params = ['use_eval_set', 'n_jobs']

def __init__(self, params: Optional[OperationParameters] = None):
super().__init__(params)

self.params.update(**self.params.to_dict())

# TODO: Adding checking params compatibility with each other
# self.check_params(self.params.to_dict())
self.params.update(**self.check_and_update_params(self.params.to_dict()))

self.model_params = {k: v for k, v in self.params.to_dict().items() if k not in self.__operation_params}
self.model = None
Expand All @@ -30,15 +28,13 @@ def fit(self, input_data: InputData):
input_data = input_data.get_not_encoded_data()

if self.params.get('use_eval_set'):
# TODO: Using this method for tuning
train_input, eval_input = train_test_data_setup(input_data)

train_input = self.convert_to_pool(train_input)
eval_input = self.convert_to_pool(eval_input)

self.model.fit(
X=train_input,
eval_set=eval_input,
)
self.model.fit(X=train_input, eval_set=eval_input)

else:
train_input = self.convert_to_pool(input_data)
Expand All @@ -53,7 +49,9 @@ def predict(self, input_data: InputData):
return prediction

@staticmethod
def check_params(params):
def check_and_update_params(params):
params['thread_count'] = params['n_jobs']

if params['use_best_model'] or params['early_stopping_rounds'] and not params['use_eval_set']:
params['use_best_model'] = False
params['early_stopping_rounds'] = False
Expand Down
6 changes: 4 additions & 2 deletions fedot/core/repository/data/default_operation_params.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
"iterations": 1000,
"use_eval_set": false,
"use_best_model": false,
"early_stopping_rounds": 100
"early_stopping_rounds": null,
"n_jobs": 1
},
"catboostreg": {
"allow_writing_files": false,
"verbose": false,
"iterations": 1000,
"use_eval_set": false,
"use_best_model": false,
"early_stopping_rounds": 100
"early_stopping_rounds": null,
"n_jobs": 1
},
"lgbm": {
"num_leaves": 32,
Expand Down
4 changes: 2 additions & 2 deletions test/integration/models/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_boosting_classification_operation():
)

for model_name in model_names:
pipeline = PipelineBuilder().add_node(model_name).build()
pipeline.fit(train_data, n_jobs=-1)
pipeline = PipelineBuilder().add_node(model_name, params={'n_jobs': -1}).build()
pipeline.fit(train_data)
predicted_output = pipeline.predict(test_data, output_mode='labels')
metric = roc_auc(test_data.target, predicted_output.predict)

Expand Down

0 comments on commit 8680db2

Please sign in to comment.