1
- from PyQt5 .QtWidgets import QApplication , QMainWindow , QFileDialog , QMessageBox , QTableWidgetItem , QComboBox , \
2
- QTableView , QMenu , QListWidgetItem , QPushButton
1
+ from PyQt5 .QtWidgets import QApplication , QMainWindow , QFileDialog , QInputDialog , QMessageBox , QTableWidgetItem , \
2
+ QComboBox , QCheckBox , QTableView , QMenu , QListWidgetItem , QPushButton
3
3
from PyQt5 .QtCore import pyqtSignal , pyqtSlot , QUrl , QModelIndex
4
4
from PyQt5 .QtGui import QCursor
5
5
import sys
6
6
import os
7
7
import pandas as pd
8
+ import numpy as np
8
9
import ui
9
10
import qtutils
10
11
import utils
12
+ import classifier
11
13
12
14
13
15
class HelpWindow (QMainWindow , ui .Ui_HelpWindow ):
@@ -41,16 +43,30 @@ def __init__(self, parent=None):
41
43
self .options_tableWidget .setEditTriggers (QTableView .NoEditTriggers )
42
44
self .options_tableWidget .itemSelectionChanged .connect (self .on_column_selected )
43
45
self .plotQt = qtutils .PlotQt ()
44
- self .xButton .clicked .connect (lambda : self .actionset_x .trigger ())
45
- self .yButton .clicked .connect (lambda : self .actionset_y .trigger ())
46
- self .cButton .clicked .connect (lambda : self .actionset_c .trigger ())
46
+ self .xButton .clicked .connect (self .actionset_x .trigger )
47
+ self .yButton .clicked .connect (self .actionset_y .trigger )
48
+ self .cButton .clicked .connect (self .actionset_c .trigger )
49
+ self .classifier_property_pushButton .clicked .connect (self .set_property )
50
+ self .create_classifier_pushButton .clicked .connect (self .create_classifier )
51
+ self .show_test_pushButton .clicked .connect (self .evaluate_classifier )
52
+ self .train_checkBox .toggled .connect (lambda toggled : self .on_checkBoxes_toggled (self .train_checkBox , toggled ))
53
+ self .test_checkBox .toggled .connect (lambda toggled : self .on_checkBoxes_toggled (self .test_checkBox , toggled ))
54
+ self .result_checkBox .toggled .connect (lambda toggled : self .on_checkBoxes_toggled (self .result_checkBox , toggled ))
47
55
"""Initialize data"""
48
- self .data = []
56
+ self .data = dict ()
49
57
self .currentData = None
50
58
self .currentDataBinder = utils .Binder (input_names = ('c' , 'x' , 'y' ))
59
+ self .currentTestData = None
60
+ self .currentTestResult = None
61
+ self .currentTestAccuracy = None
51
62
self .fileBinder = utils .Binder (input_names = ('train' , 'test' ))
52
63
self .currentTrace = []
53
64
self .supportedFormat = ['csv files (*.csv)' , 'Microsoft Excel (*.xls *.xlsx)' ]
65
+ self .classifier : classifier .AbstractClassifier = None
66
+ self .classifier_property = ('exp' , 1.0 )
67
+ self .show_train_data = True
68
+ self .show_test_data = False
69
+ self .show_result_data = False
54
70
55
71
def resizeEvent (self , event ):
56
72
# self.root_horizontal.resize(event.size().width(), event.size().height() - 25)
@@ -61,36 +77,56 @@ def open_help_window(self):
61
77
self .helpWindow .show ()
62
78
63
79
def open_open_dialog (self ):
80
+ self .statusbar .showMessage ('正在选择文件……' )
64
81
fname , format = QFileDialog .getOpenFileName (self , caption = '打开文件' , filter = ';;' .join (self .supportedFormat ))
65
82
if format == self .supportedFormat [0 ]:
83
+ self .statusbar .showMessage ('正在打开文件……' )
66
84
data = pd .read_csv (fname )
67
85
elif format == self .supportedFormat [1 ]:
86
+ self .statusbar .showMessage ('正在打开文件……' )
68
87
data = pd .read_excel (fname )
69
88
elif fname and format :
70
89
QMessageBox .warning (self , '文件格式不正确' , fname + '的文件格式不正确!' , QMessageBox .Ok )
90
+ self .statusbar .clearMessage ()
71
91
return
72
92
else :
93
+ self .statusbar .clearMessage ()
73
94
return
74
95
75
- self .data .append (data )
76
- self .action_insert_file (os .path .split (fname )[1 ])
96
+ file_name = os .path .split (fname )[1 ]
97
+ self .data [file_name ] = data
98
+ self .insert_file (file_name )
99
+ self .statusbar .clearMessage ()
77
100
78
- def action_insert_file (self , file_name ):
101
+ def insert_file (self , file_name ):
79
102
self .fileBinder .output_names += (file_name ,)
80
103
self .files_tableWidget .setRowCount (self .files_tableWidget .rowCount () + 1 )
81
104
self .files_tableWidget .setItem (self .files_tableWidget .rowCount () - 1 , 0 , QTableWidgetItem (file_name ))
82
105
self .files_tableWidget .setItem (self .files_tableWidget .rowCount () - 1 , 1 , QTableWidgetItem ())
83
106
84
107
@pyqtSlot (int , int )
85
108
def on_file_dblclicked (self , i , j ):
86
- self .currentData = self .data [i ]
109
+ self .statusbar .showMessage ('正在加载文件……' )
110
+ current_index = list (self .data .keys ())[i ]
111
+ self .currentData = self .data [current_index ]
87
112
self .options_tableWidget .setRowCount (0 )
113
+ self .currentDataBinder .output_names = self .currentData .columns
114
+ self .actionset_train .trigger ()
88
115
for c in self .currentData .columns :
89
- self .currentDataBinder .output_names += (c ,)
90
116
self .options_tableWidget .setRowCount (self .options_tableWidget .rowCount () + 1 )
91
117
self .options_tableWidget .setItem (self .options_tableWidget .rowCount () - 1 , 0 , QTableWidgetItem (c ))
92
118
self .options_tableWidget .setItem (self .options_tableWidget .rowCount () - 1 , 1 , QTableWidgetItem ())
119
+ self .statusbar .clearMessage ()
93
120
121
+ def on_checkBoxes_toggled (self , cb : QCheckBox , toggled ):
122
+ if cb == self .train_checkBox :
123
+ self .show_train_data = toggled
124
+ elif cb == self .test_checkBox :
125
+ self .show_test_data = toggled
126
+ elif cb == self .result_checkBox :
127
+ self .show_result_data = toggled
128
+ self .plot_refresh ()
129
+
94
130
def on_file_selected (self ):
95
131
have_selected = len (self .files_tableWidget .selectedItems ()) != 0
96
132
self .menuset_current_file_as .setEnabled (have_selected )
@@ -117,8 +153,16 @@ def on_file_set(self, option, trig):
117
153
selected_item : QTableWidgetItem = self .files_tableWidget .item (selected_index .row (), 0 )
118
154
if trig :
119
155
self .fileBinder .connect (option , selected_item .text ())
156
+ if option == 'train' :
157
+ self .currentData = self .data [selected_item .text ()]
158
+ elif option == 'test' :
159
+ self .currentTestData = self .data [selected_item .text ()]
120
160
else :
121
161
self .fileBinder .disconnect (option , selected_item .text ())
162
+ if option == 'train' :
163
+ self .currentData = None
164
+ elif option == 'test' :
165
+ self .currentTestData = None
122
166
for i in range (self .files_tableWidget .rowCount ()):
123
167
col : QTableWidgetItem = self .files_tableWidget .item (i , 0 )
124
168
op : QTableWidgetItem = self .files_tableWidget .item (i , 1 )
@@ -133,17 +177,7 @@ def on_option_set(self, option, trig):
133
177
self .currentDataBinder .connect (option , selected_item .text ())
134
178
else :
135
179
self .currentDataBinder .disconnect (option , selected_item .text ())
136
- data_x = self .currentDataBinder .infer ('x' )
137
- data_y = self .currentDataBinder .infer ('y' )
138
- data_c = self .currentDataBinder .infer ('c' )
139
- if data_x and data_y :
140
- if data_c :
141
- html_path , self .currentTrace = self .plotQt .scatter (self .currentData [data_x ], self .currentData [data_y ],
142
- self .currentData [data_c ])
143
- self .plotter .load (QUrl .fromLocalFile (html_path ))
144
- else :
145
- html_path , self .currentTrace = self .plotQt .scatter (self .currentData [data_x ], self .currentData [data_y ])
146
- self .plotter .load (QUrl .fromLocalFile (html_path ))
180
+ self .plot_refresh ()
147
181
for i in range (self .options_tableWidget .rowCount ()):
148
182
col : QTableWidgetItem = self .options_tableWidget .item (i , 0 )
149
183
op : QTableWidgetItem = self .options_tableWidget .item (i , 1 )
@@ -180,6 +214,148 @@ def set_column_menu_status(self, option):
180
214
self .actionset_y .setChecked (False )
181
215
self .actionset_c .setChecked (False )
182
216
217
+ def set_property (self ):
218
+ func , ok = QInputDialog .getItem (self , 'Parameters' , '势函数' , ['exp' , 'cauchy' ])
219
+ if ok :
220
+ alpha , ok = QInputDialog .getDouble (self , 'Parameters' , 'alpha' , 1.0 )
221
+ if ok :
222
+ self .classifier_property = (func , alpha )
223
+
224
+ def plot_new (self , x , y , c = None , marker = None ):
225
+ html_path , self .currentTrace = self .plotQt .scatter (x , y , c , marker = marker )
226
+ self .plotter .load (QUrl .fromLocalFile (html_path ))
227
+
228
+ def plot_append (self , x , y , c = None , marker = None ):
229
+ html_path , self .currentTrace = self .plotQt .scatter (x , y , c , traces = self .currentTrace , marker = marker )
230
+ self .plotter .load (QUrl .fromLocalFile (html_path ))
231
+
232
+ def plot_clear (self ):
233
+ self .plotter .setHtml ('' )
234
+
235
+ def plot_refresh (self ):
236
+ self .statusbar .showMessage ('正在自动绘图……' )
237
+ data_x = self .currentDataBinder .infer ('x' )
238
+ data_y = self .currentDataBinder .infer ('y' )
239
+ data_c = self .currentDataBinder .infer ('c' )
240
+ self .plot_clear ()
241
+ self .currentTrace = None
242
+ if data_x and data_y :
243
+ if self .show_train_data :
244
+ if data_c :
245
+ self .plot_append (
246
+ self .currentData [data_x ],
247
+ self .currentData [data_y ],
248
+ self .currentData [data_c ],
249
+ marker = dict (
250
+ size = 3
251
+ )
252
+ )
253
+ else :
254
+ self .plot_append (
255
+ self .currentData [data_x ],
256
+ self .currentData [data_y ],
257
+ marker = dict (
258
+ size = 3
259
+ )
260
+ )
261
+ if self .show_test_data :
262
+ if self .currentTestData is not None :
263
+ if data_c :
264
+ self .plot_append (
265
+ self .currentTestData [data_x ],
266
+ self .currentTestData [data_y ],
267
+ self .currentTestData [data_c ],
268
+ marker = dict (
269
+ size = 6 ,
270
+ symbol = 'circle-open'
271
+ )
272
+ )
273
+ else :
274
+ self .plot_append (
275
+ self .currentTestData [data_x ],
276
+ self .currentTestData [data_y ],
277
+ marker = dict (
278
+ size = 6 ,
279
+ symbol = 'circle-open'
280
+ )
281
+ )
282
+ if self .show_result_data :
283
+ if self .currentTestResult is not None :
284
+ test_result = self .currentTestData [[data_x , data_y ]]
285
+ test_result [data_c ] = self .currentTestResult
286
+ if data_c :
287
+ self .plot_append (
288
+ test_result [data_x ],
289
+ test_result [data_y ],
290
+ test_result [data_c ],
291
+ marker = dict (
292
+ size = 6 ,
293
+ symbol = 'cross'
294
+ )
295
+ )
296
+ else :
297
+ self .plot_append (
298
+ test_result [data_x ],
299
+ test_result [data_y ],
300
+ marker = dict (
301
+ size = 6 ,
302
+ symbol = 'cross'
303
+ )
304
+ )
305
+ self .statusbar .clearMessage ()
306
+
307
+ def create_classifier (self ):
308
+ self .statusbar .showMessage ('正在创建分类器……' )
309
+ if not self .currentDataBinder .infer ('c' ):
310
+ QMessageBox .warning (self , '警告' , '请指定输入数据的类别列!' , QMessageBox .Ok )
311
+ self .statusbar .clearMessage ()
312
+ return
313
+ # index = list(set(self.currentData.columns) - {self.currentDataBinder.infer('c')})
314
+ index = list (self .currentData .columns ).copy ()
315
+ index .remove (self .currentDataBinder .infer ('c' ))
316
+ func , alpha = self .classifier_property
317
+ if func == 'exp' :
318
+ self .classifier = classifier .PotentialClassifier (
319
+ self .currentData [index ],
320
+ self .currentData [self .currentDataBinder .infer ('c' )],
321
+ potential_function = classifier .PotentialFunctions .exponential (alpha )
322
+ )
323
+ elif func == 'cauchy' :
324
+ self .classifier = classifier .PotentialClassifier (
325
+ self .currentData [index ],
326
+ self .currentData [self .currentDataBinder .infer ('c' )],
327
+ potential_function = classifier .PotentialFunctions .cauchy (alpha )
328
+ )
329
+ QMessageBox .information (self , '提示' , '设置成功!' , QMessageBox .Ok )
330
+ self .statusbar .clearMessage ()
331
+
332
+ def evaluate_classifier (self ):
333
+ if self .classifier is not None :
334
+ file_test = self .fileBinder .infer ('test' )
335
+ if file_test :
336
+ self .currentTestData = self .data [file_test ]
337
+ index = list (self .currentData .columns ).copy ()
338
+ has_answer = self .currentDataBinder .infer ('c' ) in self .currentTestData .columns
339
+ index .remove (self .currentDataBinder .infer ('c' ))
340
+ for k in index :
341
+ if k not in self .currentTestData .columns :
342
+ QMessageBox .warning (self , '警告' , '测试数据格式不完整!' , QMessageBox .Ok )
343
+ return
344
+ test_data = self .currentTestData [index ].to_numpy ()
345
+ self .currentTestResult = self .classifier .evaluate (test_data )
346
+ if has_answer :
347
+ correct : np .ndarray = self .currentTestResult == self .currentTestData [
348
+ self .currentDataBinder .infer ('c' )].to_numpy ().reshape (- 1 , 1 )
349
+ self .currentTestAccuracy = correct .sum () / self .currentTestResult .shape [0 ]
350
+ QMessageBox .information (self , '提示' , '测试成功!正确率%.3f' % self .currentTestAccuracy , QMessageBox .Ok )
351
+ else :
352
+ QMessageBox .information (self , '提示' , '测试成功!' , QMessageBox .Ok )
353
+ self .plot_refresh ()
354
+ else :
355
+ QMessageBox .warning (self , '警告' , '请先指定测试集文件!' , QMessageBox .Ok )
356
+ else :
357
+ QMessageBox .warning (self , '警告' , '请先创建分类器!' , QMessageBox .Ok )
358
+
183
359
184
360
if __name__ == '__main__' :
185
361
app = QApplication (sys .argv )
0 commit comments