Skip to content

[AdvancedCompiler] Add polar #543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 28, 2025
Merged

Conversation

AdvancedCompiler
Copy link
Contributor

@AdvancedCompiler AdvancedCompiler commented Apr 16, 2025

PR Category

Operator

Type of Change

New Feature

Description

Implement the function of polar op.

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

update the performance of modified implementation
89dbc1462d42c6ee1b1930cbed82494

def polar(abs, angle):
logging.debug("GEMS POLAR")
real, imag = polar_kernel(abs, angle)
return torch.complex(real, imag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.complex calls complex_kernel_cuda. as the result, your implementation calls two kernels and performs worse than baseline. please try torch.view_as_complex and consider if there is a way to compute polar and create a complex tensor in one kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try pointwise_dynamic function with output parameters. Pass out0=<>, out1=<> to the function as output parameters. But output allocation has to be done manually(device, dtype, shape, stride inference)


@pytest.mark.polar
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some vendors don't support torch.float64 on their chips. and we don't require them to run tests of torch.float64 successfully. it's okay to check torch.float32 only.

@@ -7,6 +7,7 @@
torch.float16: 1e-3,
torch.float32: 1.3e-6,
torch.bfloat16: 0.016,
torch.float64: 1e-12,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we set the tolerance values of different data types according to pytorch/torch/testing/_comparison.py. it's recommended to choose 1e-7 for torch.float64.

@StrongSpoon StrongSpoon self-assigned this Apr 21, 2025
@AdvancedCompiler AdvancedCompiler changed the title Add polar [AdvancedCompiler] Add polar Apr 23, 2025
Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@StrongSpoon StrongSpoon merged commit bd692eb into FlagOpen:master Apr 28, 2025
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants