From 2d0199e1d09ea3c7d1b271207962256b01cd3e7f Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Wed, 18 Dec 2024 13:23:46 +0100 Subject: [PATCH] [nvfuser] Don't allow shape only nvfuser region (#1559) Related: https://github.com/Lightning-AI/lightning-thunder/issues/1251 Microbenchmark ```python import torch import torch.utils.benchmark import thunder # Seen in HF-Qwen2 Model # [t6866] = nvFusion4(t6857) # # lora_res_12 = prims.reshape(t6857, (1, 4096, 16)) # lora_res_12: "cuda:0 bf16[1, 4096, 16]" # # t6866 = prims.reshape(lora_res_12, (4096, 16)) # t6866: "cuda:0 bf16[4096, 16]" def fn(x): x = x.reshape(-1, 4096, 16) x = x.reshape(4096, 16) return x x = torch.randn(4096, 16, device="cuda") jfn = thunder.jit(fn, nv_enable_shape_only_fusion=False) jfn(x) timer = torch.utils.benchmark.Timer("jfn(x)", globals={"jfn": jfn, "x": x}) print(timer.blocked_autorange(min_run_time=1)) # print(thunder.last_traces(jfn)[-1]) ``` Before PR ```python jfn(x) Median: 36.18 us 3 measurements, 10000 runs per measurement, 1 thread ``` After PR ```python jfn(x) Median: 18.95 us IQR: 0.02 us (18.93 to 18.96) 6 measurements, 10000 runs per measurement, 1 thread ``` --- thunder/core/prims.py | 2 +- thunder/executors/nvfuserex_impl.py | 32 ++++++++++------- thunder/tests/test_networks.py | 3 +- thunder/tests/test_nvfuser.py | 56 ++++++++++++++++++++++++----- 4 files changed, 70 insertions(+), 23 deletions(-) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index c17a28296c..da46147155 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -3585,7 +3585,7 @@ def stride_order_meta(a: TensorProxy, /, order: Sequence[int]) -> TensorProxy: # TODO Consider a more general stride manipulation primitive, like PyTorch's # as_strided or set_strided operations # See clang.stride_order for this prim's documentation -stride_order = make_prim(PrimIDs.STRIDE_ORDER, "stride_order", meta=stride_order_meta) +stride_order = make_prim(PrimIDs.STRIDE_ORDER, "stride_order", meta=stride_order_meta, tags=(OpTags.SHAPE_OP,)) # # Reduction prims diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 4bd78b72b5..9abab32c79 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -488,6 +488,18 @@ def __repr__(self): return f"FusionDefinitionWrapper({self.name})" +def all_tagged(bsym: BoundSymbol, tag: prims.OpTags) -> bool: + """:obj:`True` if `bsym` and its subsymbols all are tagged with ``tag``.""" + if not has_tags(bsym, {tag}): + return False + + for sbsym in bsym.subsymbols: + if not has_tags(sbsym, {tag}): + return False + + return True + + # Group bookend meta operations into separate regions # This function returns a List[Region] which changes the executor of meta regions to torchex # @@ -524,20 +536,10 @@ def can_move_to_rear(bsym: BoundSymbol) -> bool: return False return True - def all_tagged(bsym: BoundSymbol, tags: set[prims.OpTags]) -> bool: - if not has_tags(bsym, tags): - return False - - for sbsym in bsym.subsymbols: - if not has_tags(sbsym, tags): - return False - - return True - # traversing all bound_symbols in topo order for bsym in region.bound_symbols: # we look at meta operations that can be moved to the front - if all_tagged(bsym, {prims.OpTags.SHAPE_OP}) and can_move_to_front(bsym): + if all_tagged(bsym, prims.OpTags.SHAPE_OP) and can_move_to_front(bsym): # when we remove a node, we add all the bsym's flat_outs to region_inputs front_meta_cluster.append(bsym) for out in bsym.flat_outs: @@ -548,7 +550,7 @@ def all_tagged(bsym: BoundSymbol, tags: set[prims.OpTags]) -> bool: middle_cluster.append(bsym) # traversing all bound_symbols in reverse topo order for bsym in reversed(copy(middle_cluster)): - if all_tagged(bsym, {prims.OpTags.SHAPE_OP}) and can_move_to_rear(bsym): + if all_tagged(bsym, prims.OpTags.SHAPE_OP) and can_move_to_rear(bsym): middle_cluster.remove(bsym) rear_meta_cluster.insert(0, bsym) @@ -898,6 +900,12 @@ def _can_fuse_node(n: Node): else: bookend_result = {"front_bsyms": [], "fusion": region, "rear_bsyms": []} + # Don't fuse a region which has only Shape Operations. + all_shape_ops = all(map(lambda bsym: all_tagged(bsym, prims.OpTags.SHAPE_OP), bsyms)) + if all_shape_ops: + fused_bsyms.extend(bsyms) + continue + if len(bsyms) == 1: bsym: BoundSymbol = bsyms[0] can_fuse: bool = self.can_fuse(bsym) diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index ef09fd834b..3d80b5473f 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -517,7 +517,6 @@ def test_hf_llama(): # transformers logs a cache deprecation warning llama_logger.setLevel(logging.CRITICAL) - model_id = "meta-llama/Llama-3.2-1B" config_args = LLAMA_3_2_1B_CFG.copy() config_args["num_hidden_layers"] = 1 @@ -554,4 +553,4 @@ def test_hf_llama(): top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols} # changes this to fewer as needed, the goal is to not have too many fusions - assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7 + assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 6 diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 2f8cd0ed09..6ac0d1212e 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -356,28 +356,28 @@ def test_cse_rematerialization(executor, device, _): fw_trace = thunder.last_traces(compiled_func)[-1] fusion_bsyms = tuple(filter(lambda a: a.sym.is_fusion, fw_trace.bound_symbols)) - assert len(fusion_bsyms) == 11 + assert len(fusion_bsyms) == 9 # fusion groups 1 and 6 correspond with the apply_rotary_emb function # Nvfuser with recomputation should use precomputed cos and sin values. - assert len(fusion_bsyms[1].args) == len(fusion_bsyms[6].args) + assert len(fusion_bsyms[1].args) == len(fusion_bsyms[5].args) # Below, we check that freqs_sin and freqs_cos are used # in the same operation in both fusions. (fusion1_freqs_sin_arg,) = (a for a in fusion_bsyms[1].args if a.name == "freqs_sin") (fusion1_freqs_cos_arg,) = (a for a in fusion_bsyms[1].args if a.name == "freqs_cos") - (fusion6_freqs_sin_arg,) = (a for a in fusion_bsyms[6].args if a.name == "freqs_sin") - (fusion6_freqs_cos_arg,) = (a for a in fusion_bsyms[6].args if a.name == "freqs_cos") + (fusion5_freqs_sin_arg,) = (a for a in fusion_bsyms[5].args if a.name == "freqs_sin") + (fusion5_freqs_cos_arg,) = (a for a in fusion_bsyms[5].args if a.name == "freqs_cos") (fusion1_freqs_sin_user,) = (s for s in fusion_bsyms[1].subsymbols if s.args[0] is fusion1_freqs_sin_arg) - (fusion6_freqs_sin_user,) = (s for s in fusion_bsyms[6].subsymbols if s.args[0] is fusion6_freqs_sin_arg) + (fusion6_freqs_sin_user,) = (s for s in fusion_bsyms[5].subsymbols if s.args[0] is fusion5_freqs_sin_arg) assert fusion1_freqs_sin_user.sym is fusion6_freqs_sin_user.sym assert fusion1_freqs_sin_user.args[1:] == fusion6_freqs_sin_user.args[1:] (fusion1_freqs_cos_user,) = (s for s in fusion_bsyms[1].subsymbols if s.args[0] is fusion1_freqs_cos_arg) - (fusion6_freqs_cos_user,) = (s for s in fusion_bsyms[6].subsymbols if s.args[0] is fusion1_freqs_cos_arg) + (fusion5_freqs_cos_user,) = (s for s in fusion_bsyms[5].subsymbols if s.args[0] is fusion5_freqs_cos_arg) - assert fusion1_freqs_cos_user.sym is fusion6_freqs_cos_user.sym - assert fusion1_freqs_cos_user.args[1:] == fusion6_freqs_cos_user.args[1:] + assert fusion1_freqs_cos_user.sym is fusion5_freqs_cos_user.sym + assert fusion1_freqs_cos_user.args[1:] == fusion5_freqs_cos_user.args[1:] # Tests that two separated nvFuser regions can be merged when they don't depend @@ -1117,3 +1117,43 @@ def fn(a, b): # verify the functionality of the above flags. with pytest.raises(RuntimeError, match="Can not find a scheduler to schedule fusion segment"): out = compiled_func(*inps) + + +@instantiate( + dtypes=(thunder.float32,), + devicetypes=(devices.DeviceType.CUDA,), + executors=(nvFuserExecutor,), +) +def test_no_shape_only_fusion_region(executor, device: str, thunder_dtype: dtypes.dtype): + x = make_tensor(2, 2, 2, device=device, dtype=ltorch.to_torch_dtype(thunder_dtype)) + + def fn(x): + return x.view(4, -1).transpose(0, 1) + + jfn = thunder.jit(fn) + + expected = fn(x) + actual = jfn(x) + + torch.testing.assert_close(actual, expected) + + fwd_trace = thunder.last_traces(jfn)[-1] + + # Make sure there are no fusion symbols. + assert all(not bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols) + + # Verify that we create fusion even if we have a single compute op. + def fn(x): + # There is a `sin` which is not a shape op. + return x.view(4, -1).transpose(0, 1).sin().transpose(0, 1).view(2, 2, 2) + + jfn = thunder.jit(fn) + expected = fn(x) + actual = jfn(x) + + torch.testing.assert_close(actual, expected) + + fwd_trace = thunder.last_traces(jfn)[-1] + + # Make sure there is a fusion symbol. + assert any(bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols)