@@ -180,8 +180,7 @@ def scaled_dot_product_attention(
180
180
def forward (self , query_layer , key_layer , value_layer , attention_mask = None ):
181
181
# query_layer: [sq, b, np, hn] -[premute]-> [batch_size, head_num, seq_len, hidden_size]
182
182
query_layer , key_layer , value_layer = [
183
- # k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
184
- k .transpose (1 , 2 ).transpose (0 , 2 ) for k in [query_layer , key_layer , value_layer ]
183
+ k .permute (1 , 2 , 0 , 3 ) for k in [query_layer , key_layer , value_layer ]
185
184
]
186
185
if attention_mask is None and query_layer .shape [2 ] == key_layer .shape [2 ]:
187
186
context_layer = self .scaled_dot_product_attention (
@@ -195,8 +194,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
195
194
query_layer , key_layer , value_layer , attention_mask
196
195
)
197
196
198
- # context_layer = context_layer.permute(2, 0, 1, 3)
199
- context_layer = context_layer .transpose (0 , 1 ).transpose (0 , 2 )
197
+ context_layer = context_layer .permute (2 , 0 , 1 , 3 )
200
198
context_layer = context_layer .flatten (2 )
201
199
return context_layer
202
200
@@ -711,8 +709,7 @@ def get_prompt(self, batch_size):
711
709
)
712
710
# seq_len, b, nh, hidden_size
713
711
past_key_values = self .dropout (past_key_values )
714
- # past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
715
- past_key_values = past_key_values .transpose (0 , 2 ).split (2 )
712
+ past_key_values = past_key_values .permute ([2 , 1 , 0 , 3 , 4 ]).split (2 )
716
713
return past_key_values
717
714
718
715
def forward (
0 commit comments