Skip to content

Commit 4661796

Browse files
committed
examples: rnn: fix incorrect data type in int8 example
1 parent b5f1f92 commit 4661796

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

examples/simple_rnn_int8.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
7979
int batch, int feature_size, int8_t *weights_src_layer,
8080
float weights_src_layer_scale, int32_t *compensation,
8181
uint8_t *dec_src_layer, float dec_src_layer_scale,
82-
float dec_src_layer_shift, float *annotations,
82+
float dec_src_layer_shift, uint8_t *annotations,
8383
float *weighted_annotations, float *weights_alignments) {
8484
// dst_iter : (n, c) matrix
8585
// src_layer: (n, c) matrix
@@ -168,7 +168,10 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
168168
context_vectors[i * (feature_size + feature_size) + feature_size
169169
+ j]
170170
+= alignments[k * batch + i]
171-
* annotations[j + feature_size * (i + batch * k)];
171+
* (((float)annotations[j
172+
+ feature_size * (i + batch * k)]
173+
- dec_src_layer_shift)
174+
/ dec_src_layer_scale);
172175
}
173176

174177
void copy_context(float *src_iter, int n_layers, int n_states, int batch,
@@ -670,7 +673,7 @@ void simple_net() {
670673
feature_size, user_weights_attention_src_layer.data(),
671674
weights_attention_scale, weights_attention_sum_rows.data(),
672675
src_att_layer_handle, data_scale, data_shift,
673-
(float *)enc_bidir_dst_layer_memory.get_data_handle(),
676+
(uint8_t *)enc_bidir_dst_layer_memory.get_data_handle(),
674677
weighted_annotations.data(),
675678
user_weights_alignments.data());
676679

0 commit comments

Comments
 (0)