11/* ******************************************************************************
2- * Copyright 2016-2019 Intel Corporation
2+ * Copyright 2016-2020 Intel Corporation
33*
44* Licensed under the Apache License, Version 2.0 (the "License");
55* you may not use this file except in compliance with the License.
@@ -229,8 +229,8 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_generic(
229229 });
230230}
231231
232- template <impl:: data_type_t data_type >
233- void ref_eltwise_bwd_t <data_type>::execute_backward_dense(
232+ template <>
233+ void ref_eltwise_bwd_t <data_type:: f32 >::execute_backward_dense(
234234 const exec_ctx_t &ctx) const {
235235 auto src = CTX_IN_MEM (const data_t *, DNNL_ARG_SRC);
236236 auto diff_dst = CTX_IN_MEM (const data_t *, DNNL_ARG_DIFF_DST);
@@ -239,7 +239,7 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_dense(
239239 const memory_desc_wrapper data_d (pd ()->src_md ());
240240 const memory_desc_wrapper diff_data_d (pd ()->diff_src_md ());
241241
242- const ptrdiff_t nelems = static_cast < ptrdiff_t >( data_d.nelems (true ) );
242+ const auto nelems = data_d.nelems (true );
243243 const auto alg_kind = pd ()->desc ()->alg_kind ;
244244 const float alpha = pd ()->desc ()->alpha ;
245245 const float beta = pd ()->desc ()->beta ;
@@ -248,11 +248,57 @@ void ref_eltwise_bwd_t<data_type>::execute_backward_dense(
248248 diff_dst += diff_data_d.offset0 ();
249249 diff_src += diff_data_d.offset0 ();
250250
251- parallel_nd (nelems, [&](ptrdiff_t e) {
252- const data_t dd = diff_dst[e];
253- const data_t s = src[e];
254- data_t &ds = diff_src[e];
255- ds = compute_eltwise_scalar_bwd (alg_kind, dd, s, alpha, beta);
251+ parallel (0 , [&](const int ithr, const int nthr) {
252+ dim_t start = 0 , end = 0 ;
253+ balance211 (nelems, nthr, ithr, start, end);
254+ if (start == end) return ;
255+
256+ for (dim_t i = start; i < end; i++) {
257+ diff_src[i] = compute_eltwise_scalar_bwd (
258+ alg_kind, diff_dst[i], src[i], alpha, beta);
259+ }
260+ });
261+ }
262+
263+ template <>
264+ void ref_eltwise_bwd_t <data_type::bf16 >::execute_backward_dense(
265+ const exec_ctx_t &ctx) const {
266+ using namespace memory_tracking ::names;
267+
268+ auto src = CTX_IN_MEM (const data_t *, DNNL_ARG_SRC);
269+ auto diff_dst = CTX_IN_MEM (const data_t *, DNNL_ARG_DIFF_DST);
270+ auto diff_src = CTX_OUT_MEM (data_t *, DNNL_ARG_DIFF_SRC);
271+
272+ auto scratchpad = ctx.get_scratchpad_grantor ();
273+ auto s_f = scratchpad.template get <float >(key_eltwise_src);
274+ auto dd_f = scratchpad.template get <float >(key_eltwise_diff_dst);
275+
276+ const memory_desc_wrapper data_d (pd ()->src_md ());
277+ const memory_desc_wrapper diff_data_d (pd ()->diff_src_md ());
278+
279+ const auto nelems = data_d.nelems (true );
280+ const auto alg_kind = pd ()->desc ()->alg_kind ;
281+ const float alpha = pd ()->desc ()->alpha ;
282+ const float beta = pd ()->desc ()->beta ;
283+
284+ src += data_d.offset0 ();
285+ diff_dst += diff_data_d.offset0 ();
286+ diff_src += diff_data_d.offset0 ();
287+
288+ parallel (0 , [&](const int ithr, const int nthr) {
289+ dim_t start = 0 , end = 0 ;
290+ balance211 (nelems, nthr, ithr, start, end);
291+ if (start == end) return ;
292+
293+ cvt_bfloat16_to_float (s_f + start, src + start, end - start);
294+ cvt_bfloat16_to_float (dd_f + start, diff_dst + start, end - start);
295+
296+ for (dim_t i = start; i < end; i++) {
297+ dd_f[i] = compute_eltwise_scalar_bwd (
298+ alg_kind, dd_f[i], s_f[i], alpha, beta);
299+ }
300+
301+ cvt_float_to_bfloat16 (diff_src + start, dd_f + start, end - start);
256302 });
257303}
258304
@@ -264,7 +310,6 @@ template struct ref_eltwise_fwd_t<data_type::u8>;
264310
265311template struct ref_eltwise_bwd_t <data_type::f32 >;
266312template struct ref_eltwise_bwd_t <data_type::bf16 >;
267- template struct ref_eltwise_bwd_t <data_type::s32>;
268313
269314} // namespace cpu
270315} // namespace impl
0 commit comments