Skip to content

Commit 28bace8

Browse files
authoredMar 1, 2020
Add files via upload
1 parent 268b76e commit 28bace8

28 files changed

+13310
-0
lines changed
 

‎Classifier/Decision_Visualization.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
import matplotlib.pyplot as plt
5+
plt.rcParams['figure.figsize'] = 8,8
6+
7+
from scipy.ndimage.interpolation import zoom
8+
import VGG
9+
import ResNet
10+
import numpy as np
11+
import os
12+
import gradcamutils
13+
from DenseNet import densenet
14+
from PIL import Image
15+
16+
# use this environment flag to change which GPU to use
17+
os.environ["CUDA_VISIBLE_DEVICES"]="-1" # specify which GPU(s) to be used
18+
19+
vggModel = VGG.VGG19((352,320,1),4) #set up model architecture
20+
vggModel.summary()
21+
vggModel.load_weights("/home/reza/DeepKneeExplainer/resources/old_models/balanced JSN/VGG19-JSNnewbalance-front-0.8896.h5") #load weights
22+
23+
denseNetModel = densenet.DenseNetImageNet161(input_shape=(352,320,1),classes=4, weights=None)
24+
denseNetModel.summary()
25+
denseNetModel.load_weights("/home/reza/DeepKneeExplainer/resources/old_models/balanced JSN/DenseNet161-JSNnewbalance-XRfront-0.8965.h5")
26+
27+
resNetModel = ResNet.ResNet34(input_shape=(352,320,1),classes=4)
28+
resNetModel.summary()
29+
resNetModel.load_weights("/home/reza/DeepKneeExplainer/resources/old_models/balanced JSN/ResNet34-JSNnewbalance-front-0.8395.h5")
30+
31+
img=Image.open("/home/reza/DeepKneeExplainer/resources/Data/XR/balancedXR ROI/front/validation/0082_R.png") #open image you want to visualize
32+
img=np.array(img.resize((320,352), Image.ANTIALIAS))
33+
im = img.reshape(1,352,320,1)
34+
gradcam=gradcamutils.grad_cam(vggModel,im,layer_name='block5_conv4') #for VGG, here there are parameters to set image width (W) and height (H)
35+
gradcamplus=gradcamutils.grad_cam_plus(vggModel,im,layer_name='block5_conv4')
36+
37+
fig, ax = plt.subplots(nrows=1,ncols=3)
38+
plt.subplot(131)
39+
plt.imshow(img, cmap ='gray')
40+
plt.title("input image")
41+
plt.subplot(132)
42+
plt.imshow(img, cmap ='gray')
43+
plt.imshow(gradcam,alpha=0.5,cmap="jet")
44+
plt.title("Grad-CAM")
45+
plt.subplot(133)
46+
plt.imshow(img, cmap ='gray')
47+
plt.imshow(gradcamplus,alpha=0.45,cmap="jet")
48+
plt.title("Grad-CAM++")
49+
#plt.show()
50+
51+
plt.savefig('VGGPlots.png')
52+
53+
plt.figure(figsize=(7,1))
54+
55+
probs = [0.12, 0.30, 0.48, 0.10]
56+
probs = np.asarray(probs, dtype=np.float32)
57+
58+
objects = ('JSN0','JSN1','JSN2','JSN3')
59+
y_pos = np.arange(len(objects))
60+
61+
for jsn in range(4):
62+
plt.text(jsn-0.2, 0.35, "%.2f" % np.round(probs[jsn],2), fontsize=10)
63+
64+
plt.bar(np.array([0, 1, 2, 3]), probs, color='red',align='center',tick_label=['JSN0','JSN1','JSN2','JSN3'],alpha=0.3)
65+
plt.ylim(0,1)
66+
plt.yticks([])
67+
plt.xticks(y_pos, objects)
68+
plt.savefig('VGGNetprob.png')
69+
70+
gradcam=gradcamutils.grad_cam(denseNetModel,im,layer_name='feature') #for DenseNet
71+
gradcamplus=gradcamutils.grad_cam_plus(denseNetModel,im,layer_name='feature')
72+
73+
fig, ax = plt.subplots(nrows=1,ncols=3)
74+
plt.subplot(131)
75+
plt.imshow(img, cmap ='gray')
76+
plt.title("input image")
77+
plt.subplot(132)
78+
plt.imshow(img, cmap ='gray')
79+
plt.imshow(gradcam,alpha=0.5,cmap="jet")
80+
plt.title("Grad-CAM")
81+
plt.subplot(133)
82+
plt.imshow(img, cmap ='gray')
83+
plt.imshow(gradcamplus,alpha=0.45,cmap="jet")
84+
plt.title("Grad-CAM++")
85+
86+
plt.savefig('DenseNetPlots.png')
87+
88+
plt.figure(figsize=(7,1))
89+
90+
probs = [0.03, 0.22, 0.51, 0.25]
91+
probs = np.asarray(probs, dtype=np.float32)
92+
93+
for jsn in range(4):
94+
plt.text(jsn-0.2, 0.35, "%.2f" % np.round(probs[jsn],2), fontsize=10)
95+
96+
plt.bar(np.array([0, 1, 2, 3]), probs, color='red',align='center',tick_label=['JSN0','JSN1','JSN2','JSN3'],alpha=0.3)
97+
plt.ylim(0,1)
98+
plt.yticks([])
99+
plt.xticks(y_pos, objects)
100+
plt.savefig('DenseNetprob.png')
101+
102+
gradcam=gradcamutils.grad_cam(resNetModel,im,layer_name='conv2d_196') #for DenseNet
103+
gradcamplus=gradcamutils.grad_cam_plus(resNetModel,im,layer_name='conv2d_196')
104+
105+
fig, ax = plt.subplots(nrows=1,ncols=3)
106+
plt.subplot(131)
107+
plt.imshow(img, cmap ='gray')
108+
plt.title("input image")
109+
plt.subplot(132)
110+
plt.imshow(img, cmap ='gray')
111+
plt.imshow(gradcam,alpha=0.5,cmap="jet")
112+
plt.title("Grad-CAM")
113+
plt.subplot(133)
114+
plt.imshow(img, cmap ='gray')
115+
plt.imshow(gradcamplus,alpha=0.45,cmap="jet")
116+
plt.title("Grad-CAM++")
117+
118+
plt.savefig('ResNetPlots.png')
119+
120+
plt.figure(figsize=(7,1))
121+
122+
probs = [0.01, 0.25, 0.55, 0.10]
123+
probs = np.asarray(probs, dtype=np.float32)
124+
125+
for jsn in range(4):
126+
plt.text(jsn-0.2, 0.35, "%.2f" % np.round(probs[jsn],2), fontsize=10)
127+
128+
plt.bar(np.array([0, 1, 2, 3]), probs, color='red',align='center',tick_label=['JSN0','JSN1','JSN2','JSN3'],alpha=0.3)
129+
plt.ylim(0,1)
130+
plt.yticks([])
131+
plt.xticks(y_pos, objects)
132+
plt.savefig('ResNetprob.png')
+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from __future__ import print_function
2+
import numpy as np
3+
4+
np.random.seed(3768) # for reproducibility
5+
from keras.preprocessing import sequence
6+
from keras.utils import np_utils
7+
from keras.models import Sequential,load_model,Model
8+
from keras.layers import Dense, Dropout, Activation, Flatten
9+
from keras.layers import *
10+
from keras.optimizers import SGD
11+
from random import shuffle
12+
import time
13+
import csv
14+
import os
15+
import densenet
16+
from keras.callbacks import CSVLogger
17+
from keras import callbacks
18+
from PIL import Image
19+
from keras.preprocessing.image import ImageDataGenerator
20+
import tensorflow as tf
21+
import keras
22+
from sklearn.metrics import classification_report
23+
import sklearn.metrics as sklm
24+
from keras.callbacks import EarlyStopping
25+
from keras.callbacks import LearningRateScheduler
26+
from keras import initializers
27+
import keras
28+
import tensorflow as tf
29+
30+
def get_session():
31+
config = tf.ConfigProto()
32+
config.gpu_options.allow_growth = True
33+
return tf.Session(config=config)
34+
# use this environment flag to change which GPU to use
35+
#os.environ["CUDA_VISIBLE_DEVICES"] = ""
36+
# set the modified tf session as backend in keras
37+
keras.backend.tensorflow_backend.set_session(get_session())
38+
39+
def dense_to_one_hot(labels_dense,num_clases=5):
40+
return np.eye(num_clases)[labels_dense]
41+
42+
def load():
43+
imgList=[]
44+
labelList=[]
45+
reader = open("/data/jiao/newlabel.csv") #label file path
46+
data=reader.readlines()
47+
files = os.listdir('/data/jiao/XR/ROI_resize/front/training/') #training path for ROIs/MRIs
48+
shuffle(files)
49+
for file in files:
50+
if file.endswith(".xml"):continue
51+
fi_d = os.path.join('/data/jiao/XR/ROI_resize/front/training/',file) #training path for ROIs/MRIs
52+
img=Image.open(fi_d).convert('L')
53+
im=np.array(img.resize((320,352), Image.ANTIALIAS))
54+
patient=file.split('_')[0]
55+
direction=file.split('_')[1].split('.')[0]
56+
label="q"
57+
for row in data:
58+
if patient in row.split(",")[0]:
59+
if "L" in direction:
60+
label=row.split(",")[3]
61+
else:
62+
label=row.split(",")[6]
63+
break
64+
if "V" in file: #for dataset balance, I use Grade 3 images from other stages, they are named with stagename as V3
65+
label="3"
66+
if "8" not in label and "9" not in label and "X" not in label and '.' not in label: #in the labels, there are 8, 9 and X which are useless in our case.
67+
#if "." in label:
68+
#label='4'
69+
label= dense_to_one_hot(int(label),4)
70+
imgList.append(im)
71+
labelList.append(label)
72+
return np.array(imgList),np.array(labelList)
73+
74+
def load_val():
75+
imgList=[]
76+
labelList=[]
77+
reader = open("/data/jiao/newlabel.csv") #label file path
78+
data=reader.readlines()
79+
files = os.listdir('/data/jiao/XR/ROI_resize/front/validation/') #test path for ROIs/MRIs
80+
for file in files:
81+
if file.endswith(".xml"):continue
82+
fi_d = os.path.join('/data/jiao/XR/ROI_resize/front/validation/',file) #test path for ROIs/MRIs
83+
img=Image.open(fi_d).convert('L')
84+
im=np.array(img.resize((320,352), Image.ANTIALIAS))
85+
patient=file.split('_')[0]
86+
direction=file.split('_')[1].split('.')[0]
87+
label="q"
88+
for row in data:
89+
if patient in row.split(",")[0]:
90+
if "L" in direction:
91+
label=row.split(",")[3]
92+
else:
93+
label=row.split(",")[6]
94+
break
95+
if "V" in file:
96+
label="3"
97+
if "8" not in label and "9" not in label and "X" not in label and '.' not in label:
98+
#if "." in label:
99+
#label='4'
100+
label= dense_to_one_hot(int(label),4)
101+
imgList.append(im)
102+
labelList.append(label)
103+
return np.array(imgList),np.array(labelList)
104+
105+
def load_valY(): #load labels with decimal format
106+
imgList=[]
107+
labelList=[]
108+
reader = open("/data/jiao/newlabel.csv")
109+
data=reader.readlines()
110+
files = os.listdir('/data/jiao/XR/ROI_resize/front/validation/')
111+
for file in files:
112+
if file.endswith(".xml"):continue
113+
patient=file.split('_')[0]
114+
direction=file.split('_')[1].split('.')[0]
115+
label="q"
116+
for row in data:
117+
if patient in row.split(",")[0]:
118+
if "L" in direction:
119+
label=row.split(",")[3]
120+
else:
121+
label=row.split(",")[6]
122+
break
123+
if "V" in file:
124+
label="3"
125+
if "8" not in label and "9" not in label and "X" not in label and '.' not in label:
126+
#if "." in label:
127+
#label='4'
128+
labelList.append(int(label))
129+
return np.array(labelList)
130+
131+
132+
batch_size=32
133+
model = densenet.DenseNetImageNet201(input_shape=(352,320,1),classes=4, weights=None) #here you can change Densenet for 121,161,169 and 201 or your own architectures, the detail settings are input_shape=None, bottleneck=True,reduction=0.5, dropout_rate=0.0,weight_decay=1e-6,include_top=True, weights='imagenet',input_tensor=None,classes=1000, activation='softmax'
134+
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.95, nesterov=True)
135+
model.compile(optimizer=sgd, loss='mse',metrics=['accuracy'])
136+
137+
datagen = ImageDataGenerator(
138+
featurewise_center=True,
139+
samplewise_center=False, # set each sample mean to 0
140+
featurewise_std_normalization=True,
141+
samplewise_std_normalization=False)
142+
X_train, Y_train = load()
143+
X_test, Y_test = load_val()
144+
X_train = X_train.reshape( len(X_train), len(X_train[0]), len(X_train[0][0]),1)
145+
X_test = X_test.reshape( len(X_test), len(X_test[0]), len(X_test[0][0]),1)
146+
X_train = X_train.astype('float32')
147+
X_test = X_test.astype('float32')
148+
X_train /= 255
149+
X_test /= 255
150+
datagen.fit(X_train)
151+
for i in range(len(X_test)):
152+
X_test[i] = datagen.standardize(X_test[i])
153+
earlystop=EarlyStopping(monitor='val_acc', min_delta=0, patience=300, verbose=1, mode='auto', restore_best_weights=True)
154+
history = model.fit_generator(datagen.flow(X_train, Y_train,batch_size=batch_size),steps_per_epoch=32,epochs=4096,shuffle=True,validation_data=(X_test, Y_test), verbose=1,callbacks=[earlystop])
155+
score, acc = model.evaluate(X_test,Y_test,batch_size=batch_size)
156+
print("Accuracy:",acc)
157+
if acc>0.6: #if the accuracy is higher than 60%, the models are saved
158+
model.save_weights("DenseNet-JSNnew-front.h5")
159+
y_pred = model.predict(X_test)
160+
Y_predict = y_pred.argmax(axis=-1)
161+
f=open('DenseNetRESULTS-JSNnew-front.txt','a') #create performance report
162+
f.write(classification_report(load_valY(), Y_predict))
163+
f.write(str(sklm.cohen_kappa_score(load_valY(), Y_predict))+","+str(acc)+","+str(score)+"\n")
164+
print(classification_report(load_valY(), Y_predict))

0 commit comments

Comments
 (0)