Skip to content

Commit 2e302a9

Browse files
authoredMar 18, 2025··
Fix assertions in slice_scatter decomposition (#3420)
1 parent 7dbd4cb commit 2e302a9

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed
 

‎py/torch_tensorrt/dynamo/lowering/_decompositions.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -203,18 +203,18 @@ def slice_scatter_decomposition(
203203
if step is None:
204204
step = 1
205205

206-
# Ensure start, end, and step are all integers
207-
assert isinstance(start, int), "start must be an integer"
208-
assert isinstance(end, int), "end must be an integer"
209-
assert isinstance(step, int), "step must be an integer"
210-
211206
src_dim = src_tensor.shape
212207
# step == 0 is not a valid torch case
213208
# also src_dim should be equal to slice dimension
214209

215210
if start == 0 and end == dim_size and step == 1:
216211
return src_tensor
217212

213+
# Ensure start, end, and step are all integers
214+
assert isinstance(start, int), "start must be an integer"
215+
assert isinstance(end, int), "end must be an integer"
216+
assert isinstance(step, int), "step must be an integer"
217+
218218
cat_tensors = []
219219
index_tensor_shape = []
220220
for i, src_each_dim in enumerate(list(src_dim)):

0 commit comments

Comments
 (0)
Please sign in to comment.