Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing W_f Implementation in ChildSumTreeLSTMCell: Alignment with Paper and Proposed Fix #7848

Open
scopeofaperture opened this issue Dec 18, 2024 · 1 comment

Comments

@scopeofaperture
Copy link

📚 Documentation

Issue is in the examples and thus in the tutorials

I believe the current implementation of the ChildSumTreeLSTMCell in examples/pytorch/tree_lstm/tree_lstm.py does not conform to the Child-Sum Tree LSTM described in the original paper or the documentation. Specifically, the implementation lacks the weight matrix $W^{(f)}$ described in Equation 4 of the paper.

Evidence:

The equation specifies a weight matrix $W^{(f)}$, which contributes to the calculation of forget gates for each child node. However, this matrix is not implemented in the code.

equation 4

The current code does not include a linear layer to compute $W^{(f)}$. This deviation can lead to incorrect behavior, as the forget gates are not calculated as per the paper.

Proposed Fix:

I have drafted an alternative implementation based on the original implementation of the ChildSumTreeLSTMCell that incorporates $W^{(f)}$ with the old API:

class ChildSumTreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(ChildSumTreeLSTMCell, self).__init__()
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(torch.zeros(1, 3 * h_size))

        self.W_f = nn.Linear(x_size, h_size, bias=False)
        self.U_f = nn.Linear(h_size, h_size, bias=False)
        self.b_f = nn.Parameter(torch.zeros(1, h_size))

    def message_func(self, edges):
        return {"h": edges.src["h"], "c": edges.src["c"]}

    def reduce_func(self, nodes):
        h_tild = torch.sum(nodes.mailbox["h"], 1)
        wx = self.W_f(nodes.data["x"]).unsqueeze(1)
        uh = self.U_f(nodes.mailbox["h"])
        f = torch.sigmoid(wx + uh + self.b_f.unsqueeze(1))
        c_tild = torch.sum(f * nodes.mailbox["c"], 1)
        return {"h_tild": h_tild, "c_tild": c_tild}

    def apply_node_func(self, nodes):
        # equation (3), (5), (6)
        iou = self.W_iou(nodes.data["x"]) + self.b_iou
        if "h_tild" in nodes.data:
            iou += self.U_iou(nodes.data["h_tild"])
        i, o, u = torch.chunk(iou, 3, 1)
        i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
        # equation (7)
        c = i * u
        if "c_tild" in nodes.data:
            c += nodes.data["c_tild"]
        # equation (8)
        h = o * torch.tanh(c)
        return {"h": h, "c": c}

Observed Issue with Training:

Currently, the cell isn't training as I imagined for my use case and I was hoping to strike a balance between helping to fix the example and getting confirmation that the cell is implemented correctly.

Request for Feedback:

  1. Could someone confirm whether the proposed implementation aligns with the Child-Sum Tree LSTM described in the paper?
  2. Any suggestions to align it with DGL best practices?

Thank you for your time and assistance. I am happy to further contribute by refining the implementation or submitting a PR if this approach is confirmed to be correct (new to Graphs and DGL).

Copy link

This issue has been automatically marked as stale due to lack of activity. It will be closed if no further activity occurs. Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant