Skip to content

Commit ca79d16

Browse files
committed
Add plotLIFT and plotCGain
1 parent d6d71a3 commit ca79d16

File tree

7 files changed

+250
-2
lines changed

7 files changed

+250
-2
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Imports: car,
1717
ggplot2,
1818
hnp,
1919
plotROC,
20+
ROCR,
2021
tseries
2122
RoxygenNote: 6.0.1
2223
Suggests: aods3,

NAMESPACE

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Generated by roxygen2: do not edit by hand
22

3-
S3method(plot,modelAudit)
43
export(audit)
4+
export(plot.modelAudit)
55
export(plotACF)
66
export(plotAutocorrelation)
7+
export(plotCGains)
78
export(plotCook)
89
export(plotHalfNormal)
10+
export(plotLIFT)
911
export(plotREC)
1012
export(plotROC)
1113
export(plotRROC)
@@ -17,6 +19,8 @@ export(scoreDW)
1719
export(scoreGQ)
1820
export(scoreHalfNormal)
1921
export(scoreRuns)
22+
import(ROCR)
23+
import(dplyr)
2024
import(ggplot2)
2125
import(plotROC)
2226
importFrom(car,durbinWatsonTest)

R/plotCGains.R

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#' @title Cumulative Gains Chart
2+
#'
3+
#' @description Cumulative Gains Chartis a plot of the rate of positive prediction against true positive rate for the different thresholds.
4+
#' It is useful for measuring and comparing the accuracy of the classificators.
5+
#' @param object An object of class ModelAudit
6+
#' @param newdata optionally, a data frame in which to look for variables with which to plot CGains curve. If omitted, the data used to build model will be used.
7+
#' @param newy optionally, required if newdata used. Response vector for new data.
8+
#' @param ... other modelAudit objects to be plotted together
9+
#'
10+
#' @return ggplot object
11+
#'
12+
#' @seealso \code{\link{plot.modelAudit}}
13+
#'
14+
#' @import ggplot2
15+
#' @import ROCR
16+
#'
17+
#' @examples
18+
#' library(auditor)
19+
#' library(mlbench)
20+
#' library(randomForest)
21+
#' data("PimaIndiansDiabetes")
22+
#'
23+
#' model_rf <- randomForest(diabetes~., data=PimaIndiansDiabetes)
24+
#' au_rf <- audit(model_rf, label="rf")
25+
#' plotCGains(au_rf)
26+
#'
27+
#' model_glm <- glm(diabetes~., family=binomial, data=PimaIndiansDiabetes)
28+
#' au_glm <- audit(model_glm)
29+
#' plotCGains(au_rf, au_glm)
30+
#'
31+
#' @export
32+
33+
34+
plotCGains <- function(object, ..., newdata = NULL, newy){
35+
if(class(object)!="modelAudit") stop("plotCGains requires object class modelAudit.")
36+
rpp <- tpr <- label <- NULL
37+
df <- getCGainsDF(object, newdata, newy)
38+
39+
dfl <- list(...)
40+
if (length(dfl) > 0) {
41+
for (resp in dfl) {
42+
if(class(resp)=="modelAudit"){
43+
df <- rbind( df, getCGainsDF(resp, newdata, newy) )
44+
}
45+
}
46+
}
47+
48+
ggplot(df, aes(x = rpp, y = tpr, color = label)) +
49+
geom_line() +
50+
xlab("Rate of Positive Prediction") +
51+
ylab("True Positive Rate") +
52+
theme_light()
53+
}
54+
55+
getCGainsDF <- function(object, newdata, newy){
56+
if (is.null(newdata)) {
57+
predictions <- object$fitted.values
58+
y <- object$y
59+
} else {
60+
if(is.null(newy)) stop("newy must be provided.")
61+
predictions <- object$predict.function(object$model, newdata)
62+
y <- newy
63+
}
64+
65+
pred <- prediction(predictions, y)
66+
gain <- performance(pred, "tpr", "rpp")
67+
68+
res <- data.frame(rpp = gain@x.values[[1]], tpr = gain@y.values[[1]], alpha = gain@alpha.values[[1]],
69+
label = object$label)
70+
return(res)
71+
}

R/plotLift.R

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#' @title Lift Chart
2+
#'
3+
#' @description Lift Chart shows the ratio of a model to a random guess.
4+
#'
5+
#' @param object An object of class ModelAudit
6+
#' @param newdata optionally, a data frame in which to look for variables with which to plot CGains curve. If omitted, the data used to build model will be used.
7+
#' @param newy optionally, required if newdata used. Response vector for new data.
8+
#' @param groups number of groups
9+
#' @param cumulative boolean. If TRUE cumulative lift curve will be plotted.
10+
#' @param ... other modelAudit objects to be plotted together
11+
#'
12+
#' @return ggplot object
13+
#'
14+
#' @seealso \code{\link{plot.modelAudit}}
15+
#'
16+
#' @import ggplot2
17+
#' @import dplyr
18+
#'
19+
#' @examples
20+
#' library(auditor)
21+
#' library(mlbench)
22+
#' library(randomForest)
23+
#' data("PimaIndiansDiabetes")
24+
#'
25+
#' model_rf <- randomForest(diabetes~., data=PimaIndiansDiabetes)
26+
#' au_rf <- audit(model_rf, label="rf")
27+
#' plotLIFT(au_rf)
28+
#'
29+
#' model_glm <- glm(diabetes~., family=binomial, data=PimaIndiansDiabetes)
30+
#' au_glm <- audit(model_glm)
31+
#' plotLIFT(au_rf, au_glm)
32+
#'
33+
#' @export
34+
35+
36+
plotLIFT <- function(object, ..., newdata = NULL, newy, groups = 10, cumulative = TRUE){
37+
if(class(object)!="modelAudit") stop("plotCGains requires object class modelAudit.")
38+
depth <- lift <- label <- NULL
39+
df <- getLIFTDF(object, newdata, newy, groups, cumulative)
40+
41+
dfl <- list(...)
42+
if (length(dfl) > 0) {
43+
for (resp in dfl) {
44+
if(class(resp)=="modelAudit"){
45+
df <- rbind( df, getLIFTDF(resp, newdata, newy, groups, cumulative) )
46+
}
47+
}
48+
}
49+
50+
ggplot(df, aes(x = depth, y = lift, color = label)) +
51+
geom_line() +
52+
xlab("Percentage of observations") +
53+
ylab("Lift") +
54+
theme_light()
55+
}
56+
57+
getLIFTDF <- function(object, newdata, newy, n.groups, cumulative = TRUE){
58+
pred <- NULL
59+
if (is.null(newdata)) {
60+
predictions <- object$fitted.values
61+
y <- as.numeric(as.character(object$y))
62+
} else {
63+
if(is.null(newy)) stop("newy must be provided.")
64+
predictions <- object$predict.function(object$model, newdata)
65+
y <- as.numeric(as.character(newy))
66+
}
67+
df <- data.frame(pred=predictions, y=y)
68+
df <- arrange(df, desc(pred))
69+
70+
group <- ceiling(seq_along(df[,2])/floor(nrow(df)/n.groups))
71+
72+
cap <- floor(nrow(df)/n.groups) * n.groups
73+
df <- stats::aggregate(df[1:cap,2], by=list(group[1:cap]), mean)
74+
75+
if (cumulative==TRUE) {
76+
df[,2] <- cumsum(df[,2])/seq_along(df[,2])
77+
}
78+
colnames(df) <- c("depth", "lift")
79+
df$lift <- df$lift/mean(y)
80+
df$depth <- 100* df$depth / n.groups
81+
df$label <- object$label
82+
return(df)
83+
}
84+

man/plot.modelAudit.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/plotCGains.Rd

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/plotLIFT.Rd

Lines changed: 46 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)