-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
231 lines (188 loc) · 10.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""
Implementation of ProGAN generator and discriminator with the key
attributions from the paper. We have tried to make the implementation
compact but a goal is also to keep it readable and understandable.
Specifically the key points implemented are:
1) Progressive growing (of model and layers)
2) Minibatch std on Discriminator
3) Normalization with PixelNorm
4) Equalized Learning Rate (for this implementation we only did it on Conv layers)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log2
"""
Factors is used in Discrmininator and Generator for how much
the channels should be multiplied and expanded for each layer,
specifically the first 5 layers the channels stay the same,
we then increase the img_size (towards the later layers)
decreasing the number of channels by 1/2, 1/4, etc.
"""
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]
class WSConv2d(nn.Module):
"""
Weight scaled Conv2d (Equalized Learning Rate)
Note that input is multiplied rather than changing weights
this will have the same result.
inspired by:
https://github.com/nvnbny/progressive_growing_of_gans/blob/master/modelUtils.py
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
super(WSConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5 # define scale
self.bias = self.conv.bias
self.conv.bias = None
# initialize conv layer
nn.init.normal_(self.conv.weight) # init weights with normal distribution
nn.init.zeros_(self.bias) # init bias with zeros
def forward(self, x):
'''reshape bias layer then add to conv layer with scaler'''
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
class PixelNorm(nn.Module):
def __init__(self):
super(PixelNorm, self).__init__()
self.epsilon = 1e-8
def forward(self, x):
'''Local Response Normalization (Krizhevsky et al., 2012)'''
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
class ConvBlock(nn.Module):
'''define generic convolutional architecture for progressive steps'''
def __init__(self, in_channels, out_channels, use_pixelnorm=True):
super(ConvBlock, self).__init__()
self.use_pn = use_pixelnorm # set pixel norm state
self.conv1 = WSConv2d(in_channels, out_channels) # init conv1 block
self.conv2 = WSConv2d(out_channels, out_channels) # init conv2 block
self.leaky = nn.LeakyReLU(0.2)
self.pn = PixelNorm() # init norm object
def forward(self, x):
x = self.leaky(self.conv1(x)) # send x through conv1 block
x = self.pn(x) if self.use_pn else x # norm x according to norm setting
x = self.leaky(self.conv2(x)) # send x through conv2 block
x = self.pn(x) if self.use_pn else x # norm x according to norm setting
return x
class Generator(nn.Module):
def __init__(self, z_dim, in_channels, img_channels=3):
super(Generator, self).__init__()
# init blocks 1x1 -> 4x4
self.initial = nn.Sequential(
PixelNorm(),
nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
nn.LeakyReLU(0.2),
WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
PixelNorm(),)
# init rgb and progressive blocks with equalized learning rates
self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size=1, stride=1, padding=0)
self.prog_blocks, self.rgb_layers = (nn.ModuleList([]), nn.ModuleList([self.initial_rgb]),)
# build defined progressive blocks
for i in range(len(factors) - 1): # -1 to prevent index error because of factors[i+1]
conv_in_c = int(in_channels * factors[i])
conv_out_c = int(in_channels * factors[i + 1]) # shift up factors
self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0))
def fade_in(self, alpha, upscaled, generated):
'''
fade in according to alpha (i.e. closeness to step change)
alpha should be scalar within [0, 1], and upscale.shape == generated.shape
'''
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
def forward(self, x, alpha, steps):
out = self.initial(x)
# base case
if steps == 0:
return self.initial_rgb(out)
# scale image then feed into progressive blocks
for step in range(steps):
upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
out = self.prog_blocks[step](upscaled)
# The number of channels in upscale will stay the same, while
# the output, which has moved through prog_blocks may change. To ensure
# we can convert both to rgb we use different rgb_layers
# using (steps-1) for upscaled and steps for out
final_upscaled = self.rgb_layers[steps - 1](upscaled)
final_out = self.rgb_layers[steps](out)
return self.fade_in(alpha, final_upscaled, final_out)
class Discriminator(nn.Module):
def __init__(self, z_dim, in_channels, img_channels=3):
super(Discriminator, self).__init__()
# init rgb and progressive blocks with equalized learning rates
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
self.leaky = nn.LeakyReLU(0.2) # init leaky activation function
# here we work backwards using factors in order to mirror the generator
# so the first prog_block and for each rgb layer we append
# we will work through the input sizes in reverse, in this case 512->256-> etc.
for i in range(len(factors) - 1, 0, -1):
conv_in = int(in_channels * factors[i])
conv_out = int(in_channels * factors[i - 1]) # shift down factors
self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
self.rgb_layers.append(WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0))
# in order to mirror the generator we place the not so accurately named
# initial_rgb block for our 4x4 input size at the end of our network structure
self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.rgb_layers.append(self.initial_rgb)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) # down sampling using avg pool
# block for 4x4 input size
self.final_block = nn.Sequential(
# +1 to in_channels because we concatenate from MiniBatch std
WSConv2d(in_channels + 1, in_channels, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
# stride here should be evenly divisible by the kernel size but > 1
# this helps prevent our network from learning overlapping pooling patterns
WSConv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=0),
nn.LeakyReLU(0.2),
WSConv2d(in_channels, 1, kernel_size=1, stride=1, padding=0, gain=1), # we use this instead of a linear layer
)
def fade_in(self, alpha, downscaled, out):
"""Used to fade in downscaled using avg pooling and output from CNN"""
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
return alpha * out + (1 - alpha) * downscaled
def minibatch_std(self, x):
batch_statistics = (torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]))
# we take the std for each example (across all channels, and pixels) then we repeat it
# for a single channel and concatenate it with the image. In this way the discriminator
# will get information about the variation in the batch/image
return torch.cat([x, batch_statistics], dim=1)
def forward(self, x, alpha, steps):
# where we should start in the list of prog_blocks. For example,
# if steps=1, then we should start at the second to last position
# because input_size will be 8x8
# if steps==0 we just use the final block
cur_step = len(self.prog_blocks) - steps
# convert from rgb as initial step, this will depend on
# the image size (each will have it's own rgb layers)
out = self.leaky(self.rgb_layers[cur_step](x))
# base case
if steps == 0: # i.e, image is 4x4
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
# because prog_blocks may change the channels, for down scale we use rgb_layer
# from previous/smaller size; which in our case correlates to current step + 1
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
out = self.avg_pool(self.prog_blocks[cur_step](out))
# the fade_in is done first between the downscaled and the input
# this is opposite from the generator
out = self.fade_in(alpha, downscaled, out)
# step through defined progressive blocks
for step in range(cur_step + 1, len(self.prog_blocks)):
out = self.prog_blocks[step](out)
out = self.avg_pool(out)
# collect minibatch standard deviation then pass through final block
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
if __name__ == "__main__":
# init test vars
Z_DIM = 100
IN_CHANNELS = 256
gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
critic = Discriminator(Z_DIM, IN_CHANNELS, img_channels=3)
# test progressive model architecture
for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
num_steps = int(log2(img_size / 4))
x = torch.randn((1, Z_DIM, 1, 1))
z = gen(x, 0.5, steps=num_steps)
assert z.shape == (1, 3, img_size, img_size)
out = critic(z, alpha=0.5, steps=num_steps)
assert out.shape == (1, 1)
print(f"Success! At img size: {img_size}")