Skip to content

INTERNAL_ASSERT running Qwen-3-4B from litgpt through thunder #5358

@t-vi

Description

@t-vi

Error running Qwen-3-4B with nvfuser through thunder.

Reproducer:
import litgpt
import torch
import thunder

with torch.device("cuda"):
    m = litgpt.model.GPT.from_name('Qwen3-4B-Instruct-2507')
    m.max_seq_length = 1024
    m.set_kv_cache(1)
    inp = torch.randint(1, 5000, (1, 16), dtype=torch.int64)
    inp_pos = torch.arange(16)
    inp2 = torch.ones(1, 1, dtype=torch.int64) * 200
    inp_pos2 = torch.tensor([16])


tm = thunder.jit(m)
tm(inp, inp_pos)

An error occurred while executing nvFuser FusionDefinition 1.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:

# CUDA devices:
#  0: NVIDIA H100 80GB HBM3
# torch version: 2.8.0+cu128
# cuda version: 12.8
# nvfuser version: 0.2.34+gitb90eb75
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 16, 6144], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[128], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T2 = fd.define_tensor(shape=[128], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[1, 16, 128], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T4 = fd.define_tensor(shape=[1, 16, 128], contiguity=[None, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T17 = fd.ops.slice(T0, start_indices=[0, 0, 0], end_indices=[1, 16, 4096], strides=[1, 1, 1], manual_normalization=0)
    T30 = fd.ops.slice(T0, start_indices=[0, 0, 4096], end_indices=[1, 16, 5120], strides=[1, 1, 1], manual_normalization=0)
    T43 = fd.ops.slice(T0, start_indices=[0, 0, 5120], end_indices=[1, 16, 6144], strides=[1, 1, 1], manual_normalization=0)
    T49 = fd.ops.reshape(T17, new_shape=[1, 16, 32, 128])
    T55 = fd.ops.reshape(T30, new_shape=[1, 16, 8, 128])
    T61 = fd.ops.reshape(T43, new_shape=[1, 16, 8, 128])
    T62 = fd.ops.permute(T49, dims=[0, 2, 1, 3])
    T63 = fd.ops.permute(T55, dims=[0, 2, 1, 3])
    T64 = fd.ops.permute(T61, dims=[0, 2, 1, 3])
    T65 = fd.ops.mul(T62, T62)
    T66 = fd.ops.sum(T65, dims=[3], keepdim=False, dtype=DataType.Null)
    T72 = fd.ops.broadcast_in_dim(T66, shape=[1, 32, 16, 1], broadcast_dims=[0, 1, 2])
    S73 = fd.define_scalar(128.000, dtype=DataType.Double)
    S74 = fd.ops.reciprocal(S73)
    T75 = fd.ops.mul(T72, S74)
    S76 = fd.define_scalar(1.00000e-06, dtype=DataType.Double)
    T77 = fd.ops.add(T75, S76)
    T78 = fd.ops.rsqrt(T77)
    T84 = fd.ops.broadcast_in_dim(T78, shape=[1, 32, 16, 128], broadcast_dims=[0, 1, 2, 3])
    T85 = fd.ops.mul(T62, T84)
    T91 = fd.ops.broadcast_in_dim(T1, shape=[1, 32, 16, 128], broadcast_dims=[3])
    T92 = fd.ops.mul(T85, T91)
    T93 = fd.ops.mul(T63, T63)
    T94 = fd.ops.sum(T93, dims=[3], keepdim=False, dtype=DataType.Null)
    T100 = fd.ops.broadcast_in_dim(T94, shape=[1, 8, 16, 1], broadcast_dims=[0, 1, 2])
    S101 = fd.define_scalar(128.000, dtype=DataType.Double)
    S102 = fd.ops.reciprocal(S101)
    T103 = fd.ops.mul(T100, S102)
    S104 = fd.define_scalar(1.00000e-06, dtype=DataType.Double)
    T105 = fd.ops.add(T103, S104)
    T106 = fd.ops.rsqrt(T105)
    T112 = fd.ops.broadcast_in_dim(T106, shape=[1, 8, 16, 128], broadcast_dims=[0, 1, 2, 3])
    T113 = fd.ops.mul(T63, T112)
    T119 = fd.ops.broadcast_in_dim(T2, shape=[1, 8, 16, 128], broadcast_dims=[3])
    T120 = fd.ops.mul(T113, T119)
    T136 = fd.ops.slice(T92, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 16, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T152 = fd.ops.slice(T92, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 16, 128], strides=[1, 1, 1, 1], manual_normalization=0)
    T153 = fd.ops.neg(T152)
    T154 = fd.ops.cat([T153, T136], dim=-1, manual_padding=0)
    T160 = fd.ops.reshape(T3, new_shape=[1, 1, 16, 128])
    T166 = fd.ops.reshape(T4, new_shape=[1, 1, 16, 128])
    T172 = fd.ops.broadcast_in_dim(T160, shape=[1, 32, 16, 128], broadcast_dims=[0, 1, 2, 3])
    T173 = fd.ops.mul(T92, T172)
    T179 = fd.ops.broadcast_in_dim(T166, shape=[1, 32, 16, 128], broadcast_dims=[0, 1, 2, 3])
    T180 = fd.ops.mul(T154, T179)
    T181 = fd.ops.add(T173, T180)
    T197 = fd.ops.slice(T120, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 16, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T213 = fd.ops.slice(T120, start_indices=[0, 0, 0, 64], end_indices=[1, 8, 16, 128], strides=[1, 1, 1, 1], manual_normalization=0)
    T214 = fd.ops.neg(T213)
    T215 = fd.ops.cat([T214, T197], dim=-1, manual_padding=0)
    T221 = fd.ops.broadcast_in_dim(T160, shape=[1, 8, 16, 128], broadcast_dims=[0, 1, 2, 3])
    T222 = fd.ops.mul(T120, T221)
    T228 = fd.ops.broadcast_in_dim(T166, shape=[1, 8, 16, 128], broadcast_dims=[0, 1, 2, 3])
    T229 = fd.ops.mul(T215, T228)
    T230 = fd.ops.add(T222, T229)
    T246 = fd.ops.slice(T92, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 16, 0], strides=[1, 1, 1, 1], manual_normalization=0)
    T247 = fd.ops.cat([T181, T246], dim=-1, manual_padding=0)
    T263 = fd.ops.slice(T120, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 16, 0], strides=[1, 1, 1, 1], manual_normalization=0)
    T264 = fd.ops.cat([T230, T263], dim=-1, manual_padding=0)
    fd.add_output(T64)
    fd.add_output(T78)
    fd.add_output(T106)
    fd.add_output(T247)
    fd.add_output(T264)

with FusionDefinition() as fd:
    nvfuser_fusion_id1(fd)

inputs = [
    torch.testing.make_tensor((1, 16, 6144), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((128,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((128,), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 16, 128), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((1, 16, 128), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions