diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 32ea0a574c..66f896c903 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -348,16 +348,14 @@ def _group_sizes_from_offsets(offsets: torch.Tensor) -> list[int]: if LooseVersion(torch.__version__) >= LooseVersion("2.8.0"): - # Required otherwise, there is a graph-break. + # Required -- otherwise there is a graph-break. _grouped_mm = torch.compiler.allow_in_graph(torch._grouped_mm) +else: + _grouped_mm = None -# This function should be replaced with torch._grouped_mm. However, -# torch._grouped_mm is yet to be usable because it requires offsets being -# multiples of 16. def grouped_mm(a: torch.Tensor, b: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: - if torch.compiler.is_compiling(): - # NOTE: This path also works for `thunder.jit` as it has a lookaside for `torch.compiler.is_compiling`. + if _grouped_mm is not None: return _grouped_mm(a, b, offsets) group_sizes = _group_sizes_from_offsets(offsets)