@@ -2598,21 +2598,13 @@ export class PreTrainedTokenizer extends Callable {
2598
2598
this . decoder . end_of_word_suffix = this . model . end_of_word_suffix ;
2599
2599
}
2600
2600
2601
- // Divide added tokens into those that left/right strip, and those that don't
2602
- const added_tokens_with_strip = this . added_tokens . filter ( x => x . rstrip || x . lstrip ) ;
2603
- const added_tokens_without_strip = this . added_tokens . filter ( x => ! x . rstrip && ! x . lstrip ) ;
2604
- const split_regex = added_tokens_with_strip . length > 0 ? new RegExp (
2605
- added_tokens_with_strip . slice ( )
2606
- // Sort by length (desc) to avoid early partial matches
2607
- . sort ( ( a , b ) => b . content . length - a . content . length )
2608
- . map ( x => `${ x . lstrip ? '\\s*' : '' } (${ escapeRegExp ( x . content ) } )${ x . rstrip ? '\\s*' : '' } ` )
2609
- . join ( '|' )
2610
- ) : null ;
2611
2601
this . added_tokens_splitter = new DictionarySplitter (
2612
- added_tokens_without_strip . map ( x => x . content ) ,
2613
- split_regex ,
2602
+ this . added_tokens . map ( x => x . content ) ,
2614
2603
) ;
2615
2604
2605
+ /** @type {Map<string, AddedToken> } */
2606
+ this . added_tokens_map = new Map ( this . added_tokens . map ( x => [ x . content , x ] ) )
2607
+
2616
2608
// Set mask token if present (otherwise will be undefined, which is fine)
2617
2609
this . mask_token = this . getToken ( 'mask_token' ) ;
2618
2610
this . mask_token_id = this . model . tokens_to_ids . get ( this . mask_token ) ;
@@ -2907,38 +2899,49 @@ export class PreTrainedTokenizer extends Callable {
2907
2899
// First, we take care of special tokens. Needed to avoid issues arising from
2908
2900
// normalization and/or pretokenization (which may not preserve special tokens)
2909
2901
const sections = this . added_tokens_splitter . split ( text ) ;
2910
- const tokens = sections . map ( ( x , section_index ) => {
2911
- const addedToken = this . added_tokens . find ( t => t . content === x ) ;
2912
- if ( addedToken !== undefined ) {
2913
- // Ignore added tokens
2914
- return x
2915
- } else {
2916
- if ( this . remove_space === true ) {
2917
- x = x . trim ( ) . split ( / \s + / ) . join ( ' ' ) ;
2918
- }
2919
- if ( this . do_lowercase_and_remove_accent ) {
2920
- x = lowercase_and_remove_accent ( x ) ;
2921
- }
2922
2902
2923
- if ( this . normalizer !== null ) {
2924
- x = this . normalizer ( x ) ;
2903
+ // Process left/right stripping of added tokens
2904
+ for ( let i = 0 ; i < sections . length ; ++ i ) {
2905
+ const addedToken = this . added_tokens_map . get ( sections [ i ] ) ;
2906
+ if ( addedToken ) {
2907
+ if ( addedToken . lstrip && i > 0 ) {
2908
+ sections [ i - 1 ] = sections [ i - 1 ] . trimEnd ( ) ;
2925
2909
}
2926
-
2927
- // If, after normalization, this section is empty (e.g., trimming whitespace),
2928
- // we return an empty array
2929
- if ( x . length === 0 ) {
2930
- return [ ] ;
2910
+ if ( addedToken . rstrip && i < sections . length - 1 ) {
2911
+ sections [ i + 1 ] = sections [ i + 1 ] . trimStart ( ) ;
2931
2912
}
2913
+ }
2914
+ }
2932
2915
2933
- const sectionTokens = ( this . pre_tokenizer !== null ) ? this . pre_tokenizer ( x , {
2934
- section_index ,
2935
- } ) : [ x ] ;
2916
+ const tokens = sections . flatMap ( ( x , section_index ) => {
2917
+ if ( x . length === 0 ) return [ ] ;
2918
+ if ( this . added_tokens_map . has ( x ) ) return [ x ] ; // Return added tokens unchanged
2936
2919
2937
- const tokens = this . model ( sectionTokens ) ;
2920
+ if ( this . remove_space === true ) {
2921
+ x = x . trim ( ) . split ( / \s + / ) . join ( ' ' ) ;
2922
+ }
2923
+ if ( this . do_lowercase_and_remove_accent ) {
2924
+ x = lowercase_and_remove_accent ( x ) ;
2925
+ }
2926
+
2927
+ if ( this . normalizer !== null ) {
2928
+ x = this . normalizer ( x ) ;
2929
+ }
2938
2930
2939
- return tokens ;
2931
+ // If, after normalization, this section is empty (e.g., trimming whitespace),
2932
+ // we return an empty array
2933
+ if ( x . length === 0 ) {
2934
+ return [ ] ;
2940
2935
}
2941
- } ) . flat ( ) ;
2936
+
2937
+ const sectionTokens = ( this . pre_tokenizer !== null ) ? this . pre_tokenizer ( x , {
2938
+ section_index,
2939
+ } ) : [ x ] ;
2940
+
2941
+ const tokens = this . model ( sectionTokens ) ;
2942
+
2943
+ return tokens ;
2944
+ } ) ;
2942
2945
2943
2946
return tokens ;
2944
2947
}
0 commit comments