Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA微调ChatGLM3-6b,数据类型不匹配的错误 #1778

Open
dayunyan opened this issue Oct 26, 2024 · 1 comment
Open

LoRA微调ChatGLM3-6b,数据类型不匹配的错误 #1778

dayunyan opened this issue Oct 26, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@dayunyan
Copy link

Describe the bug/ 问题描述 (Mandatory / 必填)
我在用mindnlp.peft的LoRA微调ChatGLM3-6b时,训练过程中在lora的linear层报错TypeError。

  • Hardware Environment(Ascend/GPU/CPU) / 硬件环境:
    GPU

  • Software Environment / 软件环境 (Mandatory / 必填):
    -- MindSpore version (e.g., 1.7.0.Bxxx) : 2.2.14
    -- Python version (e.g., Python 3.7.5) : 3.9
    -- OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04
    -- GCC/Compiler version (if compiled from source):

  • Excute Mode / 执行模式 (Mandatory / 必填)(PyNative/Graph):

/mode pynative

To Reproduce / 重现步骤 (Mandatory / 必填)
tokenizer = ChatGLM3Tokenizer.from_pretrained(
args.model_name_or_path, mirror="modelscope", revision="master"
)
model = ChatGLM3ForConditionalGeneration.from_pretrained(
args.model_name_or_path, mirror="modelscope", revision="master"
)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
training_args = TrainingArguments(
output_dir=args.save_dir,
evaluation_strategy="epoch",
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=1,
learning_rate=args.learning_rate,
num_train_epochs=args.num_epochs,
lr_scheduler_type="polynomial",
lr_scheduler_kwargs={
"lr_end": args.learning_rate * 0.0001,
"power": args.power,
},
logging_steps=200,
save_strategy="epoch",
save_total_limit=1,
# load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()

Expected behavior / 预期结果 (Mandatory / 必填)
image

Screenshots/ 日志 / 截图 (Mandatory / 必填)
Traceback (most recent call last):
File "/home/zjj/xjd/huawei-ict-2024/ChatStyle/train.py", line 251, in
run(args)
File "/home/zjj/xjd/huawei-ict-2024/ChatStyle/train.py", line 91, in run
trainer.train() # resume_from_checkpoint="./checkpoints/checkpoint-8880"
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/engine/trainer/base.py", line 755, in train
return inner_training_loop(
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/engine/trainer/base.py", line 1107, in inner_training_loop
tr_loss_step, grads = self.training_step(model, inputs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/engine/trainer/base.py", line 1382, in training_step
loss, grads = self.grad_fn(inputs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindspore/ops/composite/base.py", line 625, in after_grad
return grad
(fn_, weights)(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindspore/common/api.py", line 121, in wrapper
results = fn(*arg, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindspore/ops/composite/base.py", line 600, in after_grad
res = self.pynative_forward_run(fn, grad, weights, args, kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindspore/ops/composite/base.py", line 650, in _pynative_forward_run
outputs = fn(*args, **new_kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/engine/trainer/base.py", line 1374, in forward
return self.compute_loss(model, inputs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/engine/trainer/base.py", line 1396, in compute_loss
outputs = model(**inputs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 391, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 402, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/peft/peft_model.py", line 373, in forward
return self.get_base_model()(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 391, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 402, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/transformers/models/chatglm2/modeling_chatglm2.py", line 1650, in forward
transformer_outputs = self.transformer(
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 391, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 402, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/transformers/models/chatglm2/modeling_chatglm2.py", line 1423, in forward
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 391, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 402, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/transformers/models/chatglm2/modeling_chatglm2.py", line 1108, in forward
layer_ret = layer(
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 391, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 402, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/transformers/models/chatglm2/modeling_chatglm2.py", line 969, in forward
attention_output, kv_cache = self.self_attention(
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 391, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 402, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/transformers/models/chatglm2/modeling_chatglm2.py", line 713, in forward
mixed_x_layer = self.query_key_value(hidden_states)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 391, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/core/nn/modules/module.py", line 402, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/peft/tuners/lora/layer.py", line 729, in forward
result = result.to(torch_result_dtype)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindspore/common/tensor.py", line 3884, in to
return tensor_operator_registry.get('to')()(self, dtype)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindspore/ops/primitive.py", line 311, in call
should_elim, output = self.check_elim(*args)
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindspore/ops/operations/array_ops.py", line 357, in check_elim
if isinstance(x, Tensor) and x.dtype == dtype and not PackFunc.is_tracing():
File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindspore/common/_stub_tensor.py", line 94, in dtype
self.stub_dtype = self.stub.get_dtype()
TypeError: For primitive[Dense], the input type must be same.
name:[w]:Ref[Tensor[Float32]].
name:[x]:Tensor[Float16].

Additional context / 备注 (Optional / 选填)
Add any other context about the problem here.

@dayunyan dayunyan added the bug Something isn't working label Oct 26, 2024
@dayunyan
Copy link
Author

补充一下
bug溯源到最后一层mindnlp的内容是 File "/home/zjj/miniconda3/envs/mindspore2.2/lib/python3.9/site-packages/mindnlp/peft/tuners/lora/layer.py", line 729, in forward result = result.to(torch_result_dtype),当我想查看result的dtype时result.dtype会报同样的错TypeError: For primitive[Dense], the input type must be same. name:[w]:Ref[Tensor[Float32]]. name:[x]:Tensor[Float16].
并且这个错误在动态图和静态图模式下都会发生。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant