Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds manual induction head support to the Infinite Attention mechanism, enabling models to use specialized attention heads where queries and keys share the same projection (tied-Wk). The implementation includes configuration updates, new projection layers for induction heads, and integration into the attention computation pipeline.
Changes:
- Added
n_induction_headandn_ind_head_dimparameters to configuration classes and CLI, allowing specification of manual induction head count and dimensions - Modified InfiniteHeadAttention to initialize separate projection layers (
c_attn_k_ind,c_attn_v_ind,c_proj_ind) and compute attention for manual induction heads alongside regular heads - Created YAML configuration sweep to explore various ratios of manual induction heads to regular heads and different head dimensions
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| gpt_conf.py | Added configuration fields for manual induction head parameters |
| train_args.py | Added CLI arguments for specifying induction head count and dimensions |
| variations/attention_variations.py | Implemented manual induction head support with initialization, forward pass computation, and output integration |
| explorations/infinite_manual_induction_dim_sweep.yaml | Defined sweep experiments testing different head ratios and dimensionalities |
Comments suppressed due to low confidence (2)
variations/attention_variations.py:1325
- The post-activation L2 normalization (line 1322) and cproj_scale division (line 1325) are applied to the regular attention output y but not to the manual induction output y_ind before they are combined. This could lead to scale mismatches when combining the two outputs (line 1366), especially when post_act_l2_norm is enabled or cproj_scale is not 1.0. Consider applying these transformations to y_ind as well before combining.
if self.post_act_l2_norm:
y = y / y.norm(dim=-1, keepdim=True).clamp_min(1e-6)
if self.cproj_scale is not None and self.cproj_scale != 1.0:
y = y / self.cproj_scale
variations/attention_variations.py:1231
- Manual induction heads (q_ind, k_ind) are not receiving rotary position encodings or QK normalization, while the regular attention heads (q, k) are. This creates an inconsistency in positional information and normalization between the two types of heads. Consider whether manual induction heads should also receive these transformations, especially rotary embeddings which encode positional information that may be important for induction behavior.
# Apply Rotary Position Encodings
if (self.rotary_emb_q is not None) and (self.rotary_emb_k is not None):
q = self.rotary_emb_q(q)
k = self.rotary_emb_k(k)
# Apply QK Norm
if self.use_qk_norm:
q = q / (q.norm(dim=-1, keepdim=True) + 1e-6)
k = k / (k.norm(dim=-1, keepdim=True) + 1e-6)
if self.use_v_norm:
v = v / (v.norm(dim=-1, keepdim=True) + 1e-6)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.c_attn_q = self.linear_variant_q(self.n_embd, self.n_head * self.n_qk_head_dim, config, bias=config.bias) | ||
| self.c_attn_k = self.linear_variant_k(self.n_embd, self.n_kv_group * self.n_qk_head_dim, config, bias=config.bias) | ||
| self.c_attn_v = self.linear_variant_v(self.n_embd, self.n_kv_group * self.n_v_head_dim, config, bias=config.bias) | ||
|
|
||
| if self.use_manual_induction: | ||
| # "Manual induction" heads tie Q and K to a shared Wk projection. | ||
| ind_total_dim = self.n_induction_head * self.n_ind_head_dim | ||
| self.c_attn_k_ind = self.linear_variant_k(self.n_embd, ind_total_dim, config, bias=config.bias) | ||
| self.c_attn_v_ind = self.linear_variant_v(self.n_embd, ind_total_dim, config, bias=config.bias) | ||
| self.c_proj_ind = self.linear_variant_attn_proj(ind_total_dim, self.n_embd, config, bias=config.bias) | ||
|
|
There was a problem hiding this comment.
When n_head is 0 (as configured in the YAML sweep at line 59), the code will fail when computing c_attn_q projection dimensions. Line 1074 calculates self.n_head * self.n_qk_head_dim which would be 0 when n_head is 0, creating an invalid linear layer with 0 output dimensions. The code should either validate that n_head is positive when use_manual_induction is false, or handle the case where only manual induction heads are used without regular attention heads.
| @@ -1239,6 +1272,15 @@ | |||
| is_causal=True, | |||
| ) | |||
There was a problem hiding this comment.
When n_head is 0, the forward pass will fail at multiple points. Lines 1211-1213 will attempt to view/reshape tensors with 0 heads, and the _expand_kv function (line 1237-1238) and attention computation (lines 1266-1273 or 1290-1311) will all process empty tensors. The code should have logic to skip regular attention computation entirely when n_head is 0 and only manual induction heads are present.
| # Concat Heads or Inf Concat Heads | ||
| if self.use_concat_heads: | ||
| # (B, nh, T, v_dim) → (B, T, nh*v_dim); avoid extra .contiguous() | ||
| # flatten heads → (B, T, n_head * n_v_head_dim) | ||
| y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.n_v_head_dim) | ||
| if self.l2_norm_attn_cproj: | ||
| cproj_weight = F.normalize(self.c_proj.weight, p=2, dim=self.cproj_norm_dim) | ||
| y = F.linear(y, cproj_weight, self.c_proj.bias) | ||
| else: | ||
| y = self.c_proj(y) | ||
| elif self.n_cproj == 1: | ||
| # Sum heads first: (B, nh, T, v_dim) → (B, T, v_dim) | ||
| y = y.sum(dim=1) | ||
| if self.l2_norm_attn_cproj: | ||
| cproj_weight = F.normalize(self.c_proj.weight, p=2, dim=self.cproj_norm_dim) | ||
| y = F.linear(y, cproj_weight, self.c_proj.bias) | ||
| else: | ||
| y = self.c_proj(y) | ||
| else: | ||
| # Sum heads first: (B, nh, T, v_dim) → (B, T, v_dim) | ||
| y_sum = y.sum(dim=1) | ||
|
|
||
| # Parallel small projections then fuse; avoids Python-level loop | ||
| if self.l2_norm_attn_cproj: | ||
| proj_outputs = [ | ||
| F.linear(y_sum, F.normalize(proj.weight, p=2, dim=self.cproj_norm_dim), proj.bias) | ||
| for proj in self.c_proj_list | ||
| ] | ||
| else: | ||
| proj_outputs = [proj(y_sum) for proj in self.c_proj_list] | ||
| y = torch.stack(proj_outputs, dim=0).sum(dim=0) |
There was a problem hiding this comment.
When n_head is 0, lines 1327-1357 will fail because they assume y exists and has valid dimensions for projection. The code needs to handle the case where only manual induction heads are used (n_head=0) by skipping the regular attention output processing and only computing y_ind, then using it directly instead of adding it to y.
| if self.use_manual_induction: | ||
| y_ind = torch.nn.functional.scaled_dot_product_attention( | ||
| q_ind, | ||
| k_ind, | ||
| v_ind, | ||
| dropout_p=self.dropout if self.training else 0, | ||
| is_causal=True, | ||
| ) |
There was a problem hiding this comment.
In the flash attention path, when use_qk_norm_scale is enabled, the regular attention q is scaled by qk_scaling_factor (lines 1244-1245), but q_ind for manual induction heads is not scaled before being passed to scaled_dot_product_attention (lines 1276-1282). This inconsistency means manual induction heads won't benefit from the learned QK norm scaling when it's enabled, potentially leading to different attention behaviors between the two head types.
| att_ind = (q_ind @ k_ind.transpose(-2, -1)) | ||
| att_ind = att_ind / math.sqrt(self.n_ind_head_dim) | ||
| att_ind = att_ind.masked_fill(self.bias[:, :, :T, :T].to(x.device) == 0, float('-inf')) | ||
| att_ind = F.softmax(att_ind, dim=-1) |
There was a problem hiding this comment.
In the manual attention implementation path, the manual induction heads use standard softmax (line 1317) while the regular attention heads can use a configurable softmax_variant_attn (lines 1304-1307). This inconsistency means manual induction heads won't benefit from alternative softmax variants that may be configured for better attention stability or performance.
| att_ind = F.softmax(att_ind, dim=-1) | |
| if self.softmax_variant_attn != 'softmax': | |
| att_ind = self.softmax_layer_attn(att_ind) | |
| else: | |
| att_ind = F.softmax(att_ind, dim=-1) |
This pull request introduces support for manual induction heads in Infinite Attention, allowing for more flexible configuration of attention head types and dimensions. The main changes involve updating configuration files and classes to accept new parameters, modifying the attention module to implement manual induction heads, and updating the forward pass logic to integrate these heads into the attention computation.
Manual Induction Head Support for Infinite Attention
Configuration and Argument Updates:
n_induction_headandn_ind_head_dimparameters toGPTConfigand CLI argument parsing, enabling specification of the number and dimensionality of manual induction heads. [1] [2]explorations/infinite_manual_induction_dim_sweep.yamlto define sweep experiments for manual induction head dimensionality and head ratios.Attention Module Implementation:
variations/attention_variations.pyto initialize additional projection layers (c_attn_k_ind,c_attn_v_ind,c_proj_ind) for manual induction heads when enabled, and validated required parameters. [1] [2]