Skip to content

Commit b3e062a

Browse files
authored
Plumbing the topk to the nvFuser executor (#2237)
1 parent 0ba1063 commit b3e062a

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3016,6 +3016,34 @@ def cross_entropy_grad(
30163016
)
30173017

30183018

3019+
def _topk_check_(
3020+
a: TensorProxy, /, k: int, dim: int | None = None, largest: Number = 1, sorted: Number = 1, *args
3021+
) -> bool:
3022+
if a.ndim <= 0:
3023+
return False
3024+
if dim >= a.ndim or (dim is not None and dim < -a.ndim):
3025+
return False
3026+
return True
3027+
3028+
3029+
def topk_transform(
3030+
a: TensorProxy,
3031+
/,
3032+
k: int,
3033+
dim: int | None = None,
3034+
largest: Number = 1,
3035+
sorted: Number = 1,
3036+
*,
3037+
fd: FusionDefinition,
3038+
lc_to_nv_map: dict,
3039+
) -> any:
3040+
nva = getnv(a, fd, lc_to_nv_map)
3041+
nvk = getnv(k, fd, lc_to_nv_map)
3042+
return fd.ops.topk(nva, nvk, dim, bool(largest), bool(sorted))
3043+
3044+
3045+
register_supported(prims.topk, topk_transform, _topk_check_)
3046+
30193047
# At module/class level
30203048
NVFUSER_SUPPORTS_OPTIONS = nvfuser_version() >= LooseVersion("0.2.23")
30213049
assert NVFUSER_SUPPORTS_OPTIONS, (

0 commit comments

Comments
 (0)