Skip to content

Commit d8f6f57

Browse files
committed
Pull thunder PR "Use torch._grouped_mm in eager mode"
Lightning-AI/lightning-thunder#2721
1 parent 23a60b0 commit d8f6f57

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

benchmarks/python/layers_for_inference_benchmark.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,16 +297,14 @@ def _group_sizes_from_offsets(offsets: torch.Tensor) -> list[int]:
297297

298298

299299
if LooseVersion(torch.__version__) >= LooseVersion("2.8.0"):
300-
# Required otherwise, there is a graph-break.
300+
# Required -- otherwise there is a graph-break.
301301
_grouped_mm = torch.compiler.allow_in_graph(torch._grouped_mm)
302+
else:
303+
_grouped_mm = None
302304

303305

304-
# This function should be replaced with torch._grouped_mm. However,
305-
# torch._grouped_mm is yet to be usable because it requires offsets being
306-
# multiples of 16.
307306
def grouped_mm(a: torch.Tensor, b: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
308-
if torch.compiler.is_compiling():
309-
# NOTE: This path also works for `thunder.jit` as it has a lookaside for `torch.compiler.is_compiling`.
307+
if _grouped_mm is not None:
310308
return _grouped_mm(a, b, offsets)
311309

312310
group_sizes = _group_sizes_from_offsets(offsets)

0 commit comments

Comments
 (0)