Skip to content

Commit 723c581

Browse files
Varying first cutoff time for each target group (#258)
* implement and test varying minimum data per group * update release notes * pin scikit-learn for doc builds * update docstring * add guide for controlling cutoff times * fix dfs test * add guide to index * Revert "fix dfs test" This reverts commit 584a5cb. * pin version of featuretools * update docstring * update test case * update docstring * lint fix * parametrize test * lint fix
1 parent f396057 commit 723c581

File tree

8 files changed

+358
-97
lines changed

8 files changed

+358
-97
lines changed

composeml/label_maker.py

Lines changed: 94 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from sys import stdout
22

3+
from pandas import Series
34
from tqdm import tqdm
45

56
from composeml.data_slice import DataSliceGenerator
@@ -62,13 +63,22 @@ def labeling_function(self, value):
6263
assert isinstance(value, dict), 'value type for labeling function not supported'
6364
self._labeling_function = value
6465

66+
def _check_cutoff_time(self, value):
67+
if isinstance(value, Series):
68+
if value.index.is_unique: return value.to_dict()
69+
else: raise ValueError('more than one cutoff time exists for a target group')
70+
else: return value
71+
6572
def slice(self, df, num_examples_per_instance, minimum_data=None, maximum_data=None, gap=None, drop_empty=True):
6673
"""Generates data slices of target entity.
6774
6875
Args:
6976
df (DataFrame): Data frame to create slices on.
7077
num_examples_per_instance (int): Number of examples per unique instance of target entity.
71-
minimum_data (str): Minimum data before starting the search. Default value is first time of index.
78+
minimum_data (int or str or Series): The amount of data needed before starting the search. Defaults to the first value in the time index.
79+
The value can be a datetime string to directly set the first cutoff time or a timedelta string to denote the amount of data needed before
80+
the first cutoff time. The value can also be an integer to denote the number of rows needed before the first cutoff time.
81+
If a Series, minimum_data should be datetime string, timedelta string, or integer values with a unique set of target groups as the corresponding index.
7282
maximum_data (str): Maximum data before stopping the search. Default value is last time of index.
7383
gap (str or int): Time between examples. Default value is window size.
7484
If an integer, search will start on the first event after the minimum data.
@@ -79,24 +89,32 @@ def slice(self, df, num_examples_per_instance, minimum_data=None, maximum_data=N
7989
"""
8090
self._check_example_count(num_examples_per_instance, gap)
8191
df = self.set_index(df)
82-
entity_groups = df.groupby(self.target_entity)
92+
target_groups = df.groupby(self.target_entity)
8393
num_examples_per_instance = ExampleSearch._check_number(num_examples_per_instance)
8494

85-
generator = DataSliceGenerator(
86-
window_size=self.window_size,
87-
min_data=minimum_data,
88-
max_data=maximum_data,
89-
drop_empty=drop_empty,
90-
gap=gap,
91-
)
95+
minimum_data = self._check_cutoff_time(minimum_data)
96+
minimum_data_varies = isinstance(minimum_data, dict)
97+
98+
for group_key, df in target_groups:
99+
if minimum_data_varies:
100+
if group_key not in minimum_data: continue
101+
min_data_for_group = minimum_data[group_key]
102+
else:
103+
min_data_for_group = minimum_data
104+
105+
generator = DataSliceGenerator(
106+
window_size=self.window_size,
107+
min_data=min_data_for_group,
108+
max_data=maximum_data,
109+
drop_empty=drop_empty,
110+
gap=gap,
111+
)
92112

93-
for entity_id, df in entity_groups:
94113
for ds in generator(df):
95-
setattr(ds.context, self.target_entity, entity_id)
114+
setattr(ds.context, self.target_entity, group_key)
96115
yield ds
97116

98-
if ds.context.slice_number >= num_examples_per_instance:
99-
break
117+
if ds.context.slice_number >= num_examples_per_instance: break
100118

101119
@property
102120
def _bar_format(self):
@@ -107,72 +125,6 @@ def _bar_format(self):
107125
value += self.target_entity + ": {n}/{total} "
108126
return value
109127

110-
def _run_search(
111-
self,
112-
df,
113-
generator,
114-
search,
115-
verbose=True,
116-
*args,
117-
**kwargs,
118-
):
119-
"""Search implementation to make label records.
120-
121-
Args:
122-
df (DataFrame): Data frame to search and extract labels.
123-
generator (DataSliceGenerator): The generator for data slices.
124-
search (LabelSearch or ExampleSearch): The type of search to be done.
125-
verbose (bool): Whether to render progress bar. Default value is True.
126-
*args: Positional arguments for labeling function.
127-
**kwargs: Keyword arguments for labeling function.
128-
129-
Returns:
130-
records (list(dict)): Label Records
131-
"""
132-
df = self.set_index(df)
133-
entity_groups = df.groupby(self.target_entity)
134-
multiplier = search.expected_count if search.is_finite else 1
135-
total = entity_groups.ngroups * multiplier
136-
137-
progress_bar, records = tqdm(
138-
total=total,
139-
bar_format=self._bar_format,
140-
disable=not verbose,
141-
file=stdout,
142-
), []
143-
144-
def missing_examples(entity_count):
145-
return entity_count * search.expected_count - progress_bar.n
146-
147-
for entity_count, (entity_id, df) in enumerate(entity_groups):
148-
for ds in generator(df):
149-
items = self.labeling_function.items()
150-
labels = {name: lf(ds, *args, **kwargs) for name, lf in items}
151-
valid_labels = search.is_valid_labels(labels)
152-
if not valid_labels: continue
153-
154-
records.append({
155-
self.target_entity: entity_id,
156-
'time': ds.context.slice_start,
157-
**labels,
158-
})
159-
160-
search.update_count(labels)
161-
# if finite search, progress bar is updated for each example found
162-
if search.is_finite: progress_bar.update(n=1)
163-
if search.is_complete: break
164-
165-
# if finite search, progress bar is updated for examples not found
166-
# otherwise, progress bar is updated for each entity group
167-
n = missing_examples(entity_count + 1) if search.is_finite else 1
168-
progress_bar.update(n=n)
169-
search.reset_count()
170-
171-
total -= progress_bar.n
172-
progress_bar.update(n=total)
173-
progress_bar.close()
174-
return records
175-
176128
def _check_example_count(self, num_examples_per_instance, gap):
177129
"""Checks whether example count corresponds to data slices."""
178130
if self.window_size is None and gap is None:
@@ -195,8 +147,11 @@ def search(self,
195147
df (DataFrame): Data frame to search and extract labels.
196148
num_examples_per_instance (int or dict): The expected number of examples to return from each entity group.
197149
A dictionary can be used to further specify the expected number of examples to return from each label.
198-
minimum_data (str): Minimum data before starting the search. Default value is first time of index.
199-
maximum_data (str): Maximum data before stopping the search. Default value is last time of index.
150+
minimum_data (int or str or Series): The amount of data needed before starting the search. Defaults to the first value in the time index.
151+
The value can be a datetime string to directly set the first cutoff time or a timedelta string to denote the amount of data needed before
152+
the first cutoff time. The value can also be an integer to denote the number of rows needed before the first cutoff time.
153+
If a Series, minimum_data should be datetime string, timedelta string, or integer values with a unique set of target groups as the corresponding index.
154+
maximum_data (str): Maximum data before stopping the search. Defaults to the last value in the time index.
200155
gap (str or int): Time between examples. Default value is window size.
201156
If an integer, search will start on the first event after the minimum data.
202157
drop_empty (bool): Whether to drop empty slices. Default value is True.
@@ -212,30 +167,73 @@ def search(self,
212167
is_label_search = isinstance(num_examples_per_instance, dict)
213168
search = (LabelSearch if is_label_search else ExampleSearch)(num_examples_per_instance)
214169

215-
generator = DataSliceGenerator(
216-
window_size=self.window_size,
217-
min_data=minimum_data,
218-
max_data=maximum_data,
219-
drop_empty=drop_empty,
220-
gap=gap,
221-
)
170+
# check minimum data cutoff time
171+
minimum_data = self._check_cutoff_time(minimum_data)
172+
minimum_data_varies = isinstance(minimum_data, dict)
173+
174+
df = self.set_index(df)
175+
total = search.expected_count if search.is_finite else 1
176+
target_groups = df.groupby(self.target_entity)
177+
total *= target_groups.ngroups
222178

223-
records = self._run_search(
224-
df=df,
225-
generator=generator,
226-
search=search,
227-
verbose=verbose,
228-
*args,
229-
**kwargs,
179+
progress_bar = tqdm(
180+
total=total,
181+
file=stdout,
182+
disable=not verbose,
183+
bar_format=self._bar_format,
230184
)
231185

186+
records = []
187+
for group_count, (group_key, df) in enumerate(target_groups, start=1):
188+
if minimum_data_varies:
189+
if group_key not in minimum_data: continue
190+
min_data_for_group = minimum_data[group_key]
191+
else:
192+
min_data_for_group = minimum_data
193+
194+
generator = DataSliceGenerator(
195+
window_size=self.window_size,
196+
min_data=min_data_for_group,
197+
max_data=maximum_data,
198+
drop_empty=drop_empty,
199+
gap=gap,
200+
)
201+
202+
for ds in generator(df):
203+
setattr(ds.context, self.target_entity, group_key)
204+
205+
items = self.labeling_function.items()
206+
labels = {name: lf(ds, *args, **kwargs) for name, lf in items}
207+
valid_labels = search.is_valid_labels(labels)
208+
if not valid_labels: continue
209+
210+
records.append({
211+
self.target_entity: group_key,
212+
'time': ds.context.slice_start,
213+
**labels,
214+
})
215+
216+
search.update_count(labels)
217+
# if finite search, update progress bar for the example found
218+
if search.is_finite: progress_bar.update(n=1)
219+
if search.is_complete: break
220+
221+
# if finite search, update progress bar for missing examples
222+
if search.is_finite: progress_bar.update(n=group_count * search.expected_count - progress_bar.n)
223+
else: progress_bar.update(n=1) # otherwise, update progress bar once for each group
224+
search.reset_count()
225+
226+
total -= progress_bar.n
227+
progress_bar.update(n=total)
228+
progress_bar.close()
229+
232230
lt = LabelTimes(
233231
data=records,
234232
target_columns=list(self.labeling_function),
235233
target_entity=self.target_entity,
236234
search_settings={
237235
'num_examples_per_instance': num_examples_per_instance,
238-
'minimum_data': str(minimum_data),
236+
'minimum_data': minimum_data,
239237
'maximum_data': str(maximum_data),
240238
'window_size': str(self.window_size),
241239
'gap': str(gap),

composeml/tests/test_data_slice/test_extension.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,10 @@ def test_time_index_error(transactions):
6464
match = 'offset by frequency requires a time index'
6565
with raises(AssertionError, match=match):
6666
transactions.slice[::'1h']
67+
68+
69+
def test_minimum_data_per_group(transactions):
70+
lm = LabelMaker('customer_id', labeling_function=len, time_index='time', window_size='1h')
71+
minimum_data = {1: '2019-01-01 09:00:00', 3: '2019-01-01 12:00:00'}
72+
lengths = [len(ds) for ds in lm.slice(transactions, 1, minimum_data=minimum_data)]
73+
assert lengths == [2, 1]

composeml/tests/test_label_maker.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,3 +546,29 @@ def test_search_with_maximum_data(transactions):
546546

547547
actual = lt.pipe(to_csv, index=False)
548548
assert actual == expected
549+
550+
551+
@pytest.mark.parametrize("minimum_data", [{1: '2019-01-01 09:30:00', 2: '2019-01-01 11:30:00'}, {1: '30min', 2: '1h'}, {1: 1, 2: 2}])
552+
def test_minimum_data_per_group(transactions, minimum_data):
553+
lm = LabelMaker('customer_id', labeling_function=len, time_index='time', window_size='1h')
554+
for supported_type in [minimum_data, pd.Series(minimum_data)]:
555+
lt = lm.search(transactions, 1, minimum_data=supported_type)
556+
actual = to_csv(lt, index=False)
557+
558+
expected = [
559+
'customer_id,time,len',
560+
'1,2019-01-01 09:30:00,2',
561+
'2,2019-01-01 11:30:00,2'
562+
]
563+
564+
assert actual == expected
565+
566+
567+
def test_minimum_data_per_group_error(transactions):
568+
lm = LabelMaker('customer_id', labeling_function=len, time_index='time', window_size='1h')
569+
data = ['2019-01-01 09:00:00', '2019-01-01 12:00:00']
570+
minimum_data = pd.Series(data=data, index=[1, 1])
571+
match = "more than one cutoff time exists for a target group"
572+
573+
with pytest.raises(ValueError, match=match):
574+
lm.search(transactions, 1, minimum_data=minimum_data)

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ jupyter==1.0.0
33
nbsphinx==0.8.6
44
pydata-sphinx-theme==0.6.3
55
evalml==0.28.0
6+
scikit-learn>=0.24.0,<1.0

docs/source/release_notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Release Notes
88
* Enhancements
99
* Add ``maximum_data`` parameter to control when a search should stop (:pr:`216`)
1010
* Add optional automatic update checker (:pr:`223`, :pr:`229`, :pr:`232`)
11+
* Varying first cutoff time for each target group (:pr:`258`)
1112
* Fixes
1213
* Documentation Changes
1314
* Update doc tutorials to the latest API changes (:pr:`227`)

docs/source/user_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Use these guides to learn how to use label transformations and generate better t
88
:glob:
99
:maxdepth: 1
1010

11+
user_guide/controlling_cutoff_times
1112
user_guide/using_label_transforms
1213
user_guide/data_slice_generator
1314

0 commit comments

Comments
 (0)