Skip to content

Commit e584042

Browse files
authored
Further optimize added tokens splitter (#1265)
1 parent 75b352c commit e584042

File tree

3 files changed

+43
-72
lines changed

3 files changed

+43
-72
lines changed

src/tokenizers.js

+40-37
Original file line numberDiff line numberDiff line change
@@ -2598,21 +2598,13 @@ export class PreTrainedTokenizer extends Callable {
25982598
this.decoder.end_of_word_suffix = this.model.end_of_word_suffix;
25992599
}
26002600

2601-
// Divide added tokens into those that left/right strip, and those that don't
2602-
const added_tokens_with_strip = this.added_tokens.filter(x => x.rstrip || x.lstrip);
2603-
const added_tokens_without_strip = this.added_tokens.filter(x => !x.rstrip && !x.lstrip);
2604-
const split_regex = added_tokens_with_strip.length > 0 ? new RegExp(
2605-
added_tokens_with_strip.slice()
2606-
// Sort by length (desc) to avoid early partial matches
2607-
.sort((a, b) => b.content.length - a.content.length)
2608-
.map(x => `${x.lstrip ? '\\s*' : ''}(${escapeRegExp(x.content)})${x.rstrip ? '\\s*' : ''}`)
2609-
.join('|')
2610-
) : null;
26112601
this.added_tokens_splitter = new DictionarySplitter(
2612-
added_tokens_without_strip.map(x => x.content),
2613-
split_regex,
2602+
this.added_tokens.map(x => x.content),
26142603
);
26152604

2605+
/** @type {Map<string, AddedToken>} */
2606+
this.added_tokens_map = new Map(this.added_tokens.map(x => [x.content, x]))
2607+
26162608
// Set mask token if present (otherwise will be undefined, which is fine)
26172609
this.mask_token = this.getToken('mask_token');
26182610
this.mask_token_id = this.model.tokens_to_ids.get(this.mask_token);
@@ -2907,38 +2899,49 @@ export class PreTrainedTokenizer extends Callable {
29072899
// First, we take care of special tokens. Needed to avoid issues arising from
29082900
// normalization and/or pretokenization (which may not preserve special tokens)
29092901
const sections = this.added_tokens_splitter.split(text);
2910-
const tokens = sections.map((x, section_index) => {
2911-
const addedToken = this.added_tokens.find(t => t.content === x);
2912-
if (addedToken !== undefined) {
2913-
// Ignore added tokens
2914-
return x
2915-
} else {
2916-
if (this.remove_space === true) {
2917-
x = x.trim().split(/\s+/).join(' ');
2918-
}
2919-
if (this.do_lowercase_and_remove_accent) {
2920-
x = lowercase_and_remove_accent(x);
2921-
}
29222902

2923-
if (this.normalizer !== null) {
2924-
x = this.normalizer(x);
2903+
// Process left/right stripping of added tokens
2904+
for (let i = 0; i < sections.length; ++i) {
2905+
const addedToken = this.added_tokens_map.get(sections[i]);
2906+
if (addedToken) {
2907+
if (addedToken.lstrip && i > 0) {
2908+
sections[i - 1] = sections[i - 1].trimEnd();
29252909
}
2926-
2927-
// If, after normalization, this section is empty (e.g., trimming whitespace),
2928-
// we return an empty array
2929-
if (x.length === 0) {
2930-
return [];
2910+
if (addedToken.rstrip && i < sections.length - 1) {
2911+
sections[i + 1] = sections[i + 1].trimStart();
29312912
}
2913+
}
2914+
}
29322915

2933-
const sectionTokens = (this.pre_tokenizer !== null) ? this.pre_tokenizer(x, {
2934-
section_index,
2935-
}) : [x];
2916+
const tokens = sections.flatMap((x, section_index) => {
2917+
if (x.length === 0) return [];
2918+
if (this.added_tokens_map.has(x)) return [x]; // Return added tokens unchanged
29362919

2937-
const tokens = this.model(sectionTokens);
2920+
if (this.remove_space === true) {
2921+
x = x.trim().split(/\s+/).join(' ');
2922+
}
2923+
if (this.do_lowercase_and_remove_accent) {
2924+
x = lowercase_and_remove_accent(x);
2925+
}
2926+
2927+
if (this.normalizer !== null) {
2928+
x = this.normalizer(x);
2929+
}
29382930

2939-
return tokens;
2931+
// If, after normalization, this section is empty (e.g., trimming whitespace),
2932+
// we return an empty array
2933+
if (x.length === 0) {
2934+
return [];
29402935
}
2941-
}).flat();
2936+
2937+
const sectionTokens = (this.pre_tokenizer !== null) ? this.pre_tokenizer(x, {
2938+
section_index,
2939+
}) : [x];
2940+
2941+
const tokens = this.model(sectionTokens);
2942+
2943+
return tokens;
2944+
});
29422945

29432946
return tokens;
29442947
}

src/utils/data-structures.js

+1-17
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,9 @@ class TokenLatticeNode {
455455
export class DictionarySplitter {
456456
/**
457457
* @param {string[]} dictionary The dictionary of words to use for splitting.
458-
* @param {RegExp} [splitRegex] Optional split regex for preprocessing the input text.
459458
*/
460-
constructor(dictionary, splitRegex = null) {
459+
constructor(dictionary) {
461460
this.trie = this._buildTrie(dictionary);
462-
this.splitRegex = splitRegex;
463461
}
464462

465463
/**
@@ -486,20 +484,6 @@ export class DictionarySplitter {
486484
* @returns {string[]} An array of tokens.
487485
*/
488486
split(text) {
489-
return this.splitRegex ?
490-
text.split(this.splitRegex)
491-
.filter(x => x)
492-
.flatMap(x => this._splitSingle(x))
493-
: this._splitSingle(text)
494-
}
495-
496-
/**
497-
* Helper function to split a single text string into tokens.
498-
* @param {string} text The input text to split.
499-
* @returns {string[]} An array of tokens.
500-
* @private
501-
*/
502-
_splitSingle(text) {
503487
const result = [];
504488
const n = text.length;
505489
let start = 0;

tests/utils/data_structures.test.js

+2-18
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,16 @@ describe("Priority queue", () => {
3434

3535
describe("Dictionary splitter", () => {
3636
it("should split on a defined dictionary", () => {
37-
const splitter = new DictionarySplitter(
38-
["a", "b", "c", "abc"],
39-
null, // no split regex
40-
);
37+
const splitter = new DictionarySplitter(["a", "b", "c", "abc"]);
4138
const text = ".a.b.cc.abcdef.";
4239
const expected = [".", "a", ".", "b", ".", "c", "c", ".", "abc", "def."];
4340
const result = splitter.split(text);
4441
expect(result).toEqual(expected);
4542
});
46-
it("should split on a defined dictionary w/ split regex", () => {
47-
const splitter = new DictionarySplitter(
48-
["a", "b", "c", "abc"],
49-
/\s+/, // split on whitespace
50-
);
51-
const text = "a b c";
52-
const expected = ["a", "b", "c"];
53-
const result = splitter.split(text);
54-
expect(result).toEqual(expected);
55-
});
5643

5744
it("should handle multi-byte characters", () => {
5845
const text = "before🤗after\ud83etest";
59-
const splitter = new DictionarySplitter(
60-
["🤗" /* '\ud83e\udd17' */, "\ud83e"],
61-
null, // no split regex
62-
);
46+
const splitter = new DictionarySplitter(["🤗" /* '\ud83e\udd17' */, "\ud83e"]);
6347
const expected = ["before", "🤗", "after", "\ud83e", "test"];
6448
const result = splitter.split(text);
6549
expect(result).toEqual(expected);

0 commit comments

Comments
 (0)