Skip to content

Commit 1f437a9

Browse files
Groverkssrsuderman
andauthored
[tmtensor] Add support for i64 index type for tm_tensor.scatter (#4221)
using i32 indexing for large memory blocks is wrong --------- Co-authored-by: Rob Suderman <[email protected]>
1 parent 716303a commit 1f437a9

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,6 @@ class ConvertAtenIndexPutHackedTwinOp
841841
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
842842
return failure();
843843
Location loc = op.getLoc();
844-
MLIRContext *context = op->getContext();
845844
Value input = op.getSelf();
846845
Value values = op.getValues();
847846
auto inputType = cast<ValueTensorType>(input.getType());
@@ -962,11 +961,6 @@ class ConvertAtenIndexPutHackedTwinOp
962961
values =
963962
rewriter.create<AtenViewOp>(loc, valuesType, values, valuesDimsList);
964963

965-
// `TMTensor::ScatterOp` expects indices of element type i32.
966-
indices = convertTensorToDtype(
967-
rewriter, loc, indices,
968-
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
969-
970964
input = typeConverter->materializeTargetConversion(
971965
rewriter, loc, typeConverter->convertType(input.getType()), input);
972966
values = typeConverter->materializeTargetConversion(

lib/Dialect/TMTensor/IR/TMTensorOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,9 @@ LogicalResult ScatterOp::verify() {
623623

624624
auto indicesType = getIndicesType();
625625
if (indicesType.getRank() != 2 ||
626-
!indicesType.getElementType().isInteger(32)) {
627-
return emitOpError("expected indices to be of rank 2 of i32 element type");
626+
!isa<IntegerType>(indicesType.getElementType())) {
627+
return emitOpError(
628+
"expected indices to be of rank 2 of integer element type");
628629
}
629630
auto indexDepth = getIndexDepth();
630631
if (ShapedType::isDynamic(indexDepth)) {

0 commit comments

Comments
 (0)