1
1
from functools import partial
2
2
from typing import Optional
3
+ from copy import deepcopy
3
4
4
5
from golem .core .log import default_log
5
6
6
- from fedot .core .constants import default_data_split_ratio_by_task
7
+ from fedot .core .constants import DEFAULT_DATA_SPLIT_RATIO_BY_TASK , DEFAULT_CV_FOLDS_BY_TASK
7
8
from fedot .core .data .data import InputData
8
9
from fedot .core .data .data_split import train_test_data_setup
9
10
from fedot .core .data .multi_modal import MultiModalData
@@ -31,7 +32,7 @@ def __init__(self,
31
32
cv_folds : Optional [int ] = None ,
32
33
validation_blocks : Optional [int ] = None ,
33
34
split_ratio : Optional [float ] = None ,
34
- shuffle : bool = False ):
35
+ shuffle : bool = False , ):
35
36
self .cv_folds = cv_folds
36
37
self .validation_blocks = validation_blocks
37
38
self .split_ratio = split_ratio
@@ -45,13 +46,21 @@ def build(self, data: InputData) -> DataSource:
45
46
data .shuffle ()
46
47
47
48
# Check split_ratio
48
- split_ratio = self .split_ratio or default_data_split_ratio_by_task [data .task .task_type ]
49
+ split_ratio = self .split_ratio or DEFAULT_DATA_SPLIT_RATIO_BY_TASK [data .task .task_type ]
49
50
if not (0 < split_ratio < 1 ):
50
51
raise ValueError (f'split_ratio is { split_ratio } but should be between 0 and 1' )
51
52
52
- # Calculate the number of validation blocks
53
- if self .validation_blocks is None and data .task .task_type is TaskTypesEnum .ts_forecasting :
54
- self ._propose_cv_folds_and_validation_blocks (data , split_ratio )
53
+ # Calculate the number of validation blocks and number of cv folds for ts forecasting
54
+ if data .task .task_type is TaskTypesEnum .ts_forecasting :
55
+ if self .validation_blocks is None :
56
+ self ._propose_cv_folds_and_validation_blocks (data , split_ratio )
57
+ # when forecasting length is low and data length is high there are huge amount of validation blocks
58
+ # some model refit each step of forecasting that may be time consuming
59
+ # solution is set forecasting length to higher value and reduce validation blocks count
60
+ # without reducing validation data length which is equal to forecast_length * validation_blocks
61
+ max_validation_blocks = DEFAULT_CV_FOLDS_BY_TASK [data .task .task_type ] if self .cv_folds is None else 1
62
+ if self .validation_blocks > max_validation_blocks :
63
+ data = self ._propose_forecast_length (data , max_validation_blocks )
55
64
56
65
# Split data
57
66
if self .cv_folds is not None :
@@ -73,7 +82,7 @@ def _build_holdout_producer(self, data: InputData) -> DataSource:
73
82
that always returns same data split. Equivalent to 1-fold validation.
74
83
"""
75
84
76
- split_ratio = self .split_ratio or default_data_split_ratio_by_task [data .task .task_type ]
85
+ split_ratio = self .split_ratio or DEFAULT_DATA_SPLIT_RATIO_BY_TASK [data .task .task_type ]
77
86
train_data , test_data = train_test_data_setup (data , split_ratio , validation_blocks = self .validation_blocks )
78
87
79
88
if RemoteEvaluator ().is_enabled :
@@ -129,3 +138,11 @@ def _propose_cv_folds_and_validation_blocks(self, data, split_ratio):
129
138
else :
130
139
test_share = 1 / (self .cv_folds + 1 )
131
140
self .validation_blocks = int (data_shape * test_share // forecast_length )
141
+
142
+ def _propose_forecast_length (self , data , max_validation_blocks ):
143
+ horizon = self .validation_blocks * data .task .task_params .forecast_length
144
+ self .validation_blocks = max_validation_blocks
145
+ # TODO: make copy without copy all data, only with task copy
146
+ data = deepcopy (data )
147
+ data .task .task_params .forecast_length = int (horizon // self .validation_blocks )
148
+ return data
0 commit comments