|
| 1 | +from batchalign.document import * |
| 2 | +from batchalign.pipelines.base import * |
| 3 | +from batchalign.utils import * |
| 4 | +from batchalign.utils.dp import * |
| 5 | +from batchalign.constants import * |
| 6 | +from batchalign.models.utils import ASRAudioFile |
| 7 | + |
| 8 | +import logging |
| 9 | +L = logging.getLogger("batchalign") |
| 10 | + |
| 11 | +import re |
| 12 | +import warnings |
| 13 | +import tempfile |
| 14 | +import os |
| 15 | + |
| 16 | +from modelscope.pipelines import pipeline |
| 17 | +from modelscope.utils.constant import Tasks |
| 18 | + |
| 19 | +import torch |
| 20 | +from torchaudio import load |
| 21 | +from torchaudio import transforms as T |
| 22 | +import torchaudio |
| 23 | + |
| 24 | +class IICFAEngine(BatchalignEngine): |
| 25 | + tasks = [ Task.FORCED_ALIGNMENT ] |
| 26 | + |
| 27 | + def _hook_status(self, status_hook): |
| 28 | + self.status_hook = status_hook |
| 29 | + |
| 30 | + def __init__(self): |
| 31 | + self.status_hook = None |
| 32 | + self.__iic = pipeline( |
| 33 | + task=Tasks.speech_timestamp, |
| 34 | + model='iic/speech_timestamp_prediction-v1-16k-offline', |
| 35 | + model_revision="v2.0.4", |
| 36 | + output_dir='./tmp') |
| 37 | + |
| 38 | + def process(self, doc:Document, **kwargs): |
| 39 | + # check that the document has a media path to align to |
| 40 | + assert doc.media != None and doc.media.url != None, f"We cannot forced-align something that doesn't have a media path! Provided media tier='{doc.media}'" |
| 41 | + |
| 42 | + if doc.langs[0] not in ["zho", "cmn", "yue"]: |
| 43 | + warnings.warn("Looks like you are not aligning Chinese with IIC; this aligner is designed for Chinese and may not work well with other languages.") |
| 44 | + |
| 45 | + # load the audio file |
| 46 | + L.debug(f"IIC FA is loading url {doc.media.url}...") |
| 47 | + audio_arr, rate = load(doc.media.url) |
| 48 | + # transpose and mean to get mono |
| 49 | + audio_tensor = torch.mean(audio_arr.transpose(0,1), dim=1) |
| 50 | + audio_file = ASRAudioFile(doc.media.url, audio_tensor, rate) |
| 51 | + L.debug(f"IIC FA finished loading media.") |
| 52 | + |
| 53 | + # collect utterances into 30 secondish segments to be aligned |
| 54 | + # we have to do this because the aligner does poorly with very short segments |
| 55 | + groups = [] |
| 56 | + group = [] |
| 57 | + seg_start = 0 |
| 58 | + |
| 59 | + L.debug(f"IIC FA grouping utterances...") |
| 60 | + |
| 61 | + for i in doc.content: |
| 62 | + if not isinstance(i, Utterance): |
| 63 | + continue |
| 64 | + if i.alignment == None: |
| 65 | + warnings.warn("We found at least one utterance without utterance-level alignment; this is usually not an issue, but if the entire transcript is unaligned, it means that utterance level timing recovery (which is fuzzy using ASR) failed due to the audio clarity. On this transcript, before running forced-alignment, please supply utterance-level links.") |
| 66 | + continue |
| 67 | + |
| 68 | + # pop the previous group onto the stack |
| 69 | + if (i.alignment[-1] - seg_start) > 15*1000: |
| 70 | + groups.append(group) |
| 71 | + group = [] |
| 72 | + seg_start = i.alignment[0] |
| 73 | + |
| 74 | + # append the contents to the running group |
| 75 | + for word in i.content: |
| 76 | + group.append((word, i.alignment)) |
| 77 | + |
| 78 | + groups.append(group) |
| 79 | + |
| 80 | + L.debug(f"Begin IIC Inference...") |
| 81 | + |
| 82 | + for indx, grp in enumerate(groups): |
| 83 | + L.info(f"IIC FA processing segment {indx+1}/{len(groups)}...") |
| 84 | + if self.status_hook != None: |
| 85 | + self.status_hook(indx+1, len(groups)) |
| 86 | + |
| 87 | + # perform alignment |
| 88 | + try: |
| 89 | + # create transcript with spaces between characters |
| 90 | + transcript = [] |
| 91 | + for word, _ in grp: |
| 92 | + # skip punctuation |
| 93 | + if word.text.strip() not in MOR_PUNCT + ENDING_PUNCT: |
| 94 | + # add spaces between each character for Chinese |
| 95 | + transcript.append(" ".join(list(word.text))) |
| 96 | + |
| 97 | + transcript_text = " ".join(transcript) |
| 98 | + |
| 99 | + if len(transcript_text.strip()) == 0: |
| 100 | + continue |
| 101 | + |
| 102 | + # extract audio chunk and write to temp file |
| 103 | + if (grp[-1][1][1] - grp[0][1][0]) < 20*1000: |
| 104 | + # get the audio chunk as tensor |
| 105 | + audio_chunk = audio_file.chunk(grp[0][1][0], grp[-1][1][1]) |
| 106 | + |
| 107 | + # create temporary file for the audio chunk |
| 108 | + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: |
| 109 | + tmp_path = tmp_file.name |
| 110 | + |
| 111 | + # write the audio chunk to temp file |
| 112 | + torchaudio.save(tmp_path, audio_chunk.unsqueeze(0), rate) |
| 113 | + |
| 114 | + try: |
| 115 | + # call IIC aligner with the temp file |
| 116 | + rec_result = self.__iic(input=(tmp_path, transcript_text), |
| 117 | + data_type=("sound", "text")) |
| 118 | + finally: |
| 119 | + # clean up temp file |
| 120 | + if os.path.exists(tmp_path): |
| 121 | + os.unlink(tmp_path) |
| 122 | + else: |
| 123 | + continue |
| 124 | + except Exception as e: |
| 125 | + L.warning(f"IIC alignment failed for segment {indx+1}: {e}") |
| 126 | + continue |
| 127 | + |
| 128 | + # parse the result string |
| 129 | + # format: '<sil> 0.000 0.380;一 0.380 0.560;个 0.560 0.800;...' |
| 130 | + try: |
| 131 | + timings = [] |
| 132 | + words = [] |
| 133 | + |
| 134 | + for p in rec_result: |
| 135 | + parts = p["text"].strip().split() |
| 136 | + timestamps = p["timestamp"] |
| 137 | + |
| 138 | + for pts,tss in zip(parts, timestamps): |
| 139 | + word, start, end = pts, max(tss[1], 0), max(tss[1], 0) |
| 140 | + words.append(word) |
| 141 | + # convert to milliseconds and add offset |
| 142 | + timings.append((int(start + grp[0][1][0]), |
| 143 | + int(end + grp[0][1][0]))) |
| 144 | + except Exception as e: |
| 145 | + raise e |
| 146 | + L.warning(f"Failed to parse IIC result for segment {indx+1}: {e}") |
| 147 | + continue |
| 148 | + |
| 149 | + # create reference backplates, which are the word ids to set the timing for |
| 150 | + ref_targets = [] |
| 151 | + for indx, (word, _) in enumerate(grp): |
| 152 | + for char in word.text: |
| 153 | + ref_targets.append(ReferenceTarget(char, payload=indx)) |
| 154 | + |
| 155 | + # create target backplates for the timings |
| 156 | + payload_targets = [] |
| 157 | + try: |
| 158 | + for indx, (word, time) in enumerate(zip(words, timings)): |
| 159 | + for char in word: |
| 160 | + payload_targets.append(PayloadTarget(char, payload=indx)) |
| 161 | + except Exception as e: |
| 162 | + L.warning(f"Failed to create payload targets for segment {indx+1}: {e}") |
| 163 | + continue |
| 164 | + |
| 165 | + # alignment! |
| 166 | + alignments = align(payload_targets, ref_targets, tqdm=False) |
| 167 | + |
| 168 | + # set the ids back to the text ids |
| 169 | + # we do this BACKWARDS because we want to have the first timestamp |
| 170 | + # we get about a word first |
| 171 | + alignments.reverse() |
| 172 | + for indx, elem in enumerate(alignments): |
| 173 | + if isinstance(elem, Match): |
| 174 | + grp[elem.reference_payload][0].time = (timings[elem.payload][0], |
| 175 | + timings[elem.payload][1]) |
| 176 | + |
| 177 | + L.debug(f"Correcting text...") |
| 178 | + |
| 179 | + # we now set the end alignment of each word to the start of the next |
| 180 | + for doc_ut, ut in enumerate(doc.content): |
| 181 | + if not isinstance(ut, Utterance): |
| 182 | + continue |
| 183 | + |
| 184 | + # correct each word by bumping it forward |
| 185 | + # and if its not a word we remove the timing |
| 186 | + for indx, w in enumerate(ut.content): |
| 187 | + if w.type in [TokenType.PUNCT, TokenType.FEAT, TokenType.ANNOT]: |
| 188 | + w.time = None |
| 189 | + elif indx == len(ut.content)-1 and w.text in ENDING_PUNCT: |
| 190 | + w.time = None |
| 191 | + elif indx != len(ut.content)-1: |
| 192 | + # search forward for the next compatible time |
| 193 | + tmp = indx+1 |
| 194 | + while tmp < len(ut.content)-1 and ut.content[tmp].time == None: |
| 195 | + tmp += 1 |
| 196 | + if w.time == None: |
| 197 | + continue |
| 198 | + if ut.content[tmp].time == None: |
| 199 | + # seek forward one utterance to find their start time |
| 200 | + next_ut = doc_ut + 1 |
| 201 | + while next_ut < len(doc.content)-1 and (not isinstance(doc.content, Utterance) or doc.content[next_ut].alignment == None): |
| 202 | + next_ut += 1 |
| 203 | + if next_ut < len(doc.content) and isinstance(doc.content, Utterance) and doc.content[next_ut].alignment: |
| 204 | + w.time = (w.time[0], doc.content[next_ut].alignment[0]) |
| 205 | + else: |
| 206 | + w.time = (w.time[0], w.time[0]+500) # give half a second because we don't know |
| 207 | + |
| 208 | + # just in case, bound the time by the utterance derived timings |
| 209 | + if ut.alignment and ut.alignment[0] != None: |
| 210 | + w.time = (max(w.time[0], ut.alignment[0]), min(w.time[1], ut.alignment[1])) |
| 211 | + # if we ended up with timings that don't make sense, drop it |
| 212 | + if w.time and w.time[0] >= w.time[1]: |
| 213 | + w.time = None |
| 214 | + |
| 215 | + # clear any built-in timing (i.e. we should use utterance-derived timing) |
| 216 | + ut.time = None |
| 217 | + # correct the text |
| 218 | + if ut.alignment and ut.text != None: |
| 219 | + if '\x15' not in ut.text: |
| 220 | + ut.text = (ut.text+f" \x15{ut.alignment[0]}_{ut.alignment[1]}\x15").strip() |
| 221 | + else: |
| 222 | + ut.text = re.sub(r"\x15\d+_\d+\x15", |
| 223 | + f"\x15{ut.alignment[0]}_{ut.alignment[1]}\x15", ut.text).strip() |
| 224 | + elif ut.text != None: |
| 225 | + ut.text = re.sub(r"\x15\d+_\d+\x15", f"", ut.text).strip() |
| 226 | + |
| 227 | + return doc |
0 commit comments