Skip to content

Commit 21699a9

Browse files
committed
add data inspection utilities and tests
1. Add data inspection utilities in data loader. 2. Use omegaconf to load and save yaml config. 3. Add tests to run NHP.
1 parent 1f85043 commit 21699a9

File tree

7 files changed

+195
-43
lines changed

7 files changed

+195
-43
lines changed

easy_tpp/config_factory/config.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,34 @@
11
from abc import abstractmethod
22
from typing import Any
3+
from omegaconf import OmegaConf
34

4-
from easy_tpp.utils import save_yaml_config, load_yaml_config, Registrable, logger
5+
from easy_tpp.utils import save_yaml_config, Registrable, logger
56

67

78
class Config(Registrable):
89

9-
def save_to_yaml_file(self, fn):
10-
"""Save the config into the yaml file 'fn'.
10+
def save_to_yaml_file(self, config_dir):
11+
"""Save the config into the yaml file 'config_dir'.
1112
1213
Args:
13-
fn (str): Target filename.
14+
config_dir (str): Target filename.
1415
1516
Returns:
1617
"""
1718
yaml_config = self.get_yaml_config()
18-
save_yaml_config(fn, yaml_config)
19+
OmegaConf.save(yaml_config, config_dir)
1920

2021
@staticmethod
21-
def build_from_yaml_file(yaml_fn, **kwargs):
22+
def build_from_yaml_file(yaml_dir, **kwargs):
2223
"""Load yaml config file from disk.
2324
2425
Args:
25-
yaml_fn (str): Path of the yaml config file.
26+
yaml_dir (str): Path of the yaml config file.
2627
2728
Returns:
2829
EasyTPP.Config: Config object corresponding to cls.
2930
"""
30-
config = load_yaml_config(yaml_fn)
31+
config = OmegaConf.load(yaml_dir)
3132
pipeline_config = config.get('pipeline_config_id')
3233
config_cls = Config.by_name(pipeline_config.lower())
3334
logger.critical(f'Load pipeline config class {config_cls.__name__}')

easy_tpp/config_factory/data_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ def copy(self):
6969
max_len=self.max_len)
7070

7171

72+
@Config.register('data_config')
7273
class DataConfig(Config):
73-
def __init__(self, train_dir, valid_dir, test_dir, specs=None):
74+
def __init__(self, train_dir, valid_dir, test_dir, data_format, specs=None):
7475
"""Initialize the DataConfig object.
7576
7677
Args:
@@ -83,7 +84,7 @@ def __init__(self, train_dir, valid_dir, test_dir, specs=None):
8384
self.valid_dir = valid_dir
8485
self.test_dir = test_dir
8586
self.data_specs = specs or DataSpecConfig()
86-
self.data_format = train_dir.split('.')[-1]
87+
self.data_format = train_dir.split('.')[-1] if data_format is None else data_format
8788

