diff --git a/hydragnn/models/EGCLStack.py b/hydragnn/models/EGCLStack.py index 7109d0fc3..8ae8b2e19 100644 --- a/hydragnn/models/EGCLStack.py +++ b/hydragnn/models/EGCLStack.py @@ -26,7 +26,6 @@ def __init__( max_neighbours: Optional[int] = None, **kwargs, ): - self.edge_dim = ( 0 if edge_attr_dim is None else edge_attr_dim ) # Must be named edge_dim to trigger use by Base @@ -159,7 +158,6 @@ def __init__( self.clamp = clamp if self.equivariant: - layer = nn.Linear(hidden_channels, 1, bias=False) torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) diff --git a/hydragnn/models/HybridEGCLStack.py b/hydragnn/models/HybridEGCLStack.py index 9a19d8444..8ad6ae91b 100644 --- a/hydragnn/models/HybridEGCLStack.py +++ b/hydragnn/models/HybridEGCLStack.py @@ -56,14 +56,12 @@ def loss_hpweighted(self, pred, value, head_index, var=None): # Calculate MSE if position noise elif ihead == 1: head_loss = self.mse(head_pred, head_val) - + # Add loss to total loss and list of tasks loss - tot_loss += ( - head_loss * self.loss_weights[ihead] - ) + tot_loss += head_loss * self.loss_weights[ihead] tasks_loss.append(head_loss) return tot_loss, tasks_loss - + def __str__(self): - return "HybridEGCLStack" \ No newline at end of file + return "HybridEGCLStack" diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index 85ccb8e83..849c6b3f1 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -206,9 +206,10 @@ def update_config_NN_outputs(config, data, graph_size_variable): else: for ihead in range(len(output_type)): if output_type[ihead] != "graph": - if not "dynamic_target" in config["Variables_of_interest"] or\ - ("dynamic_target" in config["Variables_of_interest"] and\ - not config["Variables_of_interest"]["dynamic_target"]): # raise ValueError if yloc missing on non-graph, with "dynamic_target" set to false or missing + if not "dynamic_target" in config["Variables_of_interest"] or ( + "dynamic_target" in config["Variables_of_interest"] + and not config["Variables_of_interest"]["dynamic_target"] + ): # raise ValueError if yloc missing on non-graph, with "dynamic_target" set to false or missing raise ValueError( "y_loc is needed for outputs that are not at graph levels", output_type[ihead],