Skip to content

Commit

Permalink
Revert "update"
Browse files Browse the repository at this point in the history
This reverts commit 49fc21e.
  • Loading branch information
0x404 committed Sep 22, 2024
1 parent 4edb33a commit ec1a81a
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions projects/ChatGLM/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def scaled_dot_product_attention(
def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
# query_layer: [sq, b, np, hn] -[premute]-> [batch_size, head_num, seq_len, hidden_size]
query_layer, key_layer, value_layer = [
# k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
k.transpose(1, 2).transpose(0, 2) for k in [query_layer, key_layer, value_layer]
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = self.scaled_dot_product_attention(
Expand All @@ -195,8 +194,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
query_layer, key_layer, value_layer, attention_mask
)

# context_layer = context_layer.permute(2, 0, 1, 3)
context_layer = context_layer.transpose(0, 1).transpose(0, 2)
context_layer = context_layer.permute(2, 0, 1, 3)
context_layer = context_layer.flatten(2)
return context_layer

Expand Down Expand Up @@ -711,8 +709,7 @@ def get_prompt(self, batch_size):
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
# past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
past_key_values = past_key_values.transpose(0, 2).split(2)
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
return past_key_values

def forward(
Expand Down

0 comments on commit ec1a81a

Please sign in to comment.