Skip to content

Commit

Permalink
Fix bias update in scoped all reduce (#1456)
Browse files Browse the repository at this point in the history
  • Loading branch information
skavulya authored Nov 6, 2024
1 parent 309e0c4 commit 0c3b272
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.dense, "all_reduce"):
self.dense.post_all_reduce(attn_output)
return self.dense.post_all_reduce(attn_output)
return attn_output


Expand All @@ -598,7 +598,7 @@ def mlp_all_reduce(self, x):

def post_mlp_forward(self, x):
if hasattr(self.dense_4h_to_h, "all_reduce"):
self.dense_4h_to_h.post_all_reduce(x)
return self.dense_4h_to_h.post_all_reduce(x)
return x


Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down
6 changes: 2 additions & 4 deletions optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,5 @@ def all_reduce(self, input):
dist.inference_all_reduce(input, group=self.mp_group)

def post_all_reduce(self, input):
# inplace addition needed for correct results
if self.bias is not None:
input += self.bias
return input
output = input + self.bias if (self.bias is not None) else input
return output
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
###############################################################################

import math
import os
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -307,7 +308,8 @@ def pre_attn_forward(

if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
Expand Down Expand Up @@ -374,7 +376,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down

0 comments on commit 0c3b272

Please sign in to comment.