Skip to content

Qwen2模型推理结果有误 #1836

Open
@dayunyan

Description

@dayunyan

Describe the bug/ 问题描述 (Mandatory / 必填)
在使用Qwen2-7B-Instruct这个模型做基本的推理测试时,在输入完全相同的情况下,出现结果和torch完全不一致的问题。
测试代码:

from typing import List, Dict, Tuple
import argparse
import mindspore as ms
from mindspore import context
from mindnlp.peft import PeftModel, PeftConfig
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer

import datasets

context.set_context(device_target="Ascend")
# context.set_context(device_id=1)

prompt_system = "<|im_start|>system\n{}<|im_end|>\n"
prompt_user = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
prompt_assistant = "{}<|im_end|>\n"


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="Qwen/Qwen2-7B-Instruct",  # Qwen/Qwen2.5-3B
    )
    parser.add_argument("--inf_max_length", type=int, default=128)
    return parser.parse_args()

def inference(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, mirror="modelscope", revision="master")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path, mirror="modelscope", revision="master", ms_dtype=ms.float16
    )
    model.set_train(False)
    model.jit()

    messages = [
        {"role": "system", "content": "You are a helpful assistant."}
    ]
    with ms._no_grad():
        while True:
            inputs = input("Q: ")
            if inputs in ("exit", "Exit", "quit", "Quit", "e", "q"):
                break
            messages.append(
                {"role": "user", "content": inputs},
            )

            prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )

            model_inputs = tokenizer([prompt], return_tensors="ms")
            print(f"{model_inputs}")

            outputs = model.generate(**model_inputs, max_new_tokens=args.inf_max_length)
            outputs = [
                output_ids[len(input_ids) :]
                for input_ids, output_ids in zip(model_inputs["input_ids"], outputs)
            ]
            # print(outputs)
            text_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
            print(f"A: {text_output}")

            messages.append({"role": "assistant", "content": text_output})


if __name__ == "__main__":
    args = get_args()
    inference(args)

  • Hardware Environment(Ascend/GPU/CPU) / 硬件环境:

/device ascend

  • Software Environment / 软件环境 (Mandatory / 必填):
    -- MindSpore version (e.g., 1.7.0.Bxxx) : 2.3.1
    -- Python version (e.g., Python 3.7.5) : 3.9
    -- OS platform and distribution (e.g., Linux Ubuntu 16.04):
    -- GCC/Compiler version (if compiled from source):

  • Excute Mode / 执行模式 (Mandatory / 必填)(PyNative/Graph):

/mode pynative
/mode graph

To Reproduce / 重现步骤 (Mandatory / 必填)
Steps to reproduce the behavior:

  1. 下载原模型参数
  2. 运行测试代码
  3. 输出结果

Expected behavior / 预期结果 (Mandatory / 必填)
正常输出,与torch版输出相近

Screenshots/ 日志 / 截图 (Mandatory / 必填)

(MindSpore) [ma-user huawei-ict-2024]$python inference.py 
/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  setattr(self, word, getattr(machar, word).flat[0])
/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.
  return self._float_to_str(self.smallest_subnormal)
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 1.319 seconds.
Prefix dict has been built successfully.
Qwen2ForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.
[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████| 4/4 [03:39<00:00, 54.93s/it]
Q: 你是谁?
('text: <|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n', 'prompt: <|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n')
{'input_ids': Tensor(shape=[1, 22], dtype=Int64, value=
[[151644,   8948,    198 ... 151644,  77091,    198]]), 'attention_mask': Tensor(shape=[1, 22], dtype=Int64, value=
[[1, 1, 1 ... 1, 1, 1]])}
A: ![](![](https!://s3.cn-north-1.amazonaws.com/c![](https!://s3.cn-north-1.amazonaws.com!/[object%20Image!].png![](https!://s3!//cdn.cnbj1.fds.api.mi!//cdn.cnb!//cdn.cnb!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//!
Q: 今天星期几?
('text: <|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n![](![](https!://s3.cn-north-1.amazonaws.com/c![](https!://s3.cn-north-1.amazonaws.com!/[object%20Image!].png![](https!://s3!//cdn.cnbj1.fds.api.mi!//cdn.cnb!//cdn.cnb!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//!<|im_end|>\n<|im_start|>user\n今天星期几?<|im_end|>\n<|im_start|>assistant\n', 'prompt: <|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n你是谁?<|im_end|>\n<|im_start|>assistant\n![](![](https!://s3.cn-north-1.amazonaws.com/c![](https!://s3.cn-north-1.amazonaws.com!/[object%20Image!].png![](https!://s3!//cdn.cnbj1.fds.api.mi!//cdn.cnb!//cdn.cnb!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//![](!//!<|im_end|>\n<|im_start|>user\n今天星期几?<|im_end|>\n<|im_start|>assistant\n')
{'input_ids': Tensor(shape=[1, 149], dtype=Int64, value=
[[151644,   8948,    198 ... 151644,  77091,    198]]), 'attention_mask': Tensor(shape=[1, 149], dtype=Int64, value=
[[1, 1, 1 ... 1, 1, 1]])}
A: ![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](![](

Additional context / 备注 (Optional / 选填)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions