@@ -2372,8 +2372,8 @@ func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64>
2372
2372
// CHECK-LABEL: func.func @torch.aten.scatter.src$basic(
2373
2373
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,8,6],f32>,
2374
2374
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4,3],si64>,
2375
- // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[3 ,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> {
2376
- // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[3 ,4,3],f32> -> tensor<3x4x3xf32 >
2375
+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2 ,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> {
2376
+ // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2 ,4,3],f32> -> tensor<2x4x3xf32 >
2377
2377
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4,3],si64> -> tensor<2x4x3xi64>
2378
2378
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,8,6],f32> -> tensor<10x8x6xf32>
2379
2379
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
@@ -2383,8 +2383,8 @@ func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64>
2383
2383
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<{{\[\[}}{{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]]], {{\[\[}}[1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32>
2384
2384
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<{{\[\[}}{{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]], {{\[\[}}[0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32>
2385
2385
// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_9]], %[[VAL_11]] {axis = 3 : i32} : (tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>) -> tensor<2x4x3x3xi32>
2386
- // CHECK: %[[VAL_13:.*]] = tosa.const_shape {values = dense<[1, 36 , 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
2387
- // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_13]] : (tensor<3x4x3xf32 >, !tosa.shape<3>) -> tensor<1x36x1xf32 >
2386
+ // CHECK: %[[VAL_13:.*]] = tosa.const_shape {values = dense<[1, 24 , 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
2387
+ // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_13]] : (tensor<2x4x3xf32 >, !tosa.shape<3>) -> tensor<1x24x1xf32 >
2388
2388
// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[1, 480, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
2389
2389
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_15]] : (tensor<10x8x6xf32>, !tosa.shape<3>) -> tensor<1x480x1xf32>
2390
2390
// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<[24, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
@@ -2397,15 +2397,15 @@ func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64>
2397
2397
// CHECK: %[[VAL_24:.*]] = tosa.reduce_sum %[[VAL_23]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32>
2398
2398
// CHECK: %[[VAL_25:.*]] = tosa.const_shape {values = dense<[1, 24]> : tensor<2xindex>} : () -> !tosa.shape<2>
2399
2399
// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_24]], %[[VAL_25]] : (tensor<24x1xi32>, !tosa.shape<2>) -> tensor<1x24xi32>
2400
- // CHECK: %[[VAL_27:.*]] = tosa.scatter %[[VAL_16]], %[[VAL_26]], %[[VAL_14]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32 >) -> tensor<1x480x1xf32>
2400
+ // CHECK: %[[VAL_27:.*]] = tosa.scatter %[[VAL_16]], %[[VAL_26]], %[[VAL_14]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32 >) -> tensor<1x480x1xf32>
2401
2401
// CHECK: %[[VAL_28:.*]] = tosa.const_shape {values = dense<[10, 8, 6]> : tensor<3xindex>} : () -> !tosa.shape<3>
2402
2402
// CHECK: %[[VAL_29:.*]] = tosa.reshape %[[VAL_27]], %[[VAL_28]] : (tensor<1x480x1xf32>, !tosa.shape<3>) -> tensor<10x8x6xf32>
2403
2403
// CHECK: %[[VAL_30:.*]] = torch_c.from_builtin_tensor %[[VAL_29]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32>
2404
2404
// CHECK: return %[[VAL_30]] : !torch.vtensor<[10,8,6],f32>
2405
2405
// CHECK: }
2406
- func.func @torch.aten.scatter.src$basic (%arg0: !torch.vtensor <[10 ,8 ,6 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,3 ],si64 >, %arg2: !torch.vtensor <[3 ,4 ,3 ],f32 >) -> !torch.vtensor <[10 ,8 ,6 ],f32 > {
2406
+ func.func @torch.aten.scatter.src$basic (%arg0: !torch.vtensor <[10 ,8 ,6 ],f32 >, %arg1: !torch.vtensor <[2 ,4 ,3 ],si64 >, %arg2: !torch.vtensor <[2 ,4 ,3 ],f32 >) -> !torch.vtensor <[10 ,8 ,6 ],f32 > {
2407
2407
%int1 = torch.constant.int 1
2408
- %0 = torch.aten.scatter.src %arg0 , %int1 , %arg1 , %arg2 : !torch.vtensor <[10 ,8 ,6 ],f32 >, !torch.int , !torch.vtensor <[2 ,4 ,3 ],si64 >, !torch.vtensor <[3 ,4 ,3 ],f32 > -> !torch.vtensor <[10 ,8 ,6 ],f32 >
2408
+ %0 = torch.aten.scatter.src %arg0 , %int1 , %arg1 , %arg2 : !torch.vtensor <[10 ,8 ,6 ],f32 >, !torch.int , !torch.vtensor <[2 ,4 ,3 ],si64 >, !torch.vtensor <[2 ,4 ,3 ],f32 > -> !torch.vtensor <[10 ,8 ,6 ],f32 >
2409
2409
return %0 : !torch.vtensor <[10 ,8 ,6 ],f32 >
2410
2410
}
2411
2411
0 commit comments