Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
zachfox committed Oct 26, 2024
1 parent b58c053 commit 0c63fd6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
2 changes: 0 additions & 2 deletions hydragnn/models/EGCLStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(
max_neighbours: Optional[int] = None,
**kwargs,
):

self.edge_dim = (
0 if edge_attr_dim is None else edge_attr_dim
) # Must be named edge_dim to trigger use by Base
Expand Down Expand Up @@ -159,7 +158,6 @@ def __init__(
self.clamp = clamp

if self.equivariant:

layer = nn.Linear(hidden_channels, 1, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

Expand Down
10 changes: 4 additions & 6 deletions hydragnn/models/HybridEGCLStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,12 @@ def loss_hpweighted(self, pred, value, head_index, var=None):
# Calculate MSE if position noise
elif ihead == 1:
head_loss = self.mse(head_pred, head_val)

# Add loss to total loss and list of tasks loss
tot_loss += (
head_loss * self.loss_weights[ihead]
)
tot_loss += head_loss * self.loss_weights[ihead]
tasks_loss.append(head_loss)

return tot_loss, tasks_loss

def __str__(self):
return "HybridEGCLStack"
return "HybridEGCLStack"
7 changes: 4 additions & 3 deletions hydragnn/utils/input_config_parsing/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ def update_config_NN_outputs(config, data, graph_size_variable):
else:
for ihead in range(len(output_type)):
if output_type[ihead] != "graph":
if not "dynamic_target" in config["Variables_of_interest"] or\
("dynamic_target" in config["Variables_of_interest"] and\
not config["Variables_of_interest"]["dynamic_target"]): # raise ValueError if yloc missing on non-graph, with "dynamic_target" set to false or missing
if not "dynamic_target" in config["Variables_of_interest"] or (
"dynamic_target" in config["Variables_of_interest"]
and not config["Variables_of_interest"]["dynamic_target"]
): # raise ValueError if yloc missing on non-graph, with "dynamic_target" set to false or missing
raise ValueError(
"y_loc is needed for outputs that are not at graph levels",
output_type[ihead],
Expand Down

0 comments on commit 0c63fd6

Please sign in to comment.