diff --git a/hydragnn/models/PAINNStack.py b/hydragnn/models/PAINNStack.py index dd9f9ebf2..11f062461 100644 --- a/hydragnn/models/PAINNStack.py +++ b/hydragnn/models/PAINNStack.py @@ -44,7 +44,9 @@ def __init__( def _init_conv(self): last_layer = 1 == self.num_conv_layers - self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim)) + self.graph_convs.append( + self.get_conv(self.input_dim, self.hidden_dim, last_layer) + ) self.feature_layers.append(nn.Identity()) for i in range(self.num_conv_layers - 1): last_layer = i == self.num_conv_layers - 2