|
7 | 7 | import sklearn.metrics |
8 | 8 | import itertools |
9 | 9 | import pandas |
| 10 | +import math |
| 11 | +from data_algebra.cdata import * |
10 | 12 |
|
11 | 13 |
|
12 | 14 | # noinspection PyPep8Naming |
@@ -252,3 +254,61 @@ def perm_score_var_once(): |
252 | 254 | vf["importance_dev"] = [di[1] for di in stats] |
253 | 255 | vf.sort_values(by=["importance"], ascending=False, inplace=True) |
254 | 256 | return vf |
| 257 | + |
| 258 | + |
| 259 | +def threshold_statistics(d, model_predictions, yvalues): |
| 260 | + sorted_frame = d.sort_values([model_predictions], ascending=[False], inplace=False) |
| 261 | + sorted_frame.reset_index(inplace=True, drop=True) |
| 262 | + |
| 263 | + sorted_frame["precision"] = sorted_frame[yvalues].cumsum() # predicted true AND true (so far) |
| 264 | + sorted_frame["running"] = sorted_frame.index + 1 # predicted true so far |
| 265 | + sorted_frame["precision"] = sorted_frame["precision"] / sorted_frame["running"] |
| 266 | + sorted_frame["recall"] = sorted_frame[yvalues].cumsum() / sorted_frame[yvalues].sum() # denom = total true |
| 267 | + sorted_frame["enrichment"] = sorted_frame["precision"] / sorted_frame[yvalues].mean() |
| 268 | + sorted_frame["sensitivity"] = sorted_frame["recall"] |
| 269 | + |
| 270 | + sorted_frame["notY"] = 1 - sorted_frame[yvalues] # falses |
| 271 | + |
| 272 | + # num = predicted true AND false, denom = total false |
| 273 | + sorted_frame["false_positive_rate"] = sorted_frame["notY"].cumsum() / sorted_frame["notY"].sum() |
| 274 | + sorted_frame["specificity"] = 1 - sorted_frame["false_positive_rate"] |
| 275 | + |
| 276 | + sorted_frame.rename(columns={"prediction": "threshold"}, inplace=True) |
| 277 | + columns_I_want = ["threshold", "precision", "enrichment", "recall", "sensitivity", "specificity", |
| 278 | + "false_positive_rate"] |
| 279 | + sorted_frame = sorted_frame.loc[:, columns_I_want].copy() |
| 280 | + return sorted_frame |
| 281 | + |
| 282 | + |
| 283 | +def threshold_plot(d, pred_var, truth_var, truth_target, |
| 284 | + threshold_range=(-math.inf, math.inf), |
| 285 | + plotvars=("precision", "recall"), |
| 286 | + title="Measures as a function of threshold" |
| 287 | + ): |
| 288 | + frame = d.copy() |
| 289 | + frame["outcol"] = frame[truth_var] == truth_target |
| 290 | + |
| 291 | + prt_frame = threshold_statistics(frame, pred_var, "outcol") |
| 292 | + |
| 293 | + selector = (threshold_range[0] <= prt_frame.threshold) & \ |
| 294 | + (prt_frame.threshold <= threshold_range[1]) |
| 295 | + to_plot = prt_frame.loc[selector, :] |
| 296 | + |
| 297 | + reshaper = RecordMap( |
| 298 | + blocks_out=RecordSpecification( |
| 299 | + pandas.DataFrame({ |
| 300 | + 'measure': plotvars, |
| 301 | + 'value': plotvars, |
| 302 | + }), |
| 303 | + record_keys=['threshold'] |
| 304 | + ) |
| 305 | + ) |
| 306 | + |
| 307 | + prtlong = reshaper.transform(to_plot) |
| 308 | + prtlong.head() |
| 309 | + |
| 310 | + grid = seaborn.FacetGrid(prtlong, row="measure", row_order=plotvars, aspect=2, sharey=False) |
| 311 | + grid = grid.map(matplotlib.pyplot.plot, "threshold", "value") |
| 312 | + matplotlib.pyplot.subplots_adjust(top=0.9) |
| 313 | + grid.fig.suptitle(title) |
| 314 | + matplotlib.pyplot.show() |
0 commit comments