-
Notifications
You must be signed in to change notification settings - Fork 93
[Question] sparse batching for jagged inputs? #1247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I don't really get what the problem is, so let me ask a couple of clarifying questions: Can you elaborate on why you can't reshape |
Yeah sorry it's not clear. It's a bit of a different data model for batching as compared to sequences.
Each graph in the batch might be a different size (number of nodes). Rather than padding up to the maximum graph graph and reshaping into Ex: consider two line graphs: These two graphs have the following adjacency matrices: Rather than padding the second graph to have 3 vertices (and the adjacency matrix to have shape Note that I've changed the node labels of the second graph so the adjacency matrix of the sparsely batched graph is correct. So logically this graph contains two inputs, but concretely the code just sees a single graph with
I really would like to be able to keep track of the batch size somehow. An added detail, the tensordict will also contain some tensor(s) of target labels of shape |
Is a lazy stack an option? Or a nested tensor? A1 = torch.eye(3)
A2 = torch.eye(2)
td1 = TensorDict(g=A1)
td2 = TensorDict(g=A2)
td = lazy_stack([td1, td2])
# nested tensor version
td = td.densify(layout=torch.jagged) Or do you really want some map of batch-size elements to slices? |
I think it's much simpler than that actually. I'm really just searching for a way to keep track of the batch size of the tensordict while also using a sparsely batched input. Abstractly, I'm searching for a way to override the |
Are you happy with losing the capacity of indexing and such? What do you want to get from the batch size attribute (besides tracking metadata?) You could store your batch size as a non tensor data. |
Thanks for the reply! and sorry in my delayed response.
At some level, yes haha. Again, the challenge is that my data model (and GNNs in general) relies on a sparsely batched input, so the typical notion of having a "batch index" as the 0th dimension of your tensor doesn't really apply here. Rather, all the inputs are
I think this is maybe my best course of action. Really, to sum up what I think I want, is literally just to use @tensorclass(shadow=True, nocast=True)
class Graph:
X: Tensor
batch_size: Tensor
G = Graph(torch.rand(10, 4), batch_size=4)
print(len(G)) # should be 4!
0 |
I'm currently trying to use
tensordict
for GNN models.Background
Each sample in my dataset is variably sized with respect to number of nodes. Briefly, a Graph data structure looks like this:
I batch graphs sparsely. That is, because the
n_nodes
/n_edges
can vary so widely, I compress them into a single graph with multiple components:What I would like to do
I'm currently trying to transition these objects to something
TensorDict
-like. That is, I'd like to have the ability to have a batch in my graph and get a certain attribute using a key-like interface rather than an attribute-like one:Unfortunately, the current mechanics of the
batch_size
parameter in aTensorDict
make this impossible. This is because if my sparsely batched graph is composed of 10 graphs with 10 nodes each,node_features
will have a shape of100 x d
, but the correspondingbatch_size
should be10
. When I try to do this, the package raises an exception:It's clear why I'm having this problem, so my question is whether the
TensorDict
object model and this use-case are compatible? If not, do you have other suggestions? Moving over to PyG won't help me, as it would still result in an attribute-like interface that I currently use. Thanks for any input!The text was updated successfully, but these errors were encountered: