@@ -356,28 +356,28 @@ def test_cse_rematerialization(executor, device, _):
356
356
357
357
fw_trace = thunder .last_traces (compiled_func )[- 1 ]
358
358
fusion_bsyms = tuple (filter (lambda a : a .sym .is_fusion , fw_trace .bound_symbols ))
359
- assert len (fusion_bsyms ) == 11
359
+ assert len (fusion_bsyms ) == 9
360
360
# fusion groups 1 and 6 correspond with the apply_rotary_emb function
361
361
# Nvfuser with recomputation should use precomputed cos and sin values.
362
- assert len (fusion_bsyms [1 ].args ) == len (fusion_bsyms [6 ].args )
362
+ assert len (fusion_bsyms [1 ].args ) == len (fusion_bsyms [5 ].args )
363
363
364
364
# Below, we check that freqs_sin and freqs_cos are used
365
365
# in the same operation in both fusions.
366
366
(fusion1_freqs_sin_arg ,) = (a for a in fusion_bsyms [1 ].args if a .name == "freqs_sin" )
367
367
(fusion1_freqs_cos_arg ,) = (a for a in fusion_bsyms [1 ].args if a .name == "freqs_cos" )
368
- (fusion6_freqs_sin_arg ,) = (a for a in fusion_bsyms [6 ].args if a .name == "freqs_sin" )
369
- (fusion6_freqs_cos_arg ,) = (a for a in fusion_bsyms [6 ].args if a .name == "freqs_cos" )
368
+ (fusion5_freqs_sin_arg ,) = (a for a in fusion_bsyms [5 ].args if a .name == "freqs_sin" )
369
+ (fusion5_freqs_cos_arg ,) = (a for a in fusion_bsyms [5 ].args if a .name == "freqs_cos" )
370
370
371
371
(fusion1_freqs_sin_user ,) = (s for s in fusion_bsyms [1 ].subsymbols if s .args [0 ] is fusion1_freqs_sin_arg )
372
- (fusion6_freqs_sin_user ,) = (s for s in fusion_bsyms [6 ].subsymbols if s .args [0 ] is fusion6_freqs_sin_arg )
372
+ (fusion6_freqs_sin_user ,) = (s for s in fusion_bsyms [5 ].subsymbols if s .args [0 ] is fusion5_freqs_sin_arg )
373
373
374
374
assert fusion1_freqs_sin_user .sym is fusion6_freqs_sin_user .sym
375
375
assert fusion1_freqs_sin_user .args [1 :] == fusion6_freqs_sin_user .args [1 :]
376
376
(fusion1_freqs_cos_user ,) = (s for s in fusion_bsyms [1 ].subsymbols if s .args [0 ] is fusion1_freqs_cos_arg )
377
- (fusion6_freqs_cos_user ,) = (s for s in fusion_bsyms [6 ].subsymbols if s .args [0 ] is fusion1_freqs_cos_arg )
377
+ (fusion5_freqs_cos_user ,) = (s for s in fusion_bsyms [5 ].subsymbols if s .args [0 ] is fusion5_freqs_cos_arg )
378
378
379
- assert fusion1_freqs_cos_user .sym is fusion6_freqs_cos_user .sym
380
- assert fusion1_freqs_cos_user .args [1 :] == fusion6_freqs_cos_user .args [1 :]
379
+ assert fusion1_freqs_cos_user .sym is fusion5_freqs_cos_user .sym
380
+ assert fusion1_freqs_cos_user .args [1 :] == fusion5_freqs_cos_user .args [1 :]
381
381
382
382
383
383
# Tests that two separated nvFuser regions can be merged when they don't depend
@@ -1117,3 +1117,43 @@ def fn(a, b):
1117
1117
# verify the functionality of the above flags.
1118
1118
with pytest .raises (RuntimeError , match = "Can not find a scheduler to schedule fusion segment" ):
1119
1119
out = compiled_func (* inps )
1120
+
1121
+
1122
+ @instantiate (
1123
+ dtypes = (thunder .float32 ,),
1124
+ devicetypes = (devices .DeviceType .CUDA ,),
1125
+ executors = (nvFuserExecutor ,),
1126
+ )
1127
+ def test_no_shape_only_fusion_region (executor , device : str , thunder_dtype : dtypes .dtype ):
1128
+ x = make_tensor (2 , 2 , 2 , device = device , dtype = ltorch .to_torch_dtype (thunder_dtype ))
1129
+
1130
+ def fn (x ):
1131
+ return x .view (4 , - 1 ).transpose (0 , 1 )
1132
+
1133
+ jfn = thunder .jit (fn )
1134
+
1135
+ expected = fn (x )
1136
+ actual = jfn (x )
1137
+
1138
+ torch .testing .assert_close (actual , expected )
1139
+
1140
+ fwd_trace = thunder .last_traces (jfn )[- 1 ]
1141
+
1142
+ # Make sure there are no fusion symbols.
1143
+ assert all (not bsym .sym .is_fusion for bsym in fwd_trace .bound_symbols )
1144
+
1145
+ # Verify that we create fusion even if we have a single compute op.
1146
+ def fn (x ):
1147
+ # There is a `sin` which is not a shape op.
1148
+ return x .view (4 , - 1 ).transpose (0 , 1 ).sin ().transpose (0 , 1 ).view (2 , 2 , 2 )
1149
+
1150
+ jfn = thunder .jit (fn )
1151
+ expected = fn (x )
1152
+ actual = jfn (x )
1153
+
1154
+ torch .testing .assert_close (actual , expected )
1155
+
1156
+ fwd_trace = thunder .last_traces (jfn )[- 1 ]
1157
+
1158
+ # Make sure there is a fusion symbol.
1159
+ assert any (bsym .sym .is_fusion for bsym in fwd_trace .bound_symbols )
0 commit comments