Skip to content

Commit 2a850a6

Browse files
authored
Merge pull request #1376 from gzrp/dev-postgresql
Add the implmentation of generative models for PEFT
2 parents 0492192 + d923d31 commit 2a850a6

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import layer
21+
from singa import model
22+
from singa import autograd
23+
24+
25+
class GAN_MLP(model.Model):
26+
27+
def __init__(self, noise_size=100, feature_size=784, hidden_size=128):
28+
super(GAN_MLP, self).__init__()
29+
self.noise_size = noise_size
30+
self.feature_size = feature_size
31+
self.hidden_size = hidden_size
32+
33+
# Generative Net
34+
self.gen_net_fc_0 = layer.Linear(self.hidden_size)
35+
self.gen_net_relu_0 = layer.ReLU()
36+
self.gen_net_fc_1 = layer.Linear(self.feature_size)
37+
self.gen_net_sigmoid_1 = layer.Sigmoid()
38+
39+
# Discriminative Net
40+
self.dis_net_fc_0 = layer.Linear(self.hidden_size)
41+
self.dis_net_relu_0 = layer.ReLU()
42+
self.dis_net_fc_1 = layer.Linear(1)
43+
self.dis_net_sigmoid_1= layer.Sigmoid()
44+
self.binary_cross_entropy = layer.BinaryCrossEntropy()
45+
46+
def forward(self, x):
47+
# Cascaded Net
48+
y = self.forward_gen(x)
49+
y = self.forward_dis(y)
50+
return y
51+
52+
def forward_dis(self, x):
53+
# Discriminative Net
54+
y = self.dis_net_fc_0(x)
55+
y = self.dis_net_relu_0(y)
56+
y = self.dis_net_fc_1(y)
57+
y = self.dis_net_sigmoid_1(y)
58+
return y
59+
60+
def forward_gen(self, x):
61+
# Generative Net
62+
y = self.gen_net_fc_0(x)
63+
y = self.gen_net_relu_0(y)
64+
y = self.gen_net_fc_1(y)
65+
y = self.gen_net_sigmoid_1(y)
66+
return y
67+
68+
def train_one_batch(self, x, y):
69+
# Training the Generative Net
70+
out = self.forward(x)
71+
loss = self.binary_cross_entropy(out, y)
72+
# Only update the Generative Net
73+
for p, g in autograd.backward(loss):
74+
if "gen_net" in p.name:
75+
self.optimizer.apply(p.name, p, g)
76+
return out, loss
77+
78+
def train_one_batch_dis(self, x, y):
79+
# Training the Discriminative Net
80+
out = self.forward_dis(x)
81+
loss = self.binary_cross_entropy(out, y)
82+
# Only update the Discriminative Net
83+
for p, g in autograd.backward(loss):
84+
if "dis_net" in p.name:
85+
self.optimizer.apply(p.name, p, g)
86+
self.optimizer(loss)
87+
return out, loss
88+
89+
def set_optimizer(self, optimizer):
90+
self.optimizer = optimizer
91+
92+
93+
def create_model(pretrained=False, **kwargs):
94+
"""Constructs a CNN model.
95+
96+
Args:
97+
pretrained (bool): If True, returns a model pre-trained
98+
"""
99+
model = GAN_MLP(**kwargs)
100+
101+
return model
102+
103+
104+
__all__ = ['GAN_MLP', 'create_model']

0 commit comments

Comments
 (0)