8889
def get_yaml_config(self):
8990
"""Return the config in dict (yaml compatible) format.
@@ -113,6 +114,7 @@ def parse_from_yaml_config(yaml_config):
113114
train_dir=yaml_config.get('train_dir'),
114115
valid_dir=yaml_config.get('valid_dir'),
115116
test_dir=yaml_config.get('test_dir'),
117+
data_format=yaml_config.get('data_format'),
116118
specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs'))
117119
)
118120

easy_tpp/preprocess/data_loader.py

Lines changed: 150 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from collections import Counter
14
from easy_tpp.preprocess.dataset import TPPDataset
25
from easy_tpp.preprocess.dataset import get_data_loader
36
from easy_tpp.preprocess.event_tokenizer import EventTokenizer
47
from easy_tpp.utils import load_pickle, py_assert
58

69

710
class TPPDataLoader:
8-
def __init__(self, data_config, backend, **kwargs):
11+
def __init__(self, data_config, **kwargs):
912
"""Initialize the dataloader
1013
1114
Args:
@@ -14,45 +17,75 @@ def __init__(self, data_config, backend, **kwargs):
1417
"""
1518
self.data_config = data_config
1619
self.num_event_types = data_config.data_specs.num_event_types
17-
self.backend = backend
20+
self.backend = kwargs.get('backend', 'torch')
1821
self.kwargs = kwargs
1922

20-
def build_input_from_pkl(self, source_dir: str, split: str):
21-
data = load_pickle(source_dir)
23+
def build_input(self, source_dir, data_format, split):
24+
"""Helper function to load and process dataset based on file format.
25+
26+
Args:
27+
source_dir (str): Path to dataset directory.
28+
split (str): Dataset split, e.g., 'train', 'dev', 'test'.
29+
30+
Returns:
31+
dict: Dictionary containing sequences of event times, types, and intervals.
32+
"""
33+
34+
if data_format == 'pkl':
35+
return self._build_input_from_pkl(source_dir, split)
36+
elif data_format == 'json':
37+
return self._build_input_from_json(source_dir, split)
38+
else:
39+
raise ValueError(f"Unsupported file format: {data_format}")
40+
41+
def _build_input_from_pkl(self, source_dir, split):
42+
"""Load and process data from a pickle file.
43+
44+
Args:
45+
source_dir (str): Path to the pickle file.
46+
split (str): Dataset split, e.g., 'train', 'dev', 'test'.
2247
48+
Returns:
49+
dict: Dictionary with processed event sequences.
50+
"""
51+
data = load_pickle(source_dir)
2352
py_assert(data["dim_process"] == self.num_event_types,
24-
ValueError,
25-
"inconsistent dim_process in different splits?")
53+
ValueError, "Inconsistent dim_process in different splits.")
2654

2755
source_data = data[split]
28-
time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data]
29-
type_seqs = [[x["type_event"] for x in seq] for seq in source_data]
30-
time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data]
56+
return {
57+
'time_seqs': [[x["time_since_start"] for x in seq] for seq in source_data],
58+
'type_seqs': [[x["type_event"] for x in seq] for seq in source_data],
59+
'time_delta_seqs': [[x["time_since_last_event"] for x in seq] for seq in source_data]
60+
}
3161

32-
input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs})
33-
return input_dict
62+
def _build_input_from_json(self, source_dir, split):
63+
"""Load and process data from a JSON file.
3464
35-
def build_input_from_json(self, source_dir: str, split: str):
65+
Args:
66+
source_dir (str): Path to the JSON file or Hugging Face dataset name.
67+
split (str): Dataset split, e.g., 'train', 'dev', 'test'.
68+
69+
Returns:
70+
dict: Dictionary with processed event sequences.
71+
"""
3672
from datasets import load_dataset
37-
split_ = 'validation' if split == 'dev' else split
38-
# load locally
39-
if source_dir.split('.')[-1] == 'json':
40-
data = load_dataset('json', data_files={split_: source_dir}, split=split_)
73+
split_mapped = 'validation' if split == 'dev' else split
74+
if source_dir.endswith('.json'):
75+
data = load_dataset('json', data_files={split_mapped: source_dir}, split=split_mapped)
4176
elif source_dir.startswith('easytpp'):
42-
data = load_dataset(source_dir, split=split_)
77+
data = load_dataset(source_dir, split=split_mapped)
4378
else:
44-
raise NotImplementedError
79+
raise ValueError("Unsupported source directory format for JSON.")
4580

4681
py_assert(data['dim_process'][0] == self.num_event_types,
47-
ValueError,
48-
"inconsistent dim_process in different splits?")
49-
50-
time_seqs = data['time_since_start']
51-
type_seqs = data['type_event']
52-
time_delta_seqs = data['time_since_last_event']
82+
ValueError, "Inconsistent dim_process in different splits.")
5383

54-
input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs})
55-
return input_dict
84+
return {
85+
'time_seqs': data['time_since_start'],
86+
'type_seqs': data['type_event'],
87+
'time_delta_seqs': data['time_since_last_event']
88+
}
5689

5790
def get_loader(self, split='train', **kwargs):
5891
"""Get the corresponding data loader.
@@ -68,12 +101,7 @@ def get_loader(self, split='train', **kwargs):
68101
EasyTPP.DataLoader: the data loader for tpp data.
69102
"""
70103
data_dir = self.data_config.get_data_dir(split)
71-
data_source_type = data_dir.split('.')[-1]
72-
73-
if data_source_type == 'pkl':
74-
data = self.build_input_from_pkl(data_dir, split)
75-
else:
76-
data = self.build_input_from_json(data_dir, split)
104+
data = self.build_input(data_dir, self.data_config.data_format, split)
77105

78106
dataset = TPPDataset(data)
79107
tokenizer = EventTokenizer(self.data_config.data_specs)
@@ -109,3 +137,93 @@ def test_loader(self, **kwargs):
109137
EasyTPP.DataLoader: data loader for test set.
110138
"""
111139
return self.get_loader('test', **kwargs)
140+
141+
def get_statistics(self, split='train'):
142+
"""Get basic statistics about the dataset.
143+
144+
Args:
145+
split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
146+
147+
Returns:
148+
dict: Dictionary containing statistics about the dataset.
149+
"""
150+
data_dir = self.data_config.get_data_dir(split)
151+
data = self.build_input(data_dir, self.data_config.data_format, split)
152+
153+
num_sequences = len(data['time_seqs'])
154+
sequence_lengths = [len(seq) for seq in data['time_seqs']]
155+
avg_sequence_length = sum(sequence_lengths) / num_sequences
156+
all_event_types = [event for seq in data['type_seqs'] for event in seq]
157+
event_type_counts = Counter(all_event_types)
158+
159+
# Calculate time_delta_seqs statistics
160+
all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq]
161+
mean_time_delta = np.mean(all_time_deltas) if all_time_deltas else 0
162+
min_time_delta = np.min(all_time_deltas) if all_time_deltas else 0
163+
max_time_delta = np.max(all_time_deltas) if all_time_deltas else 0
164+
165+
stats = {
166+
"num_sequences": num_sequences,
167+
"avg_sequence_length": avg_sequence_length,
168+
"event_type_distribution": dict(event_type_counts),
169+
"max_sequence_length": max(sequence_lengths),
170+
"min_sequence_length": min(sequence_lengths),
171+
"mean_time_delta": mean_time_delta,
172+
"min_time_delta": min_time_delta,
173+
"max_time_delta": max_time_delta
174+
}
175+
176+
return stats
177+
178+
def plot_event_type_distribution(self, split='train'):
179+
"""Plot the distribution of event types in the dataset.
180+
181+
Args:
182+
split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
183+
"""
184+
stats = self.get_statistics(split)
185+
event_type_distribution = stats['event_type_distribution']
186+
187+
plt.figure(figsize=(8, 6))
188+
plt.bar(event_type_distribution.keys(), event_type_distribution.values(), color='skyblue')
189+
plt.xlabel('Event Types')
190+
plt.ylabel('Frequency')
191+
plt.title(f'Event Type Distribution ({split} set)')
192+
plt.show()
193+
194+
def plot_event_delta_times_distribution(self, split='train'):
195+
"""Plot the distribution of event delta times in the dataset.
196+
197+
Args:
198+
split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
199+
"""
200+
data_dir = self.data_config.get_data_dir(split)
201+
data = self.build_input(data_dir, self.data_config.data_format, split)
202+
203+
# Flatten the time_delta_seqs to get all delta times
204+
all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq]
205+
206+
plt.figure(figsize=(10, 6))
207+
plt.hist(all_time_deltas, bins=30, color='skyblue', edgecolor='black')
208+
plt.xlabel('Event Delta Times')
209+
plt.ylabel('Frequency')
210+
plt.title(f'Event Delta Times Distribution ({split} set)')
211+
plt.grid(axis='y', alpha=0.75)
212+
plt.show()
213+
214+
def plot_sequence_length_distribution(self, split='train'):
215+
"""Plot the distribution of sequence lengths in the dataset.
216+
217+
Args:
218+
split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
219+
"""
220+
data_dir = self.data_config.get_data_dir(split)
221+
data = self.build_input(data_dir, self.data_config.data_format, split)
222+
sequence_lengths = [len(seq) for seq in data['time_seqs']]
223+
224+
plt.figure(figsize=(8, 6))
225+
plt.hist(sequence_lengths, bins=10, color='salmon', edgecolor='black')
226+
plt.xlabel('Sequence Length')
227+
plt.ylabel('Frequency')
228+
plt.title(f'Sequence Length Distribution ({split} set)')
229+
plt.show()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
pipeline_config_id: data_config
2+
3+
data_format: json
4+
train_dir: easytpp/taxi # ./data/taxi/train.json
5+
valid_dir: easytpp/taxi # ./data/taxi/dev.json
6+
test_dir: easytpp/taxi # ./data/taxi/test.json
7+
data_specs:
8+
num_event_types: 10
9+
pad_token_id: 10
10+
padding_side: right
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
import sys
3+
# Get the directory of the current file
4+
current_file_path = os.path.abspath(__file__)
5+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))))
6+
7+
from easy_tpp.config_factory import Config
8+
from easy_tpp.preprocess.data_loader import TPPDataLoader
9+
10+
11+
def main():
12+
config = Config.build_from_yaml_file('./config.yaml')
13+
tpp_loader = TPPDataLoader(config)
14+
stats = tpp_loader.get_statistics(split='train')
15+
print(stats)
16+
tpp_loader.plot_event_type_distribution()
17+
tpp_loader.plot_event_delta_times_distribution()
18+
19+
if __name__ == '__main__':
20+
main()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ torch
55
tensorboard
66
packaging
77
datasets
8+
omegaconf

version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.0.7.1'
1+
__version__ = '0.0.8'

0 commit comments

Comments
 (0)