Skip to content

Commit 2fca4d9

Browse files
committed
promote main GUI, complete potential classifier
1 parent eb3ea28 commit 2fca4d9

12 files changed

+3459
-67
lines changed

classifier/classifier_potential.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class PotentialClassifier(classifier.AbstractClassifier):
1717

1818
def __init__(self, X, y, classes: iter = None,
1919
potential_function=PotentialFunctions.exponential()):
20-
self.X = X
21-
self.y = y
20+
self.X = np.array(X)
21+
self.y = np.array(y)
2222
self.W = None
2323
classes = classes or set(y)
2424
self.class_binder = utils.Binder.create_standard_binder(classes)
@@ -37,15 +37,16 @@ def initialize(self):
3737

3838
point_flag = np.zeros((data_num, 1))
3939
point_potential = np.zeros((data_num, len(self.class_binder.input_names)))
40-
for i in range(data_num):
40+
i = 0
41+
while i < data_num:
4142
if point_potential[i, self.data_type(i)] == np.max(np.squeeze(point_potential[i, :])) \
4243
and point_potential[i, self.data_type(i)] != 0:
43-
pass
44+
i += 1
4445
else:
4546
point_flag[i] += 1
4647
for j in range(data_num):
4748
point_potential[j, self.data_type(i)] += potential[i, j]
48-
i = 1
49+
i = 0
4950
self.W = point_flag
5051
return point_flag
5152

@@ -62,8 +63,8 @@ def evaluate(self, test_data: np.ndarray):
6263
for i in range(test_len):
6364
for j in range(train_len):
6465
test_potential[i, self.data_type(j)] += self.W[j] * potential[i, j]
65-
66-
test_type = np.zeros((test_len, 1))
66+
67+
test_type = np.zeros((test_len, 1), dtype=np.object)
6768
for i in range(test_len):
6869
height = np.max(test_potential[i, :])
6970
test_flag = np.argmax(test_potential[i, :])

main.py

+198-22
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog, QMessageBox, QTableWidgetItem, QComboBox, \
2-
QTableView, QMenu, QListWidgetItem, QPushButton
1+
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog, QInputDialog, QMessageBox, QTableWidgetItem, \
2+
QComboBox, QCheckBox, QTableView, QMenu, QListWidgetItem, QPushButton
33
from PyQt5.QtCore import pyqtSignal, pyqtSlot, QUrl, QModelIndex
44
from PyQt5.QtGui import QCursor
55
import sys
66
import os
77
import pandas as pd
8+
import numpy as np
89
import ui
910
import qtutils
1011
import utils
12+
import classifier
1113

1214

1315
class HelpWindow(QMainWindow, ui.Ui_HelpWindow):
@@ -41,16 +43,30 @@ def __init__(self, parent=None):
4143
self.options_tableWidget.setEditTriggers(QTableView.NoEditTriggers)
4244
self.options_tableWidget.itemSelectionChanged.connect(self.on_column_selected)
4345
self.plotQt = qtutils.PlotQt()
44-
self.xButton.clicked.connect(lambda: self.actionset_x.trigger())
45-
self.yButton.clicked.connect(lambda: self.actionset_y.trigger())
46-
self.cButton.clicked.connect(lambda: self.actionset_c.trigger())
46+
self.xButton.clicked.connect(self.actionset_x.trigger)
47+
self.yButton.clicked.connect(self.actionset_y.trigger)
48+
self.cButton.clicked.connect(self.actionset_c.trigger)
49+
self.classifier_property_pushButton.clicked.connect(self.set_property)
50+
self.create_classifier_pushButton.clicked.connect(self.create_classifier)
51+
self.show_test_pushButton.clicked.connect(self.evaluate_classifier)
52+
self.train_checkBox.toggled.connect(lambda toggled: self.on_checkBoxes_toggled(self.train_checkBox, toggled))
53+
self.test_checkBox.toggled.connect(lambda toggled: self.on_checkBoxes_toggled(self.test_checkBox, toggled))
54+
self.result_checkBox.toggled.connect(lambda toggled: self.on_checkBoxes_toggled(self.result_checkBox, toggled))
4755
"""Initialize data"""
48-
self.data = []
56+
self.data = dict()
4957
self.currentData = None
5058
self.currentDataBinder = utils.Binder(input_names=('c', 'x', 'y'))
59+
self.currentTestData = None
60+
self.currentTestResult = None
61+
self.currentTestAccuracy = None
5162
self.fileBinder = utils.Binder(input_names=('train', 'test'))
5263
self.currentTrace = []
5364
self.supportedFormat = ['csv files (*.csv)', 'Microsoft Excel (*.xls *.xlsx)']
65+
self.classifier: classifier.AbstractClassifier = None
66+
self.classifier_property = ('exp', 1.0)
67+
self.show_train_data = True
68+
self.show_test_data = False
69+
self.show_result_data = False
5470

5571
def resizeEvent(self, event):
5672
# self.root_horizontal.resize(event.size().width(), event.size().height() - 25)
@@ -61,36 +77,56 @@ def open_help_window(self):
6177
self.helpWindow.show()
6278

6379
def open_open_dialog(self):
80+
self.statusbar.showMessage('正在选择文件……')
6481
fname, format = QFileDialog.getOpenFileName(self, caption='打开文件', filter=';;'.join(self.supportedFormat))
6582
if format == self.supportedFormat[0]:
83+
self.statusbar.showMessage('正在打开文件……')
6684
data = pd.read_csv(fname)
6785
elif format == self.supportedFormat[1]:
86+
self.statusbar.showMessage('正在打开文件……')
6887
data = pd.read_excel(fname)
6988
elif fname and format:
7089
QMessageBox.warning(self, '文件格式不正确', fname + '的文件格式不正确!', QMessageBox.Ok)
90+
self.statusbar.clearMessage()
7191
return
7292
else:
93+
self.statusbar.clearMessage()
7394
return
7495

75-
self.data.append(data)
76-
self.action_insert_file(os.path.split(fname)[1])
96+
file_name = os.path.split(fname)[1]
97+
self.data[file_name] = data
98+
self.insert_file(file_name)
99+
self.statusbar.clearMessage()
77100

78-
def action_insert_file(self, file_name):
101+
def insert_file(self, file_name):
79102
self.fileBinder.output_names += (file_name,)
80103
self.files_tableWidget.setRowCount(self.files_tableWidget.rowCount() + 1)
81104
self.files_tableWidget.setItem(self.files_tableWidget.rowCount() - 1, 0, QTableWidgetItem(file_name))
82105
self.files_tableWidget.setItem(self.files_tableWidget.rowCount() - 1, 1, QTableWidgetItem())
83106

84107
@pyqtSlot(int, int)
85108
def on_file_dblclicked(self, i, j):
86-
self.currentData = self.data[i]
109+
self.statusbar.showMessage('正在加载文件……')
110+
current_index = list(self.data.keys())[i]
111+
self.currentData = self.data[current_index]
87112
self.options_tableWidget.setRowCount(0)
113+
self.currentDataBinder.output_names = self.currentData.columns
114+
self.actionset_train.trigger()
88115
for c in self.currentData.columns:
89-
self.currentDataBinder.output_names += (c,)
90116
self.options_tableWidget.setRowCount(self.options_tableWidget.rowCount() + 1)
91117
self.options_tableWidget.setItem(self.options_tableWidget.rowCount() - 1, 0, QTableWidgetItem(c))
92118
self.options_tableWidget.setItem(self.options_tableWidget.rowCount() - 1, 1, QTableWidgetItem())
119+
self.statusbar.clearMessage()
93120

121+
def on_checkBoxes_toggled(self, cb: QCheckBox, toggled):
122+
if cb == self.train_checkBox:
123+
self.show_train_data = toggled
124+
elif cb == self.test_checkBox:
125+
self.show_test_data = toggled
126+
elif cb == self.result_checkBox:
127+
self.show_result_data = toggled
128+
self.plot_refresh()
129+
94130
def on_file_selected(self):
95131
have_selected = len(self.files_tableWidget.selectedItems()) != 0
96132
self.menuset_current_file_as.setEnabled(have_selected)
@@ -117,8 +153,16 @@ def on_file_set(self, option, trig):
117153
selected_item: QTableWidgetItem = self.files_tableWidget.item(selected_index.row(), 0)
118154
if trig:
119155
self.fileBinder.connect(option, selected_item.text())
156+
if option == 'train':
157+
self.currentData = self.data[selected_item.text()]
158+
elif option == 'test':
159+
self.currentTestData = self.data[selected_item.text()]
120160
else:
121161
self.fileBinder.disconnect(option, selected_item.text())
162+
if option == 'train':
163+
self.currentData = None
164+
elif option == 'test':
165+
self.currentTestData = None
122166
for i in range(self.files_tableWidget.rowCount()):
123167
col: QTableWidgetItem = self.files_tableWidget.item(i, 0)
124168
op: QTableWidgetItem = self.files_tableWidget.item(i, 1)
@@ -133,17 +177,7 @@ def on_option_set(self, option, trig):
133177
self.currentDataBinder.connect(option, selected_item.text())
134178
else:
135179
self.currentDataBinder.disconnect(option, selected_item.text())
136-
data_x = self.currentDataBinder.infer('x')
137-
data_y = self.currentDataBinder.infer('y')
138-
data_c = self.currentDataBinder.infer('c')
139-
if data_x and data_y:
140-
if data_c:
141-
html_path, self.currentTrace = self.plotQt.scatter(self.currentData[data_x], self.currentData[data_y],
142-
self.currentData[data_c])
143-
self.plotter.load(QUrl.fromLocalFile(html_path))
144-
else:
145-
html_path, self.currentTrace = self.plotQt.scatter(self.currentData[data_x], self.currentData[data_y])
146-
self.plotter.load(QUrl.fromLocalFile(html_path))
180+
self.plot_refresh()
147181
for i in range(self.options_tableWidget.rowCount()):
148182
col: QTableWidgetItem = self.options_tableWidget.item(i, 0)
149183
op: QTableWidgetItem = self.options_tableWidget.item(i, 1)
@@ -180,6 +214,148 @@ def set_column_menu_status(self, option):
180214
self.actionset_y.setChecked(False)
181215
self.actionset_c.setChecked(False)
182216

