Skip to content

Commit d3da621

Browse files
committed
fix resample
1 parent 6f265ff commit d3da621

File tree

4 files changed

+34
-5
lines changed

4 files changed

+34
-5
lines changed

fedot/api/api_utils/api_params_repository.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from golem.core.optimisers.genetic.operators.inheritance import GeneticSchemeTypesEnum
66
from golem.core.optimisers.genetic.operators.mutation import MutationTypesEnum
77

8-
from fedot.core.composer.gp_composer.specific_operators import parameter_change_mutation, boosting_mutation
8+
from fedot.core.composer.gp_composer.specific_operators import parameter_change_mutation, boosting_mutation, \
9+
add_resample_mutation
910
from fedot.core.constants import AUTO_PRESET_NAME
1011
from fedot.core.repository.tasks import TaskTypesEnum
1112
from fedot.core.utils import default_fedot_data_dir
@@ -135,5 +136,7 @@ def _get_default_mutations(task_type: TaskTypesEnum, params) -> Sequence[Mutatio
135136
# TODO remove workaround after boosting mutation fix
136137
if task_type == TaskTypesEnum.ts_forecasting:
137138
mutations.append(partial(boosting_mutation, params=params))
139+
else:
140+
mutations.append(add_resample_mutation)
138141

139142
return mutations

fedot/core/composer/gp_composer/specific_operators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,17 @@ def boosting_mutation(pipeline: Pipeline, requirements, graph_gen_params, **kwar
9090
return pipeline
9191

9292

93+
def add_resample_mutation(pipeline: Pipeline, **kwargs):
94+
resample_node = PipelineNode('resample')
95+
96+
p_nodes = [p_node for p_node in pipeline.primary_nodes]
97+
pipeline.add_node(resample_node)
98+
99+
for node in p_nodes:
100+
pipeline.connect_nodes(resample_node, node)
101+
return pipeline
102+
103+
93104
def choose_new_model(boosting_model_candidates: List[str]) -> str:
94105
""" Since 'linear' and 'dtreg' operations are suitable for solving the problem
95106
and they are simpler than others, they are preferred """

fedot/core/pipelines/pipeline.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from copy import deepcopy
22
from datetime import timedelta
33
from os import PathLike
4-
from typing import Optional, Tuple, Union, Sequence
4+
from typing import Optional, Tuple, Union, Sequence, List
55

66
import func_timeout
77
from golem.core.dag.graph import Graph
@@ -327,6 +327,19 @@ def root_node(self) -> Optional[PipelineNode]:
327327
raise ValueError(f'{ERROR_PREFIX} More than 1 root_nodes in pipeline')
328328
return root[0]
329329

330+
@property
331+
def primary_nodes(self) -> List[PipelineNode]:
332+
"""Finds pipelines sink-node
333+
334+
Returns:
335+
the final predictor-node
336+
"""
337+
if not self.nodes:
338+
return []
339+
primary_nodes = [node for node in self.nodes
340+
if not node.nodes_from]
341+
return primary_nodes
342+
330343
def pipeline_for_side_task(self, task_type: TaskTypesEnum) -> 'Pipeline':
331344
"""Returns pipeline formed from the last node solving the given problem and all its parents
332345

fedot/core/pipelines/verification_rules.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,14 @@ def has_correct_location_of_resample(pipeline: Pipeline):
166166
is_resample_primary = True
167167
else:
168168
is_not_resample_primary = True
169-
if node.name == 'resample':
170-
raise ValueError(
171-
f'{ERROR_PREFIX} Pipeline can have only one resample operation located in start of the pipeline')
169+
else:
170+
if node.name == 'resample':
171+
raise ValueError(
172+
f'{ERROR_PREFIX} Pipeline can have only one resample operation located in start of the pipeline')
172173
if is_resample_primary and is_not_resample_primary:
173174
raise ValueError(
174175
f'{ERROR_PREFIX} Pipeline can have only one resample operation located in start of the pipeline')
176+
return True
175177

176178

177179
def get_wrong_links(ts_to_table_operations: list, ts_data_operations: list, non_ts_data_operations: list,

0 commit comments

Comments
 (0)