Skip to content

Commit e1d9986

Browse files
committed
[ET-VK] Minor performance improvements to native layer norm.
Pull Request resolved: #9892 This diff introduces minor performance improvements to the native layer norm function in the Vulkan backend of Executorch. In this new approach: The mean and variance values are calculated in 2 separate passes. Shader is dispatched based on input texture size, and input texel is read and stored in shared memory. Input stored in shard memory is then summed up using a reduce function. This implementation better utilizes a GPUs parallel processing capabilities. Differential Revision: [D72430290](https://our.internmc.facebook.com/intern/diff/D72430290/) ghstack-source-id: 276439596
1 parent 56c8dc2 commit e1d9986

File tree

2 files changed

+252
-83
lines changed

2 files changed

+252
-83
lines changed

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

+250-81
Original file line numberDiff line numberDiff line change
@@ -43,106 +43,275 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
4343
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4444
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4545

46-
void main() {
47-
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
46+
#define MAX_WORKGROUP_SIZE 64
47+
48+
// Shared memory factor increases shared memory allocation by a scale that should either be 1 or a power of 2.
49+
//
50+
// Increasing factor allows more data to be stored in shared memory and increase thread utilization during reduction.
51+
// Why? Because when performing reduction, the number of active threads becomes half in each iteration.
52+
// Increasing scaling factor increases the thread occupancy and hence utilize the GPU better.
53+
// eg.
54+
// If local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 1, 32 elements will be loaded into shared memory.
55+
// First iteration of reduce will have 16 threads sum up 32 elements.
56+
// Second iteration will have 8 threads sum up 16 elements from previous iteration and so on.
57+
// So thread utilization starts at 50%.
58+
//
59+
// By contrast if local thread size in x dimension is 32, and SHARED_MEMORY_FACTOR is 2, 64 elements will be loaded into shared memory.
60+
// First iteration of reduce will have 32 threads sum up 64 elements.
61+
// Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
62+
// Thus thread utilization starts at 100%.
63+
#define SHARED_MEMORY_FACTOR 2
64+
65+
#define offset_pos_index(index) ((index) + ((index) >> 2))
66+
67+
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
68+
69+
// Function to reduce input data in workgroup's x dimension
70+
//
71+
// The implementation resembles reduction as depicted below
72+
// | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | 2 | 3 | 2 | 7 | 0 | 11 | 0 | 2 | current_stride -> 1
73+
// | / | / | / | / | / | / | / | /
74+
// | / | / | / | / | / | / | / | /
75+
// | / | / | / | / | / | / | / | /
76+
// | 11 | 1 | 9 | 1 | 2 | 2 | 8 | 5 | 5 | 3 | 9 | 7 | 11 | 11 | 2 | 2 | current_stride -> 2
77+
// | / | / | / | /
78+
// | / | / | / | /
79+
// | / | / | / | /
80+
// | 20 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |14 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 4
81+
// | / | /
82+
// | / | /
83+
// | / | /
84+
// | / | /
85+
// | / | /
86+
// | 30 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride -> 8
87+
// | /
88+
// | /
89+
// | /
90+
// | /
91+
// | /
92+
// | /
93+
// | /
94+
// | /
95+
// | /
96+
// | 57 | 1 | 9 | 1 | 10 | 2 | 8 | 5 |27 | 3 | 9 | 7 |13 | 11 | 2 | 2 | current_stride = -> 16
97+
//
98+
// Threads access shared index in following pattern
99+
// Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 1
100+
// Shared Index | 0 | 2 | 4 | 6 | 8 | 10 | 12 | 14 | X | X | X | X | X | X | X | X | index *= 1
101+
//
102+
// Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 2
103+
// Shared Index | 0 | 4 | 8 | 12 | X | X | X | X | X | X | X | X | X | X | X | X | index *= 2
104+
//
105+
// Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 4
106+
// Shared Index | 0 | 8 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 4
107+
//
108+
// Thread | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | current_stride -> 8
109+
// Shared Index | 0 | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | index *= 8
110+
111+
void reduce_input(const int width_stride, const int shared_idx_offset) {
112+
// wait for all shared memory writes to finish
113+
memoryBarrierShared();
114+
barrier();
115+
116+
// loop log(width_stride) times
117+
for (int current_stride = 1, index = int(gl_LocalInvocationID.x << 1); current_stride < width_stride; current_stride *= 2, index <<= 1) {
118+
// if the index at this thread is within the width stride
119+
if (index < width_stride) {
120+
const int local_shared_idx = shared_idx_offset + index;
121+
// add the value at current stride to this thread's value
122+
shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)];
123+
}
48124

49-
if (any(greaterThanEqual(lpos, out_limits))) {
50-
return;
125+
memoryBarrierShared();
126+
barrier();
51127
}
128+
}
52129

130+
void reduce_non_packed_dim() {
131+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
53132
const int width = int(sizes.x);
133+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
54134

55-
if (in_packed_dim != W_DIM) {
56-
VEC4_T mean = VEC4_T(0);
57-
VEC4_T delta = VEC4_T(0);
58-
VEC4_T delta2 = VEC4_T(0);
59-
VEC4_T M2 = VEC4_T(0);
60-
61-
// Use Welford's online algorithm to compute mean and variance in one pass
62-
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
63-
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
64-
for (int w = 0; w < width; ++w) {
65-
in_pos[in_axis_map.x] = w;
66-
VEC4_T v = load_texel(t_in, in_pos);
67-
delta = v - mean;
68-
mean += delta / (w + 1);
69-
delta2 = v - mean;
70-
M2 += delta * delta2;
135+
// width batch read stride
136+
const int width_stride = int(gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
137+
138+
// local memory starting offset for this thread
139+
const int shared_idx_offset = width_stride * int(gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
140+
141+
// local memory index for this thread
142+
const int shared_idx = shared_idx_offset + int(gl_LocalInvocationID.x);
143+
144+
VEC4_T mean = VEC4_T(0);
145+
VEC4_T var = VEC4_T(0);
146+
147+
// Loop over the width in stride increments
148+
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
149+
// Read input in shared memory
150+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
151+
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
152+
153+
VEC4_T in_val = VEC4_T(0);
154+
if (all(lessThan(in_pos, out_limits))) {
155+
in_val = load_texel(t_in, in_pos);
156+
}
157+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
71158
}
72159

73-
VEC4_T var = M2 / width;
74-
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
75-
VEC4_T offset = -rstd * mean;
76-
77-
for (int w = 0; w < width; ++w) {
78-
in_pos[in_axis_map.x] = w;
79-
VEC4_T v = load_texel(t_in, in_pos);
80-
// broadcasting
81-
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
82-
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
83-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
84-
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
160+
reduce_input(width_stride, shared_idx_offset);
161+
mean += shared_input[offset_pos_index(shared_idx_offset)];
162+
}
163+
164+
mean /= width;
165+
166+
memoryBarrierShared();
167+
barrier();
168+
169+
// Loop over the width in stride increments
170+
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
171+
// Read input in shared memory
172+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
173+
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
174+
175+
VEC4_T in_val = mean;
176+
if (all(lessThan(in_pos, out_limits))) {
177+
in_val = load_texel(t_in, in_pos);
178+
}
179+
180+
const VEC4_T delta = in_val - mean;
181+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
85182
}
86183

184+
reduce_input(width_stride, shared_idx_offset);
185+
var += shared_input[offset_pos_index(shared_idx_offset)];
186+
}
187+
188+
var /= width;
189+
190+
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
191+
VEC4_T offset = -rstd * mean;
192+
193+
VEC4_T v = load_texel(t_in, lpos);
194+
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0)).xxxx;
195+
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0)).xxxx;
196+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
197+
198+
if (all(lessThan(lpos, out_limits))) {
199+
write_texel_lpos(t_out, lpos, outtex, out_axis_map);
200+
}
201+
202+
if (gl_GlobalInvocationID.x == 0) {
87203
write_texel(t_mean, lpos, mean);
88204
write_texel(t_rstd, lpos, rstd);
89-
} else {
90-
const int packed_width = divup4(width);
91-
92-
T mean = T(0);
93-
T delta = T(0);
94-
T delta2 = T(0);
95-
T M2 = T(0);
96-
// Use Welford's online algorithm to compute mean and variance in one pass
97-
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
98-
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
99-
T width_counter = T(1);
100-
101-
const bool has_unaligned_width = (width & 0x3) != 0;
102-
const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width);
103-
104-
// iterate through texels that are fully packed ie. has 4 components
105-
for (int w = 0; w < fully_packed_4_comp_count; ++w) {
106-
in_pos[in_axis_map.x] = w;
107-
VEC4_T v = load_texel(t_in, in_pos);
108-
for (int i=0; i<4; i++) {
109-
delta = v[i] - mean;
110-
mean += delta / width_counter;
111-
delta2 = v[i] - mean;
112-
M2 += delta * delta2;
113-
width_counter++;
205+
}
206+
}
207+
208+
void reduce_packed_dim() {
209+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
210+
const int width = int(sizes.x);
211+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
212+
213+
// width batch read stride
214+
const int width_stride = int(gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;
215+
216+
// local memory starting offset for this thread
217+
const int shared_idx_offset = width_stride * int(gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);
218+
219+
// local memory index for this thread
220+
const int shared_idx = shared_idx_offset + int(gl_LocalInvocationID.x);
221+
222+
const int last_packed_width_index = divup4(width) - 1;
223+
T mean = T(0);
224+
T var = T(0);
225+
const int remain = width & 3;
226+
227+
const int in_pos_x_limit = out_limits[in_axis_map.x];
228+
229+
// Loop over the width in stride increments
230+
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
231+
// Read input in shared memory
232+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
233+
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
234+
in_pos[in_axis_map.x] = in_pos_x;
235+
236+
VEC4_T in_val = VEC4_T(0);
237+
if (in_pos_x < in_pos_x_limit) {
238+
in_val = load_texel(t_in, in_pos);
114239
}
115-
}
116240

117-
// handle last texel if its not 4 aligned
118-
if (has_unaligned_width) {
119-
in_pos[in_axis_map.x] = fully_packed_4_comp_count;
120-
const int remaining_width = width & 0x3;
121-
122-
VEC4_T v = load_texel(t_in, in_pos);
123-
for (int i=0; i<remaining_width; i++) {
124-
delta = v[i] - mean;
125-
mean += delta / width_counter;
126-
delta2 = v[i] - mean;
127-
M2 += delta * delta2;
128-
width_counter++;
241+
if (in_pos_x == last_packed_width_index && remain != 0) {
242+
const int remain_inv = 4 - remain;
243+
in_val.y = mix(in_val.y, T(0), remain_inv > 2);
244+
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
245+
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
129246
}
247+
248+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
130249
}
131250

132-
T var = M2 / (width_counter - 1);
133-
T rstd = inversesqrt(var + epsilon);
134-
T offset = -rstd * mean;
135-
136-
for (int w = 0; w < packed_width; ++w) {
137-
in_pos[in_axis_map.x] = w;
138-
VEC4_T v = load_texel(t_in, in_pos);
139-
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0));
140-
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0));
141-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
142-
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
251+
reduce_input(width_stride, shared_idx_offset);
252+
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+
mean += val.x + val.y + val.z + val.w;
254+
}
255+
256+
mean /= width;
257+
258+
memoryBarrierShared();
259+
barrier();
260+
261+
// Loop over the width in stride increments
262+
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
263+
// Read input in shared memory
264+
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
265+
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
266+
in_pos[in_axis_map.x] = in_pos_x;
267+
268+
VEC4_T in_val = VEC4_T(mean);
269+
if (in_pos_x < in_pos_x_limit) {
270+
in_val = load_texel(t_in, in_pos);
271+
}
272+
273+
if (in_pos_x == last_packed_width_index && remain != 0) {
274+
const int remain_inv = 4 - remain;
275+
in_val.y = mix(in_val.y, mean.x, remain_inv > 2);
276+
in_val.z = mix(in_val.z, mean.x, remain_inv > 1);
277+
in_val.w = mix(in_val.w, mean.x, remain_inv > 0);
278+
}
279+
280+
const VEC4_T delta = in_val - mean;
281+
const VEC4_T delta2 = delta * delta;
282+
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
143283
}
144284

