Skip to content

[BUG] tensordict.nn.dispatch does not have expected behavior #1459

@RobinSobczyk

Description

@RobinSobczyk

🐛 Describe the bug

First, here is the imports and test function :

import torch
import torch.nn as nn
from tensordict import NonTensorData, TensorDict
from tensordict.nn import dispatch


def test(cls: type[nn.Module]) -> None:
    instance = cls()
    tensordict = TensorDict(
        {"a": torch.zeros(3, 5), "b": torch.ones(3, 2)}, batch_size=(3,)
    )

    print("Mode 1")
    print("legacy\n", instance(tensordict["a"], tensordict["b"]))
    print("up to date\n", instance(tensordict))

    print("Mode 2")
    instance.mode = "mode2"
    print("legacy\n", instance(tensordict["a"], tensordict["b"]))
    print("up to date\n", instance(tensordict))

I'm working with multimodal models but want them to be retrocompatible with old scripts. To do so, I wanted to use dispatch from tensordict.nn which allows to take a module doing computations in a TensorDict to still take inputs as positional arguments. I wanted it to map the legacy inputs to different tensordict keys depending on the modality. In terms of code, the logic I wanted to implement was along the line of the following :

class MyModule1(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.mode = "mode1"
        self.key_map = {
            "mode1": ["a", "z"],
            "mode2": ["a", "b"],
            # [...]
        }

    @dispatch(source=self.key_map[self.mode], dest=["c", "d"])
    def forward(self, tensordict):
        tensordict["c"] = tensordict["a"] - 1
        print("b in dict : ", "b" in tensordict.keys())
        if "b" in tensordict.keys():
            tensordict["d"] = tensordict["b"] * 2
        elif "z" in tensordict.keys():
            tensordict["d"] = tensordict["z"] ** 2
        else:
            tensordict["d"] = NonTensorData(None)
        return tensordict

which obviously does not work because self cannot be accessed outside of methods (classic python).

So I moved the forward to another method (inner_forward) and thought that forward would then make use of dispatch, as follow :

class MyModule2(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.mode = "mode1"
        self.key_map = {
            "mode1": ["a", "z"],
            "mode2": ["a", "b"],
            # [...]
        }

    def inner_forward(self, tensordict):
        tensordict["c"] = tensordict["a"] - 1
        print("b in dict : ", "b" in tensordict.keys())
        if "b" in tensordict.keys():
            tensordict["d"] = tensordict["b"] * 2
        elif "z" in tensordict.keys():
            tensordict["d"] = tensordict["z"] ** 2
        else:
            tensordict["d"] = NonTensorData(None)
        return tensordict

    def forward(self, *args, **kwargs):
        return dispatch(source=self.key_map[self.mode], dest=["c", "d"])(
            self.inner_forward
        )(*args, **kwargs)

This one raised TypeError: MyModule2.inner_forward() takes 2 positional arguments but 3 were given

So I iterated, and did the following to understand :

class MyModule2_2(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.mode = "mode1"
        self.key_map = {
            "mode1": ["a", "z"],
            "mode2": ["a", "b"],
            # [...]
        }

    def inner_forward(self, whatever, tensordict):
        print("self", self)
        print("whatever", whatever)
        print("tensordict", tensordict)
        tensordict["c"] = tensordict["a"] - 1
        print("b in dict : ", "b" in tensordict.keys())
        if "b" in tensordict.keys():
            tensordict["d"] = tensordict["b"] * 2
        elif "z" in tensordict.keys():
            tensordict["d"] = tensordict["z"] ** 2
        else:
            tensordict["d"] = NonTensorData(None)
        return tensordict

    def forward(self, *args, **kwargs):
        return dispatch(source=self.key_map[self.mode], dest=["c", "d"])(
            self.inner_forward
        )(*args, **kwargs)

which showed me through the test that in legacy arguments, the first argument was not aggregated inside the tensordict :

Mode 1
self MyModule2_2()
whatever tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
tensordict TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3, 2]),
    device=None,
    is_shared=False)
b in dict :  False

So I went with it and designed this:

class MyModule2_3(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.mode = "mode1"
        self.key_map = {
            "mode1": ["a", "z"],
            "mode2": ["a", "b"],
            # [...]
        }

    def inner_forward(self, whatever, tensordict):
        print("self", self)
        print("whatever", whatever)
        print("tensordict", tensordict)
        tensordict["c"] = tensordict["a"] - 1
        print("b in dict : ", "b" in tensordict.keys())
        if "b" in tensordict.keys():
            tensordict["d"] = tensordict["b"] * 2
        elif "z" in tensordict.keys():
            tensordict["d"] = tensordict["z"] ** 2
        else:
            tensordict["d"] = NonTensorData(None)
        return tensordict

    def forward(self, *args, **kwargs):
        return dispatch(source=self.key_map[self.mode], dest=["c", "d"])(
            self.inner_forward
        )("filler", *args, **kwargs)

which works as planned. However, I don't know how stable it is, nor where the additional arguments come from.

Finally, I found another solution that is less elegant but leverage dispatch method without bugs :

class MyModule3(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.mode = "mode1"
        self.key_map = {
            "mode1": ["a", "z"],
            "mode2": ["a", "b"],
            # [...]
        }

        def inner_forward(tensordict):
            tensordict["c"] = tensordict["a"] - 1
            print("b in dict : ", "b" in tensordict.keys())
            if "b" in tensordict.keys():
                tensordict["d"] = tensordict["b"] * 2
            elif "z" in tensordict.keys():
                tensordict["d"] = tensordict["z"] ** 2
            else:
                tensordict["d"] = NonTensorData(None)
            return tensordict

        self.inner_forward = inner_forward

    def forward(self, *args, **kwargs):
        return dispatch(source=self.key_map[self.mode], dest=["c", "d"])(
            self.inner_forward
        )(*args, **kwargs)

Here, inner_forward is defined as a function inside init and registered as an attribute rather than a method. This works as intended.

I didn't dive further in the code of dispatch to understand where this is coming from. Also, on a side note, changing dispatch typing so it matches its documentation could also be nice (currently, source and dest are typed as str, while list[str] is also possible).

Versions

Versions of relevant libraries:
[pip3] numpy==2.0.2
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] triton==3.1.0
[conda] Could not collect

to which I would add : tensordict == 0.10.0

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions