@@ -43,106 +43,275 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
43
43
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
44
44
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
45
45
46
- void main() {
47
- const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
46
+ #define MAX_WORKGROUP_SIZE 64
47
+
48
+ // Shared memory factor increases shared memory allocation by a scale that should either be 1 or a power of 2.
49
+ //
50
+ // Increasing factor allows more data to be stored in shared memory and increase thread utilization during reduction.
51
+ // Why? Because when performing reduction, the number of active threads becomes half in each iteration.
52
+ // Increasing scaling factor increases the thread occupancy and hence utilize the GPU better.
53
+ // eg.
54
+ // If local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 1, 32 elements will be loaded into shared memory.
55
+ // First iteration of reduce will have 16 threads sum up 32 elements.
56
+ // Second iteration will have 8 threads sum up 16 elements from previous iteration and so on.
57
+ // So thread utilization starts at 50%.
58
+ //
59
+ // By contrast if local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 2, 64 elements will be loaded into shared memory.
60
+ // First iteration of reduce will have 32 threads sum up 64 elements.
61
+ // Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
62
+ // Thus thread utilization starts at 100%.
63
+ #define SHARED_MEMORY_FACTOR 2
64
+
65
+ #define offset_pos_index(index) ((index) + ((index) >> 2 ))
66
+
67
+ shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
68
+
69
+ // Function to reduce input data in workgroup's x dimension
70
+ //
71
+ // The implementation resembles reduction as depicted below
72
+ // | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | 2 | 3 | 2 | 7 | 0 | 11 | 0 | 2 | current_stride -> 1
73
+ // | / | / | / | / | / | / | / | /
74
+ // | / | / | / | / | / | / | / | /
75
+ // | / | / | / | / | / | / | / | /
76
+ // | 11 | 1 | 9 | 1 | 2 | 2 | 8 | 5 | 5 | 3 | 9 | 7 | 11 | 11 | 2 | 2 | current_stride -> 2
77
+ // | / | / | / | /
78
+ // | / | / | / | /
79
+ // | / | / | / | /
80
+ // | 20 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |14 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 4
81
+ // | / | /
82
+ // | / | /
83
+ // | / | /
84
+ // | / | /
85
+ // | / | /
86
+ // | 30 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 8
87
+ // | /
88
+ // | /
89
+ // | /
90
+ // | /
91
+ // | /
92
+ // | /
93
+ // | /
94
+ // | /
95
+ // | /
96
+ // | 57 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride = -> 16
97
+ //
98
+ // Threads access shared index in following pattern
99
+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 1
100
+ // Shared Index | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | X | X | X | X | X | X | X | X | index *= 1
101
+ //
102
+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 2
103
+ // Shared Index | 0 | 4 | 8 | 12 | X | X | X | X | X | X | X | X | X | X | X | X | index *= 2
104
+ //
105
+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 4
106
+ // Shared Index | 0 | 8 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 4
107
+ //
108
+ // Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 8
109
+ // Shared Index | 0 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 8
110
+
111
+ void reduce_input(const int width_stride, const int shared_idx_offset) {
112
+ // wait for all shared memory writes to finish
113
+ memoryBarrierShared();
114
+ barrier();
115
+
116
+ // loop log(width_stride) times
117
+ for (int current_stride = 1 , index = int (gl_LocalInvocationID.x << 1 ); current_stride < width_stride; current_stride *= 2 , index <<= 1 ) {
118
+ // if the index at this thread is within the width stride
119
+ if (index < width_stride) {
120
+ const int local_shared_idx = shared_idx_offset + index;
121
+ // add the value at current stride to this thread's value
122
+ shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)];
123
+ }
48
124
49
- if ( any ( greaterThanEqual (lpos, out_limits))) {
50
- return ;
125
+ memoryBarrierShared();
126
+ barrier() ;
51
127
}
128
+ }
52
129
130
+ void reduce_non_packed_dim() {
131
+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
53
132
const int width = int (sizes.x);
133
+ ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
54
134
55
- if (in_packed_dim != W_DIM) {
56
- VEC4_T mean = VEC4_T(0 );
57
- VEC4_T delta = VEC4_T(0 );
58
- VEC4_T delta2 = VEC4_T(0 );
59
- VEC4_T M2 = VEC4_T(0 );
60
-
61
- // Use Welford's online algorithm to compute mean and variance in one pass
62
- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
63
- ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
64
- for (int w = 0 ; w < width; ++ w) {
65
- in_pos[in_axis_map.x] = w;
66
- VEC4_T v = load_texel(t_in, in_pos);
67
- delta = v - mean;
68
- mean += delta / (w + 1 );
69
- delta2 = v - mean;
70
- M2 += delta * delta2;
135
+ // width batch read stride
136
+ const int width_stride = int (gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
137
+
138
+ // local memory starting offset for this thread
139
+ const int shared_idx_offset = width_stride * int (gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
140
+
141
+ // local memory index for this thread
142
+ const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
143
+
144
+ VEC4_T mean = VEC4_T(0 );
145
+ VEC4_T var = VEC4_T(0 );
146
+
147
+ // Loop over the width in stride increments
148
+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
149
+ // Read input in shared memory
150
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
151
+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
152
+
153
+ VEC4_T in_val = VEC4_T(0 );
154
+ if (all (lessThan (in_pos, out_limits))) {
155
+ in_val = load_texel(t_in, in_pos);
156
+ }
157
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
71
158
}
72
159
73
- VEC4_T var = M2 / width;
74
- VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
75
- VEC4_T offset = - rstd * mean;
76
-
77
- for (int w = 0 ; w < width; ++ w) {
78
- in_pos[in_axis_map.x] = w;
79
- VEC4_T v = load_texel(t_in, in_pos);
80
- // broadcasting
81
- VEC4_T weight = load_texel(t_weight, ivec3 (w, 0 , 0 )).xxxx;
82
- VEC4_T bias = load_texel(t_bias, ivec3 (w, 0 , 0 )).xxxx;
83
- VEC4_T outtex = (v * rstd + offset) * weight + bias;
84
- write_texel_lpos(t_out, ivec3 (w, lpos.y, lpos.z), outtex, out_axis_map);
160
+ reduce_input(width_stride, shared_idx_offset);
161
+ mean += shared_input[offset_pos_index(shared_idx_offset)];
162
+ }
163
+
164
+ mean /= width;
165
+
166
+ memoryBarrierShared();
167
+ barrier();
168
+
169
+ // Loop over the width in stride increments
170
+ for (int width_offset = 0 ; width_offset < width; width_offset += width_stride) {
171
+ // Read input in shared memory
172
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
173
+ in_pos[in_axis_map.x] = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
174
+
175
+ VEC4_T in_val = mean;
176
+ if (all (lessThan (in_pos, out_limits))) {
177
+ in_val = load_texel(t_in, in_pos);
178
+ }
179
+
180
+ const VEC4_T delta = in_val - mean;
181
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
85
182
}
86
183
184
+ reduce_input(width_stride, shared_idx_offset);
185
+ var += shared_input[offset_pos_index(shared_idx_offset)];
186
+ }
187
+
188
+ var /= width;
189
+
190
+ VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
191
+ VEC4_T offset = - rstd * mean;
192
+
193
+ VEC4_T v = load_texel(t_in, lpos);
194
+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 )).xxxx;
195
+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 )).xxxx;
196
+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
197
+
198
+ if (all (lessThan (lpos, out_limits))) {
199
+ write_texel_lpos(t_out, lpos, outtex, out_axis_map);
200
+ }
201
+
202
+ if (gl_GlobalInvocationID.x == 0 ) {
87
203
write_texel(t_mean, lpos, mean);
88
204
write_texel(t_rstd, lpos, rstd);
89
- } else {
90
- const int packed_width = divup4(width);
91
-
92
- T mean = T(0 );
93
- T delta = T(0 );
94
- T delta2 = T(0 );
95
- T M2 = T(0 );
96
- // Use Welford's online algorithm to compute mean and variance in one pass
97
- // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
98
- ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
99
- T width_counter = T(1 );
100
-
101
- const bool has_unaligned_width = (width & 0x3) != 0 ;
102
- const int fully_packed_4_comp_count = packed_width - mix (0 , 1 , has_unaligned_width);
103
-
104
- // iterate through texels that are fully packed ie. has 4 components
105
- for (int w = 0 ; w < fully_packed_4_comp_count; ++ w) {
106
- in_pos[in_axis_map.x] = w;
107
- VEC4_T v = load_texel(t_in, in_pos);
108
- for (int i= 0 ; i< 4 ; i++ ) {
109
- delta = v[i] - mean;
110
- mean += delta / width_counter;
111
- delta2 = v[i] - mean;
112
- M2 += delta * delta2;
113
- width_counter++ ;
205
+ }
206
+ }
207
+
208
+ void reduce_packed_dim() {
209
+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
210
+ const int width = int (sizes.x);
211
+ ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
212
+
213
+ // width batch read stride
214
+ const int width_stride = int (gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
215
+
216
+ // local memory starting offset for this thread
217
+ const int shared_idx_offset = width_stride * int (gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
218
+
219
+ // local memory index for this thread
220
+ const int shared_idx = shared_idx_offset + int (gl_LocalInvocationID.x);
221
+
222
+ const int last_packed_width_index = divup4(width) - 1 ;
223
+ T mean = T(0 );
224
+ T var = T(0 );
225
+ const int remain = width & 3 ;
226
+
227
+ const int in_pos_x_limit = out_limits[in_axis_map.x];
228
+
229
+ // Loop over the width in stride increments
230
+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
231
+ // Read input in shared memory
232
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
233
+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
234
+ in_pos[in_axis_map.x] = in_pos_x;
235
+
236
+ VEC4_T in_val = VEC4_T(0 );
237
+ if (in_pos_x < in_pos_x_limit) {
238
+ in_val = load_texel(t_in, in_pos);
114
239
}
115
- }
116
240
117
- // handle last texel if its not 4 aligned
118
- if (has_unaligned_width) {
119
- in_pos[in_axis_map.x] = fully_packed_4_comp_count;
120
- const int remaining_width = width & 0x3;
121
-
122
- VEC4_T v = load_texel(t_in, in_pos);
123
- for (int i= 0 ; i< remaining_width; i++ ) {
124
- delta = v[i] - mean;
125
- mean += delta / width_counter;
126
- delta2 = v[i] - mean;
127
- M2 += delta * delta2;
128
- width_counter++ ;
241
+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
242
+ const int remain_inv = 4 - remain;
243
+ in_val.y = mix (in_val.y, T(0 ), remain_inv > 2 );
244
+ in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
245
+ in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
129
246
}
247
+
248
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
130
249
}
131
250
132
- T var = M2 / (width_counter - 1 );
133
- T rstd = inversesqrt (var + epsilon);
134
- T offset = - rstd * mean;
135
-
136
- for (int w = 0 ; w < packed_width; ++ w) {
137
- in_pos[in_axis_map.x] = w;
138
- VEC4_T v = load_texel(t_in, in_pos);
139
- VEC4_T weight = load_texel(t_weight, ivec3 (w, 0 , 0 ));
140
- VEC4_T bias = load_texel(t_bias, ivec3 (w, 0 , 0 ));
141
- VEC4_T outtex = (v * rstd + offset) * weight + bias;
142
- write_texel_lpos(t_out, ivec3 (w, lpos.y, lpos.z), outtex, out_axis_map);
251
+ reduce_input(width_stride, shared_idx_offset);
252
+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253
+ mean += val.x + val.y + val.z + val.w;
254
+ }
255
+
256
+ mean /= width;
257
+
258
+ memoryBarrierShared();
259
+ barrier();
260
+
261
+ // Loop over the width in stride increments
262
+ for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
263
+ // Read input in shared memory
264
+ for (int si = 0 ; si < SHARED_MEMORY_FACTOR; si++ ) {
265
+ const int in_pos_x = width_offset + int (gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
266
+ in_pos[in_axis_map.x] = in_pos_x;
267
+
268
+ VEC4_T in_val = VEC4_T(mean);
269
+ if (in_pos_x < in_pos_x_limit) {
270
+ in_val = load_texel(t_in, in_pos);
271
+ }
272
+
273
+ if (in_pos_x == last_packed_width_index && remain != 0 ) {
274
+ const int remain_inv = 4 - remain;
275
+ in_val.y = mix (in_val.y, mean.x, remain_inv > 2 );
276
+ in_val.z = mix (in_val.z, mean.x, remain_inv > 1 );
277
+ in_val.w = mix (in_val.w, mean.x, remain_inv > 0 );
278
+ }
279
+
280
+ const VEC4_T delta = in_val - mean;
281
+ const VEC4_T delta2 = delta * delta;
282
+ shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
143
283
}
144
284
285
+ reduce_input(width_stride, shared_idx_offset);
286
+ const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287
+ var += val.x + val.y + val.z + val.w;
288
+ }
289
+
290
+ var /= width;
291
+
292
+ T rstd = pow (var + epsilon, T(- 0.5 ));
293
+ T offset = - rstd * mean;
294
+
295
+ VEC4_T v = load_texel(t_in, lpos);
296
+ VEC4_T weight = load_texel(t_weight, ivec3 (lpos.x, 0 , 0 ));
297
+ VEC4_T bias = load_texel(t_bias, ivec3 (lpos.x, 0 , 0 ));
298
+ VEC4_T outtex = (v * rstd + offset) * weight + bias;
299
+
300
+ if (all (lessThan (lpos, out_limits))) {
301
+ write_texel_lpos(t_out, lpos, outtex, out_axis_map);
302
+ }
303
+
304
+ if (gl_GlobalInvocationID.x == 0 ) {
145
305
write_texel(t_mean, lpos, VEC4_T(mean));
146
306
write_texel(t_rstd, lpos, VEC4_T(rstd));
147
307
}
148
308
}
309
+
310
+ void main() {
311
+ // if packed dimension width
312
+ if (in_packed_dim != W_DIM) {
313
+ reduce_non_packed_dim();
314
+ } else {
315
+ reduce_packed_dim();
316
+ }
317
+ }
0 commit comments