Skip to content

[WebGPU] Unexpected Output with Phi-3 Mini 4K Instruct Model from ORT GenAI #25180

Closed
@Honry

Description

@Honry
Contributor

Describe the issue

WebNN developer preview provides a text-generation demo with some LLM models (Phi-3 Mini 4K Instruct, DeepSeek R1 Distill Qwen, TinyLLama, QWen2) which are generated from ONNXRuntime GenAI.

These models have the similar architecture (GQA, MatMulNBits, RotaryEmbedding...), when tested them with WebGPU EP, only the Phi-3 Mini 4K Instruct got unexpected result. other models worked fine.
Image

Image

To reproduce

Test Phi-3 Mini 4K Instruct:

Test others:

Urgency

No response

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.23.0-dev.20250612-70f14d7670

Execution Provider

'webgpu' (WebGPU)

Activity

added
platform:webissues related to ONNX Runtime web; typically submitted using template
on Jun 26, 2025
added
ep:DMLissues related to the DirectML execution provider
ep:WebGPUort-web webgpu provider
ep:WebNNWebNN execution provider
on Jun 26, 2025
Honry

Honry commented on Jun 26, 2025

@Honry
ContributorAuthor
fdwr

fdwr commented on Jun 26, 2025

@fdwr
Contributor

@Honry This one is tagged with ep:DML too 🤔. Should I remove that?

Honry

Honry commented on Jun 26, 2025

@Honry
ContributorAuthor

@Honry This one is tagged with ep:DML too 🤔. Should I remove that?

Yes, please!

removed
ep:DMLissues related to the DirectML execution provider
on Jun 26, 2025
fs-eire

fs-eire commented on Jul 2, 2025

@fs-eire
Contributor

If only Phi-3 Mini 4K Instruct failed and the other models are good, it's probably incorrect tokenizer or chat template is used.

Honry

Honry commented on Jul 2, 2025

@Honry
ContributorAuthor

If only Phi-3 Mini 4K Instruct failed and the other models are good, it's probably incorrect tokenizer or chat template is used.

🤔But with the same tokenizer, WebNN can pass behind the DML backend.

fs-eire

fs-eire commented on Jul 2, 2025

@fs-eire
Contributor

If I remember it correctly, for Phi3-Mini-4k-Instruct the model that exported for DML and WebGPU EP are slightly different. @guschmue do you know the details?

guschmue

guschmue commented on Jul 3, 2025

@guschmue
Contributor

I can reproduce with jsep but is fine with the new webgpu ep.
I think this is because of the bs=128.

Don't think we want to do anything here because we are switching to the new webgpu ep in days anyway.

In general, this model is optimized for DML and sub optimal for webgpu.
For webgpu we preferer models that are created by model builder with '-e webgpu', aka bs=32, accuracy_level=4, no ROE in GQA which enables the fast FA2 path in webgpu.
If using a webgpu specific model is not visible, the cuda flavor from model builder is the next best choice.

Near term we are adding ROE support to GQA in the fast FA2 path so the cuda model will perform optimal on webgpu.

Honry

Honry commented on Jul 4, 2025

@Honry
ContributorAuthor

@guschmue, Thank you for your thorough explanation! It works with the new webgpu ep. Let me close this issue.

Honry

Honry commented on Jul 8, 2025

@Honry
ContributorAuthor

@guschmue, I did see the performance improvement when adjusting the configuration for Phi3 from bs=128, accuracy_level=0 to bs=32, accuracy_level=4. However, for other models, changing the configuration from bs=32, accuracy_level=0 to bs=32, accuracy_level=4 did not result in significant improvement. Is that expected? Does this imply that smaller block size (bs) yield better performance on WebGPU?

qjia7

qjia7 commented on Jul 9, 2025

@qjia7
Contributor

accuracy_level=4 will go to the dp4 path for prefill. Otherwise, go to the normal path. If you have a long prompt, for example, > 1k, the advantage will be very obvious. But if the inputs are short, perf diff will be small between accuracy_level=4 and accuracy_level=0. The generation is using the same path. Block size is related with our algorithm to access scales_b. Our currently algorithm is friendly with bs=32. If you choose another bs, maybe need to adjust the current algorithm to get the optimal performance for webgpu.

Honry

Honry commented on Jul 9, 2025

@Honry
ContributorAuthor

accuracy_level=4 will go to the dp4 path for prefill. Otherwise, go to the normal path. If you have a long prompt, for example, > 1k, the advantage will be very obvious. But if the inputs are short, perf diff will be small between accuracy_level=4 and accuracy_level=0. The generation is using the same path. Block size is related with our algorithm to access scales_b. Our currently algorithm is friendly with bs=32. If you choose another bs, maybe need to adjust the current algorithm to get the optimal performance for webgpu.

👍 @qjia7 that make sense, thanks much for your answer!

guschmue

guschmue commented on Jul 9, 2025

@guschmue
Contributor

And, we changed the default in model builder to use accuracy level 4 - recently converted models on huggingface should come with accuracy level 4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:WebGPUort-web webgpu providerep:WebNNWebNN execution providerplatform:webissues related to ONNX Runtime web; typically submitted using template

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @skottmckay@fdwr@Honry@qjia7@fs-eire

        Issue actions

          [WebGPU] Unexpected Output with Phi-3 Mini 4K Instruct Model from ORT GenAI · Issue #25180 · microsoft/onnxruntime