We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Related: #1415
A 2D column major tensor seems to get its strides changed before the transform for _scaled_mm so that it becomes row major.
_scaled_mm
If f in the following snippet returns y, the strides of y are kept as is.
f
y
Steps to reproduce the behavior:
import torch import thunder def f(x, y, scale_x, scale_y): return torch._scaled_mm(x, y, scale_a=scale_x, scale_b=scale_y, out_dtype=torch.float32) def main(): device = torch.device("cuda") x = torch.randn(32, 64, device=device).to(dtype=torch.float8_e4m3fn) y = torch.randn(64, 96, device=device).to(dtype=torch.float8_e4m3fn).t().contiguous().t() print(f"$$$ {x.stride() = }, {y.stride() = }") scale_x = torch.tensor(1.0, device=device) scale_y = torch.tensor(1.0, device=device) expected = f(x, y, scale_x, scale_y) jitted = thunder.jit(f) actual = jitted(x, y, scale_x, scale_y) print(thunder.last_traces(jitted)[-1]) torch.testing.assert_close(actual, expected) if __name__ == "__main__": main()
The output of this script is as follows and it indicates that the transform's if not column_major branch of https://github.com/Lightning-AI/lightning-thunder/blob/crpa/subclass-torchao_float8tensor/thunder/executors/torchex.py#L1410 kicks in.
if not column_major
$$$ x.stride() = (64, 1), y.stride() = (1, 64) # Constructed by Unwrap the actual return value import torch from torch import Tensor from thunder.executors.torchex import no_autocast @torch.no_grad() @no_autocast def computation(x, y, scale_x, scale_y): # x: "cuda:0 f8_e4m3fn[32, 64]" # y: "cuda:0 f8_e4m3fn[64, 96]" # scale_x: "cuda:0 f32[]" # scale_y: "cuda:0 f32[]" t5 = torch.transpose(y, 0, 1) # t5: "cuda:0 f8_e4m3fn[96, 64]" # t5 = ltorch.transpose(y, 0, 1) # t5: "cuda:0 f8_e4m3fn[96, 64]" # t5 = prims.transpose(y, (1, 0)) # t5: "cuda:0 f8_e4m3fn[96, 64]" # /opt/pytorch/lightning-thunder/nvfuser_scaled_mm.py:7: return torch._scaled_mm(x, y, scale_a=scale_x, scale_b=scale_y, out_dtype=torch.float32) t6 = Tensor.contiguous(t5, memory_format=_torch_memory_format_0) # t6: "cuda:0 f8_e4m3fn[96, 64]" # t6 = ltorch.contiguous(t5, memory_format=_torch_memory_format_0) # t6: "cuda:0 f8_e4m3fn[96, 64]" # t6 = prims.stride_order(t5, (1, 0)) # t6: "cuda:0 f8_e4m3fn[96, 64]" del t5 t7 = torch.transpose(t6, 0, 1) # t7: "cuda:0 f8_e4m3fn[64, 96]" # t7 = ltorch.transpose(t6, 0, 1) # t7: "cuda:0 f8_e4m3fn[64, 96]" # t7 = prims.transpose(t6, (1, 0)) # t7: "cuda:0 f8_e4m3fn[64, 96]" del t6 # /opt/pytorch/lightning-thunder/nvfuser_scaled_mm.py:7: return torch._scaled_mm(x, y, scale_a=scale_x, scale_b=scale_y, out_dtype=torch.float32) t0 = torch._scaled_mm(x, t7, scale_x, scale_y, None, None, torch.float32, False) # t0: "cuda:0 f32[32, 96]" del t7 return (t0,)
From my perspective, it doesn't feel intuitive that the input's strides are changed.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
🐛 Bug
Related: #1415
A 2D column major tensor seems to get its strides changed before the transform for
_scaled_mm
so that it becomes row major.If
f
in the following snippet returnsy
, the strides ofy
are kept as is.To Reproduce
Steps to reproduce the behavior:
Code sample
The output of this script is as follows and it indicates that the transform's
if not column_major
branch of https://github.com/Lightning-AI/lightning-thunder/blob/crpa/subclass-torchao_float8tensor/thunder/executors/torchex.py#L1410 kicks in.Expected behavior
From my perspective, it doesn't feel intuitive that the input's strides are changed.
The text was updated successfully, but these errors were encountered: