diff --git a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py index 756bce55f1c6..d4b4d077176e 100644 --- a/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py +++ b/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py @@ -312,6 +312,10 @@ def text_to_ids(self, text): ids = self.tokens_to_ids(tokens) return ids + def apply_chat_template(self, *args, **kwargs): + """Appies chat template and tokenizes results""" + return self.tokenizer.apply_chat_template(*args, **kwargs) + def ids_to_text(self, ids, remove_special_tokens=True): """ Converts token IDs back to text. diff --git a/nemo/collections/common/tokenizers/tokenizer_spec.py b/nemo/collections/common/tokenizers/tokenizer_spec.py index f6e905d75c3b..4d30f4c1c183 100644 --- a/nemo/collections/common/tokenizers/tokenizer_spec.py +++ b/nemo/collections/common/tokenizers/tokenizer_spec.py @@ -26,33 +26,45 @@ class TokenizerSpec(ABC): @abstractmethod def text_to_tokens(self, text): + """Converts text into a list of tokens.""" pass @abstractmethod def tokens_to_text(self, tokens): + """Converts a list of tokens back into text.""" pass @abstractmethod def tokens_to_ids(self, tokens): + """Converts a list of tokens to their corresponding IDs.""" pass @abstractmethod def ids_to_tokens(self, ids): + """Converts a list of token IDs back to tokens.""" pass @abstractmethod def text_to_ids(self, text): + """Converts text directly to token IDs.""" pass @abstractmethod def ids_to_text(self, ids): + """Converts token IDs back to text.""" pass def add_special_tokens(self, special_tokens: List[str]): + """Adds special tokens (eos, pad, cls...) to vocab.""" + raise NotImplementedError("To be implemented") + + def apply_chat_template(self, *args, **kwargs): + """Appies chat template and tokenizes results""" raise NotImplementedError("To be implemented") @property def name(self): + """name of the class""" return type(self).__name__ @property diff --git a/tests/collections/common/test_apply_chat_template.py b/tests/collections/common/test_apply_chat_template.py new file mode 100644 index 000000000000..ce27c5c824f4 --- /dev/null +++ b/tests/collections/common/test_apply_chat_template.py @@ -0,0 +1,39 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer + + +def test_chat_template(): + transformers = pytest.importorskip("transformers") + path = "/home/TestData/akoumparouli/tokenizer_with_chat_template/" + tokenizers = [get_tokenizer(path), transformers.AutoTokenizer.from_pretrained(path)] + prompt = "Give me a short introduction to pytest." + messages = [{"role": "system", "content": "You are a helpful CI assistant."}, {"role": "user", "content": prompt}] + texts = [ + tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) for tokenizer in tokenizers + ] + assert texts[0] == texts[1] + + +def test_throws_chat_template(): + path = "/home/TestData/akoumparouli/tokenizer_without_chat_template/" + tokenizer = get_tokenizer(path) + prompt = "Give me a short introduction to pytest." + messages = [{"role": "system", "content": "You are a helpful CI assistant."}, {"role": "user", "content": prompt}] + try: + tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + except ValueError as e: + assert 'Cannot use chat template functions because tokenizer.chat_template is not set' in str(e)