Skip to content

Commit

Permalink
fuse q and kve parameters for qga case (bigscience-workshop#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaYazdaniAminabadi authored Nov 21, 2023
1 parent 2348eed commit 8415d03
Showing 1 changed file with 28 additions and 56 deletions.
84 changes: 28 additions & 56 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,39 +532,23 @@ def __init__(self, config, layer_number,
config.num_attention_heads, world_size)

# Per GQA head and per partition values
if self.use_gqa:
kv_projection_size = config.kv_channels * config.num_key_value_heads
self.num_key_value_heads_per_partition = core.utils.divide(
config.num_key_value_heads, world_size)
self.num_key_value_groups = core.utils.divide(
config.num_attention_heads, config.num_key_value_heads)
assert self.hidden_size_per_attention_head == core.utils.divide(
kv_projection_size, config.num_key_value_heads)
self.num_key_value_heads_per_partition = core.utils.divide(
config.num_key_value_heads, world_size)
self.num_key_value_groups = core.utils.divide(
config.num_attention_heads, config.num_key_value_heads)
kv_projection_size = config.kv_channels * config.num_key_value_heads
assert self.hidden_size_per_attention_head == core.utils.divide(
kv_projection_size, config.num_key_value_heads)

# Strided linear layer.
if attention_type == AttnType.self_attn and not self.use_gqa:
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
3 * projection_size,
projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False)
elif attention_type == AttnType.self_attn and self.use_gqa:
self.query = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
self.key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
Expand Down Expand Up @@ -657,6 +641,13 @@ def repeat_kv(self, hidden_states, n_rep):
return hidden_states.reshape(slen, batch,
num_key_value_heads_per_partition * n_rep,
head_dim)

def split_tensor(self, mixed_x_layer):
query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:-1] + (-1, self.hidden_size_per_attention_head))
key_layer = mixed_x_layer[:, :, :, -2, :]
value_layer = mixed_x_layer[:, :, :, -1, :]

return query_layer, key_layer, value_layer

def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
Expand Down Expand Up @@ -686,45 +677,26 @@ def forward(self, hidden_states, attention_mask,
# Query, Key, and Value
# =====================

if self.attention_type == AttnType.self_attn and not self.use_gqa:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
# [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
(-1, (self.num_key_value_groups + 2),
self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
# [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
(query_layer
key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
elif self.attention_type == AttnType.self_attn and self.use_gqa:
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)

# Attention heads [sq, b, h] --> [sq, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(hidden_states)
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_key_value_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn]
(key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(
mixed_kv_layer, 2)
value_layer) = self.split_tensor(mixed_x_layer)

# Repeat kv
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
value_layer = self.repeat_kv(value_layer,
self.num_key_value_groups)
if self.use_gqa:
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
value_layer = self.repeat_kv(value_layer,
self.num_key_value_groups)
else:
assert not self.use_gqa, 'GQA + cross-attn not tested yet'

Expand Down

0 comments on commit 8415d03

Please sign in to comment.