Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ def text_to_ids(self, text):
ids = self.tokens_to_ids(tokens)
return ids

def apply_chat_template(self, *args, **kwargs):
"""Appies chat template and tokenizes results"""
return self.tokenizer.apply_chat_template(*args, **kwargs)

def ids_to_text(self, ids, remove_special_tokens=True):
"""
Converts token IDs back to text.
Expand Down
12 changes: 12 additions & 0 deletions nemo/collections/common/tokenizers/tokenizer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,45 @@ class TokenizerSpec(ABC):

@abstractmethod
def text_to_tokens(self, text):
"""Converts text into a list of tokens."""
pass

@abstractmethod
def tokens_to_text(self, tokens):
"""Converts a list of tokens back into text."""
pass

@abstractmethod
def tokens_to_ids(self, tokens):
"""Converts a list of tokens to their corresponding IDs."""
pass

@abstractmethod
def ids_to_tokens(self, ids):
"""Converts a list of token IDs back to tokens."""
pass

@abstractmethod
def text_to_ids(self, text):
"""Converts text directly to token IDs."""
pass

@abstractmethod
def ids_to_text(self, ids):
"""Converts token IDs back to text."""
pass

def add_special_tokens(self, special_tokens: List[str]):
"""Adds special tokens (eos, pad, cls...) to vocab."""
raise NotImplementedError("To be implemented")

def apply_chat_template(self, *args, **kwargs):
"""Appies chat template and tokenizes results"""
raise NotImplementedError("To be implemented")

@property
def name(self):
"""name of the class"""
return type(self).__name__

@property
Expand Down
39 changes: 39 additions & 0 deletions tests/collections/common/test_apply_chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer


def test_chat_template():
transformers = pytest.importorskip("transformers")
path = "/home/TestData/akoumparouli/tokenizer_with_chat_template/"
tokenizers = [get_tokenizer(path), transformers.AutoTokenizer.from_pretrained(path)]
prompt = "Give me a short introduction to pytest."
messages = [{"role": "system", "content": "You are a helpful CI assistant."}, {"role": "user", "content": prompt}]
texts = [
tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) for tokenizer in tokenizers
]
assert texts[0] == texts[1]


def test_throws_chat_template():
path = "/home/TestData/akoumparouli/tokenizer_without_chat_template/"
tokenizer = get_tokenizer(path)
prompt = "Give me a short introduction to pytest."
messages = [{"role": "system", "content": "You are a helpful CI assistant."}, {"role": "user", "content": prompt}]
try:
tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except ValueError as e:
assert 'Cannot use chat template functions because tokenizer.chat_template is not set' in str(e)
Loading