Skip to content

Commit

Permalink
Fix more rebase issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored and ArthurZucker committed Jan 13, 2025
1 parent 2d69919 commit e2d4d3a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
8 changes: 2 additions & 6 deletions src/transformers/models/helium/modeling_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)

def forward(
self,
Expand Down Expand Up @@ -502,11 +502,7 @@ def __init__(self, config: HeliumConfig):
[HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = HeliumRotaryEmbedding(
dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
self.rotary_emb = HeliumRotaryEmbedding(config)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/helium/modular_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class HeliumAttention(GraniteAttention):
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.scaling = 1 / math.sqrt(self.head_dim)


Expand All @@ -137,11 +137,7 @@ def __init__(self, config: HeliumConfig):
[HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = HeliumRotaryEmbedding(
dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
self.rotary_emb = HeliumRotaryEmbedding(config)
self.gradient_checkpointing = False

# Initialize weights and apply final processing
Expand Down

0 comments on commit e2d4d3a

Please sign in to comment.