Skip to content

Commit 73c57aa

Browse files
authored
Merge pull request #3 from symanto-research/feature/token-to-char-hf
HF spans can be passed to greedy coverage
2 parents e15406f + 857a0ea commit 73c57aa

File tree

9 files changed

+171
-94
lines changed

9 files changed

+171
-94
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ or the following equation if `word_ids` are not passed:
8383
$$\textrm{match}(x_i) = \underset{i-k\leq j\leq i+k}{\textrm{min}}\ \textrm{dist}(x_i, y_j)$$
8484
It is recommended to use a large radius `k` (e.g., 30) to avoid introducing matching errors at the end of the sequence if the "speed" of the tokenizations varies a lot.
8585

86-
**Greedy-coverage**: aligns the tokens from two different tokenizers, using a greedy matching algorithm based on text coverage. This algorithm remove whitespaces from the text, and finds the positions (start, end) that each token covers in the text without whitespaces. Once we have the lists of (start, end) for each token and for each tokenization, we merge the tokens of the second tokenization that are spanned by the tokens of the first tokenization. For instance, having computed $spans_a$ = [(0, 5), (5, 13), (13, 23)] and $spans_b$ = [(0, 4), (5, 8), (8, 11), (11, 14), (15, 19), (19, 21), (21, 23)], the alignment will be [(0, [0]), (1, [1, 2, 3]), (2, [4, 5, 6])]. `merge-tokenizers` provides a C and a Python implementation of this algorithm.
86+
**Greedy-coverage**: aligns the tokens from two different tokenizers, using a greedy matching algorithm based on text coverage. This algorithm first remove whitespaces from the text, and finds the char positions (start, end) that each token covers in the text without whitespaces. This step can be avoided if you pass the char spans that each token covers, for instance, using `token_to_chars` from HuggingFace tokenizers. Once we have the lists of (start, end) for each token and for each tokenization, we merge the tokens of the second tokenization that are spanned by the tokens of the first tokenization. For instance, having computed $spans_a$ = [(0, 5), (5, 13), (13, 23)] and $spans_b$ = [(0, 4), (5, 8), (8, 11), (11, 14), (15, 19), (19, 21), (21, 23)], the alignment will be [(0, [0]), (1, [1, 2, 3]), (2, [4, 5, 6])]. `merge-tokenizers` provides a C and a Python implementation of this algorithm.
8787

8888
# 🔎 What algorithm should I use?
8989

assets/benchmark.md

