-
Notifications
You must be signed in to change notification settings - Fork 113
Add lerp operator #535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add lerp operator #535
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg
本地triton 3.3.0环境里复现,看到的是其他报错。 |
2.2.0, but we're working on upgrade right now. |
Is upgrade finished? Which triton version are we using now? |
src/flag_gems/ops/lerp.py
Outdated
@triton.jit | ||
def lerp_scalar_kernel(input, end, weight): | ||
if tl.abs(weight) < 0.5: | ||
return input + weight * (end - input) | ||
else: | ||
return end - (end - input) * (1 - weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为确保两个分支返回的数据类型一致,先显式用f32算,再统一转换回input的dtype
@triton.jit | |
def lerp_scalar_kernel(input, end, weight): | |
if tl.abs(weight) < 0.5: | |
return input + weight * (end - input) | |
else: | |
return end - (end - input) * (1 - weight) | |
@triton.jit | |
def lerp_scalar_kernel(input, end, weight): | |
input_f32 = input.to(tl.float32) | |
end_f32 = end.to(tl.float32) | |
if tl.abs(weight) < 0.5: | |
result = input_f32 + weight * (end_f32 - input_f32) | |
else: | |
result = end_f32 - (end_f32 - input_f32) * (1 - weight) | |
return result.to(input.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
|
||
|
||
def lerp_tensor(input, end, weight): | ||
logging.debug("GEMS LERP TENSOR") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to logger plz
This reverts commit 2cc97e2.
PR Category
Type of Change
Description
Issue
Progress
Performance
PyTorch lerp 公式hu会根据weight的abs值采取不同的计算公式

benchmark

测试机器4090