Skip to content

训练速度慢,GPU利用率低 #1793

Open
@dayunyan

Description

@dayunyan

Describe the bug/ 问题描述 (Mandatory / 必填)
LoRA微调Qwen2.5-3B模型时,训练阶段前10个step的速度比较快,能达到1~2s/step,随后逐渐减慢到10s/step以上,并且GPU的利用率在前期能达到100%,但在100个step之后就长时间地停在2%。

  • Hardware Environment(Ascend/GPU/CPU) / 硬件环境:

GPU

  • Software Environment / 软件环境 (Mandatory / 必填):
    -- MindSpore version (e.g., 1.7.0.Bxxx) : 2.2.14
    -- Python version (e.g., Python 3.7.5) : 3.9
    -- OS platform and distribution (e.g., Linux Ubuntu 16.04): 22.04
    -- GCC/Compiler version (if compiled from source):

  • Excute Mode / 执行模式 (Mandatory / 必填)(PyNative/Graph):

/mode pynative

To Reproduce / 重现步骤 (Mandatory / 必填)
Steps to reproduce the behavior:

    def forward_fn(input_ids, attention_mask, labels):
        output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        # loss = compute_ce_loss(output.logits, labels)
        return output.loss, output.logits

    grad_fn = ms.value_and_grad(
        forward_fn, None, model.trainable_params(), has_aux=True
    )

    def train_step(input_ids, attention_mask, labels):
        (loss, logits), grads = grad_fn(input_ids, attention_mask, labels)
        optimizer.step(grads)
        return loss, logits

    for epoch in tqdm(range(num_epochs), desc="Epoch"):
        model.set_train(True)
        total_loss, total_step = 0, 0
        with tqdm(total=num_batches, leave=False, position=1, desc="train_step") as t:
            for step, pack in enumerate(train_dataset.create_dict_iterator()):
                input_ids = pack["input_ids"]
                attention_mask = pack["attention_mask"]
                labels = pack["labels"]
                loss, logits = train_step(
                    input_ids=input_ids, attention_mask=attention_mask, labels=labels
                )
                total_loss += loss.asnumpy()
                lr_scheduler.step()
                total_step += 1
                curr_loss = total_loss / total_step
                t.set_postfix({"train-loss": f"{curr_loss:.2f}"})
                t.update(1)
                # if profiler is not None:
                #     if step == 10:
                #         profiler.start()
                #     if step == 100:
                #         profiler.stop()

        model.set_train(False)
        eval_loss = 0
        total_step = 0
        eval_preds = []
        total_text_labels = []
        with tqdm(
            total=num_batches_eval, leave=False, position=1, desc="eval_step"
        ) as t:
            for step, pack in enumerate(eval_dataset.create_dict_iterator()):
                input_ids = pack["input_ids"]
                attention_mask = pack["attention_mask"]
                labels = pack["labels"]
                text_inputs = pack["text_inputs"]
                text_labels = pack["text_labels"]
                with ms._no_grad():
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                    )
                loss = compute_ce_loss(outputs.logits, labels)
                eval_loss += loss.asnumpy()
                total_step += 1
                curr_eval_loss = eval_loss / total_step
                eval_preds.extend(
                    tokenizer.batch_decode(
                        outputs.logits.argmax(axis=-1).asnumpy(),
                        skip_special_tokens=True,
                    )
                )

                total_text_labels.extend(text_labels.tolist())
                t.set_postfix({"eval-loss": f"{curr_eval_loss:.2f}"})
                t.update(1)
        bleu_avg = compute_bleu_metrics(eval_preds, total_text_labels)
        # accuracy = correct / total * 100
        # print(f"{accuracy=} % on the evaluation dataset")
        eval_epoch_loss = eval_loss / eval_dataset.get_dataset_size()
        eval_ppl = np.exp(eval_epoch_loss)
        train_epoch_loss = total_loss / train_dataset.get_dataset_size()
        train_ppl = np.exp(train_epoch_loss)
        tqdm.write(
            f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=} {bleu_avg=}"
        )

Expected behavior / 预期结果 (Mandatory / 必填)
训练速度保持稳定且快速,GPU利用率能稳定且不能过低。

Screenshots/ 日志 / 截图 (Mandatory / 必填)
image
image

Additional context / 备注 (Optional / 选填)
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions