Skip to content

Commit 56b4f18

Browse files
authored
Merge pull request #1371 from npcmaci/format21
Add the implementations for the gan model in the peft
2 parents 44b37bb + fd4d745 commit 56b4f18

File tree

1 file changed

+186
-0
lines changed
  • examples/singa_peft/examples/model

1 file changed

+186
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import device
21+
from singa import opt
22+
from singa import tensor
23+
24+
import argparse
25+
import matplotlib.pyplot as plt
26+
import numpy as np
27+
import os
28+
from model import lsgan_mlp
29+
from utils import load_data
30+
from utils import print_log
31+
32+
33+
class LSGAN():
34+
35+
def __init__(self,
36+
dev,
37+
rows=28,
38+
cols=28,
39+
channels=1,
40+
noise_size=100,
41+
hidden_size=128,
42+
batch=128,
43+
interval=1000,
44+
learning_rate=0.001,
45+
iterations=1000000,
46+
d_steps=3,
47+
g_steps=1,
48+
dataset_filepath='mnist.pkl.gz',
49+
file_dir='lsgan_images/'):
50+
self.dev = dev
51+
self.rows = rows
52+
self.cols = cols
53+
self.channels = channels
54+
self.feature_size = self.rows * self.cols * self.channels
55+
self.noise_size = noise_size
56+
self.hidden_size = hidden_size
57+
self.batch = batch
58+
self.batch_size = self.batch // 2
59+
self.interval = interval
60+
self.learning_rate = learning_rate
61+
self.iterations = iterations
62+
self.d_steps = d_steps
63+
self.g_steps = g_steps
64+
self.dataset_filepath = dataset_filepath
65+
self.file_dir = file_dir
66+
self.model = lsgan_mlp.create_model(noise_size=self.noise_size,
67+
feature_size=self.feature_size,
68+
hidden_size=self.hidden_size)
69+
70+
def train(self):
71+
train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
72+
dev = device.create_cuda_gpu_on(0)
73+
dev.SetRandSeed(0)
74+
np.random.seed(0)
75+
76+
#sgd = opt.SGD(lr=self.learning_rate, momentum=0.9, weight_decay=1e-5)
77+
sgd = opt.Adam(lr=self.learning_rate)
78+
79+
noise = tensor.Tensor((self.batch_size, self.noise_size), dev,
80+
tensor.float32)
81+
real_images = tensor.Tensor((self.batch_size, self.feature_size), dev,
82+
tensor.float32)
83+
real_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
84+
fake_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
85+
substrahend_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32)
86+
87+
# attached model to graph
88+
self.model.set_optimizer(sgd)
89+
self.model.compile([noise],
90+
is_train=True,
91+
use_graph=False,
92+
sequential=True)
93+
94+
real_labels.set_value(1.0)
95+
fake_labels.set_value(-1.0)
96+
substrahend_labels.set_value(0.0)
97+
98+
for iteration in range(self.iterations):
99+
100+
for d_step in range(self.d_steps):
101+
idx = np.random.randint(0, train_data.shape[0], self.batch_size)
102+
real_images.copy_from_numpy(train_data[idx])
103+
104+
self.model.train()
105+
106+
# Training the Discriminative Net
107+
_, d_loss_real = self.model.train_one_batch_dis(
108+
real_images, real_labels)
109+
110+
noise.uniform(-1, 1)
111+
fake_images = self.model.forward_gen(noise)
112+
_, d_loss_fake = self.model.train_one_batch_dis(
113+
fake_images, fake_labels)
114+
115+
d_loss = tensor.to_numpy(d_loss_real)[0] + tensor.to_numpy(
116+
d_loss_fake)[0]
117+
118+
for g_step in range(self.g_steps):
119+
# Training the Generative Net
120+
noise.uniform(-1, 1)
121+
_, g_loss_tensor = self.model.train_one_batch(
122+
noise, substrahend_labels)
123+
124+
g_loss = tensor.to_numpy(g_loss_tensor)[0]
125+
126+
if iteration % self.interval == 0:
127+
self.model.eval()
128+
self.save_image(iteration)
129+
print_log(' The {} iteration, G_LOSS: {}, D_LOSS: {}'.format(
130+
iteration, g_loss, d_loss))
131+
132+
def save_image(self, iteration):
133+
demo_row = 5
134+
demo_col = 5
135+
if not hasattr(self, "demo_noise"):
136+
self.demo_noise = tensor.Tensor(
137+
(demo_col * demo_row, self.noise_size), dev, tensor.float32)
138+
self.demo_noise.uniform(-1, 1)
139+
gen_imgs = self.model.forward_gen(self.demo_noise)
140+
gen_imgs = tensor.to_numpy(gen_imgs)
141+
show_imgs = np.reshape(
142+
gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
143+
fig, axs = plt.subplots(demo_row, demo_col)
144+
cnt = 0
145+
for r in range(demo_row):
146+
for c in range(demo_col):
147+
axs[r, c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
148+
axs[r, c].axis('off')
149+
cnt += 1
150+
fig.savefig("{}{}.png".format(self.file_dir, iteration))
151+
plt.close()
152+
153+
154+
if __name__ == '__main__':
155+
parser = argparse.ArgumentParser(description='Train GAN over MNIST')
156+
parser.add_argument('filepath', type=str, help='the dataset path')
157+
parser.add_argument('--use_gpu', action='store_true')
158+
args = parser.parse_args()
159+
160+
if args.use_gpu:
161+
print('Using GPU')
162+
dev = device.create_cuda_gpu()
163+
else:
164+
print('Using CPU')
165+
dev = device.get_default_device()
166+
167+
if not os.path.exists('lsgan_images/'):
168+
os.makedirs('lsgan_images/')
169+
170+
rows = 28
171+
cols = 28
172+
channels = 1
173+
noise_size = 100
174+
hidden_size = 128
175+
batch = 128
176+
interval = 1000
177+
learning_rate = 0.0005
178+
iterations = 1000000
179+
d_steps = 1
180+
g_steps = 1
181+
dataset_filepath = 'mnist.pkl.gz'
182+
file_dir = 'lsgan_images/'
183+
lsgan = LSGAN(dev, rows, cols, channels, noise_size, hidden_size, batch,
184+
interval, learning_rate, iterations, d_steps, g_steps,
185+
dataset_filepath, file_dir)
186+
lsgan.train()

0 commit comments

Comments
 (0)