From 3e5377f81b0bf123e052e722bf47a5ec774c4c11 Mon Sep 17 00:00:00 2001 From: KlaudiaTH Date: Wed, 8 Nov 2023 18:31:22 +0100 Subject: [PATCH] Fix unnatural tokenizations if possible --- lm_eval/base.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 7ca3c677af..2b14dba8db 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -179,7 +179,20 @@ def loglikelihood(self, requests): context_enc = self.tok_encode(context) continuation_enc = self.tok_encode(continuation) - # continuation_enc = self.tok_encode(continuation, is_continuation=True) + ctx_cont_enc = self.tok_encode(context + continuation) + + if context_enc + continuation_enc != ctx_cont_enc: + if ctx_cont_enc[: len(context_enc)] == context_enc: + # continuation_enc is incorrect and context_enc is correct + continuation_enc = ctx_cont_enc[len(context_enc) :] + elif ctx_cont_enc[-len(continuation_enc) :] == continuation_enc: + # continuation_enc is correct and context_enc is incorrect + context_enc = ctx_cont_enc[: -len(continuation_enc)] + else: + # Both are incorrect + print( + f"WARNING: Unnatural tokenization of concatenated context ...{repr(context[-20:])} and continuation {repr(continuation)}" + ) new_reqs.append(((context, continuation), context_enc, continuation_enc))