Skip to content

Commit ff3ffab

Browse files
author
Fomenko, Evarist M
committed
all: rnn: align computing (per-oc) mask with the library
The i-th bit in the mask corresponds to the i-th dimension. The dimensions are enumerated from the outermost one, aka C-array style.
1 parent 28f4c96 commit ff3ffab

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

examples/cpu_rnn_inference_int8.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ void simple_net() {
237237
///
238238
const float data_shift = 64.;
239239
const float data_scale = 63.;
240-
const int weights_scale_mask = 3; // 11 for last two dimensions of ldigo
240+
const int weights_scale_mask = 0
241+
+ (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo`
242+
+ (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo`
241243
//[quantize]
242244
std::vector<float> weights_scales(lstm_n_gates * feature_size);
243245
// assign halves of vector with arbitrary values

src/cpu/rnn/rnn_reorders.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ struct rnn_weights_reorder_t : public primitive_impl_t {
125125
if (itag == format_tag::undef) return invalid_arguments;
126126

127127
const int mask = attr->rnn_weights_qparams_.mask_;
128-
if (!utils::one_of(mask, 0, 3)) return unimplemented;
128+
if (!utils::one_of(mask, 0, 24)) return unimplemented;
129129

130130
auto _pd = new pd_t(
131131
engine, attr, src_engine, src_md, dst_engine, dst_md);

tests/benchdnn/rnn/rnn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void create_dnnl_rnn_attr(const prb_t &p, dnnl_primitive_attr_t *dnnl_attr) {
5656

5757
if (p.scale_policy == policy_t::PER_OC) {
5858
DNN_SAFE_V(dnnl_primitive_attr_set_rnn_weights_qparams(
59-
*dnnl_attr, p.dic * p.n_gates(), 0x3, p.wei_oc_scales));
59+
*dnnl_attr, p.dic * p.n_gates(), 0x18, p.wei_oc_scales));
6060
} else if (p.scale_policy == policy_t::COMMON && p.wei_scale != 1.) {
6161
DNN_SAFE_V(dnnl_primitive_attr_set_rnn_weights_qparams(
6262
*dnnl_attr, 1, 0, &p.wei_scale));

0 commit comments

Comments
 (0)