Skip to content

Commit 8346e5f

Browse files
committed
make plot labels more controlable
1 parent 24f0018 commit 8346e5f

File tree

9 files changed

+159
-78
lines changed

9 files changed

+159
-78
lines changed

README.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,4 +356,4 @@
356356
},
357357
"nbformat": 4,
358358
"nbformat_minor": 4
359-
}
359+
}

pkg/build/lib/wvpy/util.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,20 @@ def matching_roc_area_curve(auc):
9999
q_eps = 1e-6
100100
q_low = 0
101101
q_high = 1
102-
while(q_low + q_eps < q_high):
103-
q_mid = (q_low + q_high)/2.0
104-
q_mid_area = numpy.mean(
105-
1 - (1 - (1 - eval_pts)**q_mid)**(1/q_mid))
102+
while q_low + q_eps < q_high:
103+
q_mid = (q_low + q_high) / 2.0
104+
q_mid_area = numpy.mean(1 - (1 - (1 - eval_pts) ** q_mid) ** (1 / q_mid))
106105
if q_mid_area <= auc:
107106
q_high = q_mid
108107
else:
109108
q_low = q_mid
110109
q = (q_low + q_high) / 2.0
111-
return {'auc': auc,
112-
'q': q,
113-
'x': 1 - eval_pts,
114-
'y': 1 - (1 - (1 - eval_pts)**q)**(1/q)}
110+
return {
111+
"auc": auc,
112+
"q": q,
113+
"x": 1 - eval_pts,
114+
"y": 1 - (1 - (1 - eval_pts) ** q) ** (1 / q),
115+
}
115116

116117

