1+ import matplotlib .pyplot as plt
2+ import numpy as np
3+ from collections import Counter
14from easy_tpp .preprocess .dataset import TPPDataset
25from easy_tpp .preprocess .dataset import get_data_loader
36from easy_tpp .preprocess .event_tokenizer import EventTokenizer
47from easy_tpp .utils import load_pickle , py_assert
58
69
710class TPPDataLoader :
8- def __init__ (self , data_config , backend , ** kwargs ):
11+ def __init__ (self , data_config , ** kwargs ):
912 """Initialize the dataloader
1013
1114 Args:
@@ -14,45 +17,75 @@ def __init__(self, data_config, backend, **kwargs):
1417 """
1518 self .data_config = data_config
1619 self .num_event_types = data_config .data_specs .num_event_types
17- self .backend = backend
20+ self .backend = kwargs . get ( ' backend' , 'torch' )
1821 self .kwargs = kwargs
1922
20- def build_input_from_pkl (self , source_dir : str , split : str ):
21- data = load_pickle (source_dir )
23+ def build_input (self , source_dir , data_format , split ):
24+ """Helper function to load and process dataset based on file format.
25+
26+ Args:
27+ source_dir (str): Path to dataset directory.
28+ split (str): Dataset split, e.g., 'train', 'dev', 'test'.
29+
30+ Returns:
31+ dict: Dictionary containing sequences of event times, types, and intervals.
32+ """
33+
34+ if data_format == 'pkl' :
35+ return self ._build_input_from_pkl (source_dir , split )
36+ elif data_format == 'json' :
37+ return self ._build_input_from_json (source_dir , split )
38+ else :
39+ raise ValueError (f"Unsupported file format: { data_format } " )
40+
41+ def _build_input_from_pkl (self , source_dir , split ):
42+ """Load and process data from a pickle file.
43+
44+ Args:
45+ source_dir (str): Path to the pickle file.
46+ split (str): Dataset split, e.g., 'train', 'dev', 'test'.
2247
48+ Returns:
49+ dict: Dictionary with processed event sequences.
50+ """
51+ data = load_pickle (source_dir )
2352 py_assert (data ["dim_process" ] == self .num_event_types ,
24- ValueError ,
25- "inconsistent dim_process in different splits?" )
53+ ValueError , "Inconsistent dim_process in different splits." )
2654
2755 source_data = data [split ]
28- time_seqs = [[x ["time_since_start" ] for x in seq ] for seq in source_data ]
29- type_seqs = [[x ["type_event" ] for x in seq ] for seq in source_data ]
30- time_delta_seqs = [[x ["time_since_last_event" ] for x in seq ] for seq in source_data ]
56+ return {
57+ 'time_seqs' : [[x ["time_since_start" ] for x in seq ] for seq in source_data ],
58+ 'type_seqs' : [[x ["type_event" ] for x in seq ] for seq in source_data ],
59+ 'time_delta_seqs' : [[x ["time_since_last_event" ] for x in seq ] for seq in source_data ]
60+ }
3161
32- input_dict = dict ({ 'time_seqs' : time_seqs , 'time_delta_seqs' : time_delta_seqs , 'type_seqs' : type_seqs })
33- return input_dict
62+ def _build_input_from_json ( self , source_dir , split ):
63+ """Load and process data from a JSON file.
3464
35- def build_input_from_json (self , source_dir : str , split : str ):
65+ Args:
66+ source_dir (str): Path to the JSON file or Hugging Face dataset name.
67+ split (str): Dataset split, e.g., 'train', 'dev', 'test'.
68+
69+ Returns:
70+ dict: Dictionary with processed event sequences.
71+ """
3672 from datasets import load_dataset
37- split_ = 'validation' if split == 'dev' else split
38- # load locally
39- if source_dir .split ('.' )[- 1 ] == 'json' :
40- data = load_dataset ('json' , data_files = {split_ : source_dir }, split = split_ )
73+ split_mapped = 'validation' if split == 'dev' else split
74+ if source_dir .endswith ('.json' ):
75+ data = load_dataset ('json' , data_files = {split_mapped : source_dir }, split = split_mapped )
4176 elif source_dir .startswith ('easytpp' ):
42- data = load_dataset (source_dir , split = split_ )
77+ data = load_dataset (source_dir , split = split_mapped )
4378 else :
44- raise NotImplementedError
79+ raise ValueError ( "Unsupported source directory format for JSON." )
4580
4681 py_assert (data ['dim_process' ][0 ] == self .num_event_types ,
47- ValueError ,
48- "inconsistent dim_process in different splits?" )
49-
50- time_seqs = data ['time_since_start' ]
51- type_seqs = data ['type_event' ]
52- time_delta_seqs = data ['time_since_last_event' ]
82+ ValueError , "Inconsistent dim_process in different splits." )
5383
54- input_dict = dict ({'time_seqs' : time_seqs , 'time_delta_seqs' : time_delta_seqs , 'type_seqs' : type_seqs })
55- return input_dict
84+ return {
85+ 'time_seqs' : data ['time_since_start' ],
86+ 'type_seqs' : data ['type_event' ],
87+ 'time_delta_seqs' : data ['time_since_last_event' ]
88+ }
5689
5790 def get_loader (self , split = 'train' , ** kwargs ):
5891 """Get the corresponding data loader.
@@ -68,12 +101,7 @@ def get_loader(self, split='train', **kwargs):
68101 EasyTPP.DataLoader: the data loader for tpp data.
69102 """
70103 data_dir = self .data_config .get_data_dir (split )
71- data_source_type = data_dir .split ('.' )[- 1 ]
72-
73- if data_source_type == 'pkl' :
74- data = self .build_input_from_pkl (data_dir , split )
75- else :
76- data = self .build_input_from_json (data_dir , split )
104+ data = self .build_input (data_dir , self .data_config .data_format , split )
77105
78106 dataset = TPPDataset (data )
79107 tokenizer = EventTokenizer (self .data_config .data_specs )
@@ -109,3 +137,93 @@ def test_loader(self, **kwargs):
109137 EasyTPP.DataLoader: data loader for test set.
110138 """
111139 return self .get_loader ('test' , ** kwargs )
140+
141+ def get_statistics (self , split = 'train' ):
142+ """Get basic statistics about the dataset.
143+
144+ Args:
145+ split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
146+
147+ Returns:
148+ dict: Dictionary containing statistics about the dataset.
149+ """
150+ data_dir = self .data_config .get_data_dir (split )
151+ data = self .build_input (data_dir , self .data_config .data_format , split )
152+
153+ num_sequences = len (data ['time_seqs' ])
154+ sequence_lengths = [len (seq ) for seq in data ['time_seqs' ]]
155+ avg_sequence_length = sum (sequence_lengths ) / num_sequences
156+ all_event_types = [event for seq in data ['type_seqs' ] for event in seq ]
157+ event_type_counts = Counter (all_event_types )
158+
159+ # Calculate time_delta_seqs statistics
160+ all_time_deltas = [delta for seq in data ['time_delta_seqs' ] for delta in seq ]
161+ mean_time_delta = np .mean (all_time_deltas ) if all_time_deltas else 0
162+ min_time_delta = np .min (all_time_deltas ) if all_time_deltas else 0
163+ max_time_delta = np .max (all_time_deltas ) if all_time_deltas else 0
164+
165+ stats = {
166+ "num_sequences" : num_sequences ,
167+ "avg_sequence_length" : avg_sequence_length ,
168+ "event_type_distribution" : dict (event_type_counts ),
169+ "max_sequence_length" : max (sequence_lengths ),
170+ "min_sequence_length" : min (sequence_lengths ),
171+ "mean_time_delta" : mean_time_delta ,
172+ "min_time_delta" : min_time_delta ,
173+ "max_time_delta" : max_time_delta
174+ }
175+
176+ return stats
177+
178+ def plot_event_type_distribution (self , split = 'train' ):
179+ """Plot the distribution of event types in the dataset.
180+
181+ Args:
182+ split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
183+ """
184+ stats = self .get_statistics (split )
185+ event_type_distribution = stats ['event_type_distribution' ]
186+
187+ plt .figure (figsize = (8 , 6 ))
188+ plt .bar (event_type_distribution .keys (), event_type_distribution .values (), color = 'skyblue' )
189+ plt .xlabel ('Event Types' )
190+ plt .ylabel ('Frequency' )
191+ plt .title (f'Event Type Distribution ({ split } set)' )
192+ plt .show ()
193+
194+ def plot_event_delta_times_distribution (self , split = 'train' ):
195+ """Plot the distribution of event delta times in the dataset.
196+
197+ Args:
198+ split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
199+ """
200+ data_dir = self .data_config .get_data_dir (split )
201+ data = self .build_input (data_dir , self .data_config .data_format , split )
202+
203+ # Flatten the time_delta_seqs to get all delta times
204+ all_time_deltas = [delta for seq in data ['time_delta_seqs' ] for delta in seq ]
205+
206+ plt .figure (figsize = (10 , 6 ))
207+ plt .hist (all_time_deltas , bins = 30 , color = 'skyblue' , edgecolor = 'black' )
208+ plt .xlabel ('Event Delta Times' )
209+ plt .ylabel ('Frequency' )
210+ plt .title (f'Event Delta Times Distribution ({ split } set)' )
211+ plt .grid (axis = 'y' , alpha = 0.75 )
212+ plt .show ()
213+
214+ def plot_sequence_length_distribution (self , split = 'train' ):
215+ """Plot the distribution of sequence lengths in the dataset.
216+
217+ Args:
218+ split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
219+ """
220+ data_dir = self .data_config .get_data_dir (split )
221+ data = self .build_input (data_dir , self .data_config .data_format , split )
222+ sequence_lengths = [len (seq ) for seq in data ['time_seqs' ]]
223+
224+ plt .figure (figsize = (8 , 6 ))
225+ plt .hist (sequence_lengths , bins = 10 , color = 'salmon' , edgecolor = 'black' )
226+ plt .xlabel ('Sequence Length' )
227+ plt .ylabel ('Frequency' )
228+ plt .title (f'Sequence Length Distribution ({ split } set)' )
229+ plt .show ()
0 commit comments