Skip to content

Commit

Permalink
[Relax] Add gather_elements and gather_nd operators (#17523)
Browse files Browse the repository at this point in the history
Add gather_elements and gather_nd operators to Relax and corresponding
ONNX frontend.
  • Loading branch information
Hzfengsy authored Nov 13, 2024
1 parent b96ee76 commit 59a2256
Show file tree
Hide file tree
Showing 12 changed files with 514 additions and 5 deletions.
17 changes: 17 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ struct FlipAttrs : public tvm::AttrsNode<FlipAttrs> {
}
}; // struct FlipAttrs

/*! \brief Attributes used in gather_elements operators */
struct GatherElementsAttrs : public tvm::AttrsNode<GatherElementsAttrs> {
Integer axis;

TVM_DECLARE_ATTRS(GatherElementsAttrs, "relax.attrs.GatherElementsAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis along which to index.");
}
}; // struct GatherElementsAttrs

/*! \brief Attributes used in gather_nd operators */
struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;
TVM_DECLARE_ATTRS(GatherNDAttrs, "relax.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dims.");
}
}; // struct GatherNDAttrs

/*! \brief Attributes used in scatter_elements operators */
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
Expand Down
22 changes: 20 additions & 2 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,24 @@ def _impl_v13(cls, bb, inputs, attr, params):
return relax.op.take(data, indices, axis)


class GatherElements(OnnxOpConverter):
"""Convert an onnx GatherElements node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", 0)
return relax.op.gather_elements(inputs[0], inputs[1], axis)


class GatherND(OnnxOpConverter):
"""Convert an onnx GatherND node into an equivalent Relax expression."""

@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
batch_dims = attr.get("batch_dims", 0)
return relax.op.gather_nd(inputs[0], inputs[1], batch_dims)


class Scatter(OnnxOpConverter):
"""Convert an onnx Scatter node into an equivalent Relax expression."""

Expand Down Expand Up @@ -3116,8 +3134,8 @@ def _get_convert_map():
"Squeeze": Squeeze,
"Constant": Constant,
"Gather": Gather,
# "GatherElements": GatherElements,
# "GatherND": GatherND,
"GatherElements": GatherElements,
"GatherND": GatherND,
"Scatter": Scatter,
"ScatterElements": ScatterElements,
"ScatterND": ScatterND,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
expand_dims,
flatten,
flip,
gather_elements,
gather_nd,
layout_transform,
one_hot,
permute_dims,
Expand Down
73 changes: 73 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,79 @@ def flip(data, axis):
return _ffi_api.flip(data, axis) # type: ignore


def gather_elements(data: Expr, indices: Expr, axis: int = 0) -> Expr:
"""Gather elements from data according to indices along the specified axis.
Parameters
----------
data : relax.Expr
The input data to the operator.
indices : relax.Expr
The indices tensor, must have integer type.
axis : int
The axis along which to index. Default is 0.
Returns
-------
ret : relax.Expr
The computed result.
Examples
--------
.. code-block:: python
data = [[1, 2], [3, 4]]
indices = [[0, 0], [1, 0]]
axis = 1
output = [[1, 1], [4, 3]]
data = [[1, 2, 3], [4, 5, 6]]
indices = [[1, 1, 1]]
axis = 0
output = [[4, 5, 6]]
"""
return _ffi_api.gather_elements(data, indices, axis) # type: ignore


def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr:
"""Update data at positions defined by indices with values in updates.
Parameters
----------
data : relax.Expr
The input data to the operator.
indices : relax.Expr
The indices tensor, must have integer type.
batch_dims : int
The number of batch dimensions. Default is 0.
Returns
-------
ret : relax.Expr
The computed result.
Examples
--------
.. code-block:: python
batch_dims = 0
data = [[0,1],[2,3]] # data_shape = [2, 2]
indices = [[0,0],[1,1]] # indices_shape = [2, 2]
output = [0,3] # output_shape = [2]
batch_dims = 1
data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
indices = [[1],[0]] # indices_shape = [2, 1]
output = [[2,3],[4,5]] # output_shape = [2, 2]
"""
return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore


def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update"
):
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ def _flip(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.flip, call.args[0], int(call.attrs.axis))


@register_legalize("relax.gather_elements")
def _gather_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.gather, call.args[0], int(call.attrs.axis), call.args[1])


@register_legalize("relax.gather_nd")
def _gather_nd(bb: BlockBuilder, call: Call) -> Expr:
def te_gather_nd(data, indices, batch_dims):
indices_ndim = len(indices.shape)
axes = [indices_ndim - 1] + list(range(indices_ndim - 1))
indices = topi.transpose(indices, axes)
return topi.gather_nd(data, indices, batch_dims)

return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims))


@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
floor_mod,
full,
full_like,
gather_elements,
gather_nd,
grad,
greater,
greater_equal,
Expand Down Expand Up @@ -772,6 +774,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"func_ret_struct_info",
"func_ret_value",
"function",
"gather_elements",
"gather_nd",
"gpu",
"grad",
"greater",
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def gather(data, axis, indices):
return cpp.gather(data, axis, indices)


def gather_nd(a, indices):
def gather_nd(a, indices, batch_dims=0):
"""Gather elements from a n-dimension array..
Parameters
Expand All @@ -543,7 +543,7 @@ def gather_nd(a, indices):
-------
ret : tvm.te.Tensor
"""
return cpp.gather_nd(a, indices)
return cpp.gather_nd(a, indices, batch_dims)


def matmul(a, b, transp_a=False, transp_b=False):
Expand Down
163 changes: 163 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,169 @@ TVM_REGISTER_OP("relax.flip")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlip)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.gather_elements */
TVM_REGISTER_NODE_TYPE(GatherElementsAttrs);

Expr gather_elements(Expr data, Expr indices, int axis) {
auto attrs = make_object<GatherElementsAttrs>();
attrs->axis = Integer(axis);
static const Op& op = Op::Get("relax.gather_elements");
return Call(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.gather_elements").set_body_typed(gather_elements);

StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& ctx) {
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* indices_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
const auto* attrs = call->attrs.as<GatherElementsAttrs>();

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherElements requires the input data to be a Tensor. However, the given one is "
<< call->args[0]->struct_info_->GetTypeKey());
}
if (indices_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherElements requires the input indices to be a Tensor. However, the given one is "
<< call->args[1]->struct_info_->GetTypeKey());
}

if (!indices_sinfo->IsUnknownDtype() && !indices_sinfo->dtype.is_int()) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherElements requires the input indices to have int64 dtype. However, the "
<< "given indices dtype is " << indices_sinfo->dtype);
}

if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice);
}

int axis = attrs->axis.IntValue();
if (axis < -data_sinfo->ndim || axis >= data_sinfo->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherElements requires axis to be within the input dimension range ["
<< -data_sinfo->ndim << ", " << data_sinfo->ndim - 1 << "]. However, the "
<< "given axis is " << axis);
}

if (data_sinfo->ndim != indices_sinfo->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherElements requires data and indices to have the same rank. However, "
<< "data rank is " << data_sinfo->ndim << " while indices rank is "
<< indices_sinfo->ndim);
}
if (indices_sinfo->shape.defined()) {
return TensorStructInfo(indices_sinfo->shape.value(), data_sinfo->dtype, data_sinfo->vdevice);
}
return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.gather_elements")
.set_attrs_type<GatherElementsAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGatherElements)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.gather_nd */
TVM_REGISTER_NODE_TYPE(GatherNDAttrs);

Expr gather_nd(Expr data, Expr indices, int batch_dims) {
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = Integer(batch_dims);
static const Op& op = Op::Get("relax.gather_nd");
return Call(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.gather_nd").set_body_typed(gather_nd);

StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) {
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* indices_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
const auto* attrs = call->attrs.as<GatherNDAttrs>();

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherND requires the input data to be a Tensor. However, the given one is "
<< call->args[0]->struct_info_->GetTypeKey());
}
if (indices_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherND requires the input indices to be a Tensor. However, the given one is "
<< call->args[1]->struct_info_->GetTypeKey());
}
ICHECK_GE(attrs->batch_dims.IntValue(), 0);
int batch_dims = attrs->batch_dims.IntValue();
int input_dims = data_sinfo->ndim;
if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != DataType::Int(64)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherND requires the input indices to have int64 dtype. However, the "
<< "given indices dtype is " << indices_sinfo->dtype);
}

if (data_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice);
}

if (batch_dims < 0 || batch_dims > data_sinfo->ndim) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "GatherND batch_dims must be in range [0, data.ndim]. However, got batch_dims="
<< batch_dims << ", data.ndim=" << input_dims);
}

if (batch_dims > indices_sinfo->ndim - 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherND batch_dims cannot exceed indices.ndim-1. However, got batch_dims="
<< batch_dims << ", indices.ndim=" << indices_sinfo->ndim);
}

// Check if indices shape is known
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (!indices_shape || !indices_shape->values.back()->IsInstance<IntImmNode>()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice);
}
int l = indices_shape->values.back().as<IntImmNode>()->value;
int output_ndim = indices_sinfo->ndim + input_dims - l - 1 - batch_dims;
if (!data_shape) {
return TensorStructInfo(data_sinfo->dtype, output_ndim, data_sinfo->vdevice);
}

// In this condition, all input shapes are known
Array<PrimExpr> out_shape;
if (l > input_dims - batch_dims) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "GatherND requires the last dimension of indices to be less than or "
"equal to the rank of data minus batch_dims. However, the given shapes are "
<< "indices: " << ShapeExpr(indices_shape->values) << ", data: "
<< ShapeExpr(data_shape->values) << ", with batch_dims=" << batch_dims);
}
for (int i = 0; i < indices_sinfo->ndim - 1; ++i) {
out_shape.push_back(indices_shape->values[i]);
}
for (int i = batch_dims + l; i < input_dims; ++i) {
out_shape.push_back(data_shape->values[i]);
}
ICHECK_EQ(out_shape.size(), output_ndim);
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.gather_nd")
.set_attrs_type<GatherNDAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGatherND)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.scatter_elements */
TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);

Expand Down
Loading

0 comments on commit 59a2256

Please sign in to comment.