Skip to content

Commit a373f2b

Browse files
authored
Merge pull request #1374 from xiezl/patch-4
Add the implementations for the gan model in the peft
2 parents 63c6247 + d1c0de6 commit a373f2b

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import layer
21+
from singa import model
22+
from singa import autograd
23+
24+
25+
class LSGAN_MLP(model.Model):
26+
27+
def __init__(self, noise_size=100, feature_size=784, hidden_size=128):
28+
super(LSGAN_MLP, self).__init__()
29+
self.noise_size = noise_size
30+
self.feature_size = feature_size
31+
self.hidden_size = hidden_size
32+
33+
# Generative Net
34+
self.gen_net_fc_0 = layer.Linear(self.hidden_size)
35+
self.gen_net_relu_0 = layer.ReLU()
36+
self.gen_net_fc_1 = layer.Linear(self.feature_size)
37+
self.gen_net_sigmoid_1 = layer.Sigmoid()
38+
39+
# Discriminative Net
40+
self.dis_net_fc_0 = layer.Linear(self.hidden_size)
41+
self.dis_net_relu_0 = layer.ReLU()
42+
self.dis_net_fc_1 = layer.Linear(1)
43+
self.mse_loss = layer.MeanSquareError()
44+
45+
def forward(self, x):
46+
# Cascaded Net
47+
y = self.forward_gen(x)
48+
y = self.forward_dis(y)
49+
return y
50+
51+
def forward_dis(self, x):
52+
# Discriminative Net
53+
y = self.dis_net_fc_0(x)
54+
y = self.dis_net_relu_0(y)
55+
y = self.dis_net_fc_1(y)
56+
return y
57+
58+
def forward_gen(self, x):
59+
# Generative Net
60+
y = self.gen_net_fc_0(x)
61+
y = self.gen_net_relu_0(y)
62+
y = self.gen_net_fc_1(y)
63+
y = self.gen_net_sigmoid_1(y)
64+
return y
65+
66+
def train_one_batch(self, x, y):
67+
# Training the Generative Net
68+
out = self.forward(x)
69+
loss = self.mse_loss(out, y)
70+
# Only update the Generative Net
71+
for p, g in autograd.backward(loss):
72+
if "gen_net" in p.name:
73+
self.optimizer.apply(p.name, p, g)
74+
return out, loss
75+
76+
def train_one_batch_dis(self, x, y):
77+
# Training the Discriminative Net
78+
out = self.forward_dis(x)
79+
loss = self.mse_loss(out, y)
80+
# Only update the Discriminative Net
81+
for p, g in autograd.backward(loss):
82+
if "dis_net" in p.name:
83+
self.optimizer.apply(p.name, p, g)
84+
return out, loss
85+
86+
def set_optimizer(self, optimizer):
87+
self.optimizer = optimizer
88+
89+
90+
def create_model(pretrained=False, **kwargs):
91+
"""Construct a CNN model.
92+
93+
Args:
94+
pretrained (bool): If True, returns a model pre-trained
95+
"""
96+
model = LSGAN_MLP(**kwargs)
97+
98+
return model
99+
100+
101+
__all__ = ['LSGAN_MLP', 'create_model']

0 commit comments

Comments
 (0)