217+
def set_property(self):
218+
func, ok = QInputDialog.getItem(self, 'Parameters', '势函数', ['exp', 'cauchy'])
219+
if ok:
220+
alpha, ok = QInputDialog.getDouble(self, 'Parameters', 'alpha', 1.0)
221+
if ok:
222+
self.classifier_property = (func, alpha)
223+
224+
def plot_new(self, x, y, c=None, marker=None):
225+
html_path, self.currentTrace = self.plotQt.scatter(x, y, c, marker=marker)
226+
self.plotter.load(QUrl.fromLocalFile(html_path))
227+
228+
def plot_append(self, x, y, c=None, marker=None):
229+
html_path, self.currentTrace = self.plotQt.scatter(x, y, c, traces=self.currentTrace, marker=marker)
230+
self.plotter.load(QUrl.fromLocalFile(html_path))
231+
232+
def plot_clear(self):
233+
self.plotter.setHtml('')
234+
235+
def plot_refresh(self):
236+
self.statusbar.showMessage('正在自动绘图……')
237+
data_x = self.currentDataBinder.infer('x')
238+
data_y = self.currentDataBinder.infer('y')
239+
data_c = self.currentDataBinder.infer('c')
240+
self.plot_clear()
241+
self.currentTrace = None
242+
if data_x and data_y:
243+
if self.show_train_data:
244+
if data_c:
245+
self.plot_append(
246+
self.currentData[data_x],
247+
self.currentData[data_y],
248+
self.currentData[data_c],
249+
marker=dict(
250+
size=3
251+
)
252+
)
253+
else:
254+
self.plot_append(
255+
self.currentData[data_x],
256+
self.currentData[data_y],
257+
marker=dict(
258+
size=3
259+
)
260+
)
261+
if self.show_test_data:
262+
if self.currentTestData is not None:
263+
if data_c:
264+
self.plot_append(
265+
self.currentTestData[data_x],
266+
self.currentTestData[data_y],
267+
self.currentTestData[data_c],
268+
marker=dict(
269+
size=6,
270+
symbol='circle-open'
271+
)
272+
)
273+
else:
274+
self.plot_append(
275+
self.currentTestData[data_x],
276+
self.currentTestData[data_y],
277+
marker=dict(
278+
size=6,
279+
symbol='circle-open'
280+
)
281+
)
282+
if self.show_result_data:
283+
if self.currentTestResult is not None:
284+
test_result = self.currentTestData[[data_x, data_y]]
285+
test_result[data_c] = self.currentTestResult
286+
if data_c:
287+
self.plot_append(
288+
test_result[data_x],
289+
test_result[data_y],
290+
test_result[data_c],
291+
marker=dict(
292+
size=6,
293+
symbol='cross'
294+
)
295+
)
296+
else:
297+
self.plot_append(
298+
test_result[data_x],
299+
test_result[data_y],
300+
marker=dict(
301+
size=6,
302+
symbol='cross'
303+
)
304+
)
305+
self.statusbar.clearMessage()
306+
307+
def create_classifier(self):
308+
self.statusbar.showMessage('正在创建分类器……')
309+
if not self.currentDataBinder.infer('c'):
310+
QMessageBox.warning(self, '警告', '请指定输入数据的类别列!', QMessageBox.Ok)
311+
self.statusbar.clearMessage()
312+
return
313+
# index = list(set(self.currentData.columns) - {self.currentDataBinder.infer('c')})
314+
index = list(self.currentData.columns).copy()
315+
index.remove(self.currentDataBinder.infer('c'))
316+
func, alpha = self.classifier_property
317+
if func == 'exp':
318+
self.classifier = classifier.PotentialClassifier(
319+
self.currentData[index],
320+
self.currentData[self.currentDataBinder.infer('c')],
321+
potential_function=classifier.PotentialFunctions.exponential(alpha)
322+
)
323+
elif func == 'cauchy':
324+
self.classifier = classifier.PotentialClassifier(
325+
self.currentData[index],
326+
self.currentData[self.currentDataBinder.infer('c')],
327+
potential_function=classifier.PotentialFunctions.cauchy(alpha)
328+
)
329+
QMessageBox.information(self, '提示', '设置成功!', QMessageBox.Ok)
330+
self.statusbar.clearMessage()
331+
332+
def evaluate_classifier(self):
333+
if self.classifier is not None:
334+
file_test = self.fileBinder.infer('test')
335+
if file_test:
336+
self.currentTestData = self.data[file_test]
337+
index = list(self.currentData.columns).copy()
338+
has_answer = self.currentDataBinder.infer('c') in self.currentTestData.columns
339+
index.remove(self.currentDataBinder.infer('c'))
340+
for k in index:
341+
if k not in self.currentTestData.columns:
342+
QMessageBox.warning(self, '警告', '测试数据格式不完整!', QMessageBox.Ok)
343+
return
344+
test_data = self.currentTestData[index].to_numpy()
345+
self.currentTestResult = self.classifier.evaluate(test_data)
346+
if has_answer:
347+
correct: np.ndarray = self.currentTestResult == self.currentTestData[
348+
self.currentDataBinder.infer('c')].to_numpy().reshape(-1, 1)
349+
self.currentTestAccuracy = correct.sum() / self.currentTestResult.shape[0]
350+
QMessageBox.information(self, '提示', '测试成功!正确率%.3f' % self.currentTestAccuracy, QMessageBox.Ok)
351+
else:
352+
QMessageBox.information(self, '提示', '测试成功!', QMessageBox.Ok)
353+
self.plot_refresh()
354+
else:
355+
QMessageBox.warning(self, '警告', '请先指定测试集文件!', QMessageBox.Ok)
356+
else:
357+
QMessageBox.warning(self, '警告', '请先创建分类器!', QMessageBox.Ok)
358+
183359

