diff --git a/src/NanoParticleTools/machine_learning/models/gnn_model/model.py b/src/NanoParticleTools/machine_learning/models/gnn_model/model.py index 51174f5..5bd43a7 100644 --- a/src/NanoParticleTools/machine_learning/models/gnn_model/model.py +++ b/src/NanoParticleTools/machine_learning/models/gnn_model/model.py @@ -1,27 +1,17 @@ -import torch -import torch.nn.functional as F -from torch import nn -from torch_geometric import nn as pyg_nn -import pytorch_lightning as pl - from NanoParticleTools.machine_learning.core import SpectrumModelBase from NanoParticleTools.machine_learning.models.mlp_model.model import MLPSpectrumModel -from typing import Optional, Callable -from torch_scatter.scatter import scatter -from NanoParticleTools.machine_learning.core import SpectrumModelBase -from NanoParticleTools.machine_learning.modules.layer_interaction import InteractionBlock, InteractionConv +from NanoParticleTools.machine_learning.modules.layer_interaction import ( + InteractionBlock, InteractionConv) from NanoParticleTools.machine_learning.modules.film import FiLMLayer from NanoParticleTools.machine_learning.modules import NonLinearMLP -from torch_geometric.data.batch import Batch -from torch_geometric.data import HeteroData + +from torch_scatter.scatter import scatter from torch.nn import functional as F from torch import nn import torch -import torch_geometric.nn as gnn -import warnings -from typing import Dict, List -from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity -from torch.nn.functional import pairwise_distance +from torch_geometric import nn as pyg_nn + +from typing import Dict, List, Optional, Callable class GraphRepresentationModel(SpectrumModelBase): @@ -73,16 +63,16 @@ def __init__(self, self.convs = nn.ModuleList() for _ in range(self.n_message_passing): self.convs.append( - gnn.GATv2Conv(embed_dim, - embed_dim, - edge_dim=nsigma, - concat=False, - add_self_loops=False)) + pyg_nn.GATv2Conv(embed_dim, + embed_dim, + edge_dim=nsigma, + concat=False, + add_self_loops=False)) if aggregation == 'sum': - self.aggregation = gnn.aggr.SumAggregation() + self.aggregation = pyg_nn.aggr.SumAggregation() elif aggregation == 'mean': - self.aggregation = gnn.aggr.MeanAggregation() + self.aggregation = pyg_nn.aggr.MeanAggregation() elif aggregation == 'set2set': # Use the Set2Set aggregation method to pool the graph # into a single global feature vector @@ -211,14 +201,16 @@ def __init__(self, self.convs = nn.ModuleList() for i in range(self.n_message_passing): if i == 0: - self.convs.append(InteractionConv(embed_dim, embed_dim, nsigma)) + self.convs.append(InteractionConv(embed_dim, embed_dim, + nsigma)) else: - self.convs.append(InteractionConv(2*embed_dim, embed_dim, nsigma)) + self.convs.append( + InteractionConv(2 * embed_dim, embed_dim, nsigma)) if aggregation == 'sum': - self.aggregation = gnn.aggr.SumAggregation() + self.aggregation = pyg_nn.aggr.SumAggregation() elif aggregation == 'mean': - self.aggregation = gnn.aggr.MeanAggregation() + self.aggregation = pyg_nn.aggr.MeanAggregation() elif aggregation == 'set2set': # Use the Set2Set aggregation method to pool the graph # into a single global feature vector