From b80e4189fd5e6f22eeb1b5310d6c53f926d69dd9 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:46:29 +0100 Subject: [PATCH] [Whisper] Fix whisper tokenizer (#34537) * handle single timestamp ending * include last timestamp token * handle single timestamp ending * avoid floating points arithm limitations * ensure float64 operations * new test * make fixup * make copies * handle edge case double tokens ending with different tokens * handle single timestamp ending * make fixup * handle conditioning on prev segments * fix * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * [run-slow] whisper * don't call item() to avoid unnecessary sync * fix --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Co-authored-by: Eustache Le Bihan --- .../models/whisper/generation_whisper.py | 40 ++++++--- .../models/whisper/tokenization_whisper.py | 34 +++++-- .../whisper/tokenization_whisper_fast.py | 36 ++++++-- tests/models/whisper/test_modeling_whisper.py | 88 +++++++++++++++++++ 4 files changed, 173 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 0ecdcb4dbdea..2f58375f3de7 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -308,6 +308,7 @@ def generate( num_segment_frames: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, time_precision: float = 0.02, + time_precision_features: float = 0.01, return_token_timestamps: Optional[bool] = None, return_segments: bool = False, return_dict_in_generate: Optional[bool] = None, @@ -417,6 +418,8 @@ def generate( time_precision (`int`, *optional*, defaults to 0.02): The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts for 20 ms. + time_precision_features (`int`, *optional*, defaults to 0.01): + The duration represented by a feature frame in seconds. return_token_timestamps (`bool`, *optional*): Whether to return token-level timestamps with the text. This can be used with or without the `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into @@ -629,7 +632,7 @@ def generate( cur_bsz=cur_bsz, batch_idx_map=batch_idx_map, ) - time_offset = seek * time_precision / input_stride + time_offset = seek.to(torch.float64) * time_precision / input_stride seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) # 6.2 cut out next 30s segment from input features @@ -658,6 +661,7 @@ def generate( config=self.config, device=init_tokens.device, suppress_tokens=suppress_tokens, + timestamp_begin=timestamp_begin, kwargs=kwargs, ) @@ -718,6 +722,7 @@ def generate( timestamp_begin=timestamp_begin, seek_num_frames=seek_num_frames, time_precision=time_precision, + time_precision_features=time_precision_features, input_stride=input_stride, prev_idx=prev_i, idx=i, @@ -1665,6 +1670,7 @@ def _prepare_decoder_input_ids( config, device, suppress_tokens, + timestamp_begin, kwargs, ): if "decoder_input_ids" in kwargs: @@ -1684,6 +1690,14 @@ def _prepare_decoder_input_ids( # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] + for segments in active_segments: + for seg in segments: + if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin: + # the segment finishes with two timestamp tokens + # we need to ignore the last timestamp token + # see https://github.com/huggingface/transformers/pull/34537 + seg["tokens"] = seg["tokens"][:-1] + if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments": prev_ids = prompt_ids else: @@ -1778,6 +1792,7 @@ def _retrieve_segment( timestamp_begin, seek_num_frames, time_precision, + time_precision_features, input_stride, prev_idx, idx, @@ -1799,17 +1814,22 @@ def _retrieve_segment( segments = [] if single_timestamp_ending: slices.append(len(seek_sequence)) + else: + # we want to include the last timestamp token in the last segment to know it was no single ending + slices[-1] += 1 last_slice = 0 # Add each segment to list of all segments - for current_slice in slices: + for i, current_slice in enumerate(slices): + is_last_slice = i == len(slices) - 1 sliced_tokens = seek_sequence[last_slice:current_slice] - start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin - end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin + start_timestamp_pos = sliced_tokens[0] - timestamp_begin + idx_sliced_tokens = -1 if not is_last_slice or single_timestamp_ending else -2 + end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin segments.append( { - "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, - "end": time_offset[prev_idx] + end_timestamp_pos * time_precision, + "start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision, + "end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision, "tokens": sliced_tokens, "result": seek_outputs[idx], } @@ -1827,16 +1847,16 @@ def _retrieve_segment( # otherwise, ignore the unfinished segment and seek to the last timestamp # here we throw away all predictions after the last predicted "end of segment" # since we are cutting right in the middle of an audio - last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin + last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin segment_offset = last_timestamp_pos * input_stride else: # If whisper does not predict any "end of segment" token, then # the whole decoding is considered a segment and we add it to the list of segments timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] - last_timestamp_pos = seek_num_frames[prev_idx] - if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: + last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision) + if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = timestamps[-1].item() - timestamp_begin + last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64) segments = [ { "start": time_offset[prev_idx], diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 0a6eb75c55f6..e537ef95da67 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -528,7 +528,9 @@ def basic_normalize(text, remove_diacritics=False): normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics) return normalizer(text) - def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: + def _decode_with_timestamps( + self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500 + ) -> str: """ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". @@ -538,15 +540,25 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre cur_max_timestamp = 0.0 prev_segments_len = 0.0 + penultimate_timestamp = 0.0 - for token in token_ids: + for i, token in enumerate(token_ids): if token >= timestamp_begin: timestamp = float((token - timestamp_begin) * time_precision) if timestamp < cur_max_timestamp: # next segment has started - prev_segments_len += cur_max_timestamp + last_was_single_ending = i >= 2 and not ( + token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin + ) + if last_was_single_ending: + prev_segments_len += time_precision * segment_size + else: + cur_max_timestamp = penultimate_timestamp + prev_segments_len += penultimate_timestamp + outputs = outputs[:-2] + penultimate_timestamp = cur_max_timestamp cur_max_timestamp = timestamp outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") @@ -558,7 +570,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre ] return "".join(outputs) - def _compute_offsets(self, token_ids, time_precision=0.02): + def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): """ Compute offsets for a given tokenized input @@ -567,6 +579,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, *optional*, defaults to 0.02): The time ratio to convert from token to time. + segment_size (`int`, *optional*, defaults to 1500): + The number of features in the input mel spectrogram. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -597,7 +611,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02): if start_timestamp_position < cur_max_timestamp: # next segment has started - prev_segments_len += cur_max_timestamp + is_single_ending = last_slice >= 2 and not ( + token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin + ) + if is_single_ending: + prev_segments_len += segment_size + else: + prev_segments_len += cur_max_timestamp cur_max_timestamp = end_timestamp_position @@ -609,8 +629,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): { "text": text, "timestamp": ( - (start_timestamp_position + prev_segments_len) * time_precision, - (end_timestamp_position + prev_segments_len) * time_precision, + start_timestamp_position * time_precision + prev_segments_len * time_precision, + end_timestamp_position * time_precision + prev_segments_len * time_precision, ), } ) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 66cf412cc2a8..f0383cb0def7 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -169,7 +169,9 @@ def _encode_plus(self, *args, **kwargs) -> BatchEncoding: return super()._encode_plus(*args, **kwargs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps - def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: + def _decode_with_timestamps( + self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500 + ) -> str: """ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". @@ -179,15 +181,25 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre cur_max_timestamp = 0.0 prev_segments_len = 0.0 + penultimate_timestamp = 0.0 - for token in token_ids: + for i, token in enumerate(token_ids): if token >= timestamp_begin: timestamp = float((token - timestamp_begin) * time_precision) if timestamp < cur_max_timestamp: # next segment has started - prev_segments_len += cur_max_timestamp - + last_was_single_ending = i >= 2 and not ( + token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin + ) + if last_was_single_ending: + prev_segments_len += time_precision * segment_size + else: + cur_max_timestamp = penultimate_timestamp + prev_segments_len += penultimate_timestamp + outputs = outputs[:-2] + + penultimate_timestamp = cur_max_timestamp cur_max_timestamp = timestamp outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") @@ -200,7 +212,7 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre return "".join(outputs) # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets - def _compute_offsets(self, token_ids, time_precision=0.02): + def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): """ Compute offsets for a given tokenized input @@ -209,6 +221,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): List of tokenized input ids. Can be obtained using the `__call__` method. time_precision (`float`, *optional*, defaults to 0.02): The time ratio to convert from token to time. + segment_size (`int`, *optional*, defaults to 1500): + The number of features in the input mel spectrogram. """ offsets = [] # ensure torch tensor of token ids is placed on cpu @@ -239,7 +253,13 @@ def _compute_offsets(self, token_ids, time_precision=0.02): if start_timestamp_position < cur_max_timestamp: # next segment has started - prev_segments_len += cur_max_timestamp + is_single_ending = last_slice >= 2 and not ( + token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin + ) + if is_single_ending: + prev_segments_len += segment_size + else: + prev_segments_len += cur_max_timestamp cur_max_timestamp = end_timestamp_position @@ -251,8 +271,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02): { "text": text, "timestamp": ( - (start_timestamp_position + prev_segments_len) * time_precision, - (end_timestamp_position + prev_segments_len) * time_precision, + start_timestamp_position * time_precision + prev_segments_len * time_precision, + end_timestamp_position * time_precision + prev_segments_len * time_precision, ), } ) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 9389c4f47def..faab43854cce 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2096,6 +2096,94 @@ def test_tiny_longform_timestamps_generation(self): transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT) + @slow + def test_small_longform_timestamps_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-small.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en") + model.to(torch_device) + + dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") + sample = dataset[0]["audio"]["array"] + sampling_rate = dataset[0]["audio"]["sampling_rate"] + + sample = [*sample[: 15 * sampling_rate], *np.zeros(16 * sampling_rate).tolist(), *sample[15 * sampling_rate :]] + sample = np.array(sample) + + input_features = processor( + sample, + sampling_rate=16_000, + padding="longest", + truncation=False, + return_attention_mask=True, + return_tensors="pt", + ).input_features + + input_features = input_features.to(torch_device) + generated_ids = model.generate(input_features, return_timestamps=True, return_segments=True) + + EXPECTED_TRANSCRIPT = [ + { + "text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "timestamp": (0.0, 6.38), + }, + { + "text": " Nor is Mr. Quilter's manner less interesting than his matter.", + "timestamp": (6.38, 11.32), + }, + { + "text": " He tells us that at this festive season of the year,", + "timestamp": (11.32, 15.0), + }, + { + "text": " With Christmas and roast beef looming before us, similes drawn from eating and its results", + "timestamp": (30.0, 36.76), + }, + { + "text": " occur most readily to the mind.", + "timestamp": (36.76, 39.80), + }, + { + "text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and", + "timestamp": (39.80, 45.36), + }, + { + "text": " can discover in it but little of rocky Ithaca.", + "timestamp": (45.36, 49.0), + }, + { + "text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles", + "timestamp": (49.0, 56.28), + }, + { + "text": " are as national as a jingo poem. Mr. Burkett fosters landscape's smile at one much in", + "timestamp": (56.28, 64.12), + }, + { + "text": " the same way that Mr. Karker used to flash his teeth. And Mr. John Collier gives his", + "timestamp": (64.12, 70.76), + }, + { + "text": " sitter a cheerful slap on the back before he says, like a shampoo or in a Turkish bath,", + "timestamp": (70.76, 77.16), + }, + { + "text": " Next Man", + "timestamp": (77.16, 78.16), + }, + ] + + transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True) + self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT) + + transcript_segments = [ + { + "text": processor.decode(seg["tokens"], skip_special_tokens=True), + "timestamp": (seg["start"].item(), seg["end"].item()), + } + for seg in generated_ids["segments"][0] + ] + self.assertEqual(transcript_segments, EXPECTED_TRANSCRIPT) + @slow def test_large_timestamp_generation(self): set_seed(0)