Skip to content

Commit

Permalink
small fix for last_layer hanging gradients with PAINN (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
RylieWeaver authored Oct 29, 2024
1 parent 0fafea7 commit 46dc82f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion hydragnn/models/PAINNStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def __init__(

def _init_conv(self):
last_layer = 1 == self.num_conv_layers
self.graph_convs.append(self.get_conv(self.input_dim, self.hidden_dim))
self.graph_convs.append(
self.get_conv(self.input_dim, self.hidden_dim, last_layer)
)
self.feature_layers.append(nn.Identity())
for i in range(self.num_conv_layers - 1):
last_layer = i == self.num_conv_layers - 2
Expand Down

0 comments on commit 46dc82f

Please sign in to comment.