Skip to content

Commit

Permalink
+add AMX-BF16 kernel DepthwiseConvolution_k7p3d1s1w4 for class SynetM…
Browse files Browse the repository at this point in the history
…ergedConvolution16b.
  • Loading branch information
ermig1979 committed Oct 10, 2024
1 parent 688403b commit ee2ffcb
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/2024.html
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ <h5>New features</h5>
<ul>
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of class SynetConvolution16bNhwcDepthwise.</li>
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w4 for class SynetConvolution32fNhwcDepthwise.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k7p3d1s1w4 for class SynetMergedConvolution16b.</li>
</ul>
<h5>Im
<h5>Improving</h5>
Expand Down
120 changes: 119 additions & 1 deletion src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1067,9 +1067,127 @@ namespace Simd

//-------------------------------------------------------------------------------------------------

template<typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k7p3d1s1w4(const uint8_t* src8,
const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, const float* weight, const float* bias, const float* params, uint8_t* dst)
{
assert(p.IsKernel(7) && p.IsPad(3) && p.IsStride(1) && p.IsDilation(1) && Aligned(p.srcW, 4));
const T* src = (T*)src8;
size_t srcH = p.srcH, srcW = p.srcW;
size_t sM = (a.bufH[1] - 1), sD = a.bufH[1] ? a.bufH[1] * p.srcW * F : F, sX = a.bufH[1] ? F : p.srcC, sY = sX * p.srcW, dstC = maC;
size_t dX = (a.bufH[2] ? a.maC * 2 : p.dstC * a.elem[1]), dY = p.dstW * dX, dy0 = a.bufH[2] ? yBeg : 0, dD = a.bufH[2] ? F * 2 : F * a.elem[1];
size_t wD = 49 * F, dstCF = AlignLo(dstC, F), dstW = p.dstW, endW = dstW - 4;
size_t dstCe = a.bufH[2] ? AlignHi(dstC, DF) : dstC;

__m512 s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, w0, w1, w2, w3, w4, w5, w6, d0, d1, d2, d3;

__m512 _params[2], _bias[1];
_params[0] = _mm512_set1_ps(params[0]);
if (type == SimdConvolutionActivationRestrictRange ||
type == SimdConvolutionActivationHswish ||
type == SimdConvolutionActivationHardSigmoid)
_params[1] = _mm512_set1_ps(params[1]);
for (size_t dc = 0; dc < dstCe; dc += F)
{
_bias[0] = _mm512_loadu_ps(bias + dc);
if (type == ::SimdConvolutionActivationPrelu)
_params[0] = _mm512_loadu_ps(params + dc);
__mmask16 tailS = TailMask16(dstC - dc);
__mmask32 tailC = (dc == dstCF && a.bufH[2]) ? TailMask32(dstCe - dstCF) : tailS;
for (size_t dy = yBeg; dy < yEnd; ++dy)
{
for (size_t dx = 0; dx < dstW; dx += 4)
{
d0 = _mm512_setzero_ps();
d1 = _mm512_setzero_ps();
d2 = _mm512_setzero_ps();
d3 = _mm512_setzero_ps();
for (size_t ky = 0; ky < 7; ++ky)
{
size_t sy = dy + ky - 3;
const T* ps = src + (sy & sM) * sY + (dx - 3) * sX;
const float* pw = weight + ky * 7 * F;
if (sy < srcH)
{
w0 = _mm512_maskz_loadu_ps(tailS, pw + 0 * F);
w1 = _mm512_maskz_loadu_ps(tailS, pw + 1 * F);
w2 = _mm512_maskz_loadu_ps(tailS, pw + 2 * F);
if (dx)
{
s0 = LoadSrc(ps + 0 * sX, tailS);
d0 = _mm512_fmadd_ps(s0, w0, d0);

s1 = LoadSrc(ps + 1 * sX, tailS);
d0 = _mm512_fmadd_ps(s1, w1, d0);
d1 = _mm512_fmadd_ps(s1, w0, d1);

s2 = LoadSrc(ps + 2 * sX, tailS);
d0 = _mm512_fmadd_ps(s2, w2, d0);
d1 = _mm512_fmadd_ps(s2, w1, d1);
d2 = _mm512_fmadd_ps(s2, w0, d2);
}
s3 = LoadSrc(ps + 3 * sX, tailS);
w3 = _mm512_maskz_loadu_ps(tailS, pw + 3 * F);
d0 = _mm512_fmadd_ps(s3, w3, d0);
d1 = _mm512_fmadd_ps(s3, w2, d1);
d2 = _mm512_fmadd_ps(s3, w1, d2);
d3 = _mm512_fmadd_ps(s3, w0, d3);

s4 = LoadSrc(ps + 4 * sX, tailS);
w4 = _mm512_maskz_loadu_ps(tailS, pw + 4 * F);
d0 = _mm512_fmadd_ps(s4, w4, d0);
d1 = _mm512_fmadd_ps(s4, w3, d1);
d2 = _mm512_fmadd_ps(s4, w2, d2);
d3 = _mm512_fmadd_ps(s4, w1, d3);

s5 = LoadSrc(ps + 5 * sX, tailS);
w5 = _mm512_maskz_loadu_ps(tailS, pw + 5 * F);
d0 = _mm512_fmadd_ps(s5, w5, d0);
d1 = _mm512_fmadd_ps(s5, w4, d1);
d2 = _mm512_fmadd_ps(s5, w3, d2);
d3 = _mm512_fmadd_ps(s5, w2, d3);

s6 = LoadSrc(ps + 6 * sX, tailS);
w6 = _mm512_maskz_loadu_ps(tailS, pw + 6 * F);
d0 = _mm512_fmadd_ps(s6, w6, d0);
d1 = _mm512_fmadd_ps(s6, w5, d1);
d2 = _mm512_fmadd_ps(s6, w4, d2);
d3 = _mm512_fmadd_ps(s6, w3, d3);
if (dx < endW)
{
s7 = LoadSrc(ps + 7 * sX, tailS);
d1 = _mm512_fmadd_ps(s7, w6, d1);
d2 = _mm512_fmadd_ps(s7, w5, d2);
d3 = _mm512_fmadd_ps(s7, w4, d3);

s8 = LoadSrc(ps + 8 * sX, tailS);
d2 = _mm512_fmadd_ps(s8, w6, d2);
d3 = _mm512_fmadd_ps(s8, w5, d3);

s9 = LoadSrc(ps + 9 * sX, tailS);
d3 = _mm512_fmadd_ps(s9, w6, d3);
}
}
}
uint8_t* pd = dst + (dy - dy0) * dY + dx * dX;
Save1<term, type>(pd + 0 * dX, dD, d0, _bias, _params, tailC);
Save1<term, type>(pd + 1 * dX, dD, d1, _bias, _params, tailC);
Save1<term, type>(pd + 2 * dX, dD, d2, _bias, _params, tailC);
Save1<term, type>(pd + 3 * dX, dD, d3, _bias, _params, tailC);
}
}
src += sD;
dst += dD;
weight += wD;
}
}

