[AdvancedCompiler] Add polar#543
Conversation
src/flag_gems/ops/polar.py
Outdated
| def polar(abs, angle): | ||
| logging.debug("GEMS POLAR") | ||
| real, imag = polar_kernel(abs, angle) | ||
| return torch.complex(real, imag) |
There was a problem hiding this comment.
torch.complex calls complex_kernel_cuda. as the result, your implementation calls two kernels and performs worse than baseline. please try torch.view_as_complex and consider if there is a way to compute polar and create a complex tensor in one kernel.
There was a problem hiding this comment.
Try pointwise_dynamic function with output parameters. Pass out0=<>, out1=<> to the function as output parameters. But output allocation has to be done manually(device, dtype, shape, stride inference)
tests/test_binary_pointwise_ops.py
Outdated
|
|
||
| @pytest.mark.polar | ||
| @pytest.mark.parametrize("shape", POINTWISE_SHAPES) | ||
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) |
There was a problem hiding this comment.
some vendors don't support torch.float64 on their chips. and we don't require them to run tests of torch.float64 successfully. it's okay to check torch.float32 only.
src/flag_gems/testing/__init__.py
Outdated
| torch.float16: 1e-3, | ||
| torch.float32: 1.3e-6, | ||
| torch.bfloat16: 0.016, | ||
| torch.float64: 1e-12, |
There was a problem hiding this comment.
we set the tolerance values of different data types according to pytorch/torch/testing/_comparison.py. it's recommended to choose 1e-7 for torch.float64.
* wei wang add polar * remove result.json * de flag_gems.ops.polar * fix * 421fix * add to_reference * opt --------- Co-authored-by: zwhangang123 <[email protected]>
PR Category
Operator
Type of Change
New Feature
Description
Implement the function of polar op.
Issue
Progress
Performance
update the performance of modified implementation
