Skip to content

Commit

Permalink
Embedding support for nvfuser in inference (#1674)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Priya2698 and pre-commit-ci[bot] authored Jan 23, 2025
1 parent d2b5dfd commit 4672a5d
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
49 changes: 47 additions & 2 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,9 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx:
register_executor(ex)


def register_supported(id: Hashable, translator: Callable, checker: Callable):
ex.register_supported(id, checker)
def register_supported(sym_or_id: Hashable, translator: Callable, checker: Callable):
ex.register_supported(sym_or_id, checker)
id = sym_or_id.id if isinstance(sym_or_id, Symbol) else sym_or_id
_translation_map[id] = translator


Expand Down Expand Up @@ -2582,3 +2583,47 @@ def scaled_dot_product_flash_attention_grad(
execution_transform=scaled_dot_product_flash_attention,
grad_transform=scaled_dot_product_flash_attention_grad,
)


def _embedding_check(
input: TensorProxy,
weight: TensorProxy,
padding_idx: None | int,
max_norm: None | float,
norm_type: None | float,
scale_grad_by_freq: None | bool,
sparse: None | bool,
) -> bool:
if nvfuser_version() < LooseVersion("0.2.25"):
return False
enable_embedding: None | bool = get_compile_option("nv_enable_embedding", "Enable nvFuser embedding.")
if not enable_embedding:
return False
# Verify input and weight are supported tensors.
if not are_supported_tensors(input, weight) or (weight.ndim != 2):
return False
return True


def embedding(
input: TensorProxy,
weight: TensorProxy,
padding_idx: None | int = None,
max_norm: None | float = None,
norm_type: None | float = 2.0,
scale_grad_by_freq: None | bool = False,
sparse: None | bool = False,
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:
inputs = [input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse]
nv_inputs = []
for inp in inputs:
nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None
nv_inputs.append(nv_inp)
return fd.ops.embedding_fwd(*nv_inputs)


register_supported(PrimIDs.EMBEDDING, embedding, _embedding_check)
register_supported(ltorch.embedding, embedding, _embedding_check)
32 changes: 32 additions & 0 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_opinfo,
linear_opinfo,
matmul_opinfo,
embedding_opinfo,
)
from looseversion import LooseVersion

Expand Down Expand Up @@ -1214,3 +1215,34 @@ def fn(x):

# Make sure there is a fusion symbol.
assert any(bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols)


@instantiate(
dtypes=(thunder.float16, thunder.bfloat16),
devicetypes=(devices.DeviceType.CUDA,),
executors=(nvFuserExecutor,),
decorators=(
pytest.mark.skipif(
nvfuser_version() is None or nvfuser_version() < LooseVersion("0.2.25"),
reason="Requires nvFuser version 0.2.25 or later",
),
),
)
def test_embedding(
executor,
device: str,
dtype: dtypes.dtype,
):

def embedding_fn(inputs):
return torch.nn.functional.embedding(*inputs)

for sample in embedding_opinfo.sample_inputs(device, dtype):
compiled_func = thunder.jit(embedding_fn, executors_list=executor.executors_list(), nv_enable_embedding=True)
out = compiled_func(sample.args)
expected_out = torch.nn.functional.embedding(*sample.args)
fwd_trace = thunder.last_traces(compiled_func)[-1]
fwd_fusion = examine.get_fusions(fwd_trace)

assert len(fwd_fusion) == 1
torch.testing.assert_close(out, expected_out)

0 comments on commit 4672a5d

Please sign in to comment.