From 46dc82fba9c961e43abf2abcc21bace34338bb5e Mon Sep 17 00:00:00 2001 From: RylieWeaver <123048075+RylieWeaver@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:15:36 -0400 Subject: [PATCH] small fix for last_layer hanging gradients with PAINN (#300) --- hydragnn/models/PAINNStack.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hydragnn/models/PAINNStack.py b/hydragnn/models/PAINNStack.py index dd9f9ebf2..11f062461 100644 --- a/hydragnn/models/PAINNStack.py +++ b/hydragnn/models/PAINNStack.py @@ -44,7 +44,9 @@ def __init__( def _init_conv(self): last_layer = 1 == self.num_conv_layers - self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim)) + self.graph_convs.append( + self.get_conv(self.input_dim, self.hidden_dim, last_layer) + ) self.feature_layers.append(nn.Identity()) for i in range(self.num_conv_layers - 1): last_layer = i == self.num_conv_layers - 2