Skip to content

Commit 7e46468

Browse files
committed
Update cnn-svm.py
1 parent 4bbfa75 commit 7e46468

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

DeepLearning Tutorials/dive_into_keras/cnn-svm.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import theano
1010
from sklearn.ensemble import RandomForestClassifier
1111
from sklearn.svm import SVC
12+
from sklearn.preprocessing import MinMaxScaler
1213
from data import load_data
1314
import random
1415

@@ -46,12 +47,15 @@ def rf(traindata,trainlabel,testdata,testlabel):
4647
(trainlabel,testlabel) = (label[0:30000],label[30000:])
4748
#use origin_model to predict testdata
4849
origin_model = cPickle.load(open("model.pkl","rb"))
50+
#print(origin_model.layers)
4951
pred_testlabel = origin_model.predict_classes(testdata,batch_size=1, verbose=1)
5052
num = len(testlabel)
5153
accuracy = len([1 for i in range(num) if testlabel[i]==pred_testlabel[i]])/float(num)
5254
print(" Origin_model Accuracy:",accuracy)
5355
#define theano funtion to get output of FC layer
54-
get_feature = theano.function([origin_model.layers[0].input],origin_model.layers[11].get_output(train=False),allow_input_downcast=False)
56+
get_feature = theano.function([origin_model.layers[0].input],origin_model.layers[9].output,allow_input_downcast=False)
5557
feature = get_feature(data)
5658
#train svm using FC-layer feature
59+
scaler = MinMaxScaler()
60+
feature = scaler.fit_transform(feature)
5761
svc(feature[0:30000],label[0:30000],feature[30000:],label[30000:])

0 commit comments

Comments
 (0)