Skip to content

Add lerp operator #535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 23, 2025
Merged

Add lerp operator #535

merged 7 commits into from
Jun 23, 2025

Conversation

MARD1NO
Copy link
Collaborator

@MARD1NO MARD1NO commented Apr 11, 2025

PR Category

Type of Change

Description

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

PyTorch lerp 公式hu会根据weight的abs值采取不同的计算公式
image

benchmark
测试机器4090
image

@MARD1NO MARD1NO marked this pull request as ready for review May 12, 2025 07:16
@MARD1NO MARD1NO requested a review from StrongSpoon May 12, 2025 07:17
StrongSpoon
StrongSpoon previously approved these changes May 23, 2025
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg

@ch1y0q
Copy link
Contributor

ch1y0q commented Jun 11, 2025

本地triton 3.3.0环境里复现,看到的是其他报错。
@StrongSpoon ci里用的triton版本是什么?

@tongxin
Copy link
Contributor

tongxin commented Jun 13, 2025

本地triton 3.3.0环境里复现,看到的是其他报错。 @StrongSpoon ci里用的triton版本是什么?

2.2.0, but we're working on upgrade right now.

@ch1y0q
Copy link
Contributor

ch1y0q commented Jun 17, 2025

本地triton 3.3.0环境里复现,看到的是其他报错。 @StrongSpoon ci里用的triton版本是什么?

2.2.0, but we're working on upgrade right now.

Is upgrade finished? Which triton version are we using now?

Comment on lines 24 to 29
@triton.jit
def lerp_scalar_kernel(input, end, weight):
if tl.abs(weight) < 0.5:
return input + weight * (end - input)
else:
return end - (end - input) * (1 - weight)
Copy link
Contributor

@ch1y0q ch1y0q Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为确保两个分支返回的数据类型一致,先显式用f32算,再统一转换回input的dtype

Suggested change
@triton.jit
def lerp_scalar_kernel(input, end, weight):
if tl.abs(weight) < 0.5:
return input + weight * (end - input)
else:
return end - (end - input) * (1 - weight)
@triton.jit
def lerp_scalar_kernel(input, end, weight):
input_f32 = input.to(tl.float32)
end_f32 = end.to(tl.float32)
if tl.abs(weight) < 0.5:
result = input_f32 + weight * (end_f32 - input_f32)
else:
result = end_f32 - (end_f32 - input_f32) * (1 - weight)
return result.to(input.dtype)

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm



def lerp_tensor(input, end, weight):
logging.debug("GEMS LERP TENSOR")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to logger plz

@StrongSpoon StrongSpoon merged commit 2cc97e2 into master Jun 23, 2025
8 of 14 checks passed
@StrongSpoon StrongSpoon deleted the add_lerp branch June 23, 2025 05:59
StrongSpoon added a commit that referenced this pull request Jun 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants