Skip to content

Commit 2d0199e

Browse files
authored
[nvfuser] Don't allow shape only nvfuser region (#1559)
Related: #1251 Microbenchmark ```python import torch import torch.utils.benchmark import thunder # Seen in HF-Qwen2 Model # [t6866] = nvFusion4(t6857) # # lora_res_12 = prims.reshape(t6857, (1, 4096, 16)) # lora_res_12: "cuda:0 bf16[1, 4096, 16]" # # t6866 = prims.reshape(lora_res_12, (4096, 16)) # t6866: "cuda:0 bf16[4096, 16]" def fn(x): x = x.reshape(-1, 4096, 16) x = x.reshape(4096, 16) return x x = torch.randn(4096, 16, device="cuda") jfn = thunder.jit(fn, nv_enable_shape_only_fusion=False) jfn(x) timer = torch.utils.benchmark.Timer("jfn(x)", globals={"jfn": jfn, "x": x}) print(timer.blocked_autorange(min_run_time=1)) # print(thunder.last_traces(jfn)[-1]) ``` Before PR ```python <torch.utils.benchmark.utils.common.Measurement object at 0x7305591c54e0> jfn(x) Median: 36.18 us 3 measurements, 10000 runs per measurement, 1 thread ``` After PR ```python <torch.utils.benchmark.utils.common.Measurement object at 0x7e533962cc10> jfn(x) Median: 18.95 us IQR: 0.02 us (18.93 to 18.96) 6 measurements, 10000 runs per measurement, 1 thread ```
1 parent e37bec2 commit 2d0199e

File tree

4 files changed

+70
-23
lines changed

4 files changed

+70
-23
lines changed

thunder/core/prims.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3585,7 +3585,7 @@ def stride_order_meta(a: TensorProxy, /, order: Sequence[int]) -> TensorProxy:
35853585
# TODO Consider a more general stride manipulation primitive, like PyTorch's
35863586
# as_strided or set_strided operations
35873587
# See clang.stride_order for this prim's documentation
3588-
stride_order = make_prim(PrimIDs.STRIDE_ORDER, "stride_order", meta=stride_order_meta)
3588+
stride_order = make_prim(PrimIDs.STRIDE_ORDER, "stride_order", meta=stride_order_meta, tags=(OpTags.SHAPE_OP,))
35893589

35903590
#
35913591
# Reduction prims

thunder/executors/nvfuserex_impl.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,18 @@ def __repr__(self):
488488
return f"FusionDefinitionWrapper({self.name})"
489489

490490

491+
def all_tagged(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
492+
""":obj:`True` if `bsym` and its subsymbols all are tagged with ``tag``."""
493+
if not has_tags(bsym, {tag}):
494+
return False
495+
496+
for sbsym in bsym.subsymbols:
497+
if not has_tags(sbsym, {tag}):
498+
return False
499+
500+
return True
501+
502+
491503
# Group bookend meta operations into separate regions
492504
# This function returns a List[Region] which changes the executor of meta regions to torchex
493505
#
@@ -524,20 +536,10 @@ def can_move_to_rear(bsym: BoundSymbol) -> bool:
524536
return False
525537
return True
526538

527-
def all_tagged(bsym: BoundSymbol, tags: set[prims.OpTags]) -> bool:
528-
if not has_tags(bsym, tags):
529-
return False
530-
531-
for sbsym in bsym.subsymbols:
532-
if not has_tags(sbsym, tags):
533-
return False
534-
535-
return True
536-
537539
# traversing all bound_symbols in topo order
538540
for bsym in region.bound_symbols:
539541
# we look at meta operations that can be moved to the front
540-
if all_tagged(bsym, {prims.OpTags.SHAPE_OP}) and can_move_to_front(bsym):
542+
if all_tagged(bsym, prims.OpTags.SHAPE_OP) and can_move_to_front(bsym):
541543
# when we remove a node, we add all the bsym's flat_outs to region_inputs
542544
front_meta_cluster.append(bsym)
543545
for out in bsym.flat_outs:
@@ -548,7 +550,7 @@ def all_tagged(bsym: BoundSymbol, tags: set[prims.OpTags]) -> bool:
548550
middle_cluster.append(bsym)
549551
# traversing all bound_symbols in reverse topo order
550552
for bsym in reversed(copy(middle_cluster)):
551-
if all_tagged(bsym, {prims.OpTags.SHAPE_OP}) and can_move_to_rear(bsym):
553+
if all_tagged(bsym, prims.OpTags.SHAPE_OP) and can_move_to_rear(bsym):
552554
middle_cluster.remove(bsym)
553555
rear_meta_cluster.insert(0, bsym)
554556

@@ -898,6 +900,12 @@ def _can_fuse_node(n: Node):
898900
else:
899901
bookend_result = {"front_bsyms": [], "fusion": region, "rear_bsyms": []}
900902

903+
# Don't fuse a region which has only Shape Operations.
904+
all_shape_ops = all(map(lambda bsym: all_tagged(bsym, prims.OpTags.SHAPE_OP), bsyms))
905+
if all_shape_ops:
906+
fused_bsyms.extend(bsyms)
907+
continue
908+
901909
if len(bsyms) == 1:
902910
bsym: BoundSymbol = bsyms[0]
903911
can_fuse: bool = self.can_fuse(bsym)

thunder/tests/test_networks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,6 @@ def test_hf_llama():
517517

518518
# transformers logs a cache deprecation warning
519519
llama_logger.setLevel(logging.CRITICAL)
520-
model_id = "meta-llama/Llama-3.2-1B"
521520

522521
config_args = LLAMA_3_2_1B_CFG.copy()
523522
config_args["num_hidden_layers"] = 1
@@ -554,4 +553,4 @@ def test_hf_llama():
554553

555554
top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols}
556555
# changes this to fewer as needed, the goal is to not have too many fusions
557-
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7
556+
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 6

thunder/tests/test_nvfuser.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -356,28 +356,28 @@ def test_cse_rematerialization(executor, device, _):
356356

357357
fw_trace = thunder.last_traces(compiled_func)[-1]
358358
fusion_bsyms = tuple(filter(lambda a: a.sym.is_fusion, fw_trace.bound_symbols))
359-
assert len(fusion_bsyms) == 11
359+
assert len(fusion_bsyms) == 9
360360
# fusion groups 1 and 6 correspond with the apply_rotary_emb function
361361
# Nvfuser with recomputation should use precomputed cos and sin values.
362-
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[6].args)
362+
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[5].args)
363363

364364
# Below, we check that freqs_sin and freqs_cos are used
365365
# in the same operation in both fusions.
366366
(fusion1_freqs_sin_arg,) = (a for a in fusion_bsyms[1].args if a.name == "freqs_sin")
367367
(fusion1_freqs_cos_arg,) = (a for a in fusion_bsyms[1].args if a.name == "freqs_cos")
368-
(fusion6_freqs_sin_arg,) = (a for a in fusion_bsyms[6].args if a.name == "freqs_sin")
369-
(fusion6_freqs_cos_arg,) = (a for a in fusion_bsyms[6].args if a.name == "freqs_cos")
368+
(fusion5_freqs_sin_arg,) = (a for a in fusion_bsyms[5].args if a.name == "freqs_sin")
369+
(fusion5_freqs_cos_arg,) = (a for a in fusion_bsyms[5].args if a.name == "freqs_cos")
370370

371371
(fusion1_freqs_sin_user,) = (s for s in fusion_bsyms[1].subsymbols if s.args[0] is fusion1_freqs_sin_arg)
372-
(fusion6_freqs_sin_user,) = (s for s in fusion_bsyms[6].subsymbols if s.args[0] is fusion6_freqs_sin_arg)
372+
(fusion6_freqs_sin_user,) = (s for s in fusion_bsyms[5].subsymbols if s.args[0] is fusion5_freqs_sin_arg)
373373

374374
assert fusion1_freqs_sin_user.sym is fusion6_freqs_sin_user.sym
375375
assert fusion1_freqs_sin_user.args[1:] == fusion6_freqs_sin_user.args[1:]
376376
(fusion1_freqs_cos_user,) = (s for s in fusion_bsyms[1].subsymbols if s.args[0] is fusion1_freqs_cos_arg)
377-
(fusion6_freqs_cos_user,) = (s for s in fusion_bsyms[6].subsymbols if s.args[0] is fusion1_freqs_cos_arg)
377+
(fusion5_freqs_cos_user,) = (s for s in fusion_bsyms[5].subsymbols if s.args[0] is fusion5_freqs_cos_arg)
378378

379-
assert fusion1_freqs_cos_user.sym is fusion6_freqs_cos_user.sym
380-
assert fusion1_freqs_cos_user.args[1:] == fusion6_freqs_cos_user.args[1:]
379+
assert fusion1_freqs_cos_user.sym is fusion5_freqs_cos_user.sym
380+
assert fusion1_freqs_cos_user.args[1:] == fusion5_freqs_cos_user.args[1:]
381381

382382

383383
# Tests that two separated nvFuser regions can be merged when they don't depend
@@ -1117,3 +1117,43 @@ def fn(a, b):
11171117
# verify the functionality of the above flags.
11181118
with pytest.raises(RuntimeError, match="Can not find a scheduler to schedule fusion segment"):
11191119
out = compiled_func(*inps)
1120+
1121+
1122+
@instantiate(
1123+
dtypes=(thunder.float32,),
1124+
devicetypes=(devices.DeviceType.CUDA,),
1125+
executors=(nvFuserExecutor,),
1126+
)
1127+
def test_no_shape_only_fusion_region(executor, device: str, thunder_dtype: dtypes.dtype):
1128+
x = make_tensor(2, 2, 2, device=device, dtype=ltorch.to_torch_dtype(thunder_dtype))
1129+
1130+
def fn(x):
1131+
return x.view(4, -1).transpose(0, 1)
1132+
1133+
jfn = thunder.jit(fn)
1134+
1135+
expected = fn(x)
1136+
actual = jfn(x)
1137+
1138+
torch.testing.assert_close(actual, expected)
1139+
1140+
fwd_trace = thunder.last_traces(jfn)[-1]
1141+
1142+
# Make sure there are no fusion symbols.
1143+
assert all(not bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols)
1144+
1145+
# Verify that we create fusion even if we have a single compute op.
1146+
def fn(x):
1147+
# There is a `sin` which is not a shape op.
1148+
return x.view(4, -1).transpose(0, 1).sin().transpose(0, 1).view(2, 2, 2)
1149+
1150+
jfn = thunder.jit(fn)
1151+
expected = fn(x)
1152+
actual = jfn(x)
1153+
1154+
torch.testing.assert_close(actual, expected)
1155+
1156+
fwd_trace = thunder.last_traces(jfn)[-1]
1157+
1158+
# Make sure there is a fusion symbol.
1159+
assert any(bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols)

0 commit comments

Comments
 (0)