Skip to content

[BUG] TensorDict and TensorClass objects are not supported in torch.jit.fork #1474

@granthamtaylor

Description

@granthamtaylor

Describe the bug

These objects cannot be outputted from forked functions. This is extremely unfortunate because, well, tensordicts are such a great way to handle multiple, heterogeneous blocks of data.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torch
from tensordict import tensorclass

@tensorclass
class MyClass:
    x: torch.Tensor
    y: torch.Tensor

def make_obj(x: torch.Tensor):
    return MyClass(x=x, y=x + 1)

def parallel_class():
    fut1 = torch.jit.fork(make_obj, torch.tensor(1))
    fut2 = torch.jit.fork(make_obj, torch.tensor(2))
    o1 = torch.jit.wait(fut1)
    o2 = torch.jit.wait(fut2)
    return o1, o2

print(parallel_class())  # ❌ Error
import torch
from tensordict import TensorDict

def make_td(x: torch.Tensor):
    # returns a simple TensorDict
    return TensorDict({"x": x, "y": x + 1}, batch_size=[])

def parallel():
    # Attempt to fork two TensorDict-producing tasks
    fut1 = torch.jit.fork(make_td, torch.tensor(1))
    fut2 = torch.jit.fork(make_td, torch.tensor(2))
    td1 = torch.jit.wait(fut1)
    td2 = torch.jit.wait(fut2)
    return td1, td2

print(parallel())   # ❌ Error

Expected behavior

Both of these functions should work as expected (as if the outputted objects were of type dict[str, Tensor] but instead raise the following error:

Only tensors and (possibly nested) tuples of tensors, lists, or dicts are supported as inputs or outputs of traced functions, but instead got value of type TensorDict.

System info

Describe the characteristic of your environment:

  • Installed via UV
  • Python 3.12
  • tensordict==0.10.0

Additional context

Tested on Mac and Linux, with both CPU and CUDA.

Reason and Possible fixes

N/A

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions