Skip to content

Commit

Permalink
[nvfuser] Don't allow shape only nvfuser region (#1559)
Browse files Browse the repository at this point in the history
Related: #1251

Microbenchmark
```python
import torch
import torch.utils.benchmark
import thunder

# Seen in HF-Qwen2 Model
# [t6866] = nvFusion4(t6857)
#     # lora_res_12 = prims.reshape(t6857, (1, 4096, 16))  # lora_res_12: "cuda:0 bf16[1, 4096, 16]"
#     # t6866 = prims.reshape(lora_res_12, (4096, 16))  # t6866: "cuda:0 bf16[4096, 16]"

def fn(x):
    x = x.reshape(-1, 4096, 16)
    x = x.reshape(4096, 16)
    return x


x = torch.randn(4096, 16, device="cuda")
jfn = thunder.jit(fn, nv_enable_shape_only_fusion=False)

jfn(x)

timer = torch.utils.benchmark.Timer("jfn(x)", globals={"jfn": jfn, "x": x})

print(timer.blocked_autorange(min_run_time=1))
# print(thunder.last_traces(jfn)[-1])
```

Before PR
```python
<torch.utils.benchmark.utils.common.Measurement object at 0x7305591c54e0>
jfn(x)
  Median: 36.18 us
  3 measurements, 10000 runs per measurement, 1 thread
```

After PR
```python
<torch.utils.benchmark.utils.common.Measurement object at 0x7e533962cc10>
jfn(x)
  Median: 18.95 us
  IQR:    0.02 us (18.93 to 18.96)
  6 measurements, 10000 runs per measurement, 1 thread
```
  • Loading branch information
kshitij12345 authored Dec 18, 2024
1 parent e37bec2 commit 2d0199e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 23 deletions.
2 changes: 1 addition & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -3585,7 +3585,7 @@ def stride_order_meta(a: TensorProxy, /, order: Sequence[int]) -> TensorProxy:
# TODO Consider a more general stride manipulation primitive, like PyTorch's
# as_strided or set_strided operations
# See clang.stride_order for this prim's documentation
stride_order = make_prim(PrimIDs.STRIDE_ORDER, "stride_order", meta=stride_order_meta)
stride_order = make_prim(PrimIDs.STRIDE_ORDER, "stride_order", meta=stride_order_meta, tags=(OpTags.SHAPE_OP,))

#
# Reduction prims
Expand Down
32 changes: 20 additions & 12 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,18 @@ def __repr__(self):
return f"FusionDefinitionWrapper({self.name})"


def all_tagged(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
""":obj:`True` if `bsym` and its subsymbols all are tagged with ``tag``."""
if not has_tags(bsym, {tag}):
return False

for sbsym in bsym.subsymbols:
if not has_tags(sbsym, {tag}):
return False

return True


# Group bookend meta operations into separate regions
# This function returns a List[Region] which changes the executor of meta regions to torchex
#
Expand Down Expand Up @@ -524,20 +536,10 @@ def can_move_to_rear(bsym: BoundSymbol) -> bool:
return False
return True

def all_tagged(bsym: BoundSymbol, tags: set[prims.OpTags]) -> bool:
if not has_tags(bsym, tags):
return False

for sbsym in bsym.subsymbols:
if not has_tags(sbsym, tags):
return False

return True

# traversing all bound_symbols in topo order
for bsym in region.bound_symbols:
# we look at meta operations that can be moved to the front
if all_tagged(bsym, {prims.OpTags.SHAPE_OP}) and can_move_to_front(bsym):
if all_tagged(bsym, prims.OpTags.SHAPE_OP) and can_move_to_front(bsym):
# when we remove a node, we add all the bsym's flat_outs to region_inputs
front_meta_cluster.append(bsym)
for out in bsym.flat_outs:
Expand All @@ -548,7 +550,7 @@ def all_tagged(bsym: BoundSymbol, tags: set[prims.OpTags]) -> bool:
middle_cluster.append(bsym)
# traversing all bound_symbols in reverse topo order
for bsym in reversed(copy(middle_cluster)):
if all_tagged(bsym, {prims.OpTags.SHAPE_OP}) and can_move_to_rear(bsym):
if all_tagged(bsym, prims.OpTags.SHAPE_OP) and can_move_to_rear(bsym):
middle_cluster.remove(bsym)
rear_meta_cluster.insert(0, bsym)

Expand Down Expand Up @@ -898,6 +900,12 @@ def _can_fuse_node(n: Node):
else:
bookend_result = {"front_bsyms": [], "fusion": region, "rear_bsyms": []}

# Don't fuse a region which has only Shape Operations.
all_shape_ops = all(map(lambda bsym: all_tagged(bsym, prims.OpTags.SHAPE_OP), bsyms))
if all_shape_ops:
fused_bsyms.extend(bsyms)
continue

if len(bsyms) == 1:
bsym: BoundSymbol = bsyms[0]
can_fuse: bool = self.can_fuse(bsym)
Expand Down
3 changes: 1 addition & 2 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ def test_hf_llama():

# transformers logs a cache deprecation warning
llama_logger.setLevel(logging.CRITICAL)
model_id = "meta-llama/Llama-3.2-1B"

config_args = LLAMA_3_2_1B_CFG.copy()
config_args["num_hidden_layers"] = 1
Expand Down Expand Up @@ -554,4 +553,4 @@ def test_hf_llama():

top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols}
# changes this to fewer as needed, the goal is to not have too many fusions
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 6
56 changes: 48 additions & 8 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,28 +356,28 @@ def test_cse_rematerialization(executor, device, _):

fw_trace = thunder.last_traces(compiled_func)[-1]
fusion_bsyms = tuple(filter(lambda a: a.sym.is_fusion, fw_trace.bound_symbols))
assert len(fusion_bsyms) == 11
assert len(fusion_bsyms) == 9
# fusion groups 1 and 6 correspond with the apply_rotary_emb function
# Nvfuser with recomputation should use precomputed cos and sin values.
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[6].args)
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[5].args)

# Below, we check that freqs_sin and freqs_cos are used
# in the same operation in both fusions.
(fusion1_freqs_sin_arg,) = (a for a in fusion_bsyms[1].args if a.name == "freqs_sin")
(fusion1_freqs_cos_arg,) = (a for a in fusion_bsyms[1].args if a.name == "freqs_cos")
(fusion6_freqs_sin_arg,) = (a for a in fusion_bsyms[6].args if a.name == "freqs_sin")
(fusion6_freqs_cos_arg,) = (a for a in fusion_bsyms[6].args if a.name == "freqs_cos")
(fusion5_freqs_sin_arg,) = (a for a in fusion_bsyms[5].args if a.name == "freqs_sin")
(fusion5_freqs_cos_arg,) = (a for a in fusion_bsyms[5].args if a.name == "freqs_cos")

(fusion1_freqs_sin_user,) = (s for s in fusion_bsyms[1].subsymbols if s.args[0] is fusion1_freqs_sin_arg)
(fusion6_freqs_sin_user,) = (s for s in fusion_bsyms[6].subsymbols if s.args[0] is fusion6_freqs_sin_arg)
(fusion6_freqs_sin_user,) = (s for s in fusion_bsyms[5].subsymbols if s.args[0] is fusion5_freqs_sin_arg)

assert fusion1_freqs_sin_user.sym is fusion6_freqs_sin_user.sym
assert fusion1_freqs_sin_user.args[1:] == fusion6_freqs_sin_user.args[1:]
(fusion1_freqs_cos_user,) = (s for s in fusion_bsyms[1].subsymbols if s.args[0] is fusion1_freqs_cos_arg)
(fusion6_freqs_cos_user,) = (s for s in fusion_bsyms[6].subsymbols if s.args[0] is fusion1_freqs_cos_arg)
(fusion5_freqs_cos_user,) = (s for s in fusion_bsyms[5].subsymbols if s.args[0] is fusion5_freqs_cos_arg)

assert fusion1_freqs_cos_user.sym is fusion6_freqs_cos_user.sym
assert fusion1_freqs_cos_user.args[1:] == fusion6_freqs_cos_user.args[1:]
assert fusion1_freqs_cos_user.sym is fusion5_freqs_cos_user.sym
assert fusion1_freqs_cos_user.args[1:] == fusion5_freqs_cos_user.args[1:]


# Tests that two separated nvFuser regions can be merged when they don't depend
Expand Down Expand Up @@ -1117,3 +1117,43 @@ def fn(a, b):
# verify the functionality of the above flags.
with pytest.raises(RuntimeError, match="Can not find a scheduler to schedule fusion segment"):
out = compiled_func(*inps)


@instantiate(
dtypes=(thunder.float32,),
devicetypes=(devices.DeviceType.CUDA,),
executors=(nvFuserExecutor,),
)
def test_no_shape_only_fusion_region(executor, device: str, thunder_dtype: dtypes.dtype):
x = make_tensor(2, 2, 2, device=device, dtype=ltorch.to_torch_dtype(thunder_dtype))

def fn(x):
return x.view(4, -1).transpose(0, 1)

jfn = thunder.jit(fn)

expected = fn(x)
actual = jfn(x)

torch.testing.assert_close(actual, expected)

fwd_trace = thunder.last_traces(jfn)[-1]

# Make sure there are no fusion symbols.
assert all(not bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols)

# Verify that we create fusion even if we have a single compute op.
def fn(x):
# There is a `sin` which is not a shape op.
return x.view(4, -1).transpose(0, 1).sin().transpose(0, 1).view(2, 2, 2)

jfn = thunder.jit(fn)
expected = fn(x)
actual = jfn(x)

torch.testing.assert_close(actual, expected)

fwd_trace = thunder.last_traces(jfn)[-1]

# Make sure there is a fusion symbol.
assert any(bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols)

0 comments on commit 2d0199e

Please sign in to comment.