Skip to content

[bug][relax.frontend.torch] FFI segfault in tvm::relax::Tuple::Tuple when importing torch.export graph with 4D advanced-indexing write (aten.index_put_) and tuple outputs #18363

@tinywisdom

Description

@tinywisdom

Summary

Importing a torch.exported program into TVM Relax triggers a segmentation fault inside FFI during construction of a Relax Tuple. The minimal model performs a 4D advanced indexing write using two integer index tensors on the last two dims (L[..., idx, idx] = ...) and returns a Python tuple of tensors (x[..., :1], L). The exported graph is free of RNG ops (no randn), so the crash appears related to the combination of aten.index_put_ lowering and tuple output construction.

Actual behavior

[1] torch.export ...
=== Exported ops ===
... (as above)
[2] tvm.relax.frontend.torch.from_exported_program ...
!!!!!!! TVM FFI encountered a Segfault !!!!!!! 
  ...
  tvm::relax::Tuple::Tuple(tvm::ffi::Array<tvm::RelaxExpr, void>, tvm::Span) [clone .cold]
  ...
Segmentation fault (core dumped)

Environment

  • OS: (Ubuntu 22.04.4 LTS (x86_64))
  • TVM version: (release v0.21.0)
  • Python: (3.10.16)
  • LLVM: (17.0.6)
  • Pytorch: (2.8.0)

Steps to reproduce

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""  # avoid GPU warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program

class M4D(nn.Module):
    def forward(self, x):
        B, K, N = 2, 3, 5
        L = x.new_zeros(B, K, N, N)       # tensor construct only; no randomness
        idx = torch.arange(N, device=x.device)

        # key trigger: gather diagonal, apply smooth monotonic transform, scatter back
        diag = L[..., idx, idx]           # shape: [B, K, N]
        diag = F.elu(diag) + 1.0 + 1e-8   # avoid all-zero; any smooth transform works
        L[..., idx, idx] = diag           # advanced indexing write (two int index tensors)

        # key trigger: return a Python-level tuple (two tensors)
        return x[..., :1], L

if __name__ == "__main__":
    torch.manual_seed(0)
    m = M4D().eval()
    ex_in = torch.zeros(2, 3, 5)  # any input; ensures no randn exported

    print("[1] torch.export ...")
    ep = torch_export(m, (ex_in,))

    # sanity: list exported ops
    try:
        print("=== Exported ops ===")
        for n in ep.graph.nodes:
            print(getattr(n, "op", None), getattr(n, "target", None))
    except Exception:
        pass

    print("[2] tvm.relax.frontend.torch.from_exported_program ...")
    mod = from_exported_program(ep)  # <-- segfaults inside FFI Tuple construction
    print("[OK] Converted without segfault (if you see this, env may differ)")

Triage

  • needs-triage
  • bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions