Skip to content

Commit ee2ffcb

Browse files
committed
+add AMX-BF16 kernel DepthwiseConvolution_k7p3d1s1w4 for class SynetMergedConvolution16b.
1 parent 688403b commit ee2ffcb

File tree

3 files changed

+122
-1
lines changed

3 files changed

+122
-1
lines changed

docs/2024.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ <h5>New features</h5>
4141
<ul>
4242
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of class SynetConvolution16bNhwcDepthwise.</li>
4343
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w4 for class SynetConvolution32fNhwcDepthwise.</li>
44+
<li>AMX-BF16 kernel DepthwiseConvolution_k7p3d1s1w4 for class SynetMergedConvolution16b.</li>
4445
</ul>
4546
<h5>Im
4647
<h5>Improving</h5>

src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise.cpp

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1067,9 +1067,127 @@ namespace Simd
10671067

10681068
//-------------------------------------------------------------------------------------------------
10691069

1070+
template<typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k7p3d1s1w4(const uint8_t* src8,
1071+
const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, const float* weight, const float* bias, const float* params, uint8_t* dst)
1072+
{
1073+
assert(p.IsKernel(7) && p.IsPad(3) && p.IsStride(1) && p.IsDilation(1) && Aligned(p.srcW, 4));
1074+
const T* src = (T*)src8;
1075+
size_t srcH = p.srcH, srcW = p.srcW;
1076+
size_t sM = (a.bufH[1] - 1), sD = a.bufH[1] ? a.bufH[1] * p.srcW * F : F, sX = a.bufH[1] ? F : p.srcC, sY = sX * p.srcW, dstC = maC;
1077+
size_t dX = (a.bufH[2] ? a.maC * 2 : p.dstC * a.elem[1]), dY = p.dstW * dX, dy0 = a.bufH[2] ? yBeg : 0, dD = a.bufH[2] ? F * 2 : F * a.elem[1];
1078+
size_t wD = 49 * F, dstCF = AlignLo(dstC, F), dstW = p.dstW, endW = dstW - 4;
1079+
size_t dstCe = a.bufH[2] ? AlignHi(dstC, DF) : dstC;
1080+
1081+
__m512 s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, w0, w1, w2, w3, w4, w5, w6, d0, d1, d2, d3;
1082+
1083+
__m512 _params[2], _bias[1];
1084+
_params[0] = _mm512_set1_ps(params[0]);
1085+
if (type == SimdConvolutionActivationRestrictRange ||
1086+
type == SimdConvolutionActivationHswish ||
1087+
type == SimdConvolutionActivationHardSigmoid)
1088+
_params[1] = _mm512_set1_ps(params[1]);
1089+
for (size_t dc = 0; dc < dstCe; dc += F)
1090+
{
1091+
_bias[0] = _mm512_loadu_ps(bias + dc);
1092+
if (type == ::SimdConvolutionActivationPrelu)
1093+
_params[0] = _mm512_loadu_ps(params + dc);
1094+
__mmask16 tailS = TailMask16(dstC - dc);
1095+
__mmask32 tailC = (dc == dstCF && a.bufH[2]) ? TailMask32(dstCe - dstCF) : tailS;
1096+
for (size_t dy = yBeg; dy < yEnd; ++dy)
1097+
{
1098+
for (size_t dx = 0; dx < dstW; dx += 4)
1099+
{
1100+
d0 = _mm512_setzero_ps();
1101+
d1 = _mm512_setzero_ps();
1102+
d2 = _mm512_setzero_ps();
1103+
d3 = _mm512_setzero_ps();
1104+
for (size_t ky = 0; ky < 7; ++ky)
1105+
{
1106+
size_t sy = dy + ky - 3;
1107+
const T* ps = src + (sy & sM) * sY + (dx - 3) * sX;
1108+
const float* pw = weight + ky * 7 * F;
1109+
if (sy < srcH)
1110+
{
1111+
w0 = _mm512_maskz_loadu_ps(tailS, pw + 0 * F);
1112+
w1 = _mm512_maskz_loadu_ps(tailS, pw + 1 * F);
1113+
w2 = _mm512_maskz_loadu_ps(tailS, pw + 2 * F);
1114+
if (dx)
1115+
{
1116+
s0 = LoadSrc(ps + 0 * sX, tailS);
1117+
d0 = _mm512_fmadd_ps(s0, w0, d0);
1118+
1119+
s1 = LoadSrc(ps + 1 * sX, tailS);
1120+
d0 = _mm512_fmadd_ps(s1, w1, d0);
1121+
d1 = _mm512_fmadd_ps(s1, w0, d1);
1122+
1123+
s2 = LoadSrc(ps + 2 * sX, tailS);
1124+
d0 = _mm512_fmadd_ps(s2, w2, d0);
1125+
d1 = _mm512_fmadd_ps(s2, w1, d1);
1126+
d2 = _mm512_fmadd_ps(s2, w0, d2);
1127+
}
1128+
s3 = LoadSrc(ps + 3 * sX, tailS);
1129+
w3 = _mm512_maskz_loadu_ps(tailS, pw + 3 * F);
1130+
d0 = _mm512_fmadd_ps(s3, w3, d0);
1131+
d1 = _mm512_fmadd_ps(s3, w2, d1);
1132+
d2 = _mm512_fmadd_ps(s3, w1, d2);
1133+
d3 = _mm512_fmadd_ps(s3, w0, d3);
1134+
1135+
s4 = LoadSrc(ps + 4 * sX, tailS);
1136+
w4 = _mm512_maskz_loadu_ps(tailS, pw + 4 * F);
1137+
d0 = _mm512_fmadd_ps(s4, w4, d0);
1138+
d1 = _mm512_fmadd_ps(s4, w3, d1);
1139+
d2 = _mm512_fmadd_ps(s4, w2, d2);
1140+
d3 = _mm512_fmadd_ps(s4, w1, d3);
1141+
1142+
s5 = LoadSrc(ps + 5 * sX, tailS);
1143+
w5 = _mm512_maskz_loadu_ps(tailS, pw + 5 * F);
1144+
d0 = _mm512_fmadd_ps(s5, w5, d0);
1145+
d1 = _mm512_fmadd_ps(s5, w4, d1);
1146+
d2 = _mm512_fmadd_ps(s5, w3, d2);
1147+
d3 = _mm512_fmadd_ps(s5, w2, d3);
1148+
1149+
s6 = LoadSrc(ps + 6 * sX, tailS);
1150+
w6 = _mm512_maskz_loadu_ps(tailS, pw + 6 * F);
1151+
d0 = _mm512_fmadd_ps(s6, w6, d0);
1152+
d1 = _mm512_fmadd_ps(s6, w5, d1);
1153+
d2 = _mm512_fmadd_ps(s6, w4, d2);
1154+
d3 = _mm512_fmadd_ps(s6, w3, d3);
1155+
if (dx < endW)
1156+
{
1157+
s7 = LoadSrc(ps + 7 * sX, tailS);
1158+
d1 = _mm512_fmadd_ps(s7, w6, d1);
1159+
d2 = _mm512_fmadd_ps(s7, w5, d2);
1160+
d3 = _mm512_fmadd_ps(s7, w4, d3);
1161+
1162+
s8 = LoadSrc(ps + 8 * sX, tailS);
1163+
d2 = _mm512_fmadd_ps(s8, w6, d2);
1164+
d3 = _mm512_fmadd_ps(s8, w5, d3);
1165+
1166+
s9 = LoadSrc(ps + 9 * sX, tailS);
1167+
d3 = _mm512_fmadd_ps(s9, w6, d3);
1168+
}
1169+
}
1170+
}
1171+
uint8_t* pd = dst + (dy - dy0) * dY + dx * dX;
1172+
Save1<term, type>(pd + 0 * dX, dD, d0, _bias, _params, tailC);
1173+
Save1<term, type>(pd + 1 * dX, dD, d1, _bias, _params, tailC);
1174+
Save1<term, type>(pd + 2 * dX, dD, d2, _bias, _params, tailC);
1175+
Save1<term, type>(pd + 3 * dX, dD, d3, _bias, _params, tailC);
1176+
}
1177+
}
1178+
src += sD;
1179+
dst += dD;
1180+
weight += wD;
1181+
}
1182+
}
1183+
1184+
//-------------------------------------------------------------------------------------------------
1185+
10701186
template<typename T, Term16bType term, SimdConvolutionActivationType type, bool nofma> static void SetDepthwise(const ConvParam& p, DepthwisePtr& depthwise)
10711187
{
1072-
if (IsKernel(p, 3) && IsDilation(p, 1) && Aligned(p.dstC, F))
1188+
if (p.IsKernel(7) && p.IsPad(3) && p.IsStride(1) && p.IsDilation(1) && Aligned(p.srcW, 4))
1189+
depthwise = DepthwiseConvolution_k7p3d1s1w4<T, term, type>;
1190+
else if (IsKernel(p, 3) && IsDilation(p, 1) && Aligned(p.dstC, F))
10731191
depthwise = DepthwiseConvolution3x3<T, term, type, nofma>;
10741192
else if(p.padX + p.padW > 2 && p.srcC >= 128)
10751193
depthwise = DepthwiseConvolutionLargePad<T, term, type, nofma>;

src/Test/TestSynetMergedConvolution16b.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ namespace Test
285285
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 116, 15, 5), Cnv(a1, 3, 2), Cnv(a2, 1, 1, 116), f32, f32, c), f1, f2);
286286
#endif
287287
#if 1
288+
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 16, 16), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 256), f32, f32, c), f1, f2);
289+
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 16, 16), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 256), b16, b16, c), f1, f2);
288290
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 304, 17, 15), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 1216), b16, b16, c), f1, f2);
289291
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 76, 64, 64), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 304), f32, b16, c), f1, f2);
290292
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 152, 32, 32), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 608), f32, b16, c), f1, f2);

0 commit comments

Comments
 (0)