|
1 |
| -import torch |
2 |
| -import torch.nn.functional as F |
3 |
| -from torch import nn |
4 |
| -from torch_geometric import nn as pyg_nn |
5 |
| -import pytorch_lightning as pl |
6 |
| - |
7 | 1 | from NanoParticleTools.machine_learning.core import SpectrumModelBase
|
8 | 2 | from NanoParticleTools.machine_learning.models.mlp_model.model import MLPSpectrumModel
|
9 |
| -from typing import Optional, Callable |
10 |
| -from torch_scatter.scatter import scatter |
11 |
| -from NanoParticleTools.machine_learning.core import SpectrumModelBase |
12 |
| -from NanoParticleTools.machine_learning.modules.layer_interaction import InteractionBlock, InteractionConv |
| 3 | +from NanoParticleTools.machine_learning.modules.layer_interaction import ( |
| 4 | + InteractionBlock, InteractionConv) |
13 | 5 | from NanoParticleTools.machine_learning.modules.film import FiLMLayer
|
14 | 6 | from NanoParticleTools.machine_learning.modules import NonLinearMLP
|
15 |
| -from torch_geometric.data.batch import Batch |
16 |
| -from torch_geometric.data import HeteroData |
| 7 | + |
| 8 | +from torch_scatter.scatter import scatter |
17 | 9 | from torch.nn import functional as F
|
18 | 10 | from torch import nn
|
19 | 11 | import torch
|
20 |
| -import torch_geometric.nn as gnn |
21 |
| -import warnings |
22 |
| -from typing import Dict, List |
23 |
| -from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity |
24 |
| -from torch.nn.functional import pairwise_distance |
| 12 | +from torch_geometric import nn as pyg_nn |
| 13 | + |
| 14 | +from typing import Dict, List, Optional, Callable |
25 | 15 |
|
26 | 16 |
|
27 | 17 | class GraphRepresentationModel(SpectrumModelBase):
|
@@ -73,16 +63,16 @@ def __init__(self,
|
73 | 63 | self.convs = nn.ModuleList()
|
74 | 64 | for _ in range(self.n_message_passing):
|
75 | 65 | self.convs.append(
|
76 |
| - gnn.GATv2Conv(embed_dim, |
77 |
| - embed_dim, |
78 |
| - edge_dim=nsigma, |
79 |
| - concat=False, |
80 |
| - add_self_loops=False)) |
| 66 | + pyg_nn.GATv2Conv(embed_dim, |
| 67 | + embed_dim, |
| 68 | + edge_dim=nsigma, |
| 69 | + concat=False, |
| 70 | + add_self_loops=False)) |
81 | 71 |
|
82 | 72 | if aggregation == 'sum':
|
83 |
| - self.aggregation = gnn.aggr.SumAggregation() |
| 73 | + self.aggregation = pyg_nn.aggr.SumAggregation() |
84 | 74 | elif aggregation == 'mean':
|
85 |
| - self.aggregation = gnn.aggr.MeanAggregation() |
| 75 | + self.aggregation = pyg_nn.aggr.MeanAggregation() |
86 | 76 | elif aggregation == 'set2set':
|
87 | 77 | # Use the Set2Set aggregation method to pool the graph
|
88 | 78 | # into a single global feature vector
|
@@ -211,14 +201,16 @@ def __init__(self,
|
211 | 201 | self.convs = nn.ModuleList()
|
212 | 202 | for i in range(self.n_message_passing):
|
213 | 203 | if i == 0:
|
214 |
| - self.convs.append(InteractionConv(embed_dim, embed_dim, nsigma)) |
| 204 | + self.convs.append(InteractionConv(embed_dim, embed_dim, |
| 205 | + nsigma)) |
215 | 206 | else:
|
216 |
| - self.convs.append(InteractionConv(2*embed_dim, embed_dim, nsigma)) |
| 207 | + self.convs.append( |
| 208 | + InteractionConv(2 * embed_dim, embed_dim, nsigma)) |
217 | 209 |
|
218 | 210 | if aggregation == 'sum':
|
219 |
| - self.aggregation = gnn.aggr.SumAggregation() |
| 211 | + self.aggregation = pyg_nn.aggr.SumAggregation() |
220 | 212 | elif aggregation == 'mean':
|
221 |
| - self.aggregation = gnn.aggr.MeanAggregation() |
| 213 | + self.aggregation = pyg_nn.aggr.MeanAggregation() |
222 | 214 | elif aggregation == 'set2set':
|
223 | 215 | # Use the Set2Set aggregation method to pool the graph
|
224 | 216 | # into a single global feature vector
|
|
0 commit comments