diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 1c182f92aa..3925aad5d2 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -203,11 +203,6 @@ def slice_scatter_decomposition( if step is None: step = 1 - # Ensure start, end, and step are all integers - assert isinstance(start, int), "start must be an integer" - assert isinstance(end, int), "end must be an integer" - assert isinstance(step, int), "step must be an integer" - src_dim = src_tensor.shape # step == 0 is not a valid torch case # also src_dim should be equal to slice dimension @@ -215,6 +210,11 @@ def slice_scatter_decomposition( if start == 0 and end == dim_size and step == 1: return src_tensor + # Ensure start, end, and step are all integers + assert isinstance(start, int), "start must be an integer" + assert isinstance(end, int), "end must be an integer" + assert isinstance(step, int), "step must be an integer" + cat_tensors = [] index_tensor_shape = [] for i, src_each_dim in enumerate(list(src_dim)):