-
Notifications
You must be signed in to change notification settings - Fork 780
[DispatchCreation] Enable fusion of encoding ops with multi-use producers #22444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[DispatchCreation] Enable fusion of encoding ops with multi-use producers #22444
Conversation
…cers -- This commit adds `enableAggressiveFusion` pass option to `FuseEncodingOpsIntoDispatchRegionsPass` in order to allow fusion of encoding ops to producers with multi-uses. -- `enableAggressiveFusion` is set using `enableMultiUseEncodingFusion` which in itself is set when using `-O3`. Signed-off-by: Abhishek Varma <[email protected]>
Signed-off-by: Abhishek Varma <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a separate experimental flag for now? I have the concern about controlling it with optimization level, because data-tiling is not ready for optimization level, and we should decouple the dep between umbrella flags a bit. See #21935 for details
@Abhishek-Varma can you remind me why we need this for fp8 but not fp16? E.g., can you provide an example from fp8 about the call graph (i.e., I think there are three dispatch ops) and the corresponding executables at dispatch creation level?
I want to understand how many memory footprints are added with the flag. This is the main reason that why we don't want it in optimization level, because users can be surprised about the behavior.
The other question, that you may have to check, is: do we need the flag for 405B fp4 model? (cc @jtuyls )
(I should add this comment in the first place, sorry about that.)
If we add a new flag, let's make sure it is very clear that it is only meant for testing purposes, and if we want to use it in a downstream project then we should do that by adding the tie to the optimization level. My main concern is that these flags keep proliferating through the downstream projects because everyone copy pastes flag files, and it is hard to remove them all. |
|
Wait, looking at the PR, it seems this does have its own flag? ( Maybe the naming is just slightly confusing because the pass option has the same name as some other pre-existing flags. |
No, we don't need it for the 405b model. |
I want to say that it is their issues if they use experimental flags. The experimental flags can be dropped anytime. However, the reality is that it is becoming tricky if users are from our partner groups. I don't have a solution, but it is a separate problem. As developers, we do our best to signal that they are for testing purpose. I'd refuse to fix issues if they don't know why they use the flag. If they use the flag because of our suggestion, then it is a trigger point to consider whether we promote the flag to pipeline option or add it to whatever umbrella flag. #22295 is an example that we graduate the experimental flag. I think we should only have this experimental flag in https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/DispatchCreation/Passes.cpp, but not exposing it to pipeline options. I don't understand why we need it for 8b fp8 model, but not other llama models. Is something wrong in the input program? |
The graph looks very different for 405b fp4, because we have scaled matmuls with quantized weights, and I think there are requantization dispatches before every matmul. In 8b, we would have cases like: But in 405b, we have: So, in 405b, we don't have to fuse with any dispatches that have multiple users. |
|
Thanks for the explanation, Max. I wanted to see the IR, but I failed. I did not see the IR pattern that you described; I don't see any encoding dispatch on ToT. MLIR file: nod-ai/shark-ai#2548 (comment) IR dumps: https://gist.github.com/hanhanW/f3011926ac6edd218d15c58d5c4ffa97 All the My compile flags are from https://github.com/nod-ai/iree-model-benchmark/blob/main/llama3/compile-8b-base.sh iree-compile \
--iree-hal-target-backends=rocm \
--iree-hip-target=gfx950 \
--iree-hal-target-device=hip \
--iree-opt-level=O3 \
--iree-dispatch-creation-propagate-collapse-across-expands=true \
--iree-codegen-enable-default-tuning-specs=true \
--iree-hip-enable-tensor-ukernels \
--iree-hal-indirect-command-buffers=true \
--iree-stream-resource-memory-model=discrete \
--iree-hip-specialize-dispatches \
--iree-hal-memoization=true \
-o /tmp/z.vmfb \
--iree-dispatch-creation-data-tiling \
--mlir-print-ir-after=iree-stream-conversion \
--mlir-disable-threading \
~/llama3_8b_fp8_e4m3fn.mlir \
2> ~/dump.mlirDo we really have the issue in the new MLIR file? Maybe we can move forward without this patch? |
|
I see, maybe something has changed then? I think we need to see the original repro for the issue this is solving if we still have it. Side note: This IR pattern does still happen on SXDL (unet/punet) AFAIK, so IMO it's still worth having this flag for more complex models like unet, even if the problem doesn't exist in llama anymore. |
Agree, but we should understand what problem we are solving, and we do solve the issue properly. Otherwise, it may be a tech debt; it may make future changes harder. People, like me, may ask what does it break if we revert the change? Is it safe to drop the flag for llama? I see the value of the PR; it is good. What's missing is that we want PR description more accurate. (And nice work for bringing up the model into a reasonable state, @Abhishek-Varma !) (If SDXL is the case, we should only use it as an experimental flag, and we don't share this flag with anyone.) |
|
Maybe we can add to the PR description and drop the third point from PR description, if @Abhishek-Varma can help verify that if it is still the issue for llama or not. |
|
So the issue persists for the following IR : 8b_attention_kernel_torch.zip. I tried the IR which you used above and that doesn't require this patch. To differentiate between the two I tried using
I've removed the third point from the PR's description @hanhanW @Max191 Regarding the flag - do we remove this from |
|
@Abhishek-Varma I want to decouple the dep between data-tiling umbrella flag and We also don't have an active use case for your flag, so my suggestion is moving it to |
-- This commit adds
enableAggressiveFusionpass option toFuseEncodingOpsIntoDispatchRegionsPassin order to allow fusion of encoding ops to producers with multi-uses.
--
enableAggressiveFusionis set usingenableMultiUseEncodingFusionwhichin itself is set when using
-O3.-- It aims to address the following patterns as observed in SDXL
Signed-off-by: Abhishek Varma [email protected]