-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlsgan.py
57 lines (41 loc) · 1.74 KB
/
lsgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
import math
import torch
from torch.autograd import Variable
import torch.utils.data
import torchvision.utils as vutils
from gan import GAN_base
class LSGAN(GAN_base):
def __init__(self, netG, netD, optimizerD, optimizerG, opt):
GAN_base.__init__(self, netG, netD, optimizerD, optimizerG, opt)
# criterion for training
self.criterion = self.LSGANLoss
self.real_label = 1
self.fake_label = -1
self.generator_label = 1 # fake labels are real for generator cost
shift = torch.ones(opt.batch_size)
self.shift = Variable(shift.cuda()) if self.is_cuda else Variable(shift)
def compute_disc_score(self, data_a, data_b):
th = torch.cuda if self.is_cuda else torch
if self.conditional:
data_a = self.join_xy(data_a)
data_b = self.join_xy(data_b)
scores_a = self.netD(data_a)
scores_b = self.netD(data_b)
labels_a = Variable(th.FloatTensor(scores_a.size(0)).fill_(self.real_label))
errD_a = self.criterion(scores_a, labels_a)
labels_b = Variable(th.FloatTensor(scores_b.size(0)).fill_(self.fake_label))
errD_b = self.criterion(scores_b, labels_b)
errD = errD_a + errD_b
return errD
def compute_gen_score(self, data):
th = torch.cuda if self.is_cuda else torch
if self.conditional:
data = self.join_xy(data)
scores = self.netD(data)
labels = Variable(th.FloatTensor(scores.size()).fill_(self.generator_label))
errG = self.criterion(scores, labels)
return errG
def LSGANLoss(self, scores, labels):
loss = torch.nn.MSELoss(size_average=True)
return loss(scores * labels, self.shift[:labels.size()[0]])