//-------------------------------------------------------------------------------------------------

template<typename T, Term16bType term, SimdConvolutionActivationType type, bool nofma> static void SetDepthwise(const ConvParam& p, DepthwisePtr& depthwise)
{
if (IsKernel(p, 3) && IsDilation(p, 1) && Aligned(p.dstC, F))
if (p.IsKernel(7) && p.IsPad(3) && p.IsStride(1) && p.IsDilation(1) && Aligned(p.srcW, 4))
depthwise = DepthwiseConvolution_k7p3d1s1w4<T, term, type>;
else if (IsKernel(p, 3) && IsDilation(p, 1) && Aligned(p.dstC, F))
depthwise = DepthwiseConvolution3x3<T, term, type, nofma>;
else if(p.padX + p.padW > 2 && p.srcC >= 128)
depthwise = DepthwiseConvolutionLargePad<T, term, type, nofma>;
Expand Down
2 changes: 2 additions & 0 deletions src/Test/TestSynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ namespace Test
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 116, 15, 5), Cnv(a1, 3, 2), Cnv(a2, 1, 1, 116), f32, f32, c), f1, f2);
#endif
#if 1
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 16, 16), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 256), f32, f32, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 16, 16), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 256), b16, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 304, 17, 15), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 1216), b16, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 76, 64, 64), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 304), f32, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 152, 32, 32), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 608), f32, b16, c), f1, f2);
Expand Down

0 comments on commit ee2ffcb

Please sign in to comment.