|
| 1 | +# -*- coding:utf-8 -*- |
| 2 | +from __future__ import print_function |
| 3 | +import torch |
| 4 | +import numpy as np |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.functional as F |
| 7 | +import torch.optim as optim |
| 8 | +from torch.autograd import Variable |
| 9 | +from functional import log_sum_exp |
| 10 | +from torch.utils.data import DataLoader,TensorDataset |
| 11 | +import sys |
| 12 | +import argparse |
| 13 | +from Nets import Generator, Discriminator |
| 14 | +from Datasets import * |
| 15 | +import pdb |
| 16 | +class ImprovedGAN(object): |
| 17 | + def __init__(self, G, D, labeled, unlabeled, test, args): |
| 18 | + self.G = G |
| 19 | + self.D = D |
| 20 | + if args.cuda: |
| 21 | + self.G.cuda() |
| 22 | + self.D.cuda() |
| 23 | + self.labeled = labeled |
| 24 | + self.unlabeled = unlabeled |
| 25 | + self.test = test |
| 26 | + self.Doptim = optim.SGD(self.D.parameters(), lr=args.lr, momentum = args.momentum) |
| 27 | + self.Goptim = optim.Adam(self.G.parameters(), lr=args.lr) |
| 28 | + self.args = args |
| 29 | + def trainD(self, x_label, y, x_unlabel): |
| 30 | + x_label, x_unlabel, y = Variable(x_label), Variable(x_unlabel), Variable(y, requires_grad = False) |
| 31 | + if self.args.cuda: |
| 32 | + x_label, x_unlabel, y = x_label.cuda(), x_unlabel.cuda(), y.cuda() |
| 33 | + output_label, output_unlabel, output_fake = self.D(x_label, cuda=self.args.cuda), self.D(x_unlabel, cuda=self.args.cuda), self.D(self.G(x_unlabel.size()[0], cuda = self.args.cuda).view(x_unlabel.size()), cuda=self.args.cuda) |
| 34 | + logz_label, logz_unlabel, logz_fake = log_sum_exp(output_label), log_sum_exp(output_unlabel), log_sum_exp(output_fake) # log ∑e^x_i |
| 35 | + prob_label = torch.gather(output_label, 1, y.unsqueeze(1)) # log e^x_label = x_label |
| 36 | + loss_supervised = -torch.mean(prob_label) + torch.mean(logz_label) |
| 37 | + loss_unsupervised = 0.5 * (-torch.mean(logz_unlabel) + torch.mean(F.softplus(logz_unlabel)) + # real_data: log Z/(1+Z) |
| 38 | + torch.mean(F.softplus(logz_fake)) ) # fake_data: log 1/(1+Z) |
| 39 | + loss = loss_supervised + self.args.unlabel_weight * loss_unsupervised |
| 40 | + acc = torch.mean((output_label.max(1)[1] == y).float()) |
| 41 | + self.Doptim.zero_grad() |
| 42 | + loss.backward() |
| 43 | + if loss != loss: |
| 44 | + pdb.set_trace() |
| 45 | + self.Doptim.step() |
| 46 | + return loss_supervised.data.cpu().numpy(), loss_unsupervised.data.cpu().numpy(), acc |
| 47 | + |
| 48 | + def trainG(self, x_unlabel): |
| 49 | + mom_gen = torch.mean(self.D(self.G(x_unlabel.size()[0], cuda = self.args.cuda).view(x_unlabel.size()), feature=True, cuda=self.args.cuda), dim = 0) |
| 50 | + mom_unlabel = torch.mean(self.D(Variable(x_unlabel), feature=True, cuda=self.args.cuda), dim = 0) |
| 51 | + loss = torch.mean((mom_gen - mom_unlabel) ** 2) |
| 52 | + self.Goptim.zero_grad() |
| 53 | + loss.backward() |
| 54 | + self.Goptim.step() |
| 55 | + a = self.G.main[0].weight != self.G.main[0].weight |
| 56 | + if torch.sum(a.float()) > 0: |
| 57 | + pdb.set_trace() |
| 58 | + return loss.data.cpu().numpy() |
| 59 | + |
| 60 | + def train(self): |
| 61 | + assert self.unlabeled.__len__() > self.labeled.__len__() |
| 62 | + assert type(self.labeled) == TensorDataset |
| 63 | + times = int(np.ceil(self.unlabeled.__len__() * 1. / self.labeled.__len__())) |
| 64 | + t1 = self.labeled.data_tensor.clone() |
| 65 | + t2 = self.labeled.target_tensor.clone() |
| 66 | + #tile_labeled = TensorDataset(self.labeled.data_tensor.repeat(times, 1, 1, 1), self.labeled.target_tensor.repeat(times)) |
| 67 | + tile_labeled = TensorDataset(t1.repeat(times,1,1,1),t2.repeat(times)) |
| 68 | + for epoch in range(self.args.epochs): |
| 69 | + self.G.train() |
| 70 | + self.D.train() |
| 71 | + unlabel_loader1 = DataLoader(self.unlabeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True) |
| 72 | + unlabel_loader2 = DataLoader(self.unlabeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True).__iter__() |
| 73 | + label_loader = DataLoader(tile_labeled, batch_size = self.args.batch_size, shuffle=True, drop_last=True).__iter__() |
| 74 | + batch_num = loss_supervised = loss_unsupervised = loss_gen = accuracy = 0. |
| 75 | + for (unlabel1, _label1) in unlabel_loader1: |
| 76 | + batch_num += 1 |
| 77 | + unlabel2, _label2 = unlabel_loader2.next() |
| 78 | + x, y = label_loader.next() |
| 79 | + if args.cuda: |
| 80 | + x, y, unlabel1, unlabel2 = x.cuda(), y.cuda(), unlabel1.cuda(), unlabel2.cuda() |
| 81 | + ll, lu, acc = self.trainD(x, y, unlabel1) |
| 82 | + loss_supervised += ll |
| 83 | + loss_unsupervised += lu |
| 84 | + accuracy += acc |
| 85 | + lg = self.trainG(unlabel2) |
| 86 | + loss_gen += lg |
| 87 | + if (batch_num + 1) % self.args.log_interval == 0: |
| 88 | + print('Training: %d / %d' % (batch_num + 1, len(unlabel_loader1))) |
| 89 | + print('Eval: correct %d/%d, %.4f' % (self.eval(), self.test.__len__(), acc)) |
| 90 | + loss_supervised /= batch_num |
| 91 | + loss_unsupervised /= batch_num |
| 92 | + loss_gen /= batch_num |
| 93 | + accuracy /= batch_num |
| 94 | + print("Iteration %d, loss_supervised = %.4f, loss_unsupervised = %.4f, loss_gen = %.4f train acc = %.4f" % (epoch, loss_supervised, loss_unsupervised, loss_gen, accuracy)) |
| 95 | + sys.stdout.flush() |
| 96 | + if (epoch + 1) % self.args.eval_interval == 0: |
| 97 | + print("Eval: correct %d / %d" % (self.eval(), self.test.__len__())) |
| 98 | + |
| 99 | + def predict(self, x): |
| 100 | + return torch.max(self.D(Variable(x, volatile=True), cuda=self.args.cuda), 1)[1].data |
| 101 | + def eval(self): |
| 102 | + self.G.eval() |
| 103 | + self.D.eval() |
| 104 | + d, l = [], [] |
| 105 | + for (datum, label) in self.test: |
| 106 | + d.append(datum) |
| 107 | + l.append(label) |
| 108 | + x, y = torch.stack(d), torch.LongTensor(l) |
| 109 | + if self.args.cuda: |
| 110 | + x, y = x.cuda(), y.cuda() |
| 111 | + pred = self.predict(x) |
| 112 | + return torch.sum(pred == y) |
| 113 | + def draw(self, batch_size): |
| 114 | + self.G.eval() |
| 115 | + return self.G(batch_size, cuda=self.args.cuda) |
| 116 | +if __name__ == '__main__': |
| 117 | + parser = argparse.ArgumentParser(description='PyTorch Improved GAN') |
| 118 | + parser.add_argument('--batch-size', type=int, default=64, metavar='N', |
| 119 | + help='input batch size for training (default: 64)') |
| 120 | + parser.add_argument('--epochs', type=int, default=10, metavar='N', |
| 121 | + help='number of epochs to train (default: 10)') |
| 122 | + parser.add_argument('--lr', type=float, default=0.003, metavar='LR', |
| 123 | + help='learning rate (default: 0.003)') |
| 124 | + parser.add_argument('--momentum', type=float, default=0.5, metavar='M', |
| 125 | + help='SGD momentum (default: 0.5)') |
| 126 | + parser.add_argument('--cuda', action='store_true', default=False, |
| 127 | + help='CUDA training') |
| 128 | + parser.add_argument('--seed', type=int, default=1, metavar='S', |
| 129 | + help='random seed (default: 1)') |
| 130 | + parser.add_argument('--log-interval', type=int, default=100, metavar='N', |
| 131 | + help='how many batches to wait before logging training status') |
| 132 | + parser.add_argument('--eval-interval', type=int, default=1, metavar='N', |
| 133 | + help='how many batches to wait before evaling training status') |
| 134 | + parser.add_argument('--unlabel-weight', type=int, default=1, metavar='N', |
| 135 | + help='scale factor between labeled and unlabeled data') |
| 136 | + args = parser.parse_args() |
| 137 | + args.cuda = args.cuda and torch.cuda.is_available() |
| 138 | + np.random.seed(args.seed) |
| 139 | + gan = ImprovedGAN(Generator(100), Discriminator(), MnistLabel(10), MnistUnlabel(), MnistTest(), args) |
| 140 | + gan.train() |
| 141 | + |
0 commit comments