-
Notifications
You must be signed in to change notification settings - Fork 1
/
models_gan.py
73 lines (55 loc) · 2.84 KB
/
models_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File : models_gan.py
# Author : Yuanfei Wang <[email protected]>
# Date : 05.22.2022
# Last Modified Date: 05.22.2022
# Last Modified By : Yuanfei Wang <[email protected]>
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphConvolution, GraphAggregation, MultiDenseLayer
class Generator(nn.Module):
"""Generator network."""
def __init__(self, conv_dims, z_dim, vertexes, edges, nodes, dropout_rate):
super(Generator, self).__init__()
self.activation_f = torch.nn.Tanh()
# Sequential linear layers
self.multi_dense_layer = MultiDenseLayer(z_dim, conv_dims, self.activation_f)
# molecule dataset attribute
self.vertexes = vertexes # number of atoms
self.edges = edges # kinds of bonds
self.nodes = nodes # kinds of atoms
self.edges_layer = nn.Linear(conv_dims[-1], edges * vertexes * vertexes)
self.nodes_layer = nn.Linear(conv_dims[-1], vertexes * nodes)
self.dropoout = nn.Dropout(p=dropout_rate)
def forward(self, x):
output = self.multi_dense_layer(x)
edges_logits = self.edges_layer(output).view(-1, self.edges, self.vertexes, self.vertexes)
# edge matrix should be symmetric
edges_logits = (edges_logits + edges_logits.permute(0, 1, 3, 2)) / 2
edges_logits = self.dropoout(edges_logits.permute(0, 2, 3, 1))
nodes_logits = self.nodes_layer(output)
nodes_logits = self.dropoout(nodes_logits.view(-1, self.vertexes, self.nodes))
return edges_logits, nodes_logits
class Discriminator(nn.Module):
"""Discriminator network with PatchGAN."""
def __init__(self, conv_dim, m_dim, b_dim, with_features=False, f_dim=0, dropout_rate=0.):
super(Discriminator, self).__init__()
self.activation_f = torch.nn.Tanh()
graph_conv_dim, aux_dim, linear_dim = conv_dim
# discriminator
self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, with_features, f_dim, dropout_rate)
self.agg_layer = GraphAggregation(graph_conv_dim[-1] + m_dim, aux_dim, self.activation_f, with_features, f_dim,
dropout_rate)
self.multi_dense_layer = MultiDenseLayer(aux_dim, linear_dim, self.activation_f, dropout_rate=dropout_rate)
self.output_layer = nn.Linear(linear_dim[-1], 1)
def forward(self, adj, hidden, node, activation=None):
# ignore empty edges
adj = adj[:, :, :, 1:].permute(0, 3, 1, 2)
h_1 = self.gcn_layer(node, adj, hidden)
h = self.agg_layer(h_1, node, hidden)
h = self.multi_dense_layer(h)
output = self.output_layer(h)
output = activation(output) if activation is not None else output
return output, h