Skip to content

Commit 7e2b7a0

Browse files
authored
Merge pull request #65 from cvejoski/change_pad_dtype
set the dtype to float32 during tokenization when no padding is needed
2 parents a7f0de0 + 51d22e6 commit 7e2b7a0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

easy_tpp/preprocess/event_tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,9 +414,9 @@ def _pad(
414414
max_len=max_length,
415415
dtype=np.int64)
416416
else:
417-
batch_output[self.model_input_names[0]] = np.array(encoded_inputs[self.model_input_names[0]])
418-
batch_output[self.model_input_names[1]] = np.array(encoded_inputs[self.model_input_names[1]])
419-
batch_output[self.model_input_names[2]] = np.array(encoded_inputs[self.model_input_names[2]])
417+
batch_output[self.model_input_names[0]] = np.array(encoded_inputs[self.model_input_names[0]], dtype=np.float32)
418+
batch_output[self.model_input_names[1]] = np.array(encoded_inputs[self.model_input_names[1]], dtype=np.float32)
419+
batch_output[self.model_input_names[2]] = np.array(encoded_inputs[self.model_input_names[2]], dtype=np.int64)
420420

421421
# non_pad_mask; replaced the use of event types by using the original sequence length
422422
seq_pad_mask = np.full_like(batch_output[self.model_input_names[2]], fill_value=True, dtype=bool)

0 commit comments

Comments
 (0)