Skip to content

Commit 47ebb3a

Browse files
committed
use encode instead of encode_batch
1 parent 52a6deb commit 47ebb3a

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

axlearn/experiments/text/gpt/vocabulary_fuji_v3.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,21 @@ def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
146146
s = tf.reshape(s, (1,))
147147
need_unpack = True
148148

149-
def helper(s):
149+
def helper_en(s):
150150
s = s.numpy()
151-
res = self._tokenizer.encode_batch(
152-
[item.decode("utf-8") for item in s], add_special_tokens=True
153-
)
154-
# The return does not include EOS, but we need to remove BOS.
155-
res = [item.ids[1:] if item.ids[0] == self.bos_id else item.ids for item in res]
151+
res = []
152+
for item in s:
153+
item = item.decode("utf-8")
154+
encoded = self._tokenizer.encode(item, add_special_tokens=True)
155+
ids = encoded.ids
156+
# The return does not include EOS, but we need to remove BOS.
157+
if len(ids) > 0 and ids[0] == self.bos_id:
158+
ids = ids[1:]
159+
res.append(ids)
156160
return tf.ragged.constant(res, dtype=tf.int32)
157161

158162
ret = tf.py_function(
159-
helper, inp=[s], Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32)
163+
helper_en, inp=[s], Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32)
160164
)
161165
if need_unpack:
162166
return ret[0]

0 commit comments

Comments
 (0)