Open
Description
Describe the bug/ 问题描述 (Mandatory / 必填)
LoRA微调Qwen2.5-3B模型时,训练阶段前10个step的速度比较快,能达到1~2s/step,随后逐渐减慢到10s/step以上,并且GPU的利用率在前期能达到100%,但在100个step之后就长时间地停在2%。
- Hardware Environment(
Ascend
/GPU
/CPU
) / 硬件环境:
GPU
-
Software Environment / 软件环境 (Mandatory / 必填):
-- MindSpore version (e.g., 1.7.0.Bxxx) : 2.2.14
-- Python version (e.g., Python 3.7.5) : 3.9
-- OS platform and distribution (e.g., Linux Ubuntu 16.04): 22.04
-- GCC/Compiler version (if compiled from source): -
Excute Mode / 执行模式 (Mandatory / 必填)(
PyNative
/Graph
):
/mode pynative
To Reproduce / 重现步骤 (Mandatory / 必填)
Steps to reproduce the behavior:
def forward_fn(input_ids, attention_mask, labels):
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
# loss = compute_ce_loss(output.logits, labels)
return output.loss, output.logits
grad_fn = ms.value_and_grad(
forward_fn, None, model.trainable_params(), has_aux=True
)
def train_step(input_ids, attention_mask, labels):
(loss, logits), grads = grad_fn(input_ids, attention_mask, labels)
optimizer.step(grads)
return loss, logits
for epoch in tqdm(range(num_epochs), desc="Epoch"):
model.set_train(True)
total_loss, total_step = 0, 0
with tqdm(total=num_batches, leave=False, position=1, desc="train_step") as t:
for step, pack in enumerate(train_dataset.create_dict_iterator()):
input_ids = pack["input_ids"]
attention_mask = pack["attention_mask"]
labels = pack["labels"]
loss, logits = train_step(
input_ids=input_ids, attention_mask=attention_mask, labels=labels
)
total_loss += loss.asnumpy()
lr_scheduler.step()
total_step += 1
curr_loss = total_loss / total_step
t.set_postfix({"train-loss": f"{curr_loss:.2f}"})
t.update(1)
# if profiler is not None:
# if step == 10:
# profiler.start()
# if step == 100:
# profiler.stop()
model.set_train(False)
eval_loss = 0
total_step = 0
eval_preds = []
total_text_labels = []
with tqdm(
total=num_batches_eval, leave=False, position=1, desc="eval_step"
) as t:
for step, pack in enumerate(eval_dataset.create_dict_iterator()):
input_ids = pack["input_ids"]
attention_mask = pack["attention_mask"]
labels = pack["labels"]
text_inputs = pack["text_inputs"]
text_labels = pack["text_labels"]
with ms._no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
)
loss = compute_ce_loss(outputs.logits, labels)
eval_loss += loss.asnumpy()
total_step += 1
curr_eval_loss = eval_loss / total_step
eval_preds.extend(
tokenizer.batch_decode(
outputs.logits.argmax(axis=-1).asnumpy(),
skip_special_tokens=True,
)
)
total_text_labels.extend(text_labels.tolist())
t.set_postfix({"eval-loss": f"{curr_eval_loss:.2f}"})
t.update(1)
bleu_avg = compute_bleu_metrics(eval_preds, total_text_labels)
# accuracy = correct / total * 100
# print(f"{accuracy=} % on the evaluation dataset")
eval_epoch_loss = eval_loss / eval_dataset.get_dataset_size()
eval_ppl = np.exp(eval_epoch_loss)
train_epoch_loss = total_loss / train_dataset.get_dataset_size()
train_ppl = np.exp(train_epoch_loss)
tqdm.write(
f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=} {bleu_avg=}"
)
Expected behavior / 预期结果 (Mandatory / 必填)
训练速度保持稳定且快速,GPU利用率能稳定且不能过低。
Screenshots/ 日志 / 截图 (Mandatory / 必填)
Additional context / 备注 (Optional / 选填)
Add any other context about the problem here.