From 0f57b0f4f3837e0fc5c485a64b42e6ddf7c7b703 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Thu, 27 Feb 2025 16:19:18 +0900 Subject: [PATCH] fix: allow slice_scatter decomposition with SymInt parameters for the case when returning the source tensor --- py/torch_tensorrt/dynamo/lowering/_decompositions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 1c182f92aa..3925aad5d2 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -203,11 +203,6 @@ def slice_scatter_decomposition( if step is None: step = 1 - # Ensure start, end, and step are all integers - assert isinstance(start, int), "start must be an integer" - assert isinstance(end, int), "end must be an integer" - assert isinstance(step, int), "step must be an integer" - src_dim = src_tensor.shape # step == 0 is not a valid torch case # also src_dim should be equal to slice dimension @@ -215,6 +210,11 @@ def slice_scatter_decomposition( if start == 0 and end == dim_size and step == 1: return src_tensor + # Ensure start, end, and step are all integers + assert isinstance(start, int), "start must be an integer" + assert isinstance(end, int), "end must be an integer" + assert isinstance(step, int), "step must be an integer" + cat_tensors = [] index_tensor_shape = [] for i, src_each_dim in enumerate(list(src_dim)):