184360
if __name__ == '__main__':
185361
app = QApplication(sys.argv)

qtutils/ParameterDialog.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from PyQt5.QtWidgets import QDialog, QComboBox, QFormLayout, QHBoxLayout, QLabel, QSlider, QLineEdit
2+
from PyQt5.QtGui import QDoubleValidator
3+
from PyQt5.QtCore import Qt
4+
5+
6+
class ParameterDialog(QDialog):
7+
ComboBox = 0
8+
IntSlider = 1
9+
FloatSlider = 2
10+
11+
def __init__(self, parameter_dict: dict, parent=None):
12+
super(ParameterDialog, self).__init__(parent)
13+
self.setWindowTitle('Parameters')
14+
self.setWindowModality(Qt.NonModal)
15+
self.parameters = parameter_dict
16+
self.ui()
17+
18+
def ui(self):
19+
layout = QFormLayout(self)
20+
for k, v in self.parameters.items():
21+
if v['kind'] == ParameterDialog.ComboBox:
22+
comboBox = QComboBox()
23+
comboBox.addItems(v['items'])
24+
layout.addRow(k, comboBox)
25+
elif v['kind'] == ParameterDialog.FloatSlider:
26+
f = lambda n: v['min'] + n / 100 * (v['max'] - v['min'])
27+
g = lambda x: int((x - v['min']) / (v['max'] - v['min']) * 100)
28+
slider = QSlider(Qt.Horizontal)
29+
slider.setMinimum(0)
30+
slider.setMaximum(100)
31+
slider.setSingleStep(1)
32+
slider.setValue(g(v['init']))
33+
le = QLineEdit()
34+
le.setPlaceholderText(v['init'])
35+
dv = QDoubleValidator(self)
36+
dv.setRange(v['min'], v['max'])
37+
dv.setNotation(QDoubleValidator.StandardNotation)
38+
dv.setDecimals(2)
39+
le.setValidator(dv)
40+
# slider.valueChanged.connect(lambda n: le.setText(f(n)))
41+
# le.textChanged.connect(lambda x: slider.setValue(g(x)))
42+
hlayout = QHBoxLayout()
43+
hlayout.addWidget(slider, alignment=Qt.AlignLeft)
44+
hlayout.addWidget(le, alignment=Qt.AlignRight)
45+
layout.addRow(k, hlayout)
46+
self.setLayout(layout)

0 commit comments

Comments
 (0)