Skip to content

Commit 40be9ec

Browse files
committed
Replace invalid utf-8 sequences with the replacement character
Before, when parsing message content, if we came across any invalid utf-8 sequences, we'd forever accumulate them in `undecoded_tokens` and any subsequent content would get dropped when we eventually found our next stop token. Now, we detect invalid utf-8 sequences and replace them with the utf-8 replacement character '\uFFFD' and continue parsing further content. In real-world scenarios, sometimes invalid utf-8 sequences are being generated by gpt-oss models. This could be caused by too high temperature settings, prompts with extensive usage of utf-8 characters in unexpected ways that are outside the training datasets, or some combination of both. The net effect is that parsing will continue making forward progress after we hit an invalid utf-8 sequence, which is important for scenarios where inference servers are generating streaming long message contents and the users will expect those tokens to be streamed back as they're generated instead of buffered for long periods of time in our `StreamableParser`. See vllm-project/vllm#26480 for one such real-world scenario encountered in vLLM.
1 parent 508cbaa commit 40be9ec

File tree

2 files changed

+215
-5
lines changed

2 files changed

+215
-5
lines changed

src/encoding.rs

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,7 @@ pub struct StreamableParser {
10321032
stop_tokens: HashSet<Rank>,
10331033
last_content_delta: Option<String>,
10341034
undecoded_tokens: Vec<Rank>,
1035+
undecoded_bytes: Vec<u8>,
10351036
}
10361037

10371038
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
@@ -1068,6 +1069,7 @@ impl StreamableParser {
10681069
stop_tokens,
10691070
last_content_delta: None,
10701071
undecoded_tokens: Vec::new(),
1072+
undecoded_bytes: Vec::new(),
10711073
})
10721074
}
10731075

@@ -1148,14 +1150,60 @@ impl StreamableParser {
11481150
match self
11491151
.encoding
11501152
.tokenizer()
1151-
.decode_utf8(&self.undecoded_tokens)
1153+
.decode_bytes(&self.undecoded_tokens)
11521154
{
1153-
Ok(decoded) => {
1154-
content_tokens.extend(self.undecoded_tokens.iter().copied());
1155-
self.last_content_delta = Some(decoded);
1155+
Ok(decoded_bytes) => {
1156+
self.undecoded_bytes.extend(decoded_bytes.iter().copied());
1157+
match String::from_utf8(self.undecoded_bytes.clone()) {
1158+
Ok(decoded_str) => {
1159+
self.encoding
1160+
.render_text_into(&decoded_str, content_tokens)?;
1161+
self.last_content_delta = Some(decoded_str);
1162+
self.undecoded_bytes.clear();
1163+
}
1164+
Err(e) => {
1165+
let utf8_error = e.utf8_error();
1166+
let decoded_bytes = e.into_bytes();
1167+
1168+
let valid_len = utf8_error.valid_up_to();
1169+
1170+
let mut content_delta = String::new();
1171+
if valid_len > 0 {
1172+
let valid_str = String::from_utf8(
1173+
decoded_bytes[..valid_len].to_vec(),
1174+
)
1175+
.unwrap();
1176+
self.encoding
1177+
.render_text_into(&valid_str, content_tokens)?;
1178+
content_delta.push_str(&valid_str);
1179+
self.undecoded_bytes.drain(..valid_len);
1180+
}
1181+
1182+
match utf8_error.error_len() {
1183+
Some(error_len) => {
1184+
let replacement = '\u{FFFD}'.to_string();
1185+
self.encoding.render_text_into(
1186+
&replacement,
1187+
content_tokens,
1188+
)?;
1189+
content_delta.push_str(&replacement);
1190+
self.undecoded_bytes.drain(..error_len);
1191+
}
1192+
None => {
1193+
// waiting on next byte in our utf-8 sequence
1194+
self.last_content_delta = None;
1195+
}
1196+
}
1197+
1198+
if !content_delta.is_empty() {
1199+
self.last_content_delta = Some(content_delta);
1200+
}
1201+
}
1202+
}
11561203
self.undecoded_tokens.clear();
11571204
}
11581205
Err(_) => {
1206+
// Invalid bytes, so wait on the next token
11591207
self.last_content_delta = None;
11601208
}
11611209
}
@@ -1167,7 +1215,13 @@ impl StreamableParser {
11671215
true
11681216
};
11691217
if is_eos {
1170-
let text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
1218+
let content_text = self.encoding.tokenizer().decode_utf8(content_tokens)?;
1219+
let tokens_text = self
1220+
.encoding
1221+
.tokenizer()
1222+
.decode_utf8(self.undecoded_tokens.clone())?;
1223+
let bytes_text = String::from_utf8_lossy(&self.undecoded_bytes);
1224+
let text = content_text + &tokens_text + &bytes_text;
11711225
let message = Message {
11721226
author: header.author.clone(),
11731227
recipient: header.recipient.clone(),

tests/test_harmony.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,3 +981,159 @@ def test_streamable_parser_tool_call_with_constrain_adjacent():
981981
]
982982

983983
assert parser.messages == expected
984+
985+
986+
def test_streamable_parser_invalid_utf8_decoding():
987+
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
988+
989+
# Confirm our token sequence is invalid utf-8
990+
# token 9552 corresponds to the bytes [32, 240, 159]
991+
# 32 is a space, 240,159 is an invalid utf-8 sequence
992+
invalid_token_sequence = [9552, 9552]
993+
with pytest.raises(HarmonyError):
994+
encoding.decode_utf8(invalid_token_sequence)
995+
996+
prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
997+
suffix_tokens = encoding.encode("worked<|end|>", allowed_special="all")
998+
tokens = prefix_tokens + invalid_token_sequence + suffix_tokens
999+
parser = StreamableParser(encoding, None)
1000+
for token in tokens:
1001+
parser.process(token)
1002+
1003+
expected = [
1004+
# Confirm we got the utf-8 replacement characters for the invalid sequences
1005+
# and the remaining valid utf-8 sequence
1006+
Message.from_role_and_content(Role.ASSISTANT, " \uFFFD \uFFFDworked"),
1007+
]
1008+
assert parser.messages == expected
1009+
1010+
1011+
def test_streamable_parser_invalid_utf8_decoding_split_across_tokens():
1012+
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
1013+
1014+
valid_token_sequence = encoding.encode("XY")
1015+
encoding.decode_utf8(valid_token_sequence)
1016+
1017+
# Confirm prepending specific token makes invalid utf-8
1018+
# 9552 token is the start of a multi-byte utf-8 sequence,
1019+
# which means prepending it to our previously valid sequence
1020+
# makes it invalid utf-8
1021+
invalid_token_sequence = [9552] + valid_token_sequence
1022+
with pytest.raises(HarmonyError):
1023+
encoding.decode_utf8(invalid_token_sequence)
1024+
1025+
prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
1026+
suffix_tokens = encoding.encode("<|end|>", allowed_special="all")
1027+
tokens = prefix_tokens + invalid_token_sequence + suffix_tokens
1028+
parser = StreamableParser(encoding, None)
1029+
for token in tokens:
1030+
parser.process(token)
1031+
1032+
expected = [
1033+
# One utf-8 replacement character but otherwise kept our space
1034+
# (from token 9552) and "X" and "Y" tokens
1035+
Message.from_role_and_content(Role.ASSISTANT, " \uFFFDXY"),
1036+
]
1037+
assert parser.messages == expected
1038+
1039+
1040+
def test_streamable_parser_invalid_utf8_decoding_multi_byte_token():
1041+
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
1042+
1043+
# Valid utf-8 sequence - 55=X, 56=Y in tokenizer
1044+
valid_token_sequence = encoding.encode(" interesting")
1045+
encoding.decode_utf8(valid_token_sequence)
1046+
1047+
# Confirm prepending specific token makes invalid utf-8
1048+
# 9552 token is the start of a multi-byte utf-8 sequence,
1049+
# which means prepending it to our previously valid sequence
1050+
# makes it invalid utf-8
1051+
invalid_token_sequence = [9552] + valid_token_sequence
1052+
with pytest.raises(HarmonyError):
1053+
encoding.decode_utf8(invalid_token_sequence)
1054+
1055+
prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
1056+
suffix_tokens = encoding.encode("<|end|>", allowed_special="all")
1057+
tokens = prefix_tokens + invalid_token_sequence + suffix_tokens
1058+
parser = StreamableParser(encoding, None)
1059+
for token in tokens:
1060+
parser.process(token)
1061+
1062+
expected = [
1063+
# One utf-8 replacement character and the contents of our second token,
1064+
# which maps to the text " interesting"
1065+
Message.from_role_and_content(Role.ASSISTANT, " \uFFFD interesting"),
1066+
]
1067+
assert parser.messages == expected
1068+
1069+
1070+
def test_streamable_parser_invalid_utf8_decoding_multi_byte_token_no_eos_marker():
1071+
"""Ensure we don't leave partially decoded tokens with no EOS marker."""
1072+
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
1073+
1074+
# Valid utf-8 sequence - 55=X, 56=Y in tokenizer
1075+
valid_token_sequence = encoding.encode(" interesting")
1076+
encoding.decode_utf8(valid_token_sequence)
1077+
1078+
# Confirm prepending specific token makes invalid utf-8
1079+
# 9552 token is the start of a multi-byte utf-8 sequence,
1080+
# which means prepending it to our previously valid sequence
1081+
# makes it invalid utf-8
1082+
invalid_token_sequence = [9552] + valid_token_sequence
1083+
with pytest.raises(HarmonyError):
1084+
encoding.decode_utf8(invalid_token_sequence)
1085+
1086+
prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
1087+
suffix_tokens = encoding.encode(" story")
1088+
tokens = prefix_tokens + invalid_token_sequence + suffix_tokens
1089+
parser = StreamableParser(encoding, None)
1090+
1091+
content_deltas = []
1092+
for token in tokens:
1093+
parser.process(token)
1094+
if parser.last_content_delta is not None:
1095+
content_deltas.append(parser.last_content_delta)
1096+
1097+
# No EOS, so no full message, but make sure we have the current content
1098+
assert parser.current_content == " \uFFFD interesting story"
1099+
1100+
# Ensure all the deltas combine to form our expected content
1101+
assert "".join(content_deltas) == " \uFFFD interesting story"
1102+
1103+
# Confirm we can keep accumulating content delta and content
1104+
one_more_token = encoding.encode("Y")[0]
1105+
parser.process(one_more_token)
1106+
assert parser.last_content_delta == "Y"
1107+
assert parser.current_content == " \uFFFD interesting storyY"
1108+
1109+
1110+
def test_streamable_parser_tricky_utf8_decoding():
1111+
"""Try text with various types of utf-8 sequences that are more likely to fail."""
1112+
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
1113+
1114+
tricky_utf8_text = (
1115+
"Hello Müller, Γειά σου, Привет, שלום, مرحبا, नमस्ते, こんにちは, 안녕하세요,"
1116+
" 你好. Normalized (naïve) vs. decomposed (naïve) characters. "
1117+
"Some emojis: 😊👋🏾👨‍👩‍👧‍👦🇺🇸."
1118+
)
1119+
valid_token_sequence = encoding.encode(tricky_utf8_text)
1120+
1121+
prefix_tokens = encoding.encode("<|start|>assistant<|message|>", allowed_special="all")
1122+
suffix_tokens = encoding.encode("<|end|>", allowed_special="all")
1123+
tokens = prefix_tokens + valid_token_sequence + suffix_tokens
1124+
parser = StreamableParser(encoding, None)
1125+
1126+
content_deltas = []
1127+
for token in tokens:
1128+
parser.process(token)
1129+
if parser.last_content_delta is not None:
1130+
content_deltas.append(parser.last_content_delta)
1131+
1132+
expected = [
1133+
Message.from_role_and_content(Role.ASSISTANT, tricky_utf8_text),
1134+
]
1135+
# Ensure we got the entirety of our tricky utf-8 text as message content
1136+
assert parser.messages == expected
1137+
1138+
# Ensure if we're accumulating content deltas we still get the full utf-8 text
1139+
assert "".join(content_deltas) == tricky_utf8_text

0 commit comments

Comments
 (0)