Skip to content

Commit 9764cba

Browse files
committed
fix cohere2 completions
1 parent 3bc1dfb commit 9764cba

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

tests/models/cohere2/test_modeling_cohere2.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def test_model_flash_attn(self):
266266
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for Gemma2, especially in long context
267267
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
268268
EXPECTED_TEXTS = [
269-
'<BOS_TOKEN>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
270-
"<PAD><PAD><BOS_TOKEN>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
269+
'<BOS_TOKEN>Hello I am doing a project for my school and I need to create a website for a fictional company. I have the logo and the name of the company. I need a website that is simple and easy to navigate. I need a home page, about us, services, contact us, and a gallery. I need the website to be responsive and I need it to be able to be hosted on a server. I need the website to be done in a week. I need the website to be done in HTML,',
270+
"<PAD><PAD><BOS_TOKEN>Hi today I'm going to show you how to make a simple and easy to make a chocolate cake.\n\nThis recipe is very simple and easy to make.\n\nYou will need:\n\n* 2 cups of flour\n* 1 cup of sugar\n* 1/2 cup of cocoa powder\n* 1 teaspoon of baking powder\n* 1 teaspoon of baking soda\n* 1/2 teaspoon of salt\n* 2 eggs\n* 1 cup of milk\n",
271271
] # fmt: skip
272272

273273
model = AutoModelForCausalLM.from_pretrained(
@@ -285,21 +285,14 @@ def test_export_static_cache(self):
285285
if version.parse(torch.__version__) < version.parse("2.5.0"):
286286
self.skipTest(reason="This test requires torch >= 2.5 to run.")
287287

288-
from transformers.integrations.executorch import (
289-
TorchExportableModuleWithStaticCache,
290-
convert_and_export_with_cache,
291-
)
288+
from transformers.integrations.executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache
292289

293-
tokenizer = AutoTokenizer.from_pretrained(
294-
"CohereForAI/c4ai-command-r7b-12-2024", pad_token="<PAD>", padding_side="right"
295-
)
290+
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
296291
EXPECTED_TEXT_COMPLETION = [
297-
"Hello I am doing a project for my school and I need to know how to make a program that will take a number",
292+
"Hello I am doing a project on the effects of social media on mental health. I have a few questions. 1. What is the relationship",
298293
]
299-
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
300-
"input_ids"
301-
].shape[-1]
302294

295+
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token="<PAD>", padding_side="right")
303296
# Load model
304297
device = "cpu"
305298
dtype = torch.bfloat16
@@ -314,18 +307,18 @@ def test_export_static_cache(self):
314307
generation_config=GenerationConfig(
315308
use_cache=True,
316309
cache_implementation=cache_implementation,
317-
max_length=max_generation_length,
310+
max_length=30,
318311
cache_config={
319312
"batch_size": batch_size,
320-
"max_cache_len": max_generation_length,
313+
"max_cache_len": 30,
321314
},
322315
),
323316
)
324317

325318
prompts = ["Hello I am doing"]
326319
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
327320
prompt_token_ids = prompt_tokens["input_ids"]
328-
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
321+
max_new_tokens = 30 - prompt_token_ids.shape[-1]
329322

330323
# Static Cache + export
331324
exported_program = convert_and_export_with_cache(model)

0 commit comments

Comments
 (0)