Skip to content

Commit 7e3f7fb

Browse files
committed
init
0 parents  commit 7e3f7fb

File tree

4 files changed

+317
-0
lines changed

4 files changed

+317
-0
lines changed

Datasets.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
from torch.utils.data import TensorDataset
3+
from torchvision import datasets, transforms
4+
import numpy as np
5+
def MnistLabel(class_num):
6+
raw_dataset = datasets.MNIST('../data', train=True, download=True,
7+
transform=transforms.Compose([
8+
transforms.ToTensor(),
9+
#transforms.Normalize((0.1307,), (0.3081,))
10+
]))
11+
class_tot = [0] * 10
12+
data = []
13+
labels = []
14+
positive_tot = 0
15+
tot = 0
16+
perm = np.random.permutation(raw_dataset.__len__())
17+
for i in range(raw_dataset.__len__()):
18+
datum, label = raw_dataset.__getitem__(perm[i])
19+
if class_tot[label] < class_num:
20+
data.append(datum.numpy())
21+
labels.append(label)
22+
class_tot[label] += 1
23+
tot += 1
24+
if tot >= 10 * class_num:
25+
break
26+
return TensorDataset(torch.FloatTensor(np.array(data)), torch.LongTensor(np.array(labels)))
27+
28+
def MnistUnlabel():
29+
return datasets.MNIST('../data', train=True, download=True,
30+
transform=transforms.Compose([
31+
transforms.ToTensor(),
32+
#transforms.Normalize((0.1307,), (0.3081,))
33+
]))
34+
35+
def MnistTest():
36+
return datasets.MNIST('../data', train=False, download=True,
37+
transform=transforms.Compose([
38+
transforms.ToTensor(),
39+
transforms.Normalize((0.1307,), (0.3081,))
40+
]))
41+
42+
if __name__ == '__main__':
43+
print dir(MnistTest())

