Skip to content

Commit 1f85043

Browse files
committed
add tests - data loader test
1 parent ea238c3 commit 1f85043

File tree

6 files changed

+1045
-3
lines changed

6 files changed

+1045
-3
lines changed

easy_tpp/preprocess/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get_dt_stats(self):
116116
s_2_y = dts.var()
117117
m = dts.shape[0]
118118
n += m
119-
# Formulat taken from https://math.stackexchange.com/questions/3604607/can-i-work-out-the-variance-in-batches
119+
# Formula taken from https://math.stackexchange.com/questions/3604607/can-i-work-out-the-variance-in-batches
120120
s_2_x = (((n - 1) * s_2_x + (m - 1) * s_2_y) / (n + m - 1)) + (
121121
(n * m * ((x_bar - y_bar) ** 2)) / ((n + m) * (n + m - 1)))
122122
x_bar = (n * x_bar + m * y_bar) / (n + m)

easy_tpp/preprocess/event_tokenizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,9 @@ def _pad(
414414
max_len=max_length,
415415
dtype=np.int64)
416416
else:
417-
batch_output = encoded_inputs
417+
batch_output[self.model_input_names[0]] = np.array(encoded_inputs[self.model_input_names[0]])
418+
batch_output[self.model_input_names[1]] = np.array(encoded_inputs[self.model_input_names[1]])
419+
batch_output[self.model_input_names[2]] = np.array(encoded_inputs[self.model_input_names[2]])
418420

419421
# non_pad_mask; replaced the use of event types by using the original sequence length
420422
seq_pad_mask = np.full_like(batch_output[self.model_input_names[2]], fill_value=True, dtype=bool)

examples/hf_data_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ def load_data_from_hf(hf_dir=None, local_dir=None):
1717
if __name__ == '__main__':
1818
# in case one fails to load from hf directly
1919
# one can load the json data file locally
20-
load_data_from_hf(hf_dir=None, local_dir={'validation':'dev.json'})
20+
# load_data_from_hf(hf_dir=None, local_dir={'validation':'dev.json'})
21+
load_data_from_hf(hf_dir='easytpp/taxi')

tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)