117118
# https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
@@ -160,7 +161,7 @@ def plot_roc(
160161
lw = 2
161162
matplotlib.pyplot.gcf().clear()
162163
fig1, ax1 = matplotlib.pyplot.subplots()
163-
ax1.set_aspect('equal')
164+
ax1.set_aspect("equal")
164165
matplotlib.pyplot.plot(
165166
fpr,
166167
tpr,
@@ -172,10 +173,8 @@ def plot_roc(
172173
matplotlib.pyplot.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
173174
if ideal_curve is not None:
174175
matplotlib.pyplot.plot(
175-
ideal_curve['x'],
176-
ideal_curve['y'],
177-
linestyle='--',
178-
color=ideal_line_color)
176+
ideal_curve["x"], ideal_curve["y"], linestyle="--", color=ideal_line_color
177+
)
179178
matplotlib.pyplot.xlim([0.0, 1.0])
180179
matplotlib.pyplot.ylim([0.0, 1.0])
181180
matplotlib.pyplot.xlabel("False Positive Rate (1-Specificity)")
@@ -186,14 +185,28 @@ def plot_roc(
186185
return auc
187186

188187

189-
def dual_density_plot(probs, istrue, title="Double density plot", *, truth_target=True):
188+
def dual_density_plot(
189+
probs,
190+
istrue,
191+
title="Double density plot",
192+
*,
193+
truth_target=True,
194+
positive_label="positive examples",
195+
negative_lable="negative examples",
196+
ylable="density of examples",
197+
xlabel="model score"
198+
):
190199
"""
191200
Plot a dual density plot of numeric prediction probs against boolean istrue.
192201
193202
:param probs: vector of numeric predictions.
194203
:param istrue: truth vector
195204
:param title: tiotle of plot
196205
:param truth_target: value considerd true
206+
:param positive_label=label for positive class
207+
:param negative_lable=label for negative class
208+
:param ylable=y axis label
209+
:param xlabel=x axis label
197210
:return: None, plot produced by function call.
198211
199212
Example:
@@ -220,10 +233,10 @@ def dual_density_plot(probs, istrue, title="Double density plot", *, truth_targe
220233
preds_on_negative = [
221234
probs[i] for i in range(len(probs)) if not istrue[i] == truth_target
222235
]
223-
seaborn.kdeplot(preds_on_positive, label="positive examples", shade=True)
224-
seaborn.kdeplot(preds_on_negative, label="negative examples", shade=True)
225-
matplotlib.pyplot.ylabel("density of examples")
226-
matplotlib.pyplot.xlabel("model score")
236+
seaborn.kdeplot(preds_on_positive, label=positive_label, shade=True)
237+
seaborn.kdeplot(preds_on_negative, label=negative_lable, shade=True)
238+
matplotlib.pyplot.ylabel(ylable)
239+
matplotlib.pyplot.xlabel(xlabel)
227240
matplotlib.pyplot.title(title)
228241
matplotlib.pyplot.show()
229242

@@ -242,7 +255,15 @@ def dual_hist_plot(probs, istrue, title="Dual Histogram Plot"):
242255

243256

244257
def dual_density_plot_proba1(
245-
probs, istrue, title="Double density plot", *, truth_target=True
258+
probs,
259+
istrue,
260+
title="Double density plot",
261+
*,
262+
truth_target=True,
263+
positive_label="positive examples",
264+
negative_lable="negative examples",
265+
ylable="density of examples",
266+
xlabel="model score"
246267
):
247268
"""
248269
Plot a dual density plot of numeric prediction probs[:,1] against boolean istrue.
@@ -251,6 +272,10 @@ def dual_density_plot_proba1(
251272
:param istrue: truth target
252273
:param title: title of plot
253274
:param truth_target: value considered true
275+
:param positive_label=label for positive class
276+
:param negative_lable=label for negative class
277+
:param ylable=y axis label
278+
:param xlabel=x axis label
254279
:return: None, plot produced by call.
255280
"""
256281
istrue = [v for v in istrue]
@@ -261,10 +286,10 @@ def dual_density_plot_proba1(
261286
preds_on_negative = [
262287
probs[i, 1] for i in range(len(probs)) if not istrue[i] == truth_target
263288
]
264-
seaborn.kdeplot(preds_on_positive, label="positive examples", shade=True)
265-
seaborn.kdeplot(preds_on_negative, label="negative examples", shade=True)
266-
matplotlib.pyplot.ylabel("density of examples")
267-
matplotlib.pyplot.xlabel("model score")
289+
seaborn.kdeplot(preds_on_positive, label=positive_label, shade=True)
290+
seaborn.kdeplot(preds_on_negative, label=negative_lable, shade=True)
291+
matplotlib.pyplot.ylabel(ylable)
292+
matplotlib.pyplot.xlabel(xlabel)
268293
matplotlib.pyplot.title(title)
269294
matplotlib.pyplot.show()
270295

@@ -471,12 +496,14 @@ def threshold_statistics(
471496
# basic cumulative facts
472497
sorted_frame["count"] = sorted_frame["one"].cumsum() # predicted true so far
473498
sorted_frame["fraction"] = sorted_frame["count"] / max(1, sorted_frame["one"].sum())
474-
sorted_frame["precision"] = sorted_frame["truth"].cumsum() / sorted_frame["count"].clip(lower=1)
475-
sorted_frame["true_positive_rate"] = (
476-
sorted_frame["truth"].cumsum() / max(1, sorted_frame["truth"].sum())
499+
sorted_frame["precision"] = sorted_frame["truth"].cumsum() / sorted_frame[
500+
"count"
501+
].clip(lower=1)
502+
sorted_frame["true_positive_rate"] = sorted_frame["truth"].cumsum() / max(
503+
1, sorted_frame["truth"].sum()
477504
)
478-
sorted_frame["false_positive_rate"] = (
479-
sorted_frame["notY"].cumsum() / max(1, sorted_frame["notY"].sum())
505+
sorted_frame["false_positive_rate"] = sorted_frame["notY"].cumsum() / max(
506+
1, sorted_frame["notY"].sum()
480507
)
481508
sorted_frame["true_negative_rate"] = (
482509
sorted_frame["notY"].sum() - sorted_frame["notY"].cumsum()
@@ -486,7 +513,7 @@ def threshold_statistics(
486513
) / max(1, sorted_frame["truth"].sum())
487514

488515
# approximate cdf work
489-
sorted_frame['cdf'] = 1 - sorted_frame['fraction']
516+
sorted_frame["cdf"] = 1 - sorted_frame["fraction"]
490517

491518
# derived facts and synonyms
492519
sorted_frame["recall"] = sorted_frame["true_positive_rate"]
77 Bytes
Binary file not shown.

pkg/dist/wvpy-0.2.3.tar.gz

84 Bytes
Binary file not shown.

pkg/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
author_email="[email protected]",
1414
url="https://github.com/WinVector/wvpy",
1515
packages=setuptools.find_packages(),
16-
install_requires=["numpy", "pandas", "scikit-learn", "matplotlib", "data_algebra"],
16+
install_requires=["numpy", "pandas", "sklearn", "matplotlib", "data_algebra"],
1717
platforms=["any"],
1818
license="License :: OSI Approved :: BSD 3-clause License",
1919
description=DESCRIPTION,

pkg/tests/test_cross_plan1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
21
import wvpy.util
32

3+
44
def test_cross_plan1():
55
n = 10
66
k = 3
@@ -10,8 +10,8 @@ def test_cross_plan1():
1010
universe = set(range(n))
1111
saw = set()
1212
for split in plan:
13-
train = split['train']
14-
test = split['test']
13+
train = split["train"]
14+
test = split["test"]
1515
assert len(train) > 0
1616
assert len(test) > 0
1717
assert len(set(train) - universe) == 0

pkg/tests/test_stats1.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,52 @@
33
import data_algebra.test_util
44
import data_algebra.util
55

6+
67
def test_stats1():
78
d = pandas.DataFrame({"x": [1, 2, 3, 4, 5], "y": [False, False, True, True, False]})
89

910
stats = wvpy.util.threshold_statistics(d, model_predictions="x", yvalues="y",)
1011
# print(data_algebra.util.pandas_to_example_str(stats))
1112

12-
expect = pandas.DataFrame({
13-
'threshold': [0.999999, 1.0, 2.0, 3.0, 4.0, 5.0, 5.000001],
14-
'count': [5, 5, 4, 3, 2, 1, 0],
15-
'fraction': [1.0, 1.0, 0.8, 0.6, 0.4, 0.2, 0.0],
16-
'precision': [0.4, 0.4, 0.5, 0.6666666666666666, 0.5, 0.0, 0.0],
17-
'true_positive_rate': [1.0, 1.0, 1.0, 1.0, 0.5, 0.0, 0.0],
18-
'false_positive_rate': [1.0, 1.0, 0.6666666666666666, 0.3333333333333333, 0.3333333333333333, 0.3333333333333333, 0.0],
19-
'true_negative_rate': [0.0, 0.0, 0.3333333333333333, 0.6666666666666666, 0.6666666666666666, 0.6666666666666666, 1.0],
20-
'false_negative_rate': [0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0],
21-
'cdf': [0.0, 0.0, 0.19999999999999996, 0.4, 0.6, 0.8, 1.0],
22-
'recall': [1.0, 1.0, 1.0, 1.0, 0.5, 0.0, 0.0],
23-
'sensitivity': [1.0, 1.0, 1.0, 1.0, 0.5, 0.0, 0.0],
24-
'specificity': [0.0, 0.0, 0.33333333333333337, 0.6666666666666667, 0.6666666666666667, 0.6666666666666667, 1.0],
25-
})
13+
expect = pandas.DataFrame(
14+
{
15+
"threshold": [0.999999, 1.0, 2.0, 3.0, 4.0, 5.0, 5.000001],
16+
"count": [5, 5, 4, 3, 2, 1, 0],
17+
"fraction": [1.0, 1.0, 0.8, 0.6, 0.4, 0.2, 0.0],
18+
"precision": [0.4, 0.4, 0.5, 0.6666666666666666, 0.5, 0.0, 0.0],
19+
"true_positive_rate": [1.0, 1.0, 1.0, 1.0, 0.5, 0.0, 0.0],
20+
"false_positive_rate": [
21+
1.0,
22+
1.0,
23+
0.6666666666666666,
24+
0.3333333333333333,
25+
0.3333333333333333,
26+
0.3333333333333333,
27+
0.0,
28+
],
29+
"true_negative_rate": [
30+
0.0,
31+
0.0,
32+
0.3333333333333333,
33+
0.6666666666666666,
34+
0.6666666666666666,
35+
0.6666666666666666,
36+
1.0,
37+
],
38+
"false_negative_rate": [0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 1.0],
39+
"cdf": [0.0, 0.0, 0.19999999999999996, 0.4, 0.6, 0.8, 1.0],
40+
"recall": [1.0, 1.0, 1.0, 1.0, 0.5, 0.0, 0.0],
41+
"sensitivity": [1.0, 1.0, 1.0, 1.0, 0.5, 0.0, 0.0],
42+
"specificity": [
43+
0.0,
44+
0.0,
45+
0.33333333333333337,
46+
0.6666666666666667,
47+
0.6666666666666667,
48+
0.6666666666666667,
49+
1.0,
50+
],
51+
}
52+
)
2653

2754
assert data_algebra.test_util.equivalent_frames(stats, expect)

pkg/wvpy.egg-info/requires.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
numpy
22
pandas
3-
scikit-learn
3+
sklearn
44
matplotlib
55
data_algebra

0 commit comments

Comments
 (0)