Skip to content

Commit dc1440c

Browse files
Neuron up mistral (#18222)
Signed-off-by: Satyajith Chilappagari <[email protected]>
1 parent 8171221 commit dc1440c

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

tests/neuron/2_core/test_mistral.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from vllm import LLM, SamplingParams
4+
5+
6+
def test_mistral():
7+
llm = LLM(model="mistralai/Mistral-7B-v0.1",
8+
tensor_parallel_size=2,
9+
max_num_seqs=4,
10+
max_model_len=512,
11+
use_v2_block_manager=True,
12+
override_neuron_config={
13+
"sequence_parallel_enabled": False,
14+
"skip_warmup": True
15+
},
16+
device="neuron")
17+
18+
prompts = [
19+
"The president of the United States is",
20+
"The capital of France is",
21+
]
22+
outputs = llm.generate(prompts, SamplingParams(top_k=1))
23+
24+
expected_outputs = [
25+
" the most powerful person in the world. He is the head of state "
26+
"and head",
27+
" a city of many faces. It is a city of history, culture, art"
28+
]
29+
30+
for expected_output, output in zip(expected_outputs, outputs):
31+
generated_text = output.outputs[0].text
32+
assert (expected_output == generated_text)

vllm/model_executor/model_loader/neuronx_distributed.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
# Models supported by Neuronx distributed for inference.
4949
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
5050
"LlamaForCausalLM":
51+
("neuronx_distributed_inference.models.llama.modeling_llama",
52+
"NeuronLlamaForCausalLM"),
53+
"MistralForCausalLM":
5154
("neuronx_distributed_inference.models.llama.modeling_llama",
5255
"NeuronLlamaForCausalLM"),
5356
"DbrxForCausalLM":

vllm/platforms/neuron.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
5151
assert (vllm_config.lora_config
5252
is None), "LoRA is not supported for Neuron backend."
5353

54-
cache_config = vllm_config.cache_config
55-
if cache_config:
54+
if vllm_config.cache_config and vllm_config.model_config:
5655
# neuron needs block_size = max_model_len
5756
vllm_config.cache_config.block_size = \
5857
vllm_config.model_config.max_model_len # type: ignore

0 commit comments

Comments
 (0)