Skip to content

Commit a932f50

Browse files
committed
分类模型对比:tfidf/word2vec + lr/svm
1 parent f5d1da5 commit a932f50

File tree

3 files changed

+601
-30
lines changed

3 files changed

+601
-30
lines changed

jupyter/codeTest.ipynb

+173-30
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
},
2323
{
2424
"cell_type": "code",
25-
"execution_count": 1,
25+
"execution_count": 23,
2626
"metadata": {
2727
"collapsed": false
2828
},
@@ -46,7 +46,7 @@
4646
"name": "stderr",
4747
"output_type": "stream",
4848
"text": [
49-
"Loading model cost 1.042 seconds.\n",
49+
"Loading model cost 1.214 seconds.\n",
5050
"Prefix dict has been built succesfully.\n"
5151
]
5252
},
@@ -63,10 +63,69 @@
6363
{
6464
"data": {
6565
"text/plain": [
66-
"{'doc_title': ['南', '天', '信息', '管理层', '增持', '86', '万股'], 'doc_type': 'IT'}"
66+
"{'doc_content': ['本期',\n",
67+
" '节目',\n",
68+
" '内容',\n",
69+
" '介绍',\n",
70+
" '关注',\n",
71+
" '机动车',\n",
72+
" '驾驶证',\n",
73+
" '申领',\n",
74+
" '和',\n",
75+
" '使用',\n",
76+
" '规定',\n",
77+
" '搜狐',\n",
78+
" '汽车',\n",
79+
" '广播',\n",
80+
" '诚邀',\n",
81+
" '全国',\n",
82+
" '各地',\n",
83+
" '强势',\n",
84+
" '电台',\n",
85+
" '真情',\n",
86+
" '加盟',\n",
87+
" '携手',\n",
88+
" '打造',\n",
89+
" '中国',\n",
90+
" '汽车',\n",
91+
" '广播',\n",
92+
" '最强',\n",
93+
" '容',\n",
94+
" '把脉',\n",
95+
" '全球',\n",
96+
" '汽车产业',\n",
97+
" '风向标',\n",
98+
" '引领',\n",
99+
" '时尚',\n",
100+
" '汽车',\n",
101+
" '消费',\n",
102+
" '的',\n",
103+
" '参考书',\n",
104+
" '搜狐',\n",
105+
" '汽车',\n",
106+
" '广播',\n",
107+
" '车旅',\n",
108+
" '杂志',\n",
109+
" '服务',\n",
110+
" '我们',\n",
111+
" '的',\n",
112+
" '汽车',\n",
113+
" '生活',\n",
114+
" '加盟',\n",
115+
" '热线',\n",
116+
" '13381202220',\n",
117+
" '010',\n",
118+
" '62729907',\n",
119+
" '独家',\n",
120+
" '出品',\n",
121+
" '搜狐',\n",
122+
" '汽车',\n",
123+
" '事业部'],\n",
124+
" 'doc_title': ['搜狐', '汽车', '广播', '车旅', '杂志', '2012', '06', '20', '期'],\n",
125+
" 'doc_type': '汽车'}"
67126
]
68127
},
69-
"execution_count": 1,
128+
"execution_count": 23,
70129
"metadata": {},
71130
"output_type": "execute_result"
72131
}
@@ -97,7 +156,7 @@
97156
" return tokens\n",
98157
"\n",
99158
"# 对新闻标题进行分词,得到带分词的新闻数据\n",
100-
"tokenSougouNews = featurelize(sougouNews, fields=['doc_title'], analyzer=Analyzer())\n",
159+
"tokenSougouNews = featurelize(sougouNews, fields=['doc_title', 'doc_content'], analyzer=Analyzer())\n",
101160
"print('完成对新闻标题的分词')\n",
102161
"\n",
103162
"# 将分词后的结果dump到本地\n",
@@ -159,14 +218,14 @@
159218
"source": [
160219
"import pickle\n",
161220
"\n",
221+
" \n",
222+
"with open('tokenSougouNews-test.pk', 'rb') as f:\n",
223+
" testData = pickle\n",
162224
"with open('tokenSougouNews-train.pk', 'rb') as f:\n",
163225
" trainData = pickle.load(f)\n",
164226
"trainX = [dict(doc_title=' '.join(d['doc_title'])) for d in trainData]\n",
165227
"trainY = [d['doc_type'] for d in trainData]\n",
166-
"print('train size=%d' % (len(trainX)))\n",
167-
" \n",
168-
"with open('tokenSougouNews-test.pk', 'rb') as f:\n",
169-
" testData = pickle.load(f)\n",
228+
"print('train size=%d' % (len(trainX))).load(f)\n",
170229
"testX = [dict(doc_title=' '.join(d['doc_title'])) for d in testData]\n",
171230
"testY = [d['doc_type'] for d in testData]\n",
172231
"print('test size=%d' % (len(testX)))"
@@ -227,7 +286,7 @@
227286
},
228287
{
229288
"cell_type": "code",
230-
"execution_count": 11,
289+
"execution_count": 16,
231290
"metadata": {
232291
"collapsed": false
233292
},
@@ -244,6 +303,7 @@
244303
"name": "stdout",
245304
"output_type": "stream",
246305
"text": [
306+
"tfidf+lr: trainAcc=0.913848, testAcc=0.869774\n",
247307
"tfidf+multiNB: trainAcc=0.867886, testAcc=0.821235\n",
248308
"tfidf+svm: trainAcc=0.981018, testAcc=0.895906\n"
249309
]
@@ -252,23 +312,33 @@
252312
"source": [
253313
"from sklearn.naive_bayes import MultinomialNB, BernoulliNB\n",
254314
"from sklearn.svm import LinearSVC\n",
315+
"from sklearn.linear_model import LogisticRegression\n",
255316
"from sklearn.pipeline import Pipeline\n",
256317
"from sklearn.metrics import accuracy_score\n",
257318
"\n",
319+
"# tfidf + lr\n",
320+
"lrClf = Pipeline([('tfidfVectorizor', TfidfVectorizor(['doc_title'])),\n",
321+
" ('lr', LogisticRegression())])\n",
322+
"lrClf.fit(trainX, trainY)\n",
323+
"\n",
324+
"trainAcc = accuracy_score(trainY, lrClf.predict(trainX))\n",
325+
"testAcc = accuracy_score(testY, lrClf.predict(testX))\n",
326+
"print('tfidf+lr: trainAcc=%f, testAcc=%f' % (trainAcc, testAcc))\n",
327+
"\n",
328+
"# tfidf + nb\n",
258329
"nbClf = Pipeline([('tfidfVectorizor', TfidfVectorizor(['doc_title'])),\n",
259330
" ('multinomialNB', MultinomialNB())])\n",
260331
"nbClf.fit(trainX, trainY)\n",
261332
"\n",
262-
"# 计算误差\n",
263333
"trainAcc = accuracy_score(trainY, nbClf.predict(trainX))\n",
264334
"testAcc = accuracy_score(testY, nbClf.predict(testX))\n",
265335
"print('tfidf+multiNB: trainAcc=%f, testAcc=%f' % (trainAcc, testAcc))\n",
266336
"\n",
337+
"# tfidf + svm\n",
267338
"svmClf = Pipeline([('tfidfVectorizor', TfidfVectorizor(['doc_title'])),\n",
268339
" ('svm', LinearSVC())])\n",
269340
"svmClf.fit(trainX, trainY)\n",
270341
"\n",
271-
"# 计算误差\n",
272342
"trainAcc = accuracy_score(trainY, svmClf.predict(trainX))\n",
273343
"testAcc = accuracy_score(testY, svmClf.predict(testX))\n",
274344
"print('tfidf+svm: trainAcc=%f, testAcc=%f' % (trainAcc, testAcc))"
@@ -285,7 +355,7 @@
285355
},
286356
{
287357
"cell_type": "code",
288-
"execution_count": 12,
358+
"execution_count": 17,
289359
"metadata": {
290360
"collapsed": false
291361
},
@@ -300,10 +370,10 @@
300370
{
301371
"data": {
302372
"text/plain": [
303-
"<__main__.Doc2VecVectorizor at 0x1fd4fd75710>"
373+
"<__main__.Doc2VecVectorizor at 0x1fd5007cf98>"
304374
]
305375
},
306-
"execution_count": 12,
376+
"execution_count": 17,
307377
"metadata": {},
308378
"output_type": "execute_result"
309379
}
@@ -312,7 +382,7 @@
312382
"from gensim.models import Word2Vec\n",
313383
"\n",
314384
"class Doc2VecVectorizor(object):\n",
315-
" def __init__(self, fields, size=200, window=3, min_count=1):\n",
385+
" def __init__(self, fields, size=100, window=3, min_count=1):\n",
316386
" self.fields = fields\n",
317387
" self.size = size\n",
318388
" self.window = window\n",
@@ -352,27 +422,27 @@
352422
},
353423
{
354424
"cell_type": "code",
355-
"execution_count": 13,
425+
"execution_count": 18,
356426
"metadata": {
357427
"collapsed": false
358428
},
359429
"outputs": [
360430
{
361431
"data": {
362432
"text/plain": [
363-
"[('老年人', 0.974733293056488),\n",
364-
" ('日内瓦', 0.9729659557342529),\n",
365-
" ('国际足球', 0.9727454781532288),\n",
366-
" ('专访', 0.9721158146858215),\n",
367-
" ('搜狐', 0.9709295034408569),\n",
368-
" ('第九届', 0.9708148241043091),\n",
369-
" ('舞蹈节', 0.9674550294876099),\n",
370-
" ('文化周', 0.9654016494750977),\n",
371-
" ('日程安排', 0.9652378559112549),\n",
372-
" ('作文题', 0.9637157320976257)]"
433+
"[('舞蹈节', 0.9734185934066772),\n",
434+
" ('专访', 0.9699808955192566),\n",
435+
" ('老年人', 0.9686485528945923),\n",
436+
" ('日内瓦', 0.9671200513839722),\n",
437+
" ('搜狐', 0.9666953086853027),\n",
438+
" ('看车', 0.963032603263855),\n",
439+
" ('国际足球', 0.9596318006515503),\n",
440+
" ('广汽传祺', 0.9582968950271606),\n",
441+
" ('篮联', 0.9582201242446899),\n",
442+
" ('海河', 0.9577779173851013)]"
373443
]
374444
},
375-
"execution_count": 13,
445+
"execution_count": 18,
376446
"metadata": {},
377447
"output_type": "execute_result"
378448
}
@@ -381,6 +451,28 @@
381451
"doc2vec.word2vec.wv.similar_by_word(word='体育', topn=10)"
382452
]
383453
},
454+
{
455+
"cell_type": "code",
456+
"execution_count": 20,
457+
"metadata": {
458+
"collapsed": false
459+
},
460+
"outputs": [
461+
{
462+
"data": {
463+
"text/plain": [
464+
"100"
465+
]
466+
},
467+
"execution_count": 20,
468+
"metadata": {},
469+
"output_type": "execute_result"
470+
}
471+
],
472+
"source": [
473+
"doc2vec.word2vec.vector_size"
474+
]
475+
},
384476
{
385477
"cell_type": "markdown",
386478
"metadata": {},
@@ -390,7 +482,7 @@
390482
},
391483
{
392484
"cell_type": "code",
393-
"execution_count": 15,
485+
"execution_count": 19,
394486
"metadata": {
395487
"collapsed": false
396488
},
@@ -408,7 +500,7 @@
408500
"name": "stdout",
409501
"output_type": "stream",
410502
"text": [
411-
"doc2vec+svm: trainAcc=0.706894, testAcc=0.709253\n"
503+
"doc2vec+svm: trainAcc=0.705841, testAcc=0.708672\n"
412504
]
413505
}
414506
],
@@ -424,6 +516,57 @@
424516
"testAcc = accuracy_score(testY, svmClf.predict(testX))\n",
425517
"print('doc2vec+svm: trainAcc=%f, testAcc=%f' % (trainAcc, testAcc))"
426518
]
519+
},
520+
{
521+
"cell_type": "markdown",
522+
"metadata": {},
523+
"source": [
524+
"### tf-idf加权的word2vec + classification\n",
525+
"#### tf-idf加权的word2vec"
526+
]
527+
},
528+
{
529+
"cell_type": "code",
530+
"execution_count": null,
531+
"metadata": {
532+
"collapsed": true
533+
},
534+
"outputs": [],
535+
"source": [
536+
"from gensim.models import Word2Vec\n",
537+
"\n",
538+
"class Doc2VecVectorizor(object):\n",
539+
" def __init__(self, tfidfVectorizor, word2vecVectorizor, fields):\n",
540+
" self.tfidfVectorizor = tfidfVectorizor\n",
541+
" self.word2vecVectorizor = word2vecVectorizor\n",
542+
" self.fields = fields\n",
543+
" \n",
544+
" def fit(self, X, y=None):\n",
545+
" return self\n",
546+
" \n",
547+
" def transform(self, X):\n",
548+
" \"\"\"\n",
549+
" 计算文档的特征向量\n",
550+
" 1. 对每个属性,计算每个词的tfidf-vector和word-vector,然后将所有词的两个vector的加权平均向量作为该属性的vector\n",
551+
" 2. 所有属性的vector,flatten为一个宽vector,作为该文档的特征向量\n",
552+
" \"\"\"\n",
553+
" return np.array([self.__doc2vec(x) for x in X])\n",
554+
" \n",
555+
" def __sentence2vec(self, sentence):\n",
556+
" if len(sentence.strip()) == 0:\n",
557+
" return np.zeros(self.size)\n",
558+
" vectors = [self.word2vecVectorizor[word]*self.tfidfVectorizor.transform() \n",
559+
" if word in self.word2vecVectorizor else np.zeros(self.size) \n",
560+
" for word in sentence.split()]\n",
561+
" return np.mean(vectors, axis=0)\n",
562+
" \n",
563+
" def __doc2vec(self, doc):\n",
564+
" vectors = np.array([self.__sentence2vec(doc[field]) for field in self.fields])\n",
565+
" return vectors.flatten()\n",
566+
" \n",
567+
"doc2vec = Doc2VecVectorizor(fields=['doc_title'])\n",
568+
"doc2vec.fit(trainX)"
569+
]
427570
}
428571
],
429572
"metadata": {

0 commit comments

Comments
 (0)