Skip to content

Commit 5867e5f

Browse files
committed
Clean up formatting
1 parent 0fa7cc4 commit 5867e5f

File tree

1 file changed

+20
-28
lines changed
  • src/NanoParticleTools/machine_learning/models/gnn_model

1 file changed

+20
-28
lines changed

src/NanoParticleTools/machine_learning/models/gnn_model/model.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,17 @@
1-
import torch
2-
import torch.nn.functional as F
3-
from torch import nn
4-
from torch_geometric import nn as pyg_nn
5-
import pytorch_lightning as pl
6-
71
from NanoParticleTools.machine_learning.core import SpectrumModelBase
82
from NanoParticleTools.machine_learning.models.mlp_model.model import MLPSpectrumModel
9-
from typing import Optional, Callable
10-
from torch_scatter.scatter import scatter
11-
from NanoParticleTools.machine_learning.core import SpectrumModelBase
12-
from NanoParticleTools.machine_learning.modules.layer_interaction import InteractionBlock, InteractionConv
3+
from NanoParticleTools.machine_learning.modules.layer_interaction import (
4+
InteractionBlock, InteractionConv)
135
from NanoParticleTools.machine_learning.modules.film import FiLMLayer
146
from NanoParticleTools.machine_learning.modules import NonLinearMLP
15-
from torch_geometric.data.batch import Batch
16-
from torch_geometric.data import HeteroData
7+
8+
from torch_scatter.scatter import scatter
179
from torch.nn import functional as F
1810
from torch import nn
1911
import torch
20-
import torch_geometric.nn as gnn
21-
import warnings
22-
from typing import Dict, List
23-
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity
24-
from torch.nn.functional import pairwise_distance
12+
from torch_geometric import nn as pyg_nn
13+
14+
from typing import Dict, List, Optional, Callable
2515

2616

2717
class GraphRepresentationModel(SpectrumModelBase):
@@ -73,16 +63,16 @@ def __init__(self,
7363
self.convs = nn.ModuleList()
7464
for _ in range(self.n_message_passing):
7565
self.convs.append(
76-
gnn.GATv2Conv(embed_dim,
77-
embed_dim,
78-
edge_dim=nsigma,
79-
concat=False,
80-
add_self_loops=False))
66+
pyg_nn.GATv2Conv(embed_dim,
67+
embed_dim,
68+
edge_dim=nsigma,
69+
concat=False,
70+
add_self_loops=False))
8171

8272
if aggregation == 'sum':
83-
self.aggregation = gnn.aggr.SumAggregation()
73+
self.aggregation = pyg_nn.aggr.SumAggregation()
8474
elif aggregation == 'mean':
85-
self.aggregation = gnn.aggr.MeanAggregation()
75+
self.aggregation = pyg_nn.aggr.MeanAggregation()
8676
elif aggregation == 'set2set':
8777
# Use the Set2Set aggregation method to pool the graph
8878
# into a single global feature vector
@@ -211,14 +201,16 @@ def __init__(self,
211201
self.convs = nn.ModuleList()
212202
for i in range(self.n_message_passing):
213203
if i == 0:
214-
self.convs.append(InteractionConv(embed_dim, embed_dim, nsigma))
204+
self.convs.append(InteractionConv(embed_dim, embed_dim,
205+
nsigma))
215206
else:
216-
self.convs.append(InteractionConv(2*embed_dim, embed_dim, nsigma))
207+
self.convs.append(
208+
InteractionConv(2 * embed_dim, embed_dim, nsigma))
217209

218210
if aggregation == 'sum':
219-
self.aggregation = gnn.aggr.SumAggregation()
211+
self.aggregation = pyg_nn.aggr.SumAggregation()
220212
elif aggregation == 'mean':
221-
self.aggregation = gnn.aggr.MeanAggregation()
213+
self.aggregation = pyg_nn.aggr.MeanAggregation()
222214
elif aggregation == 'set2set':
223215
# Use the Set2Set aggregation method to pool the graph
224216
# into a single global feature vector

0 commit comments

Comments
 (0)