@@ -15,13 +15,13 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
1515 tensor containing the timestamps in seconds for each predicted token
1616 """
1717 # Create a list with `decoder_layers` elements, each a tensor of shape
18- # (batch size , attention_heads, output length, input length ).
18+ # (batch_size , attention_heads, output_length, input_length ).
1919 cross_attentions = []
2020 for i in range (self .config .decoder_layers ):
2121 cross_attentions .append (torch .cat ([x [i ] for x in generate_outputs .cross_attentions ], dim = 2 ))
2222
23- # Select specific cross-attention layers and heads. This is a tensor
24- # of shape (batch size, num selected, output length, input length ).
23+ # Select specific cross-attention layers and heads. This results in a tensor
24+ # of shape (batch_size, num_selected_heads, output_length, input_length ).
2525 weights = torch .stack ([cross_attentions [l ][:, h ] for l , h in alignment_heads ])
2626 weights = weights .permute ([1 , 0 , 2 , 3 ])
2727 if num_frames is not None :
@@ -32,21 +32,39 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
3232 weights = (weights - mean ) / std
3333 weights = _median_filter (weights , self .config .median_filter_width )
3434
35- # Average the different cross-attention heads.
35+ # Average the different cross-attention heads to get a matrix of shape
36+ # (batch_size, output_length, input_length).
3637 matrix = weights .mean (dim = 1 )
3738
38- timestamps = torch .zeros_like (generate_outputs .sequences , dtype = torch .float32 )
39+ # Initialize the timestamps tensor with the correct size.
40+ # We'll find the maximum length of `jump_times` across the batch.
41+ batch_size = generate_outputs .sequences .size (0 )
42+ max_jump_length = 0
43+ batch_jump_times = []
3944
40- # Perform dynamic time warping on each element of the batch .
41- for batch_idx in range (timestamps . shape [ 0 ] ):
45+ # First pass: Compute `jump_times` and find the maximum length .
46+ for batch_idx in range (batch_size ):
4247 text_indices , time_indices = _dynamic_time_warping (- matrix [batch_idx ].float ().cpu ().numpy ())
4348 jumps = np .pad (np .diff (text_indices ), (1 , 0 ), constant_values = 1 ).astype (bool )
4449 jump_times = time_indices [jumps ] * time_precision
45- timestamps [batch_idx , 1 :] = torch .tensor (jump_times )
50+ batch_jump_times .append (jump_times )
51+ if len (jump_times ) > max_jump_length :
52+ max_jump_length = len (jump_times )
53+
54+ # Initialize timestamps tensor with appropriate size.
55+ # Adding 1 to account for the initial zero (timestamps[:, 0]).
56+ timestamps = torch .zeros ((batch_size , max_jump_length + 1 ), dtype = torch .float32 )
57+
58+ # Second pass: Assign `jump_times` to the timestamps tensor.
59+ for batch_idx , jump_times in enumerate (batch_jump_times ):
60+ length = len (jump_times )
61+ # Assign `jump_times` to the appropriate slice in `timestamps`.
62+ timestamps [batch_idx , 1 :1 + length ] = torch .tensor (jump_times , dtype = torch .float32 )
4663
4764 return timestamps
4865
4966
67+
5068@dataclass
5169class ASRAudioFile :
5270 file : str
0 commit comments