Skip to content

Commit ec1a81a

Browse files
committed
Revert "update"
This reverts commit 49fc21e.
1 parent 4edb33a commit ec1a81a

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

projects/ChatGLM/chatglm.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,7 @@ def scaled_dot_product_attention(
180180
def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
181181
# query_layer: [sq, b, np, hn] -[premute]-> [batch_size, head_num, seq_len, hidden_size]
182182
query_layer, key_layer, value_layer = [
183-
# k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
184-
k.transpose(1, 2).transpose(0, 2) for k in [query_layer, key_layer, value_layer]
183+
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
185184
]
186185
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
187186
context_layer = self.scaled_dot_product_attention(
@@ -195,8 +194,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
195194
query_layer, key_layer, value_layer, attention_mask
196195
)
197196

198-
# context_layer = context_layer.permute(2, 0, 1, 3)
199-
context_layer = context_layer.transpose(0, 1).transpose(0, 2)
197+
context_layer = context_layer.permute(2, 0, 1, 3)
200198
context_layer = context_layer.flatten(2)
201199
return context_layer
202200

@@ -711,8 +709,7 @@ def get_prompt(self, batch_size):
711709
)
712710
# seq_len, b, nh, hidden_size
713711
past_key_values = self.dropout(past_key_values)
714-
# past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
715-
past_key_values = past_key_values.transpose(0, 2).split(2)
712+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
716713
return past_key_values
717714

718715
def forward(

0 commit comments

Comments
 (0)