Skip to content

Commit 8b19c86

Browse files
committed
refactor(webvtt): make WebVTTTimestamp public
Since WebVTTTimestamp is used in DoclingDocument, the class should be public. Strengthen validation of cue language start tag annotation. Signed-off-by: Cesar Berrospi Ramis <[email protected]>
1 parent a2b12ec commit 8b19c86

File tree

4 files changed

+69
-30
lines changed

4 files changed

+69
-30
lines changed

docling_core/types/doc/document.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
)
6969
from docling_core.types.doc.tokens import DocumentToken, TableToken
7070
from docling_core.types.doc.utils import parse_otsl_table_content, relative_path
71-
from docling_core.types.doc.webvtt import _WebVTTTimestamp
71+
from docling_core.types.doc.webvtt import WebVTTTimestamp
7272

7373
_logger = logging.getLogger(__name__)
7474

@@ -1230,14 +1230,14 @@ class ProvenanceTrack(BaseModel):
12301230
"""
12311231

12321232
start_time: Annotated[
1233-
_WebVTTTimestamp,
1233+
WebVTTTimestamp,
12341234
Field(
12351235
examples=["00.11.000", "00:00:06.500", "01:28:34.300"],
12361236
description="Start time offset of the track cue",
12371237
),
12381238
]
12391239
end_time: Annotated[
1240-
_WebVTTTimestamp,
1240+
WebVTTTimestamp,
12411241
Field(
12421242
examples=["00.12.000", "00:00:08.200", "01:29:30.100"],
12431243
description="End time offset of the track cue",

docling_core/types/doc/webvtt.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,18 @@ class _WebVTTLineTerminator(str, Enum):
2828
]
2929

3030

31-
class _WebVTTTimestamp(BaseModel):
31+
class WebVTTTimestamp(BaseModel):
3232
"""WebVTT timestamp.
3333
34+
The timestamp is a string consisting of the following components in the given order:
35+
36+
- hours (optional, required if non-zero): two or more digits
37+
- minutes: two digits between 0 and 59
38+
- a colon character (:)
39+
- seconds: two digits between 0 and 59
40+
- a full stop character (.)
41+
- thousandths of a second: three digits
42+
3443
A WebVTT timestamp is always interpreted relative to the current playback position
3544
of the media data that the WebVTT file is to be synchronized with.
3645
"""
@@ -54,6 +63,7 @@ class _WebVTTTimestamp(BaseModel):
5463

5564
@model_validator(mode="after")
5665
def validate_raw(self) -> Self:
66+
"""Validate the WebVTT timestamp as a string."""
5767
m = self._pattern.match(self.raw)
5868
if not m:
5969
raise ValueError(f"Invalid WebVTT timestamp format: {self.raw}")
@@ -81,16 +91,15 @@ def seconds(self) -> float:
8191

8292
@override
8393
def __str__(self) -> str:
94+
"""Return a string representation of a WebVTT timestamp."""
8495
return self.raw
8596

8697

8798
class _WebVTTCueTimings(BaseModel):
8899
"""WebVTT cue timings."""
89100

90-
start: Annotated[
91-
_WebVTTTimestamp, Field(description="Start time offset of the cue")
92-
]
93-
end: Annotated[_WebVTTTimestamp, Field(description="End time offset of the cue")]
101+
start: Annotated[WebVTTTimestamp, Field(description="Start time offset of the cue")]
102+
end: Annotated[WebVTTTimestamp, Field(description="End time offset of the cue")]
94103

95104
@model_validator(mode="after")
96105
def check_order(self) -> Self:
@@ -224,6 +233,21 @@ def __str__(self):
224233
return f"<{self._get_name_with_classes()} {self.annotation}>"
225234

226235

236+
class _WebVTTCueLanguageSpanStartTag(_WebVTTCueSpanStartTagAnnotated):
237+
_bcp47_regex = re.compile(r"^[a-zA-Z]{2,3}(-[a-zA-Z0-9]{2,8})*$", re.IGNORECASE)
238+
239+
name: Literal["lang"] = Field("lang", description="The tag name")
240+
annotation: Annotated[
241+
str,
242+
Field(
243+
pattern=_bcp47_regex.pattern,
244+
min_length=2,
245+
max_length=99,
246+
description="Cue language span start tag annotation",
247+
),
248+
]
249+
250+
227251
class _WebVTTCueComponentBase(BaseModel):
228252
"""WebVTT caption or subtitle cue component.
229253
@@ -294,7 +318,7 @@ class _WebVTTCueLanguageSpan(_WebVTTCueComponentBase):
294318
"""
295319

296320
kind: Literal["lang"] = "lang"
297-
start_tag: _WebVTTCueSpanStartTagAnnotated
321+
start_tag: _WebVTTCueLanguageSpanStartTag
298322

299323

300324
_WebVTTCueComponent = Annotated[
@@ -369,7 +393,7 @@ def parse(cls, raw: str) -> "_WebVTTCueBlock":
369393
start, end = [t.strip() for t in timing_line.split("-->")]
370394
end = re.split(" |\t", end)[0] # ignore the cue settings list
371395
timings: _WebVTTCueTimings = _WebVTTCueTimings(
372-
start=_WebVTTTimestamp(raw=start), end=_WebVTTTimestamp(raw=end)
396+
start=WebVTTTimestamp(raw=start), end=WebVTTTimestamp(raw=end)
373397
)
374398
cue_text = " ".join(cue_lines).strip()
375399
# adding close tag for cue spans without end tag
@@ -409,13 +433,17 @@ def parse(cls, raw: str) -> "_WebVTTCueBlock":
409433
classes: list[str] = []
410434
if class_string:
411435
classes = [c for c in class_string.split(".") if c]
412-
st = (
413-
_WebVTTCueSpanStartTagAnnotated(
436+
st: _WebVTTCueSpanStartTag
437+
if annotation and ct == "lang":
438+
st = _WebVTTCueLanguageSpanStartTag(
414439
name=ct, classes=classes, annotation=annotation.strip()
415440
)
416-
if annotation
417-
else _WebVTTCueSpanStartTag(name=ct, classes=classes)
418-
)
441+
elif annotation:
442+
st = _WebVTTCueSpanStartTagAnnotated(
443+
name=ct, classes=classes, annotation=annotation.strip()
444+
)
445+
else:
446+
st = _WebVTTCueSpanStartTag(name=ct, classes=classes)
419447
it = _WebVTTCueInternalText(components=children)
420448
cp: _WebVTTCueComponent
421449
if ct == "c":

docs/DoclingDocument.json

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,7 +2143,7 @@
21432143
"description": "Provenance information for elements extracted from media assets.\n\nA `ProvenanceTrack` instance describes a cue in a text track associated with a\nmedia element (audio, video, subtitles, screen recordings, ...).",
21442144
"properties": {
21452145
"start_time": {
2146-
"$ref": "#/$defs/_WebVTTTimestamp",
2146+
"$ref": "#/$defs/WebVTTTimestamp",
21472147
"description": "Start time offset of the track cue",
21482148
"examples": [
21492149
"00.11.000",
@@ -2152,7 +2152,7 @@
21522152
]
21532153
},
21542154
"end_time": {
2155-
"$ref": "#/$defs/_WebVTTTimestamp",
2155+
"$ref": "#/$defs/WebVTTTimestamp",
21562156
"description": "End time offset of the track cue",
21572157
"examples": [
21582158
"00.12.000",
@@ -3067,8 +3067,8 @@
30673067
"title": "TitleItem",
30683068
"type": "object"
30693069
},
3070-
"_WebVTTTimestamp": {
3071-
"description": "WebVTT timestamp.\n\nA WebVTT timestamp is always interpreted relative to the current playback position\nof the media data that the WebVTT file is to be synchronized with.",
3070+
"WebVTTTimestamp": {
3071+
"description": "WebVTT timestamp.\n\nThe timestamp is a string consisting of the following components in the given order:\n\n- hours (optional, required if non-zero): two or more digits\n- minutes: two digits between 0 and 59\n- a colon character (:)\n- seconds: two digits between 0 and 59\n- a full stop character (.)\n- thousandths of a second: three digits\n\nA WebVTT timestamp is always interpreted relative to the current playback position\nof the media data that the WebVTT file is to be synchronized with.",
30723072
"properties": {
30733073
"raw": {
30743074
"description": "A representation of the WebVTT Timestamp as a single string",
@@ -3079,7 +3079,7 @@
30793079
"required": [
30803080
"raw"
30813081
],
3082-
"title": "_WebVTTTimestamp",
3082+
"title": "WebVTTTimestamp",
30833083
"type": "object"
30843084
}
30853085
},

test/test_webvtt.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@
99
from pydantic import ValidationError
1010

1111
from docling_core.types.doc.webvtt import (
12+
WebVTTTimestamp,
1213
_WebVTTCueBlock,
1314
_WebVTTCueComponentWithTerminator,
1415
_WebVTTCueInternalText,
1516
_WebVTTCueItalicSpan,
1617
_WebVTTCueLanguageSpan,
18+
_WebVTTCueLanguageSpanStartTag,
1719
_WebVTTCueSpanStartTagAnnotated,
1820
_WebVTTCueTextSpan,
1921
_WebVTTCueTimings,
2022
_WebVTTCueVoiceSpan,
2123
_WebVTTFile,
22-
_WebVTTTimestamp,
2324
)
2425

2526
from .test_data_gen_flag import GEN_TEST_DATA
@@ -42,7 +43,7 @@ def test_vtt_cue_commponents() -> None:
4243
0.0,
4344
]
4445
for idx, ts in enumerate(valid_timestamps):
45-
model = _WebVTTTimestamp(raw=ts)
46+
model = WebVTTTimestamp(raw=ts)
4647
assert model.seconds == valid_total_seconds[idx]
4748

4849
"""Test invalid WebVTT timestamps."""
@@ -57,35 +58,35 @@ def test_vtt_cue_commponents() -> None:
5758
]
5859
for ts in invalid_timestamps:
5960
with pytest.raises(ValidationError):
60-
_WebVTTTimestamp(raw=ts)
61+
WebVTTTimestamp(raw=ts)
6162

6263
"""Test the timestamp __str__ method."""
63-
model = _WebVTTTimestamp(raw="00:01:02.345")
64+
model = WebVTTTimestamp(raw="00:01:02.345")
6465
assert str(model) == "00:01:02.345"
6566

6667
"""Test valid cue timings."""
67-
start = _WebVTTTimestamp(raw="00:10.005")
68-
end = _WebVTTTimestamp(raw="00:14.007")
68+
start = WebVTTTimestamp(raw="00:10.005")
69+
end = WebVTTTimestamp(raw="00:14.007")
6970
cue_timings = _WebVTTCueTimings(start=start, end=end)
7071
assert cue_timings.start == start
7172
assert cue_timings.end == end
7273
assert str(cue_timings) == "00:10.005 --> 00:14.007"
7374

7475
"""Test invalid cue timings with end timestamp before start."""
75-
start = _WebVTTTimestamp(raw="00:10.700")
76-
end = _WebVTTTimestamp(raw="00:10.500")
76+
start = WebVTTTimestamp(raw="00:10.700")
77+
end = WebVTTTimestamp(raw="00:10.500")
7778
with pytest.raises(ValidationError) as excinfo:
7879
_WebVTTCueTimings(start=start, end=end)
7980
assert "End timestamp must be greater than start timestamp" in str(excinfo.value)
8081

8182
"""Test invalid cue timings with missing end."""
82-
start = _WebVTTTimestamp(raw="00:10.500")
83+
start = WebVTTTimestamp(raw="00:10.500")
8384
with pytest.raises(ValidationError) as excinfo:
8485
_WebVTTCueTimings(start=start) # type: ignore[call-arg]
8586
assert "Field required" in str(excinfo.value)
8687

8788
"""Test invalid cue timings with missing start."""
88-
end = _WebVTTTimestamp(raw="00:10.500")
89+
end = WebVTTTimestamp(raw="00:10.500")
8990
with pytest.raises(ValidationError) as excinfo:
9091
_WebVTTCueTimings(end=end) # type: ignore[call-arg]
9192
assert "Field required" in str(excinfo.value)
@@ -272,3 +273,13 @@ def test_webvtt_file() -> None:
272273
assert len(block.payload) == 1
273274
assert isinstance(block.payload[0].component, _WebVTTCueTextSpan)
274275
assert block.payload[0].component.text == "Good."
276+
277+
278+
def test_webvtt_cue_language_span_start_tag():
279+
_WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en"}')
280+
_WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en-US"}')
281+
_WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "zh-Hant"}')
282+
with pytest.raises(ValidationError, match="should match pattern"):
283+
_WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "en_US"}')
284+
with pytest.raises(ValidationError, match="should match pattern"):
285+
_WebVTTCueLanguageSpanStartTag.model_validate_json('{"annotation": "123-de"}')

0 commit comments

Comments
 (0)