-
Notifications
You must be signed in to change notification settings - Fork 112
Description
🐛 Describe the bug
First, here is the imports and test function :
import torch
import torch.nn as nn
from tensordict import NonTensorData, TensorDict
from tensordict.nn import dispatch
def test(cls: type[nn.Module]) -> None:
instance = cls()
tensordict = TensorDict(
{"a": torch.zeros(3, 5), "b": torch.ones(3, 2)}, batch_size=(3,)
)
print("Mode 1")
print("legacy\n", instance(tensordict["a"], tensordict["b"]))
print("up to date\n", instance(tensordict))
print("Mode 2")
instance.mode = "mode2"
print("legacy\n", instance(tensordict["a"], tensordict["b"]))
print("up to date\n", instance(tensordict))I'm working with multimodal models but want them to be retrocompatible with old scripts. To do so, I wanted to use dispatch from tensordict.nn which allows to take a module doing computations in a TensorDict to still take inputs as positional arguments. I wanted it to map the legacy inputs to different tensordict keys depending on the modality. In terms of code, the logic I wanted to implement was along the line of the following :
class MyModule1(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mode = "mode1"
self.key_map = {
"mode1": ["a", "z"],
"mode2": ["a", "b"],
# [...]
}
@dispatch(source=self.key_map[self.mode], dest=["c", "d"])
def forward(self, tensordict):
tensordict["c"] = tensordict["a"] - 1
print("b in dict : ", "b" in tensordict.keys())
if "b" in tensordict.keys():
tensordict["d"] = tensordict["b"] * 2
elif "z" in tensordict.keys():
tensordict["d"] = tensordict["z"] ** 2
else:
tensordict["d"] = NonTensorData(None)
return tensordictwhich obviously does not work because self cannot be accessed outside of methods (classic python).
So I moved the forward to another method (inner_forward) and thought that forward would then make use of dispatch, as follow :
class MyModule2(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mode = "mode1"
self.key_map = {
"mode1": ["a", "z"],
"mode2": ["a", "b"],
# [...]
}
def inner_forward(self, tensordict):
tensordict["c"] = tensordict["a"] - 1
print("b in dict : ", "b" in tensordict.keys())
if "b" in tensordict.keys():
tensordict["d"] = tensordict["b"] * 2
elif "z" in tensordict.keys():
tensordict["d"] = tensordict["z"] ** 2
else:
tensordict["d"] = NonTensorData(None)
return tensordict
def forward(self, *args, **kwargs):
return dispatch(source=self.key_map[self.mode], dest=["c", "d"])(
self.inner_forward
)(*args, **kwargs)This one raised TypeError: MyModule2.inner_forward() takes 2 positional arguments but 3 were given
So I iterated, and did the following to understand :
class MyModule2_2(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mode = "mode1"
self.key_map = {
"mode1": ["a", "z"],
"mode2": ["a", "b"],
# [...]
}
def inner_forward(self, whatever, tensordict):
print("self", self)
print("whatever", whatever)
print("tensordict", tensordict)
tensordict["c"] = tensordict["a"] - 1
print("b in dict : ", "b" in tensordict.keys())
if "b" in tensordict.keys():
tensordict["d"] = tensordict["b"] * 2
elif "z" in tensordict.keys():
tensordict["d"] = tensordict["z"] ** 2
else:
tensordict["d"] = NonTensorData(None)
return tensordict
def forward(self, *args, **kwargs):
return dispatch(source=self.key_map[self.mode], dest=["c", "d"])(
self.inner_forward
)(*args, **kwargs)which showed me through the test that in legacy arguments, the first argument was not aggregated inside the tensordict :
Mode 1
self MyModule2_2()
whatever tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
tensordict TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 2]),
device=None,
is_shared=False)
b in dict : False
So I went with it and designed this:
class MyModule2_3(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mode = "mode1"
self.key_map = {
"mode1": ["a", "z"],
"mode2": ["a", "b"],
# [...]
}
def inner_forward(self, whatever, tensordict):
print("self", self)
print("whatever", whatever)
print("tensordict", tensordict)
tensordict["c"] = tensordict["a"] - 1
print("b in dict : ", "b" in tensordict.keys())
if "b" in tensordict.keys():
tensordict["d"] = tensordict["b"] * 2
elif "z" in tensordict.keys():
tensordict["d"] = tensordict["z"] ** 2
else:
tensordict["d"] = NonTensorData(None)
return tensordict
def forward(self, *args, **kwargs):
return dispatch(source=self.key_map[self.mode], dest=["c", "d"])(
self.inner_forward
)("filler", *args, **kwargs)which works as planned. However, I don't know how stable it is, nor where the additional arguments come from.
Finally, I found another solution that is less elegant but leverage dispatch method without bugs :
class MyModule3(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.mode = "mode1"
self.key_map = {
"mode1": ["a", "z"],
"mode2": ["a", "b"],
# [...]
}
def inner_forward(tensordict):
tensordict["c"] = tensordict["a"] - 1
print("b in dict : ", "b" in tensordict.keys())
if "b" in tensordict.keys():
tensordict["d"] = tensordict["b"] * 2
elif "z" in tensordict.keys():
tensordict["d"] = tensordict["z"] ** 2
else:
tensordict["d"] = NonTensorData(None)
return tensordict
self.inner_forward = inner_forward
def forward(self, *args, **kwargs):
return dispatch(source=self.key_map[self.mode], dest=["c", "d"])(
self.inner_forward
)(*args, **kwargs)Here, inner_forward is defined as a function inside init and registered as an attribute rather than a method. This works as intended.
I didn't dive further in the code of dispatch to understand where this is coming from. Also, on a side note, changing dispatch typing so it matches its documentation could also be nice (currently, source and dest are typed as str, while list[str] is also possible).
Versions
Versions of relevant libraries:
[pip3] numpy==2.0.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] Could not collect
to which I would add : tensordict == 0.10.0
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)