File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -414,9 +414,9 @@ def _pad(
414414 max_len = max_length ,
415415 dtype = np .int64 )
416416 else :
417- batch_output [self .model_input_names [0 ]] = np .array (encoded_inputs [self .model_input_names [0 ]])
418- batch_output [self .model_input_names [1 ]] = np .array (encoded_inputs [self .model_input_names [1 ]])
419- batch_output [self .model_input_names [2 ]] = np .array (encoded_inputs [self .model_input_names [2 ]])
417+ batch_output [self .model_input_names [0 ]] = np .array (encoded_inputs [self .model_input_names [0 ]], dtype = np . float32 )
418+ batch_output [self .model_input_names [1 ]] = np .array (encoded_inputs [self .model_input_names [1 ]], dtype = np . float32 )
419+ batch_output [self .model_input_names [2 ]] = np .array (encoded_inputs [self .model_input_names [2 ]], dtype = np . int64 )
420420
421421 # non_pad_mask; replaced the use of event types by using the original sequence length
422422 seq_pad_mask = np .full_like (batch_output [self .model_input_names [2 ]], fill_value = True , dtype = bool )
You can’t perform that action at this time.
0 commit comments