|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 3 | +# or more contributor license agreements. See the NOTICE file |
| 4 | +# distributed with this work for additional information |
| 5 | +# regarding copyright ownership. The ASF licenses this file |
| 6 | +# to you under the Apache License, Version 2.0 (the |
| 7 | +# "License"); you may not use this file except in compliance |
| 8 | +# with the License. You may obtain a copy of the License at |
| 9 | +# |
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, |
| 13 | +# software distributed under the License is distributed on an |
| 14 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | +# KIND, either express or implied. See the License for the |
| 16 | +# specific language governing permissions and limitations |
| 17 | +# under the License. |
| 18 | +# |
| 19 | + |
| 20 | +from singa import device |
| 21 | +from singa import opt |
| 22 | +from singa import tensor |
| 23 | + |
| 24 | +import argparse |
| 25 | +import matplotlib.pyplot as plt |
| 26 | +import numpy as np |
| 27 | +import os |
| 28 | +from model import lsgan_mlp |
| 29 | +from utils import load_data |
| 30 | +from utils import print_log |
| 31 | + |
| 32 | + |
| 33 | +class LSGAN(): |
| 34 | + |
| 35 | + def __init__(self, |
| 36 | + dev, |
| 37 | + rows=28, |
| 38 | + cols=28, |
| 39 | + channels=1, |
| 40 | + noise_size=100, |
| 41 | + hidden_size=128, |
| 42 | + batch=128, |
| 43 | + interval=1000, |
| 44 | + learning_rate=0.001, |
| 45 | + iterations=1000000, |
| 46 | + d_steps=3, |
| 47 | + g_steps=1, |
| 48 | + dataset_filepath='mnist.pkl.gz', |
| 49 | + file_dir='lsgan_images/'): |
| 50 | + self.dev = dev |
| 51 | + self.rows = rows |
| 52 | + self.cols = cols |
| 53 | + self.channels = channels |
| 54 | + self.feature_size = self.rows * self.cols * self.channels |
| 55 | + self.noise_size = noise_size |
| 56 | + self.hidden_size = hidden_size |
| 57 | + self.batch = batch |
| 58 | + self.batch_size = self.batch // 2 |
| 59 | + self.interval = interval |
| 60 | + self.learning_rate = learning_rate |
| 61 | + self.iterations = iterations |
| 62 | + self.d_steps = d_steps |
| 63 | + self.g_steps = g_steps |
| 64 | + self.dataset_filepath = dataset_filepath |
| 65 | + self.file_dir = file_dir |
| 66 | + self.model = lsgan_mlp.create_model(noise_size=self.noise_size, |
| 67 | + feature_size=self.feature_size, |
| 68 | + hidden_size=self.hidden_size) |
| 69 | + |
| 70 | + def train(self): |
| 71 | + train_data, _, _, _, _, _ = load_data(self.dataset_filepath) |
| 72 | + dev = device.create_cuda_gpu_on(0) |
| 73 | + dev.SetRandSeed(0) |
| 74 | + np.random.seed(0) |
| 75 | + |
| 76 | + #sgd = opt.SGD(lr=self.learning_rate, momentum=0.9, weight_decay=1e-5) |
| 77 | + sgd = opt.Adam(lr=self.learning_rate) |
| 78 | + |
| 79 | + noise = tensor.Tensor((self.batch_size, self.noise_size), dev, |
| 80 | + tensor.float32) |
| 81 | + real_images = tensor.Tensor((self.batch_size, self.feature_size), dev, |
| 82 | + tensor.float32) |
| 83 | + real_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32) |
| 84 | + fake_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32) |
| 85 | + substrahend_labels = tensor.Tensor((self.batch_size, 1), dev, tensor.float32) |
| 86 | + |
| 87 | + # attached model to graph |
| 88 | + self.model.set_optimizer(sgd) |
| 89 | + self.model.compile([noise], |
| 90 | + is_train=True, |
| 91 | + use_graph=False, |
| 92 | + sequential=True) |
| 93 | + |
| 94 | + real_labels.set_value(1.0) |
| 95 | + fake_labels.set_value(-1.0) |
| 96 | + substrahend_labels.set_value(0.0) |
| 97 | + |
| 98 | + for iteration in range(self.iterations): |
| 99 | + |
| 100 | + for d_step in range(self.d_steps): |
| 101 | + idx = np.random.randint(0, train_data.shape[0], self.batch_size) |
| 102 | + real_images.copy_from_numpy(train_data[idx]) |
| 103 | + |
| 104 | + self.model.train() |
| 105 | + |
| 106 | + # Training the Discriminative Net |
| 107 | + _, d_loss_real = self.model.train_one_batch_dis( |
| 108 | + real_images, real_labels) |
| 109 | + |
| 110 | + noise.uniform(-1, 1) |
| 111 | + fake_images = self.model.forward_gen(noise) |
| 112 | + _, d_loss_fake = self.model.train_one_batch_dis( |
| 113 | + fake_images, fake_labels) |
| 114 | + |
| 115 | + d_loss = tensor.to_numpy(d_loss_real)[0] + tensor.to_numpy( |
| 116 | + d_loss_fake)[0] |
| 117 | + |
| 118 | + for g_step in range(self.g_steps): |
| 119 | + # Training the Generative Net |
| 120 | + noise.uniform(-1, 1) |
| 121 | + _, g_loss_tensor = self.model.train_one_batch( |
| 122 | + noise, substrahend_labels) |
| 123 | + |
| 124 | + g_loss = tensor.to_numpy(g_loss_tensor)[0] |
| 125 | + |
| 126 | + if iteration % self.interval == 0: |
| 127 | + self.model.eval() |
| 128 | + self.save_image(iteration) |
| 129 | + print_log(' The {} iteration, G_LOSS: {}, D_LOSS: {}'.format( |
| 130 | + iteration, g_loss, d_loss)) |
| 131 | + |
| 132 | + def save_image(self, iteration): |
| 133 | + demo_row = 5 |
| 134 | + demo_col = 5 |
| 135 | + if not hasattr(self, "demo_noise"): |
| 136 | + self.demo_noise = tensor.Tensor( |
| 137 | + (demo_col * demo_row, self.noise_size), dev, tensor.float32) |
| 138 | + self.demo_noise.uniform(-1, 1) |
| 139 | + gen_imgs = self.model.forward_gen(self.demo_noise) |
| 140 | + gen_imgs = tensor.to_numpy(gen_imgs) |
| 141 | + show_imgs = np.reshape( |
| 142 | + gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels)) |
| 143 | + fig, axs = plt.subplots(demo_row, demo_col) |
| 144 | + cnt = 0 |
| 145 | + for r in range(demo_row): |
| 146 | + for c in range(demo_col): |
| 147 | + axs[r, c].imshow(show_imgs[cnt, :, :, 0], cmap='gray') |
| 148 | + axs[r, c].axis('off') |
| 149 | + cnt += 1 |
| 150 | + fig.savefig("{}{}.png".format(self.file_dir, iteration)) |
| 151 | + plt.close() |
| 152 | + |
| 153 | + |
| 154 | +if __name__ == '__main__': |
| 155 | + parser = argparse.ArgumentParser(description='Train GAN over MNIST') |
| 156 | + parser.add_argument('filepath', type=str, help='the dataset path') |
| 157 | + parser.add_argument('--use_gpu', action='store_true') |
| 158 | + args = parser.parse_args() |
| 159 | + |
| 160 | + if args.use_gpu: |
| 161 | + print('Using GPU') |
| 162 | + dev = device.create_cuda_gpu() |
| 163 | + else: |
| 164 | + print('Using CPU') |
| 165 | + dev = device.get_default_device() |
| 166 | + |
| 167 | + if not os.path.exists('lsgan_images/'): |
| 168 | + os.makedirs('lsgan_images/') |
| 169 | + |
| 170 | + rows = 28 |
| 171 | + cols = 28 |
| 172 | + channels = 1 |
| 173 | + noise_size = 100 |
| 174 | + hidden_size = 128 |
| 175 | + batch = 128 |
| 176 | + interval = 1000 |
| 177 | + learning_rate = 0.0005 |
| 178 | + iterations = 1000000 |
| 179 | + d_steps = 1 |
| 180 | + g_steps = 1 |
| 181 | + dataset_filepath = 'mnist.pkl.gz' |
| 182 | + file_dir = 'lsgan_images/' |
| 183 | + lsgan = LSGAN(dev, rows, cols, channels, noise_size, hidden_size, batch, |
| 184 | + interval, learning_rate, iterations, d_steps, g_steps, |
| 185 | + dataset_filepath, file_dir) |
| 186 | + lsgan.train() |
0 commit comments