Skip to content

[Question] sparse batching for jagged inputs? #1247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
davidegraff opened this issue Mar 4, 2025 · 6 comments
Open

[Question] sparse batching for jagged inputs? #1247

davidegraff opened this issue Mar 4, 2025 · 6 comments

Comments

@davidegraff
Copy link

I'm currently trying to use tensordict for GNN models.

Background
Each sample in my dataset is variably sized with respect to number of nodes. Briefly, a Graph data structure looks like this:

from jaxtyping import Float
from torch import Tensor

class Graph:
    node_feats: Float[Tensor, "n_nodes d_v"]
    edge_feats: Float[Tensor, "n_edges d_e"]
    edge_index: NonTensorData[Int[Tensor, "2 n_edges"]]

I batch graphs sparsely. That is, because the n_nodes/n_edges can vary so widely, I compress them into a single graph with multiple components:

class BatchedGraph(Graph):
    """A :class:`BatchedGraph` represents a batch of individual :class:`Graph`s."""

    batch_index: Int[Tensor, "n_nodes"]
    """A tensor of shape ``n_nodes`` containing the index of the parent :class:`Graph` of each node the
    batched graph."""
    num_graphs: int | None
    """the number of independent graphs (i.e., components) in the :class:`BatchedGraph`

What I would like to do
I'm currently trying to transition these objects to something TensorDict-like. That is, I'd like to have the ability to have a batch in my graph and get a certain attribute using a key-like interface rather than an attribute-like one:

td = TensorDict(G=Graph(...))
td['G', 'node_feats'] # like this!
td['G'].node_feats # instead of this...

Unfortunately, the current mechanics of the batch_size parameter in a TensorDict make this impossible. This is because if my sparsely batched graph is composed of 10 graphs with 10 nodes each, node_features will have a shape of 100 x d, but the corresponding batch_size should be 10. When I try to do this, the package raises an exception:

>>> n_nodes, d = 100, 16
>>> batch_size = 10
>>> td = TensorDict(V=torch.randn(n_nodes, d), batch_index=torch.randint(0, batch_size, (n_nodes,)), batch_size=[10]
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([10]) and value.shape=torch.Size([100, 16]).

It's clear why I'm having this problem, so my question is whether the TensorDict object model and this use-case are compatible? If not, do you have other suggestions? Moving over to PyG won't help me, as it would still result in an attribute-like interface that I currently use. Thanks for any input!

@vmoens
Copy link
Collaborator

vmoens commented Mar 4, 2025

I don't really get what the problem is, so let me ask a couple of clarifying questions:

Can you elaborate on why you can't reshape V and batch_index as (batch_size, n_nodes//batch_size)?
What you want out of the batch-size is the capacity to manipulate the shape (reshape, index etc) - is that right?

@davidegraff
Copy link
Author

Yeah sorry it's not clear. It's a bit of a different data model for batching as compared to sequences.

Can you elaborate on why you can't reshape V and batch_index as (batch_size, n_nodes//batch_size)?

Each graph in the batch might be a different size (number of nodes). Rather than padding up to the maximum graph graph and reshaping into (b, n_nodes_max, d), where n_nodes_max = max(graph.n_nodes for graph in graphs) I instead concatenate along the node dimension into a single tensor of shape (n_nodes_total, d), where n_nodes_total = sum(graph.n_nodes for graph in graphs). Logically, this corresponds to a single graph with b disconnected components instead of a batch of b graphs.

Ex: consider two line graphs: 1-2-3 and 1-2.

These two graphs have the following adjacency matrices:

$$ A_1 = \begin{bmatrix} 0 & 1 & 0\\ 1 & 0 & 1\\ 0 & 1 & 0\\ \end{bmatrix} $$ $$ A_2 = \begin{bmatrix} 0 & 1\\ 1 & 0\\ \end{bmatrix} $$

Rather than padding the second graph to have 3 vertices (and the adjacency matrix to have shape $$(3, 3)$$), I instead concatenate them into a single graph 1-2-3 4-5 with the adjacency matrix

$$ A_{12} = \begin{bmatrix} 0 & 1 & 0 & 0 & 0\\ 1 & 0 & 1 & 0 & 0\\ 0 & 1 & 0 & 0 & 0\\ 0 & 0 & 0 & 0 & 1\\ 0 & 0 & 0 & 1 & 0\\ \end{bmatrix} $$

Note that I've changed the node labels of the second graph so the adjacency matrix of the sparsely batched graph is correct. So logically this graph contains two inputs, but concretely the code just sees a single graph with node_feats of shape (5, d_v) and edge_feats with shape (6, d_e). Absent padding (which would require a whole rework of my data model), there's no way to reshape these tensors into (2, *, d)

What you want out of the batch-size is the capacity to manipulate the shape (reshape, index etc) - is that right?

I really would like to be able to keep track of the batch size somehow. An added detail, the tensordict will also contain some tensor(s) of target labels of shape (b, num_tasks), but I don't know the keys of these tensors. I would really like to get the batch via len(tensordict) or tensordict.batch_size[0] without having to figure out the target key. Sorry again for such an open question, but any input would be appreciated. Thanks!

@vmoens
Copy link
Collaborator

vmoens commented Mar 5, 2025

Is a lazy stack an option? Or a nested tensor?

A1 = torch.eye(3)
A2 = torch.eye(2)
td1 = TensorDict(g=A1)
td2 = TensorDict(g=A2)
td = lazy_stack([td1, td2])
# nested tensor version
td = td.densify(layout=torch.jagged)

Or do you really want some map of batch-size elements to slices?

@davidegraff
Copy link
Author

I think it's much simpler than that actually. I'm really just searching for a way to keep track of the batch size of the tensordict while also using a sparsely batched input. Abstractly, I'm searching for a way to override the batch_size or __len__ method of the BatchedGraph tensorclass. To reiterate, above, the resulting BatchedGraph has a few 2D tensor attributes, but the size of each of these dimensions is not the batch size. I know the batch size of a BatchedGraph when creating one, but tensordict is unable to infer this from any of the data so I'm wondering if there's someway to manually specify this. Does that make sense?

@vmoens
Copy link
Collaborator

vmoens commented Mar 6, 2025

Are you happy with losing the capacity of indexing and such?

What do you want to get from the batch size attribute (besides tracking metadata?)
From what I understand you're just bothered that you can't set batch_size to be whatever value you want, but you don't want the batch size to change when you stack, index, cat etc.

You could store your batch size as a non tensor data.
You could also use tensorclass and put "shadow=True" which allows you to have fields named batch_size and such names which are usually protected

@davidegraff
Copy link
Author

Thanks for the reply! and sorry in my delayed response.

Are you happy with losing the capacity of indexing and such?

What do you want to get from the batch size attribute (besides tracking metadata?) From what I understand you're just bothered that you can't set batch_size to be whatever value you want, but you don't want the batch size to change when you stack, index, cat etc.

At some level, yes haha. Again, the challenge is that my data model (and GNNs in general) relies on a sparsely batched input, so the typical notion of having a "batch index" as the 0th dimension of your tensor doesn't really apply here. Rather, all the inputs are torch.cat-ed along the 0th dimension as opposed to being padded and then stacked. So there's a separate batch_index tensor which tracks the batch index of each item in your batch (i.e., for a given node, which graph does it belong to). All that said, there's really no simple way for tensordict to handle this use case naturally (nor am I asking you to!) when it comes to stacking, indexing, concatenation, etc.

You could store your batch size as a non tensor data. You could also use tensorclass and put "shadow=True" which allows you to have fields named batch_size and such names which are usually protected

I think this is maybe my best course of action. Really, to sum up what I think I want, is literally just to use dict parts of tensordict to store my data and use this with TensorDictModules. I don't really need any additional utilities of the package (for this specific use-case), but the sticking point is that I just need a way to manually set or override the batch_size attribute of a TensorDict. Setting shadow=True solves my immediate needs, but I was wondering if there could be future support for manually implementing the batch_size field of a tensorclass, rather than just adding a field with a similar name:

@tensorclass(shadow=True, nocast=True)
class Graph:
    X: Tensor
    batch_size: Tensor

G = Graph(torch.rand(10, 4), batch_size=4)
print(len(G)) # should be 4!
0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants