Skip to content

Commit

Permalink
Add clone -> fd.ops.set translation for nvFuser (#1685)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk authored Jan 23, 2025
1 parent 4672a5d commit c951fdb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
8 changes: 8 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,6 +1709,14 @@ def trunc(a: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict)
register_supported(PrimIDs.TRUNC, trunc, _elementwise_unary_check)


def clone(a: TensorProxy, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:
nva = getnv(a, fd, lc_to_nv_map)

return fd.ops.set(nva)


register_supported(PrimIDs.CLONE, clone, _elementwise_unary_check)

#
# Elementwise binary operations
#
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def test_hf_for_nemo(model_id):
def test_hf_llama():
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from transformers.models.llama.modeling_llama import logger as llama_logger
from thunder.examine import get_fusion_symbols
import logging

# transformers logs a cache deprecation warning
Expand Down Expand Up @@ -548,9 +549,8 @@ def test_hf_llama():
expected2 = model(past_key_values=res["past_key_values"], **args2)
assert_close(res2, expected2, rtol=1e-1, atol=1e-1)

top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols}
# changes this to fewer as needed, the goal is to not have too many fusions
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7
assert len(get_fusion_symbols(thunder.last_traces(jm)[-1])) == 7


@requiresCUDA
Expand Down

0 comments on commit c951fdb

Please sign in to comment.