ImprovedGAN.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# -*- coding:utf-8 -*-
2+
from __future__ import print_function
3+
import torch
4+
import numpy as np
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
import torch.optim as optim
8+
from torch.autograd import Variable
9+
from functional import log_sum_exp
10+
from torch.utils.data import DataLoader,TensorDataset
11+
import sys
12+
import argparse
13+
from Nets import Generator, Discriminator
14+
from Datasets import *
15+
import pdb
16+
class ImprovedGAN(object):
17+
def __init__(self, G, D, labeled, unlabeled, test, args):
18+
self.G = G
19+
self.D = D
20+
if args.cuda:
21+
self.G.cuda()
22+
self.D.cuda()
23+
self.labeled = labeled
24+
self.unlabeled = unlabeled
25+
self.test = test
26+
self.Doptim = optim.SGD(self.D.parameters(), lr=args.lr, momentum = args.momentum)
27+
self.Goptim = optim.Adam(self.G.parameters(), lr=args.lr)
28+
self.args = args
29+
def trainD(self, x_label, y, x_unlabel):
30+
x_label, x_unlabel, y = Variable(x_label), Variable(x_unlabel), Variable(y, requires_grad = False)
31+
if self.args.cuda:
32+
x_label, x_unlabel, y = x_label.cuda(), x_unlabel.cuda(), y.cuda()
33+
output_label, output_unlabel, output_fake = self.D(x_label, cuda=self.args.cuda), self.D(x_unlabel, cuda=self.args.cuda), self.D(self.G(x_unlabel.size()[0], cuda = self.args.cuda).view(x_unlabel.size()), cuda=self.args.cuda)
34+
logz_label, logz_unlabel, logz_fake = log_sum_exp(output_label), log_sum_exp(output_unlabel), log_sum_exp(output_fake) # log ∑e^x_i
35+
prob_label = torch.gather(output_label, 1, y.unsqueeze(1)) # log e^x_label = x_label
36+
loss_supervised = -torch.mean(prob_label) + torch.mean(logz_label)
37+
loss_unsupervised = 0.5 * (-torch.mean(logz_unlabel) + torch.mean(F.softplus(logz_unlabel)) + # real_data: log Z/(1+Z)
38+
torch.mean(F.softplus(logz_fake)) ) # fake_data: log 1/(1+Z)
39+
loss = loss_supervised + self.args.unlabel_weight * loss_unsupervised
40+
acc = torch.mean((output_label.max(1)[1] == y).float())
41+
self.Doptim.zero_grad()
42+
loss.backward()
43+
if loss != loss:
44+
pdb.set_trace()
45+
self.Doptim.step()
46+
return loss_supervised.data.cpu().numpy(), loss_unsupervised.data.cpu().numpy(), acc
47+
48+
def trainG(self, x_unlabel):
49+
mom_gen = torch.mean(self.D(self.G(x_unlabel.size()[0], cuda = self.args.cuda).view(x_unlabel.size()), feature=True, cuda=self.args.cuda), dim = 0)
50+
mom_unlabel = torch.mean(self.D(Variable(x_unlabel), feature=True, cuda=self.args.cuda), dim = 0)
51+
loss = torch.mean((mom_gen - mom_unlabel) ** 2)
52+
self.Goptim.zero_grad()
53+
loss.backward()
54+
self.Goptim.step()
55+
a = self.G.main[0].weight != self.G.main[0].weight
56+
if torch.sum(a.float()) > 0:
57+
pdb.set_trace()
58+
return loss.data.cpu().numpy()
59+
60+
def train(self):
61+
assert self.unlabeled.__len__() > self.labeled.__len__()
62+
assert type(self.labeled) == TensorDataset
63+
times = int(np.ceil(self.unlabeled.__len__() * 1. / self.labeled.__len__()))
64+
t1 = self.labeled.data_tensor.clone()
65+
t2 = self.labeled.target_tensor.clone()
66+
#tile_labeled = TensorDataset(self.labeled.data_tensor.repeat(times, 1, 1, 1), self.labeled.target_tensor.repeat(times))
67+
tile_labeled = TensorDataset(t1.repeat(times,1,1,1),t2.repeat(times))
68+
for epoch in range(self.args.epochs):
69+
self.G.train()
70+
self.D.train()
71+
unlabel_loader1 = DataLoader(self.unlabeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True)
72+
unlabel_loader2 = DataLoader(self.unlabeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True).__iter__()
73+
label_loader = DataLoader(tile_labeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True).__iter__()
74+
batch_num = loss_supervised = loss_unsupervised = loss_gen = accuracy = 0.
75+
for (unlabel1, _label1) in unlabel_loader1:
76+
batch_num += 1
77+
unlabel2, _label2 = unlabel_loader2.next()
78+
x, y = label_loader.next()
79+
if args.cuda:
80+
x, y, unlabel1, unlabel2 = x.cuda(), y.cuda(), unlabel1.cuda(), unlabel2.cuda()
81+
ll, lu, acc = self.trainD(x, y, unlabel1)
82+
loss_supervised += ll
83+
loss_unsupervised += lu
84+
accuracy += acc
85+
lg = self.trainG(unlabel2)
86+
loss_gen += lg
87+
if (batch_num + 1) % self.args.log_interval == 0:
88+
print('Training: %d / %d' % (batch_num + 1, len(unlabel_loader1)))
89+
print('Eval: correct %d/%d, %.4f' % (self.eval(), self.test.__len__(), acc))
90+
loss_supervised /= batch_num
91+
loss_unsupervised /= batch_num
92+
loss_gen /= batch_num
93+
accuracy /= batch_num
94+
print("Iteration %d, loss_supervised = %.4f, loss_unsupervised = %.4f, loss_gen = %.4f train acc = %.4f" % (epoch, loss_supervised, loss_unsupervised, loss_gen, accuracy))
95+
sys.stdout.flush()
96+
if (epoch + 1) % self.args.eval_interval == 0:
97+
print("Eval: correct %d / %d" % (self.eval(), self.test.__len__()))
98+
99+
def predict(self, x):
100+
return torch.max(self.D(Variable(x, volatile=True), cuda=self.args.cuda), 1)[1].data
101+
def eval(self):
102+
self.G.eval()
103+
self.D.eval()
104+
d, l = [], []
105+
for (datum, label) in self.test:
106+
d.append(datum)
107+
l.append(label)
108+
x, y = torch.stack(d), torch.LongTensor(l)
109+
if self.args.cuda:
110+
x, y = x.cuda(), y.cuda()
111+
pred = self.predict(x)
112+
return torch.sum(pred == y)
113+
def draw(self, batch_size):
114+
self.G.eval()
115+
return self.G(batch_size, cuda=self.args.cuda)
116+
if __name__ == '__main__':
117+
parser = argparse.ArgumentParser(description='PyTorch Improved GAN')
118+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
119+
help='input batch size for training (default: 64)')
120+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
121+
help='number of epochs to train (default: 10)')
122+
parser.add_argument('--lr', type=float, default=0.003, metavar='LR',
123+
help='learning rate (default: 0.003)')
124+
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
125+
help='SGD momentum (default: 0.5)')
126+
parser.add_argument('--cuda', action='store_true', default=False,
127+
help='CUDA training')
128+
parser.add_argument('--seed', type=int, default=1, metavar='S',
129+
help='random seed (default: 1)')
130+
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
131+
help='how many batches to wait before logging training status')
132+
parser.add_argument('--eval-interval', type=int, default=1, metavar='N',
133+
help='how many batches to wait before evaling training status')
134+
parser.add_argument('--unlabel-weight', type=int, default=1, metavar='N',
135+
help='scale factor between labeled and unlabeled data')
136+
args = parser.parse_args()
137+
args.cuda = args.cuda and torch.cuda.is_available()
138+
np.random.seed(args.seed)
139+
gan = ImprovedGAN(Generator(100), Discriminator(), MnistLabel(10), MnistUnlabel(), MnistTest(), args)
140+
gan.train()
141+