+64-64
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,66 @@
11
| Tokens | Algorithm | Mean | Std |
22
|:---------|:----------------------------|------------:|------------:|
3-
| 64-32 | DTWAligner | 0.000810194 | 0.00012322 |
4-
| 64-32 | PythonDTWAligner | 0.004805 | 0.0382287 |
5-
| 64-32 | GreedyDistanceAligner | 0.000986452 | 0.00013143 |
6-
| 64-32 | PythonGreedyCoverageAligner | 0.000352281 | 3.61758e-05 |
7-
| 64-32 | GreedyCoverageAligner | 0.000657192 | 0.00500252 |
8-
| 64-32 | FastDTWAligner | 0.0017894 | 0.000227047 |
9-
| 64-32 | TamuheyAligner | 0.000297001 | 3.36874e-05 |
10-
| 64-32 | WordIdsAligner | 0.000164852 | 1.55548e-05 |
11-
| 64-64 | DTWAligner | 0.00134678 | 0.000129305 |
12-
| 64-64 | PythonDTWAligner | 0.00585662 | 0.000402595 |
13-
| 64-64 | GreedyDistanceAligner | 0.00167739 | 0.000128032 |
14-
| 64-64 | PythonGreedyCoverageAligner | 0.000391849 | 3.21681e-05 |
15-
| 64-64 | GreedyCoverageAligner | 0.000489788 | 3.76893e-05 |
16-
| 64-64 | FastDTWAligner | 0.0021804 | 0.000175052 |
17-
| 64-64 | TamuheyAligner | 0.000346222 | 3.42165e-05 |
18-
| 64-64 | WordIdsAligner | 0.000283829 | 2.9102e-05 |
19-
| 128-64 | DTWAligner | 0.00264733 | 0.00027827 |
20-
| 128-64 | PythonDTWAligner | 0.0115429 | 0.000933241 |
21-
| 128-64 | GreedyDistanceAligner | 0.00237439 | 0.004617 |
22-
| 128-64 | PythonGreedyCoverageAligner | 0.000677936 | 5.42784e-05 |
23-
| 128-64 | GreedyCoverageAligner | 0.00105866 | 0.00445993 |
24-
| 128-64 | FastDTWAligner | 0.00313382 | 0.00021815 |
25-
| 128-64 | TamuheyAligner | 0.000589425 | 4.43911e-05 |
26-
| 128-64 | WordIdsAligner | 0.000301838 | 2.67991e-05 |
27-
| 128-128 | DTWAligner | 0.00478876 | 0.000480806 |
28-
| 128-128 | PythonDTWAligner | 0.0224406 | 0.00123022 |
29-
| 128-128 | GreedyDistanceAligner | 0.00401765 | 0.00395362 |
30-
| 128-128 | PythonGreedyCoverageAligner | 0.00093948 | 0.00425627 |
31-
| 128-128 | GreedyCoverageAligner | 0.00116883 | 0.00507499 |
32-
| 128-128 | FastDTWAligner | 0.00401919 | 0.00448164 |
33-
| 128-128 | TamuheyAligner | 0.000688969 | 4.54692e-05 |
34-
| 128-128 | WordIdsAligner | 0.000563423 | 3.19658e-05 |
35-
| 256-128 | DTWAligner | 0.00928647 | 0.00389654 |
36-
| 256-128 | PythonDTWAligner | 0.0441353 | 0.00618959 |
37-
| 256-128 | GreedyDistanceAligner | 0.00512323 | 0.00615063 |
38-
| 256-128 | PythonGreedyCoverageAligner | 0.00159762 | 0.00507555 |
39-
| 256-128 | GreedyCoverageAligner | 0.00177792 | 0.00383042 |
40-
| 256-128 | FastDTWAligner | 0.00586248 | 0.00397453 |
41-
| 256-128 | TamuheyAligner | 0.00151315 | 0.00547704 |
42-
| 256-128 | WordIdsAligner | 0.000931732 | 0.00502417 |
43-
| 256-256 | DTWAligner | 0.0170545 | 0.00156439 |
44-
| 256-256 | PythonDTWAligner | 0.0849948 | 0.0103587 |
45-
| 256-256 | GreedyDistanceAligner | 0.00850646 | 0.000559411 |
46-
| 256-256 | PythonGreedyCoverageAligner | 0.00142612 | 0.000129661 |
47-
| 256-256 | GreedyCoverageAligner | 0.00181911 | 0.000135627 |
48-
| 256-256 | FastDTWAligner | 0.00791399 | 0.0084514 |
49-
| 256-256 | TamuheyAligner | 0.00135891 | 0.000101657 |
50-
| 256-256 | WordIdsAligner | 0.00251443 | 0.0110405 |
51-
| 512-256 | DTWAligner | 0.0278183 | 0.0090862 |
52-
| 512-256 | PythonDTWAligner | 0.138419 | 0.0380331 |
53-
| 512-256 | GreedyDistanceAligner | 0.0109791 | 0.0111162 |
54-
| 512-256 | PythonGreedyCoverageAligner | 0.0030857 | 0.00920703 |
55-
| 512-256 | GreedyCoverageAligner | 0.0030004 | 0.00516345 |
56-
| 512-256 | FastDTWAligner | 0.0108567 | 0.0104887 |
57-
| 512-256 | TamuheyAligner | 0.00249206 | 0.00669024 |
58-
| 512-256 | WordIdsAligner | 0.00149341 | 0.0050352 |
59-
| 512-512 | DTWAligner | 0.0473118 | 0.0210875 |
60-
| 512-512 | PythonDTWAligner | 0.238191 | 0.100658 |
61-
| 512-512 | GreedyDistanceAligner | 0.0154189 | 0.0086549 |
62-
| 512-512 | PythonGreedyCoverageAligner | 0.003124 | 0.00870312 |
63-
| 512-512 | GreedyCoverageAligner | 0.0039112 | 0.0088561 |
64-
| 512-512 | FastDTWAligner | 0.0126372 | 0.0108067 |
65-
| 512-512 | TamuheyAligner | 0.0027519 | 0.0064969 |
66-
| 512-512 | WordIdsAligner | 0.00279659 | 0.0089556 |
3+
| 64-32 | DTWAligner | 0.000861357 | 0.000186175 |
4+
| 64-32 | PythonDTWAligner | 0.00488161 | 0.0384466 |
5+
| 64-32 | GreedyDistanceAligner | 0.00101318 | 8.03171e-05 |
6+
| 64-32 | PythonGreedyCoverageAligner | 0.000289684 | 2.87963e-05 |
7+
| 64-32 | GreedyCoverageAligner | 0.000599975 | 0.00527034 |
8+
| 64-32 | FastDTWAligner | 0.00185941 | 0.00019357 |
9+
| 64-32 | TamuheyAligner | 0.000320376 | 3.99697e-05 |
10+
| 64-32 | WordIdsAligner | 0.000179664 | 1.95877e-05 |
11+
| 64-64 | DTWAligner | 0.00136892 | 0.000118518 |
12+
| 64-64 | PythonDTWAligner | 0.00584478 | 0.000417107 |
13+
| 64-64 | GreedyDistanceAligner | 0.00168752 | 0.000148426 |
14+
| 64-64 | PythonGreedyCoverageAligner | 0.000310445 | 3.35894e-05 |
15+
| 64-64 | GreedyCoverageAligner | 0.000395369 | 2.96585e-05 |
16+
| 64-64 | FastDTWAligner | 0.00219548 | 0.000157088 |
17+
| 64-64 | TamuheyAligner | 0.00036602 | 3.45161e-05 |
18+
| 64-64 | WordIdsAligner | 0.000297422 | 2.06211e-05 |
19+
| 128-64 | DTWAligner | 0.00271824 | 0.000200187 |
20+
| 128-64 | PythonDTWAligner | 0.0115742 | 0.000744274 |
21+
| 128-64 | GreedyDistanceAligner | 0.00239153 | 0.00425989 |
22+
| 128-64 | PythonGreedyCoverageAligner | 0.000585796 | 0.000134413 |
23+
| 128-64 | GreedyCoverageAligner | 0.000711155 | 7.19061e-05 |
24+
| 128-64 | FastDTWAligner | 0.00339628 | 0.00433525 |
25+
| 128-64 | TamuheyAligner | 0.000844504 | 0.00459188 |
26+
| 128-64 | WordIdsAligner | 0.00032799 | 3.18738e-05 |
27+
| 128-128 | DTWAligner | 0.00509727 | 0.00440907 |
28+
| 128-128 | PythonDTWAligner | 0.0226934 | 0.004201 |
29+
| 128-128 | GreedyDistanceAligner | 0.00391005 | 0.000248378 |
30+
| 128-128 | PythonGreedyCoverageAligner | 0.000805967 | 0.00402241 |
31+
| 128-128 | GreedyCoverageAligner | 0.00079272 | 5.0157e-05 |
32+
| 128-128 | FastDTWAligner | 0.00392676 | 0.000197439 |
33+
| 128-128 | TamuheyAligner | 0.000951038 | 0.00426009 |
34+
| 128-128 | WordIdsAligner | 0.000611856 | 3.1337e-05 |
35+
| 256-128 | DTWAligner | 0.00987252 | 0.00110288 |
36+
| 256-128 | PythonDTWAligner | 0.0465292 | 0.00897979 |
37+
| 256-128 | GreedyDistanceAligner | 0.00516035 | 0.000530812 |
38+
| 256-128 | PythonGreedyCoverageAligner | 0.00173597 | 0.00690275 |
39+
| 256-128 | GreedyCoverageAligner | 0.00183678 | 0.00602486 |
40+
| 256-128 | FastDTWAligner | 0.00713865 | 0.00992345 |
41+
| 256-128 | TamuheyAligner | 0.00192115 | 0.00740615 |
42+
| 256-128 | WordIdsAligner | 0.000898538 | 0.00400415 |
43+
| 256-256 | DTWAligner | 0.0191239 | 0.00894593 |
44+
| 256-256 | PythonDTWAligner | 0.0894118 | 0.0111358 |
45+
| 256-256 | GreedyDistanceAligner | 0.00993978 | 0.00830519 |
46+
| 256-256 | PythonGreedyCoverageAligner | 0.00216607 | 0.0092726 |
47+
| 256-256 | GreedyCoverageAligner | 0.00206758 | 0.00684659 |
48+
| 256-256 | FastDTWAligner | 0.00829801 | 0.00766549 |
49+
| 256-256 | TamuheyAligner | 0.00181568 | 0.00433226 |
50+
| 256-256 | WordIdsAligner | 0.00127796 | 0.000218844 |
51+
| 512-256 | DTWAligner | 0.0302889 | 0.0120761 |
52+
| 512-256 | PythonDTWAligner | 0.143709 | 0.0406983 |
53+
| 512-256 | GreedyDistanceAligner | 0.0108034 | 0.00822557 |
54+
| 512-256 | PythonGreedyCoverageAligner | 0.00318358 | 0.0107403 |
55+
| 512-256 | GreedyCoverageAligner | 0.00435857 | 0.0134328 |
56+
| 512-256 | FastDTWAligner | 0.0112162 | 0.0100254 |
57+
| 512-256 | TamuheyAligner | 0.0028252 | 0.00688869 |
58+
| 512-256 | WordIdsAligner | 0.0019157 | 0.00707873 |
59+
| 512-512 | DTWAligner | 0.0483401 | 0.0218033 |
60+
| 512-512 | PythonDTWAligner | 0.241758 | 0.103833 |
61+
| 512-512 | GreedyDistanceAligner | 0.0151268 | 0.00600527 |
62+
| 512-512 | PythonGreedyCoverageAligner | 0.0022175 | 0.00379141 |
63+
| 512-512 | GreedyCoverageAligner | 0.00377235 | 0.0106682 |
64+
| 512-512 | FastDTWAligner | 0.0116152 | 0.00300244 |
65+
| 512-512 | TamuheyAligner | 0.00515227 | 0.0154369 |
66+
| 512-512 | WordIdsAligner | 0.00371507 | 0.0122044 |

assets/benchmark.png

425 Bytes
Loading

merge_tokenizers/aligners/base.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def align_pair(
4444
tokenized_pair.preprocessed_tokens_b = preprocess_tokens(
4545
tokenized_pair.tokens_b
4646
)
47+
4748
# If both tokenizations are the same, return 1-1 alignment
4849
if (
4950
tokenized_pair.preprocessed_tokens_a
@@ -69,20 +70,30 @@ def align(self, tokenized_set: TokenizedSet) -> List[Alignment]:
6970
if tokenized_set.word_ids
7071
else [[] for _ in range(len(tokenized_set.tokens))]
7172
)
73+
spans = (
74+
tokenized_set.spans
75+
if tokenized_set.spans
76+
else [[] for _ in range(len(tokenized_set.tokens))]
77+
)
78+
7279
tokens_a = tokenized_set.tokens[0]
7380
word_ids_a = word_ids[0]
81+
spans_a = spans[0]
82+
7483
return [
7584
self.align_pair(
7685
TokenizedPair(
7786
tokens_a=tokens_a,
7887
tokens_b=tokens_b,
7988
word_ids_a=word_ids_a,
8089
word_ids_b=word_ids_b,
90+
spans_a=spans_a,
91+
spans_b=spans_b,
8192
text=tokenized_set.text,
8293
)
8394
)
84-
for tokens_b, word_ids_b in zip(
85-
tokenized_set.tokens[1:], word_ids[1:]
95+
for tokens_b, word_ids_b, spans_b in zip(
96+
tokenized_set.tokens[1:], word_ids[1:], spans[1:]
8697
)
8798
]
8899

@@ -164,16 +175,24 @@ def aggregate_features(
164175
else [[] for _ in range(len(tokenized_set.tokens))]
165176
)
166177

178+
spans = (
179+
tokenized_set.spans
180+
if tokenized_set.spans
181+
else [[] for _ in range(len(tokenized_set.tokens))]
182+
)
183+
167184
tokens_a = tokenized_set.tokens[0]
168185
word_ids_a = word_ids[0]
186+
spans_a = spans[0]
169187
features_a = tokenized_set.features[0]
170188
merged_features = []
171189

172-
for idx, (tokens_b, features_b, word_ids_b) in enumerate(
190+
for idx, (tokens_b, features_b, word_ids_b, spans_b) in enumerate(
173191
zip(
174192
tokenized_set.tokens[1:],
175193
tokenized_set.features[1:],
176194
word_ids[1:],
195+
spans[1:],
177196
)
178197
):
179198
merged_features.append(
@@ -183,6 +202,8 @@ def aggregate_features(
183202
tokens_b=tokens_b,
184203
word_ids_a=word_ids_a,
185204
word_ids_b=word_ids_b,
205+
spans_a=spans_a,
206+
spans_b=spans_b,
186207
features_a=features_a,
187208
features_b=features_b,
188209
text=tokenized_set.text,

merge_tokenizers/aligners/greedy_coverage.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -73,27 +73,33 @@ def _align_pair(
7373
7474
will result in [(0, [0]), (1, [1, 2, 3]), (2, [4, 5, 6])]
7575
"""
76-
text = tokenized_pair.text.lower().replace(" ", "").encode("utf-8")
7776

7877
# Get the span covered by each token
7978
spans = {}
80-
for tokenization, preprocessed_tokens in {
81-
"a": tokenized_pair.preprocessed_tokens_a,
82-
"b": tokenized_pair.preprocessed_tokens_b,
83-
}.items():
84-
ptr = (ctypes.c_char_p * len(preprocessed_tokens))(
85-
*[token.encode("utf-8") for token in preprocessed_tokens]
86-
)
87-
c_spans = self.c_get_spans(
88-
ptr,
89-
text,
90-
len(preprocessed_tokens),
91-
)
92-
spans[tokenization] = [
93-
(c_spans[i].start, c_spans[i].end)
94-
for i in range(len(preprocessed_tokens))
95-
]
96-
self.c_free_spans(c_spans)
79+
# If the spans covering the text are not passed, compute them.
80+
if not tokenized_pair.spans_a and not tokenized_pair.spans_b:
81+
text = tokenized_pair.text.lower().replace(" ", "").encode("utf-8")
82+
for tokenization, preprocessed_tokens in {
83+
"a": tokenized_pair.preprocessed_tokens_a,
84+
"b": tokenized_pair.preprocessed_tokens_b,
85+
}.items():
86+
ptr = (ctypes.c_char_p * len(preprocessed_tokens))(
87+
*[token.encode("utf-8") for token in preprocessed_tokens]
88+
)
89+
c_spans = self.c_get_spans(
90+
ptr,
91+
text,
92+
len(preprocessed_tokens),
93+
)
94+
spans[tokenization] = [
95+
(c_spans[i].start, c_spans[i].end)
96+
for i in range(len(preprocessed_tokens))
97+
]
98+
self.c_free_spans(c_spans)
99+
# Otherwise, use them
100+
else:
101+
spans["a"] = tokenized_pair.spans_a
102+
spans["b"] = tokenized_pair.spans_b
97103

98104
# Merge the spans
99105
c_spans_a = (Tuple * len(spans["a"]))(*spans["a"]) # type: ignore

merge_tokenizers/aligners/greedy_coverage_py.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,22 @@ def _align_pair(
117117
118118
will result in [(0, [0]), (1, [1, 2, 3]), (2, [4, 5, 6])]
119119
"""
120-
text = tokenized_pair.text.lower().replace(" ", "")
121120

122-
# Get spans and align
123-
spans_a = get_spans(tokenized_pair.preprocessed_tokens_a, text)
124-
spans_b = get_spans(tokenized_pair.preprocessed_tokens_b, text)
121+
# Get spans
122+
# If the spans covering the text are not passed, compute them.
123+
if not tokenized_pair.spans_a and not tokenized_pair.spans_b:
124+
assert (
125+
tokenized_pair.text
126+
), "`text` must be passed as argument when not passing `span_a` and `span_b`"
127+
text = tokenized_pair.text.lower().replace(" ", "")
128+
spans_a = get_spans(tokenized_pair.preprocessed_tokens_a, text)
129+
spans_b = get_spans(tokenized_pair.preprocessed_tokens_b, text)
130+
# Otherwise, use them.
131+
else:
132+
spans_a = tokenized_pair.spans_a
133+
spans_b = tokenized_pair.spans_b
134+
135+
# Align spans
125136
alignments = merge_spans(spans_a, spans_b)
126137

127138
# Merge alignments

merge_tokenizers/types.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Tuple
22

33
import numpy as np
44
from pydantic import BaseModel, field_validator, model_validator
@@ -13,6 +13,8 @@ class TokenizedPair(BaseModel):
1313
tokens_b: List[str]
1414
word_ids_a: List[int] = []
1515
word_ids_b: List[int] = []
16+
spans_a: List[Tuple[int, int]] = []
17+
spans_b: List[Tuple[int, int]] = []
1618
preprocessed_tokens_a: List[str] = []
1719
preprocessed_tokens_b: List[str] = []
1820
text: str = ""
@@ -21,7 +23,7 @@ class TokenizedPair(BaseModel):
2123

2224
@field_validator("word_ids_a", "word_ids_b", mode="before")
2325
@classmethod
24-
def prepare_word_ids_a(cls, word_ids):
26+
def prepare_word_ids(cls, word_ids):
2527
if word_ids:
2628
if word_ids[0] is None:
2729
word_ids[0] = -1
@@ -31,6 +33,18 @@ def prepare_word_ids_a(cls, word_ids):
3133
else:
3234
return []
3335

36+
@field_validator("spans_a", "spans_b", mode="before")
37+
@classmethod
38+
def prepare_spans(cls, spans):
39+
if spans:
40+
if spans[0] is None:
41+
spans[0] = (-1, -1)
42+
if spans[-1] is None:
43+
spans[-1] = (-1, -1)
44+
return spans
45+
else:
46+
return []
47+
3448
class Config:
3549
arbitrary_types_allowed = True
3650

@@ -42,6 +56,7 @@ class TokenizedSet(BaseModel):
4256

4357
tokens: List[List[str]]
4458
word_ids: List[List[int]] = []
59+
spans: List[List[Tuple[int, int]]] = []
4560
features: List[np.ndarray] = []
4661
text: str = ""
4762

@@ -58,7 +73,7 @@ def check_len_word_ids(self) -> "TokenizedSet":
5873

5974
@field_validator("word_ids", mode="before")
6075
@classmethod
61-
def prepare_word_ids_a(cls, _word_ids):
76+
def prepare_word_ids(cls, _word_ids):
6277
new_word_ids = []
6378
if _word_ids:
6479
for word_ids in _word_ids:
@@ -71,6 +86,21 @@ def prepare_word_ids_a(cls, _word_ids):
7186
else:
7287
return []
7388

89+
@field_validator("spans", mode="before")
90+
@classmethod
91+
def prepare_spans(cls, _spans):
92+
new_spans = []
93+
if _spans:
94+
for spans in _spans:
95+
if spans[0] is None:
96+
spans[0] = (-1, -1)
97+
if spans[-1] is None:
98+
spans[-1] = (-1, -1)
99+
new_spans.append(spans)
100+
return new_spans
101+
else:
102+
return []
103+
74104

75105
class PositionAlignment(BaseModel):
76106
"""

merge_tokenizers/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
_MAJOR = "0"
22
_MINOR = "0"
3-
_REVISION = "5"
3+
_REVISION = "6"
44

55
VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
66
VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _REVISION)

0 commit comments

Comments
 (0)