Skip to content

Commit c949816

Browse files
committed
fix testcode & change to softplus
1 parent 7e3f7fb commit c949816

File tree

4 files changed

+188
-130
lines changed

4 files changed

+188
-130
lines changed

Datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,17 @@ def MnistLabel(class_num):
2626
return TensorDataset(torch.FloatTensor(np.array(data)), torch.LongTensor(np.array(labels)))
2727

2828
def MnistUnlabel():
29-
return datasets.MNIST('../data', train=True, download=True,
29+
raw_dataset = datasets.MNIST('../data', train=True, download=True,
3030
transform=transforms.Compose([
3131
transforms.ToTensor(),
3232
#transforms.Normalize((0.1307,), (0.3081,))
3333
]))
34-
34+
return raw_dataset
3535
def MnistTest():
3636
return datasets.MNIST('../data', train=False, download=True,
3737
transform=transforms.Compose([
3838
transforms.ToTensor(),
39-
transforms.Normalize((0.1307,), (0.3081,))
39+
#transforms.Normalize((0.1307,), (0.3081,))
4040
]))
4141

4242
if __name__ == '__main__':

ImprovedGAN.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,26 @@
1313
from Nets import Generator, Discriminator
1414
from Datasets import *
1515
import pdb
16+
import tensorboardX
1617
class ImprovedGAN(object):
1718
def __init__(self, G, D, labeled, unlabeled, test, args):
1819
self.G = G
1920
self.D = D
20-
if args.cuda:
21-
self.G.cuda()
22-
self.D.cuda()
21+
self.writer = tensorboardX.SummaryWriter(log_dir='logfile')
22+
if args.cuda:
23+
self.G.cuda()
24+
self.D.cuda()
2325
self.labeled = labeled
2426
self.unlabeled = unlabeled
2527
self.test = test
26-
self.Doptim = optim.SGD(self.D.parameters(), lr=args.lr, momentum = args.momentum)
27-
self.Goptim = optim.Adam(self.G.parameters(), lr=args.lr)
28+
self.Doptim = optim.Adam(self.D.parameters(), lr=args.lr, betas= (args.momentum, 0.999))
29+
self.Goptim = optim.Adam(self.G.parameters(), lr=args.lr, betas = (args.momentum,0.999))
2830
self.args = args
2931
def trainD(self, x_label, y, x_unlabel):
3032
x_label, x_unlabel, y = Variable(x_label), Variable(x_unlabel), Variable(y, requires_grad = False)
3133
if self.args.cuda:
3234
x_label, x_unlabel, y = x_label.cuda(), x_unlabel.cuda(), y.cuda()
33-
output_label, output_unlabel, output_fake = self.D(x_label, cuda=self.args.cuda), self.D(x_unlabel, cuda=self.args.cuda), self.D(self.G(x_unlabel.size()[0], cuda = self.args.cuda).view(x_unlabel.size()), cuda=self.args.cuda)
35+
output_label, output_unlabel, output_fake = self.D(x_label, cuda=self.args.cuda), self.D(x_unlabel, cuda=self.args.cuda), self.D(self.G(x_unlabel.size()[0], cuda = self.args.cuda).view(x_unlabel.size()).detach(), cuda=self.args.cuda)
3436
logz_label, logz_unlabel, logz_fake = log_sum_exp(output_label), log_sum_exp(output_unlabel), log_sum_exp(output_fake) # log ∑e^x_i
3537
prob_label = torch.gather(output_label, 1, y.unsqueeze(1)) # log e^x_label = x_label
3638
loss_supervised = -torch.mean(prob_label) + torch.mean(logz_label)
@@ -40,53 +42,69 @@ def trainD(self, x_label, y, x_unlabel):
4042
acc = torch.mean((output_label.max(1)[1] == y).float())
4143
self.Doptim.zero_grad()
4244
loss.backward()
43-
if loss != loss:
44-
pdb.set_trace()
4545
self.Doptim.step()
4646
return loss_supervised.data.cpu().numpy(), loss_unsupervised.data.cpu().numpy(), acc
4747

4848
def trainG(self, x_unlabel):
49-
mom_gen = torch.mean(self.D(self.G(x_unlabel.size()[0], cuda = self.args.cuda).view(x_unlabel.size()), feature=True, cuda=self.args.cuda), dim = 0)
50-
mom_unlabel = torch.mean(self.D(Variable(x_unlabel), feature=True, cuda=self.args.cuda), dim = 0)
51-
loss = torch.mean((mom_gen - mom_unlabel) ** 2)
49+
fake = self.G(x_unlabel.size()[0], cuda = self.args.cuda).view(x_unlabel.size())
50+
# fake.retain_grad()
51+
mom_gen, output_fake = self.D(fake, feature=True, cuda=self.args.cuda)
52+
mom_unlabel, _ = self.D(Variable(x_unlabel), feature=True, cuda=self.args.cuda)
53+
mom_gen = torch.mean(mom_gen, dim = 0)
54+
mom_unlabel = torch.mean(mom_unlabel, dim = 0)
55+
loss_fm = torch.mean((mom_gen - mom_unlabel) ** 2)
56+
#loss_adv = -torch.mean(F.softplus(log_sum_exp(output_fake)))
57+
loss = loss_fm #+ 1. * loss_adv
5258
self.Goptim.zero_grad()
59+
self.Doptim.zero_grad()
5360
loss.backward()
5461
self.Goptim.step()
55-
a = self.G.main[0].weight != self.G.main[0].weight
56-
if torch.sum(a.float()) > 0:
57-
pdb.set_trace()
5862
return loss.data.cpu().numpy()
5963

6064
def train(self):
6165
assert self.unlabeled.__len__() > self.labeled.__len__()
6266
assert type(self.labeled) == TensorDataset
6367
times = int(np.ceil(self.unlabeled.__len__() * 1. / self.labeled.__len__()))
64-
t1 = self.labeled.data_tensor.clone()
65-
t2 = self.labeled.target_tensor.clone()
66-
#tile_labeled = TensorDataset(self.labeled.data_tensor.repeat(times, 1, 1, 1), self.labeled.target_tensor.repeat(times))
67-
tile_labeled = TensorDataset(t1.repeat(times,1,1,1),t2.repeat(times))
68+
t1 = self.labeled.data_tensor.clone()
69+
t2 = self.labeled.target_tensor.clone()
70+
tile_labeled = TensorDataset(t1.repeat(times,1,1,1),t2.repeat(times))
71+
gn = 0
6872
for epoch in range(self.args.epochs):
6973
self.G.train()
7074
self.D.train()
71-
unlabel_loader1 = DataLoader(self.unlabeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True)
72-
unlabel_loader2 = DataLoader(self.unlabeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True).__iter__()
73-
label_loader = DataLoader(tile_labeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True).__iter__()
74-
batch_num = loss_supervised = loss_unsupervised = loss_gen = accuracy = 0.
75+
unlabel_loader1 = DataLoader(self.unlabeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True, num_workers = 4)
76+
unlabel_loader2 = DataLoader(self.unlabeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True, num_workers = 4).__iter__()
77+
label_loader = DataLoader(tile_labeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True, num_workers = 4).__iter__()
78+
loss_supervised = loss_unsupervised = loss_gen = accuracy = 0.
79+
batch_num = 0
7580
for (unlabel1, _label1) in unlabel_loader1:
81+
# pdb.set_trace()
7682
batch_num += 1
7783
unlabel2, _label2 = unlabel_loader2.next()
7884
x, y = label_loader.next()
79-
if args.cuda:
80-
x, y, unlabel1, unlabel2 = x.cuda(), y.cuda(), unlabel1.cuda(), unlabel2.cuda()
85+
if args.cuda:
86+
x, y, unlabel1, unlabel2 = x.cuda(), y.cuda(), unlabel1.cuda(), unlabel2.cuda()
8187
ll, lu, acc = self.trainD(x, y, unlabel1)
8288
loss_supervised += ll
8389
loss_unsupervised += lu
8490
accuracy += acc
8591
lg = self.trainG(unlabel2)
92+
if epoch > 1 and lg > 1:
93+
# pdb.set_trace()
94+
lg = self.trainG(unlabel2)
8695
loss_gen += lg
8796
if (batch_num + 1) % self.args.log_interval == 0:
8897
print('Training: %d / %d' % (batch_num + 1, len(unlabel_loader1)))
89-
print('Eval: correct %d/%d, %.4f' % (self.eval(), self.test.__len__(), acc))
98+
gn += 1
99+
self.writer.add_scalars('loss', {'loss_supervised':ll, 'loss_unsupervised':lu, 'loss_gen':lg}, gn)
100+
self.writer.add_histogram('real_feature', self.D(Variable(x, volatile = True), cuda=self.args.cuda, feature = True)[0], gn)
101+
self.writer.add_histogram('fake_feature', self.D(self.G(self.args.batch_size, cuda = self.args.cuda), cuda=self.args.cuda, feature = True)[0], gn)
102+
self.writer.add_histogram('fc3_bias', self.G.fc3.bias, gn)
103+
self.writer.add_histogram('D_feature_weight', self.D.layers[-1].weight, gn)
104+
# self.writer.add_histogram('D_feature_bias', self.D.layers[-1].bias, gn)
105+
#print('Eval: correct %d/%d, %.4f' % (self.eval(), self.test.__len__(), acc))
106+
self.D.train()
107+
self.G.train()
90108
loss_supervised /= batch_num
91109
loss_unsupervised /= batch_num
92110
loss_gen /= batch_num
@@ -106,16 +124,16 @@ def eval(self):
106124
d.append(datum)
107125
l.append(label)
108126
x, y = torch.stack(d), torch.LongTensor(l)
109-
if self.args.cuda:
110-
x, y = x.cuda(), y.cuda()
127+
if self.args.cuda:
128+
x, y = x.cuda(), y.cuda()
111129
pred = self.predict(x)
112130
return torch.sum(pred == y)
113131
def draw(self, batch_size):
114-
self.G.eval()
115-
return self.G(batch_size, cuda=self.args.cuda)
132+
self.G.eval()
133+
return self.G(batch_size, cuda=self.args.cuda)
116134
if __name__ == '__main__':
117135
parser = argparse.ArgumentParser(description='PyTorch Improved GAN')
118-
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
136+
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
119137
help='input batch size for training (default: 64)')
120138
parser.add_argument('--epochs', type=int, default=10, metavar='N',
121139
help='number of epochs to train (default: 10)')
@@ -131,7 +149,7 @@ def draw(self, batch_size):
131149
help='how many batches to wait before logging training status')
132150
parser.add_argument('--eval-interval', type=int, default=1, metavar='N',
133151
help='how many batches to wait before evaling training status')
134-
parser.add_argument('--unlabel-weight', type=int, default=1, metavar='N',
152+
parser.add_argument('--unlabel-weight', type=float, default=1, metavar='N',
135153
help='scale factor between labeled and unlabeled data')
136154
args = parser.parse_args()
137155
args.cuda = args.cuda and torch.cuda.is_available()

Nets.py

Lines changed: 102 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import torch
2+
from torch.nn.parameter import Parameter
23
from torch import nn
34
from torch.nn import functional as F
45
from torch.autograd import Variable
5-
6+
import pdb
7+
from functional import reset_normal_param, LinearWeightNorm
68
# class Discriminator(nn.Module):
79
# def __init__(self, output_units = 10):
810
# super(Discriminator, self).__init__()
@@ -22,107 +24,113 @@
2224
# x = self.fc2(x)
2325
# return x if not feature else x_f
2426

25-
#class Discriminator(nn.Module):
26-
# def __init__(self, input_dim = 28 ** 2, output_dim = 10):
27-
# super(Discriminator, self).__init__()
28-
# self.input_dim = input_dim
29-
# self.layers = torch.nn.ModuleList([
30-
# nn.Linear(input_dim, 1000),
31-
# nn.Linear(1000, 500),
32-
# nn.Linear(500, 250)]
33-
# )
34-
# self.bns = torch.nn.ModuleList([
35-
# nn.BatchNorm1d(1000, affine=False),
36-
# nn.BatchNorm1d(500, affine=False),
37-
# nn.BatchNorm1d(250, affine=True)]
38-
# )
39-
# self.final = nn.Linear(250, output_dim)
40-
# def forward(self, x, feature = False, cuda = False):
41-
# x = x.view(-1, self.input_dim)
42-
# noise = torch.randn(x.size()) * 0.3 if self.training else torch.Tensor([0])
43-
# if cuda:
44-
# noise = noise.cuda()
45-
# x = x + Variable(noise, requires_grad = False)
46-
# for i in range(len(self.layers)):
47-
# m = self.layers[i]
48-
# bn = self.bns[i]
49-
# x_f = F.leaky_relu(m(x))
50-
# noise = torch.randn(x_f.size()) * 0.5 if self.training else torch.Tensor([0])
51-
# if cuda:
52-
# noise = noise.cuda()
53-
# x = (x_f + Variable(noise, requires_grad = False))
54-
# if feature:
55-
# return x_f
56-
# return self.final(x)
57-
58-
59-
# class Generator(nn.Module):
60-
# def __init__(self, z_dim, output_dim = 28 ** 2):
61-
# super(Generator, self).__init__()
62-
# self.z_dim = z_dim
63-
# self.fc1 = nn.Linear(z_dim, 500)
64-
# self.bn1 = nn.BatchNorm1d(500, affine = False)
65-
# self.fc2 = nn.Linear(500, 500)
66-
# self.bn2 = nn.BatchNorm1d(500, affine = False)
67-
# self.fc3 = nn.Linear(500, output_dim)
68-
# def forward(self, batch_size, cuda = False):
69-
# x = Variable(torch.rand(batch_size, self.z_dim), requires_grad = False)
70-
# if cuda:
71-
# x = x.cuda()
72-
# x = F.softplus(self.bn1(self.fc1(x)))
73-
# #x = F.softplus(self.bn2(self.fc2(x)))
74-
# x = F.sigmoid(self.fc3(x))
75-
# return x
76-
7727
class Discriminator(nn.Module):
78-
def __init__(self, nc = 1, ndf = 64, output_units = 10):
28+
def __init__(self, input_dim = 28 ** 2, output_dim = 10):
7929
super(Discriminator, self).__init__()
80-
self.ndf = ndf
81-
self.main = nn.Sequential(
82-
# state size. (nc) x 28 x 28
83-
nn.Conv2d(nc, ndf, 4, 2, 3, bias=False),
84-
nn.BatchNorm2d(ndf),
85-
nn.LeakyReLU(0.2, inplace=True),
86-
# state size. (ndf) x 16 x 16
87-
nn.Conv2d(ndf, ndf * 4, 4, 2, 1, bias=False),
88-
nn.BatchNorm2d(ndf * 4),
89-
nn.LeakyReLU(0.2, inplace=True),
90-
# state size. (ndf*2) x 8 x 8
91-
nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False),
92-
nn.BatchNorm2d(ndf * 4),
93-
nn.LeakyReLU(0.2, inplace=True),
94-
# state size. (ndf*4) x 4 x 4
95-
nn.Conv2d(ndf * 4, ndf * 4, 4, 1, 0, bias=False),
30+
self.input_dim = input_dim
31+
self.layers = torch.nn.ModuleList([
32+
LinearWeightNorm(input_dim, 1000),
33+
LinearWeightNorm(1000, 500),
34+
LinearWeightNorm(500, 250),
35+
LinearWeightNorm(250, 250),
36+
LinearWeightNorm(250, 250)]
9637
)
97-
self.final = nn.Linear(ndf * 4, output_units, bias=False)
38+
self.final = LinearWeightNorm(250, output_dim, weight_scale=1)
39+
#for layer in self.layers:
40+
# reset_normal_param(layer, 0.1)
41+
#reset_normal_param(self.final, 0.1, 5)
9842
def forward(self, x, feature = False, cuda = False):
99-
x_f = self.main(x).view(-1, self.ndf * 4)
100-
return x_f if feature else self.final(x_f)
43+
x = x.view(-1, self.input_dim)
44+
noise = torch.randn(x.size()) * 0.3 if self.training else torch.Tensor([0])
45+
if cuda:
46+
noise = noise.cuda()
47+
x = x + Variable(noise, requires_grad = False)
48+
for i in range(len(self.layers)):
49+
m = self.layers[i]
50+
x_f = F.relu(m(x))
51+
noise = torch.randn(x_f.size()) * 0.5 if self.training else torch.Tensor([0])
52+
if cuda:
53+
noise = noise.cuda()
54+
x = (x_f + Variable(noise, requires_grad = False))
55+
if feature:
56+
return x_f, self.final(x)
57+
return self.final(x)
58+
10159

10260
class Generator(nn.Module):
103-
def __init__(self, z_dim, ngf = 64, output_dim = 28 ** 2):
61+
def __init__(self, z_dim, output_dim = 28 ** 2):
10462
super(Generator, self).__init__()
10563
self.z_dim = z_dim
106-
self.main = nn.Sequential(
107-
# input is Z, going into a convolution
108-
nn.ConvTranspose2d(z_dim, ngf * 4, 4, 1, 0, bias=False),
109-
nn.BatchNorm2d(ngf * 4),
110-
nn.ReLU(True),
111-
# state size. (ngf*8) x 4 x 4
112-
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
113-
nn.BatchNorm2d(ngf * 2),
114-
nn.ReLU(True),
115-
# state size. (ngf*4) x 8 x 8
116-
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
117-
nn.BatchNorm2d(ngf),
118-
nn.ReLU(True),
119-
# state size. (ngf*2) x 16 x 16
120-
nn.ConvTranspose2d(ngf, 1, 4, 2, 3, bias=False),
121-
# state size. (ngf) x 32 x 32
122-
nn.Sigmoid()
123-
)
64+
self.fc1 = nn.Linear(z_dim, 500, bias = False)
65+
self.bn1 = nn.BatchNorm1d(500, affine = False, eps=1e-6, momentum = 0.5)
66+
self.fc2 = nn.Linear(500, 500, bias = False)
67+
self.bn2 = nn.BatchNorm1d(500, affine = False, eps=1e-6, momentum = 0.5)
68+
self.fc3 = LinearWeightNorm(500, output_dim, weight_scale = 1)
69+
self.bn1_b = Parameter(torch.zeros(500))
70+
self.bn2_b = Parameter(torch.zeros(500))
71+
nn.init.xavier_uniform(self.fc1.weight)
72+
nn.init.xavier_uniform(self.fc2.weight)
73+
#reset_normal_param(self.fc1, 0.1)
74+
#reset_normal_param(self.fc2, 0.1)
75+
#reset_normal_param(self.fc3, 0.1)
12476
def forward(self, batch_size, cuda = False):
125-
x = Variable(torch.rand(batch_size, self.z_dim, 1, 1), requires_grad = False, volatile = self.training)
77+
x = Variable(torch.rand(batch_size, self.z_dim), requires_grad = False, volatile = not self.training)
12678
if cuda:
12779
x = x.cuda()
128-
return self.main(x)
80+
x = F.softplus(self.bn1(self.fc1(x)) + self.bn1_b)
81+
x = F.softplus(self.bn2(self.fc2(x)) + self.bn2_b)
82+
x = F.softplus(self.fc3(x))
83+
return x
84+
85+
#class Discriminator(nn.Module):
86+
# def __init__(self, nc = 1, ndf = 64, output_units = 10):
87+
# super(Discriminator, self).__init__()
88+
# self.ndf = ndf
89+
# self.main = nn.Sequential(
90+
# # state size. (nc) x 28 x 28
91+
# nn.Conv2d(nc, ndf, 4, 2, 3, bias=False),
92+
# nn.BatchNorm2d(ndf),
93+
# nn.LeakyReLU(0.2, inplace=True),
94+
# # state size. (ndf) x 16 x 16
95+
# nn.Conv2d(ndf, ndf * 4, 4, 2, 1, bias=False),
96+
# nn.BatchNorm2d(ndf * 4),
97+
# nn.LeakyReLU(0.2, inplace=True),
98+
# # state size. (ndf*2) x 8 x 8
99+
# nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False),
100+
# nn.BatchNorm2d(ndf * 4),
101+
# nn.LeakyReLU(0.2, inplace=True),
102+
# # state size. (ndf*4) x 4 x 4
103+
# nn.Conv2d(ndf * 4, ndf * 4, 4, 1, 0, bias=False),
104+
# )
105+
# self.final = nn.Linear(ndf * 4, output_units, bias=False)
106+
# def forward(self, x, feature = False, cuda = False):
107+
# x_f = self.main(x).view(-1, self.ndf * 4)
108+
# return x_f if feature else self.final(x_f)
109+
110+
#class Generator(nn.Module):
111+
# def __init__(self, z_dim, ngf = 64, output_dim = 28 ** 2):
112+
# super(Generator, self).__init__()
113+
# self.z_dim = z_dim
114+
# self.main = nn.Sequential(
115+
# # input is Z, going into a convolution
116+
# nn.ConvTranspose2d(z_dim, ngf * 4, 4, 1, 0, bias=False),
117+
# nn.BatchNorm2d(ngf * 4),
118+
# nn.ReLU(True),
119+
# # state size. (ngf*8) x 4 x 4
120+
# nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
121+
# nn.BatchNorm2d(ngf * 2),
122+
# nn.ReLU(True),
123+
# # state size. (ngf*4) x 8 x 8
124+
# nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
125+
# nn.BatchNorm2d(ngf),
126+
# nn.ReLU(True),
127+
# # state size. (ngf*2) x 16 x 16
128+
# nn.ConvTranspose2d(ngf, 1, 4, 2, 3, bias=False),
129+
# # state size. (ngf) x 32 x 32
130+
# nn.Sigmoid()
131+
# )
132+
# def forward(self, batch_size, cuda = False):
133+
# x = Variable(torch.rand(batch_size, self.z_dim, 1, 1), requires_grad = False, volatile = not self.training)
134+
# if cuda:
135+
# x = x.cuda()
136+
# return self.main(x)

0 commit comments

Comments
 (0)