Skip to content

Commit

Permalink
+add AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w4 for clas…
Browse files Browse the repository at this point in the history
…s SynetConvolution32fNhwcDepthwise.
  • Loading branch information
ermig1979 committed Oct 10, 2024
1 parent c92d5a8 commit 688403b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/2024.html
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ <h4>Algorithms</h4>
<h5>New features</h5>
<ul>
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of class SynetConvolution16bNhwcDepthwise.</li>
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w4 for class SynetConvolution32fNhwcDepthwise.</li>
</ul>
<h5>Im
<h5>Improving</h5>
Expand Down
105 changes: 104 additions & 1 deletion src/Simd/SimdAvx512bwSynetConvolution32fNhwcDepthwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,9 +885,112 @@ namespace Simd

//-------------------------------------------------------------------------------------------------

template<::SimdConvolutionActivationType type> void Convolution32fNhwcDepthwise_k7p3d1s1w4(const float* src, const ConvParam& p, const float* weight, const float* bias, const float* params, float* dst)
{
assert(p.IsKernel(7) && p.IsPad(3) && p.IsStride(1) && p.IsDilation(1) && Aligned(p.srcW, 4));

size_t dstC = p.dstC, dstCF = AlignLo(p.dstC, F), dstW = p.dstW, srcH = p.srcH, end = dstW - 4;
__m512 s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, w0, w1, w2, w3, w4, w5, w6, d0, d1, d2, d3, _params[2];
_params[0] = _mm512_set1_ps(params[0]);
if (type == SimdConvolutionActivationRestrictRange ||
type == SimdConvolutionActivationHswish ||
type == SimdConvolutionActivationHardSigmoid)
_params[1] = _mm512_set1_ps(params[1]);
for (size_t dy = 0; dy < p.dstH; ++dy)
{
for (size_t dx = 0; dx < dstW; dx += 4)
{
for (size_t dc = 0; dc < dstC; dc += F)
{
__mmask16 tail = dc < dstCF ? __mmask16(-1) : TailMask16(dstC - dc);
if (type == SimdConvolutionActivationPrelu)
_params[0] = _mm512_maskz_loadu_ps(tail, params + dc);
d0 = bias ? _mm512_maskz_loadu_ps(tail, bias + dc) : _mm512_setzero_ps();
d1 = d0; d2 = d0; d3 = d0;
for (size_t ky = 0; ky < 7; ++ky)
{
size_t sy = dy + ky - 3;
const float* ps = src + (sy * dstW + dx - 3) * dstC + dc;
const float* pw = weight + ky * 7 * dstC + dc;
if (sy < srcH)
{
w0 = _mm512_maskz_loadu_ps(tail, pw + 0 * dstC);
w1 = _mm512_maskz_loadu_ps(tail, pw + 1 * dstC);
w2 = _mm512_maskz_loadu_ps(tail, pw + 2 * dstC);
if (dx)
{
s0 = _mm512_maskz_loadu_ps(tail, ps + 0 * dstC);
d0 = _mm512_fmadd_ps(s0, w0, d0);

s1 = _mm512_maskz_loadu_ps(tail, ps + 1 * dstC);
d0 = _mm512_fmadd_ps(s1, w1, d0);
d1 = _mm512_fmadd_ps(s1, w0, d1);

s2 = _mm512_maskz_loadu_ps(tail, ps + 2 * dstC);
d0 = _mm512_fmadd_ps(s2, w2, d0);
d1 = _mm512_fmadd_ps(s2, w1, d1);
d2 = _mm512_fmadd_ps(s2, w0, d2);
}
s3 = _mm512_maskz_loadu_ps(tail, ps + 3 * dstC);
w3 = _mm512_maskz_loadu_ps(tail, pw + 3 * dstC);
d0 = _mm512_fmadd_ps(s3, w3, d0);
d1 = _mm512_fmadd_ps(s3, w2, d1);
d2 = _mm512_fmadd_ps(s3, w1, d2);
d3 = _mm512_fmadd_ps(s3, w0, d3);

s4 = _mm512_maskz_loadu_ps(tail, ps + 4 * dstC);
w4 = _mm512_maskz_loadu_ps(tail, pw + 4 * dstC);
d0 = _mm512_fmadd_ps(s4, w4, d0);
d1 = _mm512_fmadd_ps(s4, w3, d1);
d2 = _mm512_fmadd_ps(s4, w2, d2);
d3 = _mm512_fmadd_ps(s4, w1, d3);

s5 = _mm512_maskz_loadu_ps(tail, ps + 5 * dstC);
w5 = _mm512_maskz_loadu_ps(tail, pw + 5 * dstC);
d0 = _mm512_fmadd_ps(s5, w5, d0);
d1 = _mm512_fmadd_ps(s5, w4, d1);
d2 = _mm512_fmadd_ps(s5, w3, d2);
d3 = _mm512_fmadd_ps(s5, w2, d3);

s6 = _mm512_maskz_loadu_ps(tail, ps + 6 * dstC);
w6 = _mm512_maskz_loadu_ps(tail, pw + 6 * dstC);
d0 = _mm512_fmadd_ps(s6, w6, d0);
d1 = _mm512_fmadd_ps(s6, w5, d1);
d2 = _mm512_fmadd_ps(s6, w4, d2);
d3 = _mm512_fmadd_ps(s6, w3, d3);
if (dx < end)
{
s7 = _mm512_maskz_loadu_ps(tail, ps + 7 * dstC);
d1 = _mm512_fmadd_ps(s7, w6, d1);
d2 = _mm512_fmadd_ps(s7, w5, d2);
d3 = _mm512_fmadd_ps(s7, w4, d3);

s8 = _mm512_maskz_loadu_ps(tail, ps + 8 * dstC);
d2 = _mm512_fmadd_ps(s8, w6, d2);
d3 = _mm512_fmadd_ps(s8, w5, d3);

s9 = _mm512_maskz_loadu_ps(tail, ps + 9 * dstC);
d3 = _mm512_fmadd_ps(s9, w6, d3);
}
}
}
float* pd = dst + (dy * dstW + dx) * dstC + dc;
_mm512_mask_storeu_ps(pd + 0 * dstC, tail, Activate<type>(d0, _params, 0));
_mm512_mask_storeu_ps(pd + 1 * dstC, tail, Activate<type>(d1, _params, 0));
_mm512_mask_storeu_ps(pd + 2 * dstC, tail, Activate<type>(d2, _params, 0));
_mm512_mask_storeu_ps(pd + 3 * dstC, tail, Activate<type>(d3, _params, 0));
}
}
}
}

//-------------------------------------------------------------------------------------------------

template <::SimdConvolutionActivationType type> SynetConvolution32fNhwcDepthwise::ConvolutionPtr Get(const ConvParam& p)
{
if (p.IsKernel(3) && p.IsDilation(1))
if (p.IsKernel(7) && p.IsPad(3) && p.IsStride(1) && p.IsDilation(1) && Aligned(p.srcW, 4))
return Convolution32fNhwcDepthwise_k7p3d1s1w4<type>;
else if (p.IsKernel(3) && p.IsDilation(1))
return Convolution32fNhwcDepthwise3x3<type>;
else
return Convolution32fNhwcDepthwiseDefault<type>;
Expand Down

0 comments on commit 688403b

Please sign in to comment.