Skip to content

Missing W_f Implementation in ChildSumTreeLSTMCell: Alignment with Paper and Proposed Fix #7848

Open
@scopeofaperture

Description

@scopeofaperture

📚 Documentation

Issue is in the examples and thus in the tutorials

I believe the current implementation of the ChildSumTreeLSTMCell in examples/pytorch/tree_lstm/tree_lstm.py does not conform to the Child-Sum Tree LSTM described in the original paper or the documentation. Specifically, the implementation lacks the weight matrix $W^{(f)}$ described in Equation 4 of the paper.

Evidence:

The equation specifies a weight matrix $W^{(f)}$, which contributes to the calculation of forget gates for each child node. However, this matrix is not implemented in the code.

equation 4

The current code does not include a linear layer to compute $W^{(f)}$. This deviation can lead to incorrect behavior, as the forget gates are not calculated as per the paper.

Proposed Fix:

I have drafted an alternative implementation based on the original implementation of the ChildSumTreeLSTMCell that incorporates $W^{(f)}$ with the old API:

class ChildSumTreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(torch.zeros(1, 3 * h_size))

        self.W_f = nn.Linear(x_size, h_size, bias=False)
        self.U_f = nn.Linear(h_size, h_size, bias=False)
        self.b_f = nn.Parameter(torch.zeros(1, h_size))

    def message_func(self, edges):
        return {"h": edges.src["h"], "c": edges.src["c"]}

    def reduce_func(self, nodes):
        h_tild = torch.sum(nodes.mailbox["h"], 1)
        wx = self.W_f(nodes.data["x"]).unsqueeze(1)
        uh = self.U_f(nodes.mailbox["h"])
        f = torch.sigmoid(wx + uh + self.b_f.unsqueeze(1))
        c_tild = torch.sum(f * nodes.mailbox["c"], 1)
        return {"h_tild": h_tild, "c_tild": c_tild}

    def apply_node_func(self, nodes):
        # equation (3), (5), (6)
        iou = self.W_iou(nodes.data["x"]) + self.b_iou
        if "h_tild" in nodes.data:
            iou += self.U_iou(nodes.data["h_tild"])
        i, o, u = torch.chunk(iou, 3, 1)
        i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
        # equation (7)
        c = i * u
        if "c_tild" in nodes.data:
            c += nodes.data["c_tild"]
        # equation (8)
        h = o * torch.tanh(c)
        return {"h": h, "c": c}

Observed Issue with Training:

Currently, the cell isn't training as I imagined for my use case and I was hoping to strike a balance between helping to fix the example and getting confirmation that the cell is implemented correctly.

Request for Feedback:

  1. Could someone confirm whether the proposed implementation aligns with the Child-Sum Tree LSTM described in the paper?
  2. Any suggestions to align it with DGL best practices?

Thank you for your time and assistance. I am happy to further contribute by refining the implementation or submitting a PR if this approach is confirmed to be correct (new to Graphs and DGL).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions