Skip to content

Commit f12bcca

Browse files
committed
mmm
1 parent 76e47a9 commit f12bcca

File tree

7 files changed

+253
-24
lines changed

7 files changed

+253
-24
lines changed

batchalign/cli/cli.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def batchalign(ctx, verbose):
110110
default=False, help="For utterance timing recovery, OpenAI Whisper (ASR) instead of Rev.AI (default).")
111111
@click.option("--wav2vec/--whisper_fa",
112112
default=True, help="Use Whisper instead of Wav2Vec for English (defaults for Whisper for non-English)")
113+
@click.option("--iic", is_flag=True, default=False, help="Use IIC forced alignment (for Chinese).")
113114
@click.option("--tencent/--rev",
114115
default=False, help="Use Tencent instead of Rev.AI (default).")
115116
@click.option("--funaudio/--rev",
@@ -118,7 +119,7 @@ def batchalign(ctx, verbose):
118119
@click.option("--wor/--nowor",
119120
default=True, help="Should we write word level alignment line? Default to yes.")
120121
@click.pass_context
121-
def align(ctx, in_dir, out_dir, whisper, wav2vec, tencent, funaudio, **kwargs):
122+
def align(ctx, in_dir, out_dir, whisper, wav2vec, iic, tencent, funaudio, **kwargs):
122123
"""Align transcripts against corresponding media files."""
123124
def loader(file):
124125
return (
@@ -129,26 +130,23 @@ def loader(file):
129130
def writer(doc, output):
130131
CHATFile(doc=doc).write(output, write_wor=kwargs.get("wor", True))
131132

132-
if not wav2vec:
133-
_dispatch("align", "eng", 1,
134-
["cha"], ctx,
135-
in_dir, out_dir,
136-
loader, writer, C,
137-
fa="whisper_fa",
138-
utr = ("whisper_utr" if whisper else
139-
("tencent_utr" if tencent else
140-
("funaudio_utr" if funaudio else "rev_utr"))),
141-
**kwargs)
133+
# Determine FA engine
134+
if iic:
135+
fa_engine = "iic_fa"
136+
elif not wav2vec:
137+
fa_engine = "whisper_fa"
142138
else:
143-
_dispatch("align", "eng", 1,
144-
["cha"], ctx,
145-
in_dir, out_dir,
146-
loader, writer, C,
147-
fa="wav2vec_fa",
148-
utr = ("whisper_utr" if whisper else
149-
("tencent_utr" if tencent else
150-
("funaudio_utr" if funaudio else "rev_utr"))),
151-
**kwargs)
139+
fa_engine = "wav2vec_fa"
140+
141+
_dispatch("align", "eng", 1,
142+
["cha"], ctx,
143+
in_dir, out_dir,
144+
loader, writer, C,
145+
fa=fa_engine,
146+
utr = ("whisper_utr" if whisper else
147+
("tencent_utr" if tencent else
148+
("funaudio_utr" if funaudio else "rev_utr"))),
149+
**kwargs)
152150

153151
#################### TRANSCRIBE ################################
154152

batchalign/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .cleanup import NgramRetraceEngine, DisfluencyReplacementEngine
88
from .speaker import NemoSpeakerEngine
99

10-
from .fa import WhisperFAEngine, Wave2VecFAEngine
10+
from .fa import WhisperFAEngine, Wave2VecFAEngine, IICFAEngine
1111
from .utr import WhisperUTREngine, RevUTREngine, TencentUTREngine, FunAudioUTREngine
1212

1313
from .analysis import EvaluationEngine

batchalign/pipelines/dispatch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
StanzaUtteranceEngine, CorefEngine, Wave2VecFAEngine, TencentEngine,
1010
OAIWhisperEngine, TencentUTREngine, AliyunEngine, FunAudioEngine,
1111
FunAudioUTREngine, SeamlessTranslationModel, GoogleTranslateEngine,
12-
OAIWhisperEngine, PyannoteEngine)
12+
OAIWhisperEngine, PyannoteEngine, IICFAEngine)
1313

1414
from batchalign import BatchalignPipeline
1515
from batchalign.models import resolve
@@ -135,6 +135,8 @@ def dispatch_pipeline(pkg_str, lang, num_speakers=None, **arg_overrides):
135135
engines.append(CorefEngine())
136136
elif engine == "wav2vec_fa":
137137
engines.append(Wave2VecFAEngine())
138+
elif engine == "iic_fa":
139+
engines.append(IICFAEngine())
138140
elif engine == "seamless_translate":
139141
engines.append(SeamlessTranslationModel())
140142
elif engine == "tencent":
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .whisper_fa import WhisperFAEngine
22
from .wave2vec_fa import Wave2VecFAEngine
3+
from .iic_fa import IICFAEngine

batchalign/pipelines/fa/iic_fa.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
from batchalign.document import *
2+
from batchalign.pipelines.base import *
3+
from batchalign.utils import *
4+
from batchalign.utils.dp import *
5+
from batchalign.constants import *
6+
from batchalign.models.utils import ASRAudioFile
7+
8+
import logging
9+
L = logging.getLogger("batchalign")
10+
11+
import re
12+
import warnings
13+
import tempfile
14+
import os
15+
16+
from modelscope.pipelines import pipeline
17+
from modelscope.utils.constant import Tasks
18+
19+
import torch
20+
from torchaudio import load
21+
from torchaudio import transforms as T
22+
import torchaudio
23+
24+
class IICFAEngine(BatchalignEngine):
25+
tasks = [ Task.FORCED_ALIGNMENT ]
26+
27+
def _hook_status(self, status_hook):
28+
self.status_hook = status_hook
29+
30+
def __init__(self):
31+
self.status_hook = None
32+
self.__iic = pipeline(
33+
task=Tasks.speech_timestamp,
34+
model='iic/speech_timestamp_prediction-v1-16k-offline',
35+
model_revision="v2.0.4",
36+
output_dir='./tmp')
37+
38+
def process(self, doc:Document, **kwargs):
39+
# check that the document has a media path to align to
40+
assert doc.media != None and doc.media.url != None, f"We cannot forced-align something that doesn't have a media path! Provided media tier='{doc.media}'"
41+
42+
if doc.langs[0] not in ["zho", "cmn", "yue"]:
43+
warnings.warn("Looks like you are not aligning Chinese with IIC; this aligner is designed for Chinese and may not work well with other languages.")
44+
45+
# load the audio file
46+
L.debug(f"IIC FA is loading url {doc.media.url}...")
47+
audio_arr, rate = load(doc.media.url)
48+
# transpose and mean to get mono
49+
audio_tensor = torch.mean(audio_arr.transpose(0,1), dim=1)
50+
audio_file = ASRAudioFile(doc.media.url, audio_tensor, rate)
51+
L.debug(f"IIC FA finished loading media.")
52+
53+
# collect utterances into 30 secondish segments to be aligned
54+
# we have to do this because the aligner does poorly with very short segments
55+
groups = []
56+
group = []
57+
seg_start = 0
58+
59+
L.debug(f"IIC FA grouping utterances...")
60+
61+
for i in doc.content:
62+
if not isinstance(i, Utterance):
63+
continue
64+
if i.alignment == None:
65+
warnings.warn("We found at least one utterance without utterance-level alignment; this is usually not an issue, but if the entire transcript is unaligned, it means that utterance level timing recovery (which is fuzzy using ASR) failed due to the audio clarity. On this transcript, before running forced-alignment, please supply utterance-level links.")
66+
continue
67+
68+
# pop the previous group onto the stack
69+
if (i.alignment[-1] - seg_start) > 15*1000:
70+
groups.append(group)
71+
group = []
72+
seg_start = i.alignment[0]
73+
74+
# append the contents to the running group
75+
for word in i.content:
76+
group.append((word, i.alignment))
77+
78+
groups.append(group)
79+
80+
L.debug(f"Begin IIC Inference...")
81+
82+
for indx, grp in enumerate(groups):
83+
L.info(f"IIC FA processing segment {indx+1}/{len(groups)}...")
84+
if self.status_hook != None:
85+
self.status_hook(indx+1, len(groups))
86+
87+
# perform alignment
88+
try:
89+
# create transcript with spaces between characters
90+
transcript = []
91+
for word, _ in grp:
92+
# skip punctuation
93+
if word.text.strip() not in MOR_PUNCT + ENDING_PUNCT:
94+
# add spaces between each character for Chinese
95+
transcript.append(" ".join(list(word.text)))
96+
97+
transcript_text = " ".join(transcript)
98+
99+
if len(transcript_text.strip()) == 0:
100+
continue
101+
102+
# extract audio chunk and write to temp file
103+
if (grp[-1][1][1] - grp[0][1][0]) < 20*1000:
104+
# get the audio chunk as tensor
105+
audio_chunk = audio_file.chunk(grp[0][1][0], grp[-1][1][1])
106+
107+
# create temporary file for the audio chunk
108+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
109+
tmp_path = tmp_file.name
110+
111+
# write the audio chunk to temp file
112+
torchaudio.save(tmp_path, audio_chunk.unsqueeze(0), rate)
113+
114+
try:
115+
# call IIC aligner with the temp file
116+
rec_result = self.__iic(input=(tmp_path, transcript_text),
117+
data_type=("sound", "text"))
118+
finally:
119+
# clean up temp file
120+
if os.path.exists(tmp_path):
121+
os.unlink(tmp_path)
122+
else:
123+
continue
124+
except Exception as e:
125+
L.warning(f"IIC alignment failed for segment {indx+1}: {e}")
126+
continue
127+
128+
# parse the result string
129+
# format: '<sil> 0.000 0.380;一 0.380 0.560;个 0.560 0.800;...'
130+
try:
131+
timings = []
132+
words = []
133+
134+
for p in rec_result:
135+
parts = p["text"].strip().split()
136+
timestamps = p["timestamp"]
137+
138+
for pts,tss in zip(parts, timestamps):
139+
word, start, end = pts, max(tss[1], 0), max(tss[1], 0)
140+
words.append(word)
141+
# convert to milliseconds and add offset
142+
timings.append((int(start + grp[0][1][0]),
143+
int(end + grp[0][1][0])))
144+
except Exception as e:
145+
raise e
146+
L.warning(f"Failed to parse IIC result for segment {indx+1}: {e}")
147+
continue
148+
149+
# create reference backplates, which are the word ids to set the timing for
150+
ref_targets = []
151+
for indx, (word, _) in enumerate(grp):
152+
for char in word.text:
153+
ref_targets.append(ReferenceTarget(char, payload=indx))
154+
155+
# create target backplates for the timings
156+
payload_targets = []
157+
try:
158+
for indx, (word, time) in enumerate(zip(words, timings)):
159+
for char in word:
160+
payload_targets.append(PayloadTarget(char, payload=indx))
161+
except Exception as e:
162+
L.warning(f"Failed to create payload targets for segment {indx+1}: {e}")
163+
continue
164+
165+
# alignment!
166+
alignments = align(payload_targets, ref_targets, tqdm=False)
167+
168+
# set the ids back to the text ids
169+
# we do this BACKWARDS because we want to have the first timestamp
170+
# we get about a word first
171+
alignments.reverse()
172+
for indx, elem in enumerate(alignments):
173+
if isinstance(elem, Match):
174+
grp[elem.reference_payload][0].time = (timings[elem.payload][0],
175+
timings[elem.payload][1])
176+
177+
L.debug(f"Correcting text...")
178+
179+
# we now set the end alignment of each word to the start of the next
180+
for doc_ut, ut in enumerate(doc.content):
181+
if not isinstance(ut, Utterance):
182+
continue
183+
184+
# correct each word by bumping it forward
185+
# and if its not a word we remove the timing
186+
for indx, w in enumerate(ut.content):
187+
if w.type in [TokenType.PUNCT, TokenType.FEAT, TokenType.ANNOT]:
188+
w.time = None
189+
elif indx == len(ut.content)-1 and w.text in ENDING_PUNCT:
190+
w.time = None
191+
elif indx != len(ut.content)-1:
192+
# search forward for the next compatible time
193+
tmp = indx+1
194+
while tmp < len(ut.content)-1 and ut.content[tmp].time == None:
195+
tmp += 1
196+
if w.time == None:
197+
continue
198+
if ut.content[tmp].time == None:
199+
# seek forward one utterance to find their start time
200+
next_ut = doc_ut + 1
201+
while next_ut < len(doc.content)-1 and (not isinstance(doc.content, Utterance) or doc.content[next_ut].alignment == None):
202+
next_ut += 1
203+
if next_ut < len(doc.content) and isinstance(doc.content, Utterance) and doc.content[next_ut].alignment:
204+
w.time = (w.time[0], doc.content[next_ut].alignment[0])
205+
else:
206+
w.time = (w.time[0], w.time[0]+500) # give half a second because we don't know
207+
208+
# just in case, bound the time by the utterance derived timings
209+
if ut.alignment and ut.alignment[0] != None:
210+
w.time = (max(w.time[0], ut.alignment[0]), min(w.time[1], ut.alignment[1]))
211+
# if we ended up with timings that don't make sense, drop it
212+
if w.time and w.time[0] >= w.time[1]:
213+
w.time = None
214+
215+
# clear any built-in timing (i.e. we should use utterance-derived timing)
216+
ut.time = None
217+
# correct the text
218+
if ut.alignment and ut.text != None:
219+
if '\x15' not in ut.text:
220+
ut.text = (ut.text+f" \x15{ut.alignment[0]}_{ut.alignment[1]}\x15").strip()
221+
else:
222+
ut.text = re.sub(r"\x15\d+_\d+\x15",
223+
f"\x15{ut.alignment[0]}_{ut.alignment[1]}\x15", ut.text).strip()
224+
elif ut.text != None:
225+
ut.text = re.sub(r"\x15\d+_\d+\x15", f"", ut.text).strip()
226+
227+
return doc

batchalign/version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
0.7.22-post.29
1+
0.7.22-post.30
22
November 28rd, 2025
33
Tencent Fixes

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def read(fname):
5151
"click~=8.1",
5252
"matplotlib>=3.8.0,<4.0.0",
5353
"pyfiglet==1.0.2",
54-
"setuptools>=78.1.1",
54+
"setuptools",
5555
"soundfile~=0.12.0",
5656
"rich-click>=1.7.0",
5757
"typing-extensions",
@@ -66,6 +66,7 @@ def read(fname):
6666
"oss2",
6767
"openai-whisper>=20240930",
6868
"funasr",
69+
"modelscope[nlp,audio]",
6970
"cos-python-sdk-v5",
7071
"openai-whisper",
7172
"llvmlite>=0.44.0",

0 commit comments

Comments
 (0)