Description
📚 Documentation
Issue is in the examples and thus in the tutorials
I believe the current implementation of the ChildSumTreeLSTMCell
in examples/pytorch/tree_lstm/tree_lstm.py
does not conform to the Child-Sum Tree LSTM described in the original paper or the documentation. Specifically, the implementation lacks the weight matrix
Evidence:
The equation specifies a weight matrix
The current code does not include a linear layer to compute
Proposed Fix:
I have drafted an alternative implementation based on the original implementation of the ChildSumTreeLSTMCell
that incorporates
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(torch.zeros(1, 3 * h_size))
self.W_f = nn.Linear(x_size, h_size, bias=False)
self.U_f = nn.Linear(h_size, h_size, bias=False)
self.b_f = nn.Parameter(torch.zeros(1, h_size))
def message_func(self, edges):
return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes):
h_tild = torch.sum(nodes.mailbox["h"], 1)
wx = self.W_f(nodes.data["x"]).unsqueeze(1)
uh = self.U_f(nodes.mailbox["h"])
f = torch.sigmoid(wx + uh + self.b_f.unsqueeze(1))
c_tild = torch.sum(f * nodes.mailbox["c"], 1)
return {"h_tild": h_tild, "c_tild": c_tild}
def apply_node_func(self, nodes):
# equation (3), (5), (6)
iou = self.W_iou(nodes.data["x"]) + self.b_iou
if "h_tild" in nodes.data:
iou += self.U_iou(nodes.data["h_tild"])
i, o, u = torch.chunk(iou, 3, 1)
i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
# equation (7)
c = i * u
if "c_tild" in nodes.data:
c += nodes.data["c_tild"]
# equation (8)
h = o * torch.tanh(c)
return {"h": h, "c": c}
Observed Issue with Training:
Currently, the cell isn't training as I imagined for my use case and I was hoping to strike a balance between helping to fix the example and getting confirmation that the cell is implemented correctly.
Request for Feedback:
- Could someone confirm whether the proposed implementation aligns with the Child-Sum Tree LSTM described in the paper?
- Any suggestions to align it with DGL best practices?
Thank you for your time and assistance. I am happy to further contribute by refining the implementation or submitting a PR if this approach is confirmed to be correct (new to Graphs and DGL).