Skip to content

Commit

Permalink
Clean up formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
sivonxay committed Jan 6, 2024
1 parent 0fa7cc4 commit 5867e5f
Showing 1 changed file with 20 additions and 28 deletions.
48 changes: 20 additions & 28 deletions src/NanoParticleTools/machine_learning/models/gnn_model/model.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric import nn as pyg_nn
import pytorch_lightning as pl

from NanoParticleTools.machine_learning.core import SpectrumModelBase
from NanoParticleTools.machine_learning.models.mlp_model.model import MLPSpectrumModel
from typing import Optional, Callable
from torch_scatter.scatter import scatter
from NanoParticleTools.machine_learning.core import SpectrumModelBase
from NanoParticleTools.machine_learning.modules.layer_interaction import InteractionBlock, InteractionConv
from NanoParticleTools.machine_learning.modules.layer_interaction import (
InteractionBlock, InteractionConv)
from NanoParticleTools.machine_learning.modules.film import FiLMLayer
from NanoParticleTools.machine_learning.modules import NonLinearMLP
from torch_geometric.data.batch import Batch
from torch_geometric.data import HeteroData

from torch_scatter.scatter import scatter
from torch.nn import functional as F
from torch import nn
import torch
import torch_geometric.nn as gnn
import warnings
from typing import Dict, List
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity
from torch.nn.functional import pairwise_distance
from torch_geometric import nn as pyg_nn

from typing import Dict, List, Optional, Callable


class GraphRepresentationModel(SpectrumModelBase):
Expand Down Expand Up @@ -73,16 +63,16 @@ def __init__(self,
self.convs = nn.ModuleList()
for _ in range(self.n_message_passing):
self.convs.append(
gnn.GATv2Conv(embed_dim,
embed_dim,
edge_dim=nsigma,
concat=False,
add_self_loops=False))
pyg_nn.GATv2Conv(embed_dim,
embed_dim,
edge_dim=nsigma,
concat=False,
add_self_loops=False))

if aggregation == 'sum':
self.aggregation = gnn.aggr.SumAggregation()
self.aggregation = pyg_nn.aggr.SumAggregation()
elif aggregation == 'mean':
self.aggregation = gnn.aggr.MeanAggregation()
self.aggregation = pyg_nn.aggr.MeanAggregation()
elif aggregation == 'set2set':
# Use the Set2Set aggregation method to pool the graph
# into a single global feature vector
Expand Down Expand Up @@ -211,14 +201,16 @@ def __init__(self,
self.convs = nn.ModuleList()
for i in range(self.n_message_passing):
if i == 0:
self.convs.append(InteractionConv(embed_dim, embed_dim, nsigma))
self.convs.append(InteractionConv(embed_dim, embed_dim,
nsigma))
else:
self.convs.append(InteractionConv(2*embed_dim, embed_dim, nsigma))
self.convs.append(
InteractionConv(2 * embed_dim, embed_dim, nsigma))

if aggregation == 'sum':
self.aggregation = gnn.aggr.SumAggregation()
self.aggregation = pyg_nn.aggr.SumAggregation()
elif aggregation == 'mean':
self.aggregation = gnn.aggr.MeanAggregation()
self.aggregation = pyg_nn.aggr.MeanAggregation()
elif aggregation == 'set2set':
# Use the Set2Set aggregation method to pool the graph
# into a single global feature vector
Expand Down

0 comments on commit 5867e5f

Please sign in to comment.