Skip to content

Commit 6ed1845

Browse files
authored
Merge pull request #11 from ss-sebastian/cantonese-try-3
Cantonese try 3
2 parents ff52db9 + 9510141 commit 6ed1845

File tree

6 files changed

+368
-186
lines changed

6 files changed

+368
-186
lines changed

batchalign/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# CHAT punctuation specifications
2-
ENDING_PUNCT = [".", "?", "!", "+//.", "+/.", "+...", "+\"/.", "+..?", "+\".", "+//?", "+.", "+!?", "+/?", "..."]
2+
ENDING_PUNCT = [".", "?", "?", "!", "!", "+//.", "+/.", "+...", "+\"/.", "+..?","+..?", "+\".", "+//?", "+//?","+.", "+!?", "+!?", "+/?", "+/?", "...", "?","!"]
33
MOR_PUNCT = ["‡", "„", ","]
44
CHAT_IGNORE = ["xxx", "yyy", "www"]
55

batchalign/models/resolve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
"whisper": {
1414
'eng': ("talkbank/CHATWhisper-en-large-v1", "openai/whisper-large-v2"),
15-
# 'yue': ("alvanlii/whisper-small-cantonese", "alvanlii/whisper-small-cantonese"),
15+
'yue': ("alvanlii/whisper-small-cantonese", "alvanlii/whisper-small-cantonese"),
1616
}
1717
}
1818

batchalign/models/utils.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
1515
tensor containing the timestamps in seconds for each predicted token
1616
"""
1717
# Create a list with `decoder_layers` elements, each a tensor of shape
18-
# (batch size, attention_heads, output length, input length).
18+
# (batch_size, attention_heads, output_length, input_length).
1919
cross_attentions = []
2020
for i in range(self.config.decoder_layers):
2121
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
2222

23-
# Select specific cross-attention layers and heads. This is a tensor
24-
# of shape (batch size, num selected, output length, input length).
23+
# Select specific cross-attention layers and heads. This results in a tensor
24+
# of shape (batch_size, num_selected_heads, output_length, input_length).
2525
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
2626
weights = weights.permute([1, 0, 2, 3])
2727
if num_frames is not None:
@@ -32,21 +32,39 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
3232
weights = (weights - mean) / std
3333
weights = _median_filter(weights, self.config.median_filter_width)
3434

35-
# Average the different cross-attention heads.
35+
# Average the different cross-attention heads to get a matrix of shape
36+
# (batch_size, output_length, input_length).
3637
matrix = weights.mean(dim=1)
3738

38-
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
39+
# Initialize the timestamps tensor with the correct size.
40+
# We'll find the maximum length of `jump_times` across the batch.
41+
batch_size = generate_outputs.sequences.size(0)
42+
max_jump_length = 0
43+
batch_jump_times = []
3944

40-
# Perform dynamic time warping on each element of the batch.
41-
for batch_idx in range(timestamps.shape[0]):
45+
# First pass: Compute `jump_times` and find the maximum length.
46+
for batch_idx in range(batch_size):
4247
text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].float().cpu().numpy())
4348
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
4449
jump_times = time_indices[jumps] * time_precision
45-
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
50+
batch_jump_times.append(jump_times)
51+
if len(jump_times) > max_jump_length:
52+
max_jump_length = len(jump_times)
53+
54+
# Initialize timestamps tensor with appropriate size.
55+
# Adding 1 to account for the initial zero (timestamps[:, 0]).
56+
timestamps = torch.zeros((batch_size, max_jump_length + 1), dtype=torch.float32)
57+
58+
# Second pass: Assign `jump_times` to the timestamps tensor.
59+
for batch_idx, jump_times in enumerate(batch_jump_times):
60+
length = len(jump_times)
61+
# Assign `jump_times` to the appropriate slice in `timestamps`.
62+
timestamps[batch_idx, 1:1+length] = torch.tensor(jump_times, dtype=torch.float32)
4663

4764
return timestamps
4865

4966

67+
5068
@dataclass
5169
class ASRAudioFile:
5270
file : str

batchalign/models/utterance/infer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def __init__(self, model):
3131
self.tokenizer = AutoTokenizer.from_pretrained(model)
3232
self.model = BertForTokenClassification.from_pretrained(model).to(DEVICE)
3333

34+
self.max_length = self.model.config.max_position_embeddings
35+
3436
# eval mode
3537
self.model.eval()
3638

@@ -43,15 +45,27 @@ def __call__(self, passage):
4345
passage = passage.replace('.','')
4446

4547
# "tokenize" the result by just splitting by space
46-
input_tokenized = passage.split(' ')
48+
input_tokenized = passage.split(' ') if passage.strip() else []
49+
if not input_tokenized:
50+
raise ValueError("Tokenized input is empty after preprocessing")
51+
52+
if len(input_tokenized) > self.max_length:
53+
input_tokenized = input_tokenized[:self.max_length]
54+
55+
print(f"Input tokenized length: {len(input_tokenized)}, tokens: {input_tokenized}")
56+
4757

4858
# pass it through the tokenizer and model
4959
tokd = self.tokenizer([input_tokenized],
5060
return_tensors='pt',
51-
is_split_into_words=True).to(DEVICE)
61+
is_split_into_words=True,
62+
truncation=True,
63+
max_length=self.max_length
64+
).to(DEVICE)
5265

5366
# pass it through the model
54-
res = self.model(**tokd).logits
67+
with torch.no_grad():
68+
res = self.model(**tokd).logits
5569

5670
# argmax
5771
classified_targets = torch.argmax(res, dim=2).cpu()

0 commit comments

Comments
 (0)