-
Notifications
You must be signed in to change notification settings - Fork 110
[AdvancedCompiler]Contiguous(cpp wrapper) #750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The head ref may contain hidden characters: "contiguous\uFF08cpp-wrapper\uFF09"
[AdvancedCompiler]Contiguous(cpp wrapper) #750
Conversation
lib/contiguous.cpp
Outdated
at::Tensor input_sizes = torch::tensor(input.sizes(), options); | ||
at::Tensor input_strides = torch::tensor(input.strides(), options); | ||
at::Tensor out_strides = torch::tensor(out.strides(), options); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, torch has a factory function for creating tensors like this.
lib/contiguous.cpp
Outdated
namespace flag_gems { | ||
using namespace triton_jit; | ||
|
||
at::Tensor contiguous(at::Tensor &input, at::MemoryFormat memory_format) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use signature
at::Tensor contiguous(const at::Tensor & self, at::MemoryFormat memory_format=c10::MemoryFormat::Contiguous);
since this is the signature translated from the schema
contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check this file torch/include/ATen/ops/contiguous_native.h
src/flag_gems/csrc/cstub.cpp
Outdated
@@ -20,6 +20,7 @@ TORCH_LIBRARY(flag_gems, m) { | |||
m.def( | |||
"rotary_embedding(Tensor q, Tensor k, Tensor cos, Tensor sin, Tensor? position_ids=None, " | |||
"bool rotary_interleaved=False) -> (Tensor, Tensor)"); // q and k may be view to other size | |||
m.def("contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the schema from native_functions.yaml if we are implementing the operator with the same semantics.
contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) ->
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
src/flag_gems/csrc/cstub.cpp
Outdated
@@ -20,6 +20,7 @@ TORCH_LIBRARY(flag_gems, m) { | |||
m.def( | |||
"rotary_embedding(Tensor q, Tensor k, Tensor cos, Tensor sin, Tensor? position_ids=None, " | |||
"bool rotary_interleaved=False) -> (Tensor, Tensor)"); // q and k may be view to other size | |||
m.def("contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the same schema with the one in native_functions.yaml
contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)
(a) here means aliasing.
namespace flag_gems { | ||
using namespace triton_jit; | ||
|
||
at::Tensor contiguous(const at::Tensor &self, at::MemoryFormat memory_format) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since our current implementation does not support ChannelLast and ChannelLast3d format, it is necessary to raise an error if the memory format passed in is not supported. We can leave it for the future.
PR Category
Operator
Type of Change
New Feature
Description
CPP wrapper packaging for Contiguous op.
Issue
Progress
Performance