Skip to content

Commit e217b4d

Browse files
authored
Merge pull request #1243 from streamjoin/script-hematologic-disease-prediction
Adding the training script for the hematologic disease prediction
2 parents 4575595 + a04cb9f commit e217b4d

File tree

1 file changed

+211
-0
lines changed
  • examples/healthcare/application/Hematologic_Disease

1 file changed

+211
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
import time
21+
from singa import singa_wrap as singa
22+
from singa import device
23+
from singa import tensor
24+
from singa import opt
25+
import numpy as np
26+
from tqdm import tqdm
27+
import argparse
28+
import sys
29+
sys.path.append("../../..")
30+
31+
from healthcare.data import bloodmnist
32+
from healthcare.models import hematologic_net
33+
34+
np_dtype = {"float16": np.float16, "float32": np.float32}
35+
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}
36+
37+
38+
def accuracy(pred, target):
39+
"""Compute recall accuracy.
40+
41+
Args:
42+
pred (Numpy ndarray): Prediction array, should be in shape (B, C)
43+
target (Numpy ndarray): Ground truth array, should be in shape (B, )
44+
45+
Return:
46+
correct (Float): Recall accuracy
47+
"""
48+
# y is network output to be compared with ground truth (int)
49+
y = np.argmax(pred, axis=1)
50+
a = (y[:,None]==target).sum()
51+
correct = np.array(a, "int").sum()
52+
return correct
53+
54+
def run(dir_path,
55+
max_epoch,
56+
batch_size,
57+
model,
58+
data,
59+
lr,
60+
graph,
61+
verbosity,
62+
dist_option='plain',
63+
spars=None,
64+
precision='float32'):
65+
# Start training
66+
dev = device.create_cpu_device()
67+
dev.SetRandSeed(0)
68+
np.random.seed(0)
69+
if data == 'bloodmnist':
70+
train_dataset, val_dataset, num_class = bloodmnist.load(dir_path=dir_path)
71+
else:
72+
print(
73+
'Wrong dataset!'
74+
)
75+
sys.exit(0)
76+
77+
if model == 'cnn':
78+
model = hematologic_net.create_model(num_classes=num_class)
79+
else:
80+
print(
81+
'Wrong model!'
82+
)
83+
sys.exit(0)
84+
85+
# Model configuration for CNN
86+
# criterion = layer.SoftMaxCrossEntropy()
87+
optimizer_ft = opt.Adam(lr)
88+
89+
tx = tensor.Tensor(
90+
(batch_size, 3, model.input_size, model.input_size), dev,
91+
singa_dtype[precision])
92+
ty = tensor.Tensor((batch_size,), dev, tensor.int32)
93+
94+
num_train_batch = train_dataset.__len__() // batch_size
95+
num_val_batch = val_dataset.__len__() // batch_size
96+
idx = np.arange(train_dataset.__len__(), dtype=np.int32)
97+
98+
# Attach model to graph
99+
model.set_optimizer(optimizer_ft)
100+
model.compile([tx], is_train=True, use_graph=graph, sequential=False)
101+
dev.SetVerbosity(verbosity)
102+
103+
# Training and evaluation loop
104+
for epoch in range(max_epoch):
105+
print(f'Epoch {epoch}:')
106+
107+
start_time = time.time()
108+
109+
train_correct = np.zeros(shape=[1], dtype=np.float32)
110+
test_correct = np.zeros(shape=[1], dtype=np.float32)
111+
train_loss = np.zeros(shape=[1], dtype=np.float32)
112+
113+
# Training part
114+
model.train()
115+
for b in tqdm(range(num_train_batch)):
116+
# Extract batch from image list
117+
x, y = train_dataset.batchgenerator(idx[b * batch_size:(b + 1) * batch_size],
118+
batch_size=batch_size, data_size=(3, model.input_size, model.input_size))
119+
x = x.astype(np_dtype[precision])
120+
121+
tx.copy_from_numpy(x)
122+
ty.copy_from_numpy(y)
123+
124+
out, loss = model(tx, ty, dist_option, spars)
125+
train_correct += accuracy(tensor.to_numpy(out), y)
126+
train_loss += tensor.to_numpy(loss)[0]
127+
print('Training loss = %f, training accuracy = %f' %
128+
(train_loss, train_correct /
129+
(num_train_batch * batch_size)))
130+
131+
# Validation part
132+
model.eval()
133+
for b in tqdm(range(num_val_batch)):
134+
x, y = train_dataset.batchgenerator(idx[b * batch_size:(b + 1) * batch_size],
135+
batch_size=batch_size, data_size=(3, model.input_size, model.input_size))
136+
x = x.astype(np_dtype[precision])
137+
138+
tx.copy_from_numpy(x)
139+
ty.copy_from_numpy(y)
140+
141+
out = model(tx)
142+
test_correct += accuracy(tensor.to_numpy(out), y)
143+
144+
print('Evaluation accuracy = %f, Elapsed Time = %fs' %
145+
(test_correct / (num_val_batch * batch_size),
146+
time.time() - start_time))
147+
148+
149+
if __name__ == '__main__':
150+
# Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
151+
parser = argparse.ArgumentParser(
152+
description='Training using the autograd and graph.')
153+
parser.add_argument(
154+
'model',
155+
choices=['cnn'],
156+
default='cnn')
157+
parser.add_argument('data',
158+
choices=['bloodmnist'],
159+
default='bloodmnist')
160+
parser.add_argument('-p',
161+
choices=['float32', 'float16'],
162+
default='float32',
163+
dest='precision')
164+
parser.add_argument('-dir',
165+
'--dir-path',
166+
default="/tmp/bloodmnist",
167+
type=str,
168+
help='the directory to store the bloodmnist dataset',
169+
dest='dir_path')
170+
parser.add_argument('-m',
171+
'--max-epoch',
172+
default=100,
173+
type=int,
174+
help='maximum epochs',
175+
dest='max_epoch')
176+
parser.add_argument('-b',
177+
'--batch-size',
178+
default=256,
179+
type=int,
180+
help='batch size',
181+
dest='batch_size')
182+
parser.add_argument('-l',
183+
'--learning-rate',
184+
default=0.003,
185+
type=float,
186+
help='initial learning rate',
187+
dest='lr')
188+
parser.add_argument('-g',
189+
'--disable-graph',
190+
default='True',
191+
action='store_false',
192+
help='disable graph',
193+
dest='graph')
194+
parser.add_argument('-v',
195+
'--log-verbosity',
196+
default=0,
197+
type=int,
198+
help='logging verbosity',
199+
dest='verbosity')
200+
201+
args = parser.parse_args()
202+
203+
run(args.dir_path,
204+
args.max_epoch,
205+
args.batch_size,
206+
args.model,
207+
args.data,
208+
args.lr,
209+
args.graph,
210+
args.verbosity,
211+
precision=args.precision)

0 commit comments

Comments
 (0)