-
Notifications
You must be signed in to change notification settings - Fork 113
Add linspace op #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add linspace op #478
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/flag_gems/ops/linspace.py
Outdated
dtype=None, | ||
layout=torch.strided, | ||
device=None, | ||
requires_grad=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the definition of linspace in aten is different from torch interface, and it doesn't include out and requires_grad. suggest ensuring what parameters it may receive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refer to the function definition in /root/miniconda3/envs/gems-ops/lib/python3.10/site-packages/torch/_C/_VariableFunctions.pyi
, which is different from the function definition in aten/src/ATen/native/native_functions.yaml
. Should we keep it consistent with the one in native_functions.yaml
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
return torch.fill(out, start) | ||
else: | ||
if isinstance(start, torch.Tensor): | ||
start = start.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
picking the item from tensor and then passing it to kernel function might cause unnecessary costing of time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to write 4 kernels for the case where start and end are tensors?
tests/test_special_ops.py
Outdated
@@ -533,6 +533,45 @@ def test_arange(start, step, end, dtype, device, pin_memory): | |||
gems_assert_equal(res_out, ref_out) | |||
|
|||
|
|||
@pytest.mark.linspace | |||
@pytest.mark.parametrize("start", [0, 2, 4]) | |||
@pytest.mark.parametrize("end", [1024, 2048, 4096]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest testing on cases that (end - start) < steps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg
PR Category
Operator
Type of Change
New Feature
Description
Add linspace op
Issue
Progress
Performance