Skip to content

Commit 4027764

Browse files
committed
use encode instead of encode_batch
1 parent 52a6deb commit 4027764

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

axlearn/experiments/text/gpt/vocabulary_fuji_v3.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,20 @@ def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
146146
s = tf.reshape(s, (1,))
147147
need_unpack = True
148148

149-
def helper(s):
150-
s = s.numpy()
151-
res = self._tokenizer.encode_batch(
152-
[item.decode("utf-8") for item in s], add_special_tokens=True
153-
)
154-
# The return does not include EOS, but we need to remove BOS.
155-
res = [item.ids[1:] if item.ids[0] == self.bos_id else item.ids for item in res]
149+
def helper_en(s):
150+
res = []
151+
for item in s.numpy():
152+
item = item.decode("utf-8")
153+
encoded = self._tokenizer.encode(item, add_special_tokens=True)
154+
ids = encoded.ids
155+
# The return does not include EOS, but we need to remove BOS.
156+
if len(ids) > 0 and ids[0] == self.bos_id:
157+
ids = ids[1:]
158+
res.append(ids)
156159
return tf.ragged.constant(res, dtype=tf.int32)
157160

158161
ret = tf.py_function(
159-
helper, inp=[s], Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32)
162+
helper_en, inp=[s], Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32)
160163
)
161164
if need_unpack:
162165
return ret[0]

0 commit comments

Comments
 (0)