285+
reduce_input(width_stride, shared_idx_offset);
286+
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287+
var += val.x + val.y + val.z + val.w;
288+
}
289+
290+
var /= width;
291+
292+
T rstd = pow(var + epsilon, T(-0.5));
293+
T offset = -rstd * mean;
294+
295+
VEC4_T v = load_texel(t_in, lpos);
296+
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0));
297+
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0));
298+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
299+
300+
if (all(lessThan(lpos, out_limits))) {
301+
write_texel_lpos(t_out, lpos, outtex, out_axis_map);
302+
}
303+
304+
if (gl_GlobalInvocationID.x == 0) {
145305
write_texel(t_mean, lpos, VEC4_T(mean));
146306
write_texel(t_rstd, lpos, VEC4_T(rstd));
147307
}
148308
}
309+
310+
void main() {
311+
// if packed dimension width
312+
if (in_packed_dim != W_DIM) {
313+
reduce_non_packed_dim();
314+
} else {
315+
reduce_packed_dim();
316+
}
317+
}

backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ void add_native_layer_norm_node(
8383

8484
std::vector<int64_t> in_sizes = t_input->sizes();
8585

86-
utils::uvec3 global_size = t_mean->logical_limits();
87-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
86+
utils::uvec3 global_size = t_out->logical_limits();
87+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
8888

8989
std::string kernel_name("native_layer_norm");
9090
kernel_name.reserve(kShaderNameReserve);

0 commit comments

Comments
 (0)