diff --git a/projects/ChatGLM/chatglm.py b/projects/ChatGLM/chatglm.py index fda48c103..cea239878 100644 --- a/projects/ChatGLM/chatglm.py +++ b/projects/ChatGLM/chatglm.py @@ -180,8 +180,7 @@ def scaled_dot_product_attention( def forward(self, query_layer, key_layer, value_layer, attention_mask=None): # query_layer: [sq, b, np, hn] -[premute]-> [batch_size, head_num, seq_len, hidden_size] query_layer, key_layer, value_layer = [ - # k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] - k.transpose(1, 2).transpose(0, 2) for k in [query_layer, key_layer, value_layer] + k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] ] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: context_layer = self.scaled_dot_product_attention( @@ -195,8 +194,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask=None): query_layer, key_layer, value_layer, attention_mask ) - # context_layer = context_layer.permute(2, 0, 1, 3) - context_layer = context_layer.transpose(0, 1).transpose(0, 2) + context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.flatten(2) return context_layer @@ -711,8 +709,7 @@ def get_prompt(self, batch_size): ) # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) - # past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - past_key_values = past_key_values.transpose(0, 2).split(2) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) return past_key_values def forward(