@@ -1067,9 +1067,127 @@ namespace Simd
1067
1067
1068
1068
// -------------------------------------------------------------------------------------------------
1069
1069
1070
+ template <typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k7p3d1s1w4 (const uint8_t * src8,
1071
+ const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, const float * weight, const float * bias, const float * params, uint8_t * dst)
1072
+ {
1073
+ assert (p.IsKernel (7 ) && p.IsPad (3 ) && p.IsStride (1 ) && p.IsDilation (1 ) && Aligned (p.srcW , 4 ));
1074
+ const T* src = (T*)src8;
1075
+ size_t srcH = p.srcH , srcW = p.srcW ;
1076
+ size_t sM = (a.bufH [1 ] - 1 ), sD = a.bufH [1 ] ? a.bufH [1 ] * p.srcW * F : F, sX = a.bufH [1 ] ? F : p.srcC , sY = sX * p.srcW , dstC = maC;
1077
+ size_t dX = (a.bufH [2 ] ? a.maC * 2 : p.dstC * a.elem [1 ]), dY = p.dstW * dX, dy0 = a.bufH [2 ] ? yBeg : 0 , dD = a.bufH [2 ] ? F * 2 : F * a.elem [1 ];
1078
+ size_t wD = 49 * F, dstCF = AlignLo (dstC, F), dstW = p.dstW , endW = dstW - 4 ;
1079
+ size_t dstCe = a.bufH [2 ] ? AlignHi (dstC, DF) : dstC;
1080
+
1081
+ __m512 s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, w0, w1, w2, w3, w4, w5, w6, d0, d1, d2, d3;
1082
+
1083
+ __m512 _params[2 ], _bias[1 ];
1084
+ _params[0 ] = _mm512_set1_ps (params[0 ]);
1085
+ if (type == SimdConvolutionActivationRestrictRange ||
1086
+ type == SimdConvolutionActivationHswish ||
1087
+ type == SimdConvolutionActivationHardSigmoid)
1088
+ _params[1 ] = _mm512_set1_ps (params[1 ]);
1089
+ for (size_t dc = 0 ; dc < dstCe; dc += F)
1090
+ {
1091
+ _bias[0 ] = _mm512_loadu_ps (bias + dc);
1092
+ if (type == ::SimdConvolutionActivationPrelu)
1093
+ _params[0 ] = _mm512_loadu_ps (params + dc);
1094
+ __mmask16 tailS = TailMask16 (dstC - dc);
1095
+ __mmask32 tailC = (dc == dstCF && a.bufH [2 ]) ? TailMask32 (dstCe - dstCF) : tailS;
1096
+ for (size_t dy = yBeg; dy < yEnd; ++dy)
1097
+ {
1098
+ for (size_t dx = 0 ; dx < dstW; dx += 4 )
1099
+ {
1100
+ d0 = _mm512_setzero_ps ();
1101
+ d1 = _mm512_setzero_ps ();
1102
+ d2 = _mm512_setzero_ps ();
1103
+ d3 = _mm512_setzero_ps ();
1104
+ for (size_t ky = 0 ; ky < 7 ; ++ky)
1105
+ {
1106
+ size_t sy = dy + ky - 3 ;
1107
+ const T* ps = src + (sy & sM ) * sY + (dx - 3 ) * sX ;
1108
+ const float * pw = weight + ky * 7 * F;
1109
+ if (sy < srcH)
1110
+ {
1111
+ w0 = _mm512_maskz_loadu_ps (tailS, pw + 0 * F);
1112
+ w1 = _mm512_maskz_loadu_ps (tailS, pw + 1 * F);
1113
+ w2 = _mm512_maskz_loadu_ps (tailS, pw + 2 * F);
1114
+ if (dx)
1115
+ {
1116
+ s0 = LoadSrc (ps + 0 * sX , tailS);
1117
+ d0 = _mm512_fmadd_ps (s0, w0, d0);
1118
+
1119
+ s1 = LoadSrc (ps + 1 * sX , tailS);
1120
+ d0 = _mm512_fmadd_ps (s1, w1, d0);
1121
+ d1 = _mm512_fmadd_ps (s1, w0, d1);
1122
+
1123
+ s2 = LoadSrc (ps + 2 * sX , tailS);
1124
+ d0 = _mm512_fmadd_ps (s2, w2, d0);
1125
+ d1 = _mm512_fmadd_ps (s2, w1, d1);
1126
+ d2 = _mm512_fmadd_ps (s2, w0, d2);
1127
+ }
1128
+ s3 = LoadSrc (ps + 3 * sX , tailS);
1129
+ w3 = _mm512_maskz_loadu_ps (tailS, pw + 3 * F);
1130
+ d0 = _mm512_fmadd_ps (s3, w3, d0);
1131
+ d1 = _mm512_fmadd_ps (s3, w2, d1);
1132
+ d2 = _mm512_fmadd_ps (s3, w1, d2);
1133
+ d3 = _mm512_fmadd_ps (s3, w0, d3);
1134
+
1135
+ s4 = LoadSrc (ps + 4 * sX , tailS);
1136
+ w4 = _mm512_maskz_loadu_ps (tailS, pw + 4 * F);
1137
+ d0 = _mm512_fmadd_ps (s4, w4, d0);
1138
+ d1 = _mm512_fmadd_ps (s4, w3, d1);
1139
+ d2 = _mm512_fmadd_ps (s4, w2, d2);
1140
+ d3 = _mm512_fmadd_ps (s4, w1, d3);
1141
+
1142
+ s5 = LoadSrc (ps + 5 * sX , tailS);
1143
+ w5 = _mm512_maskz_loadu_ps (tailS, pw + 5 * F);
1144
+ d0 = _mm512_fmadd_ps (s5, w5, d0);
1145
+ d1 = _mm512_fmadd_ps (s5, w4, d1);
1146
+ d2 = _mm512_fmadd_ps (s5, w3, d2);
1147
+ d3 = _mm512_fmadd_ps (s5, w2, d3);
1148
+
1149
+ s6 = LoadSrc (ps + 6 * sX , tailS);
1150
+ w6 = _mm512_maskz_loadu_ps (tailS, pw + 6 * F);
1151
+ d0 = _mm512_fmadd_ps (s6, w6, d0);
1152
+ d1 = _mm512_fmadd_ps (s6, w5, d1);
1153
+ d2 = _mm512_fmadd_ps (s6, w4, d2);
1154
+ d3 = _mm512_fmadd_ps (s6, w3, d3);
1155
+ if (dx < endW)
1156
+ {
1157
+ s7 = LoadSrc (ps + 7 * sX , tailS);
1158
+ d1 = _mm512_fmadd_ps (s7, w6, d1);
1159
+ d2 = _mm512_fmadd_ps (s7, w5, d2);
1160
+ d3 = _mm512_fmadd_ps (s7, w4, d3);
1161
+
1162
+ s8 = LoadSrc (ps + 8 * sX , tailS);
1163
+ d2 = _mm512_fmadd_ps (s8, w6, d2);
1164
+ d3 = _mm512_fmadd_ps (s8, w5, d3);
1165
+
1166
+ s9 = LoadSrc (ps + 9 * sX , tailS);
1167
+ d3 = _mm512_fmadd_ps (s9, w6, d3);
1168
+ }
1169
+ }
1170
+ }
1171
+ uint8_t * pd = dst + (dy - dy0) * dY + dx * dX;
1172
+ Save1<term, type>(pd + 0 * dX, dD, d0, _bias, _params, tailC);
1173
+ Save1<term, type>(pd + 1 * dX, dD, d1, _bias, _params, tailC);
1174
+ Save1<term, type>(pd + 2 * dX, dD, d2, _bias, _params, tailC);
1175
+ Save1<term, type>(pd + 3 * dX, dD, d3, _bias, _params, tailC);
1176
+ }
1177
+ }
1178
+ src += sD ;
1179
+ dst += dD;
1180
+ weight += wD;
1181
+ }
1182
+ }
1183
+
1184
+ // -------------------------------------------------------------------------------------------------
1185
+
1070
1186
template <typename T, Term16bType term, SimdConvolutionActivationType type, bool nofma> static void SetDepthwise (const ConvParam& p, DepthwisePtr& depthwise)
1071
1187
{
1072
- if (IsKernel (p, 3 ) && IsDilation (p, 1 ) && Aligned (p.dstC , F))
1188
+ if (p.IsKernel (7 ) && p.IsPad (3 ) && p.IsStride (1 ) && p.IsDilation (1 ) && Aligned (p.srcW , 4 ))
1189
+ depthwise = DepthwiseConvolution_k7p3d1s1w4<T, term, type>;
1190
+ else if (IsKernel (p, 3 ) && IsDilation (p, 1 ) && Aligned (p.dstC , F))
1073
1191
depthwise = DepthwiseConvolution3x3<T, term, type, nofma>;
1074
1192
else if (p.padX + p.padW > 2 && p.srcC >= 128 )
1075
1193
depthwise = DepthwiseConvolutionLargePad<T, term, type, nofma>;
0 commit comments