Skip to content

Commit 0c3b272

Browse files
authored
Fix bias update in scoped all reduce (#1456)
1 parent 309e0c4 commit 0c3b272

File tree

7 files changed

+12
-12
lines changed

7 files changed

+12
-12
lines changed

optimum/habana/transformers/models/falcon/modeling_falcon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def attention_all_reduce(self, attn_output):
578578

579579
def post_attn_forward(self, attn_output):
580580
if hasattr(self.dense, "all_reduce"):
581-
self.dense.post_all_reduce(attn_output)
581+
return self.dense.post_all_reduce(attn_output)
582582
return attn_output
583583

584584

@@ -598,7 +598,7 @@ def mlp_all_reduce(self, x):
598598

599599
def post_mlp_forward(self, x):
600600
if hasattr(self.dense_4h_to_h, "all_reduce"):
601-
self.dense_4h_to_h.post_all_reduce(x)
601+
return self.dense_4h_to_h.post_all_reduce(x)
602602
return x
603603

604604

optimum/habana/transformers/models/gemma/modeling_gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def attention_all_reduce(self, attn_output):
357357

358358
def post_attn_forward(self, attn_output):
359359
if hasattr(self.o_proj, "post_all_reduce"):
360-
self.o_proj.post_all_reduce(attn_output)
360+
return self.o_proj.post_all_reduce(attn_output)
361361
return attn_output
362362

363363

optimum/habana/transformers/models/llama/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ def attention_all_reduce(self, attn_output):
760760

761761
def post_attn_forward(self, attn_output):
762762
if hasattr(self.o_proj, "post_all_reduce"):
763-
self.o_proj.post_all_reduce(attn_output)
763+
return self.o_proj.post_all_reduce(attn_output)
764764
return attn_output
765765

766766

optimum/habana/transformers/models/modeling_all_models.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,5 @@ def all_reduce(self, input):
164164
dist.inference_all_reduce(input, group=self.mp_group)
165165

166166
def post_all_reduce(self, input):
167-
# inplace addition needed for correct results
168-
if self.bias is not None:
169-
input += self.bias
170-
return input
167+
output = input + self.bias if (self.bias is not None) else input
168+
return output

optimum/habana/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def attention_all_reduce(self, attn_output):
419419

420420
def post_attn_forward(self, attn_output):
421421
if hasattr(self.o_proj, "post_all_reduce"):
422-
self.o_proj.post_all_reduce(attn_output)
422+
return self.o_proj.post_all_reduce(attn_output)
423423
return attn_output
424424

425425

optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def attention_all_reduce(self, attn_output):
491491

492492
def post_attn_forward(self, attn_output):
493493
if hasattr(self.o_proj, "post_all_reduce"):
494-
self.o_proj.post_all_reduce(attn_output)
494+
return self.o_proj.post_all_reduce(attn_output)
495495
return attn_output
496496

497497

optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
###############################################################################
1818

1919
import math
20+
import os
2021
from typing import List, Optional, Tuple, Union
2122

2223
import torch
@@ -307,7 +308,8 @@ def pre_attn_forward(
307308

308309
if q_len == 1:
309310
# next token
310-
with ht.sdp_kernel(enable_recompute=False):
311+
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
312+
with ht.sdp_kernel(enable_recompute=use_recompute):
311313
attn_output = FusedSDPA.apply(
312314
query_states, key_states, value_states, attention_mask, 0.0, False, None
313315
)
@@ -374,7 +376,7 @@ def attention_all_reduce(self, attn_output):
374376

375377
def post_attn_forward(self, attn_output):
376378
if hasattr(self.o_proj, "post_all_reduce"):
377-
self.o_proj.post_all_reduce(attn_output)
379+
return self.o_proj.post_all_reduce(attn_output)
378380
return attn_output
379381

380382

0 commit comments

Comments
 (0)