@@ -79,7 +79,7 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
7979 int batch, int feature_size, int8_t *weights_src_layer,
8080 float weights_src_layer_scale, int32_t *compensation,
8181 uint8_t *dec_src_layer, float dec_src_layer_scale,
82- float dec_src_layer_shift, float *annotations,
82+ float dec_src_layer_shift, uint8_t *annotations,
8383 float *weighted_annotations, float *weights_alignments) {
8484 // dst_iter : (n, c) matrix
8585 // src_layer: (n, c) matrix
@@ -168,7 +168,10 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
168168 context_vectors[i * (feature_size + feature_size) + feature_size
169169 + j]
170170 += alignments[k * batch + i]
171- * annotations[j + feature_size * (i + batch * k)];
171+ * (((float )annotations[j
172+ + feature_size * (i + batch * k)]
173+ - dec_src_layer_shift)
174+ / dec_src_layer_scale);
172175}
173176
174177void copy_context (float *src_iter, int n_layers, int n_states, int batch,
@@ -670,7 +673,7 @@ void simple_net() {
670673 feature_size, user_weights_attention_src_layer.data (),
671674 weights_attention_scale, weights_attention_sum_rows.data (),
672675 src_att_layer_handle, data_scale, data_shift,
673- (float *)enc_bidir_dst_layer_memory.get_data_handle (),
676+ (uint8_t *)enc_bidir_dst_layer_memory.get_data_handle (),
674677 weighted_annotations.data (),
675678 user_weights_alignments.data ());
676679
0 commit comments