Skip to content

Commit e281a4a

Browse files
committed
src: add a specialization for bf16 bwd ref impl
* Per element conversion in parallel_nd slows down the execution due to * lack of vectorization and multiple function calls to cast operator.
1 parent c743dda commit e281a4a

File tree

4 files changed

+74
-14
lines changed

4 files changed

+74
-14
lines changed

src/common/memory_tracking.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2018-2019 Intel Corporation
2+
* Copyright 2018-2020 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -173,6 +173,8 @@ enum {
173173
key_conv_wei_reduction,
174174
key_conv_wei_bia_reduction,
175175
key_conv_wei_bia_reduction_bctx,
176+
key_eltwise_diff_dst,
177+
key_eltwise_src,
176178
key_iprod_bias_bf16_convert_wsp,
177179
key_iprod_dst_bf16_convert_wsp,
178180
key_iprod_int_dat_in_acc_dt,

src/cpu/cpu_eltwise_list.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019 Intel Corporation
2+
* Copyright 2019-2020 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -51,7 +51,6 @@ static const pd_create_f impl_list[] = {
5151
INSTANCE(ref_eltwise_fwd_t<s32>),
5252
INSTANCE(ref_eltwise_fwd_t<s8>),
5353
INSTANCE(ref_eltwise_fwd_t<u8>),
54-
INSTANCE(ref_eltwise_bwd_t<s32>),
5554
/* eol */
5655
nullptr,
5756
};

src/cpu/ref_eltwise.cpp

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2016-2019 Intel Corporation
2+
* Copyright 2016-2020 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -229,8 +229,8 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_generic(
229229
});
230230
}
231231

232-
template <impl::data_type_t data_type>
233-
void ref_eltwise_bwd_t<data_type>::execute_backward_dense(
232+
template <>
233+
void ref_eltwise_bwd_t<data_type::f32>::execute_backward_dense(
234234
const exec_ctx_t &ctx) const {
235235
auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
236236
auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
@@ -239,7 +239,7 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_dense(
239239
const memory_desc_wrapper data_d(pd()->src_md());
240240
const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
241241

242-
const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
242+
const auto nelems = data_d.nelems(true);
243243
const auto alg_kind = pd()->desc()->alg_kind;
244244
const float alpha = pd()->desc()->alpha;
245245
const float beta = pd()->desc()->beta;
@@ -248,11 +248,57 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_dense(
248248
diff_dst += diff_data_d.offset0();
249249
diff_src += diff_data_d.offset0();
250250

251-
parallel_nd(nelems, [&](ptrdiff_t e) {
252-
const data_t dd = diff_dst[e];
253-
const data_t s = src[e];
254-
data_t &ds = diff_src[e];
255-
ds = compute_eltwise_scalar_bwd(alg_kind, dd, s, alpha, beta);
251+
parallel(0, [&](const int ithr, const int nthr) {
252+
dim_t start = 0, end = 0;
253+
balance211(nelems, nthr, ithr, start, end);
254+
if (start == end) return;
255+
256+
for (dim_t i = start; i < end; i++) {
257+
diff_src[i] = compute_eltwise_scalar_bwd(
258+
alg_kind, diff_dst[i], src[i], alpha, beta);
259+
}
260+
});
261+
}
262+
263+
template <>
264+
void ref_eltwise_bwd_t<data_type::bf16>::execute_backward_dense(
265+
const exec_ctx_t &ctx) const {
266+
using namespace memory_tracking::names;
267+
268+
auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
269+
auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
270+
auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
271+
272+
auto scratchpad = ctx.get_scratchpad_grantor();
273+
auto s_f = scratchpad.template get<float>(key_eltwise_src);
274+
auto dd_f = scratchpad.template get<float>(key_eltwise_diff_dst);
275+
276+
const memory_desc_wrapper data_d(pd()->src_md());
277+
const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
278+
279+
const auto nelems = data_d.nelems(true);
280+
const auto alg_kind = pd()->desc()->alg_kind;
281+
const float alpha = pd()->desc()->alpha;
282+
const float beta = pd()->desc()->beta;
283+
284+
src += data_d.offset0();
285+
diff_dst += diff_data_d.offset0();
286+
diff_src += diff_data_d.offset0();
287+
288+
parallel(0, [&](const int ithr, const int nthr) {
289+
dim_t start = 0, end = 0;
290+
balance211(nelems, nthr, ithr, start, end);
291+
if (start == end) return;
292+
293+
cvt_bfloat16_to_float(s_f + start, src + start, end - start);
294+
cvt_bfloat16_to_float(dd_f + start, diff_dst + start, end - start);
295+
296+
for (dim_t i = start; i < end; i++) {
297+
dd_f[i] = compute_eltwise_scalar_bwd(
298+
alg_kind, dd_f[i], s_f[i], alpha, beta);
299+
}
300+
301+
cvt_float_to_bfloat16(diff_src + start, dd_f + start, end - start);
256302
});
257303
}
258304

@@ -264,7 +310,6 @@ template struct ref_eltwise_fwd_t<data_type::u8>;
264310

265311
template struct ref_eltwise_bwd_t<data_type::f32>;
266312
template struct ref_eltwise_bwd_t<data_type::bf16>;
267-
template struct ref_eltwise_bwd_t<data_type::s32>;
268313

269314
} // namespace cpu
270315
} // namespace impl

src/cpu/ref_eltwise.hpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2016-2019 Intel Corporation
2+
* Copyright 2016-2020 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -136,10 +136,24 @@ struct ref_eltwise_bwd_t : public primitive_impl_t {
136136
if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5))
137137
return status::unimplemented;
138138

139+
if (data_type == data_type::bf16) init_scratchpad();
140+
139141
return status::success;
140142
}
141143

142144
bool use_dense_;
145+
146+
private:
147+
void init_scratchpad() {
148+
const memory_desc_wrapper data_d(src_md());
149+
const memory_desc_wrapper diff_data_d(diff_dst_md());
150+
using namespace memory_tracking::names;
151+
auto scratchpad = scratchpad_registry().registrar();
152+
const auto diff_dst_size = diff_data_d.nelems(true) * sizeof(float);
153+
scratchpad.book(
154+
key_eltwise_src, data_d.nelems(true) * sizeof(float));
155+
scratchpad.book(key_eltwise_diff_dst, diff_dst_size);
156+
}
143157
};
144158

145159
ref_eltwise_bwd_t(const pd_t *apd) : primitive_impl_t(apd) {}

0 commit comments

Comments
 (0)