Nets.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
from torch.autograd import Variable
5+
6+
# class Discriminator(nn.Module):
7+
# def __init__(self, output_units = 10):
8+
# super(Discriminator, self).__init__()
9+
# self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
10+
# self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
11+
# self.conv2_drop = nn.Dropout2d()
12+
# self.fc1 = nn.Linear(320, 100)
13+
# self.fc2 = nn.Linear(100, output_units)
14+
15+
# def forward(self, x, feature = False, cuda = False):
16+
# x = F.leaky_relu(F.max_pool2d(self.conv1(x), 2))
17+
# x = F.leaky_relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
18+
# x = x.view(-1, 320)
19+
# x_f = self.fc1(x)
20+
# x = F.leaky_relu(x_f)
21+
# x = F.dropout(x, training=self.training)
22+
# x = self.fc2(x)
23+
# return x if not feature else x_f
24+
25+
#class Discriminator(nn.Module):
26+
# def __init__(self, input_dim = 28 ** 2, output_dim = 10):
27+
# super(Discriminator, self).__init__()
28+
# self.input_dim = input_dim
29+
# self.layers = torch.nn.ModuleList([
30+
# nn.Linear(input_dim, 1000),
31+
# nn.Linear(1000, 500),
32+
# nn.Linear(500, 250)]
33+
# )
34+
# self.bns = torch.nn.ModuleList([
35+
# nn.BatchNorm1d(1000, affine=False),
36+
# nn.BatchNorm1d(500, affine=False),
37+
# nn.BatchNorm1d(250, affine=True)]
38+
# )
39+
# self.final = nn.Linear(250, output_dim)
40+
# def forward(self, x, feature = False, cuda = False):
41+
# x = x.view(-1, self.input_dim)
42+
# noise = torch.randn(x.size()) * 0.3 if self.training else torch.Tensor([0])
43+
# if cuda:
44+
# noise = noise.cuda()
45+
# x = x + Variable(noise, requires_grad = False)
46+
# for i in range(len(self.layers)):
47+
# m = self.layers[i]
48+
# bn = self.bns[i]
49+
# x_f = F.leaky_relu(m(x))
50+
# noise = torch.randn(x_f.size()) * 0.5 if self.training else torch.Tensor([0])
51+
# if cuda:
52+
# noise = noise.cuda()
53+
# x = (x_f + Variable(noise, requires_grad = False))
54+
# if feature:
55+
# return x_f
56+
# return self.final(x)
57+
58+
59+
# class Generator(nn.Module):
60+
# def __init__(self, z_dim, output_dim = 28 ** 2):
61+
# super(Generator, self).__init__()
62+
# self.z_dim = z_dim
63+
# self.fc1 = nn.Linear(z_dim, 500)
64+
# self.bn1 = nn.BatchNorm1d(500, affine = False)
65+
# self.fc2 = nn.Linear(500, 500)
66+
# self.bn2 = nn.BatchNorm1d(500, affine = False)
67+
# self.fc3 = nn.Linear(500, output_dim)
68+
# def forward(self, batch_size, cuda = False):
69+
# x = Variable(torch.rand(batch_size, self.z_dim), requires_grad = False)
70+
# if cuda:
71+
# x = x.cuda()
72+
# x = F.softplus(self.bn1(self.fc1(x)))
73+
# #x = F.softplus(self.bn2(self.fc2(x)))
74+
# x = F.sigmoid(self.fc3(x))
75+
# return x
76+
77+
class Discriminator(nn.Module):
78+
def __init__(self, nc = 1, ndf = 64, output_units = 10):
79+
super(Discriminator, self).__init__()
80+
self.ndf = ndf
81+
self.main = nn.Sequential(
82+
# state size. (nc) x 28 x 28
83+
nn.Conv2d(nc, ndf, 4, 2, 3, bias=False),
84+
nn.BatchNorm2d(ndf),
85+
nn.LeakyReLU(0.2, inplace=True),
86+
# state size. (ndf) x 16 x 16
87+
nn.Conv2d(ndf, ndf * 4, 4, 2, 1, bias=False),
88+
nn.BatchNorm2d(ndf * 4),
89+
nn.LeakyReLU(0.2, inplace=True),
90+
# state size. (ndf*2) x 8 x 8
91+
nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1, bias=False),
92+
nn.BatchNorm2d(ndf * 4),
93+
nn.LeakyReLU(0.2, inplace=True),
94+
# state size. (ndf*4) x 4 x 4
95+
nn.Conv2d(ndf * 4, ndf * 4, 4, 1, 0, bias=False),
96+
)
97+
self.final = nn.Linear(ndf * 4, output_units, bias=False)
98+
def forward(self, x, feature = False, cuda = False):
99+
x_f = self.main(x).view(-1, self.ndf * 4)
100+
return x_f if feature else self.final(x_f)
101+
102+
class Generator(nn.Module):
103+
def __init__(self, z_dim, ngf = 64, output_dim = 28 ** 2):
104+
super(Generator, self).__init__()
105+
self.z_dim = z_dim
106+
self.main = nn.Sequential(
107+
# input is Z, going into a convolution
108+
nn.ConvTranspose2d(z_dim, ngf * 4, 4, 1, 0, bias=False),
109+
nn.BatchNorm2d(ngf * 4),
110+
nn.ReLU(True),
111+
# state size. (ngf*8) x 4 x 4
112+
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
113+
nn.BatchNorm2d(ngf * 2),
114+
nn.ReLU(True),
115+
# state size. (ngf*4) x 8 x 8
116+
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
117+
nn.BatchNorm2d(ngf),
118+
nn.ReLU(True),
119+
# state size. (ngf*2) x 16 x 16
120+
nn.ConvTranspose2d(ngf, 1, 4, 2, 3, bias=False),
121+
# state size. (ngf) x 32 x 32
122+
nn.Sigmoid()
123+
)
124+
def forward(self, batch_size, cuda = False):
125+
x = Variable(torch.rand(batch_size, self.z_dim, 1, 1), requires_grad = False, volatile = self.training)
126+
if cuda:
127+
x = x.cuda()
128+
return self.main(x)

functional.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import torch
2+
3+
def log_sum_exp(x, axis = 1):
4+
m = torch.max(x, dim = 1)[0]
5+
return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(1)), dim = axis))

0 commit comments

Comments
 (0)