-
Notifications
You must be signed in to change notification settings - Fork 112
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
These objects cannot be outputted from forked functions. This is extremely unfortunate because, well, tensordicts are such a great way to handle multiple, heterogeneous blocks of data.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
import torch
from tensordict import tensorclass
@tensorclass
class MyClass:
x: torch.Tensor
y: torch.Tensor
def make_obj(x: torch.Tensor):
return MyClass(x=x, y=x + 1)
def parallel_class():
fut1 = torch.jit.fork(make_obj, torch.tensor(1))
fut2 = torch.jit.fork(make_obj, torch.tensor(2))
o1 = torch.jit.wait(fut1)
o2 = torch.jit.wait(fut2)
return o1, o2
print(parallel_class()) # ❌ Errorimport torch
from tensordict import TensorDict
def make_td(x: torch.Tensor):
# returns a simple TensorDict
return TensorDict({"x": x, "y": x + 1}, batch_size=[])
def parallel():
# Attempt to fork two TensorDict-producing tasks
fut1 = torch.jit.fork(make_td, torch.tensor(1))
fut2 = torch.jit.fork(make_td, torch.tensor(2))
td1 = torch.jit.wait(fut1)
td2 = torch.jit.wait(fut2)
return td1, td2
print(parallel()) # ❌ ErrorExpected behavior
Both of these functions should work as expected (as if the outputted objects were of type dict[str, Tensor] but instead raise the following error:
Only tensors and (possibly nested) tuples of tensors, lists, or dicts are supported as inputs or outputs of traced functions, but instead got value of type TensorDict.
System info
Describe the characteristic of your environment:
- Installed via UV
- Python 3.12
tensordict==0.10.0
Additional context
Tested on Mac and Linux, with both CPU and CUDA.
Reason and Possible fixes
N/A
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working