Skip to content

Commit 892425d

Browse files
authoredNov 29, 2024··
Merge pull request #1234 from apache/dev-postgresql
Merge the dev branch to master branch
2 parents ae2e50f + 9fac6c1 commit 892425d

File tree

24 files changed

+3307
-13
lines changed

24 files changed

+3307
-13
lines changed
 

‎examples/cnn_ms/msmlp/model.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
3434

35+
3536
#### self-defined loss begin
3637

3738
### from autograd.py
@@ -62,11 +63,13 @@ def backward(self, dy=1.0):
6263
dx *= dy
6364
return dx
6465

66+
6567
def se_loss(x):
6668
# assert x.shape == t.shape, "input and target shape different: %s, %s" % (
6769
# x.shape, t.shape)
6870
return SumError()(x)[0]
6971

72+
7073
### from layer.py
7174
class SumErrorLayer(Layer):
7275
"""
@@ -79,6 +82,7 @@ def __init__(self):
7982
def forward(self, x):
8083
return se_loss(x)
8184

85+
8286
#### self-defined loss end
8387

8488
class MSMLP(model.Model):
@@ -92,7 +96,6 @@ def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
9296
self.linear1 = layer.Linear(perceptron_size)
9397
self.linear2 = layer.Linear(num_classes)
9498
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
95-
9699
self.sum_error = SumErrorLayer()
97100

98101
def forward(self, inputs):
@@ -101,12 +104,24 @@ def forward(self, inputs):
101104
y = self.linear2(y)
102105
return y
103106

104-
def train_one_batch(self, x, y, synflow_flag, dist_option, spars):
107+
def train_one_batch(self, x, y, dist_option, spars, synflow_flag):
108+
# print ("in train_one_batch")
105109
out = self.forward(x)
106-
loss = self.softmax_cross_entropy(out, y)
110+
# print ("train_one_batch x.data: \n", x.data)
111+
# print ("train_one_batch y.data: \n", y.data)
112+
# print ("train_one_batch out.data: \n", out.data)
113+
if synflow_flag:
114+
# print ("sum_error")
115+
loss = self.sum_error(out)
116+
else: # normal training
117+
# print ("softmax_cross_entropy")
118+
loss = self.softmax_cross_entropy(out, y)
119+
# print ("train_one_batch loss.data: \n", loss.data)
107120

108121
if dist_option == 'plain':
122+
# print ("before pn_p_g_list = self.optimizer(loss)")
109123
pn_p_g_list = self.optimizer(loss)
124+
# print ("after pn_p_g_list = self.optimizer(loss)")
110125
elif dist_option == 'half':
111126
self.optimizer.backward_and_update_half(loss)
112127
elif dist_option == 'partialUpdate':
@@ -119,17 +134,24 @@ def train_one_batch(self, x, y, synflow_flag, dist_option, spars):
119134
self.optimizer.backward_and_sparse_update(loss,
120135
topK=False,
121136
spars=spars)
137+
# print ("len(pn_p_g_list): \n", len(pn_p_g_list))
138+
# print ("len(pn_p_g_list[0]): \n", len(pn_p_g_list[0]))
139+
# print ("pn_p_g_list[0][0]: \n", pn_p_g_list[0][0])
140+
# print ("pn_p_g_list[0][1].data: \n", pn_p_g_list[0][1].data)
141+
# print ("pn_p_g_list[0][2].data: \n", pn_p_g_list[0][2].data)
122142
return pn_p_g_list, out, loss
143+
# return pn_p_g_list[0], pn_p_g_list[1], pn_p_g_list[2], out, loss
123144

124145
def set_optimizer(self, optimizer):
125146
self.optimizer = optimizer
126147

127148

128149
def create_model(pretrained=False, **kwargs):
129150
"""Constructs a CNN model.
151+
130152
Args:
131153
pretrained (bool): If True, returns a pre-trained model.
132-
154+
133155
Returns:
134156
The created CNN model.
135157
"""
@@ -196,4 +218,4 @@ def create_model(pretrained=False, **kwargs):
196218
out, loss = model(tx, ty, 'fp32', spars=None)
197219

198220
if i % 100 == 0:
199-
print("training loss = ", tensor.to_numpy(loss)[0])
221+
print("training loss = ", tensor.to_numpy(loss)[0])

‎examples/cnn_ms/train_ms_model.py

Lines changed: 552 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
import json
21+
import os
22+
import time
23+
from glob import glob
24+
25+
import numpy as np
26+
from PIL import Image
27+
from singa import device, layer, model, opt, tensor
28+
from tqdm import tqdm
29+
30+
from transforms import Compose, Normalize, ToTensor
31+
32+
np_dtype = {"float16": np.float16, "float32": np.float32}
33+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
34+
35+
36+
class ClassDataset(object):
37+
"""Fetch data from file and generate batches.
38+
39+
Load data from folder as PIL.Images and convert them into batch array.
40+
41+
Args:
42+
img_folder (Str): Folder path of the training/validation images.
43+
transforms (Transform): Preprocess transforms.
44+
"""
45+
def __init__(self, img_folder, transforms):
46+
super(ClassDataset, self).__init__()
47+
48+
self.img_list = list()
49+
self.transforms = transforms
50+
51+
classes = os.listdir(img_folder)
52+
for i in classes:
53+
images = glob(os.path.join(img_folder, i, "*"))
54+
for img in images:
55+
self.img_list.append((img, i))
56+
57+
def __len__(self) -> int:
58+
return len(self.img_list)
59+
60+
def __getitem__(self, index: int):
61+
img_path, label_str = self.img_list[index]
62+
img = Image.open(img_path)
63+
img = self.transforms.forward(img)
64+
label = np.array(label_str, dtype=np.int32)
65+
66+
return img, label
67+
68+
def batchgenerator(self, indexes, batch_size, data_size):
69+
"""Generate batch arrays from transformed image list.
70+
71+
Args:
72+
indexes (Sequence): current batch indexes list, e.g. [n, n + 1, ..., n + batch_size]
73+
batch_size (int):
74+
data_size (Tuple): input image size of shape (C, H, W)
75+
76+
Return:
77+
batch_x (Numpy ndarray): batch array of input images (B, C, H, W)
78+
batch_y (Numpy ndarray): batch array of ground truth lables (B,)
79+
"""
80+
batch_x = np.zeros((batch_size,) + data_size)
81+
batch_y = np.zeros((batch_size,) + (1,), dtype=np.int32)
82+
for idx, i in enumerate(indexes):
83+
sample_x, sample_y = self.__getitem__(i)
84+
batch_x[idx, :, :, :] = sample_x
85+
batch_y[idx, :] = sample_y
86+
87+
return batch_x, batch_y
88+
89+
90+
class CNNModel(model.Model):
91+
def __init__(self, num_classes):
92+
super(CNNModel, self).__init__()
93+
self.input_size = 28
94+
self.dimension = 4
95+
self.num_classes = num_classes
96+
97+
self.layer1 = layer.Conv2d(16, kernel_size=3, activation="RELU")
98+
self.bn1 = layer.BatchNorm2d()
99+
self.layer2 = layer.Conv2d(16, kernel_size=3, activation="RELU")
100+
self.bn2 = layer.BatchNorm2d()
101+
self.pooling2 = layer.MaxPool2d(kernel_size=2, stride=2)
102+
self.layer3 = layer.Conv2d(64, kernel_size=3, activation="RELU")
103+
self.bn3 = layer.BatchNorm2d()
104+
self.layer4 = layer.Conv2d(64, kernel_size=3, activation="RELU")
105+
self.bn4 = layer.BatchNorm2d()
106+
self.layer5 = layer.Conv2d(64, kernel_size=3, padding=1, activation="RELU")
107+
self.bn5 = layer.BatchNorm2d()
108+
self.pooling5 = layer.MaxPool2d(kernel_size=2, stride=2)
109+
110+
self.flatten = layer.Flatten()
111+
112+
self.linear1 = layer.Linear(128)
113+
self.linear2 = layer.Linear(128)
114+
self.linear3 = layer.Linear(self.num_classes)
115+
116+
self.relu = layer.ReLU()
117+
118+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
119+
self.dropout = layer.Dropout(ratio=0.3)
120+
121+
def forward(self, x):
122+
x = self.layer1(x)
123+
x = self.bn1(x)
124+
x = self.layer2(x)
125+
x = self.bn2(x)
126+
x = self.pooling2(x)
127+
128+
x = self.layer3(x)
129+
x = self.bn3(x)
130+
x = self.layer4(x)
131+
x = self.bn4(x)
132+
x = self.layer5(x)
133+
x = self.bn5(x)
134+
x = self.pooling5(x)
135+
x = self.flatten(x)
136+
x = self.linear1(x)
137+
x = self.relu(x)
138+
x = self.linear2(x)
139+
x = self.relu(x)
140+
x = self.linear3(x)
141+
return x
142+
143+
def set_optimizer(self, optimizer):
144+
self.optimizer = optimizer
145+
146+
def train_one_batch(self, x, y, dist_option, spars):
147+
out = self.forward(x)
148+
loss = self.softmax_cross_entropy(out, y)
149+
150+
if dist_option == 'plain':
151+
self.optimizer(loss)
152+
elif dist_option == 'half':
153+
self.optimizer.backward_and_update_half(loss)
154+
elif dist_option == 'partialUpdate':
155+
self.optimizer.backward_and_partial_update(loss)
156+
elif dist_option == 'sparseTopK':
157+
self.optimizer.backward_and_sparse_update(loss,
158+
topK=True,
159+
spars=spars)
160+
elif dist_option == 'sparseThreshold':
161+
self.optimizer.backward_and_sparse_update(loss,
162+
topK=False,
163+
spars=spars)
164+
return out, loss
165+
166+
167+
def accuracy(pred, target):
168+
"""Compute recall accuracy.
169+
170+
Args:
171+
pred (Numpy ndarray): Prediction array, should be in shape (B, C)
172+
target (Numpy ndarray): Ground truth array, should be in shape (B, )
173+
174+
Return:
175+
correct (Float): Recall accuracy
176+
"""
177+
# y is network output to be compared with ground truth (int)
178+
y = np.argmax(pred, axis=1)
179+
a = (y[:,None]==target).sum()
180+
correct = np.array(a, "int").sum()
181+
return correct
182+
183+
184+
# Define pre-processing methods (transforms)
185+
transforms = Compose([
186+
ToTensor(),
187+
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
188+
])
189+
190+
# Dataset loading
191+
dataset_path = "./bloodmnist"
192+
train_path = os.path.join(dataset_path, "train")
193+
val_path = os.path.join(dataset_path, "val")
194+
cfg_path = os.path.join(dataset_path, "param.json")
195+
196+
with open(cfg_path,'r') as load_f:
197+
num_class = json.load(load_f)["num_classes"]
198+
199+
train_dataset = ClassDataset(train_path, transforms)
200+
val_dataset = ClassDataset(val_path, transforms)
201+
202+
batch_size = 256
203+
204+
# Model configuration for CNN
205+
model = CNNModel(num_classes=num_class)
206+
criterion = layer.SoftMaxCrossEntropy()
207+
optimizer_ft = opt.Adam(lr=1e-3)
208+
209+
# Start training
210+
dev = device.create_cpu_device()
211+
dev.SetRandSeed(0)
212+
np.random.seed(0)
213+
214+
tx = tensor.Tensor(
215+
(batch_size, 3, model.input_size, model.input_size), dev,
216+
singa_dtype['float32'])
217+
ty = tensor.Tensor((batch_size,), dev, tensor.int32)
218+
219+
num_train_batch = train_dataset.__len__() // batch_size
220+
num_val_batch = val_dataset.__len__() // batch_size
221+
idx = np.arange(train_dataset.__len__(), dtype=np.int32)
222+
223+
model.set_optimizer(optimizer_ft)
224+
model.compile([tx], is_train=True, use_graph=False, sequential=False)
225+
dev.SetVerbosity(0)
226+
227+
max_epoch = 100
228+
for epoch in range(max_epoch):
229+
print(f'Epoch {epoch}:')
230+
231+
start_time = time.time()
232+
233+
train_correct = np.zeros(shape=[1], dtype=np.float32)
234+
test_correct = np.zeros(shape=[1], dtype=np.float32)
235+
train_loss = np.zeros(shape=[1], dtype=np.float32)
236+
237+
# Training part
238+
model.train()
239+
for b in tqdm(range(num_train_batch)):
240+
# Extract batch from image list
241+
x, y = train_dataset.batchgenerator(idx[b * batch_size:(b + 1) * batch_size],
242+
batch_size=batch_size, data_size=(3, model.input_size, model.input_size))
243+
x = x.astype(np_dtype['float32'])
244+
245+
tx.copy_from_numpy(x)
246+
ty.copy_from_numpy(y)
247+
248+
out, loss = model(tx, ty, dist_option="plain", spars=None)
249+
train_correct += accuracy(tensor.to_numpy(out), y)
250+
train_loss += tensor.to_numpy(loss)[0]
251+
print('Training loss = %f, training accuracy = %f' %
252+
(train_loss, train_correct /
253+
(num_train_batch * batch_size)))
254+
255+
# Validation part
256+
model.eval()
257+
for b in tqdm(range(num_val_batch)):
258+
x, y = train_dataset.batchgenerator(idx[b * batch_size:(b + 1) * batch_size],
259+
batch_size=batch_size, data_size=(3, model.input_size, model.input_size))
260+
x = x.astype(np_dtype['float32'])
261+
262+
tx.copy_from_numpy(x)
263+
ty.copy_from_numpy(y)
264+
265+
out = model(tx)
266+
test_correct += accuracy(tensor.to_numpy(out), y)
267+
268+
print('Evaluation accuracy = %f, Elapsed Time = %fs' %
269+
(test_correct / (num_val_batch * batch_size),
270+
time.time() - start_time))
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
<!--
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
-->
19+
# CNN demo model on BloodMnist dataset
20+
21+
## About dataset
22+
Download address: https://drive.google.com/drive/folders/1Ze9qri1UtAsIRoI0SJ4YRpdt5kUUMBEn?usp=sharing
23+
24+
The BloodMNIST , as a sub set of [MedMNIST](https://medmnist.com/), is based on a dataset of individual normal cells, captured from individuals without infection, hematologic or oncologic disease and free of any pharmacologic treatment at the moment of blood collection.
25+
It contains a total of 17,092 images and is organized into 8 classes.
26+
it is split with a ratio of 7:1:2 into training, validation and test set.
27+
The source images with resolution 3×360×363 pixels are center-cropped into 3×200×200, and then resized into 3×28×28.
28+
29+
8 classes of the dataset:
30+
```python
31+
"0": "basophil",
32+
"1": "eosinophil",
33+
"2": "erythroblast",
34+
"3": "ig (immature granulocytes)",
35+
"4": "lymphocyte",
36+
"5": "monocyte",
37+
"6": "neutrophil",
38+
"7": "platelet"
39+
```
40+
41+
# Run the demo
42+
Run
43+
```
44+
python ClassDemo.py
45+
```
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
21+
import numpy as np
22+
from PIL import Image
23+
24+
25+
class Compose(object):
26+
"""Compose several transforms together.
27+
28+
Args:
29+
transforms: list of transforms to compose.
30+
31+
Example:
32+
>>> transforms.Compose([
33+
>>> transforms.ToTensor(),
34+
>>> transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
35+
>>> ])
36+
37+
"""
38+
def __init__(self, transforms):
39+
self.transforms = transforms
40+
41+
def forward(self, img):
42+
"""
43+
Args:
44+
img (PIL Image or numpy array): Image to be processed.
45+
46+
Returns:
47+
PIL Image or numpy array: Processed image.
48+
"""
49+
for t in self.transforms:
50+
img = t.forward(img)
51+
return img
52+
53+
def __repr__(self):
54+
format_string = self.__class__.__name__ + '('
55+
for t in self.transforms:
56+
format_string += '\n'
57+
format_string += ' {0}'.format(t)
58+
format_string += '\n)'
59+
return format_string
60+
61+
62+
class ToTensor(object):
63+
"""Convert a ``PIL Image`` to ``numpy.ndarray``.
64+
65+
Converts a PIL Image (H x W x C) in the range [0, 255] to a ``numpy.array`` of shape
66+
(C x H x W) in the range [0.0, 1.0]
67+
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1).
68+
69+
In the other cases, tensors are returned without scaling.
70+
71+
.. note::
72+
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
73+
transforming target image masks.
74+
"""
75+
76+
def forward(self, pic):
77+
"""
78+
Args:
79+
pic (PIL Image): Image to be converted to array.
80+
81+
Returns:
82+
Array: Converted image.
83+
"""
84+
if not isinstance(pic, Image.Image):
85+
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
86+
87+
# Handle PIL Image
88+
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
89+
img = np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
90+
91+
if pic.mode == '1':
92+
img = 255 * img
93+
94+
# Put it from HWC to CHW format
95+
img = np.transpose(img, (2, 0, 1))
96+
97+
if img.dtype == np.uint8:
98+
return np.array(np.float32(img)/255.0, dtype=np.float)
99+
else:
100+
return np.float(img)
101+
102+
def __repr__(self):
103+
return self.__class__.__name__ + '()'
104+
105+
106+
class Normalize(object):
107+
"""Normalize a ``numpy.array`` image with mean and standard deviation.
108+
109+
This transform does not support PIL Image.
110+
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
111+
channels, this transform will normalize each channel of the input
112+
``numpy.array`` i.e.,
113+
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
114+
115+
.. note::
116+
This transform acts out of place, i.e., it does not mutate the input array.
117+
118+
Args:
119+
mean (Sequence): Sequence of means for each channel.
120+
std (Sequence): Sequence of standard deviations for each channel.
121+
inplace(bool, optional): Bool to make this operation in-place.
122+
123+
"""
124+
125+
def __init__(self, mean, std, inplace=False):
126+
super().__init__()
127+
self.mean = mean
128+
self.std = std
129+
self.inplace = inplace
130+
131+
def forward(self, img: np.ndarray):
132+
"""
133+
Args:
134+
img (Numpy ndarray): Array image to be normalized.
135+
136+
Returns:
137+
d_res (Numpy ndarray): Normalized Tensor image.
138+
"""
139+
if not isinstance(img, np.ndarray):
140+
raise TypeError('Input img should be a numpy array. Got {}.'.format(type(img)))
141+
142+
if not img.dtype == np.float:
143+
raise TypeError('Input array should be a float array. Got {}.'.format(img.dtype))
144+
145+
if img.ndim < 3:
146+
raise ValueError('Expected array to be an array image of size (..., C, H, W). Got img.shape = '
147+
'{}.'.format(img.shape))
148+
149+
if not self.inplace:
150+
img = img.copy()
151+
152+
dtype = img.dtype
153+
mean = np.array(self.mean, dtype=dtype)
154+
std = np.array(self.std, dtype=dtype)
155+
if (std == 0).any():
156+
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
157+
s_res = np.subtract(img, mean[:, None, None])
158+
d_res = np.divide(s_res, std[:, None, None])
159+
160+
return d_res
161+
162+
163+
def __repr__(self):
164+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
165+
166+
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
<!--
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
-->
19+
20+
# Singa for Malaria Detection Task
21+
22+
## Malaria
23+
24+
Malaria is caused by parasites and could be transmitted through infected mosquitoes. There are about 200 million cases worldwide, and about 400,000 deaths per year, therefore, malaria does lots of harm to global health.
25+
26+
Although Malaria is a curable disease, inadequate diagnostics make it harder to reduce mortality, as a result, a fast and reliable diagnostic test is a promising and effective way to fight malaria.
27+
28+
To mitigate the problem, we use Singa to implement a machine learning model to help with Malaria diagnosis. The dataset is from Kaggle https://www.kaggle.com/datasets/miracle9to9/files1?resource=download. Please download the dataset before running the scripts.
29+
30+
## Structure
31+
32+
* `data` includes the scripts for preprocessing Malaria image datasets.
33+
34+
* `model` includes the CNN model construction codes by creating
35+
a subclass of `Module` to wrap the neural network operations
36+
of each model.
37+
38+
* `train_cnn.py` is the training script, which controls the training flow by
39+
doing BackPropagation and SGD update.
40+
41+
## Command
42+
```bash
43+
python train_cnn.py cnn malaria -dir pathToDataset
44+
```
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
try:
20+
import pickle
21+
except ImportError:
22+
import cPickle as pickle
23+
24+
import numpy as np
25+
import os
26+
import sys
27+
from PIL import Image
28+
29+
30+
# need to save to specific local directories
31+
def load_train_data(dir_path="/tmp/malaria", resize_size=(128, 128)):
32+
dir_path = check_dataset_exist(dirpath=dir_path)
33+
path_train_label_1 = os.path.join(dir_path, "training_set/Parasitized")
34+
path_train_label_0 = os.path.join(dir_path, "training_set/Uninfected")
35+
train_label_1 = load_image_path(os.listdir(path_train_label_1))
36+
train_label_0 = load_image_path(os.listdir(path_train_label_0))
37+
labels = []
38+
Images = np.empty((len(train_label_1) + len(train_label_0),
39+
3, resize_size[0], resize_size[1]), dtype=np.uint8)
40+
for i in range(len(train_label_0)):
41+
image_path = os.path.join(path_train_label_0, train_label_0[i])
42+
temp_image = np.array(Image.open(image_path).resize(
43+
resize_size).convert("RGB")).transpose(2, 0, 1)
44+
Images[i] = temp_image
45+
labels.append(0)
46+
for i in range(len(train_label_1)):
47+
image_path = os.path.join(path_train_label_1, train_label_1[i])
48+
temp_image = np.array(Image.open(image_path).resize(
49+
resize_size).convert("RGB")).transpose(2, 0, 1)
50+
Images[i + len(train_label_0)] = temp_image
51+
labels.append(1)
52+
53+
Images = np.array(Images, dtype=np.float32)
54+
labels = np.array(labels, dtype=np.int32)
55+
return Images, labels
56+
57+
58+
# need to save to specific local directories
59+
def load_test_data(dir_path='/tmp/malaria', resize_size=(128, 128)):
60+
dir_path = check_dataset_exist(dirpath=dir_path)
61+
path_test_label_1 = os.path.join(dir_path, "testing_set/Parasitized")
62+
path_test_label_0 = os.path.join(dir_path, "testing_set/Uninfected")
63+
test_label_1 = load_image_path(os.listdir(path_test_label_1))
64+
test_label_0 = load_image_path(os.listdir(path_test_label_0))
65+
labels = []
66+
Images = np.empty((len(test_label_1) + len(test_label_0),
67+
3, resize_size[0], resize_size[1]), dtype=np.uint8)
68+
for i in range(len(test_label_0)):
69+
image_path = os.path.join(path_test_label_0, test_label_0[i])
70+
temp_image = np.array(Image.open(image_path).resize(
71+
resize_size).convert("RGB")).transpose(2, 0, 1)
72+
Images[i] = temp_image
73+
labels.append(0)
74+
for i in range(len(test_label_1)):
75+
image_path = os.path.join(path_test_label_1, test_label_1[i])
76+
temp_image = np.array(Image.open(image_path).resize(
77+
resize_size).convert("RGB")).transpose(2, 0, 1)
78+
Images[i + len(test_label_0)] = temp_image
79+
labels.append(1)
80+
81+
Images = np.array(Images, dtype=np.float32)
82+
labels = np.array(labels, dtype=np.int32)
83+
return Images, labels
84+
85+
86+
def load_image_path(list):
87+
new_list = []
88+
for image_path in list:
89+
if (image_path.endswith(".png") or image_path.endswith(".jpg")):
90+
new_list.append(image_path)
91+
return new_list
92+
93+
94+
def check_dataset_exist(dirpath):
95+
if not os.path.exists(dirpath):
96+
print(
97+
'Please download the malaria dataset first'
98+
)
99+
sys.exit(0)
100+
return dirpath
101+
102+
103+
def normalize(train_x, val_x):
104+
mean = [0.5339, 0.4180, 0.4460] # mean for malaria dataset
105+
std = [0.3329, 0.2637, 0.2761] # std for malaria dataset
106+
train_x /= 255
107+
val_x /= 255
108+
for ch in range(0, 2):
109+
train_x[:, ch, :, :] -= mean[ch]
110+
train_x[:, ch, :, :] /= std[ch]
111+
val_x[:, ch, :, :] -= mean[ch]
112+
val_x[:, ch, :, :] /= std[ch]
113+
return train_x, val_x
114+
115+
116+
def load(dir_path):
117+
train_x, train_y = load_train_data(dir_path=dir_path)
118+
val_x, val_y = load_test_data(dir_path=dir_path)
119+
train_x, val_x = normalize(train_x, val_x)
120+
train_y = train_y.flatten()
121+
val_y = val_y.flatten()
122+
return train_x, train_y, val_x, val_y
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
from singa import layer
20+
from singa import model
21+
22+
23+
class CNN(model.Model):
24+
25+
def __init__(self, num_classes=10, num_channels=1):
26+
super(CNN, self).__init__()
27+
self.num_classes = num_classes
28+
self.input_size = 128
29+
self.dimension = 4
30+
self.conv1 = layer.Conv2d(num_channels, 32, 3, padding=0, activation="RELU")
31+
self.conv2 = layer.Conv2d(32, 64, 3, padding=0, activation="RELU")
32+
self.conv3 = layer.Conv2d(64, 64, 3, padding=0, activation="RELU")
33+
self.linear1 = layer.Linear(128)
34+
self.linear2 = layer.Linear(num_classes)
35+
self.pooling1 = layer.MaxPool2d(2, 2, padding=0)
36+
self.pooling2 = layer.MaxPool2d(2, 2, padding=0)
37+
self.pooling3 = layer.MaxPool2d(2, 2, padding=0)
38+
self.relu = layer.ReLU()
39+
self.flatten = layer.Flatten()
40+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
41+
self.sigmoid = layer
42+
43+
def forward(self, x):
44+
y = self.conv1(x)
45+
y = self.pooling1(y)
46+
y = self.conv2(y)
47+
y = self.pooling2(y)
48+
y = self.conv3(y)
49+
y = self.pooling3(y)
50+
y = self.flatten(y)
51+
y = self.linear1(y)
52+
y = self.relu(y)
53+
y = self.linear2(y)
54+
return y
55+
56+
def train_one_batch(self, x, y, dist_option, spars):
57+
out = self.forward(x)
58+
loss = self.softmax_cross_entropy(out, y)
59+
60+
if dist_option == 'plain':
61+
self.optimizer(loss)
62+
elif dist_option == 'half':
63+
self.optimizer.backward_and_update_half(loss)
64+
elif dist_option == 'partialUpdate':
65+
self.optimizer.backward_and_partial_update(loss)
66+
elif dist_option == 'sparseTopK':
67+
self.optimizer.backward_and_sparse_update(loss,
68+
topK=True,
69+
spars=spars)
70+
elif dist_option == 'sparseThreshold':
71+
self.optimizer.backward_and_sparse_update(loss,
72+
topK=False,
73+
spars=spars)
74+
return out, loss
75+
76+
def set_optimizer(self, optimizer):
77+
self.optimizer = optimizer
78+
79+
80+
def create_model(**kwargs):
81+
"""Constructs a CNN model.
82+
83+
Args:
84+
pretrained (bool): If True, returns a pre-trained model.
85+
86+
Returns:
87+
The created CNN model.
88+
"""
89+
model = CNN(**kwargs)
90+
91+
return model
92+
93+
94+
__all__ = ['CNN', 'create_model']
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
from singa import layer
20+
from singa import model
21+
from singa import tensor
22+
from singa import opt
23+
from singa import device
24+
import argparse
25+
import numpy as np
26+
27+
np_dtype = {"float16": np.float16, "float32": np.float32}
28+
29+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
30+
31+
32+
class MLP(model.Model):
33+
34+
def __init__(self, perceptron_size=100, num_classes=10):
35+
super(MLP, self).__init__()
36+
self.num_classes = num_classes
37+
self.dimension = 2
38+
39+
self.relu = layer.ReLU()
40+
self.linear1 = layer.Linear(perceptron_size)
41+
self.linear2 = layer.Linear(num_classes)
42+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
43+
44+
def forward(self, inputs):
45+
y = self.linear1(inputs)
46+
y = self.relu(y)
47+
y = self.linear2(y)
48+
return y
49+
50+
def train_one_batch(self, x, y, dist_option, spars):
51+
out = self.forward(x)
52+
loss = self.softmax_cross_entropy(out, y)
53+
54+
if dist_option == 'plain':
55+
self.optimizer(loss)
56+
elif dist_option == 'half':
57+
self.optimizer.backward_and_update_half(loss)
58+
elif dist_option == 'partialUpdate':
59+
self.optimizer.backward_and_partial_update(loss)
60+
elif dist_option == 'sparseTopK':
61+
self.optimizer.backward_and_sparse_update(loss,
62+
topK=True,
63+
spars=spars)
64+
elif dist_option == 'sparseThreshold':
65+
self.optimizer.backward_and_sparse_update(loss,
66+
topK=False,
67+
spars=spars)
68+
return out, loss
69+
70+
def set_optimizer(self, optimizer):
71+
self.optimizer = optimizer
72+
73+
74+
def create_model(**kwargs):
75+
"""Constructs a CNN model.
76+
77+
Returns:
78+
The created CNN model.
79+
"""
80+
model = MLP(**kwargs)
81+
82+
return model
83+
84+
85+
__all__ = ['MLP', 'create_model']
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
### malaria dataset
20+
python train_cnn.py cnn malaria -dir pathToDataset
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
from singa import singa_wrap as singa
2+
from singa import device
3+
from singa import tensor
4+
from singa import opt
5+
import numpy as np
6+
import time
7+
import argparse
8+
import sys
9+
from PIL import Image
10+
11+
np_dtype = {"float16": np.float16, "float32": np.float32}
12+
13+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
14+
15+
16+
# Data augmentation
17+
def augmentation(x, batch_size):
18+
xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
19+
for data_num in range(0, batch_size):
20+
offset = np.random.randint(8, size=2)
21+
x[data_num, :, :, :] = xpad[data_num, :,
22+
offset[0]:offset[0] + x.shape[2],
23+
offset[1]:offset[1] + x.shape[2]]
24+
if_flip = np.random.randint(2)
25+
if (if_flip):
26+
x[data_num, :, :, :] = x[data_num, :, :, ::-1]
27+
return x
28+
29+
30+
# Calculate accuracy
31+
def accuracy(pred, target):
32+
# y is network output to be compared with ground truth (int)
33+
y = np.argmax(pred, axis=1)
34+
a = y == target
35+
correct = np.array(a, "int").sum()
36+
return correct
37+
38+
39+
# Data partition according to the rank
40+
def partition(global_rank, world_size, train_x, train_y, val_x, val_y):
41+
# Partition training data
42+
data_per_rank = train_x.shape[0] // world_size
43+
idx_start = global_rank * data_per_rank
44+
idx_end = (global_rank + 1) * data_per_rank
45+
train_x = train_x[idx_start:idx_end]
46+
train_y = train_y[idx_start:idx_end]
47+
48+
# Partition evaluation data
49+
data_per_rank = val_x.shape[0] // world_size
50+
idx_start = global_rank * data_per_rank
51+
idx_end = (global_rank + 1) * data_per_rank
52+
val_x = val_x[idx_start:idx_end]
53+
val_y = val_y[idx_start:idx_end]
54+
return train_x, train_y, val_x, val_y
55+
56+
57+
# Function to all reduce NUMPY accuracy and loss from multiple devices
58+
def reduce_variable(variable, dist_opt, reducer):
59+
reducer.copy_from_numpy(variable)
60+
dist_opt.all_reduce(reducer.data)
61+
dist_opt.wait()
62+
output = tensor.to_numpy(reducer)
63+
return output
64+
65+
66+
def resize_dataset(x, image_size):
67+
num_data = x.shape[0]
68+
dim = x.shape[1]
69+
X = np.zeros(shape=(num_data, dim, image_size, image_size),
70+
dtype=np.float32)
71+
for n in range(0, num_data):
72+
for d in range(0, dim):
73+
X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize(
74+
(image_size, image_size), Image.BILINEAR),
75+
dtype=np.float32)
76+
return X
77+
78+
79+
def run(global_rank,
80+
world_size,
81+
dir_path,
82+
max_epoch,
83+
batch_size,
84+
model,
85+
data,
86+
sgd,
87+
graph,
88+
verbosity,
89+
dist_option='plain',
90+
spars=None,
91+
precision='float32'):
92+
# now CPU version only, could change to GPU device for GPU-support machines
93+
dev = device.get_default_device()
94+
dev.SetRandSeed(0)
95+
np.random.seed(0)
96+
if data == 'malaria':
97+
from data import malaria
98+
train_x, train_y, val_x, val_y = malaria.load(dir_path=dir_path)
99+
else:
100+
print(
101+
'Wrong dataset!'
102+
)
103+
sys.exit(0)
104+
105+
num_channels = train_x.shape[1]
106+
image_size = train_x.shape[2]
107+
data_size = np.prod(train_x.shape[1:train_x.ndim]).item()
108+
num_classes = (np.max(train_y) + 1).item()
109+
110+
if model == 'cnn':
111+
from model import cnn
112+
model = cnn.create_model(num_channels=num_channels,
113+
num_classes=num_classes)
114+
else:
115+
print(
116+
'Wrong model!'
117+
)
118+
sys.exit(0)
119+
120+
# For distributed training, sequential has better performance
121+
if hasattr(sgd, "communicator"):
122+
DIST = True
123+
sequential = True
124+
else:
125+
DIST = False
126+
sequential = False
127+
128+
if DIST:
129+
train_x, train_y, val_x, val_y = partition(global_rank, world_size,
130+
train_x, train_y, val_x,
131+
val_y)
132+
133+
if model.dimension == 4:
134+
tx = tensor.Tensor(
135+
(batch_size, num_channels, model.input_size, model.input_size), dev,
136+
singa_dtype[precision])
137+
elif model.dimension == 2:
138+
tx = tensor.Tensor((batch_size, data_size),
139+
dev, singa_dtype[precision])
140+
np.reshape(train_x, (train_x.shape[0], -1))
141+
np.reshape(val_x, (val_x.shape[0], -1))
142+
143+
ty = tensor.Tensor((batch_size,), dev, tensor.int32)
144+
num_train_batch = train_x.shape[0] // batch_size
145+
num_val_batch = val_x.shape[0] // batch_size
146+
idx = np.arange(train_x.shape[0], dtype=np.int32)
147+
148+
# Attach model to graph
149+
model.set_optimizer(sgd)
150+
model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
151+
dev.SetVerbosity(verbosity)
152+
153+
# Training and evaluation loop
154+
for epoch in range(max_epoch):
155+
start_time = time.time()
156+
np.random.shuffle(idx)
157+
158+
if global_rank == 0:
159+
print('Starting Epoch %d:' % (epoch))
160+
161+
# Training phase
162+
train_correct = np.zeros(shape=[1], dtype=np.float32)
163+
test_correct = np.zeros(shape=[1], dtype=np.float32)
164+
train_loss = np.zeros(shape=[1], dtype=np.float32)
165+
166+
model.train()
167+
for b in range(num_train_batch):
168+
# if b % 100 == 0:
169+
# print ("b: \n", b)
170+
# Generate the patch data in this iteration
171+
x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
172+
if model.dimension == 4:
173+
x = augmentation(x, batch_size)
174+
if (image_size != model.input_size):
175+
x = resize_dataset(x, model.input_size)
176+
x = x.astype(np_dtype[precision])
177+
y = train_y[idx[b * batch_size:(b + 1) * batch_size]]
178+
179+
# Copy the patch data into input tensors
180+
tx.copy_from_numpy(x)
181+
ty.copy_from_numpy(y)
182+
183+
# Train the model
184+
out, loss = model(tx, ty, dist_option, spars)
185+
train_correct += accuracy(tensor.to_numpy(out), y)
186+
train_loss += tensor.to_numpy(loss)[0]
187+
188+
if DIST:
189+
# Reduce the evaluation accuracy and loss from multiple devices
190+
reducer = tensor.Tensor((1,), dev, tensor.float32)
191+
train_correct = reduce_variable(train_correct, sgd, reducer)
192+
train_loss = reduce_variable(train_loss, sgd, reducer)
193+
194+
if global_rank == 0:
195+
print('Training loss = %f, training accuracy = %f' %
196+
(train_loss, train_correct /
197+
(num_train_batch * batch_size * world_size)),
198+
flush=True)
199+
200+
# Evaluation phase
201+
model.eval()
202+
for b in range(num_val_batch):
203+
x = val_x[b * batch_size:(b + 1) * batch_size]
204+
if model.dimension == 4:
205+
if (image_size != model.input_size):
206+
x = resize_dataset(x, model.input_size)
207+
x = x.astype(np_dtype[precision])
208+
y = val_y[b * batch_size:(b + 1) * batch_size]
209+
tx.copy_from_numpy(x)
210+
ty.copy_from_numpy(y)
211+
out_test = model(tx)
212+
test_correct += accuracy(tensor.to_numpy(out_test), y)
213+
214+
if DIST:
215+
# Reduce the evaulation accuracy from multiple devices
216+
test_correct = reduce_variable(test_correct, sgd, reducer)
217+
218+
# Output the evaluation accuracy
219+
if global_rank == 0:
220+
print('Evaluation accuracy = %f, Elapsed Time = %fs' %
221+
(test_correct / (num_val_batch * batch_size * world_size),
222+
time.time() - start_time),
223+
flush=True)
224+
225+
dev.PrintTimeProfiling()
226+
227+
228+
if __name__ == '__main__':
229+
# Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
230+
parser = argparse.ArgumentParser(
231+
description='Training using the autograd and graph.')
232+
parser.add_argument(
233+
'model',
234+
choices=['cnn'],
235+
default='cnn')
236+
parser.add_argument('data',
237+
choices=['malaria'],
238+
default='malaria')
239+
parser.add_argument('-p',
240+
choices=['float32', 'float16'],
241+
default='float32',
242+
dest='precision')
243+
parser.add_argument('-dir',
244+
'--dir-path',
245+
default="/tmp/malaria",
246+
type=str,
247+
help='the directory to store the malaria dataset',
248+
dest='dir_path')
249+
parser.add_argument('-m',
250+
'--max-epoch',
251+
default=100,
252+
type=int,
253+
help='maximum epochs',
254+
dest='max_epoch')
255+
parser.add_argument('-b',
256+
'--batch-size',
257+
default=64,
258+
type=int,
259+
help='batch size',
260+
dest='batch_size')
261+
parser.add_argument('-l',
262+
'--learning-rate',
263+
default=0.005,
264+
type=float,
265+
help='initial learning rate',
266+
dest='lr')
267+
parser.add_argument('-g',
268+
'--disable-graph',
269+
default='True',
270+
action='store_false',
271+
help='disable graph',
272+
dest='graph')
273+
parser.add_argument('-v',
274+
'--log-verbosity',
275+
default=0,
276+
type=int,
277+
help='logging verbosity',
278+
dest='verbosity')
279+
280+
args = parser.parse_args()
281+
282+
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5,
283+
dtype=singa_dtype[args.precision])
284+
run(0,
285+
1,
286+
args.dir_path,
287+
args.max_epoch,
288+
args.batch_size,
289+
args.model,
290+
args.data,
291+
sgd,
292+
args.graph,
293+
args.verbosity,
294+
precision=args.precision)
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
from singa import singa_wrap as singa
20+
from singa import device
21+
from singa import tensor
22+
from singa import opt
23+
import numpy as np
24+
import time
25+
import argparse
26+
import sys
27+
sys.path.append("../../..")
28+
29+
from PIL import Image
30+
31+
from healthcare.data import malaria
32+
from healthcare.models import malaria_net
33+
34+
np_dtype = {"float16": np.float16, "float32": np.float32}
35+
36+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
37+
38+
39+
# Data augmentation
40+
def augmentation(x, batch_size):
41+
xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
42+
for data_num in range(0, batch_size):
43+
offset = np.random.randint(8, size=2)
44+
x[data_num, :, :, :] = xpad[data_num, :,
45+
offset[0]:offset[0] + x.shape[2],
46+
offset[1]:offset[1] + x.shape[2]]
47+
if_flip = np.random.randint(2)
48+
if (if_flip):
49+
x[data_num, :, :, :] = x[data_num, :, :, ::-1]
50+
return x
51+
52+
53+
# Calculate accuracy
54+
def accuracy(pred, target):
55+
# y is network output to be compared with ground truth (int)
56+
y = np.argmax(pred, axis=1)
57+
a = y == target
58+
correct = np.array(a, "int").sum()
59+
return correct
60+
61+
62+
# Data partition according to the rank
63+
def partition(global_rank, world_size, train_x, train_y, val_x, val_y):
64+
# Partition training data
65+
data_per_rank = train_x.shape[0] // world_size
66+
idx_start = global_rank * data_per_rank
67+
idx_end = (global_rank + 1) * data_per_rank
68+
train_x = train_x[idx_start:idx_end]
69+
train_y = train_y[idx_start:idx_end]
70+
71+
# Partition evaluation data
72+
data_per_rank = val_x.shape[0] // world_size
73+
idx_start = global_rank * data_per_rank
74+
idx_end = (global_rank + 1) * data_per_rank
75+
val_x = val_x[idx_start:idx_end]
76+
val_y = val_y[idx_start:idx_end]
77+
return train_x, train_y, val_x, val_y
78+
79+
80+
# Function to all reduce NUMPY accuracy and loss from multiple devices
81+
def reduce_variable(variable, dist_opt, reducer):
82+
reducer.copy_from_numpy(variable)
83+
dist_opt.all_reduce(reducer.data)
84+
dist_opt.wait()
85+
output = tensor.to_numpy(reducer)
86+
return output
87+
88+
89+
def resize_dataset(x, image_size):
90+
num_data = x.shape[0]
91+
dim = x.shape[1]
92+
X = np.zeros(shape=(num_data, dim, image_size, image_size),
93+
dtype=np.float32)
94+
for n in range(0, num_data):
95+
for d in range(0, dim):
96+
X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize(
97+
(image_size, image_size), Image.BILINEAR),
98+
dtype=np.float32)
99+
return X
100+
101+
102+
def run(global_rank,
103+
world_size,
104+
dir_path,
105+
max_epoch,
106+
batch_size,
107+
model,
108+
data,
109+
sgd,
110+
graph,
111+
verbosity,
112+
dist_option='plain',
113+
spars=None,
114+
precision='float32'):
115+
# now CPU version only, could change to GPU device for GPU-support machines
116+
dev = device.get_default_device()
117+
dev.SetRandSeed(0)
118+
np.random.seed(0)
119+
if data == 'malaria':
120+
121+
train_x, train_y, val_x, val_y = malaria.load(dir_path=dir_path)
122+
else:
123+
print(
124+
'Wrong dataset!'
125+
)
126+
sys.exit(0)
127+
128+
num_channels = train_x.shape[1]
129+
image_size = train_x.shape[2]
130+
data_size = np.prod(train_x.shape[1:train_x.ndim]).item()
131+
num_classes = (np.max(train_y) + 1).item()
132+
133+
if model == 'cnn':
134+
model = malaria_net.create_model(model_option='cnn', num_channels=num_channels,
135+
num_classes=num_classes)
136+
else:
137+
print(
138+
'Wrong model!'
139+
)
140+
sys.exit(0)
141+
142+
# For distributed training, sequential has better performance
143+
if hasattr(sgd, "communicator"):
144+
DIST = True
145+
sequential = True
146+
else:
147+
DIST = False
148+
sequential = False
149+
150+
if DIST:
151+
train_x, train_y, val_x, val_y = partition(global_rank, world_size,
152+
train_x, train_y, val_x,
153+
val_y)
154+
155+
if model.dimension == 4:
156+
tx = tensor.Tensor(
157+
(batch_size, num_channels, model.input_size, model.input_size), dev,
158+
singa_dtype[precision])
159+
elif model.dimension == 2:
160+
tx = tensor.Tensor((batch_size, data_size),
161+
dev, singa_dtype[precision])
162+
np.reshape(train_x, (train_x.shape[0], -1))
163+
np.reshape(val_x, (val_x.shape[0], -1))
164+
165+
ty = tensor.Tensor((batch_size,), dev, tensor.int32)
166+
num_train_batch = train_x.shape[0] // batch_size
167+
num_val_batch = val_x.shape[0] // batch_size
168+
idx = np.arange(train_x.shape[0], dtype=np.int32)
169+
170+
# Attach model to graph
171+
model.set_optimizer(sgd)
172+
model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
173+
dev.SetVerbosity(verbosity)
174+
175+
# Training and evaluation loop
176+
for epoch in range(max_epoch):
177+
start_time = time.time()
178+
np.random.shuffle(idx)
179+
180+
if global_rank == 0:
181+
print('Starting Epoch %d:' % (epoch))
182+
183+
# Training phase
184+
train_correct = np.zeros(shape=[1], dtype=np.float32)
185+
test_correct = np.zeros(shape=[1], dtype=np.float32)
186+
train_loss = np.zeros(shape=[1], dtype=np.float32)
187+
188+
model.train()
189+
for b in range(num_train_batch):
190+
# if b % 100 == 0:
191+
# print ("b: \n", b)
192+
# Generate the patch data in this iteration
193+
x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
194+
if model.dimension == 4:
195+
x = augmentation(x, batch_size)
196+
if (image_size != model.input_size):
197+
x = resize_dataset(x, model.input_size)
198+
x = x.astype(np_dtype[precision])
199+
y = train_y[idx[b * batch_size:(b + 1) * batch_size]]
200+
201+
# Copy the patch data into input tensors
202+
tx.copy_from_numpy(x)
203+
ty.copy_from_numpy(y)
204+
205+
# Train the model
206+
out, loss = model(tx, ty, dist_option, spars)
207+
train_correct += accuracy(tensor.to_numpy(out), y)
208+
train_loss += tensor.to_numpy(loss)[0]
209+
210+
# print('batch training loss = %f' % train_loss, flush=True)
211+
212+
if DIST:
213+
# Reduce the evaluation accuracy and loss from multiple devices
214+
reducer = tensor.Tensor((1,), dev, tensor.float32)
215+
train_correct = reduce_variable(train_correct, sgd, reducer)
216+
train_loss = reduce_variable(train_loss, sgd, reducer)
217+
218+
if global_rank == 0:
219+
print('Training loss = %f, training accuracy = %f' %
220+
(train_loss, train_correct /
221+
(num_train_batch * batch_size * world_size)),
222+
flush=True)
223+
224+
# Evaluation phase
225+
model.eval()
226+
for b in range(num_val_batch):
227+
x = val_x[b * batch_size:(b + 1) * batch_size]
228+
if model.dimension == 4:
229+
if (image_size != model.input_size):
230+
x = resize_dataset(x, model.input_size)
231+
x = x.astype(np_dtype[precision])
232+
y = val_y[b * batch_size:(b + 1) * batch_size]
233+
tx.copy_from_numpy(x)
234+
ty.copy_from_numpy(y)
235+
out_test = model(tx)
236+
test_correct += accuracy(tensor.to_numpy(out_test), y)
237+
238+
if DIST:
239+
# Reduce the evaulation accuracy from multiple devices
240+
test_correct = reduce_variable(test_correct, sgd, reducer)
241+
242+
# Output the evaluation accuracy
243+
if global_rank == 0:
244+
print('Evaluation accuracy = %f, Elapsed Time = %fs' %
245+
(test_correct / (num_val_batch * batch_size * world_size),
246+
time.time() - start_time),
247+
flush=True)
248+
249+
dev.PrintTimeProfiling()
250+
251+
252+
if __name__ == '__main__':
253+
# Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
254+
parser = argparse.ArgumentParser(
255+
description='Training using the autograd and graph.')
256+
parser.add_argument(
257+
'model',
258+
choices=['cnn'],
259+
default='cnn')
260+
parser.add_argument('data',
261+
choices=['malaria'],
262+
default='malaria')
263+
parser.add_argument('-p',
264+
choices=['float32', 'float16'],
265+
default='float32',
266+
dest='precision')
267+
parser.add_argument('-dir',
268+
'--dir-path',
269+
default="/tmp/malaria",
270+
type=str,
271+
help='the directory to store the malaria dataset',
272+
dest='dir_path')
273+
parser.add_argument('-m',
274+
'--max-epoch',
275+
default=100,
276+
type=int,
277+
help='maximum epochs',
278+
dest='max_epoch')
279+
parser.add_argument('-b',
280+
'--batch-size',
281+
default=64,
282+
type=int,
283+
help='batch size',
284+
dest='batch_size')
285+
parser.add_argument('-l',
286+
'--learning-rate',
287+
default=0.005,
288+
type=float,
289+
help='initial learning rate',
290+
dest='lr')
291+
parser.add_argument('-g',
292+
'--disable-graph',
293+
default='True',
294+
action='store_false',
295+
help='disable graph',
296+
dest='graph')
297+
parser.add_argument('-v',
298+
'--log-verbosity',
299+
default=0,
300+
type=int,
301+
help='logging verbosity',
302+
dest='verbosity')
303+
304+
args = parser.parse_args()
305+
306+
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5,
307+
dtype=singa_dtype[args.precision])
308+
run(0,
309+
1,
310+
args.dir_path,
311+
args.max_epoch,
312+
args.batch_size,
313+
args.model,
314+
args.data,
315+
sgd,
316+
args.graph,
317+
args.verbosity,
318+
precision=args.precision);
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Convolutional Prototype Learning
2+
3+
We have successfully applied the idea of prototype loss in various medical image classification task to improve performance, for example detection thyroid eye disease from CT images. Here we provide the implementation of the convolution prototype model in Singa. Due to data privacy, we are not able to release the CT image dataset used. The training scripts `./train.py` demonstrate how to apply this model on cifar-10 dataset.
4+
5+
## run
6+
7+
At Singa project root directory `python examples/healthcare/application/TED_CT_Detection/train.py`
8+
9+
## reference
10+
11+
[Robust Classification with Convolutional Prototype Learning](https://arxiv.org/abs/1805.03438)
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
from singa import layer
20+
from singa import model
21+
import singa.tensor as tensor
22+
from singa import autograd
23+
from singa.tensor import Tensor
24+
25+
26+
class CPLayer(layer.Layer):
27+
def __init__(self, prototype_count=2, temp=10.0):
28+
super(CPLayer, self).__init__()
29+
self.prototype_count = prototype_count
30+
self.temp = temp
31+
32+
def initialize(self, x):
33+
self.feature_dim = x.shape[1]
34+
self.prototype = tensor.random(
35+
(self.feature_dim, self.prototype_count), device=x.device
36+
)
37+
38+
def forward(self, feat):
39+
self.device_check(feat, self.prototype)
40+
self.dtype_check(feat, self.prototype)
41+
42+
feat_sq = autograd.mul(feat, feat)
43+
feat_sq_sum = autograd.reduce_sum(feat_sq, axes=[1], keepdims=1)
44+
feat_sq_sum_tile = autograd.tile(feat_sq_sum, repeats=[1, self.feature_dim])
45+
46+
prototype_sq = autograd.mul(self.prototype, self.prototype)
47+
prototype_sq_sum = autograd.reduce_sum(prototype_sq, axes=[0], keepdims=1)
48+
prototype_sq_sum_tile = autograd.tile(prototype_sq_sum, repeats=feat.shape[0])
49+
50+
cross_term = autograd.matmul(feat, self.prototype)
51+
cross_term_scale = Tensor(
52+
shape=cross_term.shape, device=cross_term.device, requires_grad=False
53+
).set_value(-2)
54+
cross_term_scaled = autograd.mul(cross_term, cross_term_scale)
55+
56+
dist = autograd.add(feat_sq_sum_tile, prototype_sq_sum_tile)
57+
dist = autograd.add(dist, cross_term_scaled)
58+
59+
logits_coeff = (
60+
tensor.ones((feat.shape[0], self.prototype.shape[1]), device=feat.device)
61+
* -1.0
62+
/ self.temp
63+
)
64+
logits_coeff.requires_grad = False
65+
logits = autograd.mul(logits_coeff, dist)
66+
67+
return logits
68+
69+
def get_params(self):
70+
return {self.prototype.name: self.prototype}
71+
72+
def set_params(self, parameters):
73+
self.prototype.copy_from(parameters[self.prototype.name])
74+
75+
76+
class CPL(model.Model):
77+
78+
def __init__(
79+
self,
80+
backbone: model.Model,
81+
prototype_count=2,
82+
lamb=0.5,
83+
temp=10,
84+
label=None,
85+
prototype_weight=None,
86+
):
87+
super(CPL, self).__init__()
88+
# config
89+
self.lamb = lamb
90+
self.prototype_weight = prototype_weight
91+
self.prototype_label = label
92+
93+
# layer
94+
self.backbone = backbone
95+
self.cplayer = CPLayer(prototype_count=prototype_count, temp=temp)
96+
# optimizer
97+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
98+
99+
def forward(self, x):
100+
feat = self.backbone.forward(x)
101+
logits = self.cplayer(feat)
102+
return logits
103+
104+
def train_one_batch(self, x, y):
105+
out = self.forward(x)
106+
loss = self.softmax_cross_entropy(out, y)
107+
self.optimizer(loss)
108+
return out, loss
109+
110+
def set_optimizer(self, optimizer):
111+
self.optimizer = optimizer
112+
113+
114+
def create_model(backbone, prototype_count=2, lamb=0.5, temp=10.0):
115+
model = CPL(backbone, prototype_count=prototype_count, lamb=lamb, temp=temp)
116+
return model
117+
118+
119+
__all__ = ["CPL", "create_model"]
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import device
21+
from singa import opt
22+
from singa import tensor
23+
import argparse
24+
import numpy as np
25+
import time
26+
from PIL import Image
27+
28+
import sys
29+
30+
sys.path.append(".")
31+
print(sys.path)
32+
33+
import examples.cnn.model.cnn as cnn
34+
from examples.cnn.data import cifar10
35+
import model as cpl
36+
37+
38+
def accuracy(pred, target):
39+
# y is network output to be compared with ground truth (int)
40+
y = np.argmax(pred, axis=1)
41+
a = y == target
42+
correct = np.array(a, "int").sum()
43+
return correct
44+
45+
46+
def resize_dataset(x, image_size):
47+
num_data = x.shape[0]
48+
dim = x.shape[1]
49+
X = np.zeros(shape=(num_data, dim, image_size, image_size), dtype=np.float32)
50+
for n in range(0, num_data):
51+
for d in range(0, dim):
52+
X[n, d, :, :] = np.array(
53+
Image.fromarray(x[n, d, :, :]).resize(
54+
(image_size, image_size), Image.BILINEAR
55+
),
56+
dtype=np.float32,
57+
)
58+
return X
59+
60+
61+
def run(
62+
local_rank,
63+
max_epoch,
64+
batch_size,
65+
sgd,
66+
graph,
67+
verbosity,
68+
dist_option="plain",
69+
spars=None,
70+
):
71+
dev = device.create_cuda_gpu_on(local_rank)
72+
dev.SetRandSeed(0)
73+
np.random.seed(0)
74+
75+
train_x, train_y, val_x, val_y = cifar10.load()
76+
77+
num_channels = train_x.shape[1]
78+
data_size = np.prod(train_x.shape[1 : train_x.ndim]).item()
79+
num_classes = (np.max(train_y) + 1).item()
80+
81+
backbone = cnn.create_model(num_channels=num_channels, num_classes=num_classes)
82+
model = cpl.create_model(backbone, prototype_count=10, lamb=0.5, temp=10)
83+
84+
if backbone.dimension == 4:
85+
tx = tensor.Tensor(
86+
(batch_size, num_channels, backbone.input_size, backbone.input_size), dev
87+
)
88+
train_x = resize_dataset(train_x, backbone.input_size)
89+
val_x = resize_dataset(val_x, backbone.input_size)
90+
elif backbone.dimension == 2:
91+
tx = tensor.Tensor((batch_size, data_size), dev)
92+
np.reshape(train_x, (train_x.shape[0], -1))
93+
np.reshape(val_x, (val_x.shape[0], -1))
94+
95+
ty = tensor.Tensor((batch_size,), dev, tensor.int32)
96+
num_train_batch = train_x.shape[0] // batch_size
97+
num_val_batch = val_x.shape[0] // batch_size
98+
idx = np.arange(train_x.shape[0], dtype=np.int32)
99+
100+
model.set_optimizer(sgd)
101+
model.compile([tx], is_train=True, use_graph=graph, sequential=True)
102+
dev.SetVerbosity(verbosity)
103+
104+
for epoch in range(max_epoch):
105+
print(f"Epoch {epoch}")
106+
np.random.shuffle(idx)
107+
108+
train_correct = np.zeros(shape=[1], dtype=np.float32)
109+
test_correct = np.zeros(shape=[1], dtype=np.float32)
110+
train_loss = np.zeros(shape=[1], dtype=np.float32)
111+
112+
model.train()
113+
for b in range(num_train_batch):
114+
x = train_x[idx[b * batch_size : (b + 1) * batch_size]]
115+
y = train_y[idx[b * batch_size : (b + 1) * batch_size]]
116+
tx.copy_from_numpy(x)
117+
ty.copy_from_numpy(y)
118+
119+
out, loss = model(tx, ty, dist_option, spars)
120+
train_correct += accuracy(tensor.to_numpy(out), y)
121+
train_loss += tensor.to_numpy(loss)[0]
122+
print(
123+
"Training loss = %f, training accuracy = %f"
124+
% (train_loss, train_correct / (num_train_batch * batch_size)),
125+
flush=True,
126+
)
127+
128+
model.eval()
129+
for b in range(num_val_batch):
130+
x = val_x[b * batch_size : (b + 1) * batch_size]
131+
y = val_y[b * batch_size : (b + 1) * batch_size]
132+
133+
tx.copy_from_numpy(x)
134+
ty.copy_from_numpy(y)
135+
136+
out_test = model(tx, ty, dist_option="fp32", spars=None)
137+
test_correct += accuracy(tensor.to_numpy(out_test), y)
138+
139+
140+
if __name__ == "__main__":
141+
parser = argparse.ArgumentParser(description="Train a CPL model")
142+
parser.add_argument(
143+
"-m",
144+
"--max-epoch",
145+
default=20,
146+
type=int,
147+
help="maximum epochs",
148+
dest="max_epoch",
149+
)
150+
parser.add_argument(
151+
"-b", "--batch-size", default=64, type=int, help="batch size", dest="batch_size"
152+
)
153+
parser.add_argument(
154+
"-l",
155+
"--learning-rate",
156+
default=0.005,
157+
type=float,
158+
help="initial learning rate",
159+
dest="lr",
160+
)
161+
parser.add_argument(
162+
"-i",
163+
"--device-id",
164+
default=0,
165+
type=int,
166+
help="which GPU to use",
167+
dest="device_id",
168+
)
169+
parser.add_argument(
170+
"-g",
171+
"--disable-graph",
172+
default="True",
173+
action="store_false",
174+
help="disable graph",
175+
dest="graph",
176+
)
177+
parser.add_argument(
178+
"-v",
179+
"--log-verbosity",
180+
default=0,
181+
type=int,
182+
help="logging verbosity",
183+
dest="verbosity",
184+
)
185+
args = parser.parse_args()
186+
print(args)
187+
188+
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
189+
run(
190+
args.device_id, args.max_epoch, args.batch_size, sgd, args.graph, args.verbosity
191+
)
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
import os
21+
import json
22+
from glob import glob
23+
import numpy as np
24+
from PIL import Image
25+
26+
27+
class Compose(object):
28+
"""Compose several transforms together.
29+
30+
Args:
31+
transforms: list of transforms to compose.
32+
33+
Example:
34+
>>> transforms.Compose([
35+
>>> transforms.ToTensor(),
36+
>>> transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
37+
>>> ])
38+
39+
"""
40+
41+
def __init__(self, transforms):
42+
self.transforms = transforms
43+
44+
def forward(self, img):
45+
"""
46+
Args:
47+
img (PIL Image or numpy array): Image to be processed.
48+
49+
Returns:
50+
PIL Image or numpy array: Processed image.
51+
"""
52+
for t in self.transforms:
53+
img = t.forward(img)
54+
return img
55+
56+
def __repr__(self):
57+
format_string = self.__class__.__name__ + '('
58+
for t in self.transforms:
59+
format_string += '\n'
60+
format_string += ' {0}'.format(t)
61+
format_string += '\n)'
62+
return format_string
63+
64+
65+
class ToTensor(object):
66+
"""Convert a ``PIL Image`` to ``numpy.ndarray``.
67+
68+
Converts a PIL Image (H x W x C) in the range [0, 255] to a ``numpy.array`` of shape
69+
(C x H x W) in the range [0.0, 1.0]
70+
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1).
71+
72+
In the other cases, tensors are returned without scaling.
73+
74+
.. note::
75+
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
76+
transforming target image masks.
77+
"""
78+
79+
def forward(self, pic):
80+
"""
81+
Args:
82+
pic (PIL Image): Image to be converted to array.
83+
84+
Returns:
85+
Array: Converted image.
86+
"""
87+
if not isinstance(pic, Image.Image):
88+
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
89+
90+
# Handle PIL Image
91+
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
92+
img = np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
93+
94+
if pic.mode == '1':
95+
img = 255 * img
96+
97+
# Put it from HWC to CHW format
98+
img = np.transpose(img, (2, 0, 1))
99+
100+
if img.dtype == np.uint8:
101+
return np.array(np.float32(img) / 255.0, dtype=np.float)
102+
else:
103+
return np.float(img)
104+
105+
def __repr__(self):
106+
return self.__class__.__name__ + '()'
107+
108+
109+
class Normalize(object):
110+
"""Normalize a ``numpy.array`` image with mean and standard deviation.
111+
112+
This transform does not support PIL Image.
113+
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
114+
channels, this transform will normalize each channel of the input
115+
``numpy.array`` i.e.,
116+
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
117+
118+
.. note::
119+
This transform acts out of place, i.e., it does not mutate the input array.
120+
121+
Args:
122+
mean (Sequence): Sequence of means for each channel.
123+
std (Sequence): Sequence of standard deviations for each channel.
124+
inplace(bool, optional): Bool to make this operation in-place.
125+
126+
"""
127+
128+
def __init__(self, mean, std, inplace=False):
129+
super().__init__()
130+
self.mean = mean
131+
self.std = std
132+
self.inplace = inplace
133+
134+
def forward(self, img: np.ndarray):
135+
"""
136+
Args:
137+
img (Numpy ndarray): Array image to be normalized.
138+
139+
Returns:
140+
d_res (Numpy ndarray): Normalized Tensor image.
141+
"""
142+
if not isinstance(img, np.ndarray):
143+
raise TypeError('Input img should be a numpy array. Got {}.'.format(type(img)))
144+
145+
if not img.dtype == np.float:
146+
raise TypeError('Input array should be a float array. Got {}.'.format(img.dtype))
147+
148+
if img.ndim < 3:
149+
raise ValueError('Expected array to be an array image of size (..., C, H, W). Got img.shape = '
150+
'{}.'.format(img.shape))
151+
152+
if not self.inplace:
153+
img = img.copy()
154+
155+
dtype = img.dtype
156+
mean = np.array(self.mean, dtype=dtype)
157+
std = np.array(self.std, dtype=dtype)
158+
if (std == 0).any():
159+
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
160+
s_res = np.subtract(img, mean[:, None, None])
161+
d_res = np.divide(s_res, std[:, None, None])
162+
163+
return d_res
164+
165+
def __repr__(self):
166+
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
167+
168+
169+
class ClassDataset(object):
170+
"""Fetch data from file and generate batches.
171+
172+
Load data from folder as PIL.Images and convert them into batch array.
173+
174+
Args:
175+
img_folder (Str): Folder path of the training/validation images.
176+
transforms (Transform): Preprocess transforms.
177+
"""
178+
179+
def __init__(self, img_folder, transforms):
180+
super(ClassDataset, self).__init__()
181+
182+
self.img_list = list()
183+
self.transforms = transforms
184+
185+
classes = os.listdir(img_folder)
186+
for i in classes:
187+
images = glob(os.path.join(img_folder, i, "*"))
188+
for img in images:
189+
self.img_list.append((img, i))
190+
191+
def __len__(self) -> int:
192+
return len(self.img_list)
193+
194+
def __getitem__(self, index: int):
195+
img_path, label_str = self.img_list[index]
196+
img = Image.open(img_path)
197+
img = self.transforms.forward(img)
198+
label = np.array(label_str, dtype=np.int32)
199+
200+
return img, label
201+
202+
def batchgenerator(self, indexes, batch_size, data_size):
203+
"""Generate batch arrays from transformed image list.
204+
205+
Args:
206+
indexes (Sequence): current batch indexes list, e.g. [n, n + 1, ..., n + batch_size]
207+
batch_size (int):
208+
data_size (Tuple): input image size of shape (C, H, W)
209+
210+
Return:
211+
batch_x (Numpy ndarray): batch array of input images (B, C, H, W)
212+
batch_y (Numpy ndarray): batch array of ground truth lables (B,)
213+
"""
214+
batch_x = np.zeros((batch_size,) + data_size)
215+
batch_y = np.zeros((batch_size,) + (1,), dtype=np.int32)
216+
for idx, i in enumerate(indexes):
217+
sample_x, sample_y = self.__getitem__(i)
218+
batch_x[idx, :, :, :] = sample_x
219+
batch_y[idx, :] = sample_y
220+
221+
return batch_x, batch_y
222+
223+
224+
def load(dir_path="tmp/bloodmnist"):
225+
# Dataset loading
226+
train_path = os.path.join(dir_path, "train")
227+
val_path = os.path.join(dir_path, "val")
228+
cfg_path = os.path.join(dir_path, "param.json")
229+
230+
with open(cfg_path, 'r') as load_f:
231+
num_class = json.load(load_f)["num_classes"]
232+
233+
# Define pre-processing methods (transforms)
234+
transforms = Compose([
235+
ToTensor(),
236+
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
237+
])
238+
train_dataset = ClassDataset(train_path, transforms)
239+
val_dataset = ClassDataset(val_path, transforms)
240+
return train_dataset, val_dataset, num_class

‎examples/healthcare/data/malaria.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
try:
20+
import pickle
21+
except ImportError:
22+
import cPickle as pickle
23+
24+
import numpy as np
25+
import os
26+
import sys
27+
from PIL import Image
28+
29+
30+
# need to save to specific local directories
31+
def load_train_data(dir_path="/tmp/malaria", resize_size=(128, 128)):
32+
dir_path = check_dataset_exist(dirpath=dir_path)
33+
path_train_label_1 = os.path.join(dir_path, "training_set/Parasitized")
34+
path_train_label_0 = os.path.join(dir_path, "training_set/Uninfected")
35+
train_label_1 = load_image_path(os.listdir(path_train_label_1))
36+
train_label_0 = load_image_path(os.listdir(path_train_label_0))
37+
labels = []
38+
Images = np.empty((len(train_label_1) + len(train_label_0),
39+
3, resize_size[0], resize_size[1]), dtype=np.uint8)
40+
for i in range(len(train_label_0)):
41+
image_path = os.path.join(path_train_label_0, train_label_0[i])
42+
temp_image = np.array(Image.open(image_path).resize(
43+
resize_size).convert("RGB")).transpose(2, 0, 1)
44+
Images[i] = temp_image
45+
labels.append(0)
46+
for i in range(len(train_label_1)):
47+
image_path = os.path.join(path_train_label_1, train_label_1[i])
48+
temp_image = np.array(Image.open(image_path).resize(
49+
resize_size).convert("RGB")).transpose(2, 0, 1)
50+
Images[i + len(train_label_0)] = temp_image
51+
labels.append(1)
52+
53+
Images = np.array(Images, dtype=np.float32)
54+
labels = np.array(labels, dtype=np.int32)
55+
return Images, labels
56+
57+
58+
# need to save to specific local directories
59+
def load_test_data(dir_path='/tmp/malaria', resize_size=(128, 128)):
60+
dir_path = check_dataset_exist(dirpath=dir_path)
61+
path_test_label_1 = os.path.join(dir_path, "testing_set/Parasitized")
62+
path_test_label_0 = os.path.join(dir_path, "testing_set/Uninfected")
63+
test_label_1 = load_image_path(os.listdir(path_test_label_1))
64+
test_label_0 = load_image_path(os.listdir(path_test_label_0))
65+
labels = []
66+
Images = np.empty((len(test_label_1) + len(test_label_0),
67+
3, resize_size[0], resize_size[1]), dtype=np.uint8)
68+
for i in range(len(test_label_0)):
69+
image_path = os.path.join(path_test_label_0, test_label_0[i])
70+
temp_image = np.array(Image.open(image_path).resize(
71+
resize_size).convert("RGB")).transpose(2, 0, 1)
72+
Images[i] = temp_image
73+
labels.append(0)
74+
for i in range(len(test_label_1)):
75+
image_path = os.path.join(path_test_label_1, test_label_1[i])
76+
temp_image = np.array(Image.open(image_path).resize(
77+
resize_size).convert("RGB")).transpose(2, 0, 1)
78+
Images[i + len(test_label_0)] = temp_image
79+
labels.append(1)
80+
81+
Images = np.array(Images, dtype=np.float32)
82+
labels = np.array(labels, dtype=np.int32)
83+
return Images, labels
84+
85+
86+
def load_image_path(list):
87+
new_list = []
88+
for image_path in list:
89+
if (image_path.endswith(".png") or image_path.endswith(".jpg")):
90+
new_list.append(image_path)
91+
return new_list
92+
93+
94+
def check_dataset_exist(dirpath):
95+
if not os.path.exists(dirpath):
96+
print(
97+
'Please download the malaria dataset first'
98+
)
99+
sys.exit(0)
100+
return dirpath
101+
102+
103+
def normalize(train_x, val_x):
104+
mean = [0.5339, 0.4180, 0.4460] # mean for malaria dataset
105+
std = [0.3329, 0.2637, 0.2761] # std for malaria dataset
106+
train_x /= 255
107+
val_x /= 255
108+
for ch in range(0, 2):
109+
train_x[:, ch, :, :] -= mean[ch]
110+
train_x[:, ch, :, :] /= std[ch]
111+
val_x[:, ch, :, :] -= mean[ch]
112+
val_x[:, ch, :, :] /= std[ch]
113+
return train_x, val_x
114+
115+
116+
def load(dir_path):
117+
train_x, train_y = load_train_data(dir_path=dir_path)
118+
val_x, val_y = load_test_data(dir_path=dir_path)
119+
train_x, val_x = normalize(train_x, val_x)
120+
train_y = train_y.flatten()
121+
val_y = val_y.flatten()
122+
return train_x, train_y, val_x, val_y
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
from singa import layer
20+
from singa import model
21+
from singa import tensor
22+
from singa import opt
23+
from singa import device
24+
25+
import numpy as np
26+
27+
np_dtype = {"float16": np.float16, "float32": np.float32}
28+
29+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
30+
31+
class CNN(model.Model):
32+
33+
def __init__(self, num_classes=10, num_channels=1):
34+
super(CNN, self).__init__()
35+
self.num_classes = num_classes
36+
self.input_size = 128
37+
self.dimension = 4
38+
self.conv1 = layer.Conv2d(num_channels, 32, 3, padding=0, activation="RELU")
39+
self.conv2 = layer.Conv2d(32, 64, 3, padding=0, activation="RELU")
40+
self.conv3 = layer.Conv2d(64, 64, 3, padding=0, activation="RELU")
41+
self.linear1 = layer.Linear(128)
42+
self.linear2 = layer.Linear(num_classes)
43+
self.pooling1 = layer.MaxPool2d(2, 2, padding=0)
44+
self.pooling2 = layer.MaxPool2d(2, 2, padding=0)
45+
self.pooling3 = layer.MaxPool2d(2, 2, padding=0)
46+
self.relu = layer.ReLU()
47+
self.flatten = layer.Flatten()
48+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
49+
self.sigmoid = layer
50+
51+
def forward(self, x):
52+
y = self.conv1(x)
53+
y = self.pooling1(y)
54+
y = self.conv2(y)
55+
y = self.pooling2(y)
56+
y = self.conv3(y)
57+
y = self.pooling3(y)
58+
y = self.flatten(y)
59+
y = self.linear1(y)
60+
y = self.relu(y)
61+
y = self.linear2(y)
62+
return y
63+
64+
def train_one_batch(self, x, y, dist_option, spars):
65+
out = self.forward(x)
66+
loss = self.softmax_cross_entropy(out, y)
67+
68+
if dist_option == 'plain':
69+
self.optimizer(loss)
70+
elif dist_option == 'half':
71+
self.optimizer.backward_and_update_half(loss)
72+
elif dist_option == 'partialUpdate':
73+
self.optimizer.backward_and_partial_update(loss)
74+
elif dist_option == 'sparseTopK':
75+
self.optimizer.backward_and_sparse_update(loss,
76+
topK=True,
77+
spars=spars)
78+
elif dist_option == 'sparseThreshold':
79+
self.optimizer.backward_and_sparse_update(loss,
80+
topK=False,
81+
spars=spars)
82+
return out, loss
83+
84+
def set_optimizer(self, optimizer):
85+
self.optimizer = optimizer
86+
87+
88+
class MLP(model.Model):
89+
90+
def __init__(self, perceptron_size=100, num_classes=10):
91+
super(MLP, self).__init__()
92+
self.num_classes = num_classes
93+
self.dimension = 2
94+
95+
self.relu = layer.ReLU()
96+
self.linear1 = layer.Linear(perceptron_size)
97+
self.linear2 = layer.Linear(num_classes)
98+
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
99+
100+
def forward(self, inputs):
101+
y = self.linear1(inputs)
102+
y = self.relu(y)
103+
y = self.linear2(y)
104+
return y
105+
106+
def train_one_batch(self, x, y, dist_option, spars):
107+
out = self.forward(x)
108+
loss = self.softmax_cross_entropy(out, y)
109+
110+
if dist_option == 'plain':
111+
self.optimizer(loss)
112+
elif dist_option == 'half':
113+
self.optimizer.backward_and_update_half(loss)
114+
elif dist_option == 'partialUpdate':
115+
self.optimizer.backward_and_partial_update(loss)
116+
elif dist_option == 'sparseTopK':
117+
self.optimizer.backward_and_sparse_update(loss,
118+
topK=True,
119+
spars=spars)
120+
elif dist_option == 'sparseThreshold':
121+
self.optimizer.backward_and_sparse_update(loss,
122+
topK=False,
123+
spars=spars)
124+
return out, loss
125+
126+
def set_optimizer(self, optimizer):
127+
self.optimizer = optimizer
128+
129+
130+
def create_model(model_option='cnn', **kwargs):
131+
"""Constructs a CNN model.
132+
133+
Args:
134+
pretrained (bool): If True, returns a pre-trained model.
135+
136+
Returns:
137+
The created CNN model.
138+
"""
139+
model = CNN(**kwargs)
140+
if model_option=='mlp':
141+
model = MLP(**kwargs)
142+
143+
return model
144+
145+
146+
__all__ = ['CNN', 'MLP', 'create_model']

‎examples/malaria_cnn/train_cnn.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
from singa import singa_wrap as singa
2+
from singa import device
3+
from singa import tensor
4+
from singa import opt
5+
import numpy as np
6+
import time
7+
import argparse
8+
import sys
9+
from PIL import Image
10+
11+
np_dtype = {"float16": np.float16, "float32": np.float32}
12+
13+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
14+
15+
16+
# Data augmentation
17+
def augmentation(x, batch_size):
18+
xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
19+
for data_num in range(0, batch_size):
20+
offset = np.random.randint(8, size=2)
21+
x[data_num, :, :, :] = xpad[data_num, :,
22+
offset[0]:offset[0] + x.shape[2],
23+
offset[1]:offset[1] + x.shape[2]]
24+
if_flip = np.random.randint(2)
25+
if (if_flip):
26+
x[data_num, :, :, :] = x[data_num, :, :, ::-1]
27+
return x
28+
29+
30+
# Calculate accuracy
31+
def accuracy(pred, target):
32+
# y is network output to be compared with ground truth (int)
33+
y = np.argmax(pred, axis=1)
34+
a = y == target
35+
correct = np.array(a, "int").sum()
36+
return correct
37+
38+
39+
# Data partition according to the rank
40+
def partition(global_rank, world_size, train_x, train_y, val_x, val_y):
41+
# Partition training data
42+
data_per_rank = train_x.shape[0] // world_size
43+
idx_start = global_rank * data_per_rank
44+
idx_end = (global_rank + 1) * data_per_rank
45+
train_x = train_x[idx_start:idx_end]
46+
train_y = train_y[idx_start:idx_end]
47+
48+
# Partition evaluation data
49+
data_per_rank = val_x.shape[0] // world_size
50+
idx_start = global_rank * data_per_rank
51+
idx_end = (global_rank + 1) * data_per_rank
52+
val_x = val_x[idx_start:idx_end]
53+
val_y = val_y[idx_start:idx_end]
54+
return train_x, train_y, val_x, val_y
55+
56+
57+
# Function to all reduce NUMPY accuracy and loss from multiple devices
58+
def reduce_variable(variable, dist_opt, reducer):
59+
reducer.copy_from_numpy(variable)
60+
dist_opt.all_reduce(reducer.data)
61+
dist_opt.wait()
62+
output = tensor.to_numpy(reducer)
63+
return output
64+
65+
66+
def resize_dataset(x, image_size):
67+
num_data = x.shape[0]
68+
dim = x.shape[1]
69+
X = np.zeros(shape=(num_data, dim, image_size, image_size),
70+
dtype=np.float32)
71+
for n in range(0, num_data):
72+
for d in range(0, dim):
73+
X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize(
74+
(image_size, image_size), Image.BILINEAR),
75+
dtype=np.float32)
76+
return X
77+
78+
79+
def run(global_rank,
80+
world_size,
81+
dir_path,
82+
max_epoch,
83+
batch_size,
84+
model,
85+
data,
86+
sgd,
87+
graph,
88+
verbosity,
89+
dist_option='plain',
90+
spars=None,
91+
precision='float32'):
92+
# now CPU version only, could change to GPU device for GPU-support machines
93+
dev = device.get_default_device()
94+
dev.SetRandSeed(0)
95+
np.random.seed(0)
96+
if data == 'malaria':
97+
from data import malaria
98+
train_x, train_y, val_x, val_y = malaria.load(dir_path=dir_path)
99+
else:
100+
print(
101+
'Wrong dataset!'
102+
)
103+
sys.exit(0)
104+
105+
num_channels = train_x.shape[1]
106+
image_size = train_x.shape[2]
107+
data_size = np.prod(train_x.shape[1:train_x.ndim]).item()
108+
num_classes = (np.max(train_y) + 1).item()
109+
110+
if model == 'cnn':
111+
from model import cnn
112+
model = cnn.create_model(num_channels=num_channels,
113+
num_classes=num_classes)
114+
else:
115+
print(
116+
'Wrong model!'
117+
)
118+
sys.exit(0)
119+
120+
# For distributed training, sequential has better performance
121+
if hasattr(sgd, "communicator"):
122+
DIST = True
123+
sequential = True
124+
else:
125+
DIST = False
126+
sequential = False
127+
128+
if DIST:
129+
train_x, train_y, val_x, val_y = partition(global_rank, world_size,
130+
train_x, train_y, val_x,
131+
val_y)
132+
133+
if model.dimension == 4:
134+
tx = tensor.Tensor(
135+
(batch_size, num_channels, model.input_size, model.input_size), dev,
136+
singa_dtype[precision])
137+
elif model.dimension == 2:
138+
tx = tensor.Tensor((batch_size, data_size),
139+
dev, singa_dtype[precision])
140+
np.reshape(train_x, (train_x.shape[0], -1))
141+
np.reshape(val_x, (val_x.shape[0], -1))
142+
143+
ty = tensor.Tensor((batch_size,), dev, tensor.int32)
144+
num_train_batch = train_x.shape[0] // batch_size
145+
num_val_batch = val_x.shape[0] // batch_size
146+
idx = np.arange(train_x.shape[0], dtype=np.int32)
147+
148+
# Attach model to graph
149+
model.set_optimizer(sgd)
150+
model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
151+
dev.SetVerbosity(verbosity)
152+
153+
# Training and evaluation loop
154+
for epoch in range(max_epoch):
155+
start_time = time.time()
156+
np.random.shuffle(idx)
157+
158+
if global_rank == 0:
159+
print('Starting Epoch %d:' % (epoch))
160+
161+
# Training phase
162+
train_correct = np.zeros(shape=[1], dtype=np.float32)
163+
test_correct = np.zeros(shape=[1], dtype=np.float32)
164+
train_loss = np.zeros(shape=[1], dtype=np.float32)
165+
166+
model.train()
167+
for b in range(num_train_batch):
168+
# if b % 100 == 0:
169+
# print ("b: \n", b)
170+
# Generate the patch data in this iteration
171+
x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
172+
if model.dimension == 4:
173+
x = augmentation(x, batch_size)
174+
if (image_size != model.input_size):
175+
x = resize_dataset(x, model.input_size)
176+
x = x.astype(np_dtype[precision])
177+
y = train_y[idx[b * batch_size:(b + 1) * batch_size]]
178+
179+
# Copy the patch data into input tensors
180+
tx.copy_from_numpy(x)
181+
ty.copy_from_numpy(y)
182+
183+
# Train the model
184+
out, loss = model(tx, ty, dist_option, spars)
185+
train_correct += accuracy(tensor.to_numpy(out), y)
186+
train_loss += tensor.to_numpy(loss)[0]
187+
188+
if DIST:
189+
# Reduce the evaluation accuracy and loss from multiple devices
190+
reducer = tensor.Tensor((1,), dev, tensor.float32)
191+
train_correct = reduce_variable(train_correct, sgd, reducer)
192+
train_loss = reduce_variable(train_loss, sgd, reducer)
193+
194+
if global_rank == 0:
195+
print('Training loss = %f, training accuracy = %f' %
196+
(train_loss, train_correct /
197+
(num_train_batch * batch_size * world_size)),
198+
flush=True)
199+
200+
# Evaluation phase
201+
model.eval()
202+
for b in range(num_val_batch):
203+
x = val_x[b * batch_size:(b + 1) * batch_size]
204+
if model.dimension == 4:
205+
if (image_size != model.input_size):
206+
x = resize_dataset(x, model.input_size)
207+
x = x.astype(np_dtype[precision])
208+
y = val_y[b * batch_size:(b + 1) * batch_size]
209+
tx.copy_from_numpy(x)
210+
ty.copy_from_numpy(y)
211+
out_test = model(tx)
212+
test_correct += accuracy(tensor.to_numpy(out_test), y)
213+
214+
if DIST:
215+
# Reduce the evaulation accuracy from multiple devices
216+
test_correct = reduce_variable(test_correct, sgd, reducer)
217+
218+
# Output the evaluation accuracy
219+
if global_rank == 0:
220+
print('Evaluation accuracy = %f, Elapsed Time = %fs' %
221+
(test_correct / (num_val_batch * batch_size * world_size),
222+
time.time() - start_time),
223+
flush=True)
224+
225+
dev.PrintTimeProfiling()
226+
227+
228+
if __name__ == '__main__':
229+
# Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
230+
parser = argparse.ArgumentParser(
231+
description='Training using the autograd and graph.')
232+
parser.add_argument(
233+
'model',
234+
choices=['cnn'],
235+
default='cnn')
236+
parser.add_argument('data',
237+
choices=['malaria'],
238+
default='malaria')
239+
parser.add_argument('-p',
240+
choices=['float32', 'float16'],
241+
default='float32',
242+
dest='precision')
243+
parser.add_argument('-dir',
244+
'--dir-path',
245+
default="/tmp/malaria",
246+
type=str,
247+
help='the directory to store the malaria dataset',
248+
dest='dir_path')
249+
parser.add_argument('-m',
250+
'--max-epoch',
251+
default=100,
252+
type=int,
253+
help='maximum epochs',
254+
dest='max_epoch')
255+
parser.add_argument('-b',
256+
'--batch-size',
257+
default=64,
258+
type=int,
259+
help='batch size',
260+
dest='batch_size')
261+
parser.add_argument('-l',
262+
'--learning-rate',
263+
default=0.005,
264+
type=float,
265+
help='initial learning rate',
266+
dest='lr')
267+
parser.add_argument('-g',
268+
'--disable-graph',
269+
default='True',
270+
action='store_false',
271+
help='disable graph',
272+
dest='graph')
273+
parser.add_argument('-v',
274+
'--log-verbosity',
275+
default=0,
276+
type=int,
277+
help='logging verbosity',
278+
dest='verbosity')
279+
280+
args = parser.parse_args()
281+
282+
sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5,
283+
dtype=singa_dtype[args.precision])
284+
run(0,
285+
1,
286+
args.dir_path,
287+
args.max_epoch,
288+
args.batch_size,
289+
args.model,
290+
args.data,
291+
sgd,
292+
args.graph,
293+
args.verbosity,
294+
precision=args.precision)

‎examples/msmodel_mlp/native.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
from singa import tensor
21+
from singa.tensor import Tensor
22+
from singa import autograd
23+
from singa import opt
24+
import numpy as np
25+
from singa import device
26+
import argparse
27+
28+
np_dtype = {"float16": np.float16, "float32": np.float32}
29+
30+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
31+
32+
if __name__ == "__main__":
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument('-p',
35+
choices=['float32', 'float16'],
36+
default='float32',
37+
dest='precision')
38+
parser.add_argument('-m',
39+
'--max-epoch',
40+
default=1001,
41+
type=int,
42+
help='maximum epochs',
43+
dest='max_epoch')
44+
args = parser.parse_args()
45+
46+
np.random.seed(0)
47+
48+
autograd.training = True
49+
50+
# prepare training data in numpy array
51+
52+
# generate the boundary
53+
f = lambda x: (5 * x + 1)
54+
bd_x = np.linspace(-1.0, 1, 200)
55+
bd_y = f(bd_x)
56+
57+
# generate the training data
58+
x = np.random.uniform(-1, 1, 400)
59+
y = f(x) + 2 * np.random.randn(len(x))
60+
61+
# convert training data to 2d space
62+
label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)])
63+
data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32)
64+
65+
def to_categorical(y, num_classes):
66+
"""
67+
Converts a class vector (integers) to binary class matrix.
68+
Args:
69+
y: class vector to be converted into a matrix
70+
(integers from 0 to num_classes).
71+
num_classes: total number of classes.
72+
Returns:
73+
A binary matrix representation of the input.
74+
"""
75+
y = np.array(y, dtype="int")
76+
n = y.shape[0]
77+
categorical = np.zeros((n, num_classes))
78+
categorical[np.arange(n), y] = 1
79+
return categorical
80+
81+
label = to_categorical(label, 2).astype(np.float32)
82+
print("train_data_shape:", data.shape)
83+
print("train_label_shape:", label.shape)
84+
85+
precision = singa_dtype[args.precision]
86+
np_precision = np_dtype[args.precision]
87+
88+
dev = device.create_cuda_gpu()
89+
90+
inputs = Tensor(data=data, device=dev)
91+
target = Tensor(data=label, device=dev)
92+
93+
inputs = inputs.as_type(precision)
94+
target = target.as_type(tensor.int32)
95+
96+
w0_np = np.random.normal(0, 0.1, (2, 3)).astype(np_precision)
97+
w0 = Tensor(data=w0_np,
98+
device=dev,
99+
dtype=precision,
100+
requires_grad=True,
101+
stores_grad=True)
102+
b0 = Tensor(shape=(3,),
103+
device=dev,
104+
dtype=precision,
105+
requires_grad=True,
106+
stores_grad=True)
107+
b0.set_value(0.0)
108+
109+
w1_np = np.random.normal(0, 0.1, (3, 2)).astype(np_precision)
110+
w1 = Tensor(data=w1_np,
111+
device=dev,
112+
dtype=precision,
113+
requires_grad=True,
114+
stores_grad=True)
115+
b1 = Tensor(shape=(2,),
116+
device=dev,
117+
dtype=precision,
118+
requires_grad=True,
119+
stores_grad=True)
120+
b1.set_value(0.0)
121+
122+
sgd = opt.SGD(0.05, 0.8)
123+
124+
# training process
125+
for i in range(args.max_epoch):
126+
x = autograd.matmul(inputs, w0)
127+
x = autograd.add_bias(x, b0)
128+
x = autograd.relu(x)
129+
x = autograd.matmul(x, w1)
130+
x = autograd.add_bias(x, b1)
131+
loss = autograd.softmax_cross_entropy(x, target)
132+
sgd(loss)
133+
134+
if i % 100 == 0:
135+
print("%d, training loss = " % i, tensor.to_numpy(loss)[0])

‎examples/trans/README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ optional arguments:
4646
--n_layers int transformer model n_layers default 6
4747
```
4848

49-
run the example
49+
**run the example**
50+
51+
step 1: Download the dataset to the cmn-eng directory.
52+
53+
step 2: Run the following script.
54+
5055
```
51-
python train.py --dataset cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01
56+
python train.py --dataset cmn-eng/cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01
5257
```

‎examples/trans/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ def __len__(self):
5656

5757

5858
class CmnDataset:
59-
def __init__(self, path='cmn-eng/cmn.txt', shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0):
59+
def __init__(self, path, shuffle=False, batch_size=32, train_ratio=0.8, random_seed=0):
6060
"""
6161
cmn dataset, download from https://www.manythings.org/anki/, contains 29909 Chinese and English translation
6262
pairs, the pair format: English + TAB + Chinese + TAB + Attribution
6363
Args:
64-
path: the path of the dataset, default 'cmn-eng/cnn.txt'
64+
path: the path of the dataset
6565
shuffle: shuffle the dataset, default False
6666
batch_size: the size of every batch, default 32
6767
train_ratio: the proportion of the training set to the total data set, default 0.8

‎examples/trans/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
#
1919

2020
# run this example
21-
python train.py --dataset cmn-2000.txt --max-epoch 300 --batch-size 32 --lr 0.01
21+
python train.py --dataset cmn-eng/cmn-2000.txt --max-epoch 100 --batch-size 32 --lr 0.01

‎examples/trans/train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def run(args):
3535
np.random.seed(args.seed)
3636

3737
batch_size = args.batch_size
38-
cmn_dataset = CmnDataset(path="cmn-eng/"+args.dataset, shuffle=args.shuffle, batch_size=batch_size, train_ratio=0.8)
38+
cmn_dataset = CmnDataset(path=args.dataset, shuffle=args.shuffle, batch_size=batch_size, train_ratio=0.8)
3939

4040
print("【step-0】 prepare dataset...")
4141
src_vocab_size, tgt_vocab_size = cmn_dataset.en_vab_size, cmn_dataset.cn_vab_size
@@ -151,8 +151,7 @@ def run(args):
151151

152152
if __name__ == '__main__':
153153
parser = argparse.ArgumentParser(description="Training Transformer Model.")
154-
parser.add_argument('--dataset', choices=['cmn.txt', 'cmn-15000.txt',
155-
'cmn-2000.txt'], default='cmn-2000.txt')
154+
parser.add_argument('--dataset', default='cmn-eng/cmn-2000.txt')
156155
parser.add_argument('--max-epoch', default=100, type=int, help='maximum epochs.', dest='max_epoch')
157156
parser.add_argument('--batch-size', default=64, type=int, help='batch size', dest='batch_size')
158157
parser.add_argument('--shuffle', default=True, type=bool, help='shuffle the dataset', dest='shuffle')

0 commit comments

Comments
 (0)
Please sign in to comment.