-
Notifications
You must be signed in to change notification settings - Fork 113
[Operator] index_add optimized #427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…agGems into index_add_opt
src/flag_gems/ops/index_add.py
Outdated
code.writeline( | ||
"add_on = tl.load(src + src_idx, mask=mask, other=0) * alpha" | ||
) | ||
code.writeline("tl.atomic_add(out + input_idx, add_on, mask=input_mask)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's safe to set sem='relaxed'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please resolve the conflict
src/flag_gems/ops/index_add.py
Outdated
|
||
# signature | ||
code.writeline(f"def {kernel_name}(") | ||
function_ns = NameSpace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please delete the code related to NameSpace, NameSpace is no longer used.
I see code in the generated kernel. Such access seems to be more suitable for column-major tensors? The previous #433 may be helpful. |
Thanks for all the above suggestions; the updated version has been pushed. Any further optimization ideas are welcome :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Operator
Type of Change
Bug Fix & Performance Optimization
Description
The method of using indexing to implement
index_add
resolves the bugs present in the previous version, and I accelerate it using code generation.BTW this solution doesn't perform well enough in some shapes, thus any suggestion is appreciated :)
Issue
Progress
Performance
Before: (Bug in shape (200, 40999, 3))
After: