Skip to content

Commit b452eb1

Browse files
committed
tokenizers respect padding: true with non-null max_length
This commit changes the behavior of tokenizers to match the behavior described in the docs and the behavior of the Python library. Before this commit, passing { padding: true, max_length: 512 } or { padding: 'max_length', max_length: 512 } would both always pad all outputs to 512 tokens. After this change, { padding: true, max_length: 512 } will now pad the outputs to match the longest encoding or max_length, whichever is shorter. This commit also adds a test to prevent regressions.
1 parent 6d47745 commit b452eb1

File tree

2 files changed

+57
-11
lines changed

2 files changed

+57
-11
lines changed

src/tokenizers.js

+10-11
Original file line numberDiff line numberDiff line change
@@ -2790,17 +2790,16 @@ export class PreTrainedTokenizer extends Callable {
27902790
// At this point, tokens is batched: [batch_size, tokens]
27912791
// However, array may be jagged. So, we pad to max_length
27922792

2793-
if (max_length === null) {
2794-
if (padding === 'max_length') {
2795-
max_length = this.model_max_length;
2796-
} else {
2797-
// Calculate max length from sequences
2798-
max_length = max(encodedTokens.map(x => x.input_ids.length))[0];
2799-
}
2800-
} else {
2801-
if (!truncation) {
2802-
console.warn(`Truncation was not explicitly activated but \`max_length\` is provided a specific value, please use \`truncation=true\` to explicitly truncate examples to max length.`)
2803-
}
2793+
if (truncation && max_length === null) {
2794+
max_length = this.model_max_length;
2795+
} else if (max_length && truncation === null) {
2796+
console.warn(`Truncation was not explicitly activated but \`max_length\` is provided a specific value, please use \`truncation=true\` to explicitly truncate examples to max length.`)
2797+
}
2798+
2799+
// padding: 'max_length' doesn't require any additional calculation
2800+
// but padding: true has to calculate max_length from the sequences
2801+
if (padding === true) {
2802+
max_length = Math.min(max(encodedTokens.map(x => x.input_ids.length))[0], max_length ?? Infinity);
28042803
}
28052804

28062805
// Ensure it is less than model max length

tests/tokenizers.test.js

+47
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,53 @@ describe("Tokenizer padding/truncation", () => {
180180
[0n, 0n],
181181
]);
182182
}
183+
184+
{
185+
// padding: true should pad encodings to match the longest encoding in the batch,
186+
// regardless of what is set in max_length
187+
const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, {
188+
padding: true,
189+
truncation: true,
190+
add_special_tokens: false,
191+
max_length: 3,
192+
});
193+
194+
expect(input_ids.tolist()).toEqual([
195+
[1037n, 0n],
196+
[1038n, 1039n],
197+
]);
198+
expect(attention_mask.tolist()).toEqual([
199+
[1n, 0n],
200+
[1n, 1n],
201+
]);
202+
expect(token_type_ids.tolist()).toEqual([
203+
[0n, 0n],
204+
[0n, 0n],
205+
]);
206+
}
207+
208+
{
209+
// padding: 'max_length' should pad encodings to match max_length
210+
const { input_ids, attention_mask, token_type_ids } = tokenizer(inputs, {
211+
padding: 'max_length',
212+
truncation: true,
213+
add_special_tokens: false,
214+
max_length: 3,
215+
});
216+
217+
expect(input_ids.tolist()).toEqual([
218+
[1037n, 0n, 0n],
219+
[1038n, 1039n, 0n],
220+
]);
221+
expect(attention_mask.tolist()).toEqual([
222+
[1n, 0n, 0n],
223+
[1n, 1n, 0n],
224+
]);
225+
expect(token_type_ids.tolist()).toEqual([
226+
[0n, 0n, 0n],
227+
[0n, 0n, 0n],
228+
]);
229+
}
183230
},
184231
MAX_TEST_EXECUTION_TIME,
185232
);

0 commit comments

Comments
 (0)