Skip to content

Commit 7834673

Browse files
authored
Merge pull request #1157 from liye-li/dev-liye
2 parents 946c637 + af0ddee commit 7834673

File tree

1 file changed

+202
-0
lines changed

1 file changed

+202
-0
lines changed

examples/msmlp/model.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import layer
21+
from singa import model
22+
from singa import tensor
23+
from singa import opt
24+
from singa import device
25+
from singa.autograd import Operator
26+
from singa.layer import Layer
27+
from singa import singa_wrap as singa
28+
import argparse
29+
import numpy as np
30+
31+
np_dtype = {"float16": np.float16, "float32": np.float32}
32+
33+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
34+
35+
#### self-defined loss begin
36+
37+
### from autograd.py
38+
class SumError(Operator):
39+
40+
def __init__(self):
41+
super(SumError, self).__init__()
42+
# self.t = t.data
43+
44+
def forward(self, x):
45+
# self.err = singa.__sub__(x, self.t)
46+
self.data_x = x
47+
# sqr = singa.Square(self.err)
48+
# loss = singa.SumAll(sqr)
49+
loss = singa.SumAll(x)
50+
# self.n = 1
51+
# for s in x.shape():
52+
# self.n *= s
53+
# loss /= self.n
54+
return loss
55+
56+
def backward(self, dy=1.0):
57+
# dx = self.err
58+
dev = device.get_default_device()
59+
dx = tensor.Tensor(self.data_x.shape, dev, singa_dtype['float32'])
60+
dx.copy_from_numpy(np.ones(self.data_x.shape))
61+
# dx *= float(2 / self.n)
62+
dx *= dy
63+
return dx
64+
65+
def se_loss(x):
66+
# assert x.shape == t.shape, "input and target shape different: %s, %s" % (
67+
# x.shape, t.shape)
68+
return SumError()(x)[0]
69+
70+
### from layer.py
71+
class SumErrorLayer(Layer):
72+
"""
73+
Generate a MeanSquareError operator
74+
"""
75+
76+
def __init__(self):
77+
super(SumErrorLayer, self).__init__()
78+
79+
def forward(self, x):
80+
return se_loss(x)
81+
82+
#### self-defined loss end
83+
84+
class MSMLP(model.Model):
85+
86+
def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
87+
super(MSMLP, self).__init__()
88+
self.num_classes = num_classes
89+
self.dimension = 2
90+
91+
self.relu = layer.ReLU()
92+
self.linear1 = layer.Linear(perceptron_size)
93+
self.linear2 = layer.Linear(num_classes)
94+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
95+
self.sum_error = SumErrorLayer()
96+
97+
def forward(self, inputs):
98+
y = self.linear1(inputs)
99+
y = self.relu(y)
100+
y = self.linear2(y)
101+
return y
102+
103+
def train_one_batch(self, x, y, synflow_flag, dist_option, spars):
104+
out = self.forward(x)
105+
if synflow_flag:
106+
loss = self.sum_error(out)
107+
else: # normal training
108+
loss = self.softmax_cross_entropy(out, y)
109+
110+
if dist_option == 'plain':
111+
pn_p_g_list = self.optimizer(loss)
112+
elif dist_option == 'half':
113+
self.optimizer.backward_and_update_half(loss)
114+
elif dist_option == 'partialUpdate':
115+
self.optimizer.backward_and_partial_update(loss)
116+
elif dist_option == 'sparseTopK':
117+
self.optimizer.backward_and_sparse_update(loss,
118+
topK=True,
119+
spars=spars)
120+
elif dist_option == 'sparseThreshold':
121+
self.optimizer.backward_and_sparse_update(loss,
122+
topK=False,
123+
spars=spars)
124+
return pn_p_g_list, out, loss
125+
126+
def set_optimizer(self, optimizer):
127+
self.optimizer = optimizer
128+
129+
130+
def create_model(pretrained=False, **kwargs):
131+
"""Constructs a CNN model.
132+
133+
Args:
134+
pretrained (bool): If True, returns a pre-trained model.
135+
136+
Returns:
137+
The created CNN model.
138+
"""
139+
model = MSMLP(**kwargs)
140+
141+
return model
142+
143+
144+
__all__ = ['MLP', 'create_model']
145+
146+
if __name__ == "__main__":
147+
np.random.seed(0)
148+
149+
parser = argparse.ArgumentParser()
150+
parser.add_argument('-p',
151+
choices=['float32', 'float16'],
152+
default='float32',
153+
dest='precision')
154+
parser.add_argument('-g',
155+
'--disable-graph',
156+
default='True',
157+
action='store_false',
158+
help='disable graph',
159+
dest='graph')
160+
parser.add_argument('-m',
161+
'--max-epoch',
162+
default=1001,
163+
type=int,
164+
help='maximum epochs',
165+
dest='max_epoch')
166+
args = parser.parse_args()
167+
168+
# generate the boundary
169+
f = lambda x: (5 * x + 1)
170+
bd_x = np.linspace(-1.0, 1, 200)
171+
bd_y = f(bd_x)
172+
173+
# generate the training data
174+
x = np.random.uniform(-1, 1, 400)
175+
y = f(x) + 2 * np.random.randn(len(x))
176+
177+
# choose one precision
178+
precision = singa_dtype[args.precision]
179+
np_precision = np_dtype[args.precision]
180+
181+
# convert training data to 2d space
182+
label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)]).astype(np.int32)
183+
data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np_precision)
184+
185+
dev = device.create_cuda_gpu_on(0)
186+
sgd = opt.SGD(0.1, 0.9, 1e-5, dtype=singa_dtype[args.precision])
187+
tx = tensor.Tensor((400, 2), dev, precision)
188+
ty = tensor.Tensor((400,), dev, tensor.int32)
189+
model = MLP(data_size=2, perceptron_size=3, num_classes=2)
190+
191+
# attach model to graph
192+
model.set_optimizer(sgd)
193+
model.compile([tx], is_train=True, use_graph=args.graph, sequential=True)
194+
model.train()
195+
196+
for i in range(args.max_epoch):
197+
tx.copy_from_numpy(data)
198+
ty.copy_from_numpy(label)
199+
out, loss = model(tx, ty, 'fp32', spars=None)
200+
201+
if i % 100 == 0:
202+
print("training loss = ", tensor.to_numpy(loss)[0])

0 commit comments

Comments
 (0)