Skip to content

Commit 4672a5d

Browse files
Embedding support for nvfuser in inference (#1674)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d2b5dfd commit 4672a5d

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,8 +946,9 @@ def fusion_pass(self, trace: TraceCtx) -> TraceCtx:
946946
register_executor(ex)
947947

948948

949-
def register_supported(id: Hashable, translator: Callable, checker: Callable):
950-
ex.register_supported(id, checker)
949+
def register_supported(sym_or_id: Hashable, translator: Callable, checker: Callable):
950+
ex.register_supported(sym_or_id, checker)
951+
id = sym_or_id.id if isinstance(sym_or_id, Symbol) else sym_or_id
951952
_translation_map[id] = translator
952953

953954

@@ -2582,3 +2583,47 @@ def scaled_dot_product_flash_attention_grad(
25822583
execution_transform=scaled_dot_product_flash_attention,
25832584
grad_transform=scaled_dot_product_flash_attention_grad,
25842585
)
2586+
2587+
2588+
def _embedding_check(
2589+
input: TensorProxy,
2590+
weight: TensorProxy,
2591+
padding_idx: None | int,
2592+
max_norm: None | float,
2593+
norm_type: None | float,
2594+
scale_grad_by_freq: None | bool,
2595+
sparse: None | bool,
2596+
) -> bool:
2597+
if nvfuser_version() < LooseVersion("0.2.25"):
2598+
return False
2599+
enable_embedding: None | bool = get_compile_option("nv_enable_embedding", "Enable nvFuser embedding.")
2600+
if not enable_embedding:
2601+
return False
2602+
# Verify input and weight are supported tensors.
2603+
if not are_supported_tensors(input, weight) or (weight.ndim != 2):
2604+
return False
2605+
return True
2606+
2607+
2608+
def embedding(
2609+
input: TensorProxy,
2610+
weight: TensorProxy,
2611+
padding_idx: None | int = None,
2612+
max_norm: None | float = None,
2613+
norm_type: None | float = 2.0,
2614+
scale_grad_by_freq: None | bool = False,
2615+
sparse: None | bool = False,
2616+
*,
2617+
fd: FusionDefinition,
2618+
lc_to_nv_map: dict,
2619+
) -> Any:
2620+
inputs = [input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse]
2621+
nv_inputs = []
2622+
for inp in inputs:
2623+
nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None
2624+
nv_inputs.append(nv_inp)
2625+
return fd.ops.embedding_fwd(*nv_inputs)
2626+
2627+
2628+
register_supported(PrimIDs.EMBEDDING, embedding, _embedding_check)
2629+
register_supported(ltorch.embedding, embedding, _embedding_check)

thunder/tests/test_nvfuser.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
get_opinfo,
4141
linear_opinfo,
4242
matmul_opinfo,
43+
embedding_opinfo,
4344
)
4445
from looseversion import LooseVersion
4546

@@ -1214,3 +1215,34 @@ def fn(x):
12141215

12151216
# Make sure there is a fusion symbol.
12161217
assert any(bsym.sym.is_fusion for bsym in fwd_trace.bound_symbols)
1218+
1219+
1220+
@instantiate(
1221+
dtypes=(thunder.float16, thunder.bfloat16),
1222+
devicetypes=(devices.DeviceType.CUDA,),
1223+
executors=(nvFuserExecutor,),
1224+
decorators=(
1225+
pytest.mark.skipif(
1226+
nvfuser_version() is None or nvfuser_version() < LooseVersion("0.2.25"),
1227+
reason="Requires nvFuser version 0.2.25 or later",
1228+
),
1229+
),
1230+
)
1231+
def test_embedding(
1232+
executor,
1233+
device: str,
1234+
dtype: dtypes.dtype,
1235+
):
1236+
1237+
def embedding_fn(inputs):
1238+
return torch.nn.functional.embedding(*inputs)
1239+
1240+
for sample in embedding_opinfo.sample_inputs(device, dtype):
1241+
compiled_func = thunder.jit(embedding_fn, executors_list=executor.executors_list(), nv_enable_embedding=True)
1242+
out = compiled_func(sample.args)
1243+
expected_out = torch.nn.functional.embedding(*sample.args)
1244+
fwd_trace = thunder.last_traces(compiled_func)[-1]
1245+
fwd_fusion = examine.get_fusions(fwd_trace)
1246+
1247+
assert len(fwd_fusion) == 1
1248+
torch.testing.assert_close(out, expected_out)

0 commit comments

